Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from functools import partial
from typing import List, Optional, Union

import pandas as pd
import datasets
import deepspeed
import torch
Expand Down Expand Up @@ -697,7 +698,7 @@ def main(args: FlatArguments):
configs=args.dataset_config_name,
splits=["train"],
save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None,
columns_to_keep=["messages"],
columns_to_keep=["messages","tools","documents"],
)
elif args.dataset_mixer_list is not None:
# mixing datasets via config
Expand All @@ -706,19 +707,41 @@ def main(args: FlatArguments):
configs=args.dataset_config_name,
splits=["train"],
save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None,
columns_to_keep=["messages"],
columns_to_keep=["messages","tools","documents"],
)
else:
data_files = {}
dataset_args = {}
if args.train_file is not None:
data_files["train"] = args.train_file
with accelerator.main_process_first():
raw_datasets = load_dataset(
args.train_file_type,
data_files=data_files,
**dataset_args,
)
with accelerator.main_process_first():
try:
raw_datasets = load_dataset(
args.train_file_type,
data_files=data_files,
**dataset_args,
)
except:
# - load_dataset sometimes has strict schema
# checks and may fail on tools / documents
train_files = args.train_file
if isinstance(train_files, str):
train_files = [train_files]

dfs = []
reader = (
partial(pd.read_json , lines=True)
if args.train_file_type == 'json'
else partial(pd.read_parquet, engine='auto')
)
for file in train_files:
dfs.append(reader(file, orient='records'))

raw_datasets = datasets.DatasetDict({
'train': datasets.Dataset.from_pandas(pd.concat(dfs))
})
del df
del dfs

# Load pretrained model and tokenizer
if args.config_name:
Expand Down
23 changes: 21 additions & 2 deletions open_instruct/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
import time
from dataclasses import dataclass
from typing import Any, List, NewType, Optional, Tuple, Union
import pandas as pd

import requests
from accelerate.logging import get_logger
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk, Dataset
from datasets.builder import DatasetGenerationError
from dateutil import parser
from huggingface_hub import HfApi
Expand Down Expand Up @@ -245,7 +246,13 @@ def get_datasets(
for split in splits:
# if dataset ends with .json or .jsonl, load from file
if ds.endswith(".json") or ds.endswith(".jsonl"):
dataset = load_dataset("json", data_files=ds, split=split)
try:
dataset = load_dataset("json", data_files=ds, split=split)
except:
# if there are tools / documents in dataset, load_dataset has problems
# - so we will use pandas to do the load
df = pd.read_json(ds,lines=True,orient='records')
dataset = Dataset.from_pandas(df)
else:
try:
# Try first if dataset on a Hub repo
Expand Down Expand Up @@ -307,6 +314,10 @@ def get_datasets(

# if id not in dataset, create it as ds-{index}
if "id" not in dataset.column_names:
logger.warning(
"Adding id into dataset via add_column, this could be very slow, "
"please consider pre-processing dataset to add column."
)
id_col = [f"{ds}_{i}" for i in range(len(dataset))]
dataset = dataset.add_column("id", id_col)

Expand All @@ -317,12 +328,20 @@ def get_datasets(

# if add_source_col, add that column
if add_source_col:
logger.warning(
"Adding source into dataset via add_column, this could be very slow, "
"please consider pre-processing dataset to add column."
)
source_col = [ds] * len(dataset)
dataset = dataset.add_column("source", source_col)

# for cols in columns_to_keep, if one is not present, add "None" to the column
for col in columns_to_keep:
if col not in dataset.column_names:
logger.warning(
f"Adding {col} into dataset via add_column, this could be very slow, "
"please consider pre-processing dataset to add column."
)
dataset = dataset.add_column(col, [None] * len(dataset))

# add tag to the dataset corresponding to where it was sourced from, for
Expand Down