Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ to that word). This technique has been refined for Chinese in [this paper](https

To fine-tune a model using whole word masking, use the following script:

```bash
python run_mlm_wwm.py \
--model_name_or_path roberta-base \
--dataset_name wikitext \
Expand Down
30 changes: 20 additions & 10 deletions examples/language-modeling/run_mlm_wwm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ class DataTrainingArguments:
Arguments pertaining to what data we are going to input our model for training and eval.
"""

dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
validation_file: Optional[str] = field(
default=None,
Expand Down Expand Up @@ -200,15 +206,19 @@ def main():
#
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.train_file.split(".")[-1]
if extension == "txt":
extension = "text"
datasets = load_dataset(extension, data_files=data_files)
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
else:
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.train_file.split(".")[-1]
if extension == "txt":
extension = "text"
datasets = load_dataset(extension, data_files=data_files)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.

Expand Down Expand Up @@ -278,7 +288,7 @@ def tokenize_function(examples):
# Add the chinese references if provided
if data_args.train_ref_file is not None:
tokenized_datasets["train"] = add_chinese_references(tokenized_datasets["train"], data_args.train_ref_file)
if data_args.valid_ref_file is not None:
if data_args.validation_ref_file is not None:
tokenized_datasets["validation"] = add_chinese_references(
tokenized_datasets["validation"], data_args.validation_ref_file
)
Expand Down