Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 31 additions & 30 deletions applications/DeepSpeed-Chat/training/utils/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,44 +17,45 @@


def get_raw_dataset(dataset_name, output_path, seed, local_rank):
if dataset_name == "Dahoas/rm-static":
return raw_datasets.DahoasRmstaticDataset(output_path, seed,
local_rank)
elif dataset_name == "Dahoas/full-hh-rlhf":
return raw_datasets.DahoasFullhhrlhfDataset(output_path, seed,
local_rank)
elif dataset_name == "Dahoas/synthetic-instruct-gptj-pairwise":
dn = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change dn to longer and meaningful name, such as local_data_dir

if dataset_name.startswith("./"):
dn = dataset_name
if "Dahoas/rm-static" in dataset_name:
return raw_datasets.DahoasRmstaticDataset(output_path, seed, local_rank, dn)
elif "Dahoas/full-hh-rlhf" in dataset_name:
return raw_datasets.DahoasFullhhrlhfDataset(output_path, seed, local_rank, dn)
elif "Dahoas/synthetic-instruct-gptj-pairwise" in dataset_name:
return raw_datasets.DahoasSyntheticinstructgptjpairwiseDataset(
output_path, seed, local_rank)
elif dataset_name == "yitingxie/rlhf-reward-datasets":
output_path, seed, local_rank, dn)
elif "yitingxie/rlhf-reward-datasets" in dataset_name:
return raw_datasets.YitingxieRlhfrewarddatasetsDataset(
output_path, seed, local_rank)
elif dataset_name == "openai/webgpt_comparisons":
output_path, seed, local_rank, dn)
elif "openai/webgpt_comparisons" in dataset_name:
return raw_datasets.OpenaiWebgptcomparisonsDataset(
output_path, seed, local_rank)
elif dataset_name == "stanfordnlp/SHP":
output_path, seed, local_rank, dn)
elif "stanfordnlp/SHP" in dataset_name:
return raw_datasets.StanfordnlpSHPDataset(output_path, seed,
local_rank)
elif dataset_name == "wangrui6/Zhihu-KOL":
local_rank, dn)
elif "wangrui6/Zhihu-KOL" in dataset_name:
return raw_datasets.Wangrui6ZhihuKOLDataset(output_path, seed,
local_rank)
elif dataset_name == "Cohere/miracl-zh-queries-22-12":
local_rank, dn)
elif "Cohere/miracl-zh-queries-22-12" in dataset_name:
return raw_datasets.CohereMiraclzhqueries2212Dataset(
output_path, seed, local_rank)
elif dataset_name == "Hello-SimpleAI/HC3-Chinese":
output_path, seed, local_rank, dn)
elif "Hello-SimpleAI/HC3-Chinese" in dataset_name:
return raw_datasets.HelloSimpleAIHC3ChineseDataset(
output_path, seed, local_rank)
elif dataset_name == "mkqa-Chinese":
return raw_datasets.MkqaChineseDataset(output_path, seed, local_rank)
elif dataset_name == "mkqa-Japanese":
return raw_datasets.MkqaJapaneseDataset(output_path, seed, local_rank)
elif dataset_name == "Cohere/miracl-ja-queries-22-12":
output_path, seed, local_rank, dn)
elif "mkqa-Chinese" in dataset_name:
return raw_datasets.MkqaChineseDataset(output_path, seed, local_rank, dn)
elif "mkqa-Japanese" in dataset_name:
return raw_datasets.MkqaJapaneseDataset(output_path, seed, local_rank, dn)
elif "Cohere/miracl-ja-queries-22-12" in dataset_name:
return raw_datasets.CohereMiracljaqueries2212Dataset(
output_path, seed, local_rank)
elif dataset_name == "lmqg/qg_jaquad":
return raw_datasets.LmqgQgjaquadDataset(output_path, seed, local_rank)
elif dataset_name == "lmqg/qag_jaquad":
return raw_datasets.LmqgQagjaquadDataset(output_path, seed, local_rank)
output_path, seed, local_rank, dn)
elif "lmqg/qg_jaquad" in dataset_name:
return raw_datasets.LmqgQgjaquadDataset(output_path, seed, local_rank, dn)
elif "lmqg/qag_jaquad" in dataset_name:
return raw_datasets.LmqgQagjaquadDataset(output_path, seed, local_rank, dn)
else:
raise RuntimeError(
f"We do not have configs for dataset {dataset_name}, but you can add it by yourself in raw_datasets.py."
Expand Down
133 changes: 103 additions & 30 deletions applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# follow in order to have a unified API and unified data format.
class PromptRawDataset(object):

def __init__(self, output_path, seed, local_rank):
def __init__(self, output_path, seed, local_rank, dataset_name=None):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataset_name is a confusing variable name, I recommend to change it to local_data_dir. Same comment apply to all other classes.

self.output_path = output_path
self.seed = seed
self.local_rank = local_rank
Expand Down Expand Up @@ -45,11 +45,15 @@ def get_prompt_and_rejected(self, sample):
# English dataset
class DahoasRmstaticDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
def __init__(self, output_path, seed, local_rank, dataset_name=None):
super().__init__(output_path, seed, local_rank)
self.dataset_name = "Dahoas/rm-static"
self.dataset_name_clean = "Dahoas_rm_static"
self.raw_datasets = load_dataset("Dahoas/rm-static")
if dataset_name is not None:
self.raw_datasets = load_dataset("parquet", data_dir=dataset_name)
else:
self.raw_datasets = load_dataset("Dahoas/rm-static")


def get_train_data(self):
return self.raw_datasets["train"]
Expand All @@ -76,11 +80,14 @@ def get_prompt_and_rejected(self, sample):
# English dataset
class DahoasFullhhrlhfDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
def __init__(self, output_path, seed, local_rank, dataset_name=None):
super().__init__(output_path, seed, local_rank)
self.dataset_name = "Dahoas/full-hh-rlhf"
self.dataset_name_clean = "Dahoas_full_hh_rlhf"
self.raw_datasets = load_dataset("Dahoas/full-hh-rlhf")
if dataset_name is not None:
self.raw_datasets = load_dataset("parquet", data_dir=dataset_name)
else:
self.raw_datasets = load_dataset("Dahoas/full-hh-rlhf")

def get_train_data(self):
return self.raw_datasets["train"]
Expand All @@ -107,12 +114,15 @@ def get_prompt_and_rejected(self, sample):
# English dataset
class DahoasSyntheticinstructgptjpairwiseDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
def __init__(self, output_path, seed, local_rank, dataset_name=None):
super().__init__(output_path, seed, local_rank)
self.dataset_name = "Dahoas/synthetic-instruct-gptj-pairwise"
self.dataset_name_clean = "Dahoas_synthetic_instruct_gptj_pairwise"
self.raw_datasets = load_dataset(
"Dahoas/synthetic-instruct-gptj-pairwise")
if dataset_name is not None:
self.raw_datasets = load_dataset("parquet", data_dir=dataset_name)
else:
self.raw_datasets = load_dataset(
"Dahoas/synthetic-instruct-gptj-pairwise")

def get_train_data(self):
from .data_utils import get_raw_dataset_split_index
Expand Down Expand Up @@ -154,11 +164,14 @@ def get_prompt_and_rejected(self, sample):
# English dataset
class YitingxieRlhfrewarddatasetsDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
def __init__(self, output_path, seed, local_rank, dataset_name=None):
super().__init__(output_path, seed, local_rank)
self.dataset_name = "yitingxie/rlhf-reward-datasets"
self.dataset_name_clean = "yitingxie_rlhf_reward_datasets"
self.raw_datasets = load_dataset("yitingxie/rlhf-reward-datasets")
if dataset_name is not None:
self.raw_datasets = load_dataset("parquet", data_dir=dataset_name)
else:
self.raw_datasets = load_dataset("yitingxie/rlhf-reward-datasets")

def get_train_data(self):
return self.raw_datasets["train"]
Expand All @@ -185,11 +198,14 @@ def get_prompt_and_rejected(self, sample):
# English dataset
class OpenaiWebgptcomparisonsDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
def __init__(self, output_path, seed, local_rank, dataset_name=None):
super().__init__(output_path, seed, local_rank)
self.dataset_name = "openai/webgpt_comparisons"
self.dataset_name_clean = "openai_webgpt_comparisons"
self.raw_datasets = load_dataset("openai/webgpt_comparisons")
if dataset_name is not None:
self.raw_datasets = load_dataset("parquet", data_dir=dataset_name)
else:
self.raw_datasets = load_dataset("openai/webgpt_comparisons")

def get_train_data(self):
from .data_utils import get_raw_dataset_split_index
Expand Down Expand Up @@ -259,11 +275,14 @@ def get_prompt_and_rejected(self, sample):
# English dataset
class StanfordnlpSHPDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
def __init__(self, output_path, seed, local_rank, dataset_name=None):
super().__init__(output_path, seed, local_rank)
self.dataset_name = "stanfordnlp/SHP"
self.dataset_name_clean = "stanfordnlp_SHP"
self.raw_datasets = load_dataset("stanfordnlp/SHP")
if dataset_name is not None:
self.raw_datasets = load_dataset("parquet", data_dir=dataset_name)
else:
self.raw_datasets = load_dataset("stanfordnlp/SHP")

def get_train_data(self):
return self.raw_datasets["train"]
Expand Down Expand Up @@ -306,11 +325,14 @@ def get_prompt_and_rejected(self, sample):
# Chinese dataset
class Wangrui6ZhihuKOLDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
def __init__(self, output_path, seed, local_rank, dataset_name=None):
super().__init__(output_path, seed, local_rank)
self.dataset_name = "wangrui6/Zhihu-KOL"
self.dataset_name_clean = "wangrui6_Zhihu_KOL"
self.raw_datasets = load_dataset("wangrui6/Zhihu-KOL")
if dataset_name is not None:
self.raw_datasets = load_dataset("parquet", data_dir=dataset_name)
else:
self.raw_datasets = load_dataset("wangrui6/Zhihu-KOL")

def get_train_data(self):
from .data_utils import get_raw_dataset_split_index
Expand Down Expand Up @@ -364,11 +386,14 @@ def get_prompt_and_rejected(self, sample):
# Chinese dataset
class CohereMiraclzhqueries2212Dataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
def __init__(self, output_path, seed, local_rank, dataset_name=None):
super().__init__(output_path, seed, local_rank)
self.dataset_name = "Cohere/miracl-zh-queries-22-12"
self.dataset_name_clean = "Cohere_miracl_zh_queries_22_12"
self.raw_datasets = load_dataset("Cohere/miracl-zh-queries-22-12")
if dataset_name is not None:
self.raw_datasets = load_dataset("parquet", data_dir=dataset_name)
else:
self.raw_datasets = load_dataset("Cohere/miracl-zh-queries-22-12")

def get_train_data(self):
return self.raw_datasets["train"]
Expand Down Expand Up @@ -397,11 +422,14 @@ def get_prompt_and_rejected(self, sample):
# Chinese dataset
class HelloSimpleAIHC3ChineseDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
def __init__(self, output_path, seed, local_rank, dataset_name=None):
super().__init__(output_path, seed, local_rank)
self.dataset_name = "Hello-SimpleAI/HC3-Chinese"
self.dataset_name_clean = "Hello_SimpleAI_HC3_Chinese"
self.raw_datasets = load_dataset("Hello-SimpleAI/HC3-Chinese", "all")
if dataset_name is not None:
self.raw_datasets = load_dataset("parquet", data_dir=dataset_name)
else:
self.raw_datasets = load_dataset("Hello-SimpleAI/HC3-Chinese", "all")

def get_train_data(self):
from .data_utils import get_raw_dataset_split_index
Expand Down Expand Up @@ -456,11 +484,14 @@ def get_prompt_and_rejected(self, sample):
# Chinese dataset
class MkqaChineseDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
def __init__(self, output_path, seed, local_rank, dataset_name=None):
super().__init__(output_path, seed, local_rank)
self.dataset_name = "mkqa-Chinese"
self.dataset_name_clean = "mkqa"
self.raw_datasets = load_dataset("mkqa")
if dataset_name is not None:
self.raw_datasets = load_dataset("parquet", data_dir=dataset_name)
else:
self.raw_datasets = load_dataset("mkqa")

def get_train_data(self):
from .data_utils import get_raw_dataset_split_index
Expand Down Expand Up @@ -516,11 +547,14 @@ def get_prompt_and_rejected(self, sample):
# Japanese dataset
class MkqaJapaneseDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
def __init__(self, output_path, seed, local_rank, dataset_name=None):
super().__init__(output_path, seed, local_rank)
self.dataset_name = "mkqa-Japanese"
self.dataset_name_clean = "mkqa"
self.raw_datasets = load_dataset("mkqa")
if dataset_name is not None:
self.raw_datasets = load_dataset("parquet", data_dir=dataset_name)
else:
self.raw_datasets = load_dataset("mkqa")

def get_train_data(self):
from .data_utils import get_raw_dataset_split_index
Expand Down Expand Up @@ -575,11 +609,14 @@ def get_prompt_and_rejected(self, sample):
# Japanese dataset
class CohereMiracljaqueries2212Dataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
def __init__(self, output_path, seed, local_rank, dataset_name=None):
super().__init__(output_path, seed, local_rank)
self.dataset_name = "Cohere/miracl-ja-queries-22-12"
self.dataset_name_clean = "Cohere_miracl_ja_queries_22_12"
self.raw_datasets = load_dataset("Cohere/miracl-ja-queries-22-12")
if dataset_name is not None:
self.raw_datasets = load_dataset("parquet", data_dir=dataset_name)
else:
self.raw_datasets = load_dataset("Cohere/miracl-ja-queries-22-12")

def get_train_data(self):
return self.raw_datasets["train"]
Expand Down Expand Up @@ -608,11 +645,14 @@ def get_prompt_and_rejected(self, sample):
# Japanese dataset
class LmqgQgjaquadDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
def __init__(self, output_path, seed, local_rank, dataset_name=None):
super().__init__(output_path, seed, local_rank)
self.dataset_name = "lmqg/qg_jaquad"
self.dataset_name_clean = "lmqg_qg_jaquad"
self.raw_datasets = load_dataset("lmqg/qg_jaquad")
if dataset_name is not None:
self.raw_datasets = load_dataset("parquet", data_dir=dataset_name)
else:
self.raw_datasets = load_dataset("lmqg/qg_jaquad")

def get_train_data(self):
return self.raw_datasets["train"]
Expand Down Expand Up @@ -646,11 +686,14 @@ def get_prompt_and_rejected(self, sample):
# Japanese dataset
class LmqgQagjaquadDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
def __init__(self, output_path, seed, local_rank, dataset_name=None):
super().__init__(output_path, seed, local_rank)
self.dataset_name = "lmqg/qag_jaquad"
self.dataset_name_clean = "lmqg_qag_jaquad"
self.raw_datasets = load_dataset("lmqg/qag_jaquad")
if dataset_name is not None:
self.raw_datasets = load_dataset("parquet", data_dir=dataset_name)
else:
self.raw_datasets = load_dataset("lmqg/qag_jaquad")

def get_train_data(self):
return self.raw_datasets["train"]
Expand Down Expand Up @@ -679,3 +722,33 @@ def get_prompt_and_rejected(self, sample):
f"Warning: dataset {self.dataset_name} does not include rejected response."
)
return None


class LocalParquetDataset(PromptRawDataset):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this class? This LocalParquetDataset class is never used in your PR. Please remove it if it's not necessary.

def __init__(self, output_path, seed, local_rank, dataset_name):
super().__init__(output_path, seed, local_rank)
self.dataset_name = dataset_name.strip("./")
self.dataset_name_clean = self.dataset_name.replace("/", "_")
self.raw_datasets = load_dataset("parquet", data_dir=dataset_name)

def get_train_data(self):
# TODO
raise NotImplementedError("you need to fill your data.")

def get_eval_data(self):
raise NotImplementedError("you need to fill your data.")

def get_prompt(self, sample):
raise NotImplementedError("you need to fill your data.")

def get_chosen(self, sample):
raise NotImplementedError("you need to fill your data.")

def get_rejected(self, sample):
raise NotImplementedError("you need to fill your data.")

def get_prompt_and_chosen(self, sample):
raise NotImplementedError("you need to fill your data.")

def get_prompt_and_rejected(self, sample):
raise NotImplementedError("you need to fill your data.")