diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 50f99ccc84..808787d7d5 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -27,6 +27,7 @@ from functools import partial from typing import List, Optional, Union +import pandas as pd import datasets import deepspeed import torch @@ -697,7 +698,7 @@ def main(args: FlatArguments): configs=args.dataset_config_name, splits=["train"], save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None, - columns_to_keep=["messages"], + columns_to_keep=["messages","tools","documents"], ) elif args.dataset_mixer_list is not None: # mixing datasets via config @@ -706,19 +707,41 @@ def main(args: FlatArguments): configs=args.dataset_config_name, splits=["train"], save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None, - columns_to_keep=["messages"], + columns_to_keep=["messages","tools","documents"], ) else: data_files = {} dataset_args = {} if args.train_file is not None: data_files["train"] = args.train_file - with accelerator.main_process_first(): - raw_datasets = load_dataset( - args.train_file_type, - data_files=data_files, - **dataset_args, - ) + with accelerator.main_process_first(): + try: + raw_datasets = load_dataset( + args.train_file_type, + data_files=data_files, + **dataset_args, + ) + except: + # - load_dataset sometimes has strict schema + # checks and may fail on tools / documents + train_files = args.train_file + if isinstance(train_files, str): + train_files = [train_files] + + dfs = [] + reader = ( + partial(pd.read_json , lines=True) + if args.train_file_type == 'json' + else partial(pd.read_parquet, engine='auto') + ) + for file in train_files: + dfs.append(reader(file, orient='records')) + + raw_datasets = datasets.DatasetDict({ + 'train': datasets.Dataset.from_pandas(pd.concat(dfs)) + }) + del df + del dfs # Load pretrained model and tokenizer if args.config_name: diff --git a/open_instruct/utils.py b/open_instruct/utils.py index f4f1f9e86d..7f03be0613 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -23,10 +23,11 @@ import time from dataclasses import dataclass from typing import Any, List, NewType, Optional, Tuple, Union +import pandas as pd import requests from accelerate.logging import get_logger -from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk +from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk, Dataset from datasets.builder import DatasetGenerationError from dateutil import parser from huggingface_hub import HfApi @@ -245,7 +246,13 @@ def get_datasets( for split in splits: # if dataset ends with .json or .jsonl, load from file if ds.endswith(".json") or ds.endswith(".jsonl"): - dataset = load_dataset("json", data_files=ds, split=split) + try: + dataset = load_dataset("json", data_files=ds, split=split) + except: + # if there are tools / documents in dataset, load_dataset has problems + # - so we will use pandas to do the load + df = pd.read_json(ds,lines=True,orient='records') + dataset = Dataset.from_pandas(df) else: try: # Try first if dataset on a Hub repo @@ -307,6 +314,10 @@ def get_datasets( # if id not in dataset, create it as ds-{index} if "id" not in dataset.column_names: + logger.warning( + "Adding id into dataset via add_column, this could be very slow, " + "please consider pre-processing dataset to add column." + ) id_col = [f"{ds}_{i}" for i in range(len(dataset))] dataset = dataset.add_column("id", id_col) @@ -317,12 +328,20 @@ def get_datasets( # if add_source_col, add that column if add_source_col: + logger.warning( + "Adding source into dataset via add_column, this could be very slow, " + "please consider pre-processing dataset to add column." + ) source_col = [ds] * len(dataset) dataset = dataset.add_column("source", source_col) # for cols in columns_to_keep, if one is not present, add "None" to the column for col in columns_to_keep: if col not in dataset.column_names: + logger.warning( + f"Adding {col} into dataset via add_column, this could be very slow, " + "please consider pre-processing dataset to add column." + ) dataset = dataset.add_column(col, [None] * len(dataset)) # add tag to the dataset corresponding to where it was sourced from, for