STIL - Simultaneous Slot Filling, Translation, Intent Classification, and Language Identification: Initial Results using mBART on MultiATIS++
This repository contains some of the code used in the paper named above. The purpose of this repo is only to allow other researchers to reproduce the study and results presented in the paper. Jack and Amazon likely will not improve this repo over time.
I used AWS SageMaker for this work, including a m5.xlarge instance for data preparation and analysis and a p3.16xlarge instance for training.
This study uses the pretrained mBART CC25 model. It can be downloaded from fairseq here.
The dataset is based on the MultiATIS++ and MultiATIS datasets.
The 2020 paper by Xue et al. entitled "End-to-End Slot Alignment and Recognition for Cross-Lingual NLU" describes the dataset in more detail. As of writing, the dataset was still under review by LDC. Please contact [email protected] to obtain a copy.
Once you have the data, place it in a folder called MultiATISpp-RAW/
.
At the time of my research, there were a small number of English alignment problems with the Japanese, Hindi, and Turkish data. For this reason, Japanese was excluded, and the Hindi and Turkish data from MultiATIS were used. To ensure a fair comparison with the work by Xu et al., we need to extract the validation set from MultiATIS++ and remove those examples from the MultiATIS training set. Another folder must be created called hi_tr_devsets
containing those two files.
The TSVs should have the following columns, and headers should be included in the files.
id
utterance
slot_labels
intent
MultiATIS is available from LDC here. Please place the TSVs for Hindi and Turkish in a folder called `MultiATIS-RAW/'. The TSVs should have the following columns, and headers should be excluded.
English utterance
English annotations
machine translation back to English
intent
non-English utterance
non-English annotations
MultiATISpp-RAW/
|- dev_DE.tsv
|- dev_ES.tsv
|- dev_ZH.tsv
|- test_EN.tsv
|- test_FR.tsv
|- train_DE.tsv
|- train_ES.tsv
|- train_ZH.tsv
|- dev_EN.tsv
|- dev_FR.tsv
|- test_DE.tsv
|- test_ES.tsv
|- test_ZH.tsv
|- train_EN.tsv
|- train_FR.tsv
MultiATIS-RAW/
|- Hindi-test.tsv
|- Hindi-train_1600.tsv
|- Turkish-test.tsv
|- Turkish-train_638.tsv
hi_tr_devsets/
|- dev_HI.tsv
|- dev_TR.tsv
To create the STIL dataset, run the preprocess_atis_stil.py
script on the data described above. EX:
python path/to/preprocess_atis_stil.py MultiATISpp-RAW/ MultiATIS-RAW/ hi_tr_devsets/ MultiATISpp-FLAT/
To create the traditional NLU dataset (no translation of the slots), run the preprocess_atis_traditional.py
script on the data desscribed above. EX:
python path/to/preprocess_atis_traditional.py MultiATISpp-RAW/ MultiATIS-RAW/ hi_tr_devsets/ MultiATISpp-FLAT/
The mBART model uses sentencepiece tokenization. Information can be found in the sentencepice repo. The following commands can be used to build sentencepiece.
git clone https://github.com/google/sentencepiece.git
cd sentencepiece
mkdir build
cd build
cmake ..
make -j $(nproc)
sudo make install
sudo ldconfig -v
pip install sentencepiece
Once sentencepiece has been built, tokenize the datasets:
SPM=path/to/sentencepiece/build/src/spm_encode
MODEL=path/to/mbart.cc25/sentence.bpe.model
DATA_PATH=path/to/MultiATISpp-FLAT
for SPLIT in train dev test; do for INOUT in input output; do $SPM --model=$MODEL < ${DATA_PATH}/${SPLIT}.${INOUT} > ${DATA_PATH}/${SPLIT}.spm.${INOUT}; done; done
The model requires binarized data.
The gcc
in my instances of SageMaker was too old. Upgrade first if needed:
conda install -c psi4 gcc-5
Install fairseq as editable:
git clone https://github.com/pytorch/fairseq.git
cd fairseq
pip install --editable .
Run binarization:
FAIRSEQ_PATH=path/to/fairseq
DATA_PATH=path/to/tokenized_data
DICT_PATH=path/to/mbart.cc25
python ${FAIRSEQ_PATH}/preprocess.py --source-lang input --target-lang output --trainpref ${DATA_PATH}/train.spm --validpref ${DATA_PATH}/dev.spm --testpref ${DATA_PATH}/test.spm --srcdict $DICT_PATH/dict.txt --tgtdict $DICT_PATH/dict.txt --workers 8 --destdir MultiATISpp-BIN
I used a p3.16xlarge instance for training, which has 8 Nvidia v100 GPUs. By using max sentence of 2 and update freq of 2, it will result in a batch size of 32.
PRETRAINED_BART=path/to/mbart.cc25
DATA_PATH=path/to/data-bin
FAIRSEQ_PATH=path/to/fairseq
CHECKPOINT_PATH=path/to/checkpoints
langs=ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN
python ${FAIRSEQ_PATH}/train.py ${DATA_PATH} --num-workers 32 --encoder-normalize-before --decoder-normalize-before --arch mbart_large --task translation_from_pretrained_bart --source-lang input --target-lang output --criterion label_smoothed_cross_entropy --label-smoothing 0.2 --dataset-impl mmap --optimizer adam --adam-eps 1e-08 --adam-betas '(0.9, 0.999)' --lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 --warmup-updates 936 --total-num-update 20000 --dropout 0.2 --attention-dropout 0.1 --weight-decay 0.01 --max-sentences 2 --update-freq 2 --save-interval 1 --max-epoch 40 --save-dir ${CHECKPOINT_PATH} --validate-interval 1 --seed 222 --log-format json --log-interval 60 --reset-optimizer --reset-meters --reset-dataloader --reset-lr-scheduler --restore-file ${PRETRAINED_BART}/model.pt --langs $langs --layernorm-embedding --ddp-backend no_c10d --memory-efficient-fp16 |& tee train_history.log
Watch GPU utilization:
nvidia-smi -l
Parse training and validation losses as tables from the log file:
python path/to/parse_fairseq_train_logs.py train_history.log prefix_for_output_file
This command will use 8 shards of data. Be sure to pick the right model checkpoint based on validation curves, etc.
for SHARD_ID in {0..7}; do (CUDA_VISIBLE_DEVICES=$SHARD_ID python $FAIRSEQ_PATH/generate.py data-bin/ --path $CHECKPOINT_PATH/checkpoint19.pt --task translation_from_pretrained_bart --gen-subset test -t output -s input --sacrebleu --remove-bpe 'sentencepiece' --langs $langs --memory-efficient-fp16 --max-sentences 64 --num-workers 4 --num-shards 8 --shard-id $SHARD_ID |& tee hyps_test19_${SHARD_ID}.log &); done
Combine the data from the 8 shards into one file:
for file in hyps_test19_*; do cat $file >> hyps_test_epoch19.log; done
rm hyps_test19_*
See the file entitled LICENSE
Note: This work is dependent on fairseq
and sentencepiece
, which were licensed under the MIT License and the Apache 2.0 license, respectively, at the time this work was conducted.
@inproceedings{fitzgerald2020mbartmultiatis,
title = {STIL - Simultaneous Slot Filling, Translation, Intent Classification, and Language Identification: Initial Results using mBART on MultiATIS++},
author = {Jack G. M. FitzGerald},
booktitle = {Proceedings of 1st Conference of the Asia-Pacific Chapter of the Association for Computational Linguistics},
year = {2020},
url = {https://arxiv.org/abs/2010.00760}
}