diff --git a/egs/fisher_english/s5/local/fisher_create_test_lang.sh b/egs/fisher_english/s5/local/fisher_create_test_lang.sh index f0926d2ceab..ac3e16c9c78 100755 --- a/egs/fisher_english/s5/local/fisher_create_test_lang.sh +++ b/egs/fisher_english/s5/local/fisher_create_test_lang.sh @@ -1,23 +1,25 @@ #!/bin/bash -# -if [ -f path.sh ]; then . ./path.sh; fi - -mkdir -p data/lang_test +# This script formats ARPA LM into G.fst. arpa_lm=data/local/lm/3gram-mincount/lm_unpruned.gz +dir=data/lang_test + +if [ -f ./path.sh ]; then . ./path.sh; fi +. utils/parse_options.sh + [ ! -f $arpa_lm ] && echo No such file $arpa_lm && exit 1; -mkdir -p data/lang_test -cp -r data/lang/* data/lang_test +mkdir -p $dir +cp -r data/lang/* $dir gunzip -c "$arpa_lm" | \ arpa2fst --disambig-symbol=#0 \ - --read-symbol-table=data/lang_test/words.txt - data/lang_test/G.fst + --read-symbol-table=$dir/words.txt - $dir/G.fst echo "Checking how stochastic G is (the first of these numbers should be small):" -fstisstochastic data/lang_test/G.fst +fstisstochastic $dir/G.fst ## Check lexicon. ## just have a look and make sure it seems sane. @@ -27,22 +29,21 @@ fstprint --isymbols=data/lang/phones.txt --osymbols=data/lang/words.txt data/l echo Performing further checks # Checking that G.fst is determinizable. -fstdeterminize data/lang_test/G.fst /dev/null || echo Error determinizing G. +fstdeterminize $dir/G.fst /dev/null || echo Error determinizing G. # Checking that L_disambig.fst is determinizable. -fstdeterminize data/lang_test/L_disambig.fst /dev/null || echo Error determinizing L. +fstdeterminize $dir/L_disambig.fst /dev/null || echo Error determinizing L. # Checking that disambiguated lexicon times G is determinizable # Note: we do this with fstdeterminizestar not fstdeterminize, as # fstdeterminize was taking forever (presumbaly relates to a bug # in this version of OpenFst that makes determinization slow for # some case). -fsttablecompose data/lang_test/L_disambig.fst data/lang_test/G.fst | \ +fsttablecompose $dir/L_disambig.fst $dir/G.fst | \ fstdeterminizestar >/dev/null || echo Error # Checking that LG is stochastic: -fsttablecompose data/lang/L_disambig.fst data/lang_test/G.fst | \ +fsttablecompose data/lang/L_disambig.fst $dir/G.fst | \ fstisstochastic || echo "[log:] LG is not stochastic" - echo "$0 succeeded" diff --git a/egs/fisher_english/s5/local/fisher_train_lms_pocolm.sh b/egs/fisher_english/s5/local/fisher_train_lms_pocolm.sh new file mode 100755 index 00000000000..ebdd63034c1 --- /dev/null +++ b/egs/fisher_english/s5/local/fisher_train_lms_pocolm.sh @@ -0,0 +1,170 @@ +#!/bin/bash + +# Copyright 2016 Vincent Nguyen +# 2016 Johns Hopkins University (author: Daniel Povey) +# 2017 Vimal Manohar +# Apache 2.0 +# +# It is based on the example scripts distributed with PocoLM + +set -e +stage=0 + +text=data/train/text +lexicon=data/local/dict/lexicon.txt +dir=data/local/pocolm + +num_ngrams_large=5000000 +num_ngrams_small=2500000 + +echo "$0 $@" # Print the command line for logging +. utils/parse_options.sh || exit 1; + +lm_dir=${dir}/data + +mkdir -p $dir +. ./path.sh || exit 1; # for KALDI_ROOT +export PATH=$KALDI_ROOT/tools/pocolm/scripts:$PATH +( # First make sure the pocolm toolkit is installed. + cd $KALDI_ROOT/tools || exit 1; + if [ -d pocolm ]; then + echo Not installing the pocolm toolkit since it is already there. + else + echo "$0: Please install the PocoLM toolkit with: " + echo " cd ../../../tools; extras/install_pocolm.sh; cd -" + exit 1; + fi +) || exit 1; + +for f in "$text" "$lexicon"; do + [ ! -f $x ] && echo "$0: No such file $f" && exit 1; +done + +num_dev_sentences=10000 + +#bypass_metaparam_optim_opt= +# If you want to bypass the metaparameter optimization steps with specific metaparameters +# un-comment the following line, and change the numbers to some appropriate values. +# You can find the values from output log of train_lm.py. +# These example numbers of metaparameters is for 4-gram model (with min-counts) +# running with train_lm.py. +# The dev perplexity should be close to the non-bypassed model. +#bypass_metaparam_optim_opt="--bypass-metaparameter-optimization=0.854,0.0722,0.5808,0.338,0.166,0.015,0.999,0.6228,0.340,0.172,0.999,0.788,0.501,0.406" +# Note: to use these example parameters, you may need to remove the .done files +# to make sure the make_lm_dir.py be called and tain only 3-gram model +#for order in 3; do +#rm -f ${lm_dir}/${num_word}_${order}.pocolm/.done + +if [ $stage -le 0 ]; then + mkdir -p ${dir}/data + mkdir -p ${dir}/data/text + + echo "$0: Getting the Data sources" + + rm ${dir}/data/text/* 2>/dev/null || true + + cleantext=$dir/text_all.gz + + cut -d ' ' -f 2- $text | awk -v lex=$lexicon ' + BEGIN{ + while((getline0) { seen[$1]=1; } + } + { + for(n=1; n<=NF;n++) { + if (seen[$n]) { + printf("%s ", $n); + } else { + printf(" "); + } + } + printf("\n"); + }' | gzip -c > $cleantext || exit 1; + + # This is for reporting perplexities + gunzip -c $dir/text_all.gz | head -n $num_dev_sentences > \ + ${dir}/data/test.txt + + # use a subset of the annotated training data as the dev set . + # Note: the name 'dev' is treated specially by pocolm, it automatically + # becomes the dev set. + gunzip -c $dir/text_all.gz | tail -n +$[num_dev_sentences+1] | \ + head -n $num_dev_sentences > ${dir}/data/text/dev.txt + + gunzip -c $dir/text_all.gz | tail -n +$[2*num_dev_sentences+1] > \ + ${dir}/data/text/train.txt + + # for reporting perplexities, we'll use the "real" dev set. + # (a subset of the training data is used as ${dir}/data/text/dev.txt to work + # out interpolation weights. + # note, we can't put it in ${dir}/data/text/, because then pocolm would use + # it as one of the data sources. + cut -d " " -f 2- < data/dev_and_test/text > ${dir}/data/real_dev_set.txt + + cat $lexicon | awk '{print $1}' | sort | uniq | awk ' + { + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + printf("%s\n", $1); + }' > $dir/data/wordlist || exit 1; +fi + +order=4 +wordlist=${dir}/data/wordlist + +lm_name="`basename ${wordlist}`_${order}" +min_counts='train=1' +if [ -n "${min_counts}" ]; then + lm_name+="_`echo ${min_counts} | tr -s "[:blank:]" "_" | tr "=" "-"`" +fi + +unpruned_lm_dir=${lm_dir}/${lm_name}.pocolm + +if [ $stage -le 1 ]; then + # decide on the vocabulary. + # Note: you'd use --wordlist if you had a previously determined word-list + # that you wanted to use. + # Note: if you have more than one order, use a certain amount of words as the + # vocab and want to restrict max memory for 'sort', + echo "$0: training the unpruned LM" + train_lm.py --wordlist=${wordlist} --num-splits=10 --warm-start-ratio=20 \ + --limit-unk-history=true \ + --fold-dev-into=train ${bypass_metaparam_optim_opt} \ + --min-counts="${min_counts}" \ + ${dir}/data/text ${order} ${lm_dir}/work ${unpruned_lm_dir} + + get_data_prob.py ${dir}/data/test.txt ${unpruned_lm_dir} 2>&1 | grep -F '[perplexity' | tee ${unpruned_lm_dir}/perplexity_test.log + + get_data_prob.py ${dir}/data/real_dev_set.txt ${unpruned_lm_dir} 2>&1 | grep -F '[perplexity' | tee ${unpruned_lm_dir}/perplexity_real_dev_set.log +fi + +if [ $stage -le 2 ]; then + echo "$0: pruning the LM (to larger size)" + # Using 5 million n-grams for a big LM for rescoring purposes. + prune_lm_dir.py --target-num-ngrams=$num_ngrams_large --initial-threshold=0.02 ${unpruned_lm_dir} ${dir}/data/lm_${order}_prune_big + + get_data_prob.py ${dir}/data/test.txt ${dir}/data/lm_${order}_prune_big 2>&1 | grep -F '[perplexity' | tee ${dir}/data/lm_${order}_prune_big/perplexity_test.log + + get_data_prob.py ${dir}/data/real_dev_set.txt ${dir}/data/lm_${order}_prune_big 2>&1 | grep -F '[perplexity' | tee ${dir}/data/lm_${order}_prune_big/perplexity_real_dev_set.log + + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${dir}/data/lm_${order}_prune_big | gzip -c > ${dir}/data/arpa/${order}gram_big.arpa.gz +fi + +if [ $stage -le 3 ]; then + echo "$0: pruning the LM (to smaller size)" + # Using 2.5 million n-grams for a smaller LM for graph building. + # Prune from the bigger-pruned LM, it'll be faster. + prune_lm_dir.py --target-num-ngrams=$num_ngrams_small ${dir}/data/lm_${order}_prune_big ${dir}/data/lm_${order}_prune_small + + get_data_prob.py ${dir}/data/test.txt ${dir}/data/lm_${order}_prune_small 2>&1 | grep -F '[perplexity' | tee ${dir}/data/lm_${order}_prune_small/perplexity_test.log + + get_data_prob.py ${dir}/data/real_dev_set.txt ${dir}/data/lm_${order}_prune_small 2>&1 | grep -F '[perplexity' | tee ${dir}/data/lm_${order}_prune_small/perplexity_real_dev_set.log + + format_arpa_lm.py ${dir}/data/lm_${order}_prune_small | gzip -c > ${dir}/data/arpa/${order}gram_small.arpa.gz +fi diff --git a/egs/fisher_english/s5/local/nnet3/run_ivector_common.sh b/egs/fisher_english/s5/local/nnet3/run_ivector_common.sh index f6dc67991f5..9ef3cf8877e 100755 --- a/egs/fisher_english/s5/local/nnet3/run_ivector_common.sh +++ b/egs/fisher_english/s5/local/nnet3/run_ivector_common.sh @@ -1,21 +1,21 @@ #!/bin/bash +# Copyright 2017 Hossein Hadian +# 2017 Vimal Manohar +# Apache 2.0 . ./cmd.sh set -e stage=1 -generate_alignments=true # false if doing chain training speed_perturb=true -train_set=train +train_set=train # Supervised training set +ivector_train_set= # data set for training i-vector extractor. + # If not provided, train_set will be used. -lda_train_set=train_100k nnet3_affix= -gmm=tri2_ali # should also contain alignments for $lda_train_set . ./path.sh . ./utils/parse_options.sh -gmm_dir=exp/$gmm - # perturbed data preparation if [ "$speed_perturb" == "true" ]; then if [ $stage -le 1 ]; then @@ -23,32 +23,22 @@ if [ "$speed_perturb" == "true" ]; then # to perturb the normal data to get the alignments. # _sp stands for speed-perturbed - for datadir in ${train_set}; do - utils/perturb_data_dir_speed.sh 0.9 data/${datadir} data/temp1 - utils/perturb_data_dir_speed.sh 1.1 data/${datadir} data/temp2 - utils/combine_data.sh data/${datadir}_tmp data/temp1 data/temp2 - utils/validate_data_dir.sh --no-feats data/${datadir}_tmp - rm -r data/temp1 data/temp2 + for datadir in ${train_set} ${ivector_train_set}; do + utils/data/perturb_data_dir_speed_3way.sh data/${datadir} data/${datadir}_sp + utils/fix_data_dir.sh data/${datadir}_sp mfccdir=mfcc_perturbed steps/make_mfcc.sh --cmd "$train_cmd" --nj 50 \ - data/${datadir}_tmp exp/make_mfcc/${datadir}_tmp $mfccdir || exit 1; - steps/compute_cmvn_stats.sh data/${datadir}_tmp exp/make_mfcc/${datadir}_tmp $mfccdir || exit 1; - utils/fix_data_dir.sh data/${datadir}_tmp - - utils/copy_data_dir.sh --spk-prefix sp1.0- --utt-prefix sp1.0- data/${datadir} data/temp0 - utils/combine_data.sh data/${datadir}_sp data/${datadir}_tmp data/temp0 + data/${datadir}_sp exp/make_mfcc/${datadir}_sp $mfccdir || exit 1; + steps/compute_cmvn_stats.sh \ + data/${datadir}_sp exp/make_mfcc/${datadir}_sp $mfccdir || exit 1; utils/fix_data_dir.sh data/${datadir}_sp - rm -r data/temp0 data/${datadir}_tmp done fi - - if [ $stage -le 2 ] && [ "$generate_alignments" == "true" ]; then - #obtain the alignment of the perturbed data - steps/align_fmllr.sh --nj 100 --cmd "$train_cmd" \ - data/${train_set}_sp data/lang exp/tri5a exp/tri5a_ali_${train_set}_sp || exit 1 - fi train_set=${train_set}_sp + if ! [ -z "$ivector_train_set" ]; then + ivector_train_set=${ivector_train_set}_sp + fi fi if [ $stage -le 3 ]; then @@ -58,24 +48,9 @@ if [ $stage -le 3 ]; then utils/create_split_dir.pl /export/b0{1,2,3,4}/$USER/kaldi-data/mfcc/fisher_english-$date/s5b/$mfccdir/storage $mfccdir/storage fi - # the 100k directory is copied seperately, as - # we want to use exp/tri2_ali for lda_mllt training - # the main train directory might be speed_perturbed - for dataset in $train_set $lda_train_set; do + for dataset in $ivector_train_set $train_set; do utils/copy_data_dir.sh data/$dataset data/${dataset}_hires - - # scale the waveforms, this is useful as we don't use CMVN - data_dir=data/${dataset}_hires - cat $data_dir/wav.scp | python -c " -import sys, os, subprocess, re, random -scale_low = 1.0/8 -scale_high = 2.0 -for line in sys.stdin.readlines(): - if len(line.strip()) == 0: - continue - print '{0} sox --vol {1} -t wav - -t wav - |'.format(line.strip(), random.uniform(scale_low, scale_high)) -"| sort -k1,1 -u > $data_dir/wav.scp_scaled || exit 1; - mv $data_dir/wav.scp_scaled $data_dir/wav.scp + utils/data/perturb_data_dir_volume.sh data/${dataset}_hires steps/make_mfcc.sh --nj 70 --mfcc-config conf/mfcc_hires.conf \ --cmd "$train_cmd" data/${dataset}_hires exp/make_hires/$dataset $mfccdir; @@ -94,53 +69,51 @@ for line in sys.stdin.readlines(): steps/compute_cmvn_stats.sh data/${dataset}_hires exp/make_hires/$dataset $mfccdir; utils/fix_data_dir.sh data/${dataset}_hires # remove segments with problems done +fi - # Take the first 30k utterances (about 1/8th of the data) this will be used - # for the diagubm training - utils/subset_data_dir.sh --first data/${train_set}_hires 30000 data/${train_set}_30k_hires - utils/data/remove_dup_utts.sh 200 data/${train_set}_30k_hires data/${train_set}_30k_nodup_hires # 33hr +if [ -z "$ivector_train_set" ]; then + ivector_train_set=$train_set fi # ivector extractor training if [ $stage -le 4 ]; then - # We need to build a small system just because we need the LDA+MLLT transform - # to train the diag-UBM on top of. We use --num-iters 13 because after we get - # the transform (12th iter is the last), any further training is pointless. - # this decision is based on fisher_english - steps/train_lda_mllt.sh --cmd "$train_cmd" --num-iters 13 \ + steps/online/nnet2/get_pca_transform.sh --cmd "$train_cmd" \ --splice-opts "--left-context=3 --right-context=3" \ - 5500 90000 data/${lda_train_set}_hires \ - data/lang $gmm_dir exp/nnet3${nnet3_affix}/tri3a + --max-utts 10000 --subsample 2 \ + data/${ivector_train_set}_hires \ + exp/nnet3${nnet3_affix}/pca_transform fi if [ $stage -le 5 ]; then - # To train a diagonal UBM we don't need very much data, so use the smallest subset. steps/online/nnet2/train_diag_ubm.sh --cmd "$train_cmd" --nj 30 --num-frames 200000 \ - data/${train_set}_30k_nodup_hires 512 exp/nnet3${nnet3_affix}/tri3a exp/nnet3${nnet3_affix}/diag_ubm + data/${ivector_train_set}_hires 512 \ + exp/nnet3${nnet3_affix}/pca_transform exp/nnet3${nnet3_affix}/diag_ubm fi if [ $stage -le 6 ]; then - # iVector extractors can be sensitive to the amount of data, but this one has a - # fairly small dim (defaults to 100) so we don't use all of it, we use just the - # 100k subset (just under half the data). steps/online/nnet2/train_ivector_extractor.sh --cmd "$train_cmd" --nj 10 \ - data/${lda_train_set}_hires exp/nnet3${nnet3_affix}/diag_ubm exp/nnet3${nnet3_affix}/extractor || exit 1; + data/${ivector_train_set}_hires exp/nnet3${nnet3_affix}/diag_ubm \ + exp/nnet3${nnet3_affix}/extractor || exit 1; fi if [ $stage -le 7 ]; then # We extract iVectors on all the ${train_set} data, which will be what we # train the system on. - # having a larger number of speakers is helpful for generalization, and to # handle per-utterance decoding well (iVector starts at zero). - steps/online/nnet2/copy_data_dir.sh --utts-per-spk-max 2 data/${train_set}_hires data/${train_set}_max2_hires + utils/data/modify_speaker_info.sh --utts-per-spk-max 2 \ + data/${ivector_train_set}_hires data/${ivector_train_set}_max2_hires steps/online/nnet2/extract_ivectors_online.sh --cmd "$train_cmd" --nj 30 \ - data/${train_set}_max2_hires exp/nnet3${nnet3_affix}/extractor exp/nnet3${nnet3_affix}/ivectors_${train_set}_hires || exit 1; + data/${ivector_train_set}_max2_hires exp/nnet3${nnet3_affix}/extractor \ + exp/nnet3${nnet3_affix}/ivectors_${ivector_train_set}_hires || exit 1; +fi +if [ $stage -le 8 ]; then for dataset in test dev; do steps/online/nnet2/extract_ivectors_online.sh --cmd "$train_cmd" --nj 30 \ - data/${dataset}_hires exp/nnet3${nnet3_affix}/extractor exp/nnet3${nnet3_affix}/ivectors_${dataset}_hires || exit 1; + data/${dataset}_hires exp/nnet3${nnet3_affix}/extractor \ + exp/nnet3${nnet3_affix}/ivectors_${dataset}_hires || exit 1; done fi diff --git a/egs/fisher_english/s5/local/run_unk_model.sh b/egs/fisher_english/s5/local/run_unk_model.sh new file mode 100755 index 00000000000..4a54c31e7b7 --- /dev/null +++ b/egs/fisher_english/s5/local/run_unk_model.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# Copyright 2017 Vimal Manohar + +utils/lang/make_unk_lm.sh data/local/dict exp/unk_lang_model || exit 1 + +utils/prepare_lang.sh \ + --unk-fst exp/unk_lang_model/unk_fst.txt \ + data/local/dict "" data/local/lang data/lang_unk + +# note: it's important that the LM we built in data/lang/G.fst was created using +# pocolm with the option --limit-unk-history=true (see ted_train_lm.sh). This +# keeps the graph compact after adding the unk model (we only have to add one +# copy of it). + +exit 0 + +## Caution: if you use this unk-model stuff, be sure that the scoring script +## does not use lattice-align-words-lexicon, because it's not compatible with +## the unk-model. Instead you should use lattice-align-words (of course, this +## only works if you have position-dependent phones). diff --git a/egs/fisher_english/s5/local/score.sh b/egs/fisher_english/s5/local/score.sh deleted file mode 100755 index c381abf7277..00000000000 --- a/egs/fisher_english/s5/local/score.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/bin/bash -# Copyright Johns Hopkins University (Author: Daniel Povey) 2012. Apache 2.0. - -# begin configuration section. -cmd=run.pl -min_lmwt=5 -max_lmwt=17 -#end configuration section. - -[ -f ./path.sh ] && . ./path.sh -. parse_options.sh || exit 1; - -if [ $# -ne 3 ]; then - echo "Usage: $0 [--cmd (run.pl|queue.pl...)] " - echo " Options:" - echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." - echo " --min_lmwt # minumum LM-weight for lattice rescoring " - echo " --max_lmwt # maximum LM-weight for lattice rescoring " - exit 1; -fi - -data=$1 -lang=$2 # Note: may be graph directory not lang directory, but has the necessary stuff copied. -dir=$3 - -model=$dir/../final.mdl # assume model one level up from decoding dir. - -for f in $data/text $lang/words.txt $dir/lat.1.gz; do - [ ! -f $f ] && echo "$0: expecting file $f to exist" && exit 1; -done - -name=`basename $data`; # e.g. eval2000 - -mkdir -p $dir/scoring/log - - -function filter_text { - perl -e 'foreach $w (@ARGV) { $bad{$w} = 1; } - while() { @A = split(" ", $_); $id = shift @A; print "$id "; - foreach $a (@A) { if (!defined $bad{$a}) { print "$a "; }} print "\n"; }' \ - '[NOISE]' '[LAUGHTER]' '[VOCALIZED-NOISE]' '' '%HESITATION' -} - -$cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/best_path.LMWT.log \ - lattice-best-path --lm-scale=LMWT --word-symbol-table=$lang/words.txt \ - "ark:gunzip -c $dir/lat.*.gz|" ark,t:$dir/scoring/LMWT.tra || exit 1; - -for lmwt in `seq $min_lmwt $max_lmwt`; do - utils/int2sym.pl -f 2- $lang/words.txt <$dir/scoring/$lmwt.tra | \ - filter_text > $dir/scoring/$lmwt.txt || exit 1; -done - -filter_text <$data/text >$dir/scoring/text.filt - -$cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/score.LMWT.log \ - compute-wer --text --mode=present \ - ark:$dir/scoring/text.filt ark:$dir/scoring/LMWT.txt ">&" $dir/wer_LMWT || exit 1; - -exit 0 diff --git a/egs/fisher_english/s5/local/score.sh b/egs/fisher_english/s5/local/score.sh new file mode 120000 index 00000000000..6a200b42ed3 --- /dev/null +++ b/egs/fisher_english/s5/local/score.sh @@ -0,0 +1 @@ +../steps/scoring/score_kaldi_wer.sh \ No newline at end of file diff --git a/egs/fisher_english/s5/local/semisup/chain/run_tdnn.sh b/egs/fisher_english/s5/local/semisup/chain/run_tdnn.sh new file mode 120000 index 00000000000..34499362831 --- /dev/null +++ b/egs/fisher_english/s5/local/semisup/chain/run_tdnn.sh @@ -0,0 +1 @@ +tuning/run_tdnn_1a.sh \ No newline at end of file diff --git a/egs/fisher_english/s5/local/semisup/chain/run_tdnn_100k_semisupervised.sh b/egs/fisher_english/s5/local/semisup/chain/run_tdnn_100k_semisupervised.sh new file mode 120000 index 00000000000..705b1a1dd12 --- /dev/null +++ b/egs/fisher_english/s5/local/semisup/chain/run_tdnn_100k_semisupervised.sh @@ -0,0 +1 @@ +tuning/run_tdnn_100k_semisupervised_1a.sh \ No newline at end of file diff --git a/egs/fisher_english/s5/local/semisup/chain/run_tdnn_50k_semisupervised.sh b/egs/fisher_english/s5/local/semisup/chain/run_tdnn_50k_semisupervised.sh new file mode 120000 index 00000000000..70ebebf3c13 --- /dev/null +++ b/egs/fisher_english/s5/local/semisup/chain/run_tdnn_50k_semisupervised.sh @@ -0,0 +1 @@ +tuning/run_tdnn_50k_semisupervised_1a.sh \ No newline at end of file diff --git a/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_100k_semisupervised_1a.sh b/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_100k_semisupervised_1a.sh new file mode 100644 index 00000000000..e6c3ff09e04 --- /dev/null +++ b/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_100k_semisupervised_1a.sh @@ -0,0 +1,436 @@ +#!/bin/bash + +# Copyright 2017 Vimal Manohar +# Apache 2.0 + +# This script is semi-supervised recipe with 100 hours of supervised data +# and 250 hours unsupervised data with naive splitting. +# We use only the supervised data for i-vector extractor training. +# We use 4-gram LM trained on 1250 hours of data excluding the 250 hours +# unsupervised data to create LM for decoding. Rescoring is done with +# a larger 4-gram LM. +# This script uses the same tree as that for the seed model. + +# Unsupervised set: train_unsup100k_250k +# unsup_frames_per_eg=150 +# Deriv weights: Lattice posterior of best path pdf +# Unsupervised weight: 1.0 +# Weights for phone LM (supervised, unsupervised): 3,2 +# LM for decoding unsupervised data: 4gram +# Supervision: Naive split lattices + +set -u -e -o pipefail + +stage=0 # Start from -1 for supervised seed system training +train_stage=-100 +nj=40 +decode_nj=40 +exp=exp/semisup_100k + +# Datasets -- Expects data/$supervised_set and data/$unsupervised_set to be +# present +unsupervised_set=train_unsup100k_250k # set this to your choice of unsupervised data +supervised_set=train_sup + +# Seed model options +nnet3_affix= # affix for nnet3 dir -- relates to i-vector used +chain_affix= # affix for chain dir +tdnn_affix=1a # affix for the supervised chain-model directory +tree_affix=bi_a # affix for the tree of the supervised model +train_supervised_opts="--stage -10 --train-stage -10" +gmm=tri4a # GMM model to get supervision for supervised data + +# Unsupervised options +decode_affix= # affix for decoded lattices +egs_affix= # affix for the egs that are generated from unsupervised data and for the comined egs dir +unsup_frames_per_eg=150 # if empty, will be equal to the supervised model's config +lattice_lm_scale=0.5 # lm-scale for using the weights from unsupervised lattices when creating numerator supervision +lattice_prune_beam=4.0 # If supplied, will prune the lattices prior to getting egs for unsupervised data +tolerance=1 # frame-tolerance for chain training +phone_insertion_penalty= + +rescore_unsup_lattices=false # const ARPA rescoring with a bigger LM -- false here because we have only LM text from 100 hours of data +unsup_rescoring_affix=big # affix for const ARPA lang dir + +# Semi-supervised options +comb_affix=comb250k_1a # affix for new chain-model directory trained on the combined supervised+unsupervised subsets +supervision_weights=1.0,1.0 # Weights for supervised, unsupervised data egs +lm_weights=3,2 # Weights on phone counts from supervised, unsupervised data for denominator FST creation + +sup_egs_dir= # Supply this to skip supervised egs creation +unsup_egs_dir= # Supply this to skip unsupervised egs creation +unsup_egs_opts= # Extra options to pass to unsupervised egs creation + +# Neural network opts +apply_deriv_weights=true +xent_regularize=0.1 +hidden_dim=725 + +decode_iter= # Iteration to decode with + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +if [ -f ./path.sh ]; then . ./path.sh; fi +. ./utils/parse_options.sh + +egs_affix=${egs_affix}_prun${lattice_prune_beam}_lmwt${lattice_lm_scale}_tol${tolerance} + +RANDOM=0 + +if ! cuda-compiled; then + cat < $chaindir/best_path_${unsupervised_set}${decode_affix}/frame_subsampling_factor + steps/nnet3/chain/make_weighted_den_fst.sh --num-repeats $lm_weights --cmd "$train_cmd" \ + ${treedir} ${chaindir}/best_path_${unsupervised_set}${decode_affix} \ + $dir +fi + +if [ $stage -le 11 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=$hidden_dim + relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$hidden_dim + relu-batchnorm-layer name=tdnn3 input=Append(-3,0,3) dim=$hidden_dim + relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=$hidden_dim + relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=$hidden_dim + relu-batchnorm-layer name=tdnn6 input=Append(-6,-3,0) dim=$hidden_dim + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain input=tdnn6 dim=$hidden_dim target-rms=0.5 + output-layer name=output input=prefinal-chain include-log-softmax=false dim=$num_targets max-change=1.5 + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn6 dim=$hidden_dim target-rms=0.5 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 + + # We use separate outputs for supervised and unsupervised data + # so we can properly track the train and valid objectives. + + output name=output-0 input=output.affine + output name=output-1 input=output.affine + + output name=output-0-xent input=output-xent.log-softmax + output name=output-1-xent input=output-xent.log-softmax +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +. $dir/configs/vars + +left_context=$model_left_context +right_context=$model_right_context +left_context_initial=0 +right_context_final=0 + +egs_left_context=`perl -e "print int($left_context + $frame_subsampling_factor / 2)"` +egs_right_context=`perl -e "print int($right_context + $frame_subsampling_factor / 2)"` +egs_left_context_initial=`perl -e "print int($left_context_initial + $frame_subsampling_factor / 2)"` +egs_right_context_final=`perl -e "print int($right_context_final + $frame_subsampling_factor / 2)"` + +if [ -z "$sup_egs_dir" ]; then + sup_egs_dir=$dir/egs_${supervised_set} + frames_per_eg=$(cat $chaindir/egs/info/frames_per_eg) + + if [ $stage -le 12 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $sup_egs_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/fisher_english-$(date +'%m_%d_%H_%M')/s5c/$sup_egs_dir/storage $sup_egs_dir/storage + fi + mkdir -p $sup_egs_dir/ + touch $sup_egs_dir/.nodelete # keep egs around when that run dies. + + echo "$0: generating egs from the supervised data" + steps/nnet3/chain/get_egs.sh --cmd "$decode_cmd" \ + --left-context $egs_left_context --right-context $egs_right_context \ + --left-context-initial $egs_left_context_initial --right-context-final $egs_right_context_final \ + --frame-subsampling-factor $frame_subsampling_factor \ + --alignment-subsampling-factor 3 \ + --frames-per-eg $frames_per_eg \ + --frames-per-iter 1500000 \ + --cmvn-opts "$cmvn_opts" \ + --online-ivector-dir $exp/nnet3${nnet3_affix}/ivectors_${supervised_set}_hires \ + --generate-egs-scp true \ + data/${supervised_set}_hires $dir \ + $sup_lat_dir $sup_egs_dir + fi +else + frames_per_eg=$(cat $sup_egs_dir/info/frames_per_eg) +fi + +unsup_lat_dir=${chaindir}/decode_${unsupervised_set}${decode_affix} +if [ -z "$unsup_egs_dir" ]; then + [ -z $unsup_frames_per_eg ] && [ ! -z "$frames_per_eg" ] && unsup_frames_per_eg=$frames_per_eg + unsup_egs_dir=$dir/egs_${unsupervised_set}${decode_affix}${egs_affix} + + if [ $stage -le 13 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $unsup_egs_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/fisher_english-$(date +'%m_%d_%H_%M')/s5c/$unsup_egs_dir/storage $unsup_egs_dir/storage + fi + mkdir -p $unsup_egs_dir + touch $unsup_egs_dir/.nodelete # keep egs around when that run dies. + + echo "$0: generating egs from the unsupervised data" + steps/nnet3/chain/get_egs.sh \ + --cmd "$decode_cmd" --alignment-subsampling-factor 1 \ + --left-tolerance $tolerance --right-tolerance $tolerance \ + --left-context $egs_left_context --right-context $egs_right_context \ + --left-context-initial $egs_left_context_initial --right-context-final $egs_right_context_final \ + --frames-per-eg $unsup_frames_per_eg --frames-per-iter 1500000 \ + --frame-subsampling-factor $frame_subsampling_factor \ + --cmvn-opts "$cmvn_opts" --lattice-lm-scale $lattice_lm_scale \ + --lattice-prune-beam "$lattice_prune_beam" \ + --phone-insertion-penalty "$phone_insertion_penalty" \ + --deriv-weights-scp $chaindir/best_path_${unsupervised_set}${decode_affix}/weights.scp \ + --online-ivector-dir $exp/nnet3${nnet3_affix}/ivectors_${unsupervised_set}_hires \ + --generate-egs-scp true $unsup_egs_opts \ + data/${unsupervised_set}_hires $dir \ + $unsup_lat_dir $unsup_egs_dir + fi +fi + +comb_egs_dir=$dir/${comb_affix}_egs${decode_affix}${egs_affix}_multi + +if [ $stage -le 14 ]; then + steps/nnet3/multilingual/combine_egs.sh --cmd "$train_cmd" \ + --minibatch-size 128 --frames-per-iter 1500000 \ + --lang2weight $supervision_weights --egs-prefix cegs. 2 \ + $sup_egs_dir $unsup_egs_dir $comb_egs_dir + touch $comb_egs_dir/.nodelete # keep egs around when that run dies. +fi + +if [ $train_stage -le -4 ]; then + train_stage=-4 +fi + +if [ $stage -le 15 ]; then + steps/nnet3/chain/train.py --stage $train_stage \ + --egs.dir "$comb_egs_dir" \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir $exp/nnet3${nnet3_affix}/ivectors_${supervised_set}_hires \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights $apply_deriv_weights \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width $frames_per_eg \ + --trainer.num-chunk-per-minibatch 128 \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs 4 \ + --trainer.optimization.num-jobs-initial 3 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs false \ + --feat-dir data/${supervised_set}_hires \ + --tree-dir $treedir \ + --lat-dir $sup_lat_dir \ + --dir $dir || exit 1; +fi + +test_graph_dir=$dir/graph${test_graph_affix} +if [ $stage -le 17 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 ${test_lang} $dir $test_graph_dir +fi + +if [ $stage -le 18 ]; then + iter_opts= + if [ ! -z $decode_iter ]; then + nnet3-copy --edits="remove-output-nodes name=output;rename-node old-name=output-0 new-name=output" $dir/${decode_iter}.mdl - | \ + nnet3-am-copy --set-raw-nnet=- $dir/${decode_iter}.mdl $dir/${decode_iter}-output.mdl || exit 1 + iter_opts=" --iter ${decode_iter}-output " + else + nnet3-copy --edits="remove-output-nodes name=output;rename-node old-name=output-0 new-name=output" $dir/final.mdl - | \ + nnet3-am-copy --set-raw-nnet=- $dir/final.mdl $dir/final-output.mdl || exit 1 + iter_opts=" --iter final-output " + fi + + for decode_set in dev test; do + ( + num_jobs=`cat data/${decode_set}_hires/utt2spk|cut -d' ' -f2|sort -u|wc -l` + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $num_jobs --cmd "$decode_cmd" $iter_opts \ + --online-ivector-dir $exp/nnet3${nnet3_affix}/ivectors_${decode_set}_hires \ + $test_graph_dir data/${decode_set}_hires \ + $dir/decode${test_graph_affix}_${decode_set}${decode_iter:+_iter$decode_iter} || exit 1; + ) & + done +fi + +wait; +exit 0; diff --git a/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_1a.sh b/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_1a.sh new file mode 100755 index 00000000000..8f39d46cc23 --- /dev/null +++ b/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_1a.sh @@ -0,0 +1,213 @@ +#!/bin/bash + +# Copyright 2017 Vimal Manohar +# Apache 2.0 + +set -e +set -o pipefail + +# This is fisher chain recipe for training a model on a subset of around +# 100-300 hours of supervised data. +# This system uses phone LM to model UNK. + +# configs for 'chain' +stage=0 +train_stage=-10 +get_egs_stage=-10 +exp=exp/semisup_100k + +tdnn_affix=1a +train_set=train_sup +ivector_train_set= # dataset for training i-vector extractor + +nnet3_affix= # affix for nnet3 dir -- relates to i-vector used +chain_affix= # affix for chain dir +tree_affix=bi_a +gmm=tri4a # Expect GMM model in $exp/$gmm for alignment + +# Neural network opts +xent_regularize=0.1 +hidden_dim=725 + +# training options +num_epochs=4 + +remove_egs=false +common_egs_dir= # if provided, will skip egs generation +common_treedir= # if provided, will skip the tree building stage + +decode_iter= + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +if [ -f ./path.sh ]; then . ./path.sh; fi +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ -z "$common_treedir" ]; then + if [ $stage -le 11 ]; then + # Build a tree using our new topology. + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ + --leftmost-questions-truncate -1 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 7000 data/${train_set}_sp $lang $lat_dir $treedir || exit 1 + fi +else + treedir=$common_treedir +fi + +if [ $stage -le 12 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=$hidden_dim + relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$hidden_dim + relu-batchnorm-layer name=tdnn3 input=Append(-3,0,3) dim=$hidden_dim + relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=$hidden_dim + relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=$hidden_dim + relu-batchnorm-layer name=tdnn6 input=Append(-6,-3,0) dim=$hidden_dim + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain input=tdnn6 dim=$hidden_dim target-rms=0.5 + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn6 dim=$hidden_dim target-rms=0.5 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 + +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 13 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/fisher_english-$(date +'%m_%d_%H_%M')/s5c/$dir/egs/storage $dir/egs/storage + fi + + mkdir -p $dir/egs + touch $dir/egs/.nodelete # keep egs around when that run dies. + + steps/nnet3/chain/train.py --stage $train_stage \ + --egs.dir "$common_egs_dir" \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir $train_ivector_dir \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize 0.1 \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --egs.stage $get_egs_stage \ + --egs.opts "--frames-overlap-per-eg 0 --generate-egs-scp true" \ + --egs.chunk-width 160,140,110,80 \ + --trainer.num-chunk-per-minibatch 128 \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.num-jobs-initial 3 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs $remove_egs \ + --feat-dir $train_data_dir \ + --tree-dir $treedir \ + --lat-dir $lat_dir \ + --dir $dir || exit 1; +fi + +graph_dir=$dir/graph_poco_unk +if [ $stage -le 14 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 data/lang_poco_test_unk $dir $graph_dir +fi + +decode_suff= +if [ $stage -le 15 ]; then + iter_opts= + if [ ! -z $decode_iter ]; then + iter_opts=" --iter $decode_iter " + fi + for decode_set in dev test; do + ( + num_jobs=`cat data/${decode_set}_hires/utt2spk|cut -d' ' -f2|sort -u|wc -l` + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $num_jobs --cmd "$decode_cmd" $iter_opts \ + --online-ivector-dir $exp/nnet3${nnet3_affix}/ivectors_${decode_set}_hires \ + $graph_dir data/${decode_set}_hires $dir/decode_poco_unk_${decode_set}${decode_iter:+_$decode_iter}${decode_suff} || exit 1; + ) & + done +fi +wait; +exit 0; diff --git a/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_50k_semisupervised_1a.sh b/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_50k_semisupervised_1a.sh new file mode 100755 index 00000000000..b110991e084 --- /dev/null +++ b/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_50k_semisupervised_1a.sh @@ -0,0 +1,415 @@ +#!/bin/bash + +# Copyright 2017 Vimal Manohar +# Apache 2.0 + +# This script is semi-supervised recipe with around 50 hours of supervised data +# and 250 hours unsupervised data with naive splitting. +# We use the combined data for i-vector extractor training. +# We use 4-gram LM trained on 1250 hours of data excluding the 250 hours +# unsupervised data to create LM for decoding. Rescoring is done with +# a larger 4-gram LM. +# This script uses phone LM to model UNK. +# This script uses the same tree as that for the seed model. + +# Unsupervised set: train_unsup100k_250k +# unsup_frames_per_eg=150 +# Deriv weights: Lattice posterior of best path pdf +# Unsupervised weight: 1.0 +# Weights for phone LM (supervised, unsupervised): 5,2 +# LM for decoding unsupervised data: 4gram +# Supervision: Naive split lattices + +set -u -e -o pipefail + +stage=0 # Start from -1 for supervised seed system training +train_stage=-100 +nj=40 +decode_nj=40 +exp=exp/semisup_50k + +# Datasets -- Expects data/$supervised_set and data/$unsupervised_set to be +# present +supervised_set=train_sup50k +unsupervised_set=train_unsup100k_250k +semisup_train_set=semisup50k_100k_250k + +# Seed model options +nnet3_affix=_semi50k_100k_250k # affix for nnet3 dir -- relates to i-vector used +chain_affix=_semi50k_100k_250k # affix for chain dir +tdnn_affix=1a # affix for the supervised chain-model directory +tree_affix=bi_a # affix for the tree of the supervised model +train_supervised_opts="--stage -10 --train-stage -10" +gmm=tri4a # GMM model to get supervision for supervised data + +# Unsupervised options +decode_affix= # affix for decoded lattices +egs_affix= # affix for the egs that are generated from unsupervised data and for the comined egs dir +unsup_frames_per_eg=150 # if empty, will be equal to the supervised model's config +lattice_lm_scale=0.5 # lm-scale for using the weights from unsupervised lattices when creating numerator supervision +lattice_prune_beam=4.0 # If supplied, will prune the lattices prior to getting egs for unsupervised data +tolerance=1 # frame-tolerance for chain training +phone_insertion_penalty= + +rescore_unsup_lattices=true # const ARPA rescoring with a bigger LM +unsup_rescoring_affix=big # affix for const ARPA lang dir + +# Semi-supervised options +comb_affix=comb1a # affix for new chain-model directory trained on the combined supervised+unsupervised subsets +supervision_weights=1.0,1.0 # Weights for supervised, unsupervised data egs +lm_weights=3,2 # Weights on phone counts from supervised, unsupervised data for denominator FST creation + +sup_egs_dir= # Supply this to skip supervised egs creation +unsup_egs_dir= # Supply this to skip unsupervised egs creation +unsup_egs_opts= # Extra options to pass to unsupervised egs creation + +# Neural network opts +apply_deriv_weights=true +xent_regularize=0.1 +hidden_dim=725 + +decode_iter= # Iteration to decode with + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +if [ -f ./path.sh ]; then . ./path.sh; fi +. ./utils/parse_options.sh + +egs_affix=${egs_affix}_prun${lattice_prune_beam}_lmwt${lattice_lm_scale}_tol${tolerance} + +RANDOM=0 + +if ! cuda-compiled; then + cat < $chaindir/best_path_${unsupervised_set}${decode_affix}/frame_subsampling_factor + steps/nnet3/chain/make_weighted_den_fst.sh --num-repeats $lm_weights --cmd "$train_cmd" \ + ${treedir} ${chaindir}/best_path_${unsupervised_set}${decode_affix} \ + $dir +fi + +if [ $stage -le 11 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=$hidden_dim + relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$hidden_dim + relu-batchnorm-layer name=tdnn3 input=Append(-3,0,3) dim=$hidden_dim + relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=$hidden_dim + relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=$hidden_dim + relu-batchnorm-layer name=tdnn6 input=Append(-6,-3,0) dim=$hidden_dim + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain input=tdnn6 dim=$hidden_dim target-rms=0.5 + output-layer name=output input=prefinal-chain include-log-softmax=false dim=$num_targets max-change=1.5 + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn6 dim=$hidden_dim target-rms=0.5 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 + + # We use separate outputs for supervised and unsupervised data + # so we can properly track the train and valid objectives. + + output name=output-0 input=output.affine + output name=output-1 input=output.affine + + output name=output-0-xent input=output-xent.log-softmax + output name=output-1-xent input=output-xent.log-softmax +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +. $dir/configs/vars + +left_context=$model_left_context +right_context=$model_right_context +left_context_initial=0 +right_context_final=0 + +egs_left_context=`perl -e "print int($left_context + $frame_subsampling_factor / 2)"` +egs_right_context=`perl -e "print int($right_context + $frame_subsampling_factor / 2)"` +egs_left_context_initial=`perl -e "print int($left_context_initial + $frame_subsampling_factor / 2)"` +egs_right_context_final=`perl -e "print int($right_context_final + $frame_subsampling_factor / 2)"` + +if [ -z "$sup_egs_dir" ]; then + sup_egs_dir=$dir/egs_${supervised_set} + frames_per_eg=$(cat $chaindir/egs/info/frames_per_eg) + + if [ $stage -le 12 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $sup_egs_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/fisher_english-$(date +'%m_%d_%H_%M')/s5c/$sup_egs_dir/storage $sup_egs_dir/storage + fi + mkdir -p $sup_egs_dir/ + touch $sup_egs_dir/.nodelete # keep egs around when that run dies. + + echo "$0: generating egs from the supervised data" + steps/nnet3/chain/get_egs.sh --cmd "$decode_cmd" \ + --left-context $egs_left_context --right-context $egs_right_context \ + --left-context-initial $egs_left_context_initial --right-context-final $egs_right_context_final \ + --frame-subsampling-factor $frame_subsampling_factor \ + --alignment-subsampling-factor 3 \ + --frames-per-eg $frames_per_eg \ + --frames-per-iter 1500000 \ + --cmvn-opts "$cmvn_opts" \ + --online-ivector-dir $exp/nnet3${nnet3_affix}/ivectors_${supervised_set}_hires \ + --generate-egs-scp true \ + data/${supervised_set}_hires $dir \ + $sup_lat_dir $sup_egs_dir + fi +else + frames_per_eg=$(cat $sup_egs_dir/info/frames_per_eg) +fi + +unsup_lat_dir=${chaindir}/decode_${unsupervised_set}${decode_affix} +if [ -z "$unsup_egs_dir" ]; then + [ -z $unsup_frames_per_eg ] && [ ! -z "$frames_per_eg" ] && unsup_frames_per_eg=$frames_per_eg + unsup_egs_dir=$dir/egs_${unsupervised_set}${decode_affix}${egs_affix} + + if [ $stage -le 13 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $unsup_egs_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/fisher_english-$(date +'%m_%d_%H_%M')/s5c/$unsup_egs_dir/storage $unsup_egs_dir/storage + fi + mkdir -p $unsup_egs_dir + touch $unsup_egs_dir/.nodelete # keep egs around when that run dies. + + echo "$0: generating egs from the unsupervised data" + steps/nnet3/chain/get_egs.sh \ + --cmd "$decode_cmd" --alignment-subsampling-factor 1 \ + --left-tolerance $tolerance --right-tolerance $tolerance \ + --left-context $egs_left_context --right-context $egs_right_context \ + --left-context-initial $egs_left_context_initial --right-context-final $egs_right_context_final \ + --frames-per-eg $unsup_frames_per_eg --frames-per-iter 1500000 \ + --frame-subsampling-factor $frame_subsampling_factor \ + --cmvn-opts "$cmvn_opts" --lattice-lm-scale $lattice_lm_scale \ + --lattice-prune-beam "$lattice_prune_beam" \ + --phone-insertion-penalty "$phone_insertion_penalty" \ + --deriv-weights-scp $chaindir/best_path_${unsupervised_set}${decode_affix}/weights.scp \ + --online-ivector-dir $exp/nnet3${nnet3_affix}/ivectors_${unsupervised_set}_hires \ + --generate-egs-scp true $unsup_egs_opts \ + data/${unsupervised_set}_hires $dir \ + $unsup_lat_dir $unsup_egs_dir + fi +fi + +comb_egs_dir=$dir/${comb_affix}_egs${decode_affix}${egs_affix}_multi + +if [ $stage -le 14 ]; then + steps/nnet3/multilingual/combine_egs.sh --cmd "$train_cmd" \ + --minibatch-size 128 --frames-per-iter 1500000 \ + --lang2weight $supervision_weights --egs-prefix cegs. 2 \ + $sup_egs_dir $unsup_egs_dir $comb_egs_dir + touch $comb_egs_dir/.nodelete # keep egs around when that run dies. +fi + +if [ $train_stage -le -4 ]; then + train_stage=-4 +fi + +if [ $stage -le 15 ]; then + steps/nnet3/chain/train.py --stage $train_stage \ + --egs.dir "$comb_egs_dir" \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir $exp/nnet3${nnet3_affix}/ivectors_${supervised_set}_hires \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights $apply_deriv_weights \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width $frames_per_eg \ + --trainer.num-chunk-per-minibatch 128 \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs 4 \ + --trainer.optimization.num-jobs-initial 3 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs false \ + --feat-dir data/${supervised_set}_hires \ + --tree-dir $treedir \ + --lat-dir $sup_lat_dir \ + --dir $dir || exit 1; +fi + +test_graph_dir=$dir/graph${test_graph_affix} +if [ $stage -le 17 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 ${test_lang} $dir $test_graph_dir +fi + +if [ $stage -le 18 ]; then + iter_opts= + if [ ! -z $decode_iter ]; then + nnet3-copy --edits="remove-output-nodes name=output;rename-node old-name=output-0 new-name=output" $dir/${decode_iter}.mdl - | \ + nnet3-am-copy --set-raw-nnet=- $dir/${decode_iter}.mdl $dir/${decode_iter}-output.mdl || exit 1 + iter_opts=" --iter ${decode_iter}-output " + else + nnet3-copy --edits="remove-output-nodes name=output;rename-node old-name=output-0 new-name=output" $dir/final.mdl - | \ + nnet3-am-copy --set-raw-nnet=- $dir/final.mdl $dir/final-output.mdl || exit 1 + iter_opts=" --iter final-output " + fi + + for decode_set in dev test; do + ( + num_jobs=`cat data/${decode_set}_hires/utt2spk|cut -d' ' -f2|sort -u|wc -l` + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $num_jobs --cmd "$decode_cmd" $iter_opts \ + --online-ivector-dir $exp/nnet3${nnet3_affix}/ivectors_${decode_set}_hires \ + $test_graph_dir data/${decode_set}_hires \ + $dir/decode${test_graph_affix}_${decode_set}${decode_iter:+_iter$decode_iter} || exit 1; + ) & + done +fi + +wait; +exit 0; diff --git a/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_50k_semisupervised_1b.sh b/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_50k_semisupervised_1b.sh new file mode 100644 index 00000000000..c625e5ada9f --- /dev/null +++ b/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_50k_semisupervised_1b.sh @@ -0,0 +1,428 @@ +#!/bin/bash + +# Copyright 2017 Vimal Manohar +# Apache 2.0 + +# This script is semi-supervised recipe with 50 hours of supervised data +# and 250 hours unsupervised data with naive splitting. +# We use the combined data for i-vector extractor training. +# We use 4-gram LM trained on 1250 hours of data excluding the 250 hours +# unsupervised data to create LM for decoding. Rescoring is done with +# a larger 4-gram LM. +# This script uses phone LM to model UNK. +# This script builds a new tree using stats from both supervised and +# unsupervised data. + +# Unsupervised set: train_unsup100k_250k +# unsup_frames_per_eg=150 +# Deriv weights: Lattice posterior of best path pdf +# Unsupervised weight: 1.0 +# Weights for phone LM (supervised, unsupervised): 5,2 +# LM for decoding unsupervised data: 4gram +# Supervision: Naive split lattices + +set -u -e -o pipefail + +stage=0 # Start from -1 for supervised seed system training +train_stage=-100 +nj=40 +decode_nj=40 +exp=exp/semisup_50k + +# Datasets -- Expects data/$supervised_set and data/$unsupervised_set to be +# present +supervised_set=train_sup50k +unsupervised_set=train_unsup100k_250k +semisup_train_set=semisup50k_100k_250k + +# Seed model options +nnet3_affix=_semi50k_100k_250k # affix for nnet3 dir -- relates to i-vector used +chain_affix=_semi50k_100k_250k # affix for chain dir +tdnn_affix=1b # affix for the supervised chain-model directory +train_supervised_opts="--stage -10 --train-stage -10" +gmm=tri4a # GMM model to get supervision for supervised data + +# Unsupervised options +decode_affix= # affix for decoded lattices +egs_affix= # affix for the egs that are generated from unsupervised data and for the comined egs dir +unsup_frames_per_eg=150 # if empty, will be equal to the supervised model's config +lattice_lm_scale=0.5 # lm-scale for using the weights from unsupervised lattices when creating numerator supervision +lattice_prune_beam=4.0 # If supplied, will prune the lattices prior to getting egs for unsupervised data +tolerance=1 # frame-tolerance for chain training +phone_insertion_penalty= + +rescore_unsup_lattices=true # const ARPA rescoring with a bigger LM +unsup_rescoring_affix=big # affix for const ARPA lang dir + +# Semi-supervised options +comb_affix=comb1b # affix for new chain-model directory trained on the combined supervised+unsupervised subsets +supervision_weights=1.0,1.0 # Weights for supervised, unsupervised data egs +lm_weights=3,2 # Weights on phone counts from supervised, unsupervised data for denominator FST creation + +sup_egs_dir= # Supply this to skip supervised egs creation +unsup_egs_dir= # Supply this to skip unsupervised egs creation +unsup_egs_opts= # Extra options to pass to unsupervised egs creation + +# Neural network opts +apply_deriv_weights=true +xent_regularize=0.1 +hidden_dim=725 + +decode_iter= # Iteration to decode with + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +if [ -f ./path.sh ]; then . ./path.sh; fi +. ./utils/parse_options.sh + +egs_affix=${egs_affix}_prun${lattice_prune_beam}_lmwt${lattice_lm_scale}_tol${tolerance} + +RANDOM=0 + +if ! cuda-compiled; then + cat < $chaindir/best_path_${unsupervised_set}${decode_affix}/frame_subsampling_factor + + # Build a new tree using stats from both supervised and unsupervised data + steps/nnet3/chain/build_tree_multiple_sources.sh \ + --use-fmllr false --context-opts "--context-width=2 --central-position=1" \ + --frame-subsampling-factor 3 \ + 7000 $lang \ + data/${supervised_set} \ + ${sup_ali_dir} \ + data/${unsupervised_set} \ + $chaindir/best_path_${unsupervised_set}${decode_affix} \ + $treedir || exit 1 +fi + +dir=$exp/chain${chain_affix}/tdnn${tdnn_affix}${decode_affix}${egs_affix}${comb_affix:+_$comb_affix} + +# Train denominator FST using phone alignments from +# supervised and unsupervised data +if [ $stage -le 10 ]; then + steps/nnet3/chain/make_weighted_den_fst.sh --num-repeats $lm_weights --cmd "$train_cmd" \ + ${treedir} ${chaindir}/best_path_${unsupervised_set}${decode_affix} \ + $dir +fi + +if [ $stage -le 11 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=$hidden_dim + relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$hidden_dim + relu-batchnorm-layer name=tdnn3 input=Append(-3,0,3) dim=$hidden_dim + relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=$hidden_dim + relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=$hidden_dim + relu-batchnorm-layer name=tdnn6 input=Append(-6,-3,0) dim=$hidden_dim + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain input=tdnn6 dim=$hidden_dim target-rms=0.5 + output-layer name=output input=prefinal-chain include-log-softmax=false dim=$num_targets max-change=1.5 + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn6 dim=$hidden_dim target-rms=0.5 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 + + # We use separate outputs for supervised and unsupervised data + # so we can properly track the train and valid objectives. + + output name=output-0 input=output.affine + output name=output-1 input=output.affine + + output name=output-0-xent input=output-xent.log-softmax + output name=output-1-xent input=output-xent.log-softmax +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +. $dir/configs/vars + +left_context=$model_left_context +right_context=$model_right_context +left_context_initial=0 +right_context_final=0 + +egs_left_context=`perl -e "print int($left_context + $frame_subsampling_factor / 2)"` +egs_right_context=`perl -e "print int($right_context + $frame_subsampling_factor / 2)"` +egs_left_context_initial=`perl -e "print int($left_context_initial + $frame_subsampling_factor / 2)"` +egs_right_context_final=`perl -e "print int($right_context_final + $frame_subsampling_factor / 2)"` + +if [ -z "$sup_egs_dir" ]; then + sup_egs_dir=$dir/egs_${supervised_set} + frames_per_eg=$(cat $chaindir/egs/info/frames_per_eg) + + if [ $stage -le 12 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $sup_egs_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/fisher_english-$(date +'%m_%d_%H_%M')/s5c/$sup_egs_dir/storage $sup_egs_dir/storage + fi + mkdir -p $sup_egs_dir/ + touch $sup_egs_dir/.nodelete # keep egs around when that run dies. + + echo "$0: generating egs from the supervised data" + steps/nnet3/chain/get_egs.sh --cmd "$decode_cmd" \ + --left-context $egs_left_context --right-context $egs_right_context \ + --left-context-initial $egs_left_context_initial --right-context-final $egs_right_context_final \ + --frame-subsampling-factor $frame_subsampling_factor \ + --alignment-subsampling-factor 3 \ + --frames-per-eg $frames_per_eg \ + --frames-per-iter 1500000 \ + --cmvn-opts "$cmvn_opts" \ + --online-ivector-dir $exp/nnet3${nnet3_affix}/ivectors_${supervised_set}_hires \ + --generate-egs-scp true \ + data/${supervised_set}_hires $dir \ + $sup_lat_dir $sup_egs_dir + fi +else + frames_per_eg=$(cat $sup_egs_dir/info/frames_per_eg) +fi + +unsup_lat_dir=${chaindir}/decode_${unsupervised_set}${decode_affix} +if [ -z "$unsup_egs_dir" ]; then + [ -z $unsup_frames_per_eg ] && [ ! -z "$frames_per_eg" ] && unsup_frames_per_eg=$frames_per_eg + unsup_egs_dir=$dir/egs_${unsupervised_set}${decode_affix}${egs_affix} + + if [ $stage -le 13 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $unsup_egs_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/fisher_english-$(date +'%m_%d_%H_%M')/s5c/$unsup_egs_dir/storage $unsup_egs_dir/storage + fi + mkdir -p $unsup_egs_dir + touch $unsup_egs_dir/.nodelete # keep egs around when that run dies. + + echo "$0: generating egs from the unsupervised data" + steps/nnet3/chain/get_egs.sh \ + --cmd "$decode_cmd" --alignment-subsampling-factor 1 \ + --left-tolerance $tolerance --right-tolerance $tolerance \ + --left-context $egs_left_context --right-context $egs_right_context \ + --left-context-initial $egs_left_context_initial --right-context-final $egs_right_context_final \ + --frames-per-eg $unsup_frames_per_eg --frames-per-iter 1500000 \ + --frame-subsampling-factor $frame_subsampling_factor \ + --cmvn-opts "$cmvn_opts" --lattice-lm-scale $lattice_lm_scale \ + --lattice-prune-beam "$lattice_prune_beam" \ + --phone-insertion-penalty "$phone_insertion_penalty" \ + --deriv-weights-scp $chaindir/best_path_${unsupervised_set}${decode_affix}/weights.scp \ + --online-ivector-dir $exp/nnet3${nnet3_affix}/ivectors_${unsupervised_set}_hires \ + --generate-egs-scp true $unsup_egs_opts \ + data/${unsupervised_set}_hires $dir \ + $unsup_lat_dir $unsup_egs_dir + fi +fi + +comb_egs_dir=$dir/${comb_affix}_egs${decode_affix}${egs_affix}_multi + +if [ $stage -le 14 ]; then + steps/nnet3/multilingual/combine_egs.sh --cmd "$train_cmd" \ + --minibatch-size 128 --frames-per-iter 1500000 \ + --lang2weight $supervision_weights --egs-prefix cegs. 2 \ + $sup_egs_dir $unsup_egs_dir $comb_egs_dir + touch $comb_egs_dir/.nodelete # keep egs around when that run dies. +fi + +if [ $train_stage -le -4 ]; then + train_stage=-4 +fi + +if [ $stage -le 15 ]; then + steps/nnet3/chain/train.py --stage $train_stage \ + --egs.dir "$comb_egs_dir" \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir $exp/nnet3${nnet3_affix}/ivectors_${supervised_set}_hires \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights $apply_deriv_weights \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width $frames_per_eg \ + --trainer.num-chunk-per-minibatch 128 \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs 4 \ + --trainer.optimization.num-jobs-initial 3 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs false \ + --feat-dir data/${supervised_set}_hires \ + --tree-dir $treedir \ + --lat-dir $sup_lat_dir \ + --dir $dir || exit 1; +fi + +test_graph_dir=$dir/graph${test_graph_affix} +if [ $stage -le 17 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 ${test_lang} $dir $test_graph_dir +fi + +if [ $stage -le 18 ]; then + iter_opts= + if [ ! -z $decode_iter ]; then + nnet3-copy --edits="remove-output-nodes name=output;rename-node old-name=output-0 new-name=output" $dir/${decode_iter}.mdl - | \ + nnet3-am-copy --set-raw-nnet=- $dir/${decode_iter}.mdl $dir/${decode_iter}-output.mdl || exit 1 + iter_opts=" --iter ${decode_iter}-output " + else + nnet3-copy --edits="remove-output-nodes name=output;rename-node old-name=output-0 new-name=output" $dir/final.mdl - | \ + nnet3-am-copy --set-raw-nnet=- $dir/final.mdl $dir/final-output.mdl || exit 1 + iter_opts=" --iter final-output " + fi + + for decode_set in dev test; do + ( + num_jobs=`cat data/${decode_set}_hires/utt2spk|cut -d' ' -f2|sort -u|wc -l` + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $num_jobs --cmd "$decode_cmd" $iter_opts \ + --online-ivector-dir $exp/nnet3${nnet3_affix}/ivectors_${decode_set}_hires \ + $test_graph_dir data/${decode_set}_hires \ + $dir/decode${test_graph_affix}_${decode_set}${decode_iter:+_iter$decode_iter} || exit 1; + ) & + done +fi + +wait; +exit 0; diff --git a/egs/fisher_english/s5/local/semisup/nnet3/run_ivector_common.sh b/egs/fisher_english/s5/local/semisup/nnet3/run_ivector_common.sh new file mode 100755 index 00000000000..99410aa79e9 --- /dev/null +++ b/egs/fisher_english/s5/local/semisup/nnet3/run_ivector_common.sh @@ -0,0 +1,175 @@ +#!/bin/bash + +# Copyright 2017 Vimal Manohar +# Apache 2.0 + +# This script is similar to local/nnet3/run_ivector_common.sh, but +# designed specifically for semi-supervised recipes. +# This script accepts an optional argument of --unsup-train-set for +# unsupervised data directory, which will be combined with the supervised +# data directory to create semi-supervised data directory (whose name +# is taken from the argument --semisup-train-set") + +. ./cmd.sh +set -e +stage=-1 +speed_perturb=true +train_set=train # Supervised training set + +# Unsupervised training set. +# If provided, it will be combined with supervised training set to +# create "semisup_train_set". This is the set that will be used to +# train the PCA transform and i-vector extractor. +unsup_train_set= +semisup_train_set= + +nnet3_affix= +exp=exp # experiments directory. It could be something like exp/semisup_15k. + +. ./path.sh +. ./utils/parse_options.sh + +if [ ! -z "$unsup_train_set" ] && [ -z "$semisup_train_set" ]; then + echo "$0: --semisup-train-set must be provided if --unsup-train-set is provided" + exit 1 +fi + +if [ -z "$unsup_train_set" ] && [ ! -z "$semisup_train_set" ]; then + echo "$0: --unsup-train-set must be provided if --semisup-train-set is provided" + exit 1 +fi + +if [ ! -f data/$train_set/utt2spk ]; then + echo "$0: data/$train_set/utt2spk does not exist" + exit 1 +fi + +if [ ! -z "$unsup_train_set" ]; then + if [ ! -f data/$unsup_train_set/utt2spk ]; then + echo "$0: data/$unsup_train_set/utt2spk does not exist" + exit 1 + fi + + # Combine supervised and unsupervised sets to create the + # semi-supervised training set. + if [ $stage -le 0 ]; then + utils/combine_data.sh data/$semisup_train_set \ + data/$train_set data/$unsup_train_set || exit 1 + fi +fi + +# perturbed data preparation +if [ "$speed_perturb" == "true" ]; then + if [ $stage -le 1 ]; then + # Although the nnet will be trained by high resolution data, we still have + # to perturb the normal data to get the alignments. + # _sp stands for speed-perturbed + + for datadir in ${train_set} ${unsup_train_set}; do + utils/data/perturb_data_dir_speed_3way.sh data/${datadir} data/${datadir}_sp + utils/fix_data_dir.sh data/${datadir}_sp + + mfccdir=mfcc_perturbed + steps/make_mfcc.sh --cmd "$train_cmd" --nj 50 \ + data/${datadir}_sp exp/make_mfcc/${datadir}_sp $mfccdir || exit 1; + steps/compute_cmvn_stats.sh data/${datadir}_sp exp/make_mfcc/${datadir}_sp $mfccdir || exit 1; + utils/fix_data_dir.sh data/${datadir}_sp + done + fi +fi + +if [ ! -z "$unsup_train_set" ]; then + if [ -f data/${semisup_train_set}_sp/feats.scp ]; then + echo "$0: data/${semisup_train_set}_sp/feats.scp already exists! Remove it and try again." + exit 1 + fi + + if [ $stage -le 2 ]; then + utils/combine_data.sh data/${semisup_train_set}_sp \ + data/${train_set}_sp data/${unsup_train_set}_sp + fi +fi + +if [ $stage -le 3 ]; then + mfccdir=mfcc_hires + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + date=$(date +'%m_%d_%H_%M') + utils/create_split_dir.pl /export/b0{1,2,3,4}/$USER/kaldi-data/egs/fisher_english-$date/s5b/$mfccdir/storage $mfccdir/storage + fi + + for dataset in $train_set $unsup_train_set; do + utils/copy_data_dir.sh data/${dataset}_sp data/${dataset}_sp_hires + utils/data/perturb_data_dir_volume.sh data/${dataset}_sp_hires + + steps/make_mfcc.sh --nj 70 --mfcc-config conf/mfcc_hires.conf \ + --cmd "$train_cmd" data/${dataset}_sp_hires exp/make_hires/${dataset}_sp $mfccdir; + steps/compute_cmvn_stats.sh data/${dataset}_sp_hires exp/make_hires/${dataset}_sp $mfccdir; + + # Remove the small number of utterances that couldn't be extracted for some + # reason (e.g. too short; no such file). + utils/fix_data_dir.sh data/${dataset}_sp_hires; + done + + for dataset in test dev; do + # Create MFCCs for the eval set + utils/copy_data_dir.sh data/$dataset data/${dataset}_hires + steps/make_mfcc.sh --cmd "$train_cmd" --nj 10 --mfcc-config conf/mfcc_hires.conf \ + data/${dataset}_hires exp/make_hires/$dataset $mfccdir; + steps/compute_cmvn_stats.sh data/${dataset}_hires exp/make_hires/$dataset $mfccdir; + utils/fix_data_dir.sh data/${dataset}_hires # remove segments with problems + done +fi + +ivector_train_set=${train_set}_sp +if [ ! -z "$unsup_train_set" ]; then + if [ -f data/${semisup_train_set}_sp_hires/feats.scp ]; then + echo "$0: data/${semisup_train_set}_sp_hires/feats.scp already exists! Remove it and try again." + exit 1 + fi + + if [ $stage -le 3 ]; then + utils/combine_data.sh data/${semisup_train_set}_sp_hires \ + data/${train_set}_sp_hires data/${unsup_train_set}_sp_hires + fi + ivector_train_set=${semisup_train_set}_sp +fi + +# ivector extractor training +if [ $stage -le 4 ]; then + steps/online/nnet2/get_pca_transform.sh --cmd "$train_cmd" \ + --splice-opts "--left-context=3 --right-context=3" \ + --max-utts 10000 --subsample 2 \ + data/${ivector_train_set}_hires \ + $exp/nnet3${nnet3_affix}/pca_transform +fi + +if [ $stage -le 5 ]; then + steps/online/nnet2/train_diag_ubm.sh --cmd "$train_cmd" --nj 30 --num-frames 200000 \ + data/${ivector_train_set}_hires 512 \ + $exp/nnet3${nnet3_affix}/pca_transform $exp/nnet3${nnet3_affix}/diag_ubm +fi + +if [ $stage -le 6 ]; then + steps/online/nnet2/train_ivector_extractor.sh --cmd "$train_cmd" --nj 10 \ + data/${ivector_train_set}_hires $exp/nnet3${nnet3_affix}/diag_ubm $exp/nnet3${nnet3_affix}/extractor || exit 1; +fi + +if [ $stage -le 7 ]; then + # We extract iVectors on all the ${train_set} data, which will be what we + # train the system on. + # having a larger number of speakers is helpful for generalization, and to + # handle per-utterance decoding well (iVector starts at zero). + steps/online/nnet2/copy_data_dir.sh --utts-per-spk-max 2 data/${ivector_train_set}_hires data/${ivector_train_set}_max2_hires + + steps/online/nnet2/extract_ivectors_online.sh --cmd "$train_cmd" --nj 30 \ + data/${ivector_train_set}_max2_hires $exp/nnet3${nnet3_affix}/extractor $exp/nnet3${nnet3_affix}/ivectors_${ivector_train_set}_hires || exit 1; +fi + +if [ $stage -le 8 ]; then + for dataset in test dev; do + steps/online/nnet2/extract_ivectors_online.sh --cmd "$train_cmd" --nj 30 \ + data/${dataset}_hires $exp/nnet3${nnet3_affix}/extractor $exp/nnet3${nnet3_affix}/ivectors_${dataset}_hires || exit 1; + done +fi + +exit 0; diff --git a/egs/fisher_english/s5/local/semisup/run_100k.sh b/egs/fisher_english/s5/local/semisup/run_100k.sh new file mode 100644 index 00000000000..14162872ee1 --- /dev/null +++ b/egs/fisher_english/s5/local/semisup/run_100k.sh @@ -0,0 +1,121 @@ +#!/bin/bash + +# Copyright 2017 Vimal Manohar +# Apache 2.0 + +. cmd.sh +. path.sh + +stage=-1 +train_stage=-10 + +. utils/parse_options.sh + +set -o pipefail +exp=exp/semisup_100k + +for f in data/train_sup/utt2spk data/train_unsup100k_250k/utt2spk; do + if [ ! -f $f ]; then + echo "$0: Could not find $f" + exit 1 + fi +done + +utils/subset_data_dir.sh --shortest data/train_sup 100000 data/train_sup_100kshort +utils/subset_data_dir.sh data/train_sup_100kshort 10000 data/train_sup_10k +utils/data/remove_dup_utts.sh 100 data/train_sup_10k data/train_sup_10k_nodup +utils/subset_data_dir.sh --speakers data/train_sup 30000 data/train_sup_30k + +steps/train_mono.sh --nj 10 --cmd "$train_cmd" \ + data/train_sup_10k_nodup data/lang $exp/mono0a || exit 1 + +steps/align_si.sh --nj 30 --cmd "$train_cmd" \ + data/train_sup_30k data/lang $exp/mono0a $exp/mono0a_ali || exit 1 + +steps/train_deltas.sh --cmd "$train_cmd" \ + 2500 20000 data/train_sup_30k data/lang $exp/mono0a_ali $exp/tri1 || exit 1 + +(utils/mkgraph.sh data/lang_test $exp/tri1 $exp/tri1/graph + steps/decode.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ + $exp/tri1/graph data/dev $exp/tri1/decode_dev)& + +steps/align_si.sh --nj 30 --cmd "$train_cmd" \ + data/train_sup_30k data/lang $exp/tri1 $exp/tri1_ali || exit 1; + +steps/train_lda_mllt.sh --cmd "$train_cmd" \ + 2500 20000 data/train_sup_30k data/lang $exp/tri1_ali $exp/tri2 || exit 1; + +(utils/mkgraph.sh data/lang_test $exp/tri2 $exp/tri2/graph + steps/decode.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ + $exp/tri2/graph data/dev $exp/tri2/decode_dev)& + +steps/align_si.sh --nj 30 --cmd "$train_cmd" \ + data/train_sup data/lang $exp/tri2 $exp/tri2_ali || exit 1; + +steps/train_lda_mllt.sh --cmd "$train_cmd" \ + --splice-opts "--left-context=3 --right-context=3" \ + 5000 40000 data/train_sup data/lang $exp/tri2_ali $exp/tri3a || exit 1; + +( + utils/mkgraph.sh data/lang_test $exp/tri3a $exp/tri3a/graph || exit 1; + steps/decode.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ + $exp/tri3a/graph data/dev $exp/tri3a/decode_dev || exit 1; +)& + +steps/align_fmllr.sh --nj 30 --cmd "$train_cmd" \ + data/train_sup data/lang $exp/tri3a $exp/tri3a_ali || exit 1; + +steps/train_sat.sh --cmd "$train_cmd" \ + 5000 100000 data/train_sup data/lang $exp/tri3a_ali $exp/tri4a || exit 1; + +( + utils/mkgraph.sh data/lang_test $exp/tri4a $exp/tri4a/graph + steps/decode_fmllr.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ + $exp/tri4a/graph data/dev $exp/tri4a/decode_dev +)& + +utils/combine_data.sh data/semisup100k_250k \ + data/train_sup data/train_unsup100k_250k || exit 1 + +if [ ! -f data/lang_poco_test_sup100k/G.fst ]; then + local/fisher_train_lms_pocolm.sh \ + --text data/train_sup/text \ + --dir data/local/lm_sup100k + + local/fisher_create_test_lang.sh \ + --arpa-lm data/local/pocolm_sup100k/data/arpa/4gram_small.arpa.gz \ + --dir data/lang_poco_test_sup100k +fi + +local/run_unk_model.sh || exit 1 + +for lang_dir in data/lang_poco_test_sup100k; do + rm -r ${lang_dir}_unk 2>/dev/null || true + mkdir -p ${lang_dir}_unk + cp -r data/lang_unk ${lang_dir}_unk + if [ -f ${lang_dir}/G.fst ]; then cp ${lang_dir}/G.fst ${lang_dir}_unk/G.fst; fi + if [ -f ${lang_dir}/G.carpa ]; then cp ${lang_dir}/G.carpa ${lang_dir}_unk/G.carpa; fi +done + +local/semisup/chain/run_tdnn.sh \ + --train-set train_sup \ + --ivector-train-set "" \ + --nnet3-affix "" --chain-affix "" \ + --tdnn-affix 1a --tree-affix bi_a \ + --gmm tri4a --exp $exp || exit 1 + +local/semisup/chain/run_tdnn_100k_semisupervised.sh \ + --supervised-set train_sup \ + --unsupervised-set train_unsup100k_250k \ + --semisup-train-set semisup100k_250k \ + --nnet3-affix "" --chain-affix "" \ + --tdnn-affix 1a --tree-affix bi_a \ + --gmm tri4a --exp $exp --stage 0 || exit 1 + +local/semisup/chain/run_tdnn.sh \ + --train-set semisup100k_250k \ + --nnet3-affix "" --chain-affix "" \ + --common-treedir exp/chain/tree_bi_a \ + --tdnn-affix 1a_oracle \ + --gmm tri4a --exp $exp \ + --stage 9 || exit 1 diff --git a/egs/fisher_english/s5/local/semisup/run_15k.sh b/egs/fisher_english/s5/local/semisup/run_15k.sh new file mode 100644 index 00000000000..381c35fdf79 --- /dev/null +++ b/egs/fisher_english/s5/local/semisup/run_15k.sh @@ -0,0 +1,117 @@ +#!/bin/bash + +# Copyright 2017 Vimal Manohar +# Apache 2.0 + +. cmd.sh +. path.sh + +. utils/parse_options.sh + +set -o pipefail +exp=exp/semisup_15k + +for f in data/train_sup/utt2spk data/train_unsup100k_250k/utt2spk; do + if [ ! -f $f ]; then + echo "$0: Could not find $f" + exit 1 + fi +done + +utils/subset_data_dir.sh --speakers data/train_sup 15000 data/train_sup15k || exit 1 +utils/subset_data_dir.sh --shortest data/train_sup15k 5000 data/train_sup15k_short || exit 1 +utils/subset_data_dir.sh data/train_sup15k 7500 data/train_sup15k_half || exit 1 + +steps/train_mono.sh --nj 10 --cmd "$train_cmd" \ + data/train_sup15k_short data/lang $exp/mono0a || exit 1 + +steps/align_si.sh --nj 30 --cmd "$train_cmd" \ + data/train_sup15k_half data/lang $exp/mono0a $exp/mono0a_ali || exit 1 + +steps/train_deltas.sh --cmd "$train_cmd" \ + 2000 10000 data/train_sup15k_half data/lang $exp/mono0a_ali $exp/tri1 || exit 1 + +(utils/mkgraph.sh data/lang_test $exp/tri1 $exp/tri1/graph + steps/decode.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ + $exp/tri1/graph data/dev $exp/tri1/decode_dev)& + +steps/align_si.sh --nj 30 --cmd "$train_cmd" \ + data/train_sup15k data/lang $exp/tri1 $exp/tri1_ali || exit 1; + +steps/train_lda_mllt.sh --cmd "$train_cmd" \ + 2500 15000 data/train_sup15k data/lang $exp/tri1_ali $exp/tri2 || exit 1; + +(utils/mkgraph.sh data/lang_test $exp/tri2 $exp/tri2/graph + steps/decode.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ + $exp/tri2/graph data/dev $exp/tri2/decode_dev)& + +steps/align_si.sh --nj 30 --cmd "$train_cmd" \ + data/train_sup15k data/lang $exp/tri2 $exp/tri2_ali || exit 1; + +steps/train_sat.sh --cmd "$train_cmd" \ + 2500 15000 data/train_sup15k data/lang $exp/tri2_ali $exp/tri3 || exit 1; + +( + utils/mkgraph.sh data/lang_test $exp/tri3 $exp/tri3/graph + steps/decode_fmllr.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ + $exp/tri3/graph data/dev $exp/tri3/decode_dev +)& + +utils/combine_data.sh data/semisup15k_100k_250k \ + data/train_sup15k data/train_unsup100k_250k || exit 1 + +mkdir -p data/local/pocolm_ex250k + +utils/filter_scp.pl --exclude data/train_unsup100k_250k/utt2spk \ + data/train/text > data/local/pocolm_ex250k/text.tmp + +if [ ! -f data/lang_test_poco_ex250k_big/G.carpa ]; then + local/fisher_train_lms_pocolm.sh \ + --text data/local/pocolm_ex250k/text.tmp \ + --dir data/local/pocolm_ex250k + + local/fisher_create_test_lang.sh \ + --arpa-lm data/local/pocolm_ex250k/data/arpa/4gram_small.arpa.gz \ + --dir data/lang_test_poco_ex250k + + utils/build_const_arpa_lm.sh \ + data/local/pocolm_ex250k/data/arpa/4gram_big.arpa.gz \ + data/lang_test_poco_ex250k data/lang_test_poco_ex250k_big +fi + +local/run_unk_model.sh || exit 1 + +for lang_dir in data/lang_test_poco_ex250k_big data/lang_test_poco_ex250k; do + rm -r ${lang_dir}_unk 2>/dev/null || true + mkdir -p ${lang_dir}_unk + cp -r data/lang_unk ${lang_dir}_unk + if [ -f ${lang_dir}/G.fst ]; then cp ${lang_dir}/G.fst ${lang_dir}_unk/G.fst; fi + if [ -f ${lang_dir}/G.carpa ]; then cp ${lang_dir}/G.carpa ${lang_dir}_unk/G.carpa; fi +done + +local/semisup/chain/run_tdnn.sh \ + --train-set train_sup15k \ + --ivector-train-set semisup15k_100k_250k \ + --nnet3-affix _semi15k_100k_250k \ + --chain-affix _semi15k_100k_250k \ + --tdnn-affix 1a --tree-affix bi_a \ + --hidden-dim 500 \ + --gmm tri3 --exp $exp || exit 1 + +local/semisup/chain/run_tdnn_50k_semisupervised.sh \ + --supervised-set train_sup15k \ + --unsupervised-set train_unsup100k_250k \ + --semisup-train-set semisup15k_100k_250k \ + --nnet3-affix _semi15k_100k_250k \ + --chain-affix _semi15k_100k_250k \ + --tdnn-affix 1a --tree-affix bi_a \ + --gmm tri3 --exp $exp --stage 0 || exit 1 + +local/semisup/chain/run_tdnn.sh \ + --train-set semisup15k_100k_250k \ + --nnet3-affix _semi15k_100k_250k \ + --chain-affix _semi15k_100k_250k \ + --common-treedir exp/chain_semi15k_100k_250k/tree_bi_a \ + --tdnn-affix 1a_oracle \ + --gmm tri3 --exp $exp \ + --stage 9 || exit 1 diff --git a/egs/fisher_english/s5/local/semisup/run_50k.sh b/egs/fisher_english/s5/local/semisup/run_50k.sh new file mode 100644 index 00000000000..5e7e69cef39 --- /dev/null +++ b/egs/fisher_english/s5/local/semisup/run_50k.sh @@ -0,0 +1,126 @@ +#!/bin/bash + +# Copyright 2017 Vimal Manohar +# Apache 2.0 + +. cmd.sh +. path.sh + +. utils/parse_options.sh + +set -o pipefail +exp=exp/semisup_50k + +for f in data/train_sup/utt2spk data/train_unsup100k_250k/utt2spk; do + if [ ! -f $f ]; then + echo "$0: Could not find $f" + exit 1 + fi +done + +utils/subset_data_dir.sh --speakers data/train_sup 50000 data/train_sup50k || exit 1 +utils/subset_data_dir.sh --shortest data/train_sup50k 25000 data/train_sup50k_short || exit 1 +utils/subset_data_dir.sh --speakers data/train_sup50k 30000 data/train_sup50k_30k || exit 1; + +steps/train_mono.sh --nj 10 --cmd "$train_cmd" \ + data/train_sup50k_short data/lang $exp/mono0a || exit 1 + +steps/align_si.sh --nj 30 --cmd "$train_cmd" \ + data/train_sup50k_30k data/lang $exp/mono0a $exp/mono0a_ali || exit 1 + +steps/train_deltas.sh --cmd "$train_cmd" \ + 2500 20000 data/train_sup50k_30k data/lang $exp/mono0a_ali $exp/tri1 || exit 1 + +(utils/mkgraph.sh data/lang_test $exp/tri1 $exp/tri1/graph + steps/decode.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ + $exp/tri1/graph data/dev $exp/tri1/decode_dev)& + +steps/align_si.sh --nj 30 --cmd "$train_cmd" \ + data/train_sup50k_30k data/lang $exp/tri1 $exp/tri1_ali || exit 1; + +steps/train_deltas.sh --cmd "$train_cmd" \ + 2500 20000 data/train_sup50k_30k data/lang $exp/tri1_ali $exp/tri2 || exit 1 + +(utils/mkgraph.sh data/lang_test $exp/tri2 $exp/tri2/graph + steps/decode.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ + $exp/tri2/graph data/dev $exp/tri2/decode_dev)& + +steps/align_si.sh --nj 30 --cmd "$train_cmd" \ + data/train_sup50k data/lang $exp/tri2 $exp/tri2_ali || exit 1; + +steps/train_lda_mllt.sh --cmd "$train_cmd" \ + 4000 30000 data/train_sup50k data/lang $exp/tri2_ali $exp/tri3a || exit 1; + +(utils/mkgraph.sh data/lang_test $exp/tri3a $exp/tri3a/graph + steps/decode.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ + $exp/tri3a/graph data/dev $exp/tri3a/decode_dev)& + +steps/align_fmllr.sh --nj 30 --cmd "$train_cmd" \ + data/train_sup50k data/lang $exp/tri3a $exp/tri3a_ali || exit 1; + +steps/train_sat.sh --cmd "$train_cmd" \ + 4000 50000 data/train_sup50k data/lang $exp/tri3a_ali $exp/tri4a || exit 1; + +( + utils/mkgraph.sh data/lang_test $exp/tri4a $exp/tri4a/graph + steps/decode_fmllr.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ + $exp/tri4a/graph data/dev $exp/tri4a/decode_dev +)& + +utils/combine_data.sh data/semisup50k_100k_250k \ + data/train_sup50k data/train_unsup100k_250k || exit 1 + +mkdir -p data/local/pocolm_ex250k + +utils/filter_scp.pl --exclude data/train_unsup100k_250k/utt2spk \ + data/train/text > data/local/pocolm_ex250k/text.tmp + +if [ ! -f data/lang_poco_test_ex250k_big/G.carpa ]; then + local/fisher_train_lms_pocolm.sh \ + --text data/local/pocolm_ex250k/text.tmp \ + --dir data/local/pocolm_ex250k + + local/fisher_create_test_lang.sh \ + --arpa-lm data/local/pocolm_ex250k/data/arpa/4gram_small.arpa.gz \ + --dir data/lang_poco_test_ex250k + + utils/build_const_arpa_lm.sh \ + data/local/pocolm_ex250k/data/arpa/4gram_big.arpa.gz \ + data/lang_poco_test_ex250k data/lang_poco_test_ex250k_big +fi + +local/run_unk_model.sh || exit 1 + +for lang_dir in data/lang_poco_test_ex250k_big data/lang_poco_test_ex250k; do + rm -r ${lang_dir}_unk 2>/dev/null || true + mkdir -p ${lang_dir}_unk + cp -r data/lang_unk ${lang_dir}_unk + if [ -f ${lang_dir}/G.fst ]; then cp ${lang_dir}/G.fst ${lang_dir}_unk/G.fst; fi + if [ -f ${lang_dir}/G.carpa ]; then cp ${lang_dir}/G.carpa ${lang_dir}_unk/G.carpa; fi +done + +local/semisup/chain/run_tdnn.sh \ + --train-set train_sup50k \ + --ivector-train-set semisup50k_100k_250k \ + --nnet3-affix _semi50k_100k_250k \ + --chain-affix _semi50k_100k_250k \ + --tdnn-affix 1a --tree-affix bi_a \ + --gmm tri4a --exp $exp || exit 1 + +local/semisup/chain/run_tdnn_50k_semisupervised.sh \ + --supervised-set train_sup50k \ + --unsupervised-set train_unsup100k_250k \ + --semisup-train-set semisup50k_100k_250k \ + --nnet3-affix _semi50k_100k_250k \ + --chain-affix _semi50k_100k_250k \ + --tdnn-affix 1a --tree-affix bi_a \ + --gmm tri4a --exp $exp --stage 0 || exit 1 + +local/semisup/chain/run_tdnn.sh \ + --train-set semisup50k_100k_250k \ + --nnet3-affix _semi50k_100k_250k \ + --chain-affix _semi50k_100k_250k \ + --common-treedir exp/chain_semi50k_100k_250k/tree_bi_a \ + --tdnn-affix 1a_oracle \ + --gmm tri4a --exp $exp \ + --stage 9 || exit 1 diff --git a/egs/fisher_english/s5/local/wer_output_filter b/egs/fisher_english/s5/local/wer_output_filter new file mode 100755 index 00000000000..2514c385038 --- /dev/null +++ b/egs/fisher_english/s5/local/wer_output_filter @@ -0,0 +1,16 @@ +#!/usr/bin/perl + +@filter_words = ('[NOISE]', '[LAUGHTER]', '[VOCALIZED-NOISE]', '', '%HESITATION'); +foreach $w (@filter_words) { + $bad{$w} = 1; $w = lc $w; $bad{$w} = 1; +} +while() { + @A = split(" ", $_); + $id = shift @A; + print "$id "; + + foreach $a (@A) { + if (!defined $bad{$a}) { print "$a "; } + } + print "\n"; +} diff --git a/egs/fisher_english/s5/path.sh b/egs/fisher_english/s5/path.sh index 1a6fb5f891b..7cad3842ab3 100755 --- a/egs/fisher_english/s5/path.sh +++ b/egs/fisher_english/s5/path.sh @@ -2,4 +2,6 @@ export KALDI_ROOT=`pwd`/../../.. export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH [ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 . $KALDI_ROOT/tools/config/common_path.sh +export PYTHONPATH=${PYTHONPATH:+$PYTHONPATH:}$KALDI_ROOT/tools/tensorflow_build/.local/lib/python2.7/site-packages +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$KALDI_ROOT/tools/tensorflow/bazel-bin/tensorflow/:/usr/local/cuda/lib64:/export/a11/hlyu/cudnn/lib64:/home/dpovey/libs/ export LC_ALL=C diff --git a/egs/wsj/s5/steps/best_path_weights.sh b/egs/wsj/s5/steps/best_path_weights.sh new file mode 100755 index 00000000000..67c303a7213 --- /dev/null +++ b/egs/wsj/s5/steps/best_path_weights.sh @@ -0,0 +1,183 @@ +#!/bin/bash + +# Copyright 2014-17 Vimal Manohar + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This script combines frame-level posteriors from different decode +# directories. The first decode directory is assumed to be the primary +# and is used to get the best path. The posteriors from other decode +# directories are interpolated with the posteriors of the best path. +# The output is a new directory with final.mdl, tree from the primary +# decode-dir and the best path alignments and weights in a decode-directory +# with the same basename as the primary directory. +# This is typically used to get better posteriors for semisupervised training +# of DNN +# e.g. steps/best_path_weights.sh exp/tri6_nnet/decode_train_unt.seg +# exp/sgmm_mmi_b0.1/decode_fmllr_train_unt.seg_it4 exp/combine_dnn_sgmm +# Here the final.mdl and tree are copied from exp/tri6_nnet to +# exp/combine_dnn_sgmm. ali.*.gz obtained from the primary dir and +# the interpolated posteriors in weights.scp are placed in +# exp/combine_dnn_sgmm/decode_train_unt.seg + +set -e + +# begin configuration section. +cmd=run.pl +stage=-10 +acwt=0.1 +write_words=false # Dump the word-level transcript in addition to the best path alignments +#end configuration section. + +cat < [:weight] [:weight] [[:weight] ... ] + E.g. $0 data/train_unt.seg data/lang exp/tri1/decode:0.5 exp/tri2/decode:0.25 exp/tri3/decode:0.25 exp/combine + Options: + --cmd (run.pl|queue.pl...) # specify how to run the sub-processes. +EOF + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -lt 4 ]; then + printf "$help_message\n"; + exit 1; +fi + +data=$1 +lang=$2 +dir=${@: -1} # last argument to the script +shift 2; +decode_dirs=( $@ ) # read the remaining arguments into an array +unset decode_dirs[${#decode_dirs[@]}-1] # 'pop' the last argument which is odir +num_sys=${#decode_dirs[@]} # number of systems to combine + +mkdir -p $dir +mkdir -p $dir/log + +decode_dir=`echo ${decode_dirs[0]} | cut -d: -f1` +nj=`cat $decode_dir/num_jobs` + +mkdir -p $dir + +words_wspecifier=ark:/dev/null +if $write_words; then + words_wspecifier="ark,t:| utils/int2sym.pl -f 2- $lang/words.txt > $dir/text.JOB" +fi + +if [ $stage -lt -1 ]; then + mkdir -p $dir/log + $cmd JOB=1:$nj $dir/log/best_path.JOB.log \ + lattice-best-path --acoustic-scale=$acwt \ + "ark,s,cs:gunzip -c $decode_dir/lat.JOB.gz |" \ + "$words_wspecifier" "ark:| gzip -c > $dir/ali.JOB.gz" || exit 1 +fi + +if [ -f `dirname $decode_dir`/final.mdl ]; then + src_dir=`dirname $decode_dir` +else + src_dir=$decode_dir +fi + +cp $src_dir/cmvn_opts $dir/ || exit 1 +for f in final.mat splice_opts frame_subsampling_factor; do + [ -f $src_dir/$f ] && cp $src_dir/$f $dir +done + +weights_sum=0.0 + +for i in `seq 0 $[num_sys-1]`; do + decode_dir=${decode_dirs[$i]} + + weight=`echo $decode_dir | cut -d: -s -f2` + [ -z "$weight" ] && weight=1.0 + + if [ $i -eq 0 ]; then + file_list="\"ark:vector-scale --scale=$weight ark:$dir/weights.$i.JOB.ark ark:- |\"" + else + file_list="$file_list \"ark,s,cs:vector-scale --scale=$weight ark:$dir/weights.$i.JOB.ark ark:- |\"" + fi + + weights_sum=`perl -e "print STDOUT $weights_sum + $weight"` +done + +inv_weights_sum=`perl -e "print STDOUT 1.0/$weights_sum"` + +fdir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $dir ${PWD}` + +for i in `seq 0 $[num_sys-1]`; do + if [ $stage -lt $i ]; then + decode_dir=`echo ${decode_dirs[$i]} | cut -d: -f1` + if [ -f `dirname $decode_dir`/final.mdl ]; then + # model one level up from decode dir + this_srcdir=`dirname $decode_dir` + else + this_srcdir=$decode_dir + fi + + model=$this_srcdir/final.mdl + tree=$this_srcdir/tree + + for f in $model $decode_dir/lat.1.gz $tree; do + [ ! -f $f ] && echo "$0: expecting file $f to exist" && exit 1; + done + if [ $i -eq 0 ]; then + nj=`cat $decode_dir/num_jobs` || exit 1; + cp $model $dir || exit 1 + cp $tree $dir || exit 1 + echo $nj > $dir/num_jobs + else + if [ $nj != `cat $decode_dir/num_jobs` ]; then + echo "$0: number of decoding jobs mismatches, $nj versus `cat $decode_dir/num_jobs`" + exit 1; + fi + fi + + $cmd JOB=1:$nj $dir/log/get_post.$i.JOB.log \ + lattice-to-post --acoustic-scale=$acwt \ + "ark,s,cs:gunzip -c $decode_dir/lat.JOB.gz|" ark:- \| \ + post-to-pdf-post $model ark,s,cs:- ark:- \| \ + get-post-on-ali ark,s,cs:- "ark,s,cs:gunzip -c $dir/ali.JOB.gz | convert-ali $dir/final.mdl $model $tree ark,s,cs:- ark:- | ali-to-pdf $model ark,s,cs:- ark:- |" "ark,scp:$fdir/weights.$i.JOB.ark,$fdir/weights.$i.JOB.scp" || exit 1 + fi +done + +if [ $stage -lt $num_sys ]; then + if [ "$num_sys" -eq 1 ]; then + for n in `seq $nj`; do + cat $dir/weights.0.$n.scp + done > $dir/weights.scp + else + $cmd JOB=1:$nj $dir/log/interpolate_post.JOB.log \ + vector-sum $file_list ark:- \| \ + vector-scale --scale=$inv_weights_sum ark:- \ + ark,scp:$fdir/weights.JOB.ark,$fdir/weights.JOB.scp || exit 1 + + for n in `seq $nj`; do + cat $dir/weights.$n.scp + done > $dir/weights.scp + fi +fi + +for n in `seq 1 $[num_sys-1]`; do + rm $dir/weights.$n.*.ark $dir/weights.$n.*.scp +done + +if $write_words; then + for n in `seq $nj`; do + cat $dir/text.$n + done > $dir/text +fi + +exit 0 diff --git a/egs/wsj/s5/steps/libs/nnet3/report/log_parse.py b/egs/wsj/s5/steps/libs/nnet3/report/log_parse.py index d5f2575d582..fe16663e3d8 100755 --- a/egs/wsj/s5/steps/libs/nnet3/report/log_parse.py +++ b/egs/wsj/s5/steps/libs/nnet3/report/log_parse.py @@ -399,7 +399,7 @@ def generate_acc_logprob_report(exp_dir, key="accuracy", output="output"): except: tb = traceback.format_exc() logger.warning("Error getting info from logs, exception was: " + tb) - times = [] + times = {} report = [] report.append("%Iter\tduration\ttrain_objective\tvalid_objective\tdifference") @@ -413,7 +413,7 @@ def generate_acc_logprob_report(exp_dir, key="accuracy", output="output"): try: report.append("%d\t%s\t%g\t%g\t%g" % (x[0], str(times[x[0]]), x[1], x[2], x[2]-x[1])) - except KeyError: + except KeyError, IndexError: continue total_time = 0 diff --git a/egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py b/egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py index 5b640510ea1..c7b3514428b 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py @@ -129,7 +129,8 @@ def train_new_models(dir, iter, srand, num_jobs, momentum, max_param_change, shuffle_buffer_size, num_chunk_per_minibatch_str, frame_subsampling_factor, run_opts, - backstitch_training_scale=0.0, backstitch_training_interval=1): + backstitch_training_scale=0.0, backstitch_training_interval=1, + use_multitask_egs=False): """ Called from train_one_iteration(), this method trains new models with 'num_jobs' jobs, and @@ -140,6 +141,12 @@ def train_new_models(dir, iter, srand, num_jobs, to use for each job is a little complex, so we spawn each one separately. this is no longer true for RNNs as we use do not use the --frame option but we use the same script for consistency with FF-DNN code + + use_multitask_egs : True, if different examples used to train multiple + tasks or outputs, e.g.multilingual training. + multilingual egs can be generated using get_egs.sh and + steps/nnet3/multilingual/allocate_multilingual_examples.py, + those are the top-level scripts. """ deriv_time_opts = [] @@ -167,6 +174,12 @@ def train_new_models(dir, iter, srand, num_jobs, frame_shift = ((archive_index + k/num_archives) % frame_subsampling_factor) + multitask_egs_opts = common_train_lib.get_multitask_egs_opts( + egs_dir, + egs_prefix="cegs.", + archive_index=archive_index, + use_multitask_egs=use_multitask_egs) + scp_or_ark = "scp" if use_multitask_egs else "ark" cache_io_opts = (("--read-cache={dir}/cache.{iter}".format(dir=dir, iter=iter) if iter > 0 else "") + @@ -187,9 +200,9 @@ def train_new_models(dir, iter, srand, num_jobs, --l2-regularize-factor={l2_regularize_factor} \ --srand={srand} \ "{raw_model}" {dir}/den.fst \ - "ark,bg:nnet3-chain-copy-egs \ + "ark,bg:nnet3-chain-copy-egs {multitask_egs_opts} \ --frame-shift={fr_shft} \ - ark:{egs_dir}/cegs.{archive_index}.ark ark:- | \ + {scp_or_ark}:{egs_dir}/cegs.{archive_index}.{scp_or_ark} ark:- | \ nnet3-chain-shuffle-egs --buffer-size={buf_size} \ --srand={srand} ark:- ark:- | nnet3-chain-merge-egs \ --minibatch-size={num_chunk_per_mb} ark:- ark:- |" \ @@ -212,17 +225,17 @@ def train_new_models(dir, iter, srand, num_jobs, raw_model=raw_model_string, egs_dir=egs_dir, archive_index=archive_index, buf_size=shuffle_buffer_size, - num_chunk_per_mb=num_chunk_per_minibatch_str), + num_chunk_per_mb=num_chunk_per_minibatch_str, + multitask_egs_opts=multitask_egs_opts, + scp_or_ark=scp_or_ark) require_zero_status=True) threads.append(thread) - for thread in threads: thread.join() - def train_one_iteration(dir, iter, srand, egs_dir, num_jobs, num_archives_processed, num_archives, learning_rate, shrinkage_value, @@ -234,7 +247,8 @@ def train_one_iteration(dir, iter, srand, egs_dir, momentum, max_param_change, shuffle_buffer_size, frame_subsampling_factor, run_opts, dropout_edit_string="", - backstitch_training_scale=0.0, backstitch_training_interval=1): + backstitch_training_scale=0.0, backstitch_training_interval=1, + use_multitask_egs=False): """ Called from steps/nnet3/chain/train.py for one iteration for neural network training with LF-MMI objective @@ -264,7 +278,8 @@ def train_one_iteration(dir, iter, srand, egs_dir, compute_train_cv_probabilities( dir=dir, iter=iter, egs_dir=egs_dir, l2_regularize=l2_regularize, xent_regularize=xent_regularize, - leaky_hmm_coefficient=leaky_hmm_coefficient, run_opts=run_opts) + leaky_hmm_coefficient=leaky_hmm_coefficient, run_opts=run_opts, + use_multitask_egs=use_multitask_egs) if iter > 0: # Runs in the background @@ -306,12 +321,14 @@ def train_one_iteration(dir, iter, srand, egs_dir, shuffle_buffer_size=shuffle_buffer_size, num_chunk_per_minibatch_str=cur_num_chunk_per_minibatch_str, frame_subsampling_factor=frame_subsampling_factor, + truncate_deriv_weights=truncate_deriv_weights, run_opts=run_opts, # linearly increase backstitch_training_scale during the # first few iterations (hard-coded as 15) backstitch_training_scale=(backstitch_training_scale * iter / 15 if iter < 15 else backstitch_training_scale), - backstitch_training_interval=backstitch_training_interval) + backstitch_training_interval=backstitch_training_interval, + use_multitask_egs=use_multitask_egs) [models_to_average, best_model] = common_train_lib.get_successful_models( num_jobs, '{0}/log/train.{1}.%.log'.format(dir, iter)) @@ -351,11 +368,13 @@ def train_one_iteration(dir, iter, srand, egs_dir, os.remove("{0}/cache.{1}".format(dir, iter)) -def check_for_required_files(feat_dir, tree_dir, lat_dir): +def check_for_required_files(feat_dir, tree_dir, lat_dir=None): files = ['{0}/feats.scp'.format(feat_dir), '{0}/ali.1.gz'.format(tree_dir), - '{0}/final.mdl'.format(tree_dir), '{0}/tree'.format(tree_dir), + '{0}/final.mdl'.format(tree_dir), '{0}/tree'.format(tree_dir)] + if lat_dir is not None: + files += [ '{0}/lat.1.gz'.format(lat_dir), '{0}/final.mdl'.format(lat_dir), - '{0}/num_jobs'.format(lat_dir), '{0}/splice_opts'.format(lat_dir)] + '{0}/num_jobs'.format(lat_dir)] for file in files: if not os.path.isfile(file): raise Exception('Expected {0} to exist.'.format(file)) @@ -363,7 +382,7 @@ def check_for_required_files(feat_dir, tree_dir, lat_dir): def compute_preconditioning_matrix(dir, egs_dir, num_lda_jobs, run_opts, max_lda_jobs=None, rand_prune=4.0, - lda_opts=None): + lda_opts=None, use_multitask_egs=False): """ Function to estimate and write LDA matrix from cegs This function is exactly similar to the version in module @@ -373,17 +392,28 @@ def compute_preconditioning_matrix(dir, egs_dir, num_lda_jobs, run_opts, if max_lda_jobs is not None: if num_lda_jobs > max_lda_jobs: num_lda_jobs = max_lda_jobs + multitask_egs_opts = common_train_lib.get_multitask_egs_opts( + egs_dir, + egs_prefix="cegs.", + archive_index="JOB", + use_multitask_egs=use_multitask_egs) + scp_or_ark = "scp" if use_multitask_egs else "ark" + egs_rspecifier = ( + "ark:nnet3-chain-copy-egs {multitask_egs_opts} " + "{scp_or_ark}:{egs_dir}/cegs.JOB.{scp_or_ark} ark:- |" + "".format(egs_dir=egs_dir, scp_or_ark=scp_or_ark, + multitask_egs_opts=multitask_egs_opts)) # Write stats with the same format as stats for LDA. common_lib.execute_command( """{command} JOB=1:{num_lda_jobs} {dir}/log/get_lda_stats.JOB.log \ nnet3-chain-acc-lda-stats --rand-prune={rand_prune} \ - {dir}/init.raw "ark:{egs_dir}/cegs.JOB.ark" \ + {dir}/init.raw "{egs_rspecifier}" \ {dir}/JOB.lda_stats""".format( command=run_opts.command, num_lda_jobs=num_lda_jobs, dir=dir, - egs_dir=egs_dir, + egs_rspecifier=egs_rspecifier, rand_prune=rand_prune)) # the above command would have generated dir/{1..num_lda_jobs}.lda_stats @@ -444,32 +474,50 @@ def prepare_initial_acoustic_model(dir, run_opts, srand=-1, input_model=None): def compute_train_cv_probabilities(dir, iter, egs_dir, l2_regularize, xent_regularize, leaky_hmm_coefficient, - run_opts): + run_opts, + use_multitask_egs=False): model = '{0}/{1}.mdl'.format(dir, iter) + scp_or_ark = "scp" if use_multitask_egs else "ark" + egs_suffix = ".scp" if use_multitask_egs else ".cegs" + + multitask_egs_opts = common_train_lib.get_multitask_egs_opts( + egs_dir, + egs_prefix="valid_diagnostic.", + use_multitask_egs=use_multitask_egs) + common_lib.background_command( """{command} {dir}/log/compute_prob_valid.{iter}.log \ nnet3-chain-compute-prob --l2-regularize={l2} \ --leaky-hmm-coefficient={leaky} --xent-regularize={xent_reg} \ "nnet3-am-copy --raw=true {model} - |" {dir}/den.fst \ - "ark,bg:nnet3-chain-copy-egs ark:{egs_dir}/valid_diagnostic.cegs \ + "ark,bg:nnet3-chain-copy-egs {multitask_egs_opts} {scp_or_ark}:{egs_dir}/valid_diagnostic{egs_suffix} \ ark:- | nnet3-chain-merge-egs --minibatch-size=1:64 ark:- ark:- |" \ """.format(command=run_opts.command, dir=dir, iter=iter, model=model, l2=l2_regularize, leaky=leaky_hmm_coefficient, xent_reg=xent_regularize, - egs_dir=egs_dir)) + egs_dir=egs_dir, + multitask_egs_opts=multitask_egs_opts, + scp_or_ark=scp_or_ark, egs_suffix=egs_suffix)) + + multitask_egs_opts = common_train_lib.get_multitask_egs_opts( + egs_dir, + egs_prefix="train_diagnostic.", + use_multitask_egs=use_multitask_egs) common_lib.background_command( """{command} {dir}/log/compute_prob_train.{iter}.log \ nnet3-chain-compute-prob --l2-regularize={l2} \ --leaky-hmm-coefficient={leaky} --xent-regularize={xent_reg} \ "nnet3-am-copy --raw=true {model} - |" {dir}/den.fst \ - "ark,bg:nnet3-chain-copy-egs ark:{egs_dir}/train_diagnostic.cegs \ + "ark,bg:nnet3-chain-copy-egs {multitask_egs_opts} {scp_or_ark}:{egs_dir}/train_diagnostic{egs_suffix} \ ark:- | nnet3-chain-merge-egs --minibatch-size=1:64 ark:- ark:- |" \ """.format(command=run_opts.command, dir=dir, iter=iter, model=model, l2=l2_regularize, leaky=leaky_hmm_coefficient, xent_reg=xent_regularize, - egs_dir=egs_dir)) + egs_dir=egs_dir, + multitask_egs_opts=multitask_egs_opts, + scp_or_ark=scp_or_ark, egs_suffix=egs_suffix)) def compute_progress(dir, iter, run_opts): @@ -489,10 +537,12 @@ def compute_progress(dir, iter, run_opts): model=model, prev_model=prev_model)) + def combine_models(dir, num_iters, models_to_combine, num_chunk_per_minibatch_str, egs_dir, leaky_hmm_coefficient, l2_regularize, xent_regularize, run_opts, - max_objective_evaluations=30): + max_objective_evaluations=30, + use_multitask_egs=False): """ Function to do model combination In the nnet3 setup, the logic @@ -515,6 +565,14 @@ def combine_models(dir, num_iters, models_to_combine, num_chunk_per_minibatch_st print("{0}: warning: model file {1} does not exist " "(final combination)".format(sys.argv[0], model_file)) + scp_or_ark = "scp" if use_multitask_egs else "ark" + egs_suffix = ".scp" if use_multitask_egs else ".cegs" + + multitask_egs_opts = common_train_lib.get_multitask_egs_opts( + egs_dir, + egs_prefix="combine.", + use_multitask_egs=use_multitask_egs) + # We reverse the order of the raw model strings so that the freshest one # goes first. This is important for systems that include batch # normalization-- it means that the freshest batch-norm stats are used. @@ -529,7 +587,7 @@ def combine_models(dir, num_iters, models_to_combine, num_chunk_per_minibatch_st --max-objective-evaluations={max_objective_evaluations} \ --l2-regularize={l2} --leaky-hmm-coefficient={leaky} \ --verbose=3 {dir}/den.fst {raw_models} \ - "ark,bg:nnet3-chain-copy-egs ark:{egs_dir}/combine.cegs ark:- | \ + "ark,bg:nnet3-chain-copy-egs {multitask_egs_opts} {scp_or_ark}:{egs_dir}/combine{egs_suffix} ark:- | \ nnet3-chain-merge-egs --minibatch-size={num_chunk_per_mb} \ ark:- ark:- |" - \| \ nnet3-am-copy --set-raw-nnet=- {dir}/{num_iters}.mdl \ @@ -541,7 +599,9 @@ def combine_models(dir, num_iters, models_to_combine, num_chunk_per_minibatch_st dir=dir, raw_models=" ".join(raw_model_strings), num_chunk_per_mb=num_chunk_per_minibatch_str, num_iters=num_iters, - egs_dir=egs_dir)) + egs_dir=egs_dir, + multitask_egs_opts=multitask_egs_opts, + scp_or_ark=scp_or_ark, egs_suffix=egs_suffix)) # Compute the probability of the final, combined model with # the same subset we used for the previous compute_probs, as the @@ -550,4 +610,5 @@ def combine_models(dir, num_iters, models_to_combine, num_chunk_per_minibatch_st dir=dir, iter='final', egs_dir=egs_dir, l2_regularize=l2_regularize, xent_regularize=xent_regularize, leaky_hmm_coefficient=leaky_hmm_coefficient, - run_opts=run_opts) + run_opts=run_opts, + use_multitask_egs=use_multitask_egs) diff --git a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py index 8bdcd160409..1d8efc249a9 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py @@ -326,21 +326,32 @@ def train_one_iteration(dir, iter, srand, egs_dir, def compute_preconditioning_matrix(dir, egs_dir, num_lda_jobs, run_opts, max_lda_jobs=None, rand_prune=4.0, - lda_opts=None): + lda_opts=None, use_multitask_egs=False): if max_lda_jobs is not None: if num_lda_jobs > max_lda_jobs: num_lda_jobs = max_lda_jobs + multitask_egs_opts = common_train_lib.get_multitask_egs_opts( + egs_dir, + egs_prefix="egs.", + archive_index="JOB", + use_multitask_egs=use_multitask_egs) + scp_or_ark = "scp" if use_multitask_egs else "ark" + egs_rspecifier = ( + "ark:nnet3-copy-egs {multitask_egs_opts} " + "{scp_or_ark}:{egs_dir}/egs.JOB.{scp_or_ark} ark:- |" + "".format(egs_dir=egs_dir, scp_or_ark=scp_or_ark, + multitask_egs_opts=multitask_egs_opts)) # Write stats with the same format as stats for LDA. common_lib.execute_command( """{command} JOB=1:{num_lda_jobs} {dir}/log/get_lda_stats.JOB.log \ nnet3-acc-lda-stats --rand-prune={rand_prune} \ - {dir}/init.raw "ark:{egs_dir}/egs.JOB.ark" \ + {dir}/init.raw "{egs_rspecifier}" \ {dir}/JOB.lda_stats""".format( command=run_opts.command, num_lda_jobs=num_lda_jobs, dir=dir, - egs_dir=egs_dir, + egs_rspecifier=egs_rspecifier, rand_prune=rand_prune)) # the above command would have generated dir/{1..num_lda_jobs}.lda_stats diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py index 05ae5bcdc18..891a038030e 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py @@ -367,6 +367,14 @@ class XconfigTrivialOutputLayer(XconfigLayerBase): This is for outputs that are not really output "layers" (there is no affine transform or nonlinearity), they just directly map to an output-node in nnet3. + + Parameters of the class, and their defaults: + input='[-1]' : Descriptor giving the input of the layer. + objective-type=linear : the only other choice currently is + 'quadratic', for use in regression problems + output-delay=0 : Can be used to shift the frames on the output, equivalent + to delaying labels by this many frames (positive value increases latency + in online decoding but may help if you're using unidirectional LSTMs. """ def __init__(self, first_token, key_to_value, prev_names=None): @@ -378,11 +386,17 @@ def set_default_configs(self): # note: self.config['input'] is a descriptor, '[-1]' means output # the most recent layer. - self.config = {'input': '[-1]', 'dim': -1} + self.config = {'input': '[-1]', 'dim': -1, + 'objective-type': 'linear', + 'output-delay': 0} def check_configs(self): - pass # nothing to check; descriptor-parsing can't happen in this function. + if self.config['objective-type'] != 'linear' and \ + self.config['objective-type'] != 'quadratic': + raise RuntimeError("In output, objective-type has" + " invalid value {0}" + "".format(self.config['objective-type'])) def output_name(self, auxiliary_outputs=None): @@ -413,11 +427,19 @@ def get_full_config(self): # by 'output-string' we mean a string that can appear in # config-files, i.e. it contains the 'final' names of nodes. descriptor_final_str = self.descriptors['input']['final-string'] + objective_type = self.config['objective-type'] + output_delay = self.config['output-delay'] - for config_name in ['init', 'ref', 'final']: + if output_delay != 0: + descriptor_final_str = ( + 'Offset({0}, {1})'.format(descriptor_final_str, output_delay)) + + for config_name in ['ref', 'final']: ans.append((config_name, - 'output-node name={0} input={1}'.format( - self.name, descriptor_final_str))) + 'output-node name={0} input={1} ' + 'objective={2}'.format( + self.name, descriptor_final_str, + objective_type))) return ans @@ -509,28 +531,38 @@ def check_configs(self): " invalid value {0}" "".format(self.config['learning-rate-factor'])) - # you cannot access the output of this layer from other layers... see - # comment in output_name for the reason why. def auxiliary_outputs(self): - return [] + auxiliary_outputs = ['affine'] + if self.config['include-log-softmax']: + auxiliary_outputs.append('log-softmax') - def output_name(self, auxiliary_outputs=None): + return auxiliary_outputs + + def output_name(self, auxiliary_output=None): - # Note: nodes of type output-node in nnet3 may not be accessed in - # Descriptors, so calling this with auxiliary_outputs=None doesn't - # make sense. But it might make sense to make the output of the softmax - # layer and/or the output of the affine layer available as inputs to - # other layers, in some circumstances. - # we'll implement that when it's needed. - raise RuntimeError("Outputs of output-layer may not be used by other" - " layers") + if auxiliary_output is None: + # Note: nodes of type output-node in nnet3 may not be accessed in + # Descriptors, so calling this with auxiliary_outputs=None doesn't + # make sense. + raise RuntimeError("Outputs of output-layer may not be used by other" + " layers") + + if auxiliary_output in self.auxiliary_outputs(): + return '{0}.{1}'.format(self.name, auxiliary_output) + else: + raise RuntimeError("Unknown auxiliary output name {0}" + "".format(auxiliary_output)) def output_dim(self, auxiliary_output=None): - # see comment in output_name(). - raise RuntimeError("Outputs of output-layer may not be used by other" - " layers") + if auxiliary_output is None: + # Note: nodes of type output-node in nnet3 may not be accessed in + # Descriptors, so calling this with auxiliary_outputs=None doesn't + # make sense. + raise RuntimeError("Outputs of output-layer may not be used by other" + " layers") + return self.config['dim'] def get_full_config(self): diff --git a/egs/wsj/s5/steps/lmrescore_const_arpa_undeterminized.sh b/egs/wsj/s5/steps/lmrescore_const_arpa_undeterminized.sh new file mode 100755 index 00000000000..7673aa0960c --- /dev/null +++ b/egs/wsj/s5/steps/lmrescore_const_arpa_undeterminized.sh @@ -0,0 +1,93 @@ +#!/bin/bash + +# Copyright 2014 Guoguo Chen +# 2017 Vimal Manohar +# Apache 2.0 + +# This script rescores non-compact undeterminized lattices with the +# ConstArpaLm format language model. +# This is similar to steps/lmrescore_const_arpa.sh, but expects +# undeterminized non-compact lattices as input. + +# Begin configuration section. +cmd=run.pl +skip_scoring=false +stage=1 +scoring_opts= +write_compact=true +acwt=0.1 +beam=8.0 # beam used in determinization + +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +. ./utils/parse_options.sh + +if [ $# != 5 ]; then + cat < \\ + + options: [--cmd (run.pl|queue.pl [queue opts])] + See also: steps/lmrescore_const_arpa.sh +EOF + exit 1; +fi + +[ -f path.sh ] && . ./path.sh; + +oldlang=$1 +newlang=$2 +data=$3 +indir=$4 +outdir=$5 + +oldlm=$oldlang/G.fst +newlm=$newlang/G.carpa +! cmp $oldlang/words.txt $newlang/words.txt &&\ + echo "$0: Warning: vocabularies may be incompatible." +[ ! -f $oldlm ] && echo "$0: Missing file $oldlm" && exit 1; +[ ! -f $newlm ] && echo "$0: Missing file $newlm" && exit 1; +! ls $indir/lat.*.gz >/dev/null &&\ + echo "$0: No lattices input directory $indir" && exit 1; + +if ! cmp -s $oldlang/words.txt $newlang/words.txt; then + echo "$0: $oldlang/words.txt and $newlang/words.txt differ: make sure you know what you are doing."; +fi + +oldlmcommand="fstproject --project_output=true $oldlm |" + +mkdir -p $outdir/log +nj=`cat $indir/num_jobs` || exit 1; +cp $indir/num_jobs $outdir + +lats_rspecifier="ark:gunzip -c $indir/lat.JOB.gz |" + +lats_wspecifier="ark:| gzip -c > $outdir/lat.JOB.gz" + +if [ $stage -le 1 ]; then + $cmd JOB=1:$nj $outdir/log/rescorelm.JOB.log \ + lattice-determinize-pruned --acoustic-scale=$acwt --beam=$beam \ + "ark:gunzip -c $indir/lat.JOB.gz |" ark:- \| \ + lattice-scale --lm-scale=0.0 --acoustic-scale=0.0 ark:- ark:- \| \ + lattice-lmrescore --lm-scale=-1.0 ark:- "$oldlmcommand" ark:- \| \ + lattice-lmrescore-const-arpa --lm-scale=1.0 \ + ark:- "$newlm" ark:- \| \ + lattice-project ark:- ark:- \| \ + lattice-compose --write-compact=$write_compact \ + "$lats_rspecifier" \ + ark,s,cs:- "$lats_wspecifier" || exit 1 +fi + +if ! $skip_scoring && [ $stage -le 2 ]; then + err_msg="Not scoring because local/score.sh does not exist or not executable." + [ ! -x local/score.sh ] && echo $err_msg && exit 1; + local/score.sh --cmd "$cmd" $scoring_opts $data $newlang $outdir +else + echo "Not scoring because requested so..." +fi + +exit 0; diff --git a/egs/wsj/s5/steps/nnet3/chain/build_tree_multiple_sources.sh b/egs/wsj/s5/steps/nnet3/chain/build_tree_multiple_sources.sh new file mode 100755 index 00000000000..6892a2ff1ee --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/chain/build_tree_multiple_sources.sh @@ -0,0 +1,275 @@ +#!/bin/bash +# Copyright 2012-2015 Johns Hopkins University (Author: Daniel Povey). +# 2017 Vimal Manohar +# Apache 2.0. + +# This script is similar to steps/nnet3/chain/build_tree.sh but supports +# getting statistics from multiple alignment sources. + + +# Begin configuration section. +stage=-5 +exit_stage=-100 # you can use this to require it to exit at the + # beginning of a specific stage. Not all values are + # supported. +cmd=run.pl +use_fmllr=true # If true, fmllr transforms will be applied from the alignment directories. + # Otherwise, no fmllr will be applied even if alignment directory contains trans.* +context_opts= # e.g. set this to "--context-width 5 --central-position 2" for quinphone. +cluster_thresh=-1 # for build-tree control final bottom-up clustering of leaves +frame_subsampling_factor=1 # frame subsampling factor of output w.r.t. to the input features +tree_stats_opts= +cluster_phones_opts= +repeat_frames=false +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +[ -f path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -lt 5 ]; then + echo "Usage: steps/nnet3/chain/build_tree_multiple_sources.sh <#leaves> [ ... ] " + echo " e.g.: steps/nnet3/chain/build_tree_multiple_sources.sh 15000 data/lang data/train_sup exp/tri3_ali data/train_unsup exp/tri3/best_path_train_unsup exp/tree_semi" + echo "Main options (for others, see top of script file)" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --config # config containing options" + echo " --stage # stage to do partial re-run from." + echo " --repeat-frames # Only affects alignment conversion at" + echo " # the end. If true, generate an " + echo " # alignment using the frame-subsampled " + echo " # topology that is repeated " + echo " # --frame-subsampling-factor times " + echo " # and interleaved, to be the same " + echo " # length as the original alignment " + echo " # (useful for cross-entropy training " + echo " # of reduced frame rate systems)." + exit 1; +fi + +numleaves=$1 +lang=$2 +dir=${@: -1} # last argument to the script +shift 2; +data_and_alidirs=( $@ ) # read the remaining arguments into an array +unset data_and_alidirs[${#data_and_alidirs[@]}-1] # 'pop' the last argument which is odir +num_sys=$[${#data_and_alidirs[@]}] # number of systems to combine + +if (( $num_sys % 2 != 0 )); then + echo "$0: The data and alignment arguments must be an even number of arguments." + exit 1 +fi + +num_sys=$((num_sys / 2)) + +data=$dir/data_tmp +mkdir -p $data + +mkdir -p $dir +alidir=`echo ${data_and_alidirs[1]}` + +datadirs=() +alidirs=() +for n in `seq 0 $[num_sys-1]`; do + datadirs[$n]=${data_and_alidirs[$[2*n]]} + alidirs[$n]=${data_and_alidirs[$[2*n+1]]} +done + +utils/combine_data.sh $data ${datadirs[@]} || exit 1 + +for f in $data/feats.scp $lang/phones.txt $alidir/final.mdl $alidir/ali.1.gz; do + [ ! -f $f ] && echo "$0: no such file $f" && exit 1; +done + +oov=`cat $lang/oov.int` +nj=`cat $alidir/num_jobs` || exit 1; +silphonelist=`cat $lang/phones/silence.csl` +ciphonelist=`cat $lang/phones/context_indep.csl` || exit 1; +sdata=$data/split$nj; +splice_opts=`cat $alidir/splice_opts 2>/dev/null` # frame-splicing options. +cmvn_opts=`cat $alidir/cmvn_opts 2>/dev/null` || exit 1 +delta_opts=`cat $alidir/delta_opts 2>/dev/null` + +mkdir -p $dir/log +cp $alidir/splice_opts $dir 2>/dev/null # frame-splicing options. +cp $alidir/cmvn_opts $dir 2>/dev/null # cmn/cmvn option. +cp $alidir/delta_opts $dir 2>/dev/null # delta option. + +utils/lang/check_phones_compatible.sh $lang/phones.txt $alidir/phones.txt || exit 1; +cp $lang/phones.txt $dir || exit 1; + +echo $nj >$dir/num_jobs +[[ -d $sdata && $data/feats.scp -ot $sdata ]] || split_data.sh $data $nj || exit 1; + +# Set up features. +if [ -f $alidir/final.mat ]; then feat_type=lda; else feat_type=delta; fi + +echo "$0: feature type is $feat_type" + +feats=() +feats_one=() +for n in `seq 0 $[num_sys-1]`; do + this_nj=$(cat ${alidirs[$n]}/num_jobs) || exit 1 + this_sdata=${datadirs[$n]}/split$this_nj + [[ -d $this_sdata && ${datadirs[$n]}/feats.scp -ot $this_sdata ]] || split_data.sh ${datadirs[$n]} $this_nj || exit 1; + ## Set up speaker-independent features. + case $feat_type in + delta) feats[$n]="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$this_sdata/JOB/utt2spk scp:$this_sdata/JOB/cmvn.scp scp:$this_sdata/JOB/feats.scp ark:- | add-deltas $delta_opts ark:- ark:- |" + feats_one[$n]="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$this_sdata/1/utt2spk scp:$this_sdata/1/cmvn.scp scp:$this_sdata/1/feats.scp ark:- | add-deltas $delta_opts ark:- ark:- |";; + lda) feats[$n]="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$this_sdata/JOB/utt2spk scp:$this_sdata/JOB/cmvn.scp scp:$this_sdata/JOB/feats.scp ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $alidir/final.mat ark:- ark:- |" + feats_one[$n]="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$this_sdata/1/utt2spk scp:$this_sdata/1/cmvn.scp scp:$this_sdata/1/feats.scp ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $alidir/final.mat ark:- ark:- |" + cp $alidir/final.mat $dir + cp $alidir/full.mat $dir 2>/dev/null + ;; + *) echo "$0: invalid feature type $feat_type" && exit 1; + esac + + if $use_fmllr; then + if [ ! -f ${alidirs[$n]}/trans.1 ]; then + echo "$0: Could not find fMLLR transforms in ${alidirs[$n]}" + exit 1 + fi + + echo "$0: Using transforms from ${alidirs[$n]}" + feats[i]="${feats[i]} transform-feats --utt2spk=ark:$this_sdata/JOB/utt2spk ark,s,cs:${alidirs[$n]}/trans.JOB ark:- ark:- |" + feats_one[i]="${feats_one[i]} transform-feats --utt2spk=ark:$this_sdata/1/utt2spk ark,s,cs:${alidirs[$n]}/trans.1 ark:- ark:- |" + fi + + # Do subsampling of feats, if needed + if [ $frame_subsampling_factor -gt 1 ]; then + feats[$n]="${feats[$n]} subsample-feats --n=$frame_subsampling_factor ark:- ark:- |" + feats_one[$n]="${feats_one[$n]} subsample-feats --n=$frame_subsampling_factor ark:- ark:- |" + fi +done + +if [ $stage -le -5 ]; then + echo "$0: Initializing monophone model (for alignment conversion, in case topology changed)" + + [ ! -f $lang/phones/sets.int ] && exit 1; + shared_phones_opt="--shared-phones=$lang/phones/sets.int" + # get feature dimension + example_feats="`echo ${feats[0]} | sed s/JOB/1/g`"; + if ! feat_dim=$(feat-to-dim "$example_feats" - 2>/dev/null) || [ -z $feat_dim ]; then + feat-to-dim "$example_feats" - # to see the error message. + echo "error getting feature dimension" + exit 1; + fi + + for n in `seq 0 $[num_sys-1]`; do + copy-feats "${feats_one[$n]}" ark:- + done | copy-feats ark:- ark:$dir/tmp.ark + + $cmd $dir/log/init_mono.log \ + gmm-init-mono $shared_phones_opt \ + "--train-feats=ark:subset-feats --n=10 ark:$dir/tmp.ark ark:- |" $lang/topo $feat_dim \ + $dir/mono.mdl $dir/mono.tree || exit 1 +fi + + +if [ $stage -le -4 ]; then + # Get tree stats. + + for n in `seq 0 $[num_sys-1]`; do + echo "$0: Accumulating tree stats" + this_data=${datadirs[$n]} + this_alidir=${alidirs[$n]} + this_nj=$(cat $this_alidir/num_jobs) || exit 1 + this_frame_subsampling_factor=1 + if [ -f $this_alidir/frame_subsampling_factor ]; then + this_frame_subsampling_factor=$(cat $this_alidir/frame_subsampling_factor) + fi + + if (( $frame_subsampling_factor % $this_frame_subsampling_factor != 0 )); then + echo "$0: frame-subsampling-factor=$frame_subsampling_factor is not " + echo "divisible by $this_frame_subsampling_factor (that of $this_alidir)" + exit 1 + fi + + this_frame_subsampling_factor=$((frame_subsampling_factor / this_frame_subsampling_factor)) + $cmd JOB=1:$this_nj $dir/log/acc_tree.$n.JOB.log \ + convert-ali --frame-subsampling-factor=$this_frame_subsampling_factor \ + $this_alidir/final.mdl $dir/mono.mdl $dir/mono.tree "ark:gunzip -c $this_alidir/ali.JOB.gz|" ark:- \| \ + acc-tree-stats $context_opts $tree_stats_opts --ci-phones=$ciphonelist $dir/mono.mdl \ + "${feats[$n]}" ark:- $dir/$n.JOB.treeacc || exit 1; + [ "`ls $dir/$n.*.treeacc | wc -w`" -ne "$this_nj" ] && echo "$0: Wrong #tree-accs for data $n $this_data" && exit 1; + done + + $cmd $dir/log/sum_tree_acc.log \ + sum-tree-stats $dir/treeacc $dir/*.treeacc || exit 1; + rm $dir/*.treeacc +fi + +if [ $stage -le -3 ] && $train_tree; then + echo "$0: Getting questions for tree clustering." + # preparing questions, roots file... + $cmd $dir/log/questions.log \ + cluster-phones $cluster_phones_opts $context_opts $dir/treeacc \ + $lang/phones/sets.int $dir/questions.int || exit 1; + cat $lang/phones/extra_questions.int >> $dir/questions.int + $cmd $dir/log/compile_questions.log \ + compile-questions \ + $context_opts $lang/topo $dir/questions.int $dir/questions.qst || exit 1; + + echo "$0: Building the tree" + $cmd $dir/log/build_tree.log \ + build-tree $context_opts --verbose=1 --max-leaves=$numleaves \ + --cluster-thresh=$cluster_thresh $dir/treeacc $lang/phones/roots.int \ + $dir/questions.qst $lang/topo $dir/tree || exit 1; +fi + +if [ $stage -le -2 ]; then + echo "$0: Initializing the model" + gmm-init-model --write-occs=$dir/1.occs \ + $dir/tree $dir/treeacc $lang/topo $dir/1.mdl 2> $dir/log/init_model.log || exit 1; + grep 'no stats' $dir/log/init_model.log && echo "This is a bad warning."; + rm $dir/treeacc +fi + +if [ $stage -le -1 ]; then + # Convert the alignments to the new tree. Note: we likely will not use these + # converted alignments in the CTC system directly, but they could be useful + # for other purposes. + + for n in `seq 0 $[num_sys-1]`; do + this_alidir=${alidirs[$n]} + this_nj=$(cat $this_alidir/num_jobs) || exit 1 + + this_frame_subsampling_factor=1 + if [ -f $this_alidir/frame_subsampling_factor ]; then + this_frame_subsampling_factor=$(cat $this_alidir/frame_subsampling_factor) + fi + + if (( $frame_subsampling_factor % $this_frame_subsampling_factor != 0 )); then + echo "$0: frame-subsampling-factor=$frame_subsampling_factor is not " + echo "divisible by $this_frame_subsampling_factor (hat of $this_alidir)" + exit 1 + fi + + echo "$0: frame-subsampling-factor for $this_alidir is $this_frame_subsampling_factor" + + this_frame_subsampling_factor=$((frame_subsampling_factor / this_frame_subsampling_factor)) + echo "$0: Converting alignments from $this_alidir to use current tree" + $cmd JOB=1:$this_nj $dir/log/convert.$n.JOB.log \ + convert-ali --repeat-frames=$repeat_frames \ + --frame-subsampling-factor=$this_frame_subsampling_factor \ + $this_alidir/final.mdl $dir/1.mdl $dir/tree "ark:gunzip -c $this_alidir/ali.JOB.gz |" \ + ark,scp:$dir/ali.$n.JOB.ark,$dir/ali.$n.JOB.scp || exit 1 + + for i in `seq $this_nj`; do + cat $dir/ali.$n.$i.scp + done > $dir/ali.$n.scp || exit 1 + done + + for n in `seq 0 $[num_sys-1]`; do + cat $dir/ali.$n.scp + done | sort -k1,1 > $dir/ali.scp || exit 1 + + utils/split_data.sh $data $nj + $cmd JOB=1:$nj $dir/log/copy_alignments.JOB.log \ + copy-int-vector "scp:utils/filter_scp.pl $data/split$nj/JOB/utt2spk $dir/ali.scp |" \ + "ark:| gzip -c > $dir/ali.JOB.gz" || exit 1 +fi + +cp $dir/1.mdl $dir/final.mdl + +echo $0: Done building tree diff --git a/egs/wsj/s5/steps/nnet3/chain/get_egs.sh b/egs/wsj/s5/steps/nnet3/chain/get_egs.sh index cec6f8e166f..fb663591969 100755 --- a/egs/wsj/s5/steps/nnet3/chain/get_egs.sh +++ b/egs/wsj/s5/steps/nnet3/chain/get_egs.sh @@ -52,9 +52,9 @@ left_tolerance= transform_dir= # If supplied, overrides latdir as the place to find fMLLR transforms stage=0 -nj=15 # This should be set to the maximum number of jobs you are - # comfortable to run in parallel; you can increase it if your disk - # speed is greater and you have more machines. +max_jobs_run=15 # This should be set to the maximum number of jobs you are + # comfortable to run in parallel; you can increase it if your disk + # speed is greater and you have more machines. max_shuffle_jobs_run=50 # the shuffle jobs now include the nnet3-chain-normalize-egs command, # which is fairly CPU intensive, so we can run quite a few at once # without overloading the disks. @@ -63,6 +63,17 @@ online_ivector_dir= # can be used if we are including speaker information as iV cmvn_opts= # can be used for specifying CMVN options, if feature type is not lda (if lda, # it doesn't make sense to use different options than were used as input to the # LDA transform). This is used to turn off CMVN in the online-nnet experiments. +lattice_lm_scale= # If supplied, the graph/lm weight of the lattices will be + # used (with this scale) in generating supervisions +egs_weight=1.0 # The weight which determines how much each training example + # contributes to gradients while training (can be used + # to down/up-weight a dataset) +lattice_prune_beam= # If supplied, the lattices will be pruned to this beam, + # before being used to get supervisions. +acwt=0.1 # For pruning +phone_insertion_penalty= +deriv_weights_scp= +generate_egs_scp=false echo "$0 $@" # Print the command line for logging @@ -80,7 +91,7 @@ if [ $# != 4 ]; then echo "" echo "Main options (for others, see top of script file)" echo " --config # config file containing options" - echo " --nj # The maximum number of jobs you want to run in" + echo " --max-jobs-run # The maximum number of jobs you want to run in" echo " # parallel (increase this only if you have good disk and" echo " # network speed). default=6" echo " --cmd (utils/run.pl;utils/queue.pl ) # how to run jobs." @@ -94,7 +105,7 @@ if [ $# != 4 ]; then echo " --left-context-initial # If >= 0, left-context for first chunk of an utterance" echo " --right-context-final # If >= 0, right-context for last chunk of an utterance" echo " --num-egs-diagnostic <#frames;4000> # Number of egs used in computing (train,valid) diagnostics" - echo " --num-valid-egs-combine <#frames;10000> # Number of egss used in getting combination weights at the" + echo " --num-valid-egs-combine <#frames;10000> # Number of egs used in getting combination weights at the" echo " # very end." echo " --stage # Used to run a partially-completed training process from somewhere in" echo " # the middle." @@ -116,13 +127,13 @@ for f in $data/feats.scp $latdir/lat.1.gz $latdir/final.mdl \ [ ! -f $f ] && echo "$0: no such file $f" && exit 1; done +nj=$(cat $latdir/num_jobs) || exit 1 + sdata=$data/split$nj utils/split_data.sh $data $nj mkdir -p $dir/log $dir/info -num_lat_jobs=$(cat $latdir/num_jobs) || exit 1; - # Get list of validation utterances. frame_shift=$(utils/data/get_frame_shift.sh $data) || exit 1 @@ -184,6 +195,8 @@ if [ -f $dir/trans.scp ]; then train_subset_feats="$train_subset_feats transform-feats --utt2spk=ark:$data/utt2spk scp:$dir/trans.scp ark:- ark:- |" fi +tree-info $chaindir/tree | grep num-pdfs | awk '{print $2}' > $dir/info/num_pdfs || exit 1 + if [ ! -z "$online_ivector_dir" ]; then ivector_dim=$(feat-to-dim scp:$online_ivector_dir/ivector_online.scp -) || exit 1; echo $ivector_dim > $dir/info/ivector_dim @@ -257,20 +270,11 @@ if [ -e $dir/storage ]; then done fi -if [ $stage -le 2 ]; then - echo "$0: copying training lattices" - - $cmd --max-jobs-run 6 JOB=1:$num_lat_jobs $dir/log/lattice_copy.JOB.log \ - lattice-copy "ark:gunzip -c $latdir/lat.JOB.gz|" ark,scp:$dir/lat.JOB.ark,$dir/lat.JOB.scp || exit 1; - - for id in $(seq $num_lat_jobs); do cat $dir/lat.$id.scp; done > $dir/lat.scp -fi - - egs_opts="--left-context=$left_context --right-context=$right_context --num-frames=$frames_per_eg --frame-subsampling-factor=$frame_subsampling_factor --compress=$compress" [ $left_context_initial -ge 0 ] && egs_opts="$egs_opts --left-context-initial=$left_context_initial" [ $right_context_final -ge 0 ] && egs_opts="$egs_opts --right-context-final=$right_context_final" +[ ! -z "$deriv_weights_scp" ] && egs_opts="$egs_opts --deriv-weights-rspecifier=scp:$deriv_weights_scp" chain_supervision_all_opts="--lattice-input=true --frame-subsampling-factor=$alignment_subsampling_factor" [ ! -z $right_tolerance ] && \ @@ -279,19 +283,47 @@ chain_supervision_all_opts="--lattice-input=true --frame-subsampling-factor=$ali [ ! -z $left_tolerance ] && \ chain_supervision_all_opts="$chain_supervision_all_opts --left-tolerance=$left_tolerance" +normalization_scale=1.0 + +lats_rspecifier="ark:gunzip -c $latdir/lat.JOB.gz |" +if [ ! -z $lattice_prune_beam ]; then + if [ "$lattice_prune_beam" == "0" ] || [ "$lattice_prune_beam" == "0.0" ]; then + lats_rspecifier="$lats_rspecifier lattice-1best --acoustic-scale=$acwt ark:- ark:- |" + else + lats_rspecifier="$lats_rspecifier lattice-prune --acoustic-scale=$acwt --beam=$lattice_prune_beam ark:- ark:- |" + fi +fi + +if [ ! -z "$lattice_lm_scale" ]; then + chain_supervision_all_opts="$chain_supervision_all_opts --lm-scale=$lattice_lm_scale" + + normalization_scale=$(perl -e " + if ($lattice_lm_scale >= 1.0 || $lattice_lm_scale < 0) { + print STDERR \"Invalid --lattice-lm-scale $lattice_lm_scale\"; + exit(1); + } + print (1.0 - $lattice_lm_scale);") +fi + +[ ! -z $phone_insertion_penalty ] && \ + chain_supervision_all_opts="$chain_supervision_all_opts --phone-ins-penalty=$phone_insertion_penalty" + echo $left_context > $dir/info/left_context echo $right_context > $dir/info/right_context echo $left_context_initial > $dir/info/left_context_initial echo $right_context_final > $dir/info/right_context_final -if [ $stage -le 3 ]; then - echo "$0: Getting validation and training subset examples." +if [ $stage -le 2 ]; then + echo "$0: Getting validation and training subset examples in background." rm $dir/.error 2>/dev/null - echo "$0: ... extracting validation and training-subset alignments." - # do the filtering just once, as lat.scp may be long. - utils/filter_scp.pl <(cat $dir/valid_uttlist $dir/train_subset_uttlist) \ - <$dir/lat.scp >$dir/lat_special.scp + ( + $cmd --max-jobs-run 6 JOB=1:$nj $dir/log/lattice_copy.JOB.log \ + lattice-copy --include="cat $dir/valid_uttlist $dir/train_subset_uttlist |" --ignore-missing \ + "$lats_rspecifier" \ + ark,scp:$dir/lat_special.JOB.ark,$dir/lat_special.JOB.scp || exit 1 + + for id in $(seq $nj); do cat $dir/lat_special.$id.scp; done > $dir/lat_special.scp $cmd $dir/log/create_valid_subset.log \ utils/filter_scp.pl $dir/valid_uttlist $dir/lat_special.scp \| \ @@ -299,40 +331,54 @@ if [ $stage -le 3 ]; then chain-get-supervision $chain_supervision_all_opts $chaindir/tree $chaindir/0.trans_mdl \ ark:- ark:- \| \ nnet3-chain-get-egs $ivector_opts --srand=$srand \ - $egs_opts $chaindir/normalization.fst \ - "$valid_feats" ark,s,cs:- "ark:$dir/valid_all.cegs" || touch $dir/.error & + $egs_opts --normalization-scale=$normalization_scale $chaindir/normalization.fst \ + "$valid_feats" ark,s,cs:- "ark:$dir/valid_all.cegs" || exit 1 & $cmd $dir/log/create_train_subset.log \ utils/filter_scp.pl $dir/train_subset_uttlist $dir/lat_special.scp \| \ lattice-align-phones --replace-output-symbols=true $latdir/final.mdl scp:- ark:- \| \ chain-get-supervision $chain_supervision_all_opts \ $chaindir/tree $chaindir/0.trans_mdl ark:- ark:- \| \ nnet3-chain-get-egs $ivector_opts --srand=$srand \ - $egs_opts $chaindir/normalization.fst \ - "$train_subset_feats" ark,s,cs:- "ark:$dir/train_subset_all.cegs" || touch $dir/.error & - wait; - [ -f $dir/.error ] && echo "Error detected while creating train/valid egs" && exit 1 + $egs_opts --normalization-scale=$normalization_scale $chaindir/normalization.fst \ + "$train_subset_feats" ark,s,cs:- "ark:$dir/train_subset_all.cegs" || exit 1 & + wait + sleep 5 # wait for file system to sync. echo "... Getting subsets of validation examples for diagnostics and combination." + if $generate_egs_scp; then + valid_diagnostic_output="ark,scp:$dir/valid_diagnostic.cegs,$dir/valid_diagnostic.scp" + train_diagnostic_output="ark,scp:$dir/train_diagnostic.cegs,$dir/train_diagnostic.scp" + else + valid_diagnostic_output="ark:$dir/valid_diagnostic.cegs" + train_diagnostic_output="ark:$dir/train_diagnostic.cegs" + fi $cmd $dir/log/create_valid_subset_combine.log \ nnet3-chain-subset-egs --n=$num_valid_egs_combine ark:$dir/valid_all.cegs \ - ark:$dir/valid_combine.cegs || touch $dir/.error & + ark:$dir/valid_combine.cegs || exit 1 & $cmd $dir/log/create_valid_subset_diagnostic.log \ nnet3-chain-subset-egs --n=$num_egs_diagnostic ark:$dir/valid_all.cegs \ - ark:$dir/valid_diagnostic.cegs || touch $dir/.error & + $valid_diagnostic_output || exit 1 & $cmd $dir/log/create_train_subset_combine.log \ nnet3-chain-subset-egs --n=$num_train_egs_combine ark:$dir/train_subset_all.cegs \ - ark:$dir/train_combine.cegs || touch $dir/.error & + ark:$dir/train_combine.cegs || exit 1 & $cmd $dir/log/create_train_subset_diagnostic.log \ nnet3-chain-subset-egs --n=$num_egs_diagnostic ark:$dir/train_subset_all.cegs \ - ark:$dir/train_diagnostic.cegs || touch $dir/.error & + $train_diagnostic_output || exit 1 & wait sleep 5 # wait for file system to sync. - cat $dir/valid_combine.cegs $dir/train_combine.cegs > $dir/combine.cegs + if $generate_egs_scp; then + cat $dir/valid_combine.cegs $dir/train_combine.cegs | \ + nnet3-chain-copy-egs ark:- ark,scp:$dir/combine.cegs,$dir/combine.scp + rm $dir/{train,valid}_combine.scp + else + cat $dir/valid_combine.cegs $dir/train_combine.cegs > $dir/combine.cegs + fi for f in $dir/{combine,train_diagnostic,valid_diagnostic}.cegs; do [ ! -s $f ] && echo "No examples in file $f" && exit 1; done rm $dir/valid_all.cegs $dir/train_subset_all.cegs $dir/{train,valid}_combine.cegs + ) || touch $dir/.error & fi if [ $stage -le 4 ]; then @@ -353,10 +399,12 @@ if [ $stage -le 4 ]; then # there can be too many small files to deal with, because the total number of # files is the product of 'nj' by 'num_archives_intermediate', which might be # quite large. - $cmd JOB=1:$nj $dir/log/get_egs.JOB.log \ - utils/filter_scp.pl $sdata/JOB/utt2spk $dir/lat.scp \| \ - lattice-align-phones --replace-output-symbols=true $latdir/final.mdl scp:- ark:- \| \ + + $cmd --max-jobs-run $max_jobs_run JOB=1:$nj $dir/log/get_egs.JOB.log \ + lattice-align-phones --replace-output-symbols=true $latdir/final.mdl \ + "$lats_rspecifier" ark:- \| \ chain-get-supervision $chain_supervision_all_opts \ + --weight=$egs_weight \ $chaindir/tree $chaindir/0.trans_mdl ark:- ark:- \| \ nnet3-chain-get-egs $ivector_opts --srand=\$[JOB+$srand] $egs_opts \ --num-frames-overlap=$frames_overlap_per_eg \ @@ -376,16 +424,34 @@ if [ $stage -le 5 ]; then done if [ $archives_multiple == 1 ]; then # normal case. + if $generate_egs_scp; then + output_archive="ark,scp:$dir/cegs.JOB.ark,$dir/cegs.JOB.scp" + else + output_archive="ark:$dir/cegs.JOB.ark" + fi $cmd --max-jobs-run $max_shuffle_jobs_run --mem 8G JOB=1:$num_archives_intermediate $dir/log/shuffle.JOB.log \ - nnet3-chain-normalize-egs $chaindir/normalization.fst "ark:cat $egs_list|" ark:- \| \ - nnet3-chain-shuffle-egs --srand=\$[JOB+$srand] ark:- ark:$dir/cegs.JOB.ark || exit 1; + nnet3-chain-normalize-egs --normalization-scale=$normalization_scale $chaindir/normalization.fst "ark:cat $egs_list|" ark:- \| \ + nnet3-chain-shuffle-egs --srand=\$[JOB+$srand] ark:- $output_archive || exit 1; + + if $generate_egs_scp; then + #concatenate cegs.JOB.scp in single cegs.scp + rm -rf $dir/cegs.scp + for j in $(seq $num_archives_intermediate); do + cat $dir/cegs.$j.scp || exit 1; + done > $dir/cegs.scp || exit 1; + for f in $dir/cegs.*.scp; do rm $f; done + fi else # we need to shuffle the 'intermediate archives' and then split into the # final archives. we create soft links to manage this splitting, because # otherwise managing the output names is quite difficult (and we don't want # to submit separate queue jobs for each intermediate archive, because then # the --max-jobs-run option is hard to enforce). - output_archives="$(for y in $(seq $archives_multiple); do echo ark:$dir/cegs.JOB.$y.ark; done)" + if $generate_egs_scp; then + output_archives="$(for y in $(seq $archives_multiple); do echo ark,scp:$dir/cegs.JOB.$y.ark,$dir/cegs.JOB.$y.scp; done)" + else + output_archives="$(for y in $(seq $archives_multiple); do echo ark:$dir/cegs.JOB.$y.ark; done)" + fi for x in $(seq $num_archives_intermediate); do for y in $(seq $archives_multiple); do archive_index=$[($x-1)*$archives_multiple+$y] @@ -394,12 +460,26 @@ if [ $stage -le 5 ]; then done done $cmd --max-jobs-run $max_shuffle_jobs_run --mem 8G JOB=1:$num_archives_intermediate $dir/log/shuffle.JOB.log \ - nnet3-chain-normalize-egs $chaindir/normalization.fst "ark:cat $egs_list|" ark:- \| \ + nnet3-chain-normalize-egs --normalization-scale=$normalization_scale $chaindir/normalization.fst "ark:cat $egs_list|" ark:- \| \ nnet3-chain-shuffle-egs --srand=\$[JOB+$srand] ark:- ark:- \| \ nnet3-chain-copy-egs ark:- $output_archives || exit 1; + + if $generate_egs_scp; then + #concatenate cegs.JOB.scp in single cegs.scp + rm -rf $dir/cegs.scp + for j in $(seq $num_archives_intermediate); do + for y in $(seq $archives_multiple); do + cat $dir/cegs.$j.$y.scp || exit 1; + done + done > $dir/cegs.scp || exit 1; + for f in $dir/cegs.*.*.scp; do rm $f; done + fi fi fi +wait +[ -f $dir/.error ] && echo "Error detected while creating train/valid egs" && exit 1 + if [ $stage -le 6 ]; then echo "$0: removing temporary archives" ( @@ -413,8 +493,6 @@ if [ $stage -le 6 ]; then # there are some extra soft links that we should delete. for f in $dir/cegs.*.*.ark; do rm $f; done fi - echo "$0: removing temporary lattices" - rm $dir/lat.* echo "$0: removing temporary alignments and transforms" # Ignore errors below because trans.* might not exist. rm $dir/{ali,trans}.{ark,scp} 2>/dev/null diff --git a/egs/wsj/s5/steps/nnet3/chain/train.py b/egs/wsj/s5/steps/nnet3/chain/train.py index b62f5510e3c..033649a599a 100755 --- a/egs/wsj/s5/steps/nnet3/chain/train.py +++ b/egs/wsj/s5/steps/nnet3/chain/train.py @@ -272,7 +272,8 @@ def train(args, run_opts): # Check files chain_lib.check_for_required_files(args.feat_dir, args.tree_dir, - args.lat_dir) + args.lat_dir if args.egs_dir is None + else None) # Set some variables. num_jobs = common_lib.get_number_of_jobs(args.tree_dir) @@ -404,6 +405,15 @@ def train(args, run_opts): logger.info("Copying the properties from {0} to {1}".format(egs_dir, args.dir)) common_train_lib.copy_egs_properties_to_exp_dir(egs_dir, args.dir) + if not os.path.exists('{0}/valid_diagnostic.cegs'.format(egs_dir)): + if (not os.path.exists('{0}/valid_diagnostic.scp'.format(egs_dir))): + raise Exception('neither {0}/valid_diagnostic.cegs nor ' + '{0}/valid_diagnostic.scp exist.' + 'This script expects one of them.'.format(egs_dir)) + use_multitask_egs = True + else: + use_multitask_egs = False + if ((args.stage <= -2) and (os.path.exists(args.dir+"/configs/init.config")) and (args.input_model is None)): logger.info('Computing the preconditioning matrix for input features') @@ -411,7 +421,8 @@ def train(args, run_opts): chain_lib.compute_preconditioning_matrix( args.dir, egs_dir, num_archives, run_opts, max_lda_jobs=args.max_lda_jobs, - rand_prune=args.rand_prune) + rand_prune=args.rand_prune, + use_multitask_egs=use_multitask_egs) if (args.stage <= -1): logger.info("Preparing the initial acoustic model.") @@ -519,7 +530,8 @@ def train(args, run_opts): frame_subsampling_factor=args.frame_subsampling_factor, run_opts=run_opts, backstitch_training_scale=args.backstitch_training_scale, - backstitch_training_interval=args.backstitch_training_interval) + backstitch_training_interval=args.backstitch_training_interval, + use_multitask_egs=use_multitask_egs) if args.cleanup: # do a clean up everythin but the last 2 models, under certain @@ -554,11 +566,18 @@ def train(args, run_opts): l2_regularize=args.l2_regularize, xent_regularize=args.xent_regularize, run_opts=run_opts, - max_objective_evaluations=args.max_objective_evaluations) + max_objective_evaluations=args.max_objective_evaluations, + use_multitask_egs=use_multitask_egs) else: logger.info("Copying the last-numbered model to final.mdl") common_lib.force_symlink("{0}.mdl".format(num_iters), "{0}/final.mdl".format(args.dir)) + chain_lib.compute_train_cv_probabilities( + dir=args.dir, iter=num_iters, egs_dir=egs_dir, + l2_regularize=l2_regularize, xent_regularize=xent_regularize, + leaky_hmm_coefficient=args.leaky_hmm_coefficient, + run_opts=run_opts, + use_multitask_egs=use_multitask_egs) common_lib.force_symlink("compute_prob_valid.{iter}.log" "".format(iter=num_iters-1), "{dir}/log/compute_prob_valid.final.log".format( diff --git a/egs/wsj/s5/steps/nnet3/decode.sh b/egs/wsj/s5/steps/nnet3/decode.sh index 50e02629db0..86ccb1c6302 100755 --- a/egs/wsj/s5/steps/nnet3/decode.sh +++ b/egs/wsj/s5/steps/nnet3/decode.sh @@ -32,12 +32,14 @@ extra_left_context_initial=-1 extra_right_context_final=-1 online_ivector_dir= minimize=false +word_determinize=false +write_compact=true # End configuration section. echo "$0 $@" # Print the command line for logging [ -f ./path.sh ] && . ./path.sh; # source the path. -. parse_options.sh || exit 1; +. utils/parse_options.sh || exit 1; if [ $# -ne 3 ]; then echo "Usage: $0 [options] " @@ -118,10 +120,17 @@ if [ ! -z "$online_ivector_dir" ]; then ivector_opts="--online-ivectors=scp:$online_ivector_dir/ivector_online.scp --online-ivector-period=$ivector_period" fi +extra_opts= +lats_wspecifier="ark:|" +if ! $write_compact; then + extra_opts="--determinize-lattice=false" + lats_wspecifier="ark:| lattice-determinize-phone-pruned --beam=$lattice_beam --acoustic-scale=$acwt --minimize=$minimize --word-determinize=$word_determinize --write-compact=false $model ark:- ark:- |" +fi + if [ "$post_decode_acwt" == 1.0 ]; then - lat_wspecifier="ark:|gzip -c >$dir/lat.JOB.gz" + lats_wspecifier="$lats_wspecifier gzip -c >$dir/lat.JOB.gz" else - lat_wspecifier="ark:|lattice-scale --acoustic-scale=$post_decode_acwt ark:- ark:- | gzip -c >$dir/lat.JOB.gz" + lats_wspecifier="$lats_wspecifier lattice-scale --acoustic-scale=$post_decode_acwt --write-compact=$write_compact ark:- ark:- | gzip -c >$dir/lat.JOB.gz" fi frame_subsampling_opt= @@ -138,10 +147,12 @@ if [ $stage -le 1 ]; then --extra-right-context=$extra_right_context \ --extra-left-context-initial=$extra_left_context_initial \ --extra-right-context-final=$extra_right_context_final \ - --minimize=$minimize --max-active=$max_active --min-active=$min_active --beam=$beam \ + --minimize=$minimize --word-determinize=$word_determinize \ + --max-active=$max_active --min-active=$min_active --beam=$beam \ --lattice-beam=$lattice_beam --acoustic-scale=$acwt --allow-partial=true \ - --word-symbol-table=$graphdir/words.txt "$model" \ - $graphdir/HCLG.fst "$feats" "$lat_wspecifier" || exit 1; + --word-symbol-table=$graphdir/words.txt ${extra_opts} \ + "$model" \ + $graphdir/HCLG.fst "$feats" "$lats_wspecifier" || exit 1; fi diff --git a/egs/wsj/s5/steps/nnet3/multilingual/allocate_multilingual_examples.py b/egs/wsj/s5/steps/nnet3/multilingual/allocate_multilingual_examples.py index 6372ba25e5e..3f9ec568505 100755 --- a/egs/wsj/s5/steps/nnet3/multilingual/allocate_multilingual_examples.py +++ b/egs/wsj/s5/steps/nnet3/multilingual/allocate_multilingual_examples.py @@ -95,9 +95,16 @@ def get_args(): 'output-2'.""", epilog="Called by steps/nnet3/multilingual/combine_egs.sh") - parser.add_argument("--samples-per-iter", type=int, default=40000, + parser.add_argument("--samples-per-iter", type=int, default=None, help="The target number of egs in each archive of egs, " - "(prior to merging egs). ") + "(prior to merging egs). [DEPRECATED]") + parser.add_argument("--frames-per-iter", type=int, default=400000, + help="The target number of frames in each archive of " + "egs") + parser.add_argument("--frames-per-eg-list", type=str, default=None, + action=common_lib.NullstrToNoneAction, + help="Number of frames per eg for each input language " + "as a comma separated list") parser.add_argument("--num-jobs", type=int, default=20, help="This can be used for better randomization in distributing " "examples for different languages across egs.*.scp files, " @@ -107,7 +114,7 @@ def get_args(): help="If true, egs.ranges.*.txt are generated " "randomly w.r.t distribution of remaining examples in " "each language, otherwise it is generated sequentially.", - default=True, choices = ["false", "true"]) + default=True, choices=["false", "true"]) parser.add_argument("--max-archives", type=int, default=1000, help="max number of archives used to generate egs.*.scp") parser.add_argument("--seed", type=int, default=1, @@ -129,7 +136,7 @@ def get_args(): # now the positional arguments parser.add_argument("egs_scp_lists", nargs='+', help="list of egs.scp files per input language." - "e.g. exp/lang1/egs/egs.scp exp/lang2/egs/egs.scp") + "e.g. exp/lang1/egs/egs.scp exp/lang2/egs/egs.scp") parser.add_argument("egs_dir", help="Name of egs directory e.g. exp/tdnn_multilingual_sp/egs") @@ -137,6 +144,10 @@ def get_args(): print(sys.argv, file=sys.stderr) args = parser.parse_args() + if args.samples_per_iter is not None: + args.frames_per_iter = args.samples_per_iter + args.frames_per_eg_list = None + return args @@ -153,7 +164,7 @@ def select_random_lang(lang_len, tot_egs, random_selection): count = 0 for l in range(len(lang_len)): if random_selection: - if rand_int <= (count + lang_len[l]): + if rand_int <= (count + lang_len[l]): return l else: count += lang_len[l] @@ -172,6 +183,10 @@ def process_multilingual_egs(args): scp_lists = args.egs_scp_lists num_langs = len(scp_lists) + frames_per_eg = ([1 for x in scp_lists] + if args.frames_per_eg_list is None + else [int(x) for x in args.frames_per_eg_list.split(',')]) + scp_files = [open(scp_lists[lang], 'r') for lang in range(num_langs)] lang2len = [0] * num_langs @@ -182,7 +197,7 @@ def process_multilingual_egs(args): # If weights are not provided, the weights are 1.0. if args.lang2weight is None: - lang2weight = [ 1.0 ] * num_langs + lang2weight = [1.0] * num_langs else: lang2weight = args.lang2weight.split(",") assert(len(lang2weight) == num_langs) @@ -195,10 +210,16 @@ def process_multilingual_egs(args): # Each element of all_egs (one per num_archive * num_jobs) is # an array of 3-tuples (lang-id, local-start-egs-line, num-egs) all_egs = [] - lang_len = lang2len[:] - # total num of egs in all languages - tot_num_egs = sum(lang2len[i] for i in range(len(lang2len))) - num_archives = max(1, min(args.max_archives, tot_num_egs / args.samples_per_iter)) + num_frames_in_lang = [frames_per_eg[i] * lang2len[i] + for i in range(num_langs)] + for lang in range(num_langs): + logger.info("Number of frames for language {0} " + "is {1}.".format(lang, num_frames_in_lang[lang])) + + # total num of frames in all languages + tot_num_frames = sum(num_frames_in_lang[i] for i in range(num_langs)) + num_archives = max(1, min(args.max_archives, + tot_num_frames / args.frames_per_iter)) num_arch_file = open("{0}/info/{1}num_archives".format( args.egs_dir, @@ -206,7 +227,7 @@ def process_multilingual_egs(args): "w") print("{0}".format(num_archives), file=num_arch_file) num_arch_file.close() - this_num_egs_per_archive = tot_num_egs / (num_archives * args.num_jobs) + this_num_frames_per_archive = tot_num_frames / (num_archives * args.num_jobs) logger.info("Generating {0}scp.. temporary files used to " "generate {0}.scp.".format(args.egs_prefix)) @@ -216,29 +237,36 @@ def process_multilingual_egs(args): "".format(args.egs_dir, args.egs_prefix, job + 1, archive_index + 1), "w") - this_egs = [] # this will be array of 2-tuples (lang-id start-frame num-frames) + # this will be array of 2-tuples (lang-id start-frame num-frames) + this_egs = [] num_egs = 0 - while num_egs <= this_num_egs_per_archive: - num_left_egs = sum(num_left_egs_per_lang for - num_left_egs_per_lang in lang_len) - if num_left_egs > 0: - lang_id = select_random_lang(lang_len, num_left_egs, rand_select) - start_egs = lang2len[lang_id] - lang_len[lang_id] + num_frames = 0 + while num_frames <= this_num_frames_per_archive: + num_frames_left = sum(num_frames_in_lang) + if num_frames_left > 0: + lang_id = select_random_lang(num_frames_in_lang, + num_frames_left, rand_select) + start_egs = ( + lang2len[lang_id] + - num_frames_in_lang[lang_id] / frames_per_eg[lang_id]) this_egs.append((lang_id, start_egs, args.minibatch_size)) for scpline in range(args.minibatch_size): scp_key = scp_files[lang_id].readline().splitlines()[0] print("{0} {1}".format(scp_key, lang_id), file=archfile) - lang_len[lang_id] = lang_len[lang_id] - args.minibatch_size - num_egs = num_egs + args.minibatch_size + num_frames_in_lang[lang_id] -= ( + args.minibatch_size * frames_per_eg[lang_id]) + num_egs += args.minibatch_size + num_frames += args.minibatch_size * frames_per_eg[lang_id] # If num of remaining egs in each lang is less than minibatch_size, # they are discarded. - if lang_len[lang_id] < args.minibatch_size: - lang_len[lang_id] = 0 - logger.info("Done processing data for language {0}".format( - lang_id)) + if (num_frames_in_lang[lang_id] + < args.minibatch_size * frames_per_eg[lang_id]): + num_frames_in_lang[lang_id] = 0 + logger.info("Done processing data for language {0}" + "".format(lang_id)) else: logger.info("Done processing data for all languages.") break @@ -315,4 +343,4 @@ def main(): if __name__ == "__main__": - main() + main() diff --git a/egs/wsj/s5/steps/nnet3/multilingual/combine_egs.sh b/egs/wsj/s5/steps/nnet3/multilingual/combine_egs.sh index 3826dad11a9..75a49e1004e 100755 --- a/egs/wsj/s5/steps/nnet3/multilingual/combine_egs.sh +++ b/egs/wsj/s5/steps/nnet3/multilingual/combine_egs.sh @@ -19,13 +19,15 @@ minibatch_size=512 # it is the number of consecutive egs that we take from # access. This does not have to be the actual minibatch size; num_jobs=10 # helps for better randomness across languages # per archive. -samples_per_iter=400000 # this is the target number of egs in each archive of egs +frames_per_iter=400000 # this is the target number of egs in each archive of egs # (prior to merging egs). We probably should have called # it egs_per_iter. This is just a guideline; it will pick # a number that divides the number of samples in the # entire data. lang2weight= # array of weights one per input languge to scale example's output # w.r.t its input language during training. +allocate_opts= +egs_prefix=egs. stage=0 echo "$0 $@" # Print the command line for logging @@ -33,6 +35,12 @@ echo "$0 $@" # Print the command line for logging if [ -f path.sh ]; then . ./path.sh; fi . parse_options.sh || exit 1; +if [ $# -lt 3 ]; then + echo "Usage:$0 [opts] ... " + echo "Usage:$0 [opts] 2 exp/lang1/egs exp/lang2/egs exp/multi/egs" + exit 1; +fi + num_langs=$1 shift 1 @@ -47,7 +55,8 @@ if [ ${#args[@]} != $[$num_langs+1] ]; then exit 1; fi -required="egs.scp combine.scp train_diagnostic.scp valid_diagnostic.scp" +required="${egs_prefix}scp combine.scp train_diagnostic.scp valid_diagnostic.scp" +frames_per_eg_list= train_scp_list= train_diagnostic_scp_list= valid_diagnostic_scp_list= @@ -55,13 +64,15 @@ combine_scp_list= # read paramter from $egs_dir[0]/info and cmvn_opts # to write in multilingual egs_dir. -check_params="info/feat_dim info/ivector_dim info/left_context info/right_context info/frames_per_eg cmvn_opts" +check_params="info/feat_dim info/ivector_dim info/left_context info/right_context cmvn_opts" ivec_dim=`cat ${args[0]}/info/ivector_dim` if [ $ivec_dim -ne 0 ];then check_params="$check_params info/final.ie.id"; fi for param in $check_params; do - cat ${args[0]}/$param > $megs_dir/$param || exit 1; + cat ${args[0]}/$param > $megs_dir/$param || exit 1; done +cat ${args[0]}/cmvn_opts > $megs_dir/cmvn_opts || exit 1; # caution: the top-level nnet training +cp ${args[0]}/info/frames_per_eg $megs_dir/info/frames_per_eg || exit 1; for lang in $(seq 0 $[$num_langs-1]);do multi_egs_dir[$lang]=${args[$lang]} @@ -70,10 +81,21 @@ for lang in $(seq 0 $[$num_langs-1]);do echo "$0: no such file ${multi_egs_dir[$lang]}/$f." && exit 1; fi done - train_scp_list="$train_scp_list ${args[$lang]}/egs.scp" + train_scp_list="$train_scp_list ${args[$lang]}/${egs_prefix}scp" train_diagnostic_scp_list="$train_diagnostic_scp_list ${args[$lang]}/train_diagnostic.scp" valid_diagnostic_scp_list="$valid_diagnostic_scp_list ${args[$lang]}/valid_diagnostic.scp" combine_scp_list="$combine_scp_list ${args[$lang]}/combine.scp" + + this_frames_per_eg=$(cat ${args[$lang]}/info/frames_per_eg | \ + awk -F, '{for (i=1; i<=NF; i++) sum += $i;} END{print int(sum / NF)}') # use average frames-per-eg + + # frames_per_eg_list stores the average frames-per-eg for each language. + # The average does not have to be exact. + if [ $lang -eq 0 ]; then + frames_per_eg_list="$this_frames_per_eg" + else + frames_per_eg_list="$frames_per_eg_list,$this_frames_per_eg" + fi # check parameter dimension to be the same in all egs dirs for f in $check_params; do @@ -90,16 +112,18 @@ for lang in $(seq 0 $[$num_langs-1]);do done done +if [ ! -z "$lang2weight" ]; then + egs_opt="--lang2weight '$lang2weight'" +fi + if [ $stage -le 0 ]; then echo "$0: allocating multilingual examples for training." - if [ ! -z "$lang2weight" ]; then - egs_opt="--lang2weight '$lang2weight'" - fi - # Generate egs.*.scp for multilingual setup. + # Generate ${egs_prefix}*.scp for multilingual setup. $cmd $megs_dir/log/allocate_multilingual_examples_train.log \ steps/nnet3/multilingual/allocate_multilingual_examples.py $egs_opt \ - --minibatch-size $minibatch_size \ - --samples-per-iter $samples_per_iter \ + ${allocate_opts} --minibatch-size $minibatch_size \ + --frames-per-iter $frames_per_iter --frames-per-eg-list $frames_per_eg_list \ + --egs-prefix "$egs_prefix" \ $train_scp_list $megs_dir || exit 1; fi @@ -107,20 +131,20 @@ if [ $stage -le 1 ]; then echo "$0: combine combine.scp examples from all langs in $megs_dir/combine.scp." # Generate combine.scp for multilingual setup. $cmd $megs_dir/log/allocate_multilingual_examples_combine.log \ - steps/nnet3/multilingual/allocate_multilingual_examples.py \ - --random-lang false \ - --max-archives 1 --num-jobs 1 \ - --minibatch-size $minibatch_size \ + steps/nnet3/multilingual/allocate_multilingual_examples.py $egs_opt \ + --random-lang false --max-archives 1 --num-jobs 1 \ + --frames-per-eg-list $frames_per_eg_list \ + ${allocate_opts} --minibatch-size $minibatch_size \ --egs-prefix "combine." \ $combine_scp_list $megs_dir || exit 1; echo "$0: combine train_diagnostic.scp examples from all langs in $megs_dir/train_diagnostic.scp." # Generate train_diagnostic.scp for multilingual setup. $cmd $megs_dir/log/allocate_multilingual_examples_train_diagnostic.log \ - steps/nnet3/multilingual/allocate_multilingual_examples.py \ - --random-lang false \ - --max-archives 1 --num-jobs 1 \ - --minibatch-size $minibatch_size \ + steps/nnet3/multilingual/allocate_multilingual_examples.py $egs_opt \ + --random-lang false --max-archives 1 --num-jobs 1 \ + --frames-per-eg-list $frames_per_eg_list \ + ${allocate_opts} --minibatch-size $minibatch_size \ --egs-prefix "train_diagnostic." \ $train_diagnostic_scp_list $megs_dir || exit 1; @@ -128,9 +152,10 @@ if [ $stage -le 1 ]; then echo "$0: combine valid_diagnostic.scp examples from all langs in $megs_dir/valid_diagnostic.scp." # Generate valid_diagnostic.scp for multilingual setup. $cmd $megs_dir/log/allocate_multilingual_examples_valid_diagnostic.log \ - steps/nnet3/multilingual/allocate_multilingual_examples.py \ + steps/nnet3/multilingual/allocate_multilingual_examples.py $egs_opt \ --random-lang false --max-archives 1 --num-jobs 1\ - --minibatch-size $minibatch_size \ + --frames-per-eg-list $frames_per_eg_list \ + ${allocate_opts} --minibatch-size $minibatch_size \ --egs-prefix "valid_diagnostic." \ $valid_diagnostic_scp_list $megs_dir || exit 1; @@ -140,6 +165,6 @@ for egs_type in combine train_diagnostic valid_diagnostic; do mv $megs_dir/${egs_type}.weight.1.ark $megs_dir/${egs_type}.weight.ark || exit 1; mv $megs_dir/${egs_type}.1.scp $megs_dir/${egs_type}.scp || exit 1; done -mv $megs_dir/info/egs.num_archives $megs_dir/info/num_archives || exit 1; -mv $megs_dir/info/egs.num_tasks $megs_dir/info/num_tasks || exit 1; +mv $megs_dir/info/${egs_prefix}num_archives $megs_dir/info/num_archives || exit 1; +mv $megs_dir/info/${egs_prefix}num_tasks $megs_dir/info/num_tasks || exit 1; echo "$0: Finished preparing multilingual training example." diff --git a/egs/wsj/s5/steps/nnet3/report/generate_plots.py b/egs/wsj/s5/steps/nnet3/report/generate_plots.py index 8ec283492ef..6f7987c425f 100755 --- a/egs/wsj/s5/steps/nnet3/report/generate_plots.py +++ b/egs/wsj/s5/steps/nnet3/report/generate_plots.py @@ -732,6 +732,7 @@ def main(): output_nodes.append(tuple(parts)) elif args.is_chain: output_nodes.append(('output', 'chain')) + output_nodes.append(('output-xent', 'chain')) else: output_nodes.append(('output', 'linear')) diff --git a/egs/wsj/s5/steps/scoring/score_kaldi_wer.sh b/egs/wsj/s5/steps/scoring/score_kaldi_wer.sh index 9988c941441..6651a744e4d 100755 --- a/egs/wsj/s5/steps/scoring/score_kaldi_wer.sh +++ b/egs/wsj/s5/steps/scoring/score_kaldi_wer.sh @@ -16,6 +16,7 @@ word_ins_penalty=0.0,0.5,1.0 min_lmwt=7 max_lmwt=17 iter=final +scoring_affix=_kaldi #end configuration section. echo "$0 $@" # Print the command line for logging @@ -59,15 +60,14 @@ else fi -mkdir -p $dir/scoring_kaldi -cat $data/text | $ref_filtering_cmd > $dir/scoring_kaldi/test_filt.txt || exit 1; -if [ $stage -le 0 ]; then - - for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do - mkdir -p $dir/scoring_kaldi/penalty_$wip/log +mkdir -p $dir/scoring${scoring_affix} +cat $data/text | $ref_filtering_cmd > $dir/scoring${scoring_affix}/test_filt.txt || exit 1; +for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do + mkdir -p $dir/scoring${scoring_affix}/penalty_$wip/log + if [ $stage -le 0 ]; then if $decode_mbr ; then - $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring_kaldi/penalty_$wip/log/best_path.LMWT.log \ + $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring${scoring_affix}/penalty_$wip/log/best_path.LMWT.log \ acwt=\`perl -e \"print 1.0/LMWT\"\`\; \ lattice-scale --inv-acoustic-scale=LMWT "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ lattice-add-penalty --word-ins-penalty=$wip ark:- ark:- \| \ @@ -75,37 +75,38 @@ if [ $stage -le 0 ]; then lattice-mbr-decode --word-symbol-table=$symtab \ ark:- ark,t:- \| \ utils/int2sym.pl -f 2- $symtab \| \ - $hyp_filtering_cmd '>' $dir/scoring_kaldi/penalty_$wip/LMWT.txt || exit 1; + $hyp_filtering_cmd '>' $dir/scoring${scoring_affix}/penalty_$wip/LMWT.txt || exit 1; else - $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring_kaldi/penalty_$wip/log/best_path.LMWT.log \ + $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring${scoring_affix}/penalty_$wip/log/best_path.LMWT.log \ lattice-scale --inv-acoustic-scale=LMWT "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ lattice-add-penalty --word-ins-penalty=$wip ark:- ark:- \| \ lattice-best-path --word-symbol-table=$symtab ark:- ark,t:- \| \ utils/int2sym.pl -f 2- $symtab \| \ - $hyp_filtering_cmd '>' $dir/scoring_kaldi/penalty_$wip/LMWT.txt || exit 1; + $hyp_filtering_cmd '>' $dir/scoring${scoring_affix}/penalty_$wip/LMWT.txt || exit 1; fi + fi - $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring_kaldi/penalty_$wip/log/score.LMWT.log \ - cat $dir/scoring_kaldi/penalty_$wip/LMWT.txt \| \ - compute-wer --text --mode=present \ - ark:$dir/scoring_kaldi/test_filt.txt ark,p:- ">&" $dir/wer_LMWT_$wip || exit 1; - - done -fi + if [ $stage -le 1 ]; then + $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring${scoring_affix}/penalty_$wip/log/score.LMWT.log \ + cat $dir/scoring${scoring_affix}/penalty_$wip/LMWT.txt \| \ + compute-wer --text --mode=present \ + ark:$dir/scoring${scoring_affix}/test_filt.txt ark,p:- ">&" $dir/wer_LMWT_$wip || exit 1; + fi +done -if [ $stage -le 1 ]; then +if [ $stage -le 2 ]; then for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do for lmwt in $(seq $min_lmwt $max_lmwt); do # adding /dev/null to the command list below forces grep to output the filename grep WER $dir/wer_${lmwt}_${wip} /dev/null done - done | utils/best_wer.sh >& $dir/scoring_kaldi/best_wer || exit 1 + done | utils/best_wer.sh >& $dir/scoring${scoring_affix}/best_wer || exit 1 - best_wer_file=$(awk '{print $NF}' $dir/scoring_kaldi/best_wer) + best_wer_file=$(awk '{print $NF}' $dir/scoring${scoring_affix}/best_wer) best_wip=$(echo $best_wer_file | awk -F_ '{print $NF}') best_lmwt=$(echo $best_wer_file | awk -F_ '{N=NF-1; print $N}') @@ -115,25 +116,25 @@ if [ $stage -le 1 ]; then fi if $stats; then - mkdir -p $dir/scoring_kaldi/wer_details - echo $best_lmwt > $dir/scoring_kaldi/wer_details/lmwt # record best language model weight - echo $best_wip > $dir/scoring_kaldi/wer_details/wip # record best word insertion penalty - - $cmd $dir/scoring_kaldi/log/stats1.log \ - cat $dir/scoring_kaldi/penalty_$best_wip/$best_lmwt.txt \| \ - align-text --special-symbol="'***'" ark:$dir/scoring_kaldi/test_filt.txt ark:- ark,t:- \| \ - utils/scoring/wer_per_utt_details.pl --special-symbol "'***'" \| tee $dir/scoring_kaldi/wer_details/per_utt \|\ - utils/scoring/wer_per_spk_details.pl $data/utt2spk \> $dir/scoring_kaldi/wer_details/per_spk || exit 1; - - $cmd $dir/scoring_kaldi/log/stats2.log \ - cat $dir/scoring_kaldi/wer_details/per_utt \| \ + mkdir -p $dir/scoring${scoring_affix}/wer_details + echo $best_lmwt > $dir/scoring${scoring_affix}/wer_details/lmwt # record best language model weight + echo $best_wip > $dir/scoring${scoring_affix}/wer_details/wip # record best word insertion penalty + + $cmd $dir/scoring${scoring_affix}/log/stats1.log \ + cat $dir/scoring${scoring_affix}/penalty_$best_wip/$best_lmwt.txt \| \ + align-text --special-symbol="'***'" ark:$dir/scoring${scoring_affix}/test_filt.txt ark:- ark,t:- \| \ + utils/scoring/wer_per_utt_details.pl --special-symbol "'***'" \| tee $dir/scoring${scoring_affix}/wer_details/per_utt \|\ + utils/scoring/wer_per_spk_details.pl $data/utt2spk \> $dir/scoring${scoring_affix}/wer_details/per_spk || exit 1; + + $cmd $dir/scoring${scoring_affix}/log/stats2.log \ + cat $dir/scoring${scoring_affix}/wer_details/per_utt \| \ utils/scoring/wer_ops_details.pl --special-symbol "'***'" \| \ - sort -b -i -k 1,1 -k 4,4rn -k 2,2 -k 3,3 \> $dir/scoring_kaldi/wer_details/ops || exit 1; + sort -b -i -k 1,1 -k 4,4rn -k 2,2 -k 3,3 \> $dir/scoring${scoring_affix}/wer_details/ops || exit 1; - $cmd $dir/scoring_kaldi/log/wer_bootci.log \ + $cmd $dir/scoring${scoring_affix}/log/wer_bootci.log \ compute-wer-bootci --mode=present \ - ark:$dir/scoring_kaldi/test_filt.txt ark:$dir/scoring_kaldi/penalty_$best_wip/$best_lmwt.txt \ - '>' $dir/scoring_kaldi/wer_details/wer_bootci || exit 1; + ark:$dir/scoring${scoring_affix}/test_filt.txt ark:$dir/scoring${scoring_affix}/penalty_$best_wip/$best_lmwt.txt \ + '>' $dir/scoring${scoring_affix}/wer_details/wer_bootci || exit 1; fi fi diff --git a/egs/wsj/s5/steps/subset_ali_dir.sh b/egs/wsj/s5/steps/subset_ali_dir.sh new file mode 100755 index 00000000000..c086ea39959 --- /dev/null +++ b/egs/wsj/s5/steps/subset_ali_dir.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +# Copyright 2017 Vimal Manohar +# Apache 2.0. + +cmd=run.pl + +. path.sh + +. utils/parse_options.sh + +if [ $# -ne 4 ]; then + cat < + e.g.: data/train data/train_sp exp/tri3_ali_sp exp/tri3_ali +EOF +fi + +subset_data=$1 +data=$2 +ali_dir=$3 +dir=$4 + +nj=$(cat $ali_dir/num_jobs) || exit 1 +utils/split_data.sh $data $nj + +mkdir -p $dir +cp $ali_dir/{final.mdl,*.mat,*_opts,tree} $dir/ || true +cp -r $ali_dir/phones $dir 2>/dev/null || true + +$cmd JOB=1:$nj $dir/log/copy_alignments.JOB.log \ + copy-int-vector "ark:gunzip -c $ali_dir/ali.JOB.gz |" \ + ark,scp:$dir/ali_tmp.JOB.ark,$dir/ali_tmp.JOB.scp || exit 1 + +for n in `seq $nj`; do + cat $dir/ali_tmp.$n.scp +done > $dir/ali_tmp.scp + +num_spk=$(cat $subset_data/spk2utt | wc -l) +if [ $num_spk -lt $nj ]; then + nj=$num_spk +fi + +utils/split_data.sh $subset_data $nj +$cmd JOB=1:$nj $dir/log/filter_alignments.JOB.log \ + copy-int-vector \ + "scp:utils/filter_scp.pl $subset_data/split${nj}/JOB/utt2spk $dir/ali_tmp.scp |" \ + "ark:| gzip -c > $dir/ali.JOB.gz" || exit 1 + +echo $nj > $dir/num_jobs + +rm $dir/ali_tmp.*.{ark,scp} $dir/ali_tmp.scp + +exit 0 diff --git a/src/chain/chain-supervision.cc b/src/chain/chain-supervision.cc index b5597b15667..fc22a2786f0 100644 --- a/src/chain/chain-supervision.cc +++ b/src/chain/chain-supervision.cc @@ -74,7 +74,9 @@ void ProtoSupervision::Write(std::ostream &os, bool binary) const { void SupervisionOptions::Check() const { KALDI_ASSERT(left_tolerance >= 0 && right_tolerance >= 0 && frame_subsampling_factor > 0 && - left_tolerance + right_tolerance >= frame_subsampling_factor); + left_tolerance + right_tolerance + 1 >= frame_subsampling_factor); + + KALDI_ASSERT(lm_scale >= 0.0 && lm_scale < 1.0); } bool AlignmentToProtoSupervision(const SupervisionOptions &opts, @@ -142,9 +144,10 @@ bool ProtoSupervision::operator == (const ProtoSupervision &other) const { fst::Equal(fst, other.fst)); } -bool PhoneLatticeToProtoSupervision(const SupervisionOptions &opts, - const CompactLattice &lat, - ProtoSupervision *proto_supervision) { +bool PhoneLatticeToProtoSupervisionInternalSimple( + const SupervisionOptions &opts, + const CompactLattice &lat, + ProtoSupervision *proto_supervision) { opts.Check(); if (lat.NumStates() == 0) { KALDI_WARN << "Empty lattice provided"; @@ -176,20 +179,24 @@ bool PhoneLatticeToProtoSupervision(const SupervisionOptions &opts, return false; } proto_supervision->fst.AddArc(state, - fst::StdArc(phone, phone, - fst::TropicalWeight::One(), - lat_arc.nextstate)); + fst::StdArc(phone, phone, + fst::TropicalWeight( + lat_arc.weight.Weight().Value1() + * opts.lm_scale + opts.phone_ins_penalty), + lat_arc.nextstate)); + int32 t_begin = std::max(0, (state_time - opts.left_tolerance)), t_end = std::min(num_frames, (next_state_time + opts.right_tolerance)), - t_begin_subsampled = (t_begin + factor - 1)/ factor, - t_end_subsampled = (t_end + factor - 1)/ factor; + t_begin_subsampled = (t_begin + factor - 1)/ factor, + t_end_subsampled = (t_end + factor - 1)/ factor; for (int32 t_subsampled = t_begin_subsampled; t_subsampled < t_end_subsampled; t_subsampled++) proto_supervision->allowed_phones[t_subsampled].push_back(phone); } if (lat.Final(state) != CompactLatticeWeight::Zero()) { - proto_supervision->fst.SetFinal(state, fst::TropicalWeight::One()); + proto_supervision->fst.SetFinal(state, fst::TropicalWeight( + lat.Final(state).Weight().Value1() * opts.lm_scale)); if (state_times[state] != num_frames) { KALDI_WARN << "Time of final state " << state << " in lattice is " << "not equal to number of frames " << num_frames @@ -207,6 +214,18 @@ bool PhoneLatticeToProtoSupervision(const SupervisionOptions &opts, return true; } +bool PhoneLatticeToProtoSupervision(const SupervisionOptions &opts, + const CompactLattice &lat, + ProtoSupervision *proto_supervision) { + + if (!PhoneLatticeToProtoSupervisionInternalSimple(opts, lat, proto_supervision)) + return false; + if (opts.lm_scale != 0.0) + fst::Push(&(proto_supervision->fst), + fst::REWEIGHT_TO_INITIAL, fst::kDelta, true); + + return true; +} bool TimeEnforcerFst::GetArc(StateId s, Label ilabel, fst::StdArc* oarc) { // the following call will do the range-check on 'ilabel'. @@ -657,8 +676,10 @@ bool AddWeightToSupervisionFst(const fst::StdVectorFst &normalization_fst, fst::StdVectorFst supervision_fst_noeps(supervision->fst); fst::RmEpsilon(&supervision_fst_noeps); if (!TryDeterminizeMinimize(kSupervisionMaxStates, - &supervision_fst_noeps)) + &supervision_fst_noeps)) { + KALDI_WARN << "Failed to determinize supervision fst"; return false; + } // note: by default, 'Compose' will call 'Connect', so if the // resulting FST is not connected, it will end up empty. @@ -671,8 +692,10 @@ bool AddWeightToSupervisionFst(const fst::StdVectorFst &normalization_fst, // determinize and minimize to make it as compact as possible. if (!TryDeterminizeMinimize(kSupervisionMaxStates, - &composed_fst)) + &composed_fst)) { + KALDI_WARN << "Failed to determinize normalized supervision fst"; return false; + } supervision->fst = composed_fst; // Make sure the states are numbered in increasing order of time. diff --git a/src/chain/chain-supervision.h b/src/chain/chain-supervision.h index a94f68ade90..b9ccfad8680 100644 --- a/src/chain/chain-supervision.h +++ b/src/chain/chain-supervision.h @@ -50,10 +50,16 @@ struct SupervisionOptions { int32 left_tolerance; int32 right_tolerance; int32 frame_subsampling_factor; + BaseFloat weight; + BaseFloat lm_scale; + BaseFloat phone_ins_penalty; SupervisionOptions(): left_tolerance(5), right_tolerance(5), - frame_subsampling_factor(1) { } + frame_subsampling_factor(1), + weight(1.0), + lm_scale(0.0), + phone_ins_penalty(0.0) { } void Register(OptionsItf *opts) { opts->Register("left-tolerance", &left_tolerance, "Left tolerance for " @@ -65,6 +71,13 @@ struct SupervisionOptions { "frame-rate of the original alignment. Applied after " "left-tolerance and right-tolerance are applied (so they are " "in terms of the original num-frames."); + opts->Register("weight", &weight, + "Use this to set the supervision weight for training"); + opts->Register("lm-scale", &lm_scale, "The scale with which the graph/lm " + "weights from the phone lattice are included in the " + "supervision fst."); + opts->Register("phone-ins-penalty", &phone_ins_penalty, + "The penalty to penalize longer paths"); } void Check() const; }; diff --git a/src/chainbin/nnet3-chain-copy-egs.cc b/src/chainbin/nnet3-chain-copy-egs.cc index 4f26e145ac5..b81e9a93b68 100644 --- a/src/chainbin/nnet3-chain-copy-egs.cc +++ b/src/chainbin/nnet3-chain-copy-egs.cc @@ -1,8 +1,9 @@ // chainbin/nnet3-chain-copy-egs.cc // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) -// 2014 Vimal Manohar +// 2014-2017 Vimal Manohar // 2016 Gaofeng Cheng +// 2017 Pegah Ghahremani // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,6 +27,41 @@ namespace kaldi { namespace nnet3 { +// renames name of NnetIo object from "old_name" to "new_name" +void RenameIoNames(const std::string &old_name, + const std::string &new_name, + NnetChainExample *eg_modified) { + // Get list of io-names in eg_modified. + std::vector orig_output_names; + int32 output_size = eg_modified->outputs.size(); + for (int32 output_ind = 0; output_ind < output_size; output_ind++) + orig_output_names.push_back(eg_modified->outputs[output_ind].name); + + // find the io in eg with name "old_name". + int32 rename_output_ind = + std::find(orig_output_names.begin(), orig_output_names.end(), old_name) - + orig_output_names.begin(); + + if (rename_output_ind >= output_size) + KALDI_ERR << "No io-node with name " << old_name + << "exists in eg."; + eg_modified->outputs[rename_output_ind].name = new_name; +} + +// renames NnetIo object with name 'output' to "new_output_name" +// and scales the supervision for 'output' by a factor of "weight" +void SetWeightAndRenameOutput(BaseFloat weight, + const std::string &new_output_name, + NnetChainExample *eg) { + // Scale the supervision weight for egs. + for (int32 i = 0; i < eg->outputs.size(); i++) + if (eg->outputs[i].name == "output") + if (weight != 0.0 && weight != 1.0) + eg->outputs[i].supervision.weight *= weight; + // Rename output io name to 'new_output_name'. + RenameIoNames("output", new_output_name, eg); +} + // returns an integer randomly drawn with expected value "expected_count" // (will be either floor(expected_count) or ceil(expected_count)). int32 GetCount(double expected_count) { @@ -240,6 +276,7 @@ void ModifyChainExampleContext(const NnetChainExample &eg, min_output_t, max_output_t, eg_out); } // ModifyChainExampleContext + } // namespace nnet3 } // namespace kaldi @@ -268,6 +305,8 @@ int main(int argc, char *argv[]) { int32 frame_subsampling_factor = -1; BaseFloat keep_proportion = 1.0; int32 left_context = -1, right_context = -1; + std::string eg_weight_rspecifier, eg_output_rspecifier; + ParseOptions po(usage); po.Register("random", &random, "If true, will write frames to output " "archives randomly, not round-robin."); @@ -285,6 +324,15 @@ int main(int argc, char *argv[]) { "feature left-context that we output."); po.Register("right-context", &right_context, "Can be used to truncate the " "feature right-context that we output."); + po.Register("weights", &eg_weight_rspecifier, + "Rspecifier indexed by the key of egs, providing a weight by " + "which we will scale the supervision matrix for that eg. " + "Used in multilingual training."); + po.Register("outputs", &eg_output_rspecifier, + "Rspecifier indexed by the key of egs, providing a string-valued " + "output name, e.g. 'output-0'. If provided, the NnetIo with " + "name 'output' will be renamed to the provided name. Used in " + "multilingual training."); po.Read(argc, argv); srand(srand_seed); @@ -297,6 +345,8 @@ int main(int argc, char *argv[]) { std::string examples_rspecifier = po.GetArg(1); SequentialNnetChainExampleReader example_reader(examples_rspecifier); + RandomAccessTokenReader output_reader(eg_output_rspecifier); + RandomAccessBaseFloatReader egs_weight_reader(eg_weight_rspecifier); int32 num_outputs = po.NumArgs() - 1; std::vector example_writers(num_outputs); @@ -307,8 +357,9 @@ int main(int argc, char *argv[]) { // not configurable for now. exclude_names.push_back(std::string("ivector")); - int64 num_read = 0, num_written = 0; - + int64 num_read = 0, num_written = 0, num_err = 0; + bool modify_eg_output = !(eg_output_rspecifier.empty() && + eg_weight_rspecifier.empty()); for (; !example_reader.Done(); example_reader.Next(), num_read++) { if (frame_subsampling_factor == -1) CalculateFrameSubsamplingFactor(example_reader.Value(), @@ -316,11 +367,40 @@ int main(int argc, char *argv[]) { // count is normally 1; could be 0, or possibly >1. int32 count = GetCount(keep_proportion); std::string key = example_reader.Key(); - if (frame_shift == 0 && - left_context == -1 && right_context == -1) { - const NnetChainExample &eg = example_reader.Value(); + NnetChainExample eg_modified_output; + const NnetChainExample &eg_orig = example_reader.Value(), + &eg = (modify_eg_output ? eg_modified_output : eg_orig); + // Note: in the normal case we just use 'eg'; eg_modified_output is + // for the case when the --outputs or --weights option is specified + // (only for multilingual training). + BaseFloat weight = 1.0; + std::string new_output_name; + if (modify_eg_output) { // This branch is only taken for multilingual training. + eg_modified_output = eg_orig; + if (!eg_weight_rspecifier.empty()) { + if (!egs_weight_reader.HasKey(key)) { + KALDI_WARN << "No weight for example key " << key; + num_err++; + continue; + } + weight = egs_weight_reader.Value(key); + } + if (!eg_output_rspecifier.empty()) { + if (!output_reader.HasKey(key)) { + KALDI_WARN << "No new output-name for example key " << key; + num_err++; + continue; + } + new_output_name = output_reader.Value(key); + } + } + if (frame_shift == 0 && left_context == -1 && right_context == -1) { for (int32 c = 0; c < count; c++) { int32 index = (random ? Rand() : num_written) % num_outputs; + if (modify_eg_output) // Only for multilingual training + SetWeightAndRenameOutput(weight, new_output_name, + &eg_modified_output); + example_writers[index]->Write(key, eg); num_written++; } @@ -336,6 +416,8 @@ int main(int argc, char *argv[]) { eg_out.Swap(&eg); for (int32 c = 0; c < count; c++) { int32 index = (random ? Rand() : num_written) % num_outputs; + if (modify_eg_output) + SetWeightAndRenameOutput(weight, new_output_name, &eg_out); example_writers[index]->Write(key, eg_out); num_written++; } diff --git a/src/chainbin/nnet3-chain-get-egs.cc b/src/chainbin/nnet3-chain-get-egs.cc index c8c251900ec..2287bedeb7b 100644 --- a/src/chainbin/nnet3-chain-get-egs.cc +++ b/src/chainbin/nnet3-chain-get-egs.cc @@ -43,6 +43,8 @@ static bool ProcessFile(const fst::StdVectorFst &normalization_fst, const MatrixBase *ivector_feats, int32 ivector_period, const chain::Supervision &supervision, + const VectorBase *deriv_weights, + int32 supervision_length_tolerance, const std::string &utt_id, bool compress, UtteranceSplitter *utt_splitter, @@ -51,7 +53,18 @@ static bool ProcessFile(const fst::StdVectorFst &normalization_fst, int32 num_input_frames = feats.NumRows(), num_output_frames = supervision.frames_per_sequence; - if (!utt_splitter->LengthsMatch(utt_id, num_input_frames, num_output_frames)) + int32 frame_subsampling_factor = utt_splitter->Config().frame_subsampling_factor; + + if (deriv_weights && (std::abs(deriv_weights->Dim() - num_output_frames) + > supervision_length_tolerance)) { + KALDI_WARN << "For utterance " << utt_id + << ", mismatch between deriv-weights dim and num-output-frames" + << "; " << deriv_weights->Dim() << " vs " << num_output_frames; + return false; + } + + if (!utt_splitter->LengthsMatch(utt_id, num_input_frames, num_output_frames, + supervision_length_tolerance)) return false; // LengthsMatch() will have printed a warning. std::vector chunks; @@ -65,8 +78,6 @@ static bool ProcessFile(const fst::StdVectorFst &normalization_fst, return false; } - int32 frame_subsampling_factor = utt_splitter->Config().frame_subsampling_factor; - chain::SupervisionSplitter sup_splitter(supervision); for (size_t c = 0; c < chunks.size(); c++) { @@ -88,23 +99,41 @@ static bool ProcessFile(const fst::StdVectorFst &normalization_fst, << (chunk.first_frame + chunk.num_frames) << ", FST was empty after composing with normalization FST. " << "This should be extremely rare (a few per corpus, at most)"; + return false; } int32 first_frame = 0; // we shift the time-indexes of all these parts so // that the supervised part starts from frame 0. + + NnetChainExample nnet_chain_eg; + nnet_chain_eg.outputs.resize(1); SubVector output_weights( &(chunk.output_weights[0]), static_cast(chunk.output_weights.size())); - NnetChainSupervision nnet_supervision("output", supervision_part, - output_weights, - first_frame, - frame_subsampling_factor); + if (!deriv_weights) { + NnetChainSupervision nnet_supervision("output", supervision_part, + output_weights, + first_frame, + frame_subsampling_factor); + nnet_chain_eg.outputs[0].Swap(&nnet_supervision); + } else { + Vector this_deriv_weights(num_frames_subsampled); + for (int32 i = 0; i < num_frames_subsampled; i++) { + int32 t = i + start_frame_subsampled; + if (t < deriv_weights->Dim()) + this_deriv_weights(i) = (*deriv_weights)(t); + } + KALDI_ASSERT(output_weights.Dim() == num_frames_subsampled); + this_deriv_weights.MulElements(output_weights); + NnetChainSupervision nnet_supervision("output", supervision_part, + this_deriv_weights, + first_frame, + frame_subsampling_factor); + nnet_chain_eg.outputs[0].Swap(&nnet_supervision); + } - NnetChainExample nnet_chain_eg; - nnet_chain_eg.outputs.resize(1); - nnet_chain_eg.outputs[0].Swap(&nnet_supervision); nnet_chain_eg.inputs.resize(ivector_feats != NULL ? 2 : 1); int32 tot_input_frames = chunk.left_context + chunk.num_frames + @@ -176,13 +205,15 @@ int main(int argc, char *argv[]) { "chain-get-supervision.\n"; bool compress = true; - int32 length_tolerance = 100, online_ivector_period = 1; + int32 length_tolerance = 100, online_ivector_period = 1, + supervision_length_tolerance = 1; ExampleGenerationConfig eg_config; // controls num-frames, // left/right-context, etc. + BaseFloat scale = 1.0; int32 srand_seed = 0; - std::string online_ivector_rspecifier; + std::string online_ivector_rspecifier, deriv_weights_rspecifier; ParseOptions po(usage); po.Register("compress", &compress, "If true, write egs with input features " @@ -200,6 +231,16 @@ int main(int argc, char *argv[]) { po.Register("srand", &srand_seed, "Seed for random number generator "); po.Register("length-tolerance", &length_tolerance, "Tolerance for " "difference in num-frames between feat and ivector matrices"); + po.Register("supervision-length-tolerance", &supervision_length_tolerance, "Tolerance for " + "difference in num-frames-subsampled between supervision and deriv weights"); + po.Register("deriv-weights-rspecifier", &deriv_weights_rspecifier, + "Per-frame weights (only binary - 0 or 1) that specifies " + "whether a frame's gradient must be backpropagated or not. " + "Not specifying this is equivalent to specifying a vector of " + "all 1s."); + po.Register("normalization-scale", &scale, "Scale the weights from the " + "'normalization' FST before applying them to the examples."); + eg_config.Register(&po); po.Read(argc, argv); @@ -235,6 +276,14 @@ int main(int argc, char *argv[]) { if (!normalization_fst_rxfilename.empty()) { ReadFstKaldi(normalization_fst_rxfilename, &normalization_fst); KALDI_ASSERT(normalization_fst.NumStates() > 0); + + if (scale <= 0.0) { + KALDI_ERR << "Invalid scale on normalization FST; must be > 0.0"; + } + + if (scale != 1.0) { + ScaleFst(scale, &normalization_fst); + } } // Read as GeneralMatrix so we don't need to un-compress and re-compress @@ -245,6 +294,8 @@ int main(int argc, char *argv[]) { NnetChainExampleWriter example_writer(examples_wspecifier); RandomAccessBaseFloatMatrixReader online_ivector_reader( online_ivector_rspecifier); + RandomAccessBaseFloatVectorReader deriv_weights_reader( + deriv_weights_rspecifier); int32 num_err = 0; @@ -278,10 +329,24 @@ int main(int argc, char *argv[]) { num_err++; continue; } + + const Vector *deriv_weights = NULL; + if (!deriv_weights_rspecifier.empty()) { + if (!deriv_weights_reader.HasKey(key)) { + KALDI_WARN << "No deriv weights for utterance " << key; + num_err++; + continue; + } else { + // this address will be valid until we call HasKey() or Value() + // again. + deriv_weights = &(deriv_weights_reader.Value(key)); + } + } if (!ProcessFile(normalization_fst, feats, online_ivector_feats, online_ivector_period, - supervision, key, compress, + supervision, deriv_weights, supervision_length_tolerance, + key, compress, &utt_splitter, &example_writer)) num_err++; } diff --git a/src/chainbin/nnet3-chain-normalize-egs.cc b/src/chainbin/nnet3-chain-normalize-egs.cc index 9d3f56f756a..7b99c6bd1da 100644 --- a/src/chainbin/nnet3-chain-normalize-egs.cc +++ b/src/chainbin/nnet3-chain-normalize-egs.cc @@ -41,7 +41,11 @@ int main(int argc, char *argv[]) { "e.g.\n" "nnet3-chain-normalize-egs dir/normalization.fst ark:train_in.cegs ark:train_out.cegs\n"; + BaseFloat scale = 1.0; + ParseOptions po(usage); + po.Register("normalization-scale", &scale, "Scale the weights from the " + "'normalization' FST before applying them to the examples."); po.Read(argc, argv); @@ -57,6 +61,14 @@ int main(int argc, char *argv[]) { fst::StdVectorFst normalization_fst; ReadFstKaldi(normalization_fst_rxfilename, &normalization_fst); + if (scale < 0.0) { + KALDI_ERR << "Invalid scale on normalization FST; must be >= 0.0"; + } + + if (scale != 1.0) { + ScaleFst(scale, &normalization_fst); + } + SequentialNnetChainExampleReader example_reader(examples_rspecifier); NnetChainExampleWriter example_writer(examples_wspecifier); diff --git a/src/lat/lattice-functions.cc b/src/lat/lattice-functions.cc index b04b23702fb..d48cb403f39 100644 --- a/src/lat/lattice-functions.cc +++ b/src/lat/lattice-functions.cc @@ -388,6 +388,11 @@ BaseFloat LatticeForwardBackward(const Lattice &lat, Posterior *post, if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-8)) { KALDI_WARN << "Total forward probability over lattice = " << tot_forward_prob << ", while total backward probability = " << tot_backward_prob; + + if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-2)) { + KALDI_ERR << "Total forward probability over lattice = " << tot_forward_prob + << ", while total backward probability = " << tot_backward_prob; + } } // Now combine any posteriors with the same transition-id. for (int32 t = 0; t < max_time; t++) @@ -431,7 +436,7 @@ void ConvertLatticeToPhones(const TransitionModel &trans, arc.olabel = 0; // remove any word. if ((arc.ilabel != 0) // has a transition-id on input.. && (trans.TransitionIdToHmmState(arc.ilabel) == 0) - && (!trans.IsSelfLoop(arc.ilabel))) + && (!trans.IsSelfLoop(arc.ilabel))) { // && trans.IsFinal(arc.ilabel)) // there is one of these per phone... arc.olabel = trans.TransitionIdToPhone(arc.ilabel); aiter.SetValue(arc); @@ -459,6 +464,8 @@ double ComputeLatticeAlphasAndBetas(const LatticeType &lat, StateId num_states = lat.NumStates(); KALDI_ASSERT(lat.Properties(fst::kTopSorted, true) == fst::kTopSorted); KALDI_ASSERT(lat.Start() == 0); + alpha->clear(); + beta->clear(); alpha->resize(num_states, kLogZeroDouble); beta->resize(num_states, kLogZeroDouble); @@ -495,6 +502,11 @@ double ComputeLatticeAlphasAndBetas(const LatticeType &lat, if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-8)) { KALDI_WARN << "Total forward probability over lattice = " << tot_forward_prob << ", while total backward probability = " << tot_backward_prob; + + if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-2)) { + KALDI_ERR << "Total forward probability over lattice = " << tot_forward_prob + << ", while total backward probability = " << tot_backward_prob; + } } // Split the difference when returning... they should be the same. return 0.5 * (tot_backward_prob + tot_forward_prob); @@ -1646,4 +1658,110 @@ void ComposeCompactLatticeDeterministic( fst::Connect(composed_clat); } + +void ComputeAcousticScoresMap( + const Lattice &lat, + unordered_map, std::pair, + PairHasher > *acoustic_scores) { + // typedef the arc, weight types + typedef Lattice::Arc Arc; + typedef Arc::Weight LatticeWeight; + typedef Arc::StateId StateId; + + acoustic_scores->clear(); + + std::vector state_times; + LatticeStateTimes(lat, &state_times); // Assumes the input is top sorted + + KALDI_ASSERT(lat.Start() == 0); + + for (StateId s = 0; s < lat.NumStates(); s++) { + int32 t = state_times[s]; + for (fst::ArcIterator aiter(lat, s); !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + const LatticeWeight &weight = arc.weight; + + int32 tid = arc.ilabel; + + if (tid != 0) { + unordered_map, std::pair, + PairHasher >::iterator it = acoustic_scores->find(std::make_pair(t, tid)); + if (it == acoustic_scores->end()) { + acoustic_scores->insert(std::make_pair(std::make_pair(t, tid), + std::make_pair(weight.Value2(), 1))); + } else { + if (it->second.second == 2 + && it->second.first / it->second.second != weight.Value2()) { + KALDI_VLOG(2) << "Transitions on the same frame have different " + << "acoustic costs for tid " << tid << "; " + << it->second.first / it->second.second + << " vs " << weight.Value2(); + } + it->second.first += weight.Value2(); + it->second.second++; + } + } else { + // Arcs with epsilon input label (tid) must have 0 acoustic cost + KALDI_ASSERT(weight.Value2() == 0); + } + } + + LatticeWeight f = lat.Final(s); + if (f != LatticeWeight::Zero()) { + // Final acoustic cost must be 0 as we are reading from + // non-determinized, non-compact lattice + KALDI_ASSERT(f.Value2() == 0.0); + } + } +} + +void ReplaceAcousticScoresFromMap( + const unordered_map, std::pair, + PairHasher > &acoustic_scores, + Lattice *lat) { + // typedef the arc, weight types + typedef Lattice::Arc Arc; + typedef Arc::Weight LatticeWeight; + typedef Arc::StateId StateId; + + TopSortLatticeIfNeeded(lat); + + std::vector state_times; + LatticeStateTimes(*lat, &state_times); + + KALDI_ASSERT(lat->Start() == 0); + + for (StateId s = 0; s < lat->NumStates(); s++) { + int32 t = state_times[s]; + for (fst::MutableArcIterator aiter(lat, s); + !aiter.Done(); aiter.Next()) { + Arc arc(aiter.Value()); + + int32 tid = arc.ilabel; + if (tid != 0) { + unordered_map, std::pair, + PairHasher >::const_iterator it = acoustic_scores.find(std::make_pair(t, tid)); + if (it == acoustic_scores.end()) { + KALDI_ERR << "Could not find tid " << tid << " at time " << t + << " in the acoustic scores map."; + } else { + arc.weight.SetValue2(it->second.first / it->second.second); + } + } else { + // For epsilon arcs, set acoustic cost to 0.0 + arc.weight.SetValue2(0.0); + } + aiter.SetValue(arc); + } + + LatticeWeight f = lat->Final(s); + if (f != LatticeWeight::Zero()) { + // Set final acoustic cost to 0.0 + f.SetValue2(0.0); + lat->SetFinal(s, f); + } + } +} + } // namespace kaldi diff --git a/src/lat/lattice-functions.h b/src/lat/lattice-functions.h index b4b16e6221a..abc85f9910d 100644 --- a/src/lat/lattice-functions.h +++ b/src/lat/lattice-functions.h @@ -377,6 +377,16 @@ void ComposeCompactLatticeDeterministic( fst::DeterministicOnDemandFst* det_fst, CompactLattice* composed_clat); +void ComputeAcousticScoresMap( + const Lattice &lat, + unordered_map, std::pair, + PairHasher > *acoustic_scores); + +void ReplaceAcousticScoresFromMap( + const unordered_map, std::pair, + PairHasher > &acoustic_scores, + Lattice *lat); + } // namespace kaldi #endif // KALDI_LAT_LATTICE_FUNCTIONS_H_ diff --git a/src/latbin/lattice-determinize-phone-pruned.cc b/src/latbin/lattice-determinize-phone-pruned.cc index 0959bcbcd74..ff9ece16456 100644 --- a/src/latbin/lattice-determinize-phone-pruned.cc +++ b/src/latbin/lattice-determinize-phone-pruned.cc @@ -1,6 +1,7 @@ // latbin/lattice-determinize-phone-pruned.cc // Copyright 2014 Guoguo Chen +// 2017 Vimal Manohar // See ../../COPYING for clarification regarding multiple authors // @@ -43,11 +44,13 @@ int main(int argc, char *argv[]) { " final.mdl ark:in.lats ark:det.lats\n"; ParseOptions po(usage); + bool write_compact = true; BaseFloat acoustic_scale = 1.0; BaseFloat beam = 10.0; fst::DeterminizeLatticePhonePrunedOptions opts; opts.max_mem = 50000000; + po.Register("write-compact", &write_compact, "If true, write in normal (compact) form."); po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic" " likelihoods."); po.Register("beam", &beam, "Pruning beam [applied after acoustic scaling]."); @@ -70,8 +73,13 @@ int main(int argc, char *argv[]) { // accepts. SequentialLatticeReader lat_reader(lats_rspecifier); - // Writes as compact lattice. - CompactLatticeWriter compact_lat_writer(lats_wspecifier); + CompactLatticeWriter compact_lat_writer; + LatticeWriter lat_writer; + + if (write_compact) + compact_lat_writer.Open(lats_wspecifier); + else + lat_writer.Open(lats_wspecifier); int32 n_done = 0, n_warn = 0; @@ -89,6 +97,11 @@ int main(int argc, char *argv[]) { KALDI_VLOG(2) << "Processing lattice " << key; + // Compute a map from each (t, tid) to (sum_of_acoustic_scores, count) + unordered_map, std::pair, + PairHasher > acoustic_scores; + ComputeAcousticScoresMap(lat, &acoustic_scores); + fst::ScaleLattice(fst::AcousticLatticeScale(acoustic_scale), &lat); CompactLattice det_clat; @@ -106,8 +119,19 @@ int main(int argc, char *argv[]) { sum_depth_out += depth * t; sum_t += t; - fst::ScaleLattice(fst::AcousticLatticeScale(1.0/acoustic_scale), &det_clat); - compact_lat_writer.Write(key, det_clat); + if (write_compact) { + fst::ScaleLattice(fst::AcousticLatticeScale(1.0/acoustic_scale), &det_clat); + compact_lat_writer.Write(key, det_clat); + } else{ + Lattice out_lat; + fst::ConvertLattice(det_clat, &out_lat); + + // Replace each arc (t, tid) with the averaged acoustic score from + // the computed map + ReplaceAcousticScoresFromMap(acoustic_scores, &out_lat); + lat_writer.Write(key, out_lat); + } + n_done++; } @@ -118,8 +142,8 @@ int main(int argc, char *argv[]) { << " (average num-frames = " << (sum_t / n_done) << ")."; } KALDI_LOG << "Done " << n_done << " lattices, determinization finished " - << "earlier than specified by the beam on " << n_warn << " of " - << "these."; + << "earlier than specified by the beam (or output was empty) on " + << n_warn << " of these."; return (n_done != 0 ? 0 : 1); } catch(const std::exception &e) { std::cerr << e.what(); diff --git a/src/latbin/lattice-determinize-pruned.cc b/src/latbin/lattice-determinize-pruned.cc index 3e8bca5a3ce..393d98059f5 100644 --- a/src/latbin/lattice-determinize-pruned.cc +++ b/src/latbin/lattice-determinize-pruned.cc @@ -39,6 +39,7 @@ int main(int argc, char *argv[]) { " e.g.: lattice-determinize-pruned --acoustic-scale=0.1 --beam=6.0 ark:in.lats ark:det.lats\n"; ParseOptions po(usage); + bool write_compact = true; BaseFloat acoustic_scale = 1.0; BaseFloat beam = 10.0; bool minimize = false; @@ -48,6 +49,7 @@ int main(int argc, char *argv[]) { opts.max_mem = 50000000; opts.max_loop = 0; // was 500000; + po.Register("write-compact", &write_compact, "If true, write in normal (compact) form."); po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic likelihoods"); po.Register("beam", &beam, "Pruning beam [applied after acoustic scaling]."); @@ -70,7 +72,12 @@ int main(int argc, char *argv[]) { SequentialLatticeReader lat_reader(lats_rspecifier); // Write as compact lattice. - CompactLatticeWriter compact_lat_writer(lats_wspecifier); + CompactLatticeWriter compact_lat_writer; + LatticeWriter lat_writer; + if (write_compact) + compact_lat_writer.Open(lats_wspecifier); + else + lat_writer.Open(lats_wspecifier); int32 n_done = 0, n_warn = 0; @@ -87,6 +94,11 @@ int main(int argc, char *argv[]) { KALDI_VLOG(2) << "Processing lattice " << key; + // Compute a map from each (t, tid) to (sum_of_acoustic_scores, count) + unordered_map, std::pair, + PairHasher > acoustic_scores; + ComputeAcousticScoresMap(lat, &acoustic_scores); + Invert(&lat); // so word labels are on the input side. lat_reader.FreeCurrent(); fst::ScaleLattice(fst::AcousticLatticeScale(acoustic_scale), &lat); @@ -121,8 +133,18 @@ int main(int argc, char *argv[]) { sum_depth_out += depth * t; sum_t += t; - fst::ScaleLattice(fst::AcousticLatticeScale(1.0/acoustic_scale), &det_clat); - compact_lat_writer.Write(key, det_clat); + if (write_compact) { + fst::ScaleLattice(fst::AcousticLatticeScale(1.0/acoustic_scale), &det_clat); + compact_lat_writer.Write(key, det_clat); + } else { + Lattice out_lat; + fst::ConvertLattice(det_clat, &out_lat); + + // Replace each arc (t, tid) with the averaged acoustic score from + // the computed map + ReplaceAcousticScoresFromMap(acoustic_scores, &out_lat); + lat_writer.Write(key, out_lat); + } n_done++; } diff --git a/src/latbin/lattice-scale.cc b/src/latbin/lattice-scale.cc index 5ca6012d994..58a0d2fb372 100644 --- a/src/latbin/lattice-scale.cc +++ b/src/latbin/lattice-scale.cc @@ -39,12 +39,14 @@ int main(int argc, char *argv[]) { " e.g.: lattice-scale --lm-scale=0.0 ark:1.lats ark:scaled.lats\n"; ParseOptions po(usage); + bool write_compact = true; BaseFloat acoustic_scale = 1.0; BaseFloat inv_acoustic_scale = 1.0; BaseFloat lm_scale = 1.0; BaseFloat acoustic2lm_scale = 0.0; BaseFloat lm2acoustic_scale = 0.0; + po.Register("write-compact", &write_compact, "If true, write in normal (compact) form."); po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic likelihoods"); po.Register("inv-acoustic-scale", &inv_acoustic_scale, "An alternative way " "of setting the acoustic scale: you can set its inverse."); @@ -61,14 +63,9 @@ int main(int argc, char *argv[]) { std::string lats_rspecifier = po.GetArg(1), lats_wspecifier = po.GetArg(2); - - SequentialCompactLatticeReader compact_lattice_reader(lats_rspecifier); - - // Write as compact lattice. - CompactLatticeWriter compact_lattice_writer(lats_wspecifier); - + int32 n_done = 0; - + KALDI_ASSERT(acoustic_scale == 1.0 || inv_acoustic_scale == 1.0); if (inv_acoustic_scale != 1.0) acoustic_scale = 1.0 / inv_acoustic_scale; @@ -81,12 +78,32 @@ int main(int argc, char *argv[]) { scale[1][0] = lm2acoustic_scale; scale[1][1] = acoustic_scale; - for (; !compact_lattice_reader.Done(); compact_lattice_reader.Next()) { - CompactLattice lat = compact_lattice_reader.Value(); - ScaleLattice(scale, &lat); - compact_lattice_writer.Write(compact_lattice_reader.Key(), lat); - n_done++; + if (write_compact) { + SequentialCompactLatticeReader compact_lattice_reader(lats_rspecifier); + + // Write as compact lattice. + CompactLatticeWriter compact_lattice_writer(lats_wspecifier); + + for (; !compact_lattice_reader.Done(); compact_lattice_reader.Next()) { + CompactLattice lat = compact_lattice_reader.Value(); + ScaleLattice(scale, &lat); + compact_lattice_writer.Write(compact_lattice_reader.Key(), lat); + n_done++; + } + } else { + SequentialLatticeReader lattice_reader(lats_rspecifier); + + // Write as regular lattice. + LatticeWriter lattice_writer(lats_wspecifier); + + for (; !lattice_reader.Done(); lattice_reader.Next()) { + Lattice lat = lattice_reader.Value(); + ScaleLattice(scale, &lat); + lattice_writer.Write(lattice_reader.Key(), lat); + n_done++; + } } + KALDI_LOG << "Done " << n_done << " lattices."; return (n_done != 0 ? 0 : 1); } catch(const std::exception &e) { diff --git a/src/latbin/lattice-to-fst.cc b/src/latbin/lattice-to-fst.cc index 0d2ac29a99b..30687e86232 100644 --- a/src/latbin/lattice-to-fst.cc +++ b/src/latbin/lattice-to-fst.cc @@ -22,6 +22,50 @@ #include "util/common-utils.h" #include "fstext/fstext-lib.h" #include "lat/kaldi-lattice.h" +#include "hmm/transition-model.h" + +namespace kaldi { + +void ConvertLatticeToPdfLabels( + const TransitionModel &tmodel, + const Lattice &ifst, + fst::StdVectorFst *ofst) { + typedef fst::ArcTpl ArcIn; + typedef fst::StdArc ArcOut; + typedef ArcIn::StateId StateId; + ofst->DeleteStates(); + // The states will be numbered exactly the same as the original FST. + // Add the states to the new FST. + StateId num_states = ifst.NumStates(); + for (StateId s = 0; s < num_states; s++) { + StateId news = ofst->AddState(); + assert(news == s); + } + ofst->SetStart(ifst.Start()); + for (StateId s = 0; s < num_states; s++) { + LatticeWeight final_iweight = ifst.Final(s); + if (final_iweight != LatticeWeight::Zero()) { + fst::TropicalWeight final_oweight; + ConvertLatticeWeight(final_iweight, &final_oweight); + ofst->SetFinal(s, final_oweight); + } + for (fst::ArcIterator iter(ifst, s); + !iter.Done(); + iter.Next()) { + ArcIn arc = iter.Value(); + KALDI_PARANOID_ASSERT(arc.weight != LatticeWeight::Zero()); + ArcOut oarc; + ConvertLatticeWeight(arc.weight, &oarc.weight); + oarc.ilabel = tmodel.TransitionIdToPdf(arc.ilabel) + 1; + oarc.olabel = arc.olabel; + oarc.nextstate = arc.nextstate; + ofst->AddArc(s, oarc); + } + } +} + +} + int main(int argc, char *argv[]) { try { @@ -34,8 +78,10 @@ int main(int argc, char *argv[]) { using std::vector; BaseFloat acoustic_scale = 0.0; BaseFloat lm_scale = 0.0; - bool rm_eps = true; - + bool rm_eps = true, read_compact = true, convert_to_pdf_labels = false; + std::string trans_model; + bool project_input = false, project_output = true; + const char *usage = "Turn lattices into normal FSTs, retaining only the word labels\n" "By default, removes all weights and also epsilons (configure with\n" @@ -44,9 +90,20 @@ int main(int argc, char *argv[]) { " e.g.: lattice-to-fst ark:1.lats ark:1.fsts\n"; ParseOptions po(usage); + po.Register("read-compact", &read_compact, "Read compact lattice"); po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic likelihoods"); po.Register("lm-scale", &lm_scale, "Scaling factor for graph/lm costs"); po.Register("rm-eps", &rm_eps, "Remove epsilons in resulting FSTs (in lazy way; may not remove all)"); + po.Register("convert-to-pdf-labels", &convert_to_pdf_labels, + "Convert lattice to pdf labels"); + po.Register("trans-model", &trans_model, + "Transition model"); + po.Register("project-input", &project_input, + "Project to input labels (transition-ids); applicable only " + "when --read-compact=false"); + po.Register("project-output", &project_output, + "Project to output labels (transition-ids); applicable only " + "when --read-compact=false"); po.Read(argc, argv); @@ -60,31 +117,70 @@ int main(int argc, char *argv[]) { std::string lats_rspecifier = po.GetArg(1), fsts_wspecifier = po.GetArg(2); - SequentialCompactLatticeReader lattice_reader(lats_rspecifier); + TransitionModel tmodel; + if (!trans_model.empty()) { + ReadKaldiObject(trans_model, &tmodel); + } + + SequentialCompactLatticeReader compact_lattice_reader; + SequentialLatticeReader lattice_reader; + TableWriter fst_writer(fsts_wspecifier); int32 n_done = 0; // there is no failure mode, barring a crash. - for (; !lattice_reader.Done(); lattice_reader.Next()) { - std::string key = lattice_reader.Key(); - CompactLattice clat = lattice_reader.Value(); - lattice_reader.FreeCurrent(); - ScaleLattice(scale, &clat); // typically scales to zero. - RemoveAlignmentsFromCompactLattice(&clat); // remove the alignments... - fst::VectorFst fst; - { - Lattice lat; - ConvertLattice(clat, &lat); // convert to non-compact form.. won't introduce - // extra states because already removed alignments. - ConvertLattice(lat, &fst); // this adds up the (lm,acoustic) costs to get - // the normal (tropical) costs. - Project(&fst, fst::PROJECT_OUTPUT); // Because in the standard Lattice format, - // the words are on the output, and we want the word labels. + if (read_compact) { + SequentialCompactLatticeReader compact_lattice_reader(lats_rspecifier); + for (; !compact_lattice_reader.Done(); compact_lattice_reader.Next()) { + std::string key = compact_lattice_reader.Key(); + CompactLattice clat = compact_lattice_reader.Value(); + compact_lattice_reader.FreeCurrent(); + ScaleLattice(scale, &clat); // typically scales to zero. + RemoveAlignmentsFromCompactLattice(&clat); // remove the alignments... + fst::VectorFst fst; + { + Lattice lat; + ConvertLattice(clat, &lat); // convert to non-compact form.. won't introduce + // extra states because already removed alignments. + + if (convert_to_pdf_labels) { + ConvertLatticeToPdfLabels(tmodel, lat, &fst); // this adds up the (lm,acoustic) costs to get + // the normal (tropical) costs. + } else { + ConvertLattice(lat, &fst); + } + + Project(&fst, fst::PROJECT_OUTPUT); // Because in the standard compact_lattice format, + // the words are on the output, and we want the word labels. + } + if (rm_eps) RemoveEpsLocal(&fst); + + fst_writer.Write(key, fst); + n_done++; } - if (rm_eps) RemoveEpsLocal(&fst); - - fst_writer.Write(key, fst); - n_done++; + } else { + SequentialLatticeReader lattice_reader(lats_rspecifier); + for (; !lattice_reader.Done(); lattice_reader.Next()) { + std::string key = lattice_reader.Key(); + Lattice lat = lattice_reader.Value(); + lattice_reader.FreeCurrent(); + ScaleLattice(scale, &lat); // typically scales to zero. + fst::VectorFst fst; + if (convert_to_pdf_labels) { + ConvertLatticeToPdfLabels(tmodel, lat, &fst); + } else { + ConvertLattice(lat, &fst); + } + if (project_input) + Project(&fst, fst::PROJECT_INPUT); + else if (project_output) + Project(&fst, fst::PROJECT_OUTPUT); + if (rm_eps) RemoveEpsLocal(&fst); + + fst_writer.Write(key, fst); + n_done++; + } + } KALDI_LOG << "Done converting " << n_done << " lattices to word-level FSTs"; return (n_done != 0 ? 0 : 1); diff --git a/src/nnet3/nnet-chain-diagnostics.cc b/src/nnet3/nnet-chain-diagnostics.cc index 084b33347df..8846bb0b069 100644 --- a/src/nnet3/nnet-chain-diagnostics.cc +++ b/src/nnet3/nnet-chain-diagnostics.cc @@ -60,6 +60,7 @@ NnetChainComputeProb::NnetChainComputeProb( deriv_nnet_owned_(false), deriv_nnet_(nnet), num_minibatches_processed_(0) { + KALDI_ASSERT(den_graph_.NumPdfs() > 0); KALDI_ASSERT(nnet_config.store_component_stats && !nnet_config.compute_deriv); } @@ -207,6 +208,27 @@ bool NnetChainComputeProb::PrintTotalStats() const { } +std::pair NnetChainComputeProb::GetTotalObjective() const { + unordered_map::const_iterator + iter, end; + iter = objf_info_.begin(); + end = objf_info_.end(); + BaseFloat tot_objf = 0.0, tot_weight = 0.0; + for (; iter != end; ++iter) { + const std::string &name = iter->first; + int32 node_index = nnet_.GetNodeIndex(name); + KALDI_ASSERT(node_index >= 0); + const ChainObjectiveInfo &info = iter->second; + BaseFloat like = (info.tot_like / info.tot_weight); + ObjectiveValues aux_objfs(info.tot_aux_objfs); + aux_objfs.Scale(info.tot_weight); + tot_objf += like + aux_objfs.Sum(); + tot_weight += info.tot_weight; + } + return std::make_pair(tot_objf, tot_weight); +} + + const ChainObjectiveInfo* NnetChainComputeProb::GetObjective( const std::string &output_name) const { unordered_map::const_iterator @@ -217,15 +239,29 @@ const ChainObjectiveInfo* NnetChainComputeProb::GetObjective( return NULL; } +static bool HasXentOutputs(const Nnet &nnet) { + const std::vector node_names = nnet.GetNodeNames(); + for (std::vector::const_iterator it = node_names.begin(); + it != node_names.end(); ++it) { + int32 node_index = nnet.GetNodeIndex(*it); + if (nnet.IsOutputNode(node_index) && + it->find("-xent") != std::string::npos) { + return true; + } + } + return false; +} + void RecomputeStats(const std::vector &egs, const chain::ChainTrainingOptions &chain_config_in, const fst::StdVectorFst &den_fst, Nnet *nnet) { KALDI_LOG << "Recomputing stats on nnet (affects batch-norm)"; chain::ChainTrainingOptions chain_config(chain_config_in); - if (nnet->GetNodeIndex("output-xent") != -1 && + if (HasXentOutputs(*nnet) && chain_config.xent_regularize == 0) { - // this forces it to compute the output for 'output-xent', which + // this forces it to compute the output for xent outputs, + // usually 'output-xent', which // means that we'll be computing batch-norm stats for any // components in that branch that have batch-norm. chain_config.xent_regularize = 0.1; diff --git a/src/nnet3/nnet-chain-example.cc b/src/nnet3/nnet-chain-example.cc index 351312fb952..d40df1a79f9 100644 --- a/src/nnet3/nnet-chain-example.cc +++ b/src/nnet3/nnet-chain-example.cc @@ -31,8 +31,8 @@ void NnetChainSupervision::Write(std::ostream &os, bool binary) const { WriteToken(os, binary, name); WriteIndexVector(os, binary, indexes); supervision.Write(os, binary); - WriteToken(os, binary, ""); // for DerivWeights. Want to save space. - WriteVectorAsChar(os, binary, deriv_weights); + WriteToken(os, binary, ""); // for DerivWeights. Want to save space. + deriv_weights.Write(os, binary); WriteToken(os, binary, ""); } @@ -51,8 +51,11 @@ void NnetChainSupervision::Read(std::istream &is, bool binary) { ReadToken(is, binary, &token); // in the future this back-compatibility code can be reworked. if (token != "") { - KALDI_ASSERT(token == ""); - ReadVectorAsChar(is, binary, &deriv_weights); + KALDI_ASSERT(token == "" || token == ""); + if (token == "") + ReadVectorAsChar(is, binary, &deriv_weights); + else + deriv_weights.Read(is, binary); ExpectToken(is, binary, ""); } CheckDim(); @@ -82,8 +85,7 @@ void NnetChainSupervision::CheckDim() const { } if (deriv_weights.Dim() != 0) { KALDI_ASSERT(deriv_weights.Dim() == indexes.size()); - KALDI_ASSERT(deriv_weights.Min() >= 0.0 && - deriv_weights.Max() <= 1.0); + KALDI_ASSERT(deriv_weights.Min() >= 0.0); } }