From d340f3ab3ff72245366845fdadfd26398e58fc0e Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Wed, 4 Nov 2020 10:33:17 -0500 Subject: [PATCH] Fix validation file loading in scripts --- examples/language-modeling/run_clm.py | 2 +- examples/language-modeling/run_mlm.py | 2 +- examples/language-modeling/run_mlm_wwm.py | 2 +- examples/language-modeling/run_plm.py | 2 +- .../run_{{cookiecutter.example_shortcut}}.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index b5d2c6d3aac8..d2231e1703ee 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -190,7 +190,7 @@ def main(): if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: - data_files["validation"] = data_args.train_file + data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" diff --git a/examples/language-modeling/run_mlm.py b/examples/language-modeling/run_mlm.py index 8d743822dd03..d6653b18057d 100644 --- a/examples/language-modeling/run_mlm.py +++ b/examples/language-modeling/run_mlm.py @@ -201,7 +201,7 @@ def main(): if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: - data_files["validation"] = data_args.train_file + data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" diff --git a/examples/language-modeling/run_mlm_wwm.py b/examples/language-modeling/run_mlm_wwm.py index 9a830f5f929f..ecc4c55e7c2a 100644 --- a/examples/language-modeling/run_mlm_wwm.py +++ b/examples/language-modeling/run_mlm_wwm.py @@ -204,7 +204,7 @@ def main(): if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: - data_files["validation"] = data_args.train_file + data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" diff --git a/examples/language-modeling/run_plm.py b/examples/language-modeling/run_plm.py index fc5a5dbb7656..337ebb3e7ef6 100644 --- a/examples/language-modeling/run_plm.py +++ b/examples/language-modeling/run_plm.py @@ -198,7 +198,7 @@ def main(): if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: - data_files["validation"] = data_args.train_file + data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" diff --git a/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py b/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py index 85d9d9d11ea1..778ee04afa96 100644 --- a/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py +++ b/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py @@ -205,7 +205,7 @@ def main(): if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: - data_files["validation"] = data_args.train_file + data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text"