Skip to content

Commit 3909d7f

Browse files
Add Flax BART pretraining script (#18297)
* add bart pretraining flax script * fixup * add bart pretraining flax script * add BART to README * add BART to README * add BART to README * add BART to README * add BART to README * add bos eos document * Update README.md * Update README.md * Update examples/flax/language-modeling/run_bart_dlm_flax.py Co-authored-by: Sanchit Gandhi <[email protected]> * final * final * final * remove use_auth_token ing from_config Co-authored-by: Sanchit Gandhi <[email protected]>
1 parent 941d233 commit 3909d7f

File tree

4 files changed

+1057
-3
lines changed

4 files changed

+1057
-3
lines changed

examples/flax/language-modeling/README.md

+92
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,98 @@ of 2.36 and 57.0 respectively after 3 epochs on a single TPUv3-8.
338338
This should take around 4.5 hours.
339339
Training statistics can be accessed on directly on the 🤗 [hub](https://huggingface.co/patrickvonplaten/t5-base-norwegian/tensorboard)
340340

341+
## BART: Denoising language modeling
342+
343+
In the following, we demonstrate how to train a BART model
344+
using denoising language modeling objective as introduced in [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461).
345+
More specifically, we demonstrate how JAX/Flax can be leveraged
346+
to pre-train [**`bart-base`**](https://huggingface.co/facebook/bart-base)
347+
in Norwegian on a single TPUv3-8 pod.
348+
349+
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
350+
351+
To setup all relevant files for training, let's create a directory.
352+
353+
```bash
354+
mkdir ./norwegian-roberta-base
355+
```
356+
357+
### Train tokenizer
358+
In the first step, we train a tokenizer to efficiently process the text input for the model. Similar to how it is shown in [How to train a new language model from scratch using Transformers and Tokenizers](https://huggingface.co/blog/how-to-train), we use a **`ByteLevelBPETokenizer`**.
359+
The tokenizer is trained on the complete Norwegian dataset of OSCAR
360+
and consequently saved in the cloned model directory.
361+
This can take up to 10 minutes depending on your hardware ☕.
362+
363+
```python
364+
from datasets import load_dataset
365+
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
366+
367+
# load dataset
368+
dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train")
369+
370+
# Instantiate tokenizer
371+
tokenizer = ByteLevelBPETokenizer()
372+
373+
def batch_iterator(batch_size=1000):
374+
for i in range(0, len(dataset), batch_size):
375+
yield dataset[i: i + batch_size]["text"]
376+
377+
# Customized training
378+
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
379+
"<s>",
380+
"<pad>",
381+
"</s>",
382+
"<unk>",
383+
"<mask>",
384+
])
385+
386+
# Save files to disk
387+
tokenizer.save("./norwegian-bart-base/tokenizer.json")
388+
```
389+
390+
### Create configuration
391+
392+
Next, we create the model's configuration file. This is as simple
393+
as loading and storing [`**facebook/bart-base**`](https://huggingface.co/facebook/bart-base)
394+
in the local model folder:
395+
396+
```python
397+
from transformers import BartConfig
398+
config = BartConfig.from_pretrained("facebook/bart-base", vocab_size=50265)
399+
config.save_pretrained("./norwegian-bart-base")
400+
```
401+
402+
Great, we have set up our model repository. During training, we will automatically
403+
push the training logs and model weights to the repo.
404+
405+
### Train model
406+
407+
Next we can run the example script to pretrain the model:
408+
409+
```bash
410+
python run_bart_dlm_flax.py \
411+
--output_dir="./norwegian-bart-base" \
412+
--config_name="./norwegian-bart-base" \
413+
--tokenizer_name="./norwegian-bart-base" \
414+
--dataset_name="oscar" \
415+
--dataset_config_name="unshuffled_deduplicated_no" \
416+
--max_seq_length="1024" \
417+
--per_device_train_batch_size="32" \
418+
--per_device_eval_batch_size="32" \
419+
--learning_rate="1e-4" \
420+
--warmup_steps="2000" \
421+
--overwrite_output_dir \
422+
--logging_steps="500" \
423+
--save_steps="2000" \
424+
--eval_steps="2000" \
425+
--push_to_hub
426+
```
427+
428+
Training should converge at a loss and accuracy
429+
of 1.36 and 0.77 respectively after 3 epochs on a single TPUv3-8.
430+
This should take less than 6 hours.
431+
Training statistics can be accessed on [tfhub.dev](https://tensorboard.dev/experiment/Maw62QlaSXWS0MOf2V2lbg/).
432+
341433
## Runtime evaluation
342434

343435
We also ran masked language modeling using PyTorch/XLA on a TPUv3-8, and PyTorch on 8 V100 GPUs. We report the

0 commit comments

Comments
 (0)