diff --git a/examples/language-modeling/README.md b/examples/language-modeling/README.md index 3442d4efa042..a491c5998aae 100644 --- a/examples/language-modeling/README.md +++ b/examples/language-modeling/README.md @@ -91,6 +91,7 @@ to that word). This technique has been refined for Chinese in [this paper](https To fine-tune a model using whole word masking, use the following script: +```bash python run_mlm_wwm.py \ --model_name_or_path roberta-base \ --dataset_name wikitext \ diff --git a/examples/language-modeling/run_mlm_wwm.py b/examples/language-modeling/run_mlm_wwm.py index ecc4c55e7c2a..7adad187a00d 100644 --- a/examples/language-modeling/run_mlm_wwm.py +++ b/examples/language-modeling/run_mlm_wwm.py @@ -90,6 +90,12 @@ class DataTrainingArguments: Arguments pertaining to what data we are going to input our model for training and eval. """ + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) validation_file: Optional[str] = field( default=None, @@ -200,15 +206,19 @@ def main(): # # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. - data_files = {} - if data_args.train_file is not None: - data_files["train"] = data_args.train_file - if data_args.validation_file is not None: - data_files["validation"] = data_args.validation_file - extension = data_args.train_file.split(".")[-1] - if extension == "txt": - extension = "text" - datasets = load_dataset(extension, data_files=data_files) + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name) + else: + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.train_file.split(".")[-1] + if extension == "txt": + extension = "text" + datasets = load_dataset(extension, data_files=data_files) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. @@ -278,7 +288,7 @@ def tokenize_function(examples): # Add the chinese references if provided if data_args.train_ref_file is not None: tokenized_datasets["train"] = add_chinese_references(tokenized_datasets["train"], data_args.train_ref_file) - if data_args.valid_ref_file is not None: + if data_args.validation_ref_file is not None: tokenized_datasets["validation"] = add_chinese_references( tokenized_datasets["validation"], data_args.validation_ref_file )