diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 5ca23e4b55..1498ec6b3b 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -44,6 +44,7 @@ import hashlib import json import multiprocessing +import warnings import os from dataclasses import asdict, dataclass, field from functools import cached_property @@ -1821,6 +1822,13 @@ def get_cached_dataset_tulu( )[0] +def remove_non_tensor_columns(dataset: Dataset) -> Dataset: + example = dataset[0] + cols_to_remove = [k for k, v in example.items() if not torch.is_tensor(v)] + warnings.warn(f"Removing non-tensor dataset colums {cols_to_remove}", stacklevel=1) + return dataset.remove_columns(cols_to_remove) + + def test_sft_dpo_same_tokenizer(): base_to_sft_tc = TokenizerConfig( tokenizer_name_or_path="meta-llama/Llama-3.1-8B", tokenizer_revision="main", chat_template_name="tulu" diff --git a/open_instruct/dpo_tune_cache.py b/open_instruct/dpo_tune_cache.py index 1f288eb778..4c0b8c681e 100644 --- a/open_instruct/dpo_tune_cache.py +++ b/open_instruct/dpo_tune_cache.py @@ -66,6 +66,7 @@ TokenizerConfig, get_cached_dataset_tulu, visualize_token, + remove_non_tensor_columns ) from open_instruct.dpo_utils import ( DataCollatorForSeq2SeqDPO, @@ -692,6 +693,9 @@ def load_model(): else: collate_fn = DataCollatorForSeq2SeqDPO(tokenizer=tokenizer, model=model, padding="longest") + # The collators expect to act on tensor data, so remove any non-tensor entries now. The + # non-tensor entries are assumed to be non-crucial metadata like `DATASET_ORIGIN_KEY` + train_dataset = remove_non_tensor_columns(train_dataset) train_dataloader = DataLoader( train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.per_device_train_batch_size ) diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 2c3e08a618..157d891db1 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -57,7 +57,8 @@ TOKENIZED_SFT_DATASET_KEYS, TokenizerConfig, get_cached_dataset_tulu, - visualize_token, + visualize_token_label, + remove_non_tensor_columns ) from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate from open_instruct.padding_free_collator import TensorDataCollatorWithFlattening @@ -620,6 +621,9 @@ def main(args: FlatArguments, tc: TokenizerConfig): collate_fn = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest") accelerator.print("Creating dataloader") + # The collators expect to act on tensor data, so remove any non-tensor entries now. The + # non-tensor entries are assumed to be non-crucial metadata like `DATASET_ORIGIN_KEY` + train_dataset = remove_non_tensor_columns(train_dataset) train_dataloader = DataLoader( train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.per_device_train_batch_size )