From 594c56939a5299fab86331eadce598648e28fdfc Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Mon, 27 May 2024 17:48:17 -0400 Subject: [PATCH] fix dataset load error Signed-off-by: Wang, Yi --- .../stack_llama/scripts/reward_modeling.py | 8 ++++++-- .../research_projects/stack_llama/scripts/rl_training.py | 4 +++- .../research_projects/stack_llama_2/scripts/dpo_llama2.py | 1 + 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/examples/research_projects/stack_llama/scripts/reward_modeling.py b/examples/research_projects/stack_llama/scripts/reward_modeling.py index 77f1b3722a..2692de451b 100644 --- a/examples/research_projects/stack_llama/scripts/reward_modeling.py +++ b/examples/research_projects/stack_llama/scripts/reward_modeling.py @@ -99,10 +99,14 @@ class ScriptArguments: script_args = parser.parse_args_into_dataclasses()[0] set_seed(script_args.seed) # Load the human stack-exchange-paired dataset for tuning the reward model. -train_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/reward", split="train") +train_dataset = load_dataset( + "lvwerra/stack-exchange-paired", data_dir="data/reward", split="train", verification_mode="no_checks" +) if script_args.train_subset > 0: train_dataset = train_dataset.select(range(script_args.train_subset)) -eval_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/evaluation", split="train") +eval_dataset = load_dataset( + "lvwerra/stack-exchange-paired", data_dir="data/evaluation", split="train", verification_mode="no_checks" +) if script_args.eval_subset > 0: eval_dataset = eval_dataset.select(range(script_args.eval_subset)) # Define the training args. Needs to be done before the model is loaded if you are using deepspeed. diff --git a/examples/research_projects/stack_llama/scripts/rl_training.py b/examples/research_projects/stack_llama/scripts/rl_training.py index a41f98552f..95523b0b2c 100644 --- a/examples/research_projects/stack_llama/scripts/rl_training.py +++ b/examples/research_projects/stack_llama/scripts/rl_training.py @@ -90,7 +90,9 @@ class ScriptArguments: adap_kl_ctrl=script_args.adap_kl_ctrl, ) -train_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/rl", split="train") +train_dataset = load_dataset( + "lvwerra/stack-exchange-paired", data_dir="data/rl", split="train", verification_mode="no_checks" +) train_dataset = train_dataset.select(range(100000)) original_columns = train_dataset.column_names diff --git a/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py b/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py index 46b47df7f9..435dc37fbc 100644 --- a/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py +++ b/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py @@ -110,6 +110,7 @@ def get_stack_exchange_paired( split="train", cache_dir=cache_dir, data_dir=data_dir, + verification_mode="no_checks", ) original_columns = dataset.column_names