Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BART DLM PyTorch pretraining example #18904

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 91 additions & 5 deletions examples/pytorch/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
-->

## Language model training
# Language model training

Fine-tuning (or training from scratch) the library models for language modeling on a text dataset for GPT, GPT-2,
ALBERT, BERT, DistilBERT, RoBERTa, XLNet... GPT and GPT-2 are trained or fine-tuned using a causal language modeling
Expand All @@ -29,7 +29,7 @@ There are two sets of scripts provided. The first set leverages the Trainer API.
The following examples, will run on datasets hosted on our [hub](https://huggingface.co/datasets) or with your own
text files for training and validation. We give examples of both below.

### GPT-2/GPT and causal language modeling
## GPT-2/GPT and causal language modeling

The following example fine-tunes GPT-2 on WikiText-2. We're using the raw WikiText-2 (no tokens were replaced before
the tokenization). The loss here is that of causal language modeling.
Expand Down Expand Up @@ -73,7 +73,7 @@ python run_clm_no_trainer.py \
--output_dir /tmp/test-clm
```

### RoBERTa/BERT/DistilBERT and masked language modeling
## RoBERTa/BERT/DistilBERT and masked language modeling

The following example fine-tunes RoBERTa on WikiText-2. Here too, we're using the raw WikiText-2. The loss is different
as BERT/RoBERTa have a bidirectional mechanism; we're therefore using the same loss that was used during their
Expand Down Expand Up @@ -124,11 +124,11 @@ python run_mlm_no_trainer.py \
**Note:** On TPU, you should use the flag `--pad_to_max_length` in conjunction with the `--line_by_line` flag to make
sure all your batches have the same length.

### Whole word masking
## Whole word masking

This part was moved to `examples/research_projects/mlm_wwm`.

### XLNet and permutation language modeling
## XLNet and permutation language modeling

XLNet uses a different training objective, which is permutation language modeling. It is an autoregressive method
to learn bidirectional contexts by maximizing the expected likelihood over all permutations of the input
Expand Down Expand Up @@ -174,6 +174,91 @@ concatenates all texts and then splits them in blocks of the same length).
**Note:** On TPU, you should use the flag `--pad_to_max_length` in conjunction with the `--line_by_line` flag to make
sure all your batches have the same length.

## BART and denoising language modeling

BART is is an encoder-decoder that is trained on the denoising objective. The input text is corrupted and the model
must reconstruct it. The added noise includes token masking, in-filling (replacing multiple tokens by a single mask),
replacing tokens by random tokens, permuting sentences in a sequence, and so on. The implementation here borrows from
heavily from the original
[fairseq](https://github.com/facebookresearch/fairseq/blob/1bba712622b8ae4efb3eb793a8a40da386fe11d0/fairseq/data/denoising_dataset.py)
implementation and the
[FLAX training script](https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_bart_dlm_flax.py)
.

### 1. Train a tokenizer on a dataset on the hub

```shell
python prepare_tokenizer.py \
oscar \
--dataset_config_name unshuffled_deduplicated_nl \
--dataset_split train \
--dout ./my-bart-model
```

### 2. Prepare a model config file based on an existing model

```shell
python prepare_config.py \
--pretrained_model_name facebook/bart-base \
--dout ./my-bart-model
```

### 3. Train the model and specific tokenizer and config

By default, we use the denoising parameters described in
[this post](https://github.com/facebookresearch/fairseq/issues/1899#issuecomment-1069429320).

```shell
python run_bart_dlm.py \
--config_name ./my-bart-model \
--tokenizer_name ./my-bart-model \
--dataset_name oscar \
--dataset_config_name unshuffled_deduplicated_nl \
--output_dir ./my-bart-model \
--do_train \
--do_eval
```


### Some notes

#### Sentence splitting

As part of BART, the sentences in a sample may be permuted (reordered). To detect sentences for each sample, we need
sentence splitting. By dfault, we'll use NLTK's English punct sentence splitter but by passing a spaCy model name
to `spacy_model` (e.g. `en_core_web_sm`) you can also rely on spaCy for better (but slower) sentence splitting.
You can also disable sentence splitting completely with `--no_sentence_splitting`. In that case, make sure the
sentences are already split with a padding token between them (`<pad>`).


#### Default values
The defaults are set to the
[given BART args](https://github.com/facebookresearch/fairseq/issues/1899#issuecomment-1069429320). This differs from
the Flax defaults in one respect, namely `poisson_lambda`, which is now set to `3.5` instead of `3.0`.


#### HF (Flax), fairseq, and current implementation

There are some differences in implementation between fairseq, the HF FLAX example, and this PyTorch implementation.

- `argwhere` in the Flax example
[in this position](https://github.com/huggingface/transformers/blob/65fb71bc762c46bb067306c1fd083b1cba87a095/examples/flax/language-modeling/run_bart_dlm_flax.py#L319)
is not the same as what is happening in fairseq. [In fairseq](https://github.com/facebookresearch/fairseq/blob/a6a63279422f846a3c2f6c45b9c96d6951cc4b82/fairseq/data/denoising_dataset.py#L230)
we check explicitly that the previous token was not a "full stop" (padding token) but in HF we just check whether the
current token is a full stop. In the current example I also explicitly check that the next token is not a full stop,
in case of padding. (However, in practice that should be a non-issue since all batches/samples should have the
same sequence length and there should not be any padding.)
- I found that the result of sentence permutation was not consistent in terms of where the separating pad token ended
up ([bug report](https://github.com/facebookresearch/fairseq/issues/4695)), so I have reimplemented that method so
that sentences in a sequence are still separated by a padding token, even after permutation.
- In HF FLAX, the token_mask is restricted to [non-special and non-padding tokens](https://github.com/huggingface/transformers/blob/65fb71bc762c46bb067306c1fd083b1cba87a095/examples/flax/language-modeling/run_bart_dlm_flax.py#L361).
In Fairseq, by default, only the first and last tokens are excluded and [all others](https://github.com/facebookresearch/fairseq/blob/1bba712622b8ae4efb3eb793a8a40da386fe11d0/fairseq/data/denoising_dataset.py#L241)
are prone to masking. The HF implementation seems sensible so I follow that. `get_special_tokens_mask` includes the
padding token, though, so no need to add that separately.
- The Flax example does not include methods to add more noise. I have ported those as well.
- However, I did not adapt `add_insertion_noise` to work well with padded sequences. So the inserted noise may occur
ANYWHERE. It is unclear whether this is intended behavior.


## Creating a model on the fly

Expand All @@ -186,3 +271,4 @@ python run_clm.py --model_type gpt2 --tokenizer_name gpt2 \ --config_overrides="
```

This feature is only available in `run_clm.py`, `run_plm.py` and `run_mlm.py`.

Loading