Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions open_instruct/dataset_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import hashlib
import json
import multiprocessing
import warnings
import os
from dataclasses import asdict, dataclass, field
from functools import cached_property
Expand Down Expand Up @@ -1821,6 +1822,13 @@ def get_cached_dataset_tulu(
)[0]


def remove_non_tensor_columns(dataset: Dataset) -> Dataset:
example = dataset[0]
cols_to_remove = [k for k, v in example.items() if not torch.is_tensor(v)]
warnings.warn(f"Removing non-tensor dataset colums {cols_to_remove}", stacklevel=1)
return dataset.remove_columns(cols_to_remove)


def test_sft_dpo_same_tokenizer():
base_to_sft_tc = TokenizerConfig(
tokenizer_name_or_path="meta-llama/Llama-3.1-8B", tokenizer_revision="main", chat_template_name="tulu"
Expand Down
4 changes: 4 additions & 0 deletions open_instruct/dpo_tune_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
TokenizerConfig,
get_cached_dataset_tulu,
visualize_token,
remove_non_tensor_columns
)
from open_instruct.dpo_utils import (
DataCollatorForSeq2SeqDPO,
Expand Down Expand Up @@ -692,6 +693,9 @@ def load_model():
else:
collate_fn = DataCollatorForSeq2SeqDPO(tokenizer=tokenizer, model=model, padding="longest")

# The collators expect to act on tensor data, so remove any non-tensor entries now. The
# non-tensor entries are assumed to be non-crucial metadata like `DATASET_ORIGIN_KEY`
train_dataset = remove_non_tensor_columns(train_dataset)
train_dataloader = DataLoader(
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.per_device_train_batch_size
)
Expand Down
6 changes: 5 additions & 1 deletion open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@
TOKENIZED_SFT_DATASET_KEYS,
TokenizerConfig,
get_cached_dataset_tulu,
visualize_token,
visualize_token_label,
remove_non_tensor_columns
)
from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate
from open_instruct.padding_free_collator import TensorDataCollatorWithFlattening
Expand Down Expand Up @@ -620,6 +621,9 @@ def main(args: FlatArguments, tc: TokenizerConfig):
collate_fn = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest")

accelerator.print("Creating dataloader")
# The collators expect to act on tensor data, so remove any non-tensor entries now. The
# non-tensor entries are assumed to be non-crucial metadata like `DATASET_ORIGIN_KEY`
train_dataset = remove_non_tensor_columns(train_dataset)
train_dataloader = DataLoader(
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.per_device_train_batch_size
)
Expand Down