From 295f6a7e515cf458c0671ffb354bc98266d386d1 Mon Sep 17 00:00:00 2001 From: freewym Date: Thu, 19 Apr 2018 16:57:25 -0400 Subject: [PATCH] fixed how l2-regularize is performed when backstitch training is activated; added an example script for backstitch training using tdnn-lstm acoustic model on AMI --- egs/ami/s5b/local/chain/run_tdnn_lstm_bs.sh | 1 + .../local/chain/tuning/run_tdnn_lstm_bs_1a.sh | 309 ++++++++++++++++++ src/nnet3/nnet-chain-training.cc | 16 +- src/nnet3/nnet-training.cc | 16 +- 4 files changed, 324 insertions(+), 18 deletions(-) create mode 120000 egs/ami/s5b/local/chain/run_tdnn_lstm_bs.sh create mode 100755 egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_bs_1a.sh diff --git a/egs/ami/s5b/local/chain/run_tdnn_lstm_bs.sh b/egs/ami/s5b/local/chain/run_tdnn_lstm_bs.sh new file mode 120000 index 00000000000..c3c8dc56cc2 --- /dev/null +++ b/egs/ami/s5b/local/chain/run_tdnn_lstm_bs.sh @@ -0,0 +1 @@ +tuning/run_tdnn_lstm_bs_1a.sh \ No newline at end of file diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_bs_1a.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_bs_1a.sh new file mode 100755 index 00000000000..b672a44e572 --- /dev/null +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_bs_1a.sh @@ -0,0 +1,309 @@ +#!/bin/bash + +# same as tdnn_lstm_1o but use backstitch training. +# Also num-epochs and l2-regularize are tuned for best performance. + +# local/chain/tuning/run_tdnn_lstm_bs_1a.sh --mic sdm1 --use-ihm-ali true --train-set train_cleaned --gmm tri3_cleaned +# local/chain/compare_wer_general.sh sdm1 tdnn_lstm_bs_1a_sp_bi_ihmali_ld5 tdnn_lstm1o_sp_bi_ihmali_ld5 + +# System tdnn_lstm_bs_1a_sp_bi_ihmali_ld5 tdnn_lstm1o_sp_bi_ihmali_ld5 +# WER on dev 33.8 35.2 +# WER on eval 37.5 38.7 +# Final train prob -0.126056 -0.167549 +# Final valid prob -0.228452 -0.24847 +# Final train prob (xent) -1.51685 -1.7403 +# Final valid prob (xent) -2.04719 -2.13732 + + +set -e -o pipefail + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +mic=ihm +nj=30 +min_seg_len=1.55 +use_ihm_ali=false +train_set=train_cleaned +gmm=tri3_cleaned # the gmm for the target data +ihm_gmm=tri3 # the gmm for the IHM system (if --use-ihm-ali true). +num_threads_ubm=32 +nnet3_affix=_cleaned # cleanup affix for nnet3 and chain dirs, e.g. _cleaned +num_epochs=10 + +chunk_width=150 +chunk_left_context=40 +chunk_right_context=0 +label_delay=5 +remove_egs=true +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +train_stage=-10 +tree_affix= # affix for tree directory, e.g. "a" or "b", in case we change the configuration. +tlstm_affix=1a #affix for TDNN-LSTM directory, e.g. "a" or "b", in case we change the configuration. +common_egs_dir= # you can set this to use previously dumped egs. + +alpha=0.2 +back_interval=1 + +# decode options +extra_left_context=50 +frames_per_chunk= + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat <data/lang_chain/topo + fi +fi + +if [ $stage -le 13 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj 100 --cmd "$train_cmd" ${lores_train_data_dir} \ + data/lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 14 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --leftmost-questions-truncate -1 \ + --cmd "$train_cmd" 4200 ${lores_train_data_dir} data/lang_chain $ali_dir $tree_dir +fi + +xent_regularize=0.1 + +if [ $stage -le 15 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + tdnn_opts="l2-regularize=0.003" + lstm_opts="l2-regularize=0.005" + output_opts="l2-regularize=0.001" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=1024 $tdnn_opts + relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1) dim=1024 $tdnn_opts + relu-batchnorm-layer name=tdnn3 input=Append(-1,0,1) dim=1024 $tdnn_opts + + # check steps/libs/nnet3/xconfig/lstm.py for the other options and defaults + fast-lstmp-layer name=lstm1 cell-dim=1024 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-3 $lstm_opts + relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=1024 $tdnn_opts + relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=1024 $tdnn_opts + relu-batchnorm-layer name=tdnn6 input=Append(-3,0,3) dim=1024 $tdnn_opts + fast-lstmp-layer name=lstm2 cell-dim=1024 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-3 $lstm_opts + relu-batchnorm-layer name=tdnn7 input=Append(-3,0,3) dim=1024 $tdnn_opts + relu-batchnorm-layer name=tdnn8 input=Append(-3,0,3) dim=1024 $tdnn_opts + relu-batchnorm-layer name=tdnn9 input=Append(-3,0,3) dim=1024 $tdnn_opts + fast-lstmp-layer name=lstm3 cell-dim=1024 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-3 $lstm_opts + + ## adding the layers for chain branch + output-layer name=output input=lstm3 output-delay=$label_delay include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + output-layer name=output-xent input=lstm3 output-delay=$label_delay dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts + +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 16 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/ami-$(date +'%m_%d_%H_%M')/s5b/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir $train_ivector_dir \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --egs.dir "$common_egs_dir" \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width $chunk_width \ + --egs.chunk-left-context $chunk_left_context \ + --egs.chunk-right-context $chunk_right_context \ + --egs.chunk-left-context-initial 0 \ + --egs.chunk-right-context-final 0 \ + --trainer.num-chunk-per-minibatch 64,32 \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.num-jobs-initial 2 \ + --trainer.optimization.num-jobs-final 12 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.optimization.max-models-combine=30 \ + --trainer.optimization.backstitch-training-scale $alpha \ + --trainer.optimization.backstitch-training-interval $back_interval \ + --trainer.max-param-change 2.0 \ + --trainer.deriv-truncate-margin 8 \ + --cleanup.remove-egs $remove_egs \ + --cleanup true \ + --feat-dir $train_data_dir \ + --tree-dir $tree_dir \ + --lat-dir $lat_dir \ + --dir $dir +fi + + +graph_dir=$dir/graph_${LM} +if [ $stage -le 17 ]; then + # Note: it might appear that this data/lang_chain directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 data/lang_${LM} $dir $graph_dir +fi + +if [ $stage -le 18 ]; then + rm $dir/.error 2>/dev/null || true + + [ -z $extra_left_context ] && extra_left_context=$chunk_left_context; + [ -z $frames_per_chunk ] && frames_per_chunk=$chunk_width; + + for decode_set in dev eval; do + ( + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $nj --cmd "$decode_cmd" \ + --extra-left-context $extra_left_context \ + --frames-per-chunk "$frames_per_chunk" \ + --extra-left-context-initial 0 \ + --extra-right-context-final 0 \ + --online-ivector-dir exp/$mic/nnet3${nnet3_affix}/ivectors_${decode_set}_hires \ + --scoring-opts "--min-lmwt 5 " \ + $graph_dir data/$mic/${decode_set}_hires $dir/decode_${decode_set} || exit 1; + ) || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in decoding" + exit 1 + fi +fi +exit 0 diff --git a/src/nnet3/nnet-chain-training.cc b/src/nnet3/nnet-chain-training.cc index 1d149b6f193..2ec2699ec97 100644 --- a/src/nnet3/nnet-chain-training.cc +++ b/src/nnet3/nnet-chain-training.cc @@ -245,17 +245,15 @@ void NnetChainTrainer::TrainInternalBackstitch(const NnetChainExample &eg, // delta_nnet is scaled by 1 + backstitch_training_scale when added to nnet; max_change_scale = 1.0 + nnet_config.backstitch_training_scale; scale_adding = 1.0 + nnet_config.backstitch_training_scale; + // If relevant, add in the part of the gradient that comes from L2 + // regularization. It may not be optimally inefficient to do it on both + // passes of the backstitch, like we do here, but it probably minimizes + // any harmful interactions with the max-change. + ApplyL2Regularization(*nnet_, + 1.0 / scale_adding * GetNumNvalues(eg.inputs, false) * + nnet_config.l2_regularize_factor, delta_nnet_); } - // If relevant, add in the part of the gradient that comes from L2 - // regularization. It may not be optimally inefficient to do it on both - // passes of the backstitch, like we do here, but it probably minimizes - // any harmful interactions with the max-change. - ApplyL2Regularization(*nnet_, - scale_adding * GetNumNvalues(eg.inputs, false) * - nnet_config.l2_regularize_factor, - delta_nnet_); - // Updates the parameters of nnet UpdateNnetWithMaxChange(*delta_nnet_, nnet_config.max_param_change, max_change_scale, scale_adding, nnet_, diff --git a/src/nnet3/nnet-training.cc b/src/nnet3/nnet-training.cc index 49222549e4e..8fda24cd22d 100644 --- a/src/nnet3/nnet-training.cc +++ b/src/nnet3/nnet-training.cc @@ -153,17 +153,15 @@ void NnetTrainer::TrainInternalBackstitch(const NnetExample &eg, // delta_nnet is scaled by 1 + backstitch_training_scale when added to nnet; max_change_scale = 1.0 + config_.backstitch_training_scale; scale_adding = 1.0 + config_.backstitch_training_scale; + // If relevant, add in the part of the gradient that comes from L2 + // regularization. It may not be optimally inefficient to do it on both + // passes of the backstitch, like we do here, but it probably minimizes + // any harmful interactions with the max-change. + ApplyL2Regularization(*nnet_, + 1.0 / scale_adding * GetNumNvalues(eg.io, false) * + config_.l2_regularize_factor, delta_nnet_); } - // If relevant, add in the part of the gradient that comes from L2 - // regularization. It may not be optimally inefficient to do it on both - // passes of the backstitch, like we do here, but it probably minimizes - // any harmful interactions with the max-change. - ApplyL2Regularization(*nnet_, - scale_adding * GetNumNvalues(eg.io, false) * - config_.l2_regularize_factor, - delta_nnet_); - // Updates the parameters of nnet UpdateNnetWithMaxChange(*delta_nnet_, config_.max_param_change, max_change_scale, scale_adding, nnet_,