-
Notifications
You must be signed in to change notification settings - Fork 1.1k
change load local ./ hf parquet dataset #398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
2c7ec3e
1b4112a
5bf0b4a
00906a7
c9d67ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,7 @@ | |
| # follow in order to have a unified API and unified data format. | ||
| class PromptRawDataset(object): | ||
|
|
||
| def __init__(self, output_path, seed, local_rank): | ||
| def __init__(self, output_path, seed, local_rank, dataset_name=None): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dataset_name is a confusing variable name, I recommend to change it to local_data_dir. Same comment apply to all other classes. |
||
| self.output_path = output_path | ||
| self.seed = seed | ||
| self.local_rank = local_rank | ||
|
|
@@ -45,11 +45,15 @@ def get_prompt_and_rejected(self, sample): | |
| # English dataset | ||
| class DahoasRmstaticDataset(PromptRawDataset): | ||
|
|
||
| def __init__(self, output_path, seed, local_rank): | ||
| def __init__(self, output_path, seed, local_rank, dataset_name=None): | ||
| super().__init__(output_path, seed, local_rank) | ||
| self.dataset_name = "Dahoas/rm-static" | ||
| self.dataset_name_clean = "Dahoas_rm_static" | ||
| self.raw_datasets = load_dataset("Dahoas/rm-static") | ||
| if dataset_name is not None: | ||
| self.raw_datasets = load_dataset("parquet", data_dir=dataset_name) | ||
| else: | ||
| self.raw_datasets = load_dataset("Dahoas/rm-static") | ||
|
|
||
|
|
||
| def get_train_data(self): | ||
| return self.raw_datasets["train"] | ||
|
|
@@ -76,11 +80,14 @@ def get_prompt_and_rejected(self, sample): | |
| # English dataset | ||
| class DahoasFullhhrlhfDataset(PromptRawDataset): | ||
|
|
||
| def __init__(self, output_path, seed, local_rank): | ||
| def __init__(self, output_path, seed, local_rank, dataset_name=None): | ||
| super().__init__(output_path, seed, local_rank) | ||
| self.dataset_name = "Dahoas/full-hh-rlhf" | ||
| self.dataset_name_clean = "Dahoas_full_hh_rlhf" | ||
| self.raw_datasets = load_dataset("Dahoas/full-hh-rlhf") | ||
| if dataset_name is not None: | ||
| self.raw_datasets = load_dataset("parquet", data_dir=dataset_name) | ||
| else: | ||
| self.raw_datasets = load_dataset("Dahoas/full-hh-rlhf") | ||
|
|
||
| def get_train_data(self): | ||
| return self.raw_datasets["train"] | ||
|
|
@@ -107,12 +114,15 @@ def get_prompt_and_rejected(self, sample): | |
| # English dataset | ||
| class DahoasSyntheticinstructgptjpairwiseDataset(PromptRawDataset): | ||
|
|
||
| def __init__(self, output_path, seed, local_rank): | ||
| def __init__(self, output_path, seed, local_rank, dataset_name=None): | ||
| super().__init__(output_path, seed, local_rank) | ||
| self.dataset_name = "Dahoas/synthetic-instruct-gptj-pairwise" | ||
| self.dataset_name_clean = "Dahoas_synthetic_instruct_gptj_pairwise" | ||
| self.raw_datasets = load_dataset( | ||
| "Dahoas/synthetic-instruct-gptj-pairwise") | ||
| if dataset_name is not None: | ||
| self.raw_datasets = load_dataset("parquet", data_dir=dataset_name) | ||
| else: | ||
| self.raw_datasets = load_dataset( | ||
| "Dahoas/synthetic-instruct-gptj-pairwise") | ||
|
|
||
| def get_train_data(self): | ||
| from .data_utils import get_raw_dataset_split_index | ||
|
|
@@ -154,11 +164,14 @@ def get_prompt_and_rejected(self, sample): | |
| # English dataset | ||
| class YitingxieRlhfrewarddatasetsDataset(PromptRawDataset): | ||
|
|
||
| def __init__(self, output_path, seed, local_rank): | ||
| def __init__(self, output_path, seed, local_rank, dataset_name=None): | ||
| super().__init__(output_path, seed, local_rank) | ||
| self.dataset_name = "yitingxie/rlhf-reward-datasets" | ||
| self.dataset_name_clean = "yitingxie_rlhf_reward_datasets" | ||
| self.raw_datasets = load_dataset("yitingxie/rlhf-reward-datasets") | ||
| if dataset_name is not None: | ||
| self.raw_datasets = load_dataset("parquet", data_dir=dataset_name) | ||
| else: | ||
| self.raw_datasets = load_dataset("yitingxie/rlhf-reward-datasets") | ||
|
|
||
| def get_train_data(self): | ||
| return self.raw_datasets["train"] | ||
|
|
@@ -185,11 +198,14 @@ def get_prompt_and_rejected(self, sample): | |
| # English dataset | ||
| class OpenaiWebgptcomparisonsDataset(PromptRawDataset): | ||
|
|
||
| def __init__(self, output_path, seed, local_rank): | ||
| def __init__(self, output_path, seed, local_rank, dataset_name=None): | ||
| super().__init__(output_path, seed, local_rank) | ||
| self.dataset_name = "openai/webgpt_comparisons" | ||
| self.dataset_name_clean = "openai_webgpt_comparisons" | ||
| self.raw_datasets = load_dataset("openai/webgpt_comparisons") | ||
| if dataset_name is not None: | ||
| self.raw_datasets = load_dataset("parquet", data_dir=dataset_name) | ||
| else: | ||
| self.raw_datasets = load_dataset("openai/webgpt_comparisons") | ||
|
|
||
| def get_train_data(self): | ||
| from .data_utils import get_raw_dataset_split_index | ||
|
|
@@ -259,11 +275,14 @@ def get_prompt_and_rejected(self, sample): | |
| # English dataset | ||
| class StanfordnlpSHPDataset(PromptRawDataset): | ||
|
|
||
| def __init__(self, output_path, seed, local_rank): | ||
| def __init__(self, output_path, seed, local_rank, dataset_name=None): | ||
| super().__init__(output_path, seed, local_rank) | ||
| self.dataset_name = "stanfordnlp/SHP" | ||
| self.dataset_name_clean = "stanfordnlp_SHP" | ||
| self.raw_datasets = load_dataset("stanfordnlp/SHP") | ||
| if dataset_name is not None: | ||
| self.raw_datasets = load_dataset("parquet", data_dir=dataset_name) | ||
| else: | ||
| self.raw_datasets = load_dataset("stanfordnlp/SHP") | ||
|
|
||
| def get_train_data(self): | ||
| return self.raw_datasets["train"] | ||
|
|
@@ -306,11 +325,14 @@ def get_prompt_and_rejected(self, sample): | |
| # Chinese dataset | ||
| class Wangrui6ZhihuKOLDataset(PromptRawDataset): | ||
|
|
||
| def __init__(self, output_path, seed, local_rank): | ||
| def __init__(self, output_path, seed, local_rank, dataset_name=None): | ||
| super().__init__(output_path, seed, local_rank) | ||
| self.dataset_name = "wangrui6/Zhihu-KOL" | ||
| self.dataset_name_clean = "wangrui6_Zhihu_KOL" | ||
| self.raw_datasets = load_dataset("wangrui6/Zhihu-KOL") | ||
| if dataset_name is not None: | ||
| self.raw_datasets = load_dataset("parquet", data_dir=dataset_name) | ||
| else: | ||
| self.raw_datasets = load_dataset("wangrui6/Zhihu-KOL") | ||
|
|
||
| def get_train_data(self): | ||
| from .data_utils import get_raw_dataset_split_index | ||
|
|
@@ -364,11 +386,14 @@ def get_prompt_and_rejected(self, sample): | |
| # Chinese dataset | ||
| class CohereMiraclzhqueries2212Dataset(PromptRawDataset): | ||
|
|
||
| def __init__(self, output_path, seed, local_rank): | ||
| def __init__(self, output_path, seed, local_rank, dataset_name=None): | ||
| super().__init__(output_path, seed, local_rank) | ||
| self.dataset_name = "Cohere/miracl-zh-queries-22-12" | ||
| self.dataset_name_clean = "Cohere_miracl_zh_queries_22_12" | ||
| self.raw_datasets = load_dataset("Cohere/miracl-zh-queries-22-12") | ||
| if dataset_name is not None: | ||
| self.raw_datasets = load_dataset("parquet", data_dir=dataset_name) | ||
| else: | ||
| self.raw_datasets = load_dataset("Cohere/miracl-zh-queries-22-12") | ||
|
|
||
| def get_train_data(self): | ||
| return self.raw_datasets["train"] | ||
|
|
@@ -397,11 +422,14 @@ def get_prompt_and_rejected(self, sample): | |
| # Chinese dataset | ||
| class HelloSimpleAIHC3ChineseDataset(PromptRawDataset): | ||
|
|
||
| def __init__(self, output_path, seed, local_rank): | ||
| def __init__(self, output_path, seed, local_rank, dataset_name=None): | ||
| super().__init__(output_path, seed, local_rank) | ||
| self.dataset_name = "Hello-SimpleAI/HC3-Chinese" | ||
| self.dataset_name_clean = "Hello_SimpleAI_HC3_Chinese" | ||
| self.raw_datasets = load_dataset("Hello-SimpleAI/HC3-Chinese", "all") | ||
| if dataset_name is not None: | ||
| self.raw_datasets = load_dataset("parquet", data_dir=dataset_name) | ||
| else: | ||
| self.raw_datasets = load_dataset("Hello-SimpleAI/HC3-Chinese", "all") | ||
|
|
||
| def get_train_data(self): | ||
| from .data_utils import get_raw_dataset_split_index | ||
|
|
@@ -456,11 +484,14 @@ def get_prompt_and_rejected(self, sample): | |
| # Chinese dataset | ||
| class MkqaChineseDataset(PromptRawDataset): | ||
|
|
||
| def __init__(self, output_path, seed, local_rank): | ||
| def __init__(self, output_path, seed, local_rank, dataset_name=None): | ||
| super().__init__(output_path, seed, local_rank) | ||
| self.dataset_name = "mkqa-Chinese" | ||
| self.dataset_name_clean = "mkqa" | ||
| self.raw_datasets = load_dataset("mkqa") | ||
| if dataset_name is not None: | ||
| self.raw_datasets = load_dataset("parquet", data_dir=dataset_name) | ||
| else: | ||
| self.raw_datasets = load_dataset("mkqa") | ||
|
|
||
| def get_train_data(self): | ||
| from .data_utils import get_raw_dataset_split_index | ||
|
|
@@ -516,11 +547,14 @@ def get_prompt_and_rejected(self, sample): | |
| # Japanese dataset | ||
| class MkqaJapaneseDataset(PromptRawDataset): | ||
|
|
||
| def __init__(self, output_path, seed, local_rank): | ||
| def __init__(self, output_path, seed, local_rank, dataset_name=None): | ||
| super().__init__(output_path, seed, local_rank) | ||
| self.dataset_name = "mkqa-Japanese" | ||
| self.dataset_name_clean = "mkqa" | ||
| self.raw_datasets = load_dataset("mkqa") | ||
| if dataset_name is not None: | ||
| self.raw_datasets = load_dataset("parquet", data_dir=dataset_name) | ||
| else: | ||
| self.raw_datasets = load_dataset("mkqa") | ||
|
|
||
| def get_train_data(self): | ||
| from .data_utils import get_raw_dataset_split_index | ||
|
|
@@ -575,11 +609,14 @@ def get_prompt_and_rejected(self, sample): | |
| # Japanese dataset | ||
| class CohereMiracljaqueries2212Dataset(PromptRawDataset): | ||
|
|
||
| def __init__(self, output_path, seed, local_rank): | ||
| def __init__(self, output_path, seed, local_rank, dataset_name=None): | ||
| super().__init__(output_path, seed, local_rank) | ||
| self.dataset_name = "Cohere/miracl-ja-queries-22-12" | ||
| self.dataset_name_clean = "Cohere_miracl_ja_queries_22_12" | ||
| self.raw_datasets = load_dataset("Cohere/miracl-ja-queries-22-12") | ||
| if dataset_name is not None: | ||
| self.raw_datasets = load_dataset("parquet", data_dir=dataset_name) | ||
| else: | ||
| self.raw_datasets = load_dataset("Cohere/miracl-ja-queries-22-12") | ||
|
|
||
| def get_train_data(self): | ||
| return self.raw_datasets["train"] | ||
|
|
@@ -608,11 +645,14 @@ def get_prompt_and_rejected(self, sample): | |
| # Japanese dataset | ||
| class LmqgQgjaquadDataset(PromptRawDataset): | ||
|
|
||
| def __init__(self, output_path, seed, local_rank): | ||
| def __init__(self, output_path, seed, local_rank, dataset_name=None): | ||
| super().__init__(output_path, seed, local_rank) | ||
| self.dataset_name = "lmqg/qg_jaquad" | ||
| self.dataset_name_clean = "lmqg_qg_jaquad" | ||
| self.raw_datasets = load_dataset("lmqg/qg_jaquad") | ||
| if dataset_name is not None: | ||
| self.raw_datasets = load_dataset("parquet", data_dir=dataset_name) | ||
| else: | ||
| self.raw_datasets = load_dataset("lmqg/qg_jaquad") | ||
|
|
||
| def get_train_data(self): | ||
| return self.raw_datasets["train"] | ||
|
|
@@ -646,11 +686,14 @@ def get_prompt_and_rejected(self, sample): | |
| # Japanese dataset | ||
| class LmqgQagjaquadDataset(PromptRawDataset): | ||
|
|
||
| def __init__(self, output_path, seed, local_rank): | ||
| def __init__(self, output_path, seed, local_rank, dataset_name=None): | ||
| super().__init__(output_path, seed, local_rank) | ||
| self.dataset_name = "lmqg/qag_jaquad" | ||
| self.dataset_name_clean = "lmqg_qag_jaquad" | ||
| self.raw_datasets = load_dataset("lmqg/qag_jaquad") | ||
| if dataset_name is not None: | ||
| self.raw_datasets = load_dataset("parquet", data_dir=dataset_name) | ||
| else: | ||
| self.raw_datasets = load_dataset("lmqg/qag_jaquad") | ||
|
|
||
| def get_train_data(self): | ||
| return self.raw_datasets["train"] | ||
|
|
@@ -679,3 +722,33 @@ def get_prompt_and_rejected(self, sample): | |
| f"Warning: dataset {self.dataset_name} does not include rejected response." | ||
| ) | ||
| return None | ||
|
|
||
|
|
||
| class LocalParquetDataset(PromptRawDataset): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this class? This LocalParquetDataset class is never used in your PR. Please remove it if it's not necessary. |
||
| def __init__(self, output_path, seed, local_rank, dataset_name): | ||
| super().__init__(output_path, seed, local_rank) | ||
| self.dataset_name = dataset_name.strip("./") | ||
| self.dataset_name_clean = self.dataset_name.replace("/", "_") | ||
| self.raw_datasets = load_dataset("parquet", data_dir=dataset_name) | ||
|
|
||
| def get_train_data(self): | ||
| # TODO | ||
| raise NotImplementedError("you need to fill your data.") | ||
|
|
||
| def get_eval_data(self): | ||
| raise NotImplementedError("you need to fill your data.") | ||
|
|
||
| def get_prompt(self, sample): | ||
| raise NotImplementedError("you need to fill your data.") | ||
|
|
||
| def get_chosen(self, sample): | ||
| raise NotImplementedError("you need to fill your data.") | ||
|
|
||
| def get_rejected(self, sample): | ||
| raise NotImplementedError("you need to fill your data.") | ||
|
|
||
| def get_prompt_and_chosen(self, sample): | ||
| raise NotImplementedError("you need to fill your data.") | ||
|
|
||
| def get_prompt_and_rejected(self, sample): | ||
| raise NotImplementedError("you need to fill your data.") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please change dn to longer and meaningful name, such as local_data_dir