Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion applications/DeepSpeed-Chat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ bash training_scripts/single_gpu/run_1.3b.sh
### 🐼 Adding and using your own datasets in DeepSpeed-Chat
In addition to the datasets used in our example scripts, you can also add and use your own datasets. To do so, first you need to add a new Class in [training/utils/data/raw_datasets.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py) to define the format when using your data. You need to make sure to follow the APIs and format defined in the PromptRawDataset class to ensure a consistent data format that DeepSpeed-Chat relies on. You can look at the existing classes to learn how to do so.

Second, you need to add an if condition in function get_raw_dataset in [training/utils/data/data_utils.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/utils/data/data_utils.py) corresponding to your new dataset. The dataset_name string in the if condition should be the dataset name you will provide as a arg for the training scripts. Last, you need to add your new dataset's dataset_name into your "--data_path" arg in your training scripts.
Second, you need to add an if condition in function get_raw_dataset in [training/utils/data/data_utils.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/utils/data/data_utils.py) corresponding to your new dataset. The dataset_name string in the if condition should be the dataset name you will provide as a arg for the training scripts. Last, you need to add your new dataset's dataset_name into your "--data_path" arg in your training scripts.
If you have downloaded huggingface datasets manually, you can add your local path into "--data_path", such as "--data_path ./relative/Dahoas/rm-static" and "--data_path /absolute/Dahoas/rm-static". Remeber you should not make `data/` in your local path, it may cause an exception to `load_dataset`.

One thing to note that some datasets may only have one response instead of two responses. For those datasets, you can only use them in step 1. And in such case, you should add the dataset_name as part of the "--sft_only_data_path" arg instead of the "--data_path" arg. One thing to note is that: If you plan to only do step 1 SFT, adding more single-response datasets is definitely beneficial. However, if you do plan to do steps 2 and 3, then adding too many single-response datasets during SFT could backfire: these data could be different from the data used for steps 2/3, generating different distributions which could cause training instability/worse model quality during step 2/3. That is part of the reason why we focused on trying the datasets with two responses and the preference, and always split a dataset into all 3 steps.

Expand Down
61 changes: 33 additions & 28 deletions applications/DeepSpeed-Chat/training/utils/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,44 +18,49 @@


def get_raw_dataset(dataset_name, output_path, seed, local_rank):
if dataset_name == "Dahoas/rm-static":

if "Dahoas/rm-static" in dataset_name:
return raw_datasets.DahoasRmstaticDataset(output_path, seed,
local_rank)
elif dataset_name == "Dahoas/full-hh-rlhf":
local_rank, dataset_name)
elif "Dahoas/full-hh-rlhf" in dataset_name:
return raw_datasets.DahoasFullhhrlhfDataset(output_path, seed,
local_rank)
elif dataset_name == "Dahoas/synthetic-instruct-gptj-pairwise":
local_rank, dataset_name)
elif "Dahoas/synthetic-instruct-gptj-pairwise" in dataset_name:
return raw_datasets.DahoasSyntheticinstructgptjpairwiseDataset(
output_path, seed, local_rank)
elif dataset_name == "yitingxie/rlhf-reward-datasets":
output_path, seed, local_rank, dataset_name)
elif "yitingxie/rlhf-reward-datasets" in dataset_name:
return raw_datasets.YitingxieRlhfrewarddatasetsDataset(
output_path, seed, local_rank)
elif dataset_name == "openai/webgpt_comparisons":
output_path, seed, local_rank, dataset_name)
elif "openai/webgpt_comparisons" in dataset_name:
return raw_datasets.OpenaiWebgptcomparisonsDataset(
output_path, seed, local_rank)
elif dataset_name == "stanfordnlp/SHP":
output_path, seed, local_rank, dataset_name)
elif "stanfordnlp/SHP" in dataset_name:
return raw_datasets.StanfordnlpSHPDataset(output_path, seed,
local_rank)
elif dataset_name == "wangrui6/Zhihu-KOL":
local_rank, dataset_name)
elif "wangrui6/Zhihu-KOL" in dataset_name:
return raw_datasets.Wangrui6ZhihuKOLDataset(output_path, seed,
local_rank)
elif dataset_name == "Cohere/miracl-zh-queries-22-12":
local_rank, dataset_name)
elif "Cohere/miracl-zh-queries-22-12" in dataset_name:
return raw_datasets.CohereMiraclzhqueries2212Dataset(
output_path, seed, local_rank)
elif dataset_name == "Hello-SimpleAI/HC3-Chinese":
output_path, seed, local_rank, dataset_name)
elif "Hello-SimpleAI/HC3-Chinese" in dataset_name:
return raw_datasets.HelloSimpleAIHC3ChineseDataset(
output_path, seed, local_rank)
elif dataset_name == "mkqa-Chinese":
return raw_datasets.MkqaChineseDataset(output_path, seed, local_rank)
elif dataset_name == "mkqa-Japanese":
return raw_datasets.MkqaJapaneseDataset(output_path, seed, local_rank)
elif dataset_name == "Cohere/miracl-ja-queries-22-12":
output_path, seed, local_rank, dataset_name)
elif "mkqa-Chinese" in dataset_name:
return raw_datasets.MkqaChineseDataset(output_path, seed, local_rank,
dataset_name)
elif "mkqa-Japanese" in dataset_name:
return raw_datasets.MkqaJapaneseDataset(output_path, seed, local_rank,
dataset_name)
elif "Cohere/miracl-ja-queries-22-12" in dataset_name:
return raw_datasets.CohereMiracljaqueries2212Dataset(
output_path, seed, local_rank)
elif dataset_name == "lmqg/qg_jaquad":
return raw_datasets.LmqgQgjaquadDataset(output_path, seed, local_rank)
elif dataset_name == "lmqg/qag_jaquad":
return raw_datasets.LmqgQagjaquadDataset(output_path, seed, local_rank)
output_path, seed, local_rank, dataset_name)
elif "lmqg/qg_jaquad" in dataset_name:
return raw_datasets.LmqgQgjaquadDataset(output_path, seed, local_rank,
dataset_name)
elif "lmqg/qag_jaquad" in dataset_name:
return raw_datasets.LmqgQagjaquadDataset(output_path, seed, local_rank,
dataset_name)
else:
raise RuntimeError(
f"We do not have configs for dataset {dataset_name}, but you can add it by yourself in raw_datasets.py."
Expand Down
74 changes: 30 additions & 44 deletions applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# follow in order to have a unified API and unified data format.
class PromptRawDataset(object):

def __init__(self, output_path, seed, local_rank):
def __init__(self, output_path, seed, local_rank, dataset_name):
self.output_path = output_path
self.seed = seed
self.local_rank = local_rank
self.raw_datasets = load_dataset(dataset_name)

def get_train_data(self):
return
Expand Down Expand Up @@ -45,11 +46,10 @@ def get_prompt_and_rejected(self, sample):
# English dataset
class DahoasRmstaticDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
super().__init__(output_path, seed, local_rank)
def __init__(self, output_path, seed, local_rank, dataset_name):
super().__init__(output_path, seed, local_rank, dataset_name)
self.dataset_name = "Dahoas/rm-static"
self.dataset_name_clean = "Dahoas_rm_static"
self.raw_datasets = load_dataset("Dahoas/rm-static")

def get_train_data(self):
return self.raw_datasets["train"]
Expand All @@ -76,11 +76,10 @@ def get_prompt_and_rejected(self, sample):
# English dataset
class DahoasFullhhrlhfDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
super().__init__(output_path, seed, local_rank)
def __init__(self, output_path, seed, local_rank, dataset_name):
super().__init__(output_path, seed, local_rank, dataset_name)
self.dataset_name = "Dahoas/full-hh-rlhf"
self.dataset_name_clean = "Dahoas_full_hh_rlhf"
self.raw_datasets = load_dataset("Dahoas/full-hh-rlhf")

def get_train_data(self):
return self.raw_datasets["train"]
Expand All @@ -107,12 +106,10 @@ def get_prompt_and_rejected(self, sample):
# English dataset
class DahoasSyntheticinstructgptjpairwiseDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
super().__init__(output_path, seed, local_rank)
def __init__(self, output_path, seed, local_rank, dataset_name):
super().__init__(output_path, seed, local_rank, dataset_name)
self.dataset_name = "Dahoas/synthetic-instruct-gptj-pairwise"
self.dataset_name_clean = "Dahoas_synthetic_instruct_gptj_pairwise"
self.raw_datasets = load_dataset(
"Dahoas/synthetic-instruct-gptj-pairwise")

def get_train_data(self):
from .data_utils import get_raw_dataset_split_index
Expand Down Expand Up @@ -154,11 +151,10 @@ def get_prompt_and_rejected(self, sample):
# English dataset
class YitingxieRlhfrewarddatasetsDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
super().__init__(output_path, seed, local_rank)
def __init__(self, output_path, seed, local_rank, dataset_name):
super().__init__(output_path, seed, local_rank, dataset_name)
self.dataset_name = "yitingxie/rlhf-reward-datasets"
self.dataset_name_clean = "yitingxie_rlhf_reward_datasets"
self.raw_datasets = load_dataset("yitingxie/rlhf-reward-datasets")

def get_train_data(self):
return self.raw_datasets["train"]
Expand All @@ -185,11 +181,10 @@ def get_prompt_and_rejected(self, sample):
# English dataset
class OpenaiWebgptcomparisonsDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
super().__init__(output_path, seed, local_rank)
def __init__(self, output_path, seed, local_rank, dataset_name):
super().__init__(output_path, seed, local_rank, dataset_name)
self.dataset_name = "openai/webgpt_comparisons"
self.dataset_name_clean = "openai_webgpt_comparisons"
self.raw_datasets = load_dataset("openai/webgpt_comparisons")

def get_train_data(self):
from .data_utils import get_raw_dataset_split_index
Expand Down Expand Up @@ -259,11 +254,10 @@ def get_prompt_and_rejected(self, sample):
# English dataset
class StanfordnlpSHPDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
super().__init__(output_path, seed, local_rank)
def __init__(self, output_path, seed, local_rank, dataset_name):
super().__init__(output_path, seed, local_rank, dataset_name)
self.dataset_name = "stanfordnlp/SHP"
self.dataset_name_clean = "stanfordnlp_SHP"
self.raw_datasets = load_dataset("stanfordnlp/SHP")

def get_train_data(self):
return self.raw_datasets["train"]
Expand Down Expand Up @@ -306,11 +300,10 @@ def get_prompt_and_rejected(self, sample):
# Chinese dataset
class Wangrui6ZhihuKOLDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
super().__init__(output_path, seed, local_rank)
def __init__(self, output_path, seed, local_rank, dataset_name):
super().__init__(output_path, seed, local_rank, dataset_name)
self.dataset_name = "wangrui6/Zhihu-KOL"
self.dataset_name_clean = "wangrui6_Zhihu_KOL"
self.raw_datasets = load_dataset("wangrui6/Zhihu-KOL")

def get_train_data(self):
from .data_utils import get_raw_dataset_split_index
Expand Down Expand Up @@ -364,11 +357,10 @@ def get_prompt_and_rejected(self, sample):
# Chinese dataset
class CohereMiraclzhqueries2212Dataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
super().__init__(output_path, seed, local_rank)
def __init__(self, output_path, seed, local_rank, dataset_name):
super().__init__(output_path, seed, local_rank, dataset_name)
self.dataset_name = "Cohere/miracl-zh-queries-22-12"
self.dataset_name_clean = "Cohere_miracl_zh_queries_22_12"
self.raw_datasets = load_dataset("Cohere/miracl-zh-queries-22-12")

def get_train_data(self):
return self.raw_datasets["train"]
Expand Down Expand Up @@ -397,11 +389,10 @@ def get_prompt_and_rejected(self, sample):
# Chinese dataset
class HelloSimpleAIHC3ChineseDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
super().__init__(output_path, seed, local_rank)
def __init__(self, output_path, seed, local_rank, dataset_name):
super().__init__(output_path, seed, local_rank, dataset_name)
self.dataset_name = "Hello-SimpleAI/HC3-Chinese"
self.dataset_name_clean = "Hello_SimpleAI_HC3_Chinese"
self.raw_datasets = load_dataset("Hello-SimpleAI/HC3-Chinese", "all")

def get_train_data(self):
from .data_utils import get_raw_dataset_split_index
Expand Down Expand Up @@ -456,11 +447,10 @@ def get_prompt_and_rejected(self, sample):
# Chinese dataset
class MkqaChineseDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
super().__init__(output_path, seed, local_rank)
def __init__(self, output_path, seed, local_rank, dataset_name):
super().__init__(output_path, seed, local_rank, dataset_name)
self.dataset_name = "mkqa-Chinese"
self.dataset_name_clean = "mkqa"
self.raw_datasets = load_dataset("mkqa")

def get_train_data(self):
from .data_utils import get_raw_dataset_split_index
Expand Down Expand Up @@ -516,11 +506,10 @@ def get_prompt_and_rejected(self, sample):
# Japanese dataset
class MkqaJapaneseDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
super().__init__(output_path, seed, local_rank)
def __init__(self, output_path, seed, local_rank, dataset_name):
super().__init__(output_path, seed, local_rank, dataset_name)
self.dataset_name = "mkqa-Japanese"
self.dataset_name_clean = "mkqa"
self.raw_datasets = load_dataset("mkqa")

def get_train_data(self):
from .data_utils import get_raw_dataset_split_index
Expand Down Expand Up @@ -575,11 +564,10 @@ def get_prompt_and_rejected(self, sample):
# Japanese dataset
class CohereMiracljaqueries2212Dataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
super().__init__(output_path, seed, local_rank)
def __init__(self, output_path, seed, local_rank, dataset_name):
super().__init__(output_path, seed, local_rank, dataset_name)
self.dataset_name = "Cohere/miracl-ja-queries-22-12"
self.dataset_name_clean = "Cohere_miracl_ja_queries_22_12"
self.raw_datasets = load_dataset("Cohere/miracl-ja-queries-22-12")

def get_train_data(self):
return self.raw_datasets["train"]
Expand Down Expand Up @@ -608,11 +596,10 @@ def get_prompt_and_rejected(self, sample):
# Japanese dataset
class LmqgQgjaquadDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
super().__init__(output_path, seed, local_rank)
def __init__(self, output_path, seed, local_rank, dataset_name):
super().__init__(output_path, seed, local_rank, dataset_name)
self.dataset_name = "lmqg/qg_jaquad"
self.dataset_name_clean = "lmqg_qg_jaquad"
self.raw_datasets = load_dataset("lmqg/qg_jaquad")

def get_train_data(self):
return self.raw_datasets["train"]
Expand Down Expand Up @@ -646,11 +633,10 @@ def get_prompt_and_rejected(self, sample):
# Japanese dataset
class LmqgQagjaquadDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank):
super().__init__(output_path, seed, local_rank)
def __init__(self, output_path, seed, local_rank, dataset_name):
super().__init__(output_path, seed, local_rank, dataset_name)
self.dataset_name = "lmqg/qag_jaquad"
self.dataset_name_clean = "lmqg_qag_jaquad"
self.raw_datasets = load_dataset("lmqg/qag_jaquad")

def get_train_data(self):
return self.raw_datasets["train"]
Expand Down