1) Download the CNN and Daily Mail data and preprocess it into data files with non-tokenized cased samples.
Follow the instructions here to download the original CNN and Daily Mail datasets. To preprocess the data, refer to the pointers in this issue or check out the code here.
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
for SPLIT in train val
do
for LANG in source target
do
python -m examples.roberta.multiprocessing_bpe_encoder \
--encoder-json encoder.json \
--vocab-bpe vocab.bpe \
--inputs "cnn_dm/$SPLIT.$LANG" \
--outputs "cnn_dm/$SPLIT.bpe.$LANG" \
--workers 60 \
--keep-empty;
done
done
fairseq-preprocess \
--source-lang "source" \
--target-lang "target" \
--trainpref "cnn_dm/train.bpe" \
--validpref "cnn_dm/val.bpe" \
--destdir "cnn_dm-bin/" \
--workers 60 \
--srcdict dict.txt \
--tgtdict dict.txt;
Example fine-tuning cmd
TOTAL_NUM_UPDATES=20000
WARMUP_UPDATES=500
LR=3e-05
MAX_TOKENS=2048
UPDATE_FREQ=4
BART_PATH=/path/to/bart/model.pt
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py cnn_dm-bin \
--restore-file $BART_PATH \
--max-tokens $MAX_TOKENS \
--task translation \
--source-lang source --target-lang target \
--truncate-source \
--layernorm-embedding \
--share-all-embeddings \
--share-decoder-input-output-embed \
--reset-optimizer --reset-dataloader --reset-meters \
--required-batch-size-multiple 1 \
--arch bart_large \
--criterion label_smoothed_cross_entropy \
--label-smoothing 0.1 \
--dropout 0.1 --attention-dropout 0.1 \
--weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
--clip-norm 0.1 \
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
--fp16 --update-freq $UPDATE_FREQ \
--skip-invalid-size-inputs-valid-test \
--find-unused-parameters;
Above is expected to run on 1
node with 8 32gb-V100
.
Expected training time is about 5 hours
. Training time can be reduced with distributed training on 4
nodes and --update-freq 1
.
After training the model as mentioned in previous step, you can perform inference with checkpoints in checkpoints/
directory using following python code snippet:
import torch
from fairseq.models.bart import BARTModel
bart = BARTModel.from_pretrained(
'checkpoints/',
checkpoint_file='checkpoint_best.pt',
data_name_or_path='cnn_dm-bin'
)
bart.cuda()
bart.eval()
bart.half()
count = 1
bsz = 32
with open('cnn_dm/test.source') as source, open('cnn_dm/test.hypo', 'w') as fout:
sline = source.readline().strip()
slines = [sline]
for sline in source:
if count % bsz == 0:
with torch.no_grad():
hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
for hypothesis in hypotheses_batch:
fout.write(hypothesis + '\n')
fout.flush()
slines = []
slines.append(sline.strip())
count += 1
if slines != []:
hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
for hypothesis in hypotheses_batch:
fout.write(hypothesis + '\n')
fout.flush()