diff --git a/.gitignore b/.gitignore index df7cb26de9f..4cf0fa4efa9 100644 --- a/.gitignore +++ b/.gitignore @@ -73,15 +73,17 @@ GSYMS /src/kaldi.mk.bak # /egs/ -/egs/*/s*/mfcc -/egs/*/s*/plp -/egs/*/s*/exp -/egs/*/s*/data +/egs/*/*/mfcc +/egs/*/*/plp +/egs/*/*/exp +/egs/*/*/data # /tools/ +/tools/pocolm/ /tools/ATLAS/ /tools/atlas3.8.3.tar.gz /tools/irstlm/ +/tools/mitlm/ /tools/openfst /tools/openfst-1.3.2.tar.gz /tools/openfst-1.3.2/ @@ -143,3 +145,6 @@ GSYMS /tools/mmseg-1.3.0.tar.gz /tools/mmseg-1.3.0/ /kaldiwin_vs* +/tools/cub-1.8.0.zip +/tools/cub-1.8.0/ +/tools/cub diff --git a/.travis.yml b/.travis.yml index 23507297413..51e49653efc 100644 --- a/.travis.yml +++ b/.travis.yml @@ -49,7 +49,7 @@ script: # for the explanation why extra switches needed for clang with ccache. - CXX="ccache clang++-3.8 -Qunused-arguments -fcolor-diagnostics -Wno-tautological-compare" CFLAGS="" - LDFLAGS="-llapack" + LDFLAGS="-llapack -Wl,-fuse-ld=gold" INCDIRS="$XROOT/usr/include" LIBDIRS="$XROOT/usr/lib" tools/extras/travis_script.sh diff --git a/egs/aishell/s5/RESULTS b/egs/aishell/s5/RESULTS index b58ede148c4..b6155cb62d4 100644 --- a/egs/aishell/s5/RESULTS +++ b/egs/aishell/s5/RESULTS @@ -1,8 +1,18 @@ -%WER 33.82 [ 35432 / 104765, 743 ins, 3991 del, 30698 sub ] exp/mono/decode_test/cer_12_0.0 -%WER 19.39 [ 20310 / 104765, 903 ins, 1452 del, 17955 sub ] exp/tri1/decode_test/cer_13_0.5 -%WER 19.23 [ 20147 / 104765, 910 ins, 1287 del, 17950 sub ] exp/tri2/decode_test/cer_14_0.5 -%WER 17.14 [ 17961 / 104765, 812 ins, 1024 del, 16125 sub ] exp/tri3a/decode_test/cer_14_0.0 -%WER 13.64 [ 14294 / 104765, 669 ins, 736 del, 12889 sub ] exp/tri4a/decode_test/cer_14_0.5 -%WER 12.23 [ 12809 / 104765, 656 ins, 580 del, 11573 sub ] exp/tri5a/decode_test/cer_13_1.0 -%WER 8.45 [ 8849 / 104765, 312 ins, 538 del, 7999 sub ] exp/nnet3/tdnn_sp/decode_test/cer_13_1.0 -%WER 7.46 [ 7813 / 104765, 287 ins, 472 del, 7054 sub ] exp/chain/tdnn_1a_sp/decode_test/cer_10_1.0 +%WER 36.41 [ 38146 / 104765, 837 ins, 3114 del, 34195 sub ] exp/mono/decode_test/cer_10_0.0 +%WER 18.76 [ 19654 / 104765, 949 ins, 1152 del, 17553 sub ] exp/tri1/decode_test/cer_13_0.5 +%WER 18.64 [ 19531 / 104765, 941 ins, 1159 del, 17431 sub ] exp/tri2/decode_test/cer_14_0.5 +%WER 17.04 [ 17849 / 104765, 810 ins, 1021 del, 16018 sub ] exp/tri3a/decode_test/cer_14_0.5 +%WER 13.82 [ 14482 / 104765, 764 ins, 670 del, 13048 sub ] exp/tri4a/decode_test/cer_13_0.5 +%WER 12.12 [ 12694 / 104765, 751 ins, 523 del, 11420 sub ] exp/tri5a/decode_test/cer_13_0.5 +%WER 8.65 [ 9064 / 104765, 367 ins, 455 del, 8242 sub ] exp/nnet3/tdnn_sp/decode_test/cer_14_0.5 +%WER 7.48 [ 7839 / 104765, 285 ins, 454 del, 7100 sub ] exp/chain/tdnn_1a_sp/decode_test/cer_10_1.0 + +# nnet3 tdnn with online pitch, local/nnet3/tuning/tun_tdnn_2a.sh +%WER 8.64 [ 9050 / 104765, 349 ins, 521 del, 8180 sub ] exp/nnet3/tdnn_sp/decode_test/cer_15_0.5 +%WER 8.72 [ 9135 / 104765, 367 ins, 422 del, 8346 sub ] exp/nnet3/tdnn_sp_online/decode_test/cer_12_1.0 +%WER 9.36 [ 9807 / 104765, 386 ins, 441 del, 8980 sub ] exp/nnet3/tdnn_sp_online/decode_test_per_utt/cer_13_1.0 + +# chain with online pitch, local/chain/tuning/run_tdnn_2a.sh +%WER 7.45 [ 7807 / 104765, 340 ins, 497 del, 6970 sub ] exp/chain/tdnn_2a_sp/decode_test/cer_11_0.5 +%WER 7.43 [ 7780 / 104765, 341 ins, 469 del, 6970 sub ] exp/chain/tdnn_2a_sp_online/decode_test/cer_11_0.5 +%WER 7.92 [ 8296 / 104765, 384 ins, 472 del, 7440 sub ] exp/chain/tdnn_2a_sp_online/decode_test_per_utt/cer_11_0.5 diff --git a/egs/aishell/s5/conf/online_pitch.conf b/egs/aishell/s5/conf/online_pitch.conf new file mode 100644 index 00000000000..c0f1342160d --- /dev/null +++ b/egs/aishell/s5/conf/online_pitch.conf @@ -0,0 +1,4 @@ +--sample-frequency=16000 +--simulate-first-pass-online=true +--normalization-right-context=25 +--frames-per-chunk=10 diff --git a/egs/aishell/s5/local/aishell_prepare_dict.sh b/egs/aishell/s5/local/aishell_prepare_dict.sh index 3763622a3e7..c4cabb24de4 100755 --- a/egs/aishell/s5/local/aishell_prepare_dict.sh +++ b/egs/aishell/s5/local/aishell_prepare_dict.sh @@ -15,21 +15,9 @@ mkdir -p $dict_dir cp $res_dir/lexicon.txt $dict_dir cat $dict_dir/lexicon.txt | awk '{ for(n=2;n<=NF;n++){ phones[$n] = 1; }} END{for (p in phones) print p;}'| \ - sort -u |\ - perl -e ' - my %ph_cl; - while () { - $phone = $_; - chomp($phone); - chomp($_); - $phone = $_; - next if ($phone eq "sil"); - if (exists $ph_cl{$phone}) { push(@{$ph_cl{$phone}}, $_) } - else { $ph_cl{$phone} = [$_]; } - } - foreach $key ( keys %ph_cl ) { - print "@{ $ph_cl{$key} }\n" - } + perl -e 'while(<>){ chomp($_); $phone = $_; next if ($phone eq "sil"); + m:^([^\d]+)(\d*)$: || die "Bad phone $_"; $q{$1} .= "$phone "; } + foreach $l (values %q) {print "$l\n";} ' | sort -k1 > $dict_dir/nonsilence_phones.txt || exit 1; echo sil > $dict_dir/silence_phones.txt diff --git a/egs/aishell/s5/local/chain/tuning/run_tdnn_1a.sh b/egs/aishell/s5/local/chain/tuning/run_tdnn_1a.sh index a0b183e3c5a..b38fa4d9c7a 100755 --- a/egs/aishell/s5/local/chain/tuning/run_tdnn_1a.sh +++ b/egs/aishell/s5/local/chain/tuning/run_tdnn_1a.sh @@ -90,7 +90,7 @@ if [ $stage -le 10 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/aishell/s5/local/chain/tuning/run_tdnn_2a.sh b/egs/aishell/s5/local/chain/tuning/run_tdnn_2a.sh new file mode 100755 index 00000000000..6b7223785d9 --- /dev/null +++ b/egs/aishell/s5/local/chain/tuning/run_tdnn_2a.sh @@ -0,0 +1,211 @@ +#!/bin/bash + +# This script is based on run_tdnn_1a.sh. +# This setup used online pitch to train the neural network. +# It requires a online_pitch.conf in the conf dir. + +set -e + +# configs for 'chain' +affix= +stage=0 +train_stage=-10 +get_egs_stage=-10 +dir=exp/chain/tdnn_2a # Note: _sp will get added to this +decode_iter= + +# training options +num_epochs=4 +initial_effective_lrate=0.001 +final_effective_lrate=0.0001 +max_param_change=2.0 +final_layer_normalize_target=0.5 +num_jobs_initial=2 +num_jobs_final=12 +minibatch_size=128 +frames_per_eg=150,110,90 +remove_egs=true +common_egs_dir= +xent_regularize=0.1 + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 9 ]; then + # Build a tree using our new topology. This is the critically different + # step compared with other recipes. + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 5000 data/$train_set $lang $ali_dir $treedir +fi + +if [ $stage -le 10 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=43 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=625 + relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1) dim=625 + relu-batchnorm-layer name=tdnn3 input=Append(-1,0,1) dim=625 + relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=625 + relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=625 + relu-batchnorm-layer name=tdnn6 input=Append(-3,0,3) dim=625 + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain input=tdnn6 dim=625 target-rms=0.5 + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn6 dim=625 target-rms=0.5 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 + +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 11 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/aishell-$(date +'%m_%d_%H_%M')/s5c/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir exp/nnet3/ivectors_${train_set} \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width $frames_per_eg \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.num-jobs-initial $num_jobs_initial \ + --trainer.optimization.num-jobs-final $num_jobs_final \ + --trainer.optimization.initial-effective-lrate $initial_effective_lrate \ + --trainer.optimization.final-effective-lrate $final_effective_lrate \ + --trainer.max-param-change $max_param_change \ + --cleanup.remove-egs $remove_egs \ + --feat-dir data/${train_set}_hires_online \ + --tree-dir $treedir \ + --lat-dir exp/tri5a_sp_lats \ + --dir $dir || exit 1; +fi + +if [ $stage -le 12 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 data/lang_test $dir $dir/graph +fi + +graph_dir=$dir/graph +if [ $stage -le 13 ]; then + for test_set in dev test; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj 10 --cmd "$decode_cmd" \ + --online-ivector-dir exp/nnet3/ivectors_$test_set \ + $graph_dir data/${test_set}_hires_online $dir/decode_${test_set} || exit 1; + done +fi + +if [ $stage -le 14 ]; then + steps/online/nnet3/prepare_online_decoding.sh --mfcc-config conf/mfcc_hires.conf \ + --add-pitch true \ + $lang exp/nnet3/extractor "$dir" ${dir}_online || exit 1; +fi + +dir=${dir}_online +if [ $stage -le 15 ]; then + for test_set in dev test; do + steps/online/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj 10 --cmd "$decode_cmd" \ + --config conf/decode.config \ + $graph_dir data/${test_set}_hires_online $dir/decode_${test_set} || exit 1; + done +fi + +if [ $stage -le 16 ]; then + for test_set in dev test; do + steps/online/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj 10 --cmd "$decode_cmd" --per-utt true \ + --config conf/decode.config \ + $graph_dir data/${test_set}_hires_online $dir/decode_${test_set}_per_utt || exit 1; + done +fi + +exit; diff --git a/egs/aishell/s5/local/nnet3/run_ivector_common.sh b/egs/aishell/s5/local/nnet3/run_ivector_common.sh index 1643e6381b1..af0ae122372 100755 --- a/egs/aishell/s5/local/nnet3/run_ivector_common.sh +++ b/egs/aishell/s5/local/nnet3/run_ivector_common.sh @@ -14,7 +14,7 @@ stage=0 train_set=train test_sets="dev test" gmm=tri5a - +online=false nnet3_affix= . ./cmd.sh @@ -31,6 +31,11 @@ for f in data/${train_set}/feats.scp ${gmm_dir}/final.mdl; do fi done +online_affix= +if [ $online = true ]; then + online_affix=_online +fi + if [ $stage -le 1 ]; then # Although the nnet will be trained by high resolution data, we still have to # perturb the normal data to get the alignment _sp stands for speed-perturbed @@ -54,26 +59,26 @@ if [ $stage -le 3 ]; then # Create high-resolution MFCC features (with 40 cepstra instead of 13). # this shows how you can split across multiple file-systems. echo "$0: creating high-resolution MFCC features" - mfccdir=mfcc_perturbed_hires + mfccdir=mfcc_perturbed_hires$online_affix if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then utils/create_split_dir.pl /export/b0{5,6,7,8}/$USER/kaldi-data/mfcc/aishell-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage fi for datadir in ${train_set}_sp ${test_sets}; do - utils/copy_data_dir.sh data/$datadir data/${datadir}_hires + utils/copy_data_dir.sh data/$datadir data/${datadir}_hires$online_affix done # do volume-perturbation on the training data prior to extracting hires # features; this helps make trained nnets more invariant to test data volume. - utils/data/perturb_data_dir_volume.sh data/${train_set}_sp_hires || exit 1; + utils/data/perturb_data_dir_volume.sh data/${train_set}_sp_hires$online_affix || exit 1; for datadir in ${train_set}_sp ${test_sets}; do - steps/make_mfcc_pitch.sh --nj 10 --mfcc-config conf/mfcc_hires.conf \ - --cmd "$train_cmd" data/${datadir}_hires exp/make_hires/$datadir $mfccdir || exit 1; - steps/compute_cmvn_stats.sh data/${datadir}_hires exp/make_hires/$datadir $mfccdir || exit 1; - utils/fix_data_dir.sh data/${datadir}_hires || exit 1; + steps/make_mfcc_pitch$online_affix.sh --nj 10 --mfcc-config conf/mfcc_hires.conf \ + --cmd "$train_cmd" data/${datadir}_hires$online_affix exp/make_hires/$datadir $mfccdir || exit 1; + steps/compute_cmvn_stats.sh data/${datadir}_hires$online_affix exp/make_hires/$datadir $mfccdir || exit 1; + utils/fix_data_dir.sh data/${datadir}_hires$online_affix || exit 1; # create MFCC data dir without pitch to extract iVector - utils/data/limit_feature_dim.sh 0:39 data/${datadir}_hires data/${datadir}_hires_nopitch || exit 1; + utils/data/limit_feature_dim.sh 0:39 data/${datadir}_hires$online_affix data/${datadir}_hires_nopitch || exit 1; steps/compute_cmvn_stats.sh data/${datadir}_hires_nopitch exp/make_hires/$datadir $mfccdir || exit 1; done fi diff --git a/egs/aishell/s5/local/nnet3/run_tdnn.sh b/egs/aishell/s5/local/nnet3/run_tdnn.sh deleted file mode 100755 index 3cb8cd861a3..00000000000 --- a/egs/aishell/s5/local/nnet3/run_tdnn.sh +++ /dev/null @@ -1,117 +0,0 @@ -#!/bin/bash - -# This script is based on swbd/s5c/local/nnet3/run_tdnn.sh - -# this is the standard "tdnn" system, built in nnet3; it's what we use to -# call multi-splice. - -# At this script level we don't support not running on GPU, as it would be painfully slow. -# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, -# --num-threads 16 and --minibatch-size 128. -set -e - -stage=0 -train_stage=-10 -affix= -common_egs_dir= - -# training options -initial_effective_lrate=0.0015 -final_effective_lrate=0.00015 -num_epochs=4 -num_jobs_initial=2 -num_jobs_final=12 -remove_egs=true - -# feature options -use_ivectors=true - -# End configuration section. - -. ./cmd.sh -. ./path.sh -. ./utils/parse_options.sh - -if ! cuda-compiled; then - cat < $dir/configs/network.xconfig - input dim=100 name=ivector - input dim=43 name=input - - # please note that it is important to have input layer with the name=input - # as the layer immediately preceding the fixed-affine-layer to enable - # the use of short notation for the descriptor - fixed-affine-layer name=lda input=Append(-2,-1,0,1,2,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat - - # the first splicing is moved before the lda layer, so no splicing here - relu-batchnorm-layer name=tdnn1 dim=850 - relu-batchnorm-layer name=tdnn2 dim=850 input=Append(-1,0,2) - relu-batchnorm-layer name=tdnn3 dim=850 input=Append(-3,0,3) - relu-batchnorm-layer name=tdnn4 dim=850 input=Append(-7,0,2) - relu-batchnorm-layer name=tdnn5 dim=850 input=Append(-3,0,3) - relu-batchnorm-layer name=tdnn6 dim=850 - output-layer name=output input=tdnn6 dim=$num_targets max-change=1.5 -EOF - steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ -fi - -if [ $stage -le 8 ]; then - if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then - utils/create_split_dir.pl \ - /export/b0{5,6,7,8}/$USER/kaldi-data/egs/aishell-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage - fi - - steps/nnet3/train_dnn.py --stage=$train_stage \ - --cmd="$decode_cmd" \ - --feat.online-ivector-dir exp/nnet3/ivectors_${train_set} \ - --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ - --trainer.num-epochs $num_epochs \ - --trainer.optimization.num-jobs-initial $num_jobs_initial \ - --trainer.optimization.num-jobs-final $num_jobs_final \ - --trainer.optimization.initial-effective-lrate $initial_effective_lrate \ - --trainer.optimization.final-effective-lrate $final_effective_lrate \ - --egs.dir "$common_egs_dir" \ - --cleanup.remove-egs $remove_egs \ - --cleanup.preserve-model-interval 500 \ - --use-gpu true \ - --feat-dir=data/${train_set}_hires \ - --ali-dir $ali_dir \ - --lang data/lang \ - --reporting.email="$reporting_email" \ - --dir=$dir || exit 1; -fi - -if [ $stage -le 9 ]; then - # this version of the decoding treats each utterance separately - # without carrying forward speaker information. - for decode_set in dev test; do - num_jobs=`cat data/${decode_set}_hires/utt2spk|cut -d' ' -f2|sort -u|wc -l` - decode_dir=${dir}/decode_$decode_set - steps/nnet3/decode.sh --nj $num_jobs --cmd "$decode_cmd" \ - --online-ivector-dir exp/nnet3/ivectors_${decode_set} \ - $graph_dir data/${decode_set}_hires $decode_dir || exit 1; - done -fi - -wait; -exit 0; diff --git a/egs/aishell/s5/local/nnet3/run_tdnn.sh b/egs/aishell/s5/local/nnet3/run_tdnn.sh new file mode 120000 index 00000000000..34499362831 --- /dev/null +++ b/egs/aishell/s5/local/nnet3/run_tdnn.sh @@ -0,0 +1 @@ +tuning/run_tdnn_1a.sh \ No newline at end of file diff --git a/egs/aishell/s5/local/nnet3/tuning/run_tdnn_1a.sh b/egs/aishell/s5/local/nnet3/tuning/run_tdnn_1a.sh new file mode 100755 index 00000000000..3cb8cd861a3 --- /dev/null +++ b/egs/aishell/s5/local/nnet3/tuning/run_tdnn_1a.sh @@ -0,0 +1,117 @@ +#!/bin/bash + +# This script is based on swbd/s5c/local/nnet3/run_tdnn.sh + +# this is the standard "tdnn" system, built in nnet3; it's what we use to +# call multi-splice. + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. +set -e + +stage=0 +train_stage=-10 +affix= +common_egs_dir= + +# training options +initial_effective_lrate=0.0015 +final_effective_lrate=0.00015 +num_epochs=4 +num_jobs_initial=2 +num_jobs_final=12 +remove_egs=true + +# feature options +use_ivectors=true + +# End configuration section. + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=43 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-2,-1,0,1,2,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=850 + relu-batchnorm-layer name=tdnn2 dim=850 input=Append(-1,0,2) + relu-batchnorm-layer name=tdnn3 dim=850 input=Append(-3,0,3) + relu-batchnorm-layer name=tdnn4 dim=850 input=Append(-7,0,2) + relu-batchnorm-layer name=tdnn5 dim=850 input=Append(-3,0,3) + relu-batchnorm-layer name=tdnn6 dim=850 + output-layer name=output input=tdnn6 dim=$num_targets max-change=1.5 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 8 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/aishell-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/train_dnn.py --stage=$train_stage \ + --cmd="$decode_cmd" \ + --feat.online-ivector-dir exp/nnet3/ivectors_${train_set} \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.num-jobs-initial $num_jobs_initial \ + --trainer.optimization.num-jobs-final $num_jobs_final \ + --trainer.optimization.initial-effective-lrate $initial_effective_lrate \ + --trainer.optimization.final-effective-lrate $final_effective_lrate \ + --egs.dir "$common_egs_dir" \ + --cleanup.remove-egs $remove_egs \ + --cleanup.preserve-model-interval 500 \ + --use-gpu true \ + --feat-dir=data/${train_set}_hires \ + --ali-dir $ali_dir \ + --lang data/lang \ + --reporting.email="$reporting_email" \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 9 ]; then + # this version of the decoding treats each utterance separately + # without carrying forward speaker information. + for decode_set in dev test; do + num_jobs=`cat data/${decode_set}_hires/utt2spk|cut -d' ' -f2|sort -u|wc -l` + decode_dir=${dir}/decode_$decode_set + steps/nnet3/decode.sh --nj $num_jobs --cmd "$decode_cmd" \ + --online-ivector-dir exp/nnet3/ivectors_${decode_set} \ + $graph_dir data/${decode_set}_hires $decode_dir || exit 1; + done +fi + +wait; +exit 0; diff --git a/egs/aishell/s5/local/nnet3/tuning/run_tdnn_2a.sh b/egs/aishell/s5/local/nnet3/tuning/run_tdnn_2a.sh new file mode 100755 index 00000000000..603149585f2 --- /dev/null +++ b/egs/aishell/s5/local/nnet3/tuning/run_tdnn_2a.sh @@ -0,0 +1,145 @@ +#!/bin/bash + +# This script is based on aishell/s5/local/nnet3/tuning/run_tdnn_1a.sh + +# In this script, the neural network in trained based on hires mfcc and online pitch. +# The online pitch setup requires a online_pitch.conf in the conf dir for both training +# and testing. + +set -e + +stage=0 +train_stage=-10 +affix= +common_egs_dir= + +# training options +initial_effective_lrate=0.0015 +final_effective_lrate=0.00015 +num_epochs=4 +num_jobs_initial=2 +num_jobs_final=12 +remove_egs=true + +# feature options +use_ivectors=true + +# End configuration section. + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=43 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-2,-1,0,1,2,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=850 + relu-batchnorm-layer name=tdnn2 dim=850 input=Append(-1,0,2) + relu-batchnorm-layer name=tdnn3 dim=850 input=Append(-3,0,3) + relu-batchnorm-layer name=tdnn4 dim=850 input=Append(-7,0,2) + relu-batchnorm-layer name=tdnn5 dim=850 input=Append(-3,0,3) + relu-batchnorm-layer name=tdnn6 dim=850 + output-layer name=output input=tdnn6 dim=$num_targets max-change=1.5 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 8 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/aishell-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/train_dnn.py --stage=$train_stage \ + --cmd="$decode_cmd" \ + --feat.online-ivector-dir exp/nnet3/ivectors_${train_set} \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.num-jobs-initial $num_jobs_initial \ + --trainer.optimization.num-jobs-final $num_jobs_final \ + --trainer.optimization.initial-effective-lrate $initial_effective_lrate \ + --trainer.optimization.final-effective-lrate $final_effective_lrate \ + --egs.dir "$common_egs_dir" \ + --cleanup.remove-egs $remove_egs \ + --cleanup.preserve-model-interval 500 \ + --use-gpu true \ + --feat-dir=data/${train_set}_hires_online \ + --ali-dir $ali_dir \ + --lang data/lang \ + --reporting.email="$reporting_email" \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 9 ]; then + # this version of the decoding treats each utterance separately + # without carrying forward speaker information. + for decode_set in dev test; do + num_jobs=`cat data/${decode_set}_hires_online/utt2spk|cut -d' ' -f2|sort -u|wc -l` + decode_dir=${dir}/decode_$decode_set + steps/nnet3/decode.sh --nj $num_jobs --cmd "$decode_cmd" \ + --online-ivector-dir exp/nnet3/ivectors_${decode_set} \ + $graph_dir data/${decode_set}_hires_online $decode_dir || exit 1; + done +fi + +if [ $stage -le 10 ]; then + steps/online/nnet3/prepare_online_decoding.sh --mfcc-config conf/mfcc_hires.conf \ + --add-pitch true \ + data/lang exp/nnet3/extractor "$dir" ${dir}_online || exit 1; +fi + +if [ $stage -le 11 ]; then + # do the actual online decoding with iVectors, carrying info forward from + # previous utterances of the same speaker. + for decode_set in dev test; do + num_jobs=`cat data/${decode_set}_hires_online/utt2spk|cut -d' ' -f2|sort -u|wc -l` + decode_dir=${dir}_online/decode_$decode_set + steps/online/nnet3/decode.sh --nj $num_jobs --cmd "$decode_cmd" \ + --config conf/decode.config \ + $graph_dir data/${decode_set}_hires_online $decode_dir || exit 1; + done +fi + +if [ $stage -le 12 ]; then + # this version of the decoding treats each utterance separately + # without carrying forward speaker information. + for decode_set in dev test; do + num_jobs=`cat data/${decode_set}_hires_online/utt2spk|cut -d' ' -f2|sort -u|wc -l` + decode_dir=${dir}_online/decode_${decode_set}_per_utt + steps/online/nnet3/decode.sh --nj $num_jobs --cmd "$decode_cmd" \ + --config conf/decode.config --per-utt true \ + $graph_dir data/${decode_set}_hires_online $decode_dir || exit 1; + done +fi + +wait; +exit 0; diff --git a/egs/aishell/v1/local/aishell_data_prep.sh b/egs/aishell/v1/local/aishell_data_prep.sh index 70d6ba1f3e5..11d131dcdb1 100755 --- a/egs/aishell/v1/local/aishell_data_prep.sh +++ b/egs/aishell/v1/local/aishell_data_prep.sh @@ -40,13 +40,11 @@ n=`cat $train_dir/wav.flist $dev_dir/wav.flist $test_dir/wav.flist | wc -l` # Transcriptions preparation for dir in $train_dir $test_dir; do echo Preparing $dir transcriptions - sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' |\ - sort > $dir/utt.list - sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{i=NF-1;printf("%s %s\n",$NF,$i)}' |\ - sort > $dir/utt2spk_all + sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list + sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{i=NF-1;printf("%s %s\n",$NF,$i)}' > $dir/utt2spk_all paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text_dir/*.txt > $dir/transcripts.txt - awk '{print $1}' $dir/transcripts.txt > $dir/utt.list + awk '{print $1}' $dir/transcripts.txt | sort -u > $dir/utt.list utils/filter_scp.pl -f 1 $dir/utt.list $dir/utt2spk_all | sort -u > $dir/utt2spk utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp sort -u $dir/transcripts.txt > $dir/text diff --git a/egs/aishell2/README.md b/egs/aishell2/README.md new file mode 100644 index 00000000000..f87f3819036 --- /dev/null +++ b/egs/aishell2/README.md @@ -0,0 +1,64 @@ +# AISHELL-2 + +AISHELL-2 is by far the largest free speech corpus available for Mandarin ASR research. +## 1. DATA +### Training data +* 1000 hours of speech data (around 1 million utterances) +* 1991 speakers (845 male and 1146 female) +* clean recording environment (studio or quiet living room) +* read speech +* reading prompts from various domain: entertainment, finance, technology, sports, control command, place of interest etc. +* near field recording via 3 parallel channels (iOS, Android, Microphone). +* iOS data is free for non-commercial research and education use (e.g. universities and non-commercial institutes) + +### Evaluation data: +Currently we release AISHELL2-2018A-EVAL, containing: +* dev: 2500 utterances from 5 speakers +* test: 5000 utterances from 10 speakers + +Both sets are available across the three channel conditions. + +One of interest can download the sets from [here](http://www.aishelltech.com/aishell_eval). Note that we may update and release other evaluation sets on the website later, targeting on different applications and senarios. + +## 2. RECIPE +Based on Kaldi standard system, AISHELL-2 provides a self-contained Mandarin ASR recipe, with: +* a word segmentation module, which is a must-have component for Chinese ASR systems +* an open-sourced Mandarin lexicon (DaCiDian, open-sourced at [here](https://github.com/aishell-foundation/DaCiDian)) +* Simplified GMM training & alignment generating recipe (we stopped at speaker independent stage) +* LFMMI TDNN training and decoding recipe + +# REFERENCE +We released a [paper on Arxiv](https://arxiv.org/abs/1808.10583) on a more detailed description about the corpus with some preliminary resulting numbers. If one would like to use AISHELL-2 in experiments, please cite the paper as below: +``` +@ARTICLE{aishell2, + author = {{Du}, J. and {Na}, X. and {Liu}, X. and {Bu}, H.}, + title = "{AISHELL-2: Transforming Mandarin ASR Research Into Industrial Scale}", + journal = {ArXiv}, + eprint = {1808.10583}, + primaryClass = "cs.CL", + year = 2018, + month = Aug, +} +``` + +# APPLY FOR DATA/CONTACT +AISHELL foundation is a non-profit online organization, with members from speech industry and research institutes. + +We hope AISHELL-2 corpus and recipe could be beneficial to the entire speech community. + +Depends on your location and internet speed, we distribute the corpus in two ways: +* hard-disk delivery +* cloud-disk downloading + +To apply for AISHELL-2 corpus for free, you need to fill in a very simple application form, confirming that: +* university department / educational institute information has been fully provided +* only for non-commercial research / education use + +AISHELL-foundation covers all data distribution fees (including the corpus, hard-disk cost etc) + +Data re-distribution inside your university department is OK for convenience. However, users are not supposed to re-distribute the data to other universities or educational institutes. + +To get the application form, or you come across any problem with the recipe, contact us via: + +aishell.foundation@gmail.com + diff --git a/egs/aishell2/README.txt b/egs/aishell2/README.txt deleted file mode 100644 index e8b4260f2bb..00000000000 --- a/egs/aishell2/README.txt +++ /dev/null @@ -1,50 +0,0 @@ -# AISHELL-2 - -AISHELL-2 is by far the largest free speech corpus available for Mandarin ASR research. -## 1. DATA -### training data -* 1000 hours of speech data (around 1 million utterances) -* 1991 speakers (845 male and 1146 female) -* clean recording environment(studio or quiet living room) -* read speech -* reading prompts from various domain: entertainment, finance, technology, sports, control command, place of interest etc. -* near field recording via 3 parallel channels(iOS, Android, Microphone). -* iOS data is free for non-commercial research and education use (e.g. universities and colleges) - -### evaluation data: -Currently we release AISHELL2-2018A-EVAL, containing: -* dev: 2500 utterances from 5 speaker -* test: 5000 utterances from 10 speakers - -you can download above evaluation set from: -http://www.aishelltech.com/aishell_eval - -we may update and release other evaluation sets on the website later, targeting on different applications and senarios. - -## 2. RECIPE -Based on Kaldi standard system, AISHELL-2 provides a self-contained Mandarin ASR recipe, with: -* a word segmentation module, which is a must-have component for Chinese ASR systems -* an open-sourced Mandarin lexicon(DaCiDian) -* a simplified GMM training recipe -* acoustic channel adaptation recipe(AM fine-tuning) - -# CONTACT -AISHELL foundation is a non-profit online organization, with members from speech industry and research institutes. - -We hope AISHELL-2 corpus and recipe could be beneficial to the entire speech community. - -Depends on your location and internet speed, we distribute the corpus in two ways: -* hard-disk delivery -* cloud-disk downloading - -To apply for AISHELL-2 corpus for free, you need to fill in a very simple application form, confirming that: -* university department / education institute info -* only for non-commercial research / education use - -AISHELL-foundation covers all data distribution fees (including the corpus, hard-disk cost etc) - -Data re-distribution inside your university department is OK for convenience. However, users are not supposed to re-distribute AISHELL-2 to other universities or education institutes. - -To get the application form, or you come across any problem with the recipe, contact us via: - -aishell.foundation@gmail.com diff --git a/egs/aishell2/s5/conf/online_cmvn.conf b/egs/aishell2/s5/conf/online_cmvn.conf new file mode 100644 index 00000000000..048bdfa65de --- /dev/null +++ b/egs/aishell2/s5/conf/online_cmvn.conf @@ -0,0 +1 @@ +# configuration file for apply-cmvn-online diff --git a/egs/aishell2/s5/local/chain/tuning/run_tdnn_1a.sh b/egs/aishell2/s5/local/chain/tuning/run_tdnn_1a.sh index 459bd64eeb5..86c9becac5b 100755 --- a/egs/aishell2/s5/local/chain/tuning/run_tdnn_1a.sh +++ b/egs/aishell2/s5/local/chain/tuning/run_tdnn_1a.sh @@ -103,7 +103,7 @@ fi if [ $stage -le 10 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree | grep num-pdfs | awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.002" linear_opts="orthonormal-constraint=1.0" output_opts="l2-regularize=0.0005 bottleneck-dim=256" diff --git a/egs/aishell2/s5/local/chain/tuning/run_tdnn_1b.sh b/egs/aishell2/s5/local/chain/tuning/run_tdnn_1b.sh index 30a19293181..d8560e63909 100755 --- a/egs/aishell2/s5/local/chain/tuning/run_tdnn_1b.sh +++ b/egs/aishell2/s5/local/chain/tuning/run_tdnn_1b.sh @@ -3,18 +3,17 @@ # _1b is as _1a, but with pitch feats, i-vector and dropout schedule added, referenced from wsj # basic info: -# steps/info/chain_dir_info.pl exp/chain/tdnn_1b_all_sp/ -# exp/chain/tdnn_1b_all_sp/: num-iters=1446 nj=2..2 num-params=19.3M dim=43+100->4456 combine=-0.079->-0.075 (over 9) xent:train/valid[962,1445,final]=(-0.922,-0.795,-0.746/-0.960,-0.840,-0.785) logprob:train/valid[962,1445,final]=(-0.084,-0.072,-0.070/-0.085,-0.075,-0.071) +# steps/info/chain_dir_info.pl exp/chain/tdnn_1f_nopitch_ivec_sp/exp/chain/tdnn_1f_nopitch_ivec_sp/: num-iters=578 nj=2..8 num-params=19.3M dim=43+100->4520 combine=-0.082->-0.081 (over 6) xent:train/valid[384,577,final]=(-0.863,-0.752,-0.740/-0.901,-0.791,-0.784) logprob:train/valid[384,577,final]=(-0.083,-0.076,-0.075/-0.084,-0.077,-0.076) # results: -# local/chain/compare_wer.sh exp/chain/tdnn_1d_all_sp/ -# Model tdnn_1d_all_sp +# local/chain/compare_wer.sh exp/chain/tdnn_1f_nopitch_ivec_sp/ +# Model tdnn_1f_nopitch_ivec_sp # Num. of params 19.3M -# WER(%) 8.84 -# Final train prob -0.0696 -# Final valid prob -0.0714 -# Final train prob (xent) -0.7458 -# Final valid prob (xent) -0.7854 +# WER(%) 8.81 +# Final train prob -0.0749 +# Final valid prob -0.0756 +# Final train prob (xent) -0.7401 +# Final valid prob (xent) -0.7837 set -e @@ -68,9 +67,12 @@ if [ $stage -le 5 ]; then mfccdir=mfcc_hires for datadir in ${train_set} ${test_sets}; do utils/copy_data_dir.sh data/${datadir} data/${datadir}_hires - utils/data/perturb_data_dir_volume.sh data/${datadir}_hires || exit 1; - steps/make_mfcc_pitch.sh --mfcc-config conf/mfcc_hires.conf --pitch-config conf/pitch.conf \ + utils/data/perturb_data_dir_volume.sh data/${datadir}_hires || exit 1; + steps/make_mfcc_pitch.sh --mfcc-config conf/mfcc_hires.conf --pitch-config conf/pitch.conf \ --nj $nj data/${datadir}_hires exp/make_mfcc/ ${mfccdir} + steps/compute_cmvn_stats.sh data/${datadir}_hires exp/make_mfcc ${mfccdir} + utils/data/limit_feature_dim.sh 0:39 data/${datadir}_hires data/${datadir}_hires_nopitch + steps/compute_cmvn_stats.sh data/${datadir}_hires_nopitch exp/make_mfcc ${mfccdir} done fi @@ -81,15 +83,11 @@ if [ $stage -le 6 ]; then mkdir -p exp/chain/diag_ubm_${affix} temp_data_root=exp/chain/diag_ubm_${affix} - num_utts_total=$(wc -l < data/${train_set}_hires/utt2spk) + num_utts_total=$(wc -l < data/${train_set}_hires_nopitch/utt2spk) num_utts=$[$num_utts_total/4] - utils/data/subset_data_dir.sh data/${train_set}_hires \ + utils/data/subset_data_dir.sh data/${train_set}_hires_nopitch \ $num_utts ${temp_data_root}/${train_set}_subset - #echo "$0: get cmvn stats if not there for subset" - #[ -f ${temp_data_root}/${train_set}_subset/cmvn.scp ] || \ - steps/compute_cmvn_stats.sh ${temp_data_root}/${train_set}_subset || exit 1; - echo "$0: computing a PCA transform from the hires data." steps/online/nnet2/get_pca_transform.sh --cmd "$train_cmd" \ --splice-opts "--left-context=3 --right-context=3" \ @@ -108,13 +106,13 @@ if [ $stage -le 6 ]; then echo "$0: training the iVector extractor" steps/online/nnet2/train_ivector_extractor.sh --cmd "$train_cmd" --nj $nj \ - data/${train_set}_hires exp/chain/diag_ubm_${affix} \ + data/${train_set}_hires_nopitch exp/chain/diag_ubm_${affix} \ exp/chain/extractor_${affix} || exit 1; for datadir in ${train_set} ${test_sets}; do - steps/online/nnet2/copy_data_dir.sh --utts-per-spk-max 2 data/${datadir}_hires data/${datadir}_hires_max2 + steps/online/nnet2/copy_data_dir.sh --utts-per-spk-max 2 data/${datadir}_hires_nopitch data/${datadir}_hires_nopitch_max2 steps/online/nnet2/extract_ivectors_online.sh --cmd "$train_cmd" --nj $nj \ - data/${datadir}_hires_max2 exp/chain/extractor_${affix} exp/chain/ivectors_${datadir}_${affix} || exit 1; + data/${datadir}_hires_nopitch_max2 exp/chain/extractor_${affix} exp/chain/ivectors_${datadir}_${affix} || exit 1; done fi @@ -152,7 +150,7 @@ if [ $stage -le 10 ]; then echo "$0: creating neural net configs using the xconfig parser"; feat_dim=$(feat-to-dim scp:data/${train_set}_hires/feats.scp -) num_targets=$(tree-info $treedir/tree | grep num-pdfs | awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.002" linear_opts="orthonormal-constraint=1.0" output_opts="l2-regularize=0.0005 bottleneck-dim=256" diff --git a/egs/aishell2/s5/local/prepare_data.sh b/egs/aishell2/s5/local/prepare_data.sh index 419d8eddfd1..4be9664ac31 100755 --- a/egs/aishell2/s5/local/prepare_data.sh +++ b/egs/aishell2/s5/local/prepare_data.sh @@ -45,8 +45,9 @@ utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tm python -c "import jieba" 2>/dev/null || \ (echo "jieba is not found. Use tools/extra/install_jieba.sh to install it." && exit 1;) utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/trans.txt -awk '{print $1}' $dict_dir/lexicon.txt | sort | uniq | awk 'BEGIN{idx=0}{print $1,idx++}'> $tmp/vocab.txt -python local/word_segmentation.py $tmp/vocab.txt $tmp/trans.txt > $tmp/text +# jieba's vocab format requires word count(frequency), set to 99 +awk '{print $1}' $dict_dir/lexicon.txt | sort | uniq | awk '{print $1,99}'> $tmp/word_seg_vocab.txt +python local/word_segmentation.py $tmp/word_seg_vocab.txt $tmp/trans.txt > $tmp/text # utt2spk & spk2utt awk -F'\t' '{print $2}' $tmp/wav.scp > $tmp/wav.list diff --git a/egs/aishell2/s5/local/prepare_dict.sh b/egs/aishell2/s5/local/prepare_dict.sh index d59585273a7..56ab885ae94 100755 --- a/egs/aishell2/s5/local/prepare_dict.sh +++ b/egs/aishell2/s5/local/prepare_dict.sh @@ -10,7 +10,7 @@ download_dir=data/local/DaCiDian dir=data/local/dict -if [ $# -ne 1 ]; then +if [ $# -ne 1 ]; then echo "Usage: $0 "; exit 1; fi @@ -18,7 +18,9 @@ fi dir=$1 # download the DaCiDian from github -git clone https://github.com/aishell-foundation/DaCiDian.git $download_dir +if [ ! -d $download_dir ]; then + git clone https://github.com/aishell-foundation/DaCiDian.git $download_dir +fi # here we map to the phone spn(spoken noise) mkdir -p $dir @@ -27,21 +29,9 @@ echo -e "\tspn" >> $dir/lexicon.txt # prepare silence_phones.txt, nonsilence_phones.txt, optional_silence.txt, extra_questions.txt cat $dir/lexicon.txt | awk '{ for(n=2;n<=NF;n++){ phones[$n] = 1; }} END{for (p in phones) print p;}'| \ - sort -u |\ - perl -e ' - my %ph_cl; - while () { - $phone = $_; - chomp($phone); - chomp($_); - $phone = $_; - next if ($phone eq "sil"); - if (exists $ph_cl{$phone}) { push(@{$ph_cl{$phone}}, $_) } - else { $ph_cl{$phone} = [$_]; } - } - foreach $key ( keys %ph_cl ) { - print "@{ $ph_cl{$key} }\n" - } + perl -e 'while(<>){ chomp($_); $phone = $_; next if ($phone eq "sil"); + m:^([^\d]+)(\d*)$: || die "Bad phone $_"; $q{$1} .= "$phone "; } + foreach $l (values %q) {print "$l\n";} ' | sort -k1 > $dir/nonsilence_phones.txt || exit 1; echo sil > $dir/silence_phones.txt @@ -49,9 +39,8 @@ echo sil > $dir/optional_silence.txt cat $dir/silence_phones.txt | awk '{printf("%s ", $1);} END{printf "\n";}' > $dir/extra_questions.txt || exit 1; cat $dir/nonsilence_phones.txt | perl -e 'while(<>){ foreach $p (split(" ", $_)) { - $p =~ m:^([^\d]+)(\d*)$: || die "Bad phone $_"; $q{$2} .= "$p "; } } foreach $l (values %q) {print "$l\n";}' \ + $p =~ m:^([^\d]+)(\d*)$: || die "Bad phone $_"; if($p eq "\$0"){$q{""} .= "$p ";}else{$q{$2} .= "$p ";} } } foreach $l (values %q) {print "$l\n";}' \ >> $dir/extra_questions.txt || exit 1; echo "local/prepare_dict.sh succeeded" exit 0; - diff --git a/egs/aishell2/s5/local/word_segmentation.py b/egs/aishell2/s5/local/word_segmentation.py index 1cb2c1e7350..4ce55a2003e 100644 --- a/egs/aishell2/s5/local/word_segmentation.py +++ b/egs/aishell2/s5/local/word_segmentation.py @@ -4,6 +4,7 @@ # 2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU) # Apache 2.0 +from __future__ import print_function import sys import jieba reload(sys) @@ -19,6 +20,6 @@ jieba.set_dictionary(vocab_file) for line in open(trans_file): key,trans = line.strip().split('\t',1) - words = jieba.cut(trans) + words = jieba.cut(trans, HMM=False) # turn off new word discovery (HMM-based) new_line = key + '\t' + " ".join(words) print(new_line) diff --git a/egs/ami/s5/local/ami_download.sh b/egs/ami/s5/local/ami_download.sh index b14f8550c75..cba130c8467 100755 --- a/egs/ami/s5/local/ami_download.sh +++ b/egs/ami/s5/local/ami_download.sh @@ -53,12 +53,12 @@ cat local/split_train.orig local/split_eval.orig local/split_dev.orig > $wdir/am wgetfile=$wdir/wget_$mic.sh # TODO fix this with Pawel, files don't exist anymore, -manifest="wget --continue -O $adir/MANIFEST.TXT http://groups.inf.ed.ac.uk/ami/download/temp/amiBuild-04237-Sun-Jun-15-2014.manifest.txt" -license="wget --continue -O $adir/LICENCE.TXT http://groups.inf.ed.ac.uk/ami/download/temp/Creative-Commons-Attribution-NonCommercial-ShareAlike-2.5.txt" +manifest="wget --continue -O $adir/MANIFEST.TXT http://groups.inf.ed.ac.uk/ami/download/temp/amiBuild-0153-Tue-Oct-2-2018.manifest.txt" + echo "#!/bin/bash" > $wgetfile echo $manifest >> $wgetfile -echo $license >> $wgetfile + while read line; do if [ "$mic" == "ihm" ]; then extra_headset= #some meetings have 5 sepakers (headsets) @@ -100,8 +100,7 @@ else fi fi -echo "Downloads of AMI corpus completed succesfully. License can be found under $adir/LICENCE.TXT" +echo "Downloads of AMI corpus completed succesfully." exit 0; - diff --git a/egs/ami/s5/local/ami_ihm_scoring_data_prep.sh b/egs/ami/s5/local/ami_ihm_scoring_data_prep.sh index 3157d7ffec7..7112e0259a0 100755 --- a/egs/ami/s5/local/ami_ihm_scoring_data_prep.sh +++ b/egs/ami/s5/local/ami_ihm_scoring_data_prep.sh @@ -87,18 +87,15 @@ sort -k 2 $dir/utt2spk | utils/utt2spk_to_spk2utt.pl > $dir/spk2utt || exit 1; join $dir/utt2spk $dir/segments | \ perl -ne '{BEGIN{$pu=""; $pt=0.0;} split; if ($pu eq $_[1] && $pt > $_[3]) { - print "$_[0] $_[2] $_[3] $_[4]>$_[0] $_[2] $pt $_[4]\n" + print "s/^$_[0] $_[2] $_[3] $_[4]\$/$_[0] $_[2] $pt $_[4]/;\n" } - $pu=$_[1]; $pt=$_[4]; + $pu=$_[1]; $pt=$_[4]; }' > $dir/segments_to_fix -if [ `cat $dir/segments_to_fix | wc -l` -gt 0 ]; then + +if [ -s $dir/segments_to_fix ]; then echo "$0. Applying following fixes to segments" cat $dir/segments_to_fix - while read line; do - p1=`echo $line | awk -F'>' '{print $1}'` - p2=`echo $line | awk -F'>' '{print $2}'` - sed -ir "s!$p1!$p2!" $dir/segments - done < $dir/segments_to_fix + perl -i -pf $dir/segments_to_fix $dir/segments fi # Copy stuff into its final locations diff --git a/egs/ami/s5/local/ami_mdm_scoring_data_prep.sh b/egs/ami/s5/local/ami_mdm_scoring_data_prep.sh index 4cfa9110edf..9c4b55308f2 100755 --- a/egs/ami/s5/local/ami_mdm_scoring_data_prep.sh +++ b/egs/ami/s5/local/ami_mdm_scoring_data_prep.sh @@ -94,19 +94,15 @@ awk '{print $1}' $tmpdir/segments | \ join $tmpdir/utt2spk_stm $tmpdir/segments | \ awk '{ utt=$1; spk=$2; wav=$3; t_beg=$4; t_end=$5; if(spk_prev == spk && t_end_prev > t_beg) { - print utt, wav, t_beg, t_end">"utt, wav, t_end_prev, t_end; + print "s/^"utt, wav, t_beg, t_end"$/"utt, wav, t_end_prev, t_end"/;"; } spk_prev=spk; t_end_prev=t_end; }' > $tmpdir/segments_to_fix -if [ `cat $tmpdir/segments_to_fix | wc -l` -gt 0 ]; then +if [ -s $tmpdir/segments_to_fix ]; then echo "$0. Applying following fixes to segments" cat $tmpdir/segments_to_fix - while read line; do - p1=`echo $line | awk -F'>' '{print $1}'` - p2=`echo $line | awk -F'>' '{print $2}'` - sed -ir "s:$p1:$p2:" $tmpdir/segments - done < $tmpdir/segments_to_fix + perl -i -pf $tmpdir/segments_to_fix $tmpdir/segments fi # Copy stuff into its final locations [this has been moved from the format_data diff --git a/egs/ami/s5/local/ami_sdm_scoring_data_prep.sh b/egs/ami/s5/local/ami_sdm_scoring_data_prep.sh index 91baa37d6e1..815e1b2d270 100755 --- a/egs/ami/s5/local/ami_sdm_scoring_data_prep.sh +++ b/egs/ami/s5/local/ami_sdm_scoring_data_prep.sh @@ -101,19 +101,15 @@ awk '{print $1}' $tmpdir/segments | \ join $tmpdir/utt2spk_stm $tmpdir/segments | \ awk '{ utt=$1; spk=$2; wav=$3; t_beg=$4; t_end=$5; if(spk_prev == spk && t_end_prev > t_beg) { - print utt, wav, t_beg, t_end">"utt, wav, t_end_prev, t_end; + print "s/^"utt, wav, t_beg, t_end"$/"utt, wav, t_end_prev, t_end"/;"; } spk_prev=spk; t_end_prev=t_end; }' > $tmpdir/segments_to_fix -if [ `cat $tmpdir/segments_to_fix | wc -l` -gt 0 ]; then +if [ -s $tmpdir/segments_to_fix ]; then echo "$0. Applying following fixes to segments" cat $tmpdir/segments_to_fix - while read line; do - p1=`echo $line | awk -F'>' '{print $1}'` - p2=`echo $line | awk -F'>' '{print $2}'` - sed -ir "s:$p1:$p2:" $tmpdir/segments - done < $tmpdir/segments_to_fix + perl -i -pf $tmpdir/segments_to_fix $tmpdir/segments fi # Copy stuff into its final locations [this has been moved from the format_data diff --git a/egs/ami/s5/local/sort_bad_utts.py b/egs/ami/s5/local/sort_bad_utts.py index f84fcb12608..baabdc73508 100644 --- a/egs/ami/s5/local/sort_bad_utts.py +++ b/egs/ami/s5/local/sort_bad_utts.py @@ -1,5 +1,6 @@ #!/usr/bin/env python +from __future__ import print_function import sys import argparse import logging @@ -38,10 +39,10 @@ def GetSortedWers(utt_info_file): utt_wer_sorted = sorted(utt_wer, key = lambda k : k[1]) try: import numpy as np - bins = range(0,105,5) + bins = list(range(0,105,5)) bins.append(sys.float_info.max) - hist, bin_edges = np.histogram(map(lambda x: x[1], utt_wer_sorted), + hist, bin_edges = np.histogram([x[1] for x in utt_wer_sorted], bins = bins) num_utts = len(utt_wer) string = '' diff --git a/egs/ami/s5/local/tfrnnlm/run_lstm.sh b/egs/ami/s5/local/tfrnnlm/run_lstm.sh index 31ae4a8bad7..d68fadb10f3 100755 --- a/egs/ami/s5/local/tfrnnlm/run_lstm.sh +++ b/egs/ami/s5/local/tfrnnlm/run_lstm.sh @@ -27,7 +27,7 @@ mkdir -p $dir if [ $stage -le 2 ]; then # the following script uses TensorFlow. You could use tools/extras/install_tensorflow_py.sh to install it $cuda_cmd $dir/train_rnnlm.log utils/parallel/limit_num_gpus.sh \ - python steps/tfrnnlm/lstm.py --data-path=$dir --save-path=$dir/rnnlm --vocab-path=$dir/wordlist.rnn.final + python steps/tfrnnlm/lstm.py --data_path=$dir --save_path=$dir/rnnlm --vocab_path=$dir/wordlist.rnn.final fi final_lm=ami_fsh.o3g.kn @@ -39,7 +39,7 @@ if [ $stage -le 3 ]; then decode_dir=${basedir}/decode_${decode_set} # Lattice rescoring - steps/lmrescore_rnnlm_lat.sh \ + steps/tfrnnlm/lmrescore_rnnlm_lat.sh \ --cmd "$tfrnnlm_cmd --mem 16G" \ --rnnlm-ver tensorflow --weight $weight --max-ngram-order $ngram_order \ data/lang_$LM $dir \ diff --git a/egs/ami/s5/local/tfrnnlm/run_lstm_fast.sh b/egs/ami/s5/local/tfrnnlm/run_lstm_fast.sh index 8dd876c2b2c..4cc71b55b5c 100755 --- a/egs/ami/s5/local/tfrnnlm/run_lstm_fast.sh +++ b/egs/ami/s5/local/tfrnnlm/run_lstm_fast.sh @@ -27,7 +27,7 @@ mkdir -p $dir if [ $stage -le 2 ]; then # the following script uses TensorFlow. You could use tools/extras/install_tensorflow_py.sh to install it $cuda_cmd $dir/train_rnnlm.log utils/parallel/limit_num_gpus.sh \ - python steps/tfrnnlm/lstm_fast.py --data-path=$dir --save-path=$dir/rnnlm --vocab-path=$dir/wordlist.rnn.final + python steps/tfrnnlm/lstm_fast.py --data_path=$dir --save_path=$dir/rnnlm --vocab_path=$dir/wordlist.rnn.final fi final_lm=ami_fsh.o3g.kn diff --git a/egs/ami/s5/local/tfrnnlm/run_vanilla_rnnlm.sh b/egs/ami/s5/local/tfrnnlm/run_vanilla_rnnlm.sh index 7a4635f07a4..7a95f38ba1e 100755 --- a/egs/ami/s5/local/tfrnnlm/run_vanilla_rnnlm.sh +++ b/egs/ami/s5/local/tfrnnlm/run_vanilla_rnnlm.sh @@ -27,7 +27,7 @@ mkdir -p $dir if [ $stage -le 2 ]; then # the following script uses TensorFlow. You could use tools/extras/install_tensorflow_py.sh to install it $cuda_cmd $dir/train_rnnlm.log utils/parallel/limit_num_gpus.sh \ - python steps/tfrnnlm/vanilla_rnnlm.py --data-path=$dir --save-path=$dir/rnnlm --vocab-path=$dir/wordlist.rnn.final + python steps/tfrnnlm/vanilla_rnnlm.py --data_path=$dir --save_path=$dir/rnnlm --vocab_path=$dir/wordlist.rnn.final fi final_lm=ami_fsh.o3g.kn @@ -39,7 +39,7 @@ if [ $stage -le 3 ]; then decode_dir=${basedir}/decode_${decode_set} # Lattice rescoring - steps/lmrescore_rnnlm_lat.sh \ + steps/tfrnnlm/lmrescore_rnnlm_lat.sh \ --cmd "$tfrnnlm_cmd --mem 16G" \ --rnnlm-ver tensorflow --weight $weight --max-ngram-order $ngram_order \ data/lang_$LM $dir \ diff --git a/egs/ami/s5b/RESULTS_ihm b/egs/ami/s5b/RESULTS_ihm index 42af5763829..7eb908f685e 100644 --- a/egs/ami/s5b/RESULTS_ihm +++ b/egs/ami/s5b/RESULTS_ihm @@ -86,8 +86,7 @@ %WER 19.8 | 13098 94475 | 83.1 9.6 7.4 2.8 19.8 51.8 | -0.041 | exp/ihm/chain_cleaned/tdnn_lstm1l_sp_bi_ld5/decode_dev/ascore_10/dev_hires.ctm.filt.sys %WER 19.2 | 12643 89964 | 83.2 10.7 6.1 2.5 19.2 49.7 | 0.079 | exp/ihm/chain_cleaned/tdnn_lstm1l_sp_bi_ld5/decode_eval/ascore_10/eval_hires.ctm.filt.sys -# local/chain/multi_condition/tuning/run_tdnn_lstm_1a.sh --mic ihm +# local/chain/multi_condition/tuning/run_tdnn_lstm_1b.sh --mic ihm # cleanup + chain TDNN+LSTM model + IHM reverberated data -%WER 19.4 | 13098 94479 | 83.8 10.0 6.1 3.2 19.4 51.8 | -0.168 | exp/ihm/chain_cleaned_rvb/tdnn_lstm1i_sp_rvb_bi/decode_dev/ascore_10/dev_hires.ctm.filt.sys -%WER 19.3 | 12643 89977 | 83.3 11.0 5.7 2.6 19.3 49.6 | -0.046 | exp/ihm/chain_cleaned_rvb/tdnn_lstm1i_sp_rvb_bi/decode_eval/ascore_10/eval_hires.ctm.filt.sys - +%WER 18.9 | 13098 94488 | 84.1 9.7 6.2 3.0 18.9 51.2 | 0.012 | exp/ihm/chain_cleaned_rvb/tdnn_lstm1b_sp_rvb_bi/decode_dev/ascore_11/dev_hires.ctm.filt.sys +%WER 19.3 | 12643 89989 | 83.1 10.7 6.2 2.5 19.3 50.0 | 0.136 | exp/ihm/chain_cleaned_rvb/tdnn_lstm1b_sp_rvb_bi/decode_eval/ascore_11/eval_hires.ctm.filt.sys diff --git a/egs/ami/s5b/RESULTS_sdm b/egs/ami/s5b/RESULTS_sdm index 0993b2eb52a..584c50f298a 100644 --- a/egs/ami/s5b/RESULTS_sdm +++ b/egs/ami/s5b/RESULTS_sdm @@ -93,9 +93,13 @@ %WER 35.9 | 14900 94497 | 67.8 18.2 14.1 3.7 35.9 62.5 | 0.647 | exp/sdm1/chain_cleaned/tdnn_lstm1l_sp_bi_ihmali_ld5/decode_dev/ascore_9/dev_hires_o4.ctm.filt.sys %WER 39.4 | 13223 89946 | 64.1 19.7 16.2 3.5 39.4 67.0 | 0.611 | exp/sdm1/chain_cleaned/tdnn_lstm1l_sp_bi_ihmali_ld5/decode_eval/ascore_9/eval_hires_o4.ctm.filt.sys -# local/chain/multi_condition/tuning/run_tdnn_lstm_1a.sh --mic sdm1 --use-ihm-ali true --train-set train_cleaned --gmm tri3_cleaned +# local/chain/multi_condition/tuning/run_tdnn_lstm_1b.sh --mic sdm1 --use-ihm-ali true --train-set train_cleaned --gmm tri3_cleaned # cleanup + chain TDNN+LSTM model, SDM original + IHM reverberated data, alignments from ihm data. -# *** best system *** -%WER 34.0 | 14455 94497 | 69.8 17.7 12.5 3.8 34.0 63.9 | 0.675 | exp/sdm1/chain_cleaned_rvb/tdnn_lstm1i_sp_rvb_bi_ihmali/decode_dev/ascore_10/dev_hires_o4.ctm.filt.sys -%WER 37.5 | 13261 89982 | 65.9 19.3 14.7 3.5 37.5 66.2 | 0.642 | exp/sdm1/chain_cleaned_rvb/tdnn_lstm1i_sp_rvb_bi_ihmali/decode_eval/ascore_10/eval_hires_o4.ctm.filt.sys +%WER 33.9 | 14185 94492 | 70.3 18.1 11.7 4.2 33.9 66.0 | 0.605 | exp/sdm1/chain_cleaned_rvb/tdnn_lstm1b_sp_rvb_bi_ihmali/decode_dev/ascore_10/dev_hires_o4.ctm.filt.sys +%WER 37.4 | 13610 89969 | 66.3 19.9 13.7 3.7 37.4 65.5 | 0.568 | exp/sdm1/chain_cleaned_rvb/tdnn_lstm1b_sp_rvb_bi_ihmali/decode_eval/ascore_10/eval_hires_o4.ctm.filt.sys +# local/chain/multi_condition/tuning/run_tdnn_1a.sh --mic sdm1 --use-ihm-ali true --train-set train_cleaned --gmm tri3_cleaned +# cleanup + chain TDNN-F model, SDM original + IHM reverberated data, alignments from ihm data. +# *** best system *** +%WER 33.3 | 14696 94538 | 70.4 17.2 12.4 3.7 33.3 63.1 | 0.612 | exp/sdm1/chain_cleaned_rvb/tdnn1a_sp_rvb_bi_ihmali/decode_dev/ascore_10/dev_hires_o4.ctm.filt.sys +%WER 36.7 | 14855 89974 | 66.7 18.9 14.4 3.4 36.7 59.8 | 0.580 | exp/sdm1/chain_cleaned_rvb/tdnn1a_sp_rvb_bi_ihmali/decode_eval/ascore_10/eval_hires_o4.ctm.filt.sys diff --git a/egs/ami/s5b/local/ami_ihm_scoring_data_prep.sh b/egs/ami/s5b/local/ami_ihm_scoring_data_prep.sh index 746c42c4c1a..c54876331f1 100755 --- a/egs/ami/s5b/local/ami_ihm_scoring_data_prep.sh +++ b/egs/ami/s5b/local/ami_ihm_scoring_data_prep.sh @@ -93,18 +93,15 @@ sort -k 2 $dir/utt2spk | utils/utt2spk_to_spk2utt.pl > $dir/spk2utt || exit 1; join $dir/utt2spk $dir/segments | \ perl -ne '{BEGIN{$pu=""; $pt=0.0;} split; if ($pu eq $_[1] && $pt > $_[3]) { - print "$_[0] $_[2] $_[3] $_[4]>$_[0] $_[2] $pt $_[4]\n" + print "s/^$_[0] $_[2] $_[3] $_[4]\$/$_[0] $_[2] $pt $_[4]/;\n" } $pu=$_[1]; $pt=$_[4]; }' > $dir/segments_to_fix -if [ `cat $dir/segments_to_fix | wc -l` -gt 0 ]; then + +if [ -s $dir/segments_to_fix ]; then echo "$0. Applying following fixes to segments" cat $dir/segments_to_fix - while read line; do - p1=`echo $line | awk -F'>' '{print $1}'` - p2=`echo $line | awk -F'>' '{print $2}'` - sed -ir "s!$p1!$p2!" $dir/segments - done < $dir/segments_to_fix + perl -i -pf $dir/segments_to_fix $dir/segments fi # Copy stuff into its final locations diff --git a/egs/ami/s5b/local/ami_mdm_scoring_data_prep.sh b/egs/ami/s5b/local/ami_mdm_scoring_data_prep.sh index 65f514f223c..475ef5405ba 100755 --- a/egs/ami/s5b/local/ami_mdm_scoring_data_prep.sh +++ b/egs/ami/s5b/local/ami_mdm_scoring_data_prep.sh @@ -99,19 +99,15 @@ awk '{print $1}' $tmpdir/segments | \ join $tmpdir/utt2spk_stm $tmpdir/segments | \ awk '{ utt=$1; spk=$2; wav=$3; t_beg=$4; t_end=$5; if(spk_prev == spk && t_end_prev > t_beg) { - print utt, wav, t_beg, t_end">"utt, wav, t_end_prev, t_end; + print "s/^"utt, wav, t_beg, t_end"$/"utt, wav, t_end_prev, t_end"/;"; } spk_prev=spk; t_end_prev=t_end; }' > $tmpdir/segments_to_fix -if [ `cat $tmpdir/segments_to_fix | wc -l` -gt 0 ]; then +if [ -s $tmpdir/segments_to_fix ]; then echo "$0. Applying following fixes to segments" cat $tmpdir/segments_to_fix - while read line; do - p1=`echo $line | awk -F'>' '{print $1}'` - p2=`echo $line | awk -F'>' '{print $2}'` - sed -ir "s:$p1:$p2:" $tmpdir/segments - done < $tmpdir/segments_to_fix + perl -i -pf $tmpdir/segments_to_fix $tmpdir/segments fi # Copy stuff into its final locations [this has been moved from the format_data diff --git a/egs/ami/s5b/local/ami_sdm_scoring_data_prep.sh b/egs/ami/s5b/local/ami_sdm_scoring_data_prep.sh index 1378f8b8965..d7ce038c0a7 100755 --- a/egs/ami/s5b/local/ami_sdm_scoring_data_prep.sh +++ b/egs/ami/s5b/local/ami_sdm_scoring_data_prep.sh @@ -111,25 +111,21 @@ awk '{print $1}' $tmpdir/segments | \ join $tmpdir/utt2spk_stm $tmpdir/segments | \ awk '{ utt=$1; spk=$2; wav=$3; t_beg=$4; t_end=$5; if(spk_prev == spk && t_end_prev > t_beg) { - print utt, wav, t_beg, t_end">"utt, wav, t_end_prev, t_end; + print "s/^"utt, wav, t_beg, t_end"$/"utt, wav, t_end_prev, t_end"/;"; } spk_prev=spk; t_end_prev=t_end; }' > $tmpdir/segments_to_fix -if [ `cat $tmpdir/segments_to_fix | wc -l` -gt 0 ]; then +if [ -s $tmpdir/segments_to_fix ]; then echo "$0. Applying following fixes to segments" cat $tmpdir/segments_to_fix - while read line; do - p1=`echo $line | awk -F'>' '{print $1}'` - p2=`echo $line | awk -F'>' '{print $2}'` - sed -ir "s:$p1:$p2:" $tmpdir/segments - done < $tmpdir/segments_to_fix + perl -i -pf $tmpdir/segments_to_fix $tmpdir/segments fi # Copy stuff into its final locations [this has been moved from the format_data # script] mkdir -p $dir -for f in spk2utt utt2spk utt2spk_stm wav.scp text segments reco2file_and_channel; do +for f in segments_to_fix spk2utt utt2spk utt2spk_stm wav.scp text segments reco2file_and_channel; do cp $tmpdir/$f $dir/$f || exit 1; done diff --git a/egs/ami/s5b/local/chain/multi_condition/run_tdnn.sh b/egs/ami/s5b/local/chain/multi_condition/run_tdnn.sh deleted file mode 100755 index 754a9508e66..00000000000 --- a/egs/ami/s5b/local/chain/multi_condition/run_tdnn.sh +++ /dev/null @@ -1,283 +0,0 @@ -#!/bin/bash - -# This is a chain-training script with TDNN neural networks. -# This script is based on local/chain/tuning/run_tdnn_1a.sh, but adding -# the reverberated IHM data into the train set. -# This script obtains better results on IHM, SDM and MDM tasks. - -# Please see RESULTS_* for examples of command lines invoking this script. - -# local/chain/multi_condition/run_tdnn.sh --mic ihm --train-set train_cleaned --gmm tri3_cleaned & -# local/chain/multi_condition/run_tdnn.sh --mic sdm1 --use-ihm-ali true --train-set train_cleaned --gmm tri3_cleaned & -# local/chain/multi_condition/run_tdnn.sh --mic mdm8 --use-ihm-ali true --train-set train_cleaned --gmm tri3_cleaned & - - -set -e -o pipefail - -# First the options that are passed through to run_ivector_common.sh -# (some of which are also used in this script directly). -stage=1 -mic=ihm -nj=30 -min_seg_len=1.55 -use_ihm_ali=false -train_set=train_cleaned -gmm=tri3_cleaned # the gmm for the target data -ihm_gmm=tri3_cleaned # the gmm for the IHM system (if --use-ihm-ali true). -num_threads_ubm=32 -num_data_reps=1 - -# The rest are configs specific to this script. Most of the parameters -# are just hardcoded at this level, in the commands below. -train_stage=-10 -tree_affix= # affix for tree directory, e.g. "a" or "b", in case we change the configuration. -tdnn_affix= #affix for TDNN directory, e.g. "a" or "b", in case we change the configuration. -common_egs_dir= # you can set this to use previously dumped egs. - -# End configuration section. -echo "$0 $@" # Print the command line for logging - -. ./cmd.sh -. ./path.sh -. ./utils/parse_options.sh - -if ! $use_ihm_ali; then - [ "$mic" != "ihm" ] && \ - echo "$0: you cannot specify --use-ihm-ali false if the microphone is not ihm." && \ - exit 1; -else - [ "$mic" == "ihm" ] && \ - echo "$0: you must specify --use-ihm-ali false if the microphone is ihm." && \ - exit 1; -fi - -if ! cuda-compiled; then - cat <data/lang_chain/topo - fi -fi - -if [ $stage -le 13 ]; then - # Get the alignments as lattices (gives the chain training more freedom). - # use the same num-jobs as the alignments - steps/align_fmllr_lats.sh --nj 100 --cmd "$train_cmd" ${lores_train_data_dir} \ - data/lang $gmm_dir $original_lat_dir - rm $original_lat_dir/fsts.*.gz # save space - - lat_dir_ihmdata=exp/ihm/chain${nnet3_affix}/${gmm}_${train_set}_sp_comb_lats - - mkdir -p $lat_dir/temp/ - mkdir -p $lat_dir/temp2/ - lattice-copy "ark:gunzip -c $original_lat_dir/lat.*.gz |" ark,scp:$lat_dir/temp/lats.ark,$lat_dir/temp/lats.scp - lattice-copy "ark:gunzip -c $lat_dir_ihmdata/lat.*.gz |" ark,scp:$lat_dir/temp2/lats.ark,$lat_dir/temp2/lats.scp - - # copy the lattices for the reverberated data - rm -f $lat_dir/temp/combined_lats.scp - touch $lat_dir/temp/combined_lats.scp - cat $lat_dir/temp/lats.scp >> $lat_dir/temp/combined_lats.scp - for i in `seq 1 $num_data_reps`; do - cat $lat_dir/temp2/lats.scp | sed -e "s/^/rev${i}_/" >> $lat_dir/temp/combined_lats.scp - done - sort -u $lat_dir/temp/combined_lats.scp > $lat_dir/temp/combined_lats_sorted.scp - - lattice-copy scp:$lat_dir/temp/combined_lats_sorted.scp "ark:|gzip -c >$lat_dir/lat.1.gz" || exit 1; - echo "1" > $lat_dir/num_jobs - - # copy other files from original lattice dir - for f in cmvn_opts final.mdl splice_opts tree; do - cp $original_lat_dir/$f $lat_dir/$f - done -fi - - -if [ $stage -le 14 ]; then - # Build a tree using our new topology. We know we have alignments for the - # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use - # those. - if [ -f $tree_dir/final.mdl ]; then - echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." - exit 1; - fi - steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ - --context-opts "--context-width=2 --central-position=1" \ - --leftmost-questions-truncate -1 \ - --cmd "$train_cmd" 4200 ${lores_train_data_dir} data/lang_chain $ali_dir $tree_dir -fi - -if [ $stage -le 15 ]; then - mkdir -p $dir - - echo "$0: creating neural net configs"; - - steps/nnet3/tdnn/make_configs.py \ - --self-repair-scale-nonlinearity 0.00001 \ - --feat-dir data/$mic/${train_set}_sp_hires_comb \ - --ivector-dir $train_ivector_dir \ - --tree-dir $tree_dir \ - --relu-dim 450 \ - --splice-indexes "-1,0,1 -1,0,1,2 -3,0,3 -3,0,3 -3,0,3 -6,-3,0 0" \ - --use-presoftmax-prior-scale false \ - --xent-regularize 0.1 \ - --xent-separate-forward-affine true \ - --include-log-softmax false \ - --final-layer-normalize-target 1.0 \ - $dir/configs || exit 1; -fi - -if [ $stage -le 16 ]; then - if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then - utils/create_split_dir.pl \ - /export/b0{5,6,7,8}/$USER/kaldi-data/egs/ami-rvb$(date +'%m_%d_%H_%M')/s5b/$dir/egs/storage $dir/egs/storage - fi - - touch $dir/egs/.nodelete # keep egs around when that run dies. - - steps/nnet3/chain/train.py --stage $train_stage \ - --cmd "$decode_cmd" \ - --feat.online-ivector-dir $train_ivector_dir \ - --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ - --chain.xent-regularize 0.1 \ - --chain.leaky-hmm-coefficient 0.1 \ - --chain.l2-regularize 0.00005 \ - --chain.apply-deriv-weights false \ - --chain.lm-opts="--num-extra-lm-states=2000" \ - --egs.dir "$common_egs_dir" \ - --egs.opts "--frames-overlap-per-eg 0" \ - --egs.chunk-width 150 \ - --trainer.num-chunk-per-minibatch 128 \ - --trainer.frames-per-iter 1500000 \ - --trainer.num-epochs 4 \ - --trainer.optimization.num-jobs-initial 2 \ - --trainer.optimization.num-jobs-final 12 \ - --trainer.optimization.initial-effective-lrate 0.001 \ - --trainer.optimization.final-effective-lrate 0.0001 \ - --trainer.max-param-change 2.0 \ - --cleanup.remove-egs true \ - --feat-dir $train_data_dir \ - --tree-dir $tree_dir \ - --lat-dir $lat_dir \ - --dir $dir -fi - - -graph_dir=$dir/graph_${LM} -if [ $stage -le 17 ]; then - # Note: it might appear that this data/lang_chain directory is mismatched, and it is as - # far as the 'topo' is concerned, but this script doesn't read the 'topo' from - # the lang directory. - utils/mkgraph.sh --self-loop-scale 1.0 data/lang_${LM} $dir $graph_dir -fi - -if [ $stage -le 18 ]; then - rm $dir/.error 2>/dev/null || true - for decode_set in dev eval; do - ( - steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ - --nj $nj --cmd "$decode_cmd" \ - --online-ivector-dir exp/$mic/nnet3${nnet3_affix}${rvb_affix}/ivectors_${decode_set}_hires \ - --scoring-opts "--min-lmwt 5 " \ - $graph_dir data/$mic/${decode_set}_hires $dir/decode_${decode_set} || exit 1; - ) || touch $dir/.error & - done - wait - if [ -f $dir/.error ]; then - echo "$0: something went wrong in decoding" - exit 1 - fi -fi -exit 0 diff --git a/egs/ami/s5b/local/chain/multi_condition/run_tdnn.sh b/egs/ami/s5b/local/chain/multi_condition/run_tdnn.sh new file mode 120000 index 00000000000..34499362831 --- /dev/null +++ b/egs/ami/s5b/local/chain/multi_condition/run_tdnn.sh @@ -0,0 +1 @@ +tuning/run_tdnn_1a.sh \ No newline at end of file diff --git a/egs/ami/s5b/local/chain/multi_condition/run_tdnn_lstm.sh b/egs/ami/s5b/local/chain/multi_condition/run_tdnn_lstm.sh index 8e647598556..a4fa11e0908 120000 --- a/egs/ami/s5b/local/chain/multi_condition/run_tdnn_lstm.sh +++ b/egs/ami/s5b/local/chain/multi_condition/run_tdnn_lstm.sh @@ -1 +1 @@ -tuning/run_tdnn_lstm_1a.sh \ No newline at end of file +tuning/run_tdnn_lstm_1b.sh \ No newline at end of file diff --git a/egs/ami/s5b/local/chain/multi_condition/tuning/run_tdnn_1a.sh b/egs/ami/s5b/local/chain/multi_condition/tuning/run_tdnn_1a.sh new file mode 100755 index 00000000000..4d260e3c517 --- /dev/null +++ b/egs/ami/s5b/local/chain/multi_condition/tuning/run_tdnn_1a.sh @@ -0,0 +1,334 @@ +#!/bin/bash + +# This script is based on swbd 7q TDNN-F recipe +# with resnet-style skip connections, more layers, +# skinnier bottlenecks, removing the 3-way splicing and skip-layer splicing, +# and re-tuning the learning rate and l2 regularize. The configs are +# standardized and substantially simplified. +# The advantage of this style of config is that it also works +# well on smaller datasets, and we adopt this style here also for consistency. +# This gives better results than TDNN+LSTM on AMI SDM. + +# local/chain/multi_condition/tuning/run_tdnn_1a.sh --mic ihm --train-set train_cleaned --gmm tri3_cleaned & +# local/chain/multi_condition/tuning/run_tdnn_1a.sh --mic sdm1 --use-ihm-ali true --train-set train_cleaned --gmm tri3_cleaned & +# local/chain/multi_condition/tuning/run_tdnn_1a.sh --mic mdm8 --use-ihm-ali true --train-set train_cleaned --gmm tri3_cleaned & + +# steps/info/chain_dir_info.pl exp/sdm1/chain_cleaned_rvb/tdnn1a_sp_rvb_bi_ihmali +# exp/sdm1/chain_cleaned_rvb/tdnn1a_sp_rvb_bi_ihmali: num-iters=193 nj=3..16 num-params=17.5M dim=40+100->3728 combine=-0.122->-0.121 (over 2) xent:train/valid[127,192,final]=(-2.03,-1.57,-1.58/-2.12,-1.71,-1.71) logprob:train/valid[127,192,final]=(-0.179,-0.121,-0.122/-0.198,-0.158,-0.157) + +# local/chain/compare_wer_general.sh sdm1 chain_cleaned_rvb tdnn_lstm1b_sp_rvb_bi_ihmali tdnn1a_sp_rvb_bi_ihmali +# System tdnn_lstm1b_sp_rvb_bi_ihmali tdnn1a_sp_rvb_bi_ihmali +# WER on dev 33.9 33.3 +# WER on eval 37.4 36.7 +# Final train prob -0.133611 -0.122155 +# Final valid prob -0.161014 -0.156612 +# Final train prob (xent) -1.9774 -1.57504 +# Final valid prob (xent) -2.09991 -1.705 + +set -e -o pipefail + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +mic=ihm +nj=30 +use_ihm_ali=false +train_set=train_cleaned +gmm=tri3_cleaned # the gmm for the target data +ihm_gmm=tri3_cleaned # the gmm for the IHM system (if --use-ihm-ali true). +num_threads_ubm=32 +num_data_reps=1 +num_epochs=6 +get_egs_stage=-5 +remove_egs=false + +chunk_width=160,140,110,80 +dropout_schedule='0,0@0.20,0.5@0.50,0' # dropout schedule controls the dropout + # proportion for each training iteration. +xent_regularize=0.1 + +train_stage=-10 +tree_affix= # affix for tree directory, e.g. "a" or "b", in case we change the configuration. +tdnn_affix=1a #affix for TDNN directory, e.g. "a" or "b", in case we change the configuration. +common_egs_dir= # you can set this to use previously dumped egs. + +# decode options +frames_per_chunk=160 + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! $use_ihm_ali; then + [ "$mic" != "ihm" ] && \ + echo "$0: you cannot specify --use-ihm-ali false if the microphone is not ihm." && \ + exit 1; +else + [ "$mic" == "ihm" ] && \ + echo "$0: you must specify --use-ihm-ali false if the microphone is ihm." && \ + exit 1; +fi + +if ! cuda-compiled; then + cat <data/lang_chain/topo + fi +fi + +if [ $stage -le 13 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj 100 --cmd "$train_cmd" \ + --generate-ali-from-lats true ${lores_train_data_dir} \ + data/lang $gmm_dir $original_lat_dir + rm $original_lat_dir/fsts.*.gz # save space + + lat_dir_ihmdata=exp/ihm/chain${nnet3_affix}/${gmm}_${train_set}_sp_lats + + original_lat_nj=$(cat $original_lat_dir/num_jobs) + ihm_lat_nj=$(cat $lat_dir_ihmdata/num_jobs) + + $train_cmd --max-jobs-run 10 JOB=1:$original_lat_nj $lat_dir/temp/log/copy_original_lats.JOB.log \ + lattice-copy "ark:gunzip -c $original_lat_dir/lat.JOB.gz |" ark,scp:$lat_dir/temp/lats.JOB.ark,$lat_dir/temp/lats.JOB.scp + + $train_cmd --max-jobs-run 10 JOB=1:$ihm_lat_nj $lat_dir/temp2/log/copy_ihm_lats.JOB.log \ + lattice-copy "ark:gunzip -c $lat_dir_ihmdata/lat.JOB.gz |" ark,scp:$lat_dir/temp2/lats.JOB.ark,$lat_dir/temp2/lats.JOB.scp + + for n in $(seq $original_lat_nj); do + cat $lat_dir/temp/lats.$n.scp + done > $lat_dir/temp/combined_lats.scp + + for i in `seq 1 $num_data_reps`; do + for n in $(seq $ihm_lat_nj); do + cat $lat_dir/temp2/lats.$n.scp + done | sed -e "s/^/rev${i}_/" + done >> $lat_dir/temp/combined_lats.scp + + sort -u $lat_dir/temp/combined_lats.scp > $lat_dir/temp/combined_lats_sorted.scp + + utils/split_data.sh $train_data_dir $nj + + $train_cmd --max-jobs-run 10 JOB=1:$nj $lat_dir/copy_combined_lats.JOB.log \ + lattice-copy --include=$train_data_dir/split$nj/JOB/utt2spk \ + scp:$lat_dir/temp/combined_lats_sorted.scp \ + "ark:|gzip -c >$lat_dir/lat.JOB.gz" || exit 1; + + echo $nj > $lat_dir/num_jobs + + # copy other files from original lattice dir + for f in cmvn_opts final.mdl splice_opts tree; do + cp $original_lat_dir/$f $lat_dir/$f + done +fi + + +if [ $stage -le 14 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 4200 ${lores_train_data_dir} data/lang_chain $original_lat_dir $tree_dir +fi + +if [ $stage -le 15 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + affine_opts="l2-regularize=0.01 dropout-proportion=0.0 dropout-per-dim=true dropout-per-dim-continuous=true" + tdnnf_opts="l2-regularize=0.01 dropout-proportion=0.0 bypass-scale=0.66" + linear_opts="l2-regularize=0.01 orthonormal-constraint=-1.0" + prefinal_opts="l2-regularize=0.01" + output_opts="l2-regularize=0.002" + + mkdir -p $dir/configs + + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-dropout-layer name=tdnn1 $affine_opts dim=1536 + tdnnf-layer name=tdnnf2 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=1 + tdnnf-layer name=tdnnf3 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=1 + tdnnf-layer name=tdnnf4 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=1 + tdnnf-layer name=tdnnf5 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=0 + tdnnf-layer name=tdnnf6 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf7 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf8 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf9 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf10 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf11 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf12 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf13 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf14 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf15 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + linear-component name=prefinal-l dim=256 $linear_opts + + prefinal-layer name=prefinal-chain input=prefinal-l $prefinal_opts big-dim=1536 small-dim=256 + output-layer name=output include-log-softmax=false dim=$num_targets $output_opts + + prefinal-layer name=prefinal-xent input=prefinal-l $prefinal_opts big-dim=1536 small-dim=256 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 16 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/swbd-$(date +'%m_%d_%H_%M')/s5c/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir $train_ivector_dir \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.0 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width $chunk_width \ + --trainer.num-chunk-per-minibatch 64,32 \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.num-jobs-initial 3 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.00025 \ + --trainer.optimization.final-effective-lrate 0.000025 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs $remove_egs \ + --feat-dir $train_data_dir \ + --tree-dir $tree_dir \ + --lat-dir $lat_dir \ + --dir $dir +fi + + +graph_dir=$dir/graph_${LM} +if [ $stage -le 17 ]; then + # Note: it might appear that this data/lang_chain directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 data/lang_${LM} $dir $graph_dir +fi + +if [ $stage -le 18 ]; then + rm $dir/.error 2>/dev/null || true + + [ -z $extra_left_context ] && extra_left_context=$chunk_left_context; + + for decode_set in dev eval; do + ( + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $nj --cmd "$decode_cmd" \ + --frames-per-chunk "$frames_per_chunk" \ + --online-ivector-dir exp/$mic/nnet3${nnet3_affix}${rvb_affix}/ivectors_${decode_set}_hires \ + --scoring-opts "--min-lmwt 5 " \ + $graph_dir data/$mic/${decode_set}_hires $dir/decode_${decode_set} || exit 1; + ) || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in decoding" + exit 1 + fi +fi +exit 0 diff --git a/egs/ami/s5b/local/chain/multi_condition/tuning/run_tdnn_lstm_1a.sh b/egs/ami/s5b/local/chain/multi_condition/tuning/run_tdnn_lstm_1a.sh index 2869049843f..3546b6a7ced 100755 --- a/egs/ami/s5b/local/chain/multi_condition/tuning/run_tdnn_lstm_1a.sh +++ b/egs/ami/s5b/local/chain/multi_condition/tuning/run_tdnn_lstm_1a.sh @@ -19,7 +19,6 @@ set -e -o pipefail stage=0 mic=ihm nj=30 -min_seg_len=1.55 use_ihm_ali=false train_set=train_cleaned gmm=tri3_cleaned # the gmm for the target data @@ -27,7 +26,7 @@ ihm_gmm=tri3_cleaned # the gmm for the IHM system (if --use-ihm-ali true). num_threads_ubm=32 num_data_reps=1 -chunk_width=150 +chunk_width=160,140,110,80 chunk_left_context=40 chunk_right_context=0 label_delay=5 @@ -35,13 +34,13 @@ label_delay=5 # are just hardcoded at this level, in the commands below. train_stage=-10 tree_affix= # affix for tree directory, e.g. "a" or "b", in case we change the configuration. -tlstm_affix=1i #affix for TDNN-LSTM directory, e.g. "a" or "b", in case we change the configuration. +tlstm_affix=1a #affix for TDNN-LSTM directory, e.g. "a" or "b", in case we change the configuration. common_egs_dir= # you can set this to use previously dumped egs. # decode options extra_left_context=50 -frames_per_chunk= +frames_per_chunk=160 # End configuration section. @@ -75,21 +74,19 @@ rvb_affix=_rvb if $use_ihm_ali; then gmm_dir=exp/ihm/${ihm_gmm} - ali_dir=exp/${mic}/${ihm_gmm}_ali_${train_set}_sp_comb_ihmdata - lores_train_data_dir=data/$mic/${train_set}_ihmdata_sp_comb + lores_train_data_dir=data/$mic/${train_set}_ihmdata_sp tree_dir=exp/$mic/chain${nnet3_affix}/tree_bi${tree_affix}_ihmdata - original_lat_dir=exp/$mic/chain${nnet3_affix}/${ihm_gmm}_${train_set}_sp_comb_lats_ihmdata - lat_dir=exp/$mic/chain${nnet3_affix}${rvb_affix}/${ihm_gmm}_${train_set}_sp${rvb_affix}_comb_lats_ihmdata + original_lat_dir=exp/$mic/chain${nnet3_affix}/${ihm_gmm}_${train_set}_sp_lats_ihmdata + lat_dir=exp/$mic/chain${nnet3_affix}${rvb_affix}/${ihm_gmm}_${train_set}_sp${rvb_affix}_lats_ihmdata dir=exp/$mic/chain${nnet3_affix}${rvb_affix}/tdnn_lstm${tlstm_affix}_sp${rvb_affix}_bi_ihmali # note: the distinction between when we use the 'ihmdata' suffix versus # 'ihmali' is pretty arbitrary. else gmm_dir=exp/${mic}/$gmm - ali_dir=exp/${mic}/${gmm}_ali_${train_set}_sp_comb - lores_train_data_dir=data/$mic/${train_set}_sp_comb + lores_train_data_dir=data/$mic/${train_set}_sp tree_dir=exp/$mic/chain${nnet3_affix}/tree_bi${tree_affix} - original_lat_dir=exp/$mic/chain${nnet3_affix}/${gmm}_${train_set}_sp_comb_lats - lat_dir=exp/$mic/chain${nnet3_affix}${rvb_affix}/${gmm}_${train_set}_sp${rvb_affix}_comb_lats + original_lat_dir=exp/$mic/chain${nnet3_affix}/${gmm}_${train_set}_sp_lats + lat_dir=exp/$mic/chain${nnet3_affix}${rvb_affix}/${gmm}_${train_set}_sp${rvb_affix}_lats dir=exp/$mic/chain${nnet3_affix}${rvb_affix}/tdnn_lstm${tlstm_affix}_sp${rvb_affix}_bi fi @@ -97,9 +94,7 @@ fi local/nnet3/multi_condition/run_ivector_common.sh --stage $stage \ --mic $mic \ --nj $nj \ - --min-seg-len $min_seg_len \ --train-set $train_set \ - --gmm $gmm \ --num-threads-ubm $num_threads_ubm \ --num-data-reps $num_data_reps \ --nnet3-affix "$nnet3_affix" @@ -109,13 +104,13 @@ local/nnet3/multi_condition/run_ivector_common.sh --stage $stage \ local/nnet3/prepare_lores_feats.sh --stage $stage \ --mic $mic \ --nj $nj \ - --min-seg-len $min_seg_len \ + --min-seg-len "" \ --use-ihm-ali $use_ihm_ali \ --train-set $train_set -train_data_dir=data/$mic/${train_set}_sp${rvb_affix}_hires_comb -train_ivector_dir=exp/$mic/nnet3${nnet3_affix}${rvb_affix}/ivectors_${train_set}_sp${rvb_affix}_hires_comb +train_data_dir=data/$mic/${train_set}_sp${rvb_affix}_hires +train_ivector_dir=exp/$mic/nnet3${nnet3_affix}${rvb_affix}/ivectors_${train_set}_sp${rvb_affix}_hires final_lm=`cat data/local/lm/final_lm` LM=$final_lm.pr1-7 @@ -126,19 +121,6 @@ for f in $gmm_dir/final.mdl $lores_train_data_dir/feats.scp \ done -if [ $stage -le 11 ]; then - if [ -f $ali_dir/ali.1.gz ]; then - echo "$0: alignments in $ali_dir appear to already exist. Please either remove them " - echo " ... or use a later --stage option." - exit 1 - fi - echo "$0: aligning perturbed, short-segment-combined ${maybe_ihm}data" - steps/align_fmllr.sh --nj $nj --cmd "$train_cmd" \ - ${lores_train_data_dir} data/lang $gmm_dir $ali_dir -fi - -[ ! -f $ali_dir/ali.1.gz ] && echo "$0: expected $ali_dir/ali.1.gz to exist" && exit 1 - if [ $stage -le 12 ]; then echo "$0: creating lang directory with one state per phone." # Create a version of the lang/ directory that has one state per phone in the @@ -165,28 +147,42 @@ fi if [ $stage -le 13 ]; then # Get the alignments as lattices (gives the chain training more freedom). # use the same num-jobs as the alignments - steps/align_fmllr_lats.sh --nj 100 --cmd "$train_cmd" ${lores_train_data_dir} \ + steps/align_fmllr_lats.sh --nj 100 --cmd "$train_cmd" \ + --generate-ali-from-lats true ${lores_train_data_dir} \ data/lang $gmm_dir $original_lat_dir rm $original_lat_dir/fsts.*.gz # save space - lat_dir_ihmdata=exp/ihm/chain${nnet3_affix}/${gmm}_${train_set}_sp_comb_lats + lat_dir_ihmdata=exp/ihm/chain${nnet3_affix}/${gmm}_${train_set}_sp_lats + + original_lat_nj=$(cat $original_lat_dir/num_jobs) + ihm_lat_nj=$(cat $lat_dir_ihmdata/num_jobs) - mkdir -p $lat_dir/temp/ - mkdir -p $lat_dir/temp2/ - lattice-copy "ark:gunzip -c $original_lat_dir/lat.*.gz |" ark,scp:$lat_dir/temp/lats.ark,$lat_dir/temp/lats.scp - lattice-copy "ark:gunzip -c $lat_dir_ihmdata/lat.*.gz |" ark,scp:$lat_dir/temp2/lats.ark,$lat_dir/temp2/lats.scp + $train_cmd --max-jobs-run 10 JOB=1:$original_lat_nj $lat_dir/temp/log/copy_original_lats.JOB.log \ + lattice-copy "ark:gunzip -c $original_lat_dir/lat.JOB.gz |" ark,scp:$lat_dir/temp/lats.JOB.ark,$lat_dir/temp/lats.JOB.scp + + $train_cmd --max-jobs-run 10 JOB=1:$ihm_lat_nj $lat_dir/temp2/log/copy_ihm_lats.JOB.log \ + lattice-copy "ark:gunzip -c $lat_dir_ihmdata/lat.JOB.gz |" ark,scp:$lat_dir/temp2/lats.JOB.ark,$lat_dir/temp2/lats.JOB.scp + + for n in $(seq $original_lat_nj); do + cat $lat_dir/temp/lats.$n.scp + done > $lat_dir/temp/combined_lats.scp - # copy the lattices for the reverberated data - rm -f $lat_dir/temp/combined_lats.scp - touch $lat_dir/temp/combined_lats.scp - cat $lat_dir/temp/lats.scp >> $lat_dir/temp/combined_lats.scp for i in `seq 1 $num_data_reps`; do - cat $lat_dir/temp2/lats.scp | sed -e "s/^/rev${i}_/" >> $lat_dir/temp/combined_lats.scp - done + for n in $(seq $ihm_lat_nj); do + cat $lat_dir/temp2/lats.$n.scp + done | sed -e "s/^/rev${i}_/" + done >> $lat_dir/temp/combined_lats.scp + sort -u $lat_dir/temp/combined_lats.scp > $lat_dir/temp/combined_lats_sorted.scp - lattice-copy scp:$lat_dir/temp/combined_lats_sorted.scp "ark:|gzip -c >$lat_dir/lat.1.gz" || exit 1; - echo "1" > $lat_dir/num_jobs + utils/split_data.sh $train_data_dir $nj + + $train_cmd --max-jobs-run 10 JOB=1:$nj $lat_dir/copy_combined_lats.JOB.log \ + lattice-copy --include=$train_data_dir/split$nj/JOB/utt2spk \ + scp:$lat_dir/temp/combined_lats_sorted.scp \ + "ark:|gzip -c >$lat_dir/lat.JOB.gz" || exit 1; + + echo $nj > $lat_dir/num_jobs # copy other files from original lattice dir for f in cmvn_opts final.mdl splice_opts tree; do @@ -206,7 +202,7 @@ if [ $stage -le 14 ]; then steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ --context-opts "--context-width=2 --central-position=1" \ --leftmost-questions-truncate -1 \ - --cmd "$train_cmd" 4200 ${lores_train_data_dir} data/lang_chain $ali_dir $tree_dir + --cmd "$train_cmd" 4200 ${lores_train_data_dir} data/lang_chain $original_lat_dir $tree_dir fi xent_regularize=0.1 @@ -215,7 +211,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig @@ -312,7 +308,6 @@ if [ $stage -le 18 ]; then rm $dir/.error 2>/dev/null || true [ -z $extra_left_context ] && extra_left_context=$chunk_left_context; - [ -z $frames_per_chunk ] && frames_per_chunk=$chunk_width; for decode_set in dev eval; do ( diff --git a/egs/ami/s5b/local/chain/multi_condition/tuning/run_tdnn_lstm_1b.sh b/egs/ami/s5b/local/chain/multi_condition/tuning/run_tdnn_lstm_1b.sh new file mode 100755 index 00000000000..1a839b045bd --- /dev/null +++ b/egs/ami/s5b/local/chain/multi_condition/tuning/run_tdnn_lstm_1b.sh @@ -0,0 +1,360 @@ +#!/bin/bash + +# This is a chain-training script with TDNN+LSTM neural networks. +# This script is similar to local/chain/multi_condition/tuning/run_tdnn_lstm_1a.sh, +# but updated to use new l2-regularize options and fast-lstmp with decay-time. +# It uses the reverberated IHM data in the train set. +# This script obtains better results on IHM, SDM and MDM tasks. + +# Please see RESULTS_* for examples of command lines invoking this script. + +# local/chain/multi_condition/tuning/run_tdnn_lstm_1b.sh --mic ihm --train-set train_cleaned --gmm tri3_cleaned & +# local/chain/multi_condition/tuning/run_tdnn_lstm_1b.sh --mic sdm1 --use-ihm-ali true --train-set train_cleaned --gmm tri3_cleaned & +# local/chain/multi_condition/tuning/run_tdnn_lstm_1b.sh --mic mdm8 --use-ihm-ali true --train-set train_cleaned --gmm tri3_cleaned & + +# steps/info/chain_dir_info.pl exp/ihm/chain_cleaned_rvb/tdnn_lstm1b_sp_rvb_bi +# exp/ihm/chain_cleaned_rvb/tdnn_lstm1b_sp_rvb_bi: num-iters=176 nj=2..12 num-params=43.4M dim=40+100->3736 combine=-0.101->-0.100 (over 2) xent:train/valid[116,175,final]=(-2.47,-1.60,-1.55/-2.58,-1.73,-1.69) logprob:train/valid[116,175,final]=(-0.144,-0.101,-0.099/-0.163,-0.138,-0.136) +# steps/info/chain_dir_info.pl exp/sdm1/chain_cleaned_rvb/tdnn_lstm1b_sp_rvb_bi_ihmali +# exp/sdm1/chain_cleaned_rvb/tdnn_lstm1b_sp_rvb_bi_ihmali: num-iters=174 nj=2..12 num-params=43.4M dim=40+100->3728 combine=-0.129->-0.126 (over 4) xent:train/valid[115,173,final]=(-2.86,-1.97,-1.98/-2.96,-2.10,-2.10) logprob:train/valid[115,173,final]=(-0.184,-0.134,-0.134/-0.200,-0.164,-0.161) + +# local/chain/compare_wer_general.sh ihm chain_cleaned_rvb tdnn_lstm1{a,b}_sp_rvb_bi +# System tdnn_lstm1a_sp_rvb_bi tdnn_lstm1b_sp_rvb_bi +# WER on dev 19.4 18.9 +# WER on eval 19.4 19.3 +# Final train prob -0.0627414-0.0985175 +# Final valid prob -0.141082 -0.136302 +# Final train prob (xent) -0.847054 -1.55263 +# Final valid prob (xent) -1.25849 -1.69064 + +# local/chain/compare_wer_general.sh sdm1 chain_cleaned_rvb tdnn_lstm1{a,b}_sp_rvb_bi_ihmali +# System tdnn_lstm1a_sp_rvb_bi_ihmali tdnn_lstm1b_sp_rvb_bi_ihmali +# WER on dev 34.6 33.9 +# WER on eval 37.6 37.4 +# Final train prob -0.0861836 -0.133611 +# Final valid prob -0.149669 -0.161014 +# Final train prob (xent) -1.21927 -1.9774 +# Final valid prob (xent) -1.53542 -2.09991 + +set -e -o pipefail + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +mic=ihm +nj=30 +use_ihm_ali=false +train_set=train_cleaned +gmm=tri3_cleaned # the gmm for the target data +ihm_gmm=tri3_cleaned # the gmm for the IHM system (if --use-ihm-ali true). +num_threads_ubm=32 +num_data_reps=1 +num_epochs=4 + +chunk_width=160,140,110,80 +chunk_left_context=40 +chunk_right_context=0 +label_delay=5 +dropout_schedule='0,0@0.20,0.3@0.50,0' # dropout schedule controls the dropout + # proportion for each training iteration. +xent_regularize=0.025 + +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +train_stage=-10 +tree_affix= # affix for tree directory, e.g. "a" or "b", in case we change the configuration. +tlstm_affix=1b #affix for TDNN-LSTM directory, e.g. "a" or "b", in case we change the configuration. +common_egs_dir= # you can set this to use previously dumped egs. + +# decode options +extra_left_context=50 +frames_per_chunk=160 + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! $use_ihm_ali; then + [ "$mic" != "ihm" ] && \ + echo "$0: you cannot specify --use-ihm-ali false if the microphone is not ihm." && \ + exit 1; +else + [ "$mic" == "ihm" ] && \ + echo "$0: you must specify --use-ihm-ali false if the microphone is ihm." && \ + exit 1; +fi + +if ! cuda-compiled; then + cat <data/lang_chain/topo + fi +fi + +if [ $stage -le 13 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj 100 --cmd "$train_cmd" \ + --generate-ali-from-lats true ${lores_train_data_dir} \ + data/lang $gmm_dir $original_lat_dir + rm $original_lat_dir/fsts.*.gz # save space + + lat_dir_ihmdata=exp/ihm/chain${nnet3_affix}/${gmm}_${train_set}_sp_lats + + original_lat_nj=$(cat $original_lat_dir/num_jobs) + ihm_lat_nj=$(cat $lat_dir_ihmdata/num_jobs) + + $train_cmd --max-jobs-run 10 JOB=1:$original_lat_nj $lat_dir/temp/log/copy_original_lats.JOB.log \ + lattice-copy "ark:gunzip -c $original_lat_dir/lat.JOB.gz |" ark,scp:$lat_dir/temp/lats.JOB.ark,$lat_dir/temp/lats.JOB.scp + + $train_cmd --max-jobs-run 10 JOB=1:$ihm_lat_nj $lat_dir/temp2/log/copy_ihm_lats.JOB.log \ + lattice-copy "ark:gunzip -c $lat_dir_ihmdata/lat.JOB.gz |" ark,scp:$lat_dir/temp2/lats.JOB.ark,$lat_dir/temp2/lats.JOB.scp + + for n in $(seq $original_lat_nj); do + cat $lat_dir/temp/lats.$n.scp + done > $lat_dir/temp/combined_lats.scp + + for i in `seq 1 $num_data_reps`; do + for n in $(seq $ihm_lat_nj); do + cat $lat_dir/temp2/lats.$n.scp + done | sed -e "s/^/rev${i}_/" + done >> $lat_dir/temp/combined_lats.scp + + sort -u $lat_dir/temp/combined_lats.scp > $lat_dir/temp/combined_lats_sorted.scp + + utils/split_data.sh $train_data_dir $nj + + $train_cmd --max-jobs-run 10 JOB=1:$nj $lat_dir/copy_combined_lats.JOB.log \ + lattice-copy --include=$train_data_dir/split$nj/JOB/utt2spk \ + scp:$lat_dir/temp/combined_lats_sorted.scp \ + "ark:|gzip -c >$lat_dir/lat.JOB.gz" || exit 1; + + echo $nj > $lat_dir/num_jobs + + # copy other files from original lattice dir + for f in cmvn_opts final.mdl splice_opts tree; do + cp $original_lat_dir/$f $lat_dir/$f + done +fi + + +if [ $stage -le 14 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --leftmost-questions-truncate -1 \ + --cmd "$train_cmd" 4200 ${lores_train_data_dir} data/lang_chain $original_lat_dir $tree_dir +fi + +if [ $stage -le 15 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + tdnn_opts="l2-regularize=0.006" + lstm_opts="l2-regularize=0.0025 decay-time=20 dropout-proportion=0.0" + output_opts="l2-regularize=0.001" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=1024 $tdnn_opts + relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1) dim=1024 $tdnn_opts + relu-batchnorm-layer name=tdnn3 input=Append(-1,0,1) dim=1024 $tdnn_opts + + # check steps/libs/nnet3/xconfig/lstm.py for the other options and defaults + fast-lstmp-layer name=lstm1 cell-dim=1024 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-3 $lstm_opts + relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=1024 $tdnn_opts + relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=1024 $tdnn_opts + relu-batchnorm-layer name=tdnn6 input=Append(-3,0,3) dim=1024 $tdnn_opts + fast-lstmp-layer name=lstm2 cell-dim=1024 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-3 $lstm_opts + relu-batchnorm-layer name=tdnn7 input=Append(-3,0,3) dim=1024 $tdnn_opts + relu-batchnorm-layer name=tdnn8 input=Append(-3,0,3) dim=1024 $tdnn_opts + relu-batchnorm-layer name=tdnn9 input=Append(-3,0,3) dim=1024 $tdnn_opts + fast-lstmp-layer name=lstm3 cell-dim=1024 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-3 $lstm_opts + + ## adding the layers for chain branch + output-layer name=output input=lstm3 output-delay=$label_delay include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + output-layer name=output-xent input=lstm3 output-delay=$label_delay dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts + +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 16 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/ami-$(date +'%m_%d_%H_%M')/s5b/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir $train_ivector_dir \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.0 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --egs.dir "$common_egs_dir" \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width $chunk_width \ + --egs.chunk-left-context $chunk_left_context \ + --egs.chunk-right-context $chunk_right_context \ + --egs.chunk-left-context-initial 0 \ + --egs.chunk-right-context-final 0 \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.num-chunk-per-minibatch 64,32 \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.num-jobs-initial 2 \ + --trainer.optimization.num-jobs-final 12 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.max-param-change 2.0 \ + --trainer.deriv-truncate-margin 8 \ + --cleanup.remove-egs false \ + --feat-dir $train_data_dir \ + --tree-dir $tree_dir \ + --lat-dir $lat_dir \ + --dir $dir +fi + + +graph_dir=$dir/graph_${LM} +if [ $stage -le 17 ]; then + # Note: it might appear that this data/lang_chain directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 data/lang_${LM} $dir $graph_dir +fi + +if [ $stage -le 18 ]; then + rm $dir/.error 2>/dev/null || true + + [ -z $extra_left_context ] && extra_left_context=$chunk_left_context; + + for decode_set in dev eval; do + ( + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $nj --cmd "$decode_cmd" \ + --extra-left-context $extra_left_context \ + --frames-per-chunk "$frames_per_chunk" \ + --extra-left-context-initial 0 \ + --extra-right-context-final 0 \ + --online-ivector-dir exp/$mic/nnet3${nnet3_affix}${rvb_affix}/ivectors_${decode_set}_hires \ + --scoring-opts "--min-lmwt 5 " \ + $graph_dir data/$mic/${decode_set}_hires $dir/decode_${decode_set} || exit 1; + ) || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in decoding" + exit 1 + fi +fi +exit 0 diff --git a/egs/ami/s5b/local/chain/tuning/run_cnn_tdnn_lstm_1a.sh b/egs/ami/s5b/local/chain/tuning/run_cnn_tdnn_lstm_1a.sh index 16d1f4044f5..d926c1dc6d7 100644 --- a/egs/ami/s5b/local/chain/tuning/run_cnn_tdnn_lstm_1a.sh +++ b/egs/ami/s5b/local/chain/tuning/run_cnn_tdnn_lstm_1a.sh @@ -184,7 +184,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" diff --git a/egs/ami/s5b/local/chain/tuning/run_cnn_tdnn_lstm_1b.sh b/egs/ami/s5b/local/chain/tuning/run_cnn_tdnn_lstm_1b.sh index 83e6a95582f..d9cd1c356e8 100644 --- a/egs/ami/s5b/local/chain/tuning/run_cnn_tdnn_lstm_1b.sh +++ b/egs/ami/s5b/local/chain/tuning/run_cnn_tdnn_lstm_1b.sh @@ -176,7 +176,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20 dropout-proportion=0" diff --git a/egs/ami/s5b/local/chain/tuning/run_cnn_tdnn_lstm_1c.sh b/egs/ami/s5b/local/chain/tuning/run_cnn_tdnn_lstm_1c.sh index 387b4bfcc88..a0805b4f9f1 100755 --- a/egs/ami/s5b/local/chain/tuning/run_cnn_tdnn_lstm_1c.sh +++ b/egs/ami/s5b/local/chain/tuning/run_cnn_tdnn_lstm_1c.sh @@ -185,7 +185,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=40" diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_1b.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_1b.sh index 57108dbddae..997357b80a9 100755 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_1b.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_1b.sh @@ -164,7 +164,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_1c.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_1c.sh index f87e1a12d36..4d062e65429 100755 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_1c.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_1c.sh @@ -151,7 +151,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_1d.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_1d.sh index eb84a1cd876..387570388d0 100755 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_1d.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_1d.sh @@ -163,7 +163,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_1e.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_1e.sh index e6592b667dc..0436b08cdc0 100755 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_1e.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_1e.sh @@ -161,7 +161,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_1f.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_1f.sh index 8bf2b73dada..4ca526d63b8 100644 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_1f.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_1f.sh @@ -165,7 +165,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_1g.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_1g.sh index dfb6dfedee7..baed760bb68 100644 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_1g.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_1g.sh @@ -166,7 +166,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_1h.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_1h.sh index 3e26a8b38bd..e721a858c0a 100755 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_1h.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_1h.sh @@ -167,7 +167,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_1i.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_1i.sh index 1931127c86d..de40cb2d1a4 100755 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_1i.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_1i.sh @@ -168,7 +168,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.02" output_opts="l2-regularize=0.004" diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1a.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1a.sh index d63712f1f0f..4f580b88f6b 100755 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1a.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1a.sh @@ -171,7 +171,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1b.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1b.sh index a53785f45c2..904a079d7de 100755 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1b.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1b.sh @@ -173,7 +173,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1c.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1c.sh index 76a9f735c5f..511e520465a 100755 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1c.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1c.sh @@ -172,7 +172,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1d.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1d.sh index 8cc1a4e15fa..bd81b7df4eb 100755 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1d.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1d.sh @@ -172,7 +172,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1e.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1e.sh index accfd158a9d..50903e78b6d 100755 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1e.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1e.sh @@ -174,7 +174,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1f.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1f.sh index 2b275e4e27d..f6c53001498 100755 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1f.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1f.sh @@ -173,7 +173,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1g.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1g.sh index 1c90af38c4c..79fd9ef3fb5 100755 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1g.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1g.sh @@ -174,7 +174,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1h.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1h.sh index fb4b6a475e2..e58a7f89e03 100755 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1h.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1h.sh @@ -171,7 +171,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1i.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1i.sh index 92636b4c17e..13f894f5a48 100755 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1i.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1i.sh @@ -174,7 +174,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1j.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1j.sh index 89fd8ce2915..48b31832e8c 100755 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1j.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1j.sh @@ -181,7 +181,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1k.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1k.sh index b8d947d8e92..e675bc494bb 100755 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1k.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1k.sh @@ -177,7 +177,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1l.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1l.sh index 74c0f5a6ead..2d019398274 100644 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1l.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1l.sh @@ -224,7 +224,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1m.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1m.sh index b0e7af0618d..9e5b971bbe2 100644 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1m.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1m.sh @@ -226,7 +226,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20 dropout-proportion=0.0" diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1n.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1n.sh index bee4d997b01..9575c3cf686 100644 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1n.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1n.sh @@ -178,7 +178,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1o.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1o.sh index 1e4111adc6a..a7f2625c181 100755 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1o.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_1o.sh @@ -182,7 +182,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) tdnn_opts="l2-regularize=0.025" lstm_opts="l2-regularize=0.01" output_opts="l2-regularize=0.004" diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_bs_1a.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_bs_1a.sh index b672a44e572..ca920869b30 100755 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_bs_1a.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_lstm_bs_1a.sh @@ -180,7 +180,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) tdnn_opts="l2-regularize=0.003" lstm_opts="l2-regularize=0.005" output_opts="l2-regularize=0.001" diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_opgru_1a.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_opgru_1a.sh index f68c4203767..53dbd5238db 100644 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_opgru_1a.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_opgru_1a.sh @@ -178,7 +178,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) gru_opts="dropout-per-frame=true dropout-proportion=0.0" mkdir -p $dir/configs diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_opgru_1b.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_opgru_1b.sh index ac4266ca162..dafef668e60 100644 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_opgru_1b.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_opgru_1b.sh @@ -177,7 +177,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) gru_opts="dropout-per-frame=true dropout-proportion=0.0" mkdir -p $dir/configs diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_opgru_1c.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_opgru_1c.sh index 74b21f10c33..677946d0b9a 100644 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_opgru_1c.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_opgru_1c.sh @@ -176,7 +176,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) gru_opts="dropout-per-frame=true dropout-proportion=0.0" mkdir -p $dir/configs diff --git a/egs/ami/s5b/local/nnet3/multi_condition/run_ivector_common.sh b/egs/ami/s5b/local/nnet3/multi_condition/run_ivector_common.sh index eb20415e515..5ba35fa421c 100755 --- a/egs/ami/s5b/local/nnet3/multi_condition/run_ivector_common.sh +++ b/egs/ami/s5b/local/nnet3/multi_condition/run_ivector_common.sh @@ -10,19 +10,17 @@ set -e -o pipefail stage=1 mic=ihm nj=30 -min_seg_len=1.55 # min length in seconds... we do this because chain training - # will discard segments shorter than 1.5 seconds. Must remain in sync with - # the same option given to prepare_lores_feats.sh. train_set=train_cleaned # you might set this to e.g. train_cleaned. -gmm=tri3_cleaned # This specifies a GMM-dir from the features of the type you're training the system on; - # it should contain alignments for 'train_set'. - +norvb_datadir=data/ihm/train_cleaned_sp num_threads_ubm=32 rvb_affix=_rvb nnet3_affix=_cleaned # affix for exp/$mic/nnet3 directory to put iVector stuff in, so it # becomes exp/$mic/nnet3_cleaned or whatever. num_data_reps=1 +sample_rate=16000 + +max_jobs_run=10 . ./cmd.sh . ./path.sh @@ -30,10 +28,7 @@ num_data_reps=1 nnet3_affix=${nnet3_affix}$rvb_affix -gmmdir=exp/${mic}/${gmm} - - -for f in data/${mic}/${train_set}/feats.scp ${gmmdir}/final.mdl; do +for f in data/${mic}/${train_set}/feats.scp; do if [ ! -f $f ]; then echo "$0: expected file $f to exist" exit 1 @@ -73,36 +68,23 @@ if [ $stage -le 1 ]; then for datadir in ${train_set}_sp dev eval; do steps/make_mfcc.sh --nj $nj --mfcc-config conf/mfcc_hires.conf \ - --cmd "$train_cmd" data/$mic/${datadir}_hires + --cmd "$train_cmd --max-jobs-run $max_jobs_run" data/$mic/${datadir}_hires steps/compute_cmvn_stats.sh data/$mic/${datadir}_hires utils/fix_data_dir.sh data/$mic/${datadir}_hires done fi -if [ $stage -le 2 ]; then - echo "$0: combining short segments of speed-perturbed high-resolution MFCC training data" - # we have to combine short segments or we won't be able to train chain models - # on those segments. - utils/data/combine_short_segments.sh \ - data/${mic}/${train_set}_sp_hires $min_seg_len data/${mic}/${train_set}_sp_hires_comb - - # just copy over the CMVN to avoid having to recompute it. - cp data/${mic}/${train_set}_sp_hires/cmvn.scp data/${mic}/${train_set}_sp_hires_comb/ - utils/fix_data_dir.sh data/${mic}/${train_set}_sp_hires_comb/ -fi if [ $stage -le 3 ]; then echo "$0: creating reverberated MFCC features" - datadir=data/ihm/train_cleaned_sp - - mfccdir=${datadir}_rvb${num_data_reps}_hires/data + mfccdir=${norvb_datadir}${rvb_affix}${num_data_reps}_hires/data if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then utils/create_split_dir.pl /export/b0{5,6,7,8}/$USER/kaldi-data/egs/ami-$mic-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage fi - if [ ! -f ${datadir}_rvb${num_data_reps}_hires/feats.scp ]; then - if [ ! -d "RIRS_NOISES" ]; then + if [ ! -f ${norvb_datadir}${rvb_affix}${num_data_reps}_hires/feats.scp ]; then + if [ ! -d "RIRS_NOISES/" ]; then # Download the package that includes the real RIRs, simulated RIRs, isotropic noises and point-source noises wget --no-check-certificate http://www.openslr.org/resources/28/rirs_noises.zip unzip rirs_noises.zip @@ -123,60 +105,29 @@ if [ $stage -le 3 ]; then --isotropic-noise-addition-probability 1 \ --num-replications ${num_data_reps} \ --max-noises-per-minute 1 \ - --source-sampling-rate 16000 \ - ${datadir} ${datadir}_rvb${num_data_reps} + --source-sampling-rate $sample_rate \ + ${norvb_datadir} ${norvb_datadir}${rvb_affix}${num_data_reps} - utils/copy_data_dir.sh ${datadir}_rvb${num_data_reps} ${datadir}_rvb${num_data_reps}_hires - utils/data/perturb_data_dir_volume.sh ${datadir}_rvb${num_data_reps}_hires + utils/copy_data_dir.sh ${norvb_datadir}${rvb_affix}${num_data_reps} ${norvb_datadir}${rvb_affix}${num_data_reps}_hires + utils/data/perturb_data_dir_volume.sh ${norvb_datadir}${rvb_affix}${num_data_reps}_hires steps/make_mfcc.sh --nj $nj --mfcc-config conf/mfcc_hires.conf \ - --cmd "$train_cmd" ${datadir}_rvb${num_data_reps}_hires - steps/compute_cmvn_stats.sh ${datadir}_rvb${num_data_reps}_hires - utils/fix_data_dir.sh ${datadir}_rvb${num_data_reps}_hires - - utils/data/combine_short_segments.sh \ - ${datadir}_rvb${num_data_reps}_hires $min_seg_len ${datadir}_rvb${num_data_reps}_hires_comb - - # just copy over the CMVN to avoid having to recompute it. - cp ${datadir}_rvb${num_data_reps}_hires/cmvn.scp ${datadir}_rvb${num_data_reps}_hires_comb/ - utils/fix_data_dir.sh ${datadir}_rvb${num_data_reps}_hires_comb/ + --cmd "$train_cmd --max-jobs-run $max_jobs_run" ${norvb_datadir}${rvb_affix}${num_data_reps}_hires + steps/compute_cmvn_stats.sh ${norvb_datadir}${rvb_affix}${num_data_reps}_hires + utils/fix_data_dir.sh ${norvb_datadir}${rvb_affix}${num_data_reps}_hires fi - utils/combine_data.sh data/${mic}/${train_set}_sp_rvb_hires data/${mic}/${train_set}_sp_hires ${datadir}_rvb${num_data_reps}_hires - utils/combine_data.sh data/${mic}/${train_set}_sp_rvb_hires_comb data/${mic}/${train_set}_sp_hires_comb ${datadir}_rvb${num_data_reps}_hires_comb + utils/combine_data.sh data/${mic}/${train_set}_sp${rvb_affix}_hires data/${mic}/${train_set}_sp_hires ${norvb_datadir}${rvb_affix}${num_data_reps}_hires fi - if [ $stage -le 4 ]; then - echo "$0: selecting segments of hires training data that were also present in the" - echo " ... original training data." - - # note, these data-dirs are temporary; we put them in a sub-directory - # of the place where we'll make the alignments. - temp_data_root=exp/$mic/nnet3${nnet3_affix}/tri5 - mkdir -p $temp_data_root - - utils/data/subset_data_dir.sh --utt-list data/${mic}/${train_set}/feats.scp \ - data/${mic}/${train_set}_sp_hires $temp_data_root/${train_set}_hires - - # note: essentially all the original segments should be in the hires data. - n1=$(wc -l $tmpdir/ihmutt2utt # Map the 1st field of the segments file from the ihm data (the 1st field being # the utterance-id) to the corresponding SDM or MDM utterance-id. The other # fields remain the same (e.g. we want the recording-ids from the IHM data). -utils/apply_map.pl -f 1 $tmpdir/ihmutt2utt data/$mic/train_ihmdata/segments +utils/apply_map.pl -f 1 $tmpdir/ihmutt2utt data/$mic/${train_set}_ihmdata/segments -utils/fix_data_dir.sh data/$mic/train_ihmdata +utils/fix_data_dir.sh data/$mic/${train_set}_ihmdata rm $tmpdir/ihmutt2utt diff --git a/egs/an4/s5/local/data_prep.py b/egs/an4/s5/local/data_prep.py index 24cb9bffb07..9d8083f3b60 100644 --- a/egs/an4/s5/local/data_prep.py +++ b/egs/an4/s5/local/data_prep.py @@ -15,6 +15,7 @@ # See the Apache 2 License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function import os import re import sys diff --git a/egs/an4/s5/local/lexicon_prep.py b/egs/an4/s5/local/lexicon_prep.py index 8d451daf869..3584fa86dfb 100644 --- a/egs/an4/s5/local/lexicon_prep.py +++ b/egs/an4/s5/local/lexicon_prep.py @@ -15,6 +15,7 @@ # See the Apache 2 License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function import os import re import sys diff --git a/egs/aspire/s5/local/chain/tuning/run_blstm_7b.sh b/egs/aspire/s5/local/chain/tuning/run_blstm_7b.sh index 8ff59d83ed0..bd13010c791 100755 --- a/egs/aspire/s5/local/chain/tuning/run_blstm_7b.sh +++ b/egs/aspire/s5/local/chain/tuning/run_blstm_7b.sh @@ -138,7 +138,7 @@ if [ $stage -le 11 ]; then num_targets=$(tree-info $treedir/tree | grep num-pdfs | awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" diff --git a/egs/aspire/s5/local/chain/tuning/run_tdnn_7b.sh b/egs/aspire/s5/local/chain/tuning/run_tdnn_7b.sh index 201f61dc64b..d6292fbadb3 100755 --- a/egs/aspire/s5/local/chain/tuning/run_tdnn_7b.sh +++ b/egs/aspire/s5/local/chain/tuning/run_tdnn_7b.sh @@ -136,7 +136,7 @@ if [ $stage -le 11 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/aspire/s5/local/chain/tuning/run_tdnn_lstm_1a.sh b/egs/aspire/s5/local/chain/tuning/run_tdnn_lstm_1a.sh index 63d3a7ca988..e6aa37a7543 100755 --- a/egs/aspire/s5/local/chain/tuning/run_tdnn_lstm_1a.sh +++ b/egs/aspire/s5/local/chain/tuning/run_tdnn_lstm_1a.sh @@ -151,7 +151,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=40" diff --git a/egs/aspire/s5/local/multi_condition/create_uniform_segments.py b/egs/aspire/s5/local/multi_condition/create_uniform_segments.py index e7baafc028c..010811490ef 100755 --- a/egs/aspire/s5/local/multi_condition/create_uniform_segments.py +++ b/egs/aspire/s5/local/multi_condition/create_uniform_segments.py @@ -4,13 +4,14 @@ # creates a segments file in the provided data directory # into uniform segments with specified window and overlap +from __future__ import division import imp, sys, argparse, os, math, subprocess min_segment_length = 10 # in seconds def segment(total_length, window_length, overlap = 0): increment = window_length - overlap num_windows = int(math.ceil(float(total_length)/increment)) - segments = map(lambda x: (x * increment, min( total_length, (x * increment) + window_length)), range(0, num_windows)) + segments = [(x * increment, min( total_length, (x * increment) + window_length)) for x in range(0, num_windows)] if segments[-1][1] - segments[-1][0] < min_segment_length: segments[-2] = (segments[-2][0], segments[-1][1]) segments.pop() @@ -53,7 +54,7 @@ def prepare_segments_file(kaldi_data_dir, window_length, overlap): parser = argparse.ArgumentParser() parser.add_argument('--window-length', type = float, default = 30.0, help = 'length of the window used to cut the segment') parser.add_argument('--overlap', type = float, default = 5.0, help = 'overlap of neighboring windows') - parser.add_argument('data_dir', type=str, help='directory such as data/train') + parser.add_argument('data_dir', help='directory such as data/train') params = parser.parse_args() diff --git a/egs/aspire/s5/local/multi_condition/fill_missing_recordings.py b/egs/aspire/s5/local/multi_condition/fill_missing_recordings.py index e249e54e5f6..2b4bcddda69 100755 --- a/egs/aspire/s5/local/multi_condition/fill_missing_recordings.py +++ b/egs/aspire/s5/local/multi_condition/fill_missing_recordings.py @@ -38,14 +38,14 @@ def fill_ctm(input_ctm_file, output_ctm_file, recording_names): sys.stderr.write(str(" ".join(sys.argv))) parser = argparse.ArgumentParser(usage) - parser.add_argument('input_ctm_file', type=str, help='ctm file for the recordings') - parser.add_argument('output_ctm_file', type=str, help='ctm file for the recordings') - parser.add_argument('recording_name_file', type=str, help='file with names of the recordings') + parser.add_argument('input_ctm_file', help='ctm file for the recordings') + parser.add_argument('output_ctm_file', help='ctm file for the recordings') + parser.add_argument('recording_name_file', help='file with names of the recordings') params = parser.parse_args() try: - file_names = map(lambda x: x.strip(), open("{0}".format(params.recording_name_file)).readlines()) + file_names = [x.strip() for x in open("{0}".format(params.recording_name_file)).readlines()] except IOError: raise Exception("Expected to find {0}".format(params.recording_name_file)) diff --git a/egs/aspire/s5/local/multi_condition/get_air_file_patterns.py b/egs/aspire/s5/local/multi_condition/get_air_file_patterns.py index cc06f58616a..1f06d3e7c3b 100755 --- a/egs/aspire/s5/local/multi_condition/get_air_file_patterns.py +++ b/egs/aspire/s5/local/multi_condition/get_air_file_patterns.py @@ -3,6 +3,7 @@ # script to generate the file_patterns of the AIR database # see load_air.m file in AIR db to understand the naming convention +from __future__ import print_function import sys, glob, re, os.path air_dir = sys.argv[1] @@ -45,4 +46,4 @@ file_patterns.append(file_pattern+" "+output_file_name) file_patterns = list(set(file_patterns)) file_patterns.sort() -print "\n".join(file_patterns) +print("\n".join(file_patterns)) diff --git a/egs/aspire/s5/local/multi_condition/get_ctm_conf.sh b/egs/aspire/s5/local/multi_condition/get_ctm_conf.sh deleted file mode 100755 index 23f3bcb8378..00000000000 --- a/egs/aspire/s5/local/multi_condition/get_ctm_conf.sh +++ /dev/null @@ -1,97 +0,0 @@ -#!/bin/bash -# Copyright Johns Hopkins University (Author: Daniel Povey) 2012. Apache 2.0. - -# This script produces CTM files from a decoding directory that has lattices -# present. This version gives you confidence scores. - - -# begin configuration section. -cmd=run.pl -stage=0 -min_lmwt=5 -max_lmwt=20 -use_segments=true # if we have a segments file, use it to convert - # the segments to be relative to the original files. -iter=final -#end configuration section. - -echo "$0 $@" # Print the command line for logging - -[ -f ./path.sh ] && . ./path.sh -. parse_options.sh || exit 1; - -if [ $# -ne 3 ]; then - echo "Usage: $0 [options] " - echo " Options:" - echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." - echo " --stage (0|1|2) # start scoring script from part-way through." - echo " --use-segments (true|false) # use segments and reco2file_and_channel files " - echo " # to produce a ctm relative to the original audio" - echo " # files, with channel information (typically needed" - echo " # for NIST scoring)." - echo "e.g.:" - echo "$0 data/train data/lang exp/tri4a/decode/" - echo "See also: steps/get_train_ctm.sh" - exit 1; -fi - -data=$1 -lang=$2 # Note: may be graph directory not lang directory, but has the necessary stuff copied. -dir=$3 - -model=$dir/../$iter.mdl # assume model one level up from decoding dir. - - -for f in $lang/words.txt $model $dir/lat.1.gz; do - [ ! -f $f ] && echo "$0: expecting file $f to exist" && exit 1; -done - -name=`basename $data`; # e.g. eval2000 - -mkdir -p $dir/scoring/log - -if [ -f $dir/../frame_shift ]; then - frame_shift_opt="--frame-shift=$(cat $dir/../frame_shift)" - echo "$0: $dir/../frame_shift exists, using $frame_shift_opt" -elif [ -f $dir/../frame_subsampling_factor ]; then - factor=$(cat $dir/../frame_subsampling_factor) || exit 1 - frame_shift_opt="--frame-shift=0.0$factor" - echo "$0: $dir/../frame_subsampling_factor exists, using $frame_shift_opt" -fi - - - -if [ $stage -le 0 ]; then - if [ -f $data/segments ] && $use_segments; then - f=$data/reco2file_and_channel - [ ! -f $f ] && echo "$0: expecting file $f to exist" && exit 1; - filter_cmd="utils/convert_ctm.pl $data/segments $data/reco2file_and_channel" - else - filter_cmd=cat - fi - - if [ -f $lang/phones/word_boundary.int ]; then - $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/get_ctm.LMWT.log \ - mkdir -p $dir/score_LMWT/ '&&' \ - lattice-prune --inv-acoustic-scale=LMWT --beam=5 "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ - lattice-align-words $lang/phones/word_boundary.int $model ark:- ark:- \| \ - lattice-to-ctm-conf $frame_shift_opt --decode-mbr=true --inv-acoustic-scale=LMWT ark:- - \| \ - utils/int2sym.pl -f 5 $lang/words.txt \| \ - $filter_cmd '>' $dir/score_LMWT/$name.ctm || exit 1; - else - if [ ! -f $lang/phones/align_lexicon.int ]; then - echo "$0: neither $lang/phones/word_boundary.int nor $lang/phones/align_lexicon.int exists: cannot align." - exit 1; - fi - - $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/get_ctm.LMWT.log \ - mkdir -p $dir/score_LMWT/ '&&' \ - lattice-prune --inv-acoustic-scale=LMWT --beam=5 "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ - lattice-align-words-lexicon $lang/phones/align_lexicon.int $model ark:- ark:- \| \ - lattice-to-ctm-conf $frame_shift_opt --decode-mbr=true --inv-acoustic-scale=LMWT ark:- - \| \ - utils/int2sym.pl -f 5 $lang/words.txt \| \ - $filter_cmd '>' $dir/score_LMWT/$name.ctm || exit 1; - fi -fi - - diff --git a/egs/aspire/s5/local/multi_condition/get_ctm_conf.sh b/egs/aspire/s5/local/multi_condition/get_ctm_conf.sh new file mode 120000 index 00000000000..4c0ff429c31 --- /dev/null +++ b/egs/aspire/s5/local/multi_condition/get_ctm_conf.sh @@ -0,0 +1 @@ +../../../../wsj/s5/steps/conf/get_ctm_conf.sh \ No newline at end of file diff --git a/egs/aspire/s5/local/multi_condition/normalize_wavs.py b/egs/aspire/s5/local/multi_condition/normalize_wavs.py index dabf420d9f8..6e67d2113c1 100755 --- a/egs/aspire/s5/local/multi_condition/normalize_wavs.py +++ b/egs/aspire/s5/local/multi_condition/normalize_wavs.py @@ -3,6 +3,8 @@ # normalizes the wave files provided in input file list with a common scaling factor # the common scaling factor is computed to 1/\sqrt(1/(total_samples) * \sum_i{\sum_j x_i(j)^2}) where total_samples is sum of all samples of all wavefiles. If the data is multi-channel then each channel is treated as a seperate wave files +from __future__ import division +from __future__ import print_function import argparse, scipy.io.wavfile, warnings, numpy as np, math def get_normalization_coefficient(file_list, is_rir, additional_scaling): @@ -29,7 +31,7 @@ def get_normalization_coefficient(file_list, is_rir, additional_scaling): assert(rate == sampling_rate) else: sampling_rate = rate - data = data / dtype_max_value + data = data/dtype_max_value if is_rir: # just count the energy of the direct impulse response # this is treated as energy of signal from 0.001 seconds before impulse @@ -55,8 +57,8 @@ def get_normalization_coefficient(file_list, is_rir, additional_scaling): except IOError: warnings.warn("Did not find the file {0}.".format(file)) assert(total_samples > 0) - scaling_coefficient = np.sqrt(total_samples / total_energy) - print "Scaling coefficient is {0}.".format(scaling_coefficient) + scaling_coefficient = np.sqrt(total_samples/total_energy) + print("Scaling coefficient is {0}.".format(scaling_coefficient)) if math.isnan(scaling_coefficient): raise Exception(" Nan encountered while computing scaling coefficient. This is mostly due to numerical overflow") return scaling_coefficient diff --git a/egs/aspire/s5/local/multi_condition/prepare_impulses_noises.sh b/egs/aspire/s5/local/multi_condition/prepare_impulses_noises.sh index 804de611cae..8297cdee9ca 100755 --- a/egs/aspire/s5/local/multi_condition/prepare_impulses_noises.sh +++ b/egs/aspire/s5/local/multi_condition/prepare_impulses_noises.sh @@ -114,7 +114,7 @@ cp ${output_dir}_non_normalized/info/* $output_dir/info # rename file location in the noise-rir pairing files for file in `ls $output_dir/info/noise_impulse*`; do - sed -i "s/_non_normalized//g" $file + perl -i -pe "s/_non_normalized//g" $file done # generating the rir-list with probabilities alloted for each rir diff --git a/egs/aspire/s5/local/multi_condition/read_rir.py b/egs/aspire/s5/local/multi_condition/read_rir.py index a2e1c2052e2..04898bda760 100755 --- a/egs/aspire/s5/local/multi_condition/read_rir.py +++ b/egs/aspire/s5/local/multi_condition/read_rir.py @@ -29,9 +29,9 @@ def usage(): #sys.stderr.write(" ".join(sys.argv)+"\n") parser = argparse.ArgumentParser(usage()) parser.add_argument('--output-sampling-rate', type = int, default = 8000, help = 'sampling rate of the output') - parser.add_argument('type', type = str, default = None, help = 'database type', choices = ['air']) - parser.add_argument('input', type = str, default = None, help = 'directory containing the multi-channel data for a particular recording, or file name or file-regex-pattern') - parser.add_argument('output_filename', type = str, default = None, help = 'output filename (if "-" then output is written to output pipe)') + parser.add_argument('type', default = None, help = 'database type', choices = ['air']) + parser.add_argument('input', default = None, help = 'directory containing the multi-channel data for a particular recording, or file name or file-regex-pattern') + parser.add_argument('output_filename', default = None, help = 'output filename (if "-" then output is written to output pipe)') params = parser.parse_args() if params.output_filename == "-": diff --git a/egs/aspire/s5/local/multi_condition/reverberate_wavs.py b/egs/aspire/s5/local/multi_condition/reverberate_wavs.py index 998a3ed5e74..f43e4a2f894 100755 --- a/egs/aspire/s5/local/multi_condition/reverberate_wavs.py +++ b/egs/aspire/s5/local/multi_condition/reverberate_wavs.py @@ -4,18 +4,20 @@ # script to generate multicondition training data / dev data / test data import argparse, glob, math, os, random, scipy.io.wavfile, sys -class list_cyclic_iterator: +class list_cyclic_iterator(object): def __init__(self, list, random_seed = 0): self.list_index = 0 self.list = list random.seed(random_seed) random.shuffle(self.list) - def next(self): + def __next__(self): item = self.list[self.list_index] self.list_index = (self.list_index + 1) % len(self.list) return item + next = __next__ # for Python 2 + def return_nonempty_lines(lines): new_lines = [] for line in lines: @@ -71,15 +73,15 @@ def return_nonempty_lines(lines): for i in range(len(wav_files)): wav_file = " ".join(wav_files[i].split()[1:]) output_wav_file = wav_out_files[i] - impulse_file = impulses.next() + impulse_file = next(impulses) noise_file = '' snr = '' found_impulse = False if add_noise: - for i in xrange(len(impulse_noise_index)): + for i in range(len(impulse_noise_index)): if impulse_file in impulse_noise_index[i][0]: - noise_file = impulse_noise_index[i][1].next() - snr = snrs.next() + noise_file = next(impulse_noise_index[i][1]) + snr = next(snrs) assert(len(wav_file.strip()) > 0) assert(len(impulse_file.strip()) > 0) assert(len(noise_file.strip()) > 0) diff --git a/egs/aspire/s5/local/nnet3/segment_and_decode.sh b/egs/aspire/s5/local/nnet3/segment_and_decode.sh index d66b72200c1..e8917d091e2 100755 --- a/egs/aspire/s5/local/nnet3/segment_and_decode.sh +++ b/egs/aspire/s5/local/nnet3/segment_and_decode.sh @@ -109,9 +109,9 @@ fi if [ $stage -le 4 ]; then utils/copy_data_dir.sh $sad_work_dir/${segmented_data_set}_seg \ - data/${segmented_data_set}_hires - steps/compute_cmvn_stats.sh data/${segmented_data_set}_hires - utils/fix_data_dir.sh data/${segmented_data_set}_hires + data/${segmented_data_set}_seg_hires + steps/compute_cmvn_stats.sh data/${segmented_data_set}_seg_hires + utils/fix_data_dir.sh data/${segmented_data_set}_seg_hires fi if [ $stage -le 5 ]; then @@ -122,11 +122,11 @@ if [ $stage -le 5 ]; then # acoustic conditions drift over time within the speaker's data. steps/online/nnet2/extract_ivectors.sh --cmd "$train_cmd" --nj $decode_num_jobs \ --sub-speaker-frames $sub_speaker_frames --max-count $max_count \ - data/${segmented_data_set}_hires $lang $ivector_root_dir/extractor \ - $ivector_root_dir/ivectors_${segmented_data_set} + data/${segmented_data_set}_seg_hires $lang $ivector_root_dir/extractor \ + $ivector_root_dir/ivectors_${segmented_data_set}_seg fi -decode_dir=$dir/decode_${segmented_data_set}${affix}_pp +decode_dir=$dir/decode_${segmented_data_set}_seg${affix}_pp if [ $stage -le 6 ]; then echo "Generating lattices" rm -f ${decode_dir}_tg/.error @@ -138,8 +138,8 @@ if [ $stage -le 6 ]; then --extra-right-context-final $extra_right_context_final \ --frames-per-chunk "$frames_per_chunk" \ --skip-scoring true ${iter:+--iter $iter} --lattice-beam $lattice_beam \ - --online-ivector-dir $ivector_root_dir/ivectors_${segmented_data_set} \ - $graph data/${segmented_data_set}_hires ${decode_dir}_tg || \ + --online-ivector-dir $ivector_root_dir/ivectors_${segmented_data_set}_seg \ + $graph data/${segmented_data_set}_seg_hires ${decode_dir}_tg || \ { echo "$0: Error decoding" && exit 1; } fi @@ -147,7 +147,7 @@ if [ $stage -le 7 ]; then echo "Rescoring lattices" steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ --skip-scoring true \ - ${lang}_pp_test{,_fg} data/${segmented_data_set}_hires \ + ${lang}_pp_test{,_fg} data/${segmented_data_set}_seg_hires \ ${decode_dir}_{tg,fg}; fi @@ -161,5 +161,5 @@ if [ $stage -le 8 ]; then ${iter:+--iter $iter} \ --decode-mbr true \ --tune-hyper true \ - $lang $decode_dir $act_data_set $segmented_data_set $out_file + $lang $decode_dir $act_data_set ${segmented_data_set}_seg $out_file fi diff --git a/egs/aurora4/s5/local/run_mmi_tri2b.sh b/egs/aurora4/s5/local/run_mmi_tri2b.sh index 6517e46a1a7..8a4d03c59c4 100755 --- a/egs/aurora4/s5/local/run_mmi_tri2b.sh +++ b/egs/aurora4/s5/local/run_mmi_tri2b.sh @@ -38,7 +38,7 @@ steps/train_diag_ubm.sh --silence-weight 0.5 --nj 10 --cmd "$train_cmd" \ data/train_si84 data/lang exp/tri2b_ali_si84 exp/dubm2b exp/tri2b_denlats_si84 \ exp/tri2b_fmmi_b0.1 - for iter in `seq 3 8`; do + for iter in `seq 3 8`; do steps/decode_fmmi.sh --nj 10 --cmd "$decode_cmd" --iter $iter \ exp/tri2b/graph_tgpr data/test_dev93 exp/tri2b_fmmi_b0.1/decode_tgpr_dev93_it$iter & done @@ -46,7 +46,7 @@ steps/train_diag_ubm.sh --silence-weight 0.5 --nj 10 --cmd "$train_cmd" \ steps/train_mmi_fmmi.sh --learning-rate 0.005 --boost 0.1 --cmd "$train_cmd" \ data/train_si84 data/lang exp/tri2b_ali_si84 exp/dubm2b exp/tri2b_denlats_si84 \ exp/tri2b_fmmi_b0.1_lr0.005 || exit 1; - for iter in `seq 3 8`; do + for iter in `seq 3 8`; do steps/decode_fmmi.sh --nj 10 --cmd "$decode_cmd" --iter $iter \ exp/tri2b/graph_tgpr data/test_dev93 exp/tri2b_fmmi_b0.1_lr0.005/decode_tgpr_dev93_it$iter & done @@ -54,7 +54,7 @@ steps/train_diag_ubm.sh --silence-weight 0.5 --nj 10 --cmd "$train_cmd" \ steps/train_mmi_fmmi_indirect.sh --boost 0.1 --cmd "$train_cmd" \ data/train_si84 data/lang exp/tri2b_ali_si84 exp/dubm2b exp/tri2b_denlats_si84 \ exp/tri2b_fmmi_indirect_b0.1 - for iter in `seq 3 8`; do + for iter in `seq 3 8`; do steps/decode_fmmi.sh --nj 10 --cmd "$decode_cmd" --iter $iter \ exp/tri2b/graph_tgpr data/test_dev93 exp/tri2b_fmmi_indirect_b0.1/decode_tgpr_dev93_it$iter & done diff --git a/egs/aurora4/s5/local/run_sgmm2.sh b/egs/aurora4/s5/local/run_sgmm2.sh index b7f872930e0..2eb70785bcb 100755 --- a/egs/aurora4/s5/local/run_sgmm2.sh +++ b/egs/aurora4/s5/local/run_sgmm2.sh @@ -88,14 +88,14 @@ exp/ubm5b/final.ubm exp/sgmm2_5c || exit 1; # Decode from lattices in exp/sgmm2_5b steps/decode_sgmm2_fromlats.sh --cmd "$decode_cmd" --transform-dir exp/tri4b/decode_tgpr_dev93 \ - data/test_dev93 data/lang_test_tgpr exp/sgmm2_5b/decode_tgpr_dev93 exp/sgmm2_5c/decode_tgpr_dev93 + data/test_dev93 data/lang_test_tgpr exp/sgmm2_5b/decode_tgpr_dev93 exp/sgmm2_5c/decode_tgpr_dev93 steps/decode_sgmm2_fromlats.sh --cmd "$decode_cmd" --transform-dir exp/tri4b/decode_tgpr_eval92 \ - data/test_eval92 data/lang_test_tgpr exp/sgmm2_5b/decode_tgpr_eval92 exp/sgmm2_5c/decode_tgpr_eval92 + data/test_eval92 data/lang_test_tgpr exp/sgmm2_5b/decode_tgpr_eval92 exp/sgmm2_5c/decode_tgpr_eval92 ) & steps/align_sgmm2.sh --nj 30 --cmd "$train_cmd" --transform-dir exp/tri4b_ali_si284 \ - --use-graphs true --use-gselect true data/train_si284 data/lang exp/sgmm2_5b exp/sgmm2_5b_ali_si284 + --use-graphs true --use-gselect true data/train_si284 data/lang exp/sgmm2_5b exp/sgmm2_5b_ali_si284 steps/make_denlats_sgmm2.sh --nj 30 --sub-split 30 --cmd "$decode_cmd" --transform-dir exp/tri4b_ali_si284 \ data/train_si284 data/lang exp/sgmm2_5b_ali_si284 exp/sgmm2_5b_denlats_si284 @@ -128,7 +128,7 @@ wait # Examples of combining some of the best decodings: SGMM+MMI with # MMI+fMMI on a conventional system. - + local/score_combine.sh data/test_eval92 \ data/lang_test_bd_tgpr \ exp/tri4b_fmmi_a/decode_tgpr_eval92_it8 \ diff --git a/egs/babel/s5b/local/lonestar.py b/egs/babel/s5b/local/lonestar.py index e1594e55ada..809f99b22cf 100755 --- a/egs/babel/s5b/local/lonestar.py +++ b/egs/babel/s5b/local/lonestar.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +from __future__ import print_function from pylauncher import * import pylauncher import sys @@ -39,7 +40,7 @@ def KaldiLauncher(lo, **kwargs): logfiles = list() commands = list() - for q in xrange(lo.jobstart, lo.jobend+1): + for q in range(lo.jobstart, lo.jobend+1): s = "bash " + lo.queue_scriptfile + " " + str(q) commands.append(s) @@ -74,7 +75,7 @@ def KaldiLauncher(lo, **kwargs): time.sleep(delay); lines=tail(10, logfile) - with_status=filter(lambda x:re.search(r'with status (\d+)', x), lines) + with_status=[x for x in lines if re.search(r'with status (\d+)', x)] if len(with_status) == 0: sys.stderr.write("The last line(s) of the log-file " + logfile + " does not seem" @@ -98,7 +99,7 @@ def KaldiLauncher(lo, **kwargs): sys.exit(-1); #Remove service files. Be careful not to remove something that might be needed in problem diagnostics - for i in xrange(len(commands)): + for i in range(len(commands)): out_file=os.path.join(qdir, ce.outstring+str(i)) #First, let's wait on files missing (it might be that those are missing @@ -149,7 +150,7 @@ def KaldiLauncher(lo, **kwargs): #print job.final_report() -class LauncherOpts: +class LauncherOpts(object): def __init__(self): self.sync=0 self.nof_threads = 1 @@ -199,7 +200,7 @@ def CmdLineParser(argv): jobend=int(m.group(2)) argv.pop(0) elif re.match("^.+=.*:.*$", argv[0]): - print >> sys.stderr, "warning: suspicious JOB argument " + argv[0]; + print("warning: suspicious JOB argument " + argv[0], file=sys.stderr); if jobstart > jobend: sys.stderr.write("lonestar.py: JOBSTART("+ str(jobstart) + ") must be lower than JOBEND(" + str(jobend) + ")\n") @@ -238,8 +239,8 @@ def setup_paths_and_vars(opts): cwd = os.getcwd() if opts.varname and (opts.varname not in opts.logfile ) and (opts.jobstart != opts.jobend): - print >>sys.stderr, "lonestar.py: you are trying to run a parallel job" \ - "but you are putting the output into just one log file (" + opts.logfile + ")"; + print("lonestar.py: you are trying to run a parallel job" \ + "but you are putting the output into just one log file (" + opts.logfile + ")", file=sys.stderr); sys.exit(1) if not os.path.isabs(opts.logfile): @@ -261,8 +262,8 @@ def setup_paths_and_vars(opts): taskname=os.path.basename(queue_logfile) taskname = taskname.replace(".log", ""); if taskname == "": - print >> sys.stderr, "lonestar.py: you specified the log file name in such form " \ - "that leads to an empty task name ("+logfile + ")"; + print("lonestar.py: you specified the log file name in such form " \ + "that leads to an empty task name ("+logfile + ")", file=sys.stderr); sys.exit(1) if not os.path.isabs(queue_logfile): diff --git a/egs/babel/s5b/local/resegment/segmentation.py b/egs/babel/s5b/local/resegment/segmentation.py index 7c5c8665a16..aed65a4ca14 100755 --- a/egs/babel/s5b/local/resegment/segmentation.py +++ b/egs/babel/s5b/local/resegment/segmentation.py @@ -3,6 +3,7 @@ # Copyright 2014 Vimal Manohar # Apache 2.0 +from __future__ import division import os, glob, argparse, sys, re, time from argparse import ArgumentParser @@ -19,12 +20,12 @@ def mean(l): if len(l) > 0: - return float(sum(l)) / len(l) + return (float(sum(l))/len(l)) return 0 # Analysis class # Stores statistics like the confusion matrix, length of the segments etc. -class Analysis: +class Analysis(object): def __init__(self, file_id, frame_shift, prefix): self.confusion_matrix = [0] * 9 self.type_counts = [ [[] for j in range(0,9)] for i in range(0,3) ] @@ -274,8 +275,8 @@ def read_rttm_file(rttm_file, temp_dir, frame_shift): i = len(this_file) category = splits[6] word = splits[5] - start_time = int(float(splits[3])/frame_shift + 0.5) - duration = int(float(splits[4])/frame_shift + 0.5) + start_time = int((float(splits[3])/frame_shift) + 0.5) + duration = int((float(splits[4])/frame_shift) + 0.5) if i < start_time: this_file.extend(["0"]*(start_time - i)) if type1 == "NON-LEX": @@ -295,7 +296,7 @@ def read_rttm_file(rttm_file, temp_dir, frame_shift): # Stats class to store some basic stats about the number of # times the post-processor goes through particular loops or blocks # of code in the algorithm. This is just for debugging. -class Stats: +class Stats(object): def __init__(self): self.inter_utt_nonspeech = 0 self.merge_nonspeech_segment = 0 @@ -321,7 +322,7 @@ def reset(self): self.noise_only = 0 # Timer class to time functions -class Timer: +class Timer(object): def __enter__(self): self.start = time.clock() return self @@ -332,7 +333,7 @@ def __exit__(self, *args): # The main class for post-processing a file. # This does the segmentation either looking at the file isolated # or by looking at both classes simultaneously -class JointResegmenter: +class JointResegmenter(object): def __init__(self, P, A, f, options, phone_map, stats = None, reference = None): # Pointers to prediction arrays and Initialization @@ -1290,22 +1291,22 @@ def main(): dest='hard_max_segment_length', default=15.0, \ help="Hard maximum on the segment length above which the segment " \ + "will be broken even if in the middle of speech (default: %(default)s)") - parser.add_argument('--first-separator', type=str, \ + parser.add_argument('--first-separator', \ dest='first_separator', default="-", \ help="Separator between recording-id and start-time (default: %(default)s)") - parser.add_argument('--second-separator', type=str, \ + parser.add_argument('--second-separator', \ dest='second_separator', default="-", \ help="Separator between start-time and end-time (default: %(default)s)") - parser.add_argument('--remove-noise-only-segments', type=str, \ + parser.add_argument('--remove-noise-only-segments', \ dest='remove_noise_only_segments', default="true", choices=("true", "false"), \ help="Remove segments that have only noise. (default: %(default)s)") parser.add_argument('--min-inter-utt-silence-length', type=float, \ dest='min_inter_utt_silence_length', default=1.0, \ help="Minimum silence that must exist between two separate utterances (default: %(default)s)"); - parser.add_argument('--channel1-file', type=str, \ + parser.add_argument('--channel1-file', \ dest='channel1_file', default="inLine", \ help="String that matches with the channel 1 file (default: %(default)s)") - parser.add_argument('--channel2-file', type=str, \ + parser.add_argument('--channel2-file', \ dest='channel2_file', default="outLine", \ help="String that matches with the channel 2 file (default: %(default)s)") parser.add_argument('--isolated-resegmentation', \ @@ -1388,7 +1389,7 @@ def main(): speech_cap = None if options.speech_cap_length != None: - speech_cap = int( options.speech_cap_length / options.frame_shift ) + speech_cap = int(options.speech_cap_length/options.frame_shift) # End if for f in pred_files: @@ -1454,7 +1455,7 @@ def main(): f2 = f3 # End if - if (len(A1) - len(A2)) > options.max_length_diff / options.frame_shift: + if (len(A1) - len(A2)) > int(options.max_length_diff/options.frame_shift): sys.stderr.write( \ "%s: Warning: Lengths of %s and %s differ by more than %f. " \ % (sys.argv[0], f1,f2, options.max_length_diff) \ diff --git a/egs/babel/s5c/local/lonestar.py b/egs/babel/s5c/local/lonestar.py index e1594e55ada..809f99b22cf 100755 --- a/egs/babel/s5c/local/lonestar.py +++ b/egs/babel/s5c/local/lonestar.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +from __future__ import print_function from pylauncher import * import pylauncher import sys @@ -39,7 +40,7 @@ def KaldiLauncher(lo, **kwargs): logfiles = list() commands = list() - for q in xrange(lo.jobstart, lo.jobend+1): + for q in range(lo.jobstart, lo.jobend+1): s = "bash " + lo.queue_scriptfile + " " + str(q) commands.append(s) @@ -74,7 +75,7 @@ def KaldiLauncher(lo, **kwargs): time.sleep(delay); lines=tail(10, logfile) - with_status=filter(lambda x:re.search(r'with status (\d+)', x), lines) + with_status=[x for x in lines if re.search(r'with status (\d+)', x)] if len(with_status) == 0: sys.stderr.write("The last line(s) of the log-file " + logfile + " does not seem" @@ -98,7 +99,7 @@ def KaldiLauncher(lo, **kwargs): sys.exit(-1); #Remove service files. Be careful not to remove something that might be needed in problem diagnostics - for i in xrange(len(commands)): + for i in range(len(commands)): out_file=os.path.join(qdir, ce.outstring+str(i)) #First, let's wait on files missing (it might be that those are missing @@ -149,7 +150,7 @@ def KaldiLauncher(lo, **kwargs): #print job.final_report() -class LauncherOpts: +class LauncherOpts(object): def __init__(self): self.sync=0 self.nof_threads = 1 @@ -199,7 +200,7 @@ def CmdLineParser(argv): jobend=int(m.group(2)) argv.pop(0) elif re.match("^.+=.*:.*$", argv[0]): - print >> sys.stderr, "warning: suspicious JOB argument " + argv[0]; + print("warning: suspicious JOB argument " + argv[0], file=sys.stderr); if jobstart > jobend: sys.stderr.write("lonestar.py: JOBSTART("+ str(jobstart) + ") must be lower than JOBEND(" + str(jobend) + ")\n") @@ -238,8 +239,8 @@ def setup_paths_and_vars(opts): cwd = os.getcwd() if opts.varname and (opts.varname not in opts.logfile ) and (opts.jobstart != opts.jobend): - print >>sys.stderr, "lonestar.py: you are trying to run a parallel job" \ - "but you are putting the output into just one log file (" + opts.logfile + ")"; + print("lonestar.py: you are trying to run a parallel job" \ + "but you are putting the output into just one log file (" + opts.logfile + ")", file=sys.stderr); sys.exit(1) if not os.path.isabs(opts.logfile): @@ -261,8 +262,8 @@ def setup_paths_and_vars(opts): taskname=os.path.basename(queue_logfile) taskname = taskname.replace(".log", ""); if taskname == "": - print >> sys.stderr, "lonestar.py: you specified the log file name in such form " \ - "that leads to an empty task name ("+logfile + ")"; + print("lonestar.py: you specified the log file name in such form " \ + "that leads to an empty task name ("+logfile + ")", file=sys.stderr); sys.exit(1) if not os.path.isabs(queue_logfile): diff --git a/egs/babel/s5c/local/resegment/segmentation.py b/egs/babel/s5c/local/resegment/segmentation.py index 7c5c8665a16..4bdb0fea75c 100755 --- a/egs/babel/s5c/local/resegment/segmentation.py +++ b/egs/babel/s5c/local/resegment/segmentation.py @@ -3,6 +3,7 @@ # Copyright 2014 Vimal Manohar # Apache 2.0 +from __future__ import division import os, glob, argparse, sys, re, time from argparse import ArgumentParser @@ -19,12 +20,12 @@ def mean(l): if len(l) > 0: - return float(sum(l)) / len(l) + return (float(sum(l))/len(l)) return 0 # Analysis class # Stores statistics like the confusion matrix, length of the segments etc. -class Analysis: +class Analysis(object): def __init__(self, file_id, frame_shift, prefix): self.confusion_matrix = [0] * 9 self.type_counts = [ [[] for j in range(0,9)] for i in range(0,3) ] @@ -274,7 +275,7 @@ def read_rttm_file(rttm_file, temp_dir, frame_shift): i = len(this_file) category = splits[6] word = splits[5] - start_time = int(float(splits[3])/frame_shift + 0.5) + start_time = int((float(splits[3])/frame_shift) + 0.5) duration = int(float(splits[4])/frame_shift + 0.5) if i < start_time: this_file.extend(["0"]*(start_time - i)) @@ -295,7 +296,7 @@ def read_rttm_file(rttm_file, temp_dir, frame_shift): # Stats class to store some basic stats about the number of # times the post-processor goes through particular loops or blocks # of code in the algorithm. This is just for debugging. -class Stats: +class Stats(object): def __init__(self): self.inter_utt_nonspeech = 0 self.merge_nonspeech_segment = 0 @@ -321,7 +322,7 @@ def reset(self): self.noise_only = 0 # Timer class to time functions -class Timer: +class Timer(object): def __enter__(self): self.start = time.clock() return self @@ -332,7 +333,7 @@ def __exit__(self, *args): # The main class for post-processing a file. # This does the segmentation either looking at the file isolated # or by looking at both classes simultaneously -class JointResegmenter: +class JointResegmenter(object): def __init__(self, P, A, f, options, phone_map, stats = None, reference = None): # Pointers to prediction arrays and Initialization @@ -351,9 +352,9 @@ def __init__(self, P, A, f, options, phone_map, stats = None, reference = None): self.frame_shift = options.frame_shift # Convert length in seconds to frames - self.max_frames = int(options.max_segment_length / options.frame_shift) - self.hard_max_frames = int(options.hard_max_segment_length / options.frame_shift) - self.min_inter_utt_nonspeech_length = int(options.min_inter_utt_silence_length / options.frame_shift) + self.max_frames = int(options.max_segment_length/options.frame_shift) + self.hard_max_frames = int(options.hard_max_segment_length/options.frame_shift) + self.min_inter_utt_nonspeech_length = int(options.min_inter_utt_silence_length, options.frame_shift) if ( options.remove_noise_only_segments == "false" ): self.remove_noise_segments = False elif ( options.remove_noise_only_segments == "true" ): @@ -540,7 +541,7 @@ def set_nonspeech_proportion(self): # Set the number of non-speech frames to be added depending on the # silence proportion. The target number of frames in the segments # is computed as below: - target_segment_frames = int(num_speech_frames / (1.0 - self.options.silence_proportion)) + target_segment_frames = int(num_speech_frames/(1.0 - self.options.silence_proportion)) # The number of frames currently in the segments num_segment_frames = num_speech_frames @@ -599,7 +600,7 @@ def set_nonspeech_proportion(self): if not changed: # avoid an infinite loop. if no changes, then break. break if num_segment_frames < target_segment_frames: - proportion = float(num_segment_frames - num_speech_frames) / num_segment_frames + proportion = float(num_segment_frames - num_speech_frames)/num_segment_frames sys.stderr.write("%s: Warning: for recording %s, only got a proportion %f of non-speech frames, versus target %f\n" % (sys.argv[0], self.file_id, proportion, self.options.silence_proportion)) ########################################################################### @@ -863,14 +864,14 @@ def split_long_segments(self): # Count the number of times long segments are split self.stats.split_segments += 1 - num_pieces = int((float(segment_length) / self.hard_max_frames) + 0.99999) + num_pieces = int((float(segment_length)/self.hard_max_frames) + 0.99999) sys.stderr.write("%s: Warning: for recording %s, " \ % (sys.argv[0], self.file_id) \ + "splitting segment of length %f seconds into %d pieces " \ % (segment_length * self.frame_shift, num_pieces) \ + "(--hard-max-segment-length %f)\n" \ % self.options.hard_max_segment_length) - frames_per_piece = int(segment_length / num_pieces) + frames_per_piece = int(segment_length/num_pieces) for i in range(1,num_pieces): q = n + i * frames_per_piece self.S[q] = True @@ -1290,22 +1291,22 @@ def main(): dest='hard_max_segment_length', default=15.0, \ help="Hard maximum on the segment length above which the segment " \ + "will be broken even if in the middle of speech (default: %(default)s)") - parser.add_argument('--first-separator', type=str, \ + parser.add_argument('--first-separator', \ dest='first_separator', default="-", \ help="Separator between recording-id and start-time (default: %(default)s)") - parser.add_argument('--second-separator', type=str, \ + parser.add_argument('--second-separator', \ dest='second_separator', default="-", \ help="Separator between start-time and end-time (default: %(default)s)") - parser.add_argument('--remove-noise-only-segments', type=str, \ + parser.add_argument('--remove-noise-only-segments', \ dest='remove_noise_only_segments', default="true", choices=("true", "false"), \ help="Remove segments that have only noise. (default: %(default)s)") parser.add_argument('--min-inter-utt-silence-length', type=float, \ dest='min_inter_utt_silence_length', default=1.0, \ help="Minimum silence that must exist between two separate utterances (default: %(default)s)"); - parser.add_argument('--channel1-file', type=str, \ + parser.add_argument('--channel1-file', \ dest='channel1_file', default="inLine", \ help="String that matches with the channel 1 file (default: %(default)s)") - parser.add_argument('--channel2-file', type=str, \ + parser.add_argument('--channel2-file', \ dest='channel2_file', default="outLine", \ help="String that matches with the channel 2 file (default: %(default)s)") parser.add_argument('--isolated-resegmentation', \ @@ -1388,7 +1389,7 @@ def main(): speech_cap = None if options.speech_cap_length != None: - speech_cap = int( options.speech_cap_length / options.frame_shift ) + speech_cap = int(options.speech_cap_length/options.frame_shift) # End if for f in pred_files: @@ -1454,7 +1455,7 @@ def main(): f2 = f3 # End if - if (len(A1) - len(A2)) > options.max_length_diff / options.frame_shift: + if (len(A1) - len(A2)) > int(options.max_length_diff/options.frame_shift): sys.stderr.write( \ "%s: Warning: Lengths of %s and %s differ by more than %f. " \ % (sys.argv[0], f1,f2, options.max_length_diff) \ diff --git a/egs/babel/s5c/local/syllab/generate_syllable_lang.sh b/egs/babel/s5c/local/syllab/generate_syllable_lang.sh index 2d1fcb2259e..4a0810b9415 100755 --- a/egs/babel/s5c/local/syllab/generate_syllable_lang.sh +++ b/egs/babel/s5c/local/syllab/generate_syllable_lang.sh @@ -118,8 +118,7 @@ ln -s lex.syllabs2phones.disambig.fst $out/L_disambig.fst echo "Validating the output lang dir" utils/validate_lang.pl $out || exit 1 -sed -i'' 's/#1$//g' $lout/lexicon.txt -sed -i'' 's/#1$//g' $lout/lexiconp.txt +perl -i -pe 's/#1$//g' $lout/lexicon.txt $lout/lexiconp.txt echo "Done OK." exit 0 diff --git a/egs/babel/s5d/conf/lang/404-georgian.FLP.official.conf b/egs/babel/s5d/conf/lang/404-georgian.FLP.official.conf index a6b22de419f..9cd043716ce 100644 --- a/egs/babel/s5d/conf/lang/404-georgian.FLP.official.conf +++ b/egs/babel/s5d/conf/lang/404-georgian.FLP.official.conf @@ -75,8 +75,8 @@ unsup_data_list=./conf/lists/404-georgian/untranscribed-training.list unsup_nj=32 -lexicon_file= -lexiconFlags="--romanized --oov " +lexicon_file=/export/corpora/LDC/LDC2016S12/IARPA_BABEL_OP3_404/conversational/reference_materials/lexicon.txt +lexiconFlags=" --romanized --oov " diff --git a/egs/babel/s5d/local/chain/tuning/run_tdnn.sh b/egs/babel/s5d/local/chain/tuning/run_tdnn.sh index 4f485edf7da..7b4535f8c5e 100755 --- a/egs/babel/s5d/local/chain/tuning/run_tdnn.sh +++ b/egs/babel/s5d/local/chain/tuning/run_tdnn.sh @@ -128,7 +128,7 @@ if [ $stage -le 17 ]; then num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm.sh b/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm.sh index 72f7a3c32dd..5fc14dda826 100755 --- a/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm.sh +++ b/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm.sh @@ -129,7 +129,7 @@ if [ $stage -le 17 ]; then num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" label_delay=5 diff --git a/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab1.sh b/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab1.sh index be0c2cc4b9b..8c7de5d18d4 100755 --- a/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab1.sh +++ b/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab1.sh @@ -127,7 +127,7 @@ if [ $stage -le 17 ]; then num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" label_delay=5 diff --git a/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab2.sh b/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab2.sh index 8f21a239794..0b3e70b5a04 100755 --- a/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab2.sh +++ b/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab2.sh @@ -127,7 +127,7 @@ if [ $stage -le 17 ]; then num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" label_delay=5 diff --git a/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab3.sh b/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab3.sh index 7898d172242..45f2907645e 100755 --- a/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab3.sh +++ b/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab3.sh @@ -128,7 +128,7 @@ if [ $stage -le 17 ]; then num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" label_delay=5 diff --git a/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab4.sh b/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab4.sh index 49462573245..0d92aff5c28 100755 --- a/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab4.sh +++ b/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab4.sh @@ -128,7 +128,7 @@ if [ $stage -le 17 ]; then num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" label_delay=5 diff --git a/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab5.sh b/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab5.sh index c888d985f5e..4129c00dcb4 100755 --- a/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab5.sh +++ b/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab5.sh @@ -128,7 +128,7 @@ if [ $stage -le 17 ]; then num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" label_delay=5 diff --git a/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab6.sh b/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab6.sh index e9a045e113a..1cfa50c1aa1 100755 --- a/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab6.sh +++ b/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab6.sh @@ -128,7 +128,7 @@ if [ $stage -le 17 ]; then num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" label_delay=5 diff --git a/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab7.sh b/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab7.sh index ce192a91665..ba8ac1e0373 100755 --- a/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab7.sh +++ b/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab7.sh @@ -129,7 +129,7 @@ if [ $stage -le 17 ]; then num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20 dropout-proportion=0.0" label_delay=5 diff --git a/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab8.sh b/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab8.sh index 3fc0ef2206c..5de285e080e 100755 --- a/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab8.sh +++ b/egs/babel/s5d/local/chain/tuning/run_tdnn_lstm_bab8.sh @@ -129,7 +129,7 @@ if [ $stage -le 17 ]; then num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20 dropout-proportion=0.0 " label_delay=5 diff --git a/egs/babel/s5d/local/lexicon/make_unicode_lexicon.py b/egs/babel/s5d/local/lexicon/make_unicode_lexicon.py index 68280762597..91419f6e920 100755 --- a/egs/babel/s5d/local/lexicon/make_unicode_lexicon.py +++ b/egs/babel/s5d/local/lexicon/make_unicode_lexicon.py @@ -106,6 +106,7 @@ # Import Statements from __future__ import print_function +from __future__ import division import codecs import argparse import unicodedata @@ -340,7 +341,7 @@ def encode(unicode_transcription, tag_percentage, log=False): int2graph = {v: k for k, v in graph2int.items()} graph_list_int = [graph2int[g] for g in graph_list] bin_edges = range(0, len(int2graph.keys()) + 1) - graph_counts = np.histogram(graph_list_int, bins=bin_edges)[0] / float(len(graph_list_int)) + graph_counts = np.histogram(graph_list_int, bins=bin_edges)[0]/float(len(graph_list_int)) # Set count threshold to frequency that tags the bottom 10% of graphemes bottom_idx = int(np.floor(tag_percentage * len(graph_counts))) count_thresh = sorted(graph_counts)[bottom_idx] @@ -465,7 +466,7 @@ def encode(unicode_transcription, tag_percentage, log=False): for g_dict in table: g_map = "" map_number = 0 - for g_field, g_val in sorted(g_dict.iteritems()): + for g_field, g_val in sorted(g_dict.items()): if(g_field == ("MAP" + str(map_number))): g_map = g_map + g_val + " " map_number = map_number + 1 @@ -561,7 +562,7 @@ def write_table(table, outfile): # Start writing to output with codecs.open(outfile, "w", "utf-8") as fo: # Get header names - header_names = sorted(set().union(*[d.keys() for d in table])) + header_names = sorted(set().union(*[list(d.keys()) for d in table])) # Write headers for h in header_names[:-1]: fo.write("%s\t" % h) @@ -595,7 +596,7 @@ def write_map(grapheme_map, mapfile): ''' with codecs.open(mapfile, 'w', encoding='utf-8') as f: - for g, g_map in grapheme_map.iteritems(): + for g, g_map in grapheme_map.items(): print(g, g_map, file=f) @@ -613,14 +614,14 @@ def write_lexicon(baseforms, encoded_transcription, outfile, sil_lex=None, with codecs.open(outfile, "w", "utf-8") as f: # First write the non-speech words try: - for w in sil_lex.iterkeys(): + for w in sil_lex.keys(): f.write("%s\t%s\n" % (w, sil_lex[w])) except AttributeError: pass # Then write extra-speech words try: - for w in extra_lex.iterkeys(): + for w in extra_lex.keys(): f.write("%s\t%s\n" % (w, extra_lex[w])) except AttributeError: pass @@ -629,9 +630,9 @@ def write_lexicon(baseforms, encoded_transcription, outfile, sil_lex=None, for idx, w in enumerate(baseforms): # This is really just for BABEL in case is written as a word if(w[0].lower() == ""): - f.write("%s\t\n" % (unicode(w[0]))) + f.write("%s\t\n" % (w[0])) else: - f.write("%s\t%s\n" % (unicode(w[0]), + f.write("%s\t%s\n" % (w[0], encoded_transcription[idx])) if __name__ == "__main__": diff --git a/egs/babel/s5d/local/lexicon/make_word_list.py b/egs/babel/s5d/local/lexicon/make_word_list.py index 9a9e17f6c60..c1473b8ced8 100755 --- a/egs/babel/s5d/local/lexicon/make_word_list.py +++ b/egs/babel/s5d/local/lexicon/make_word_list.py @@ -85,7 +85,7 @@ def main(): # Print the word list with codecs.open(args.word_list, "w", encoding="utf-8") as f: for word, count in words: - f.write("%d %s\n" % (count, unicode(word))) + f.write("%d %s\n" % (count, word)) if args.misprons is not None: with codecs.open(args.misprons, "w", encoding="utf-8") as f: diff --git a/egs/babel/s5d/local/make_L_align.sh b/egs/babel/s5d/local/make_L_align.sh index 50e46a00493..41e9ff32958 100755 --- a/egs/babel/s5d/local/make_L_align.sh +++ b/egs/babel/s5d/local/make_L_align.sh @@ -34,18 +34,24 @@ tmpdir=$1 dir=$2 outdir=$3 +for f in $dir/phones/optional_silence.txt $dir/phones.txt $dir/words.txt ; do + [ ! -f $f ] && echo "$0: The file $f must exist!" exit 1 +fi + silphone=`cat $dir/phones/optional_silence.txt` || exit 1; +if [ ! -f $tmpdir/lexicon.txt ] && [ ! -f $tmpdir/lexiconp.txt ] ; then + echo "$0: At least one of the files $tmpdir/lexicon.txt or $tmpdir/lexiconp.txt must exist" >&2 + exit 1 +fi + # Create lexicon with alignment info if [ -f $tmpdir/lexicon.txt ] ; then cat $tmpdir/lexicon.txt | \ awk '{printf("%s #1 ", $1); for (n=2; n <= NF; n++) { printf("%s ", $n); } print "#2"; }' -elif [ -f $tmpdir/lexiconp.txt ] ; then +else cat $tmpdir/lexiconp.txt | \ awk '{printf("%s #1 ", $1); for (n=3; n <= NF; n++) { printf("%s ", $n); } print "#2"; }' -else - echo "Neither $tmpdir/lexicon.txt nor $tmpdir/lexiconp.txt does not exist" - exit 1 fi | utils/make_lexicon_fst.pl - 0.5 $silphone | \ fstcompile --isymbols=$dir/phones.txt --osymbols=$dir/words.txt \ --keep_isymbols=false --keep_osymbols=false | \ diff --git a/egs/babel/s5d/local/prepare_unicode_lexicon.py b/egs/babel/s5d/local/prepare_unicode_lexicon.py index 86fa4d60ba1..3b9dc1abd86 100755 --- a/egs/babel/s5d/local/prepare_unicode_lexicon.py +++ b/egs/babel/s5d/local/prepare_unicode_lexicon.py @@ -89,7 +89,7 @@ def extract_phonemes(lexicon): # Read all baseform units into dictionary with {a: [a, a_1, a_2], # b: [b_1, b_3], ...} phonemes_dict = {} - for word, pron in lexicon.iteritems(): + for word, pron in lexicon.items(): for p in pron.split(): try: base = p.split("_",1)[0] @@ -98,11 +98,11 @@ def extract_phonemes(lexicon): phonemes_dict[base] = [p] # Makes sure there are no repeats in the list - phonemes_dict = {k: set(v) for k, v in phonemes_dict.iteritems()} + phonemes_dict = {k: set(v) for k, v in phonemes_dict.items()} # Get all unique phonemes phonemes = [] - for v in phonemes_dict.itervalues(): + for v in phonemes_dict.values(): for p in v: phonemes.append(p) @@ -137,11 +137,11 @@ def write_extra_questions(nonsil_phonemes, nonsil_phonemes_dict, # Write all possible phone_tag combinations that occur in the lexicon for tag in tags: - for p in nonsil_phonemes_dict.iterkeys(): + for p in nonsil_phonemes_dict.keys(): tagged_phoneme = "_".join([p, tag]) if(tagged_phoneme in nonsil_phonemes_dict[p]): fp.write("%s " % tagged_phoneme) - for p in sil_phonemes_dict.iterkeys(): + for p in sil_phonemes_dict.keys(): tagged_phoneme = "_".join([p, tag]) if(tagged_phoneme in sil_phonemes_dict[p]): fp.write("%s " % tagged_phoneme) diff --git a/egs/babel/s5d/local/resegment/segmentation.py b/egs/babel/s5d/local/resegment/segmentation.py index 7c5c8665a16..02fd7646b96 100755 --- a/egs/babel/s5d/local/resegment/segmentation.py +++ b/egs/babel/s5d/local/resegment/segmentation.py @@ -3,6 +3,7 @@ # Copyright 2014 Vimal Manohar # Apache 2.0 +from __future__ import division import os, glob, argparse, sys, re, time from argparse import ArgumentParser @@ -19,12 +20,12 @@ def mean(l): if len(l) > 0: - return float(sum(l)) / len(l) + return float(sum(l))/len(l) return 0 # Analysis class # Stores statistics like the confusion matrix, length of the segments etc. -class Analysis: +class Analysis(object): def __init__(self, file_id, frame_shift, prefix): self.confusion_matrix = [0] * 9 self.type_counts = [ [[] for j in range(0,9)] for i in range(0,3) ] @@ -274,8 +275,8 @@ def read_rttm_file(rttm_file, temp_dir, frame_shift): i = len(this_file) category = splits[6] word = splits[5] - start_time = int(float(splits[3])/frame_shift + 0.5) - duration = int(float(splits[4])/frame_shift + 0.5) + start_time = int((float(splits[3])/frame_shift) + 0.5) + duration = int((float(splits[4])/frame_shift) + 0.5) if i < start_time: this_file.extend(["0"]*(start_time - i)) if type1 == "NON-LEX": @@ -295,7 +296,7 @@ def read_rttm_file(rttm_file, temp_dir, frame_shift): # Stats class to store some basic stats about the number of # times the post-processor goes through particular loops or blocks # of code in the algorithm. This is just for debugging. -class Stats: +class Stats(object): def __init__(self): self.inter_utt_nonspeech = 0 self.merge_nonspeech_segment = 0 @@ -321,7 +322,7 @@ def reset(self): self.noise_only = 0 # Timer class to time functions -class Timer: +class Timer(object): def __enter__(self): self.start = time.clock() return self @@ -332,7 +333,7 @@ def __exit__(self, *args): # The main class for post-processing a file. # This does the segmentation either looking at the file isolated # or by looking at both classes simultaneously -class JointResegmenter: +class JointResegmenter(object): def __init__(self, P, A, f, options, phone_map, stats = None, reference = None): # Pointers to prediction arrays and Initialization @@ -351,8 +352,8 @@ def __init__(self, P, A, f, options, phone_map, stats = None, reference = None): self.frame_shift = options.frame_shift # Convert length in seconds to frames - self.max_frames = int(options.max_segment_length / options.frame_shift) - self.hard_max_frames = int(options.hard_max_segment_length / options.frame_shift) + self.max_frames = int(options.max_segment_length/options.frame_shift) + self.hard_max_frames = int(options.hard_max_segment_length/options.frame_shift) self.min_inter_utt_nonspeech_length = int(options.min_inter_utt_silence_length / options.frame_shift) if ( options.remove_noise_only_segments == "false" ): self.remove_noise_segments = False @@ -540,7 +541,7 @@ def set_nonspeech_proportion(self): # Set the number of non-speech frames to be added depending on the # silence proportion. The target number of frames in the segments # is computed as below: - target_segment_frames = int(num_speech_frames / (1.0 - self.options.silence_proportion)) + target_segment_frames = int(num_speech_frames/(1.0 - self.options.silence_proportion)) # The number of frames currently in the segments num_segment_frames = num_speech_frames @@ -599,7 +600,7 @@ def set_nonspeech_proportion(self): if not changed: # avoid an infinite loop. if no changes, then break. break if num_segment_frames < target_segment_frames: - proportion = float(num_segment_frames - num_speech_frames) / num_segment_frames + proportion = float(num_segment_frames - num_speech_frames)/ num_segment_frames sys.stderr.write("%s: Warning: for recording %s, only got a proportion %f of non-speech frames, versus target %f\n" % (sys.argv[0], self.file_id, proportion, self.options.silence_proportion)) ########################################################################### @@ -863,14 +864,14 @@ def split_long_segments(self): # Count the number of times long segments are split self.stats.split_segments += 1 - num_pieces = int((float(segment_length) / self.hard_max_frames) + 0.99999) + num_pieces = int((float(segment_length)/self.hard_max_frames) + 0.99999) sys.stderr.write("%s: Warning: for recording %s, " \ % (sys.argv[0], self.file_id) \ + "splitting segment of length %f seconds into %d pieces " \ % (segment_length * self.frame_shift, num_pieces) \ + "(--hard-max-segment-length %f)\n" \ % self.options.hard_max_segment_length) - frames_per_piece = int(segment_length / num_pieces) + frames_per_piece = int(segment_length/num_pieces) for i in range(1,num_pieces): q = n + i * frames_per_piece self.S[q] = True @@ -1388,7 +1389,7 @@ def main(): speech_cap = None if options.speech_cap_length != None: - speech_cap = int( options.speech_cap_length / options.frame_shift ) + speech_cap = int(options.speech_cap_length/options.frame_shift) # End if for f in pred_files: @@ -1454,7 +1455,7 @@ def main(): f2 = f3 # End if - if (len(A1) - len(A2)) > options.max_length_diff / options.frame_shift: + if (len(A1) - len(A2)) > options.max_length_diff/options.frame_shift: sys.stderr.write( \ "%s: Warning: Lengths of %s and %s differ by more than %f. " \ % (sys.argv[0], f1,f2, options.max_length_diff) \ diff --git a/egs/babel/s5d/local/syllab/generate_phone_lang.sh b/egs/babel/s5d/local/syllab/generate_phone_lang.sh index fc21a23231b..81d8a0acdc7 100755 --- a/egs/babel/s5d/local/syllab/generate_phone_lang.sh +++ b/egs/babel/s5d/local/syllab/generate_phone_lang.sh @@ -122,8 +122,7 @@ ln -s lex.syllabs2phones.disambig.fst $out/L_disambig.fst echo "Validating the output lang dir" utils/validate_lang.pl $out || exit 1 -sed -i'' 's/#1$//g' $lout/lexicon.txt -sed -i'' 's/#1$//g' $lout/lexiconp.txt +perl -i -pe 's/#1$//g' $lout/lexicon.txt $lout/lexiconp.txt echo "Done OK." exit 0 diff --git a/egs/babel/s5d/local/syllab/generate_syllable_lang.sh b/egs/babel/s5d/local/syllab/generate_syllable_lang.sh index db7b0902425..a7bd667027c 100755 --- a/egs/babel/s5d/local/syllab/generate_syllable_lang.sh +++ b/egs/babel/s5d/local/syllab/generate_syllable_lang.sh @@ -122,8 +122,7 @@ ln -s lex.syllabs2phones.disambig.fst $out/L_disambig.fst echo "Validating the output lang dir" utils/validate_lang.pl $out || exit 1 -sed -i'' 's/#1$//g' $lout/lexicon.txt -sed -i'' 's/#1$//g' $lout/lexiconp.txt +perl -i -pe 's/#1$//g' $lout/lexicon.txt $lout/lexiconp.txt echo "Done OK." exit 0 diff --git a/egs/bentham/README.txt b/egs/bentham/README.txt new file mode 100644 index 00000000000..02870c265f6 --- /dev/null +++ b/egs/bentham/README.txt @@ -0,0 +1,5 @@ +This directory contains example scripts for handwriting recognition on +the Bentham dataset: +http://www.transcriptorium.eu/~htrcontest/contestICFHR2014/public_html/ +In the ICFHR 2014 contest, the best performing system in the unrestricted +track obtained a WER of 8.6%. diff --git a/egs/bentham/v1/cmd.sh b/egs/bentham/v1/cmd.sh new file mode 100755 index 00000000000..3c8eb9f93a5 --- /dev/null +++ b/egs/bentham/v1/cmd.sh @@ -0,0 +1,13 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +export cmd="queue.pl" diff --git a/egs/bentham/v1/image b/egs/bentham/v1/image new file mode 120000 index 00000000000..6a4b3afeb09 --- /dev/null +++ b/egs/bentham/v1/image @@ -0,0 +1 @@ +../../cifar/v1/image \ No newline at end of file diff --git a/egs/bentham/v1/local/chain/compare_wer.sh b/egs/bentham/v1/local/chain/compare_wer.sh new file mode 100755 index 00000000000..2ce14e13694 --- /dev/null +++ b/egs/bentham/v1/local/chain/compare_wer.sh @@ -0,0 +1,120 @@ +#!/bin/bash + +# this script is used for comparing decoding results between systems. +# e.g. local/chain/compare_wer.sh exp/chain/cnn{1a,1b} + +# Copyright 2017 Chun Chieh Chang +# 2017 Ashish Arora + +if [ $# == 0 ]; then + echo "Usage: $0: [ ... ]" + echo "e.g.: $0 exp/chain/cnn{1a,1b}" + exit 1 +fi +. ./path.sh + +echo "# $0 $*" +used_epochs=false + +echo -n "# System " +for x in $*; do printf "% 10s" " $(basename $x)"; done +echo + +echo -n "# WER " +for x in $*; do + wer=$(cat $x/decode_test/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# WER (rescored) " +for x in $*; do + wer="--" + [ -d $x/decode_test_rescored ] && wer=$(cat $x/decode_test_rescored/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# CER " +for x in $*; do + cer=$(cat $x/decode_test/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + +echo -n "# CER (rescored) " +for x in $*; do + cer="--" + [ -d $x/decode_test_rescored ] && cer=$(cat $x/decode_test_rescored/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + +echo -n "# WER val " +for x in $*; do + wer=$(cat $x/decode_val/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# WER (rescored) val " +for x in $*; do + wer="--" + [ -d $x/decode_val_rescored ] && wer=$(cat $x/decode_val_rescored/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# CER val " +for x in $*; do + cer=$(cat $x/decode_val/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + +echo -n "# CER (rescored) val " +for x in $*; do + cer="--" + [ -d $x/decode_val_rescored ] && cer=$(cat $x/decode_val_rescored/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + +if $used_epochs; then + exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. +fi + +echo -n "# Final train prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final train prob (xent) " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob (xent) " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Parameters " +for x in $*; do + params=$(nnet3-info $x/final.mdl 2>/dev/null | grep num-parameters | cut -d' ' -f2 | awk '{printf "%0.2fM\n",$1/1000000}') + printf "% 10s" $params +done +echo diff --git a/egs/bentham/v1/local/chain/run_cnn_e2eali.sh b/egs/bentham/v1/local/chain/run_cnn_e2eali.sh new file mode 120000 index 00000000000..e2545b0186e --- /dev/null +++ b/egs/bentham/v1/local/chain/run_cnn_e2eali.sh @@ -0,0 +1 @@ +tuning/run_cnn_e2eali_1a.sh \ No newline at end of file diff --git a/egs/bentham/v1/local/chain/run_e2e_cnn.sh b/egs/bentham/v1/local/chain/run_e2e_cnn.sh new file mode 120000 index 00000000000..d26ba0182ce --- /dev/null +++ b/egs/bentham/v1/local/chain/run_e2e_cnn.sh @@ -0,0 +1 @@ +tuning/run_e2e_cnn_1a.sh \ No newline at end of file diff --git a/egs/bentham/v1/local/chain/tuning/run_cnn_e2eali_1a.sh b/egs/bentham/v1/local/chain/tuning/run_cnn_e2eali_1a.sh new file mode 100755 index 00000000000..ec530ef1ce4 --- /dev/null +++ b/egs/bentham/v1/local/chain/tuning/run_cnn_e2eali_1a.sh @@ -0,0 +1,261 @@ +#!/bin/bash + +# local/chain/compare_wer.sh exp/chain/e2e_cnn_1a/ exp/chain/cnn_e2eali_1a +# System e2e_cnn_1a cnn_e2eali_1a +# WER 13.72 8.14 +# WER (rescored) 13.40 8.00 +# CER 6.56 2.82 +# CER (rescored) 6.33 2.73 +# WER val 13.51 8.19 +# WER (rescored) val 13.38 7.97 +# CER val 6.40 2.93 +# CER (rescored) val 6.29 2.90 +# Final train prob 0.1037 -0.0613 +# Final valid prob 0.0720 -0.0988 +# Final train prob (xent) -0.3706 +# Final valid prob (xent) -0.4669 +# Parameters 11.54M 4.29M + +# steps/info/chain_dir_info.pl exp/chain/cnn_e2eali_1a +# exp/chain/cnn_e2eali_1a: num-iters=20 nj=3..5 num-params=4.3M dim=40->336 combine=-0.066->-0.066 (over 1) xent:train/valid[12,19,final]=(-0.822,-0.437,-0.371/-0.859,-0.514,-0.467) logprob:train/valid[12,19,final]=(-0.188,-0.078,-0.061/-0.204,-0.114,-0.099) + +set -e -o pipefail + +stage=0 + +nj=30 +train_set=train +decode_val=true +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. +affix=_1a #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +e2echain_model_dir=exp/chain/e2e_cnn_1a +common_egs_dir= +reporting_email= + +# chain options +train_stage=-10 +xent_regularize=0.1 +frame_subsampling_factor=4 +# training chunk-options +chunk_width=340,300,200,100 +num_leaves=500 +# we don't need extra left/right context for TDNN systems. +chunk_left_context=0 +chunk_right_context=0 +tdnn_dim=550 +# training options +srand=0 +remove_egs=true +lang_decode=data/lang +lang_rescore=data/lang_rescore_6g +if $decode_val; then maybe_val=val; else maybe_val= ; fi +dropout_schedule='0,0@0.20,0.2@0.50,0' +# End configuration section. +echo "$0 $@" # Print the command line for logging + + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 2 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/nnet3/align_lats.sh --nj $nj --cmd "$cmd" \ + --acoustic-scale 1.0 \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0' \ + ${train_data_dir} data/lang $e2echain_model_dir $lat_dir + echo "" >$lat_dir/splice_opts +fi + +if [ $stage -le 3 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor $frame_subsampling_factor \ + --alignment-subsampling-factor 1 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$cmd" $num_leaves ${train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 4 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + cnn_opts="l2-regularize=0.03 dropout-proportion=0.0" + tdnn_opts="l2-regularize=0.03" + output_opts="l2-regularize=0.04" + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=70" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=40 name=input + + conv-relu-batchnorm-dropout-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-dropout-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-dropout-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-dropout-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-dropout-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common3 height-subsample-out=2 + conv-relu-batchnorm-dropout-layer name=cnn6 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + relu-batchnorm-dropout-layer name=tdnn1 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + relu-batchnorm-dropout-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + relu-batchnorm-dropout-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' mod?els... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 5 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/iam-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=true \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=1000" \ + --chain.frame-subsampling-factor=$frame_subsampling_factor \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=5 \ + --trainer.frames-per-iter=1500000 \ + --trainer.optimization.num-jobs-initial=3 \ + --trainer.optimization.num-jobs-final=5 \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=32,16 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.chunk-left-context=$chunk_left_context \ + --egs.chunk-right-context=$chunk_right_context \ + --egs.chunk-left-context-initial=0 \ + --egs.chunk-right-context-final=0 \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 6 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $lang_decode \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 7 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + for decode_set in test $maybe_val; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --extra-left-context $chunk_left_context \ + --extra-right-context $chunk_right_context \ + --extra-left-context-initial 0 \ + --extra-right-context-final 0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/$decode_set $dir/decode_$decode_set || exit 1; + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/$decode_set $dir/decode_${decode_set}{,_rescored} || exit 1 + done +fi + + +echo "Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/bentham/v1/local/chain/tuning/run_e2e_cnn_1a.sh b/egs/bentham/v1/local/chain/tuning/run_e2e_cnn_1a.sh new file mode 100755 index 00000000000..716bdce3729 --- /dev/null +++ b/egs/bentham/v1/local/chain/tuning/run_e2e_cnn_1a.sh @@ -0,0 +1,166 @@ +#!/bin/bash +# Copyright 2017 Hossein Hadian + +# This script does end2end chain training (i.e. from scratch) +# ./local/chain/compare_wer.sh exp/chain/e2e_cnn_1a/ +# System e2e_cnn_1b +# WER 13.72 +# WER (rescored) 13.40 +# CER 6.56 +# CER (rescored) 6.33 +# WER val 13.51 +# WER (rescored) val 13.38 +# CER val 6.40 +# CER (rescored) val 6.29 +# Final train prob 0.1037 +# Final valid prob 0.0720 +# Final train prob (xent) +# Final valid prob (xent) +# Parameters 11.54M +# steps/info/chain_dir_info.pl exp/chain/e2e_cnn_1a +# exp/chain/e2e_cnn_1a: num-iters=26 nj=2..4 num-params=11.5M dim=40->17112 combine=0.054->0.054 (over 1) logprob:train/valid[16,25,final]=(0.078,0.102,0.104/0.051,0.069,0.072) +set -e + +# configs for 'chain' +stage=0 +train_stage=-10 +get_egs_stage=-10 +affix=1a +nj=30 + +# training options +tdnn_dim=450 +minibatch_size=150=100,64/300=50,32/600=25,16/1200=16,8 +common_egs_dir= +train_set=train +decode_val=true +lang_decode=data/lang +lang_rescore=data/lang_rescore_6g +if $decode_val; then maybe_val=val; else maybe_val= ; fi +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 1 ]; then + steps/nnet3/chain/e2e/prepare_e2e.sh --nj 30 --cmd "$cmd" \ + --shared-phones true \ + --type biphone \ + data/$train_set $lang $treedir + $cmd $treedir/log/make_phone_lm.log \ + cat data/$train_set/text \| \ + steps/nnet3/chain/e2e/text_to_phones.py data/lang \| \ + utils/sym2int.pl -f 2- data/lang/phones.txt \| \ + chain-est-phone-lm --num-extra-lm-states=500 \ + ark:- $treedir/phone_lm.fst +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $treedir/tree | grep num-pdfs | awk '{print $2}') + common1="height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="height-offsets=-2,-1,0,1,2 num-filters-out=70" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=40 name=input + + conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + relu-batchnorm-layer name=tdnn1 input=Append(-4,-2,0,2,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn4 input=Append(-4,0,4) dim=$tdnn_dim + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $output_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs +fi + +if [ $stage -le 3 ]; then + # no need to store the egs in a shared storage because we always + # remove them. Anyway, it takes only 5 minutes to generate them. + + steps/nnet3/chain/e2e/train_e2e.py --stage $train_stage \ + --cmd "$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--num_egs_diagnostic 100 --num_utts_subset 400" \ + --chain.frame-subsampling-factor 4 \ + --chain.alignment-subsampling-factor 4 \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter 1000000 \ + --trainer.num-epochs 4 \ + --trainer.optimization.momentum 0 \ + --trainer.optimization.num-jobs-initial 2 \ + --trainer.optimization.num-jobs-final 4 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.optimization.shrink-value 1.0 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs true \ + --feat-dir data/${train_set} \ + --tree-dir $treedir \ + --dir $dir || exit 1; +fi + +if [ $stage -le 4 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $lang_decode \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 5 ]; then + for decode_set in test $maybe_val; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/$decode_set $dir/decode_$decode_set || exit 1; + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/$decode_set $dir/decode_${decode_set}{,_rescored} || exit 1 + done +fi + +echo "Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/bentham/v1/local/check_tools.sh b/egs/bentham/v1/local/check_tools.sh new file mode 100755 index 00000000000..5b4d3107d3b --- /dev/null +++ b/egs/bentham/v1/local/check_tools.sh @@ -0,0 +1,43 @@ +#!/bin/bash -u + +# Copyright 2015 (c) Johns Hopkins University (Jan Trmal ) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +[ -f ./path.sh ] && . ./path.sh +set +e + +command -v python3 >&/dev/null \ + || { echo >&2 "python3 not found on PATH. You will have to install Python3, preferably >= 3.6"; exit 1; } + +python3 -c "import numpy" +if [ $? -ne 0 ] ; then + echo >&2 "This recipe needs numpy installed." + exit 1 +fi + +python3 -c "import scipy" +if [ $? -ne 0 ] ; then + echo >&2 "This recipe needs scipy installed." + exit 1 +fi + +python3 -c "import scipy.misc; scipy.misc.__dict__['imread']" +if [ $? -ne 0 ] ; then + echo >&2 "This recipe needs scipy-image and Pillow installed." + exit 1 +fi + + +exit 0 diff --git a/egs/bentham/v1/local/create_splits.sh b/egs/bentham/v1/local/create_splits.sh new file mode 100755 index 00000000000..e8ea2279a49 --- /dev/null +++ b/egs/bentham/v1/local/create_splits.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright 2018 Desh Raj (Johns Hopkins University) + +# This script reads the extracted Bentham database files and creates +# the following files (for all the data subsets): +# text, utt2spk, images.scp. + +download_dir=$1 +save_dir=$2 +mkdir -p $save_dir/{train,val,test} +touch $save_dir/{train,val,test}/{text,images.scp,utt2spk,spk2utt} + +partition_dir=$download_dir"/gt/Partitions/" +lines_dir=$download_dir"/gt/Images/Lines/" +text_dir=$download_dir"/gt/Transcriptions/" + +function split { + echo "Creating $1 split" + split_dir=$save_dir/$1 + line_file=$partition_dir/$2 + + while read -r line; do + name="$line" + spkid=${name:0:11} + echo -n $name" " | cat - $text_dir/$name* >> $split_dir/text + echo >> $split_dir/text + echo $name $lines_dir"/"$name".png" >> $split_dir/images.scp + echo $name $spkid >> $split_dir/utt2spk + done < "$line_file" + + perl -i -ne 'print if /\S/' $split_dir/images.scp $split_dir/text $split_dir/utt2spk + utils/utt2spk_to_spk2utt.pl $split_dir/utt2spk > $split_dir/spk2utt +} + +split train TrainLines.lst +split val ValidationLines.lst +split test TestLines.lst diff --git a/egs/bentham/v1/local/download_bentham_text.sh b/egs/bentham/v1/local/download_bentham_text.sh new file mode 100755 index 00000000000..e09403718a1 --- /dev/null +++ b/egs/bentham/v1/local/download_bentham_text.sh @@ -0,0 +1,32 @@ +#!/bin/bash +# Copyright 2018 Desh Raj +# Apache 2.0 + +## Download all written works of Jeremy Bentham for the Bentham HWR task LM training + +baseurl='http://oll.libertyfund.org/titles/' +savedir=$1 + +mkdir -p $savedir + +declare -a texts=("bentham-the-works-of-jeremy-bentham-vol-1/simple" + "bentham-the-works-of-jeremy-bentham-vol-2/simple" + "bentham-the-works-of-jeremy-bentham-vol-3/simple" + "bentham-the-works-of-jeremy-bentham-vol-5-scotch-reform-real-property-codification-petitions/simple" + "bentham-the-works-of-jeremy-bentham-vol-6/simple" + "bentham-the-works-of-jeremy-bentham-vol-7-rationale-of-judicial-evidence-part-2/simple" + "bentham-the-works-of-jeremy-bentham-vol-8/simple" + "bentham-the-works-of-jeremy-bentham-vol-9-constitutional-code" + "bentham-the-works-of-jeremy-bentham-vol-10-memoirs-part-i-and-correspondence/simple" + "bentham-the-works-of-jeremy-bentham-vol-11-memoirs-of-bentham-part-ii-and-analytical-index") + +counter=1 +for i in "${texts[@]}" +do + echo "Downloading $baseurl$i" + curl -s -N {$baseurl}{$i} | sed -e 's/<[^>]*>//g' > $savedir"/bentham"$counter".txt" + ((counter++)) +done + +cat $savedir"/*.txt" > $savedir"/complete.txt" +rm $savedir"/bentham*.txt" diff --git a/egs/bentham/v1/local/extract_features.sh b/egs/bentham/v1/local/extract_features.sh new file mode 100755 index 00000000000..460e467e99c --- /dev/null +++ b/egs/bentham/v1/local/extract_features.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +# Copyright 2017 Yiwen Shao +# 2018 Ashish Arora + +# Apache 2.0 +# This script runs the make features script in parallel. + +nj=4 +cmd=run.pl +feat_dim=40 +augment='no_aug' +fliplr=false +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +data=$1 +featdir=$data/data +scp=$data/images.scp +logdir=$data/log + +mkdir -p $logdir +mkdir -p $featdir + +# make $featdir an absolute pathname +featdir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $featdir ${PWD}` + +for n in $(seq $nj); do + split_scps="$split_scps $logdir/images.$n.scp" +done + +# split images.scp +utils/split_scp.pl $scp $split_scps || exit 1; + +$cmd JOB=1:$nj $logdir/extract_features.JOB.log \ + image/ocr/make_features.py $logdir/images.JOB.scp \ + --allowed_len_file_path $data/allowed_lengths.txt \ + --num-channels 4 \ + --feat-dim $feat_dim --fliplr $fliplr --augment_type $augment \| \ + copy-feats --compress=true --compression-method=7 \ + ark:- ark,scp:$featdir/images.JOB.ark,$featdir/images.JOB.scp + +## aggregates the output scp's to get feats.scp +for n in $(seq $nj); do + cat $featdir/images.$n.scp || exit 1; +done > $data/feats.scp || exit 1 diff --git a/egs/bentham/v1/local/gen_topo.py b/egs/bentham/v1/local/gen_topo.py new file mode 100755 index 00000000000..af9e20317d8 --- /dev/null +++ b/egs/bentham/v1/local/gen_topo.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python + +# Copyright 2017 (author: Chun-Chieh Chang) + +# Generate a topology file. This allows control of the number of states in the +# non-silence HMMs, and in the silence HMMs. This is a modified version of +# 'utils/gen_topo.pl'. The difference is that this creates two topologies for +# the non-silence HMMs. The number of states for punctuations is different than +# the number of states for other characters. + +from __future__ import print_function +from __future__ import division +import argparse +import string + +parser = argparse.ArgumentParser(description="Usage: steps/nnet3/chain/gen_topo.py " + " " + "e.g.: steps/nnet3/chain/gen_topo.pl 4:5:6:7:8:9:10 1:2:3\n", + epilog="See egs/swbd/s5c/local/chain/train_tdnn_a.sh for example of usage."); +parser.add_argument("num_nonsil_states", type=int, help="number of states for nonsilence phones"); +parser.add_argument("num_sil_states", type=int, help="number of states for silence phones"); +parser.add_argument("num_punctuation_states", type=int, help="number of states for punctuation"); +parser.add_argument("nonsilence_phones", + help="List of non-silence phones as integers, separated by colons, e.g. 4:5:6:7:8:9"); +parser.add_argument("silence_phones", + help="List of silence phones as integers, separated by colons, e.g. 1:2:3"); +parser.add_argument("phone_list", help="file containing all phones and their corresponding number."); + +args = parser.parse_args() + +silence_phones = [ int(x) for x in args.silence_phones.split(":") ] +nonsilence_phones = [ int(x) for x in args.nonsilence_phones.split(":") ] +all_phones = silence_phones + nonsilence_phones + +punctuation_phones = [] +exclude = set("!(),.?;:'-\"") +with open(args.phone_list) as f: + for line in f: + line = line.strip() + phone = line.split(' ')[0] + if len(phone) == 1 and phone in exclude: + punctuation_phones.append(int(line.split(' ')[1])) +# For nonsilence phones that are not punctuations +print("") +print("") +print("") +print(" ".join([str(x) for x in nonsilence_phones if x not in punctuation_phones])) +print("") +for x in range(0, args.num_nonsil_states): + xp1 = x + 1 + print(" {0} {0} {0} 0.75 {1} 0.25 ".format(x, xp1)) +print(" {} ".format(args.num_nonsil_states)) +print("") + +# For nonsilence phones that ar punctuations +print("") +print("") +print(" ".join([str(x) for x in nonsilence_phones if x in punctuation_phones])) +print("") +for x in range(0, args.num_punctuation_states): + xp1 = x + 1 + print(" {0} {0} {0} 0.75 {1} 0.25 ".format(x, xp1)) +print(" {} ".format(args.num_punctuation_states)) +print("") + +# For silence phones +print("") +print("") +print(" ".join([str(x) for x in silence_phones])) +print("") +if(args.num_sil_states > 1): + transp = 1.0/(args.num_sil_states - 1) + + state_str = " 0 0 " + for x in range(0, (args.num_sil_states - 1)): + state_str = "{} {} {} ".format(state_str, x, transp) + state_str = state_str + "" + print(state_str) + + for x in range(1, (args.num_sil_states - 1)): + state_str = " {0} {0} ".format(x) + for y in range(1, args.num_sil_states): + state_str = "{} {} {} ".format(state_str, y, transp) + state_str = state_str + "" + print(state_str) + second_last = args.num_sil_states - 1 + print(" {0} {0} {0} 0.75 {1} 0.25 ".format(second_last, args.num_sil_states)) + print(" {} ".format(args.num_sil_states)) +else: + print(" 0 0 0 0.75 1 0.25 ") + print(" {} ".format(args.num_sil_states)) +print("") +print("") diff --git a/egs/bentham/v1/local/prepare_data.sh b/egs/bentham/v1/local/prepare_data.sh new file mode 100755 index 00000000000..bbcc9863611 --- /dev/null +++ b/egs/bentham/v1/local/prepare_data.sh @@ -0,0 +1,69 @@ +#!/bin/bash + +# Copyright 2018 Desh Raj (Johns Hopkins University) + +# Apache 2.0 + +# This script downloads the Bentham handwriting database and prepares the training +# and test data (i.e text, images.scp, utt2spk and spk2utt) by calling create_splits.sh. + +# In addition, it downloads data for all texts of Bentham for LM training purpose. + +stage=0 +download_dir=data/local/download/ +database_dir="" +text_corpus_dir="" + +mkdir -p $download_dir + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +BENTHAM_IMAGES_URL='http://transcriptorium.eu/~tsdata/BenthamR0/BenthamDatasetR0-Images.zip' +BENTHAM_GT_URL='http://transcriptorium.eu/~tsdata/BenthamR0/BenthamDatasetR0-GT.zip' +bentham_images=$database_dir"/images.zip" +bentham_gt=$database_dir"/gt.zip" +bentham_text=$download_dir"/text" + +# download and extract images and transcriptions +if [ ! -f $bentham_images ]; then + echo "Downloading images and transcriptions to $database_dir" + mkdir -p $database_dir + wget $BENTHAM_IMAGES_URL -O $bentham_images + wget $BENTHAM_GT_URL -O $bentham_gt +else + echo "Not downloading since corpus already exists" +fi + +if [ ! -d $download_dir/"gt" ]; then + unzip $bentham_gt -d $download_dir + mv $download_dir"/BenthamDatasetR0-GT" $download_dir"/gt" +else + echo "Local extracted corpus already exists" +fi + +# Download extra Bentham text for LM training +if [ -d $text_corpus_dir ]; then + echo "$0: Not downloading Bentham text corpus as it is already there." +else + local/download_bentham_text.sh $text_corpus_dir +fi + +# Copy extra Bentham text to local +if [ -d $bentham_text ]; then + echo "$0: Not copying as local Bentham already present." +else + mkdir -p $bentham_text + cp $text_corpus_dir/Bentham-Text/* $bentham_text + echo "$0: Done copying extra Bentham text to local." +fi + +# Creating train, val, and test splits for all directories +if [ -d data/train ]; then + echo "Data splits and files already exist. Not creating again." +else + echo "Creating train, val, and test splits and corresponding files.." + local/create_splits.sh $download_dir "data/" +fi + diff --git a/egs/bentham/v1/local/prepare_dict.sh b/egs/bentham/v1/local/prepare_dict.sh new file mode 100755 index 00000000000..22db5ae834d --- /dev/null +++ b/egs/bentham/v1/local/prepare_dict.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash + +# Copyright 2017 Hossein Hadian +# 2017 Babak Rekabdar +# 2017 Chun Chieh Chang +# 2017 Ashish Arora + +# This script prepares the dictionary. + +set -e +dir=data/local/dict +. ./utils/parse_options.sh || exit 1; + +mkdir -p $dir + +local/prepare_lexicon.py $dir + +cut -d' ' -f2- $dir/lexicon.txt | sed 's/SIL//g' | tr ' ' '\n' | sort -u | sed '/^$/d' >$dir/nonsilence_phones.txt || exit 1; + +echo ' SIL' >> $dir/lexicon.txt + +echo SIL > $dir/silence_phones.txt + +echo SIL >$dir/optional_silence.txt + +echo -n "" >$dir/extra_questions.txt diff --git a/egs/bentham/v1/local/prepare_lexicon.py b/egs/bentham/v1/local/prepare_lexicon.py new file mode 100755 index 00000000000..3de96056c2a --- /dev/null +++ b/egs/bentham/v1/local/prepare_lexicon.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Babak Rekabdar +# 2017 Hossein Hadian +# 2017 Chun Chieh Chang +# 2017 Ashish Arora +# Apache 2.0 + +# This script prepares lexicon for BPE. It gets the set of all words that occur in data/train/text. +# Since this lexicon is based on BPE, it replaces '|' with silence. + +import argparse +import os + +parser = argparse.ArgumentParser(description="""Creates the list of characters and words in lexicon""") +parser.add_argument('dir', type=str, help='output path') +args = parser.parse_args() + +### main ### +lex = {} +text_path = os.path.join('data', 'train', 'text') +with open(text_path, 'r', encoding='utf-8') as f: + for line in f: + line_vect = line.strip().split(' ') + for i in range(1, len(line_vect)): + characters = list(line_vect[i]) + characters = " ".join([ 'SIL' if char == '|' else char for char in characters]) + characters = list(characters) + characters = "".join([ '' if char == '#' else char for char in characters]) + lex[line_vect[i]] = characters + +with open(os.path.join(args.dir, 'lexicon.txt'), 'w', encoding='utf-8') as fp: + for key in sorted(lex): + fp.write(key + " " + lex[key] + "\n") diff --git a/egs/bentham/v1/local/score.sh b/egs/bentham/v1/local/score.sh new file mode 100755 index 00000000000..1d84815fc69 --- /dev/null +++ b/egs/bentham/v1/local/score.sh @@ -0,0 +1,6 @@ + +#!/bin/bash + + +steps/scoring/score_kaldi_wer.sh "$@" +steps/scoring/score_kaldi_cer.sh --stage 2 "$@" diff --git a/egs/bentham/v1/local/train_lm.sh b/egs/bentham/v1/local/train_lm.sh new file mode 100755 index 00000000000..48632a90769 --- /dev/null +++ b/egs/bentham/v1/local/train_lm.sh @@ -0,0 +1,141 @@ +#!/bin/bash + +# Copyright 2016 Vincent Nguyen +# 2016 Johns Hopkins University (author: Daniel Povey) +# 2017 Ashish Arora +# 2017 Hossein Hadian +# 2018 Desh Raj +# Apache 2.0 +# +# This script trains an LM on the Bentham text corpus and training transcriptions. +# It is based on the example scripts distributed with PocoLM + +# It will check if pocolm is installed and if not will proceed with installation + +set -e +stage=0 +vocab_size=50000 + +echo "$0 $@" # Print the command line for logging +. ./utils/parse_options.sh || exit 1; + +dir=data/local/local_lm +lm_dir=${dir}/data +bentham_text_dir=data/local/download/text/ + +mkdir -p $dir +. ./path.sh || exit 1; # for KALDI_ROOT +export PATH=$KALDI_ROOT/tools/pocolm/scripts:$PATH +( # First make sure the pocolm toolkit is installed. + cd $KALDI_ROOT/tools || exit 1; + if [ -d pocolm ]; then + echo Not installing the pocolm toolkit since it is already there. + else + echo "$0: Please install the PocoLM toolkit with: " + echo " cd ../../../tools; extras/install_pocolm.sh; cd -" + exit 1; + fi +) || exit 1; + +bypass_metaparam_optim_opt= +# If you want to bypass the metaparameter optimization steps with specific metaparameters +# un-comment the following line, and change the numbers to some appropriate values. +# You can find the values from output log of train_lm.py. +# These example numbers of metaparameters is for 4-gram model (with min-counts) +# running with train_lm.py. +# The dev perplexity should be close to the non-bypassed model. +#bypass_metaparam_optim_opt= +# Note: to use these example parameters, you may need to remove the .done files +# to make sure the make_lm_dir.py be called and tain only 3-gram model +#for order in 3; do +#rm -f ${lm_dir}/${num_word}_${order}.pocolm/.done + +if [ $stage -le 0 ]; then + mkdir -p ${dir}/data + mkdir -p ${dir}/data/text + + echo "$0: Getting the Data sources" + + rm ${dir}/data/text/* 2>/dev/null || true + + # Using Bentham text with last 5000 lines for dev + + cat $bentham_text_dir/complete.txt | \ + sed '/^\s*$/d' | \ + utils/lang/bpe/prepend_words.py | utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt \ + | sed 's/@@//g' > ${dir}/bentham.txt + tail -n +5000 ${dir}/bentham.txt > ${dir}/data/text/bentham.txt + + # use the validation data as the dev set. + # Note: the name 'dev' is treated specially by pocolm, it automatically + # becomes the dev set. + head -5000 ${dir}/bentham.txt > ${dir}/data/text/dev.txt + + # use the training data as an additional data source. + # we can later fold the dev data into this. + cat data/train/text | cut -d " " -f 2- > ${dir}/data/text/hwr.txt + + # for reporting perplexities, we'll use the "real" dev set. + # (the validation data is used as ${dir}/data/text/dev.txt to work + # out interpolation weights.) + # note, we can't put it in ${dir}/data/text/, because then pocolm would use + # it as one of the data sources. + cut -d " " -f 2- < data/val/text > ${dir}/data/real_dev_set.txt + + # get the wordlist from Bentham text + cat ${dir}/data/text/{bentham,hwr}.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + head -n $vocab_size ${dir}/data/word_count | awk '{print $2}' > ${dir}/data/wordlist +fi + +order=6 + +if [ $stage -le 1 ]; then + # decide on the vocabulary. + # Note: you'd use --wordlist if you had a previously determined word-list + # that you wanted to use. + # Note: if you have more than one order, use a certain amount of words as the + # vocab and want to restrict max memory for 'sort', + echo "$0: training the unpruned LM" + min_counts='bentham=1 hwr=1' + wordlist=${dir}/data/wordlist + + lm_name="`basename ${wordlist}`_${order}" + if [ -n "${min_counts}" ]; then + lm_name+="_`echo ${min_counts} | tr -s "[:blank:]" "_" | tr "=" "-"`" + fi + unpruned_lm_dir=${lm_dir}/${lm_name}.pocolm + + train_lm.py --wordlist=${wordlist} --num-splits=10 --warm-start-ratio=20 \ + --limit-unk-history=true \ + ${bypass_metaparam_optim_opt} \ + ${dir}/data/text ${order} ${lm_dir}/work ${unpruned_lm_dir} + + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${unpruned_lm_dir} | gzip -c > ${dir}/data/arpa/${order}gram_unpruned.arpa.gz + + get_data_prob.py ${dir}/data/real_dev_set.txt ${unpruned_lm_dir} 2>&1 | grep -F '[perplexity' +fi + +if [ $stage -le 2 ]; then + echo "$0: pruning the LM (to larger size)" + # Using 1 million n-grams for a big LM for rescoring purposes. + size=1000000 + prune_lm_dir.py --target-num-ngrams=$size --initial-threshold=0.02 ${unpruned_lm_dir} ${dir}/data/lm_${order}_prune_big + + get_data_prob.py ${dir}/data/real_dev_set.txt ${dir}/data/lm_${order}_prune_big 2>&1 | grep -F '[perplexity' + + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${dir}/data/lm_${order}_prune_big | gzip -c > ${dir}/data/arpa/${order}gram_big.arpa.gz +fi + +if [ $stage -le 3 ]; then + echo "$0: pruning the LM (to smaller size)" + # Using 500,000 n-grams for a smaller LM for graph building. Prune from the + # bigger-pruned LM, it'll be faster. + size=500000 + prune_lm_dir.py --target-num-ngrams=$size ${dir}/data/lm_${order}_prune_big ${dir}/data/lm_${order}_prune_small + + get_data_prob.py ${dir}/data/real_dev_set.txt ${dir}/data/lm_${order}_prune_small 2>&1 | grep -F '[perplexity' + + format_arpa_lm.py ${dir}/data/lm_${order}_prune_small | gzip -c > ${dir}/data/arpa/${order}gram_small.arpa.gz +fi diff --git a/egs/bentham/v1/local/wer_output_filter b/egs/bentham/v1/local/wer_output_filter new file mode 100755 index 00000000000..24691a160a9 --- /dev/null +++ b/egs/bentham/v1/local/wer_output_filter @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# Copyright 2017 Hossein Hadian + +# This is a filter used in scoring. It separates all +# punctuations from words. For e.g. this sentence: + +# "They have come!" he said reverently, gripping his +# hands. "Isn't it a glorious thing! Long awaited." + +# is converted to this: + +# " They have come ! " he said reverently , gripping his +# hands . " Isn ' t it a glorious thing ! Long awaited . " + +# Sample BPE-based output: +# |He |ro se |from |his |b re ak f as t - s ch oo l |b en ch + +import sys +import re + +punctuations = "!(),.?;:'-\"" +escaped_punctuations = re.escape(punctuations) + +for line in sys.stdin: + words = line.strip().split() + uttid = words[0] + transcript = ''.join(words[1:]) + transcript = transcript.replace('|', ' ') + split_transcript = " ".join(re.split("([{}])".format(escaped_punctuations), + transcript)).strip() + print("{} {}".format(uttid, split_transcript)) diff --git a/egs/bentham/v1/path.sh b/egs/bentham/v1/path.sh new file mode 100755 index 00000000000..2d17b17a84a --- /dev/null +++ b/egs/bentham/v1/path.sh @@ -0,0 +1,6 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C diff --git a/egs/bentham/v1/run_end2end.sh b/egs/bentham/v1/run_end2end.sh new file mode 100755 index 00000000000..63c034e41f6 --- /dev/null +++ b/egs/bentham/v1/run_end2end.sh @@ -0,0 +1,121 @@ +#!/bin/bash +# Copyright 2018 Ashish Arora (Johns Hopkins University) +# 2018 Desh Raj (Johns Hopkins University) + +set -e +stage=0 +nj=20 +# bentham_hwr_database points to the official database path on the JHU grid. If you have not +# already downloaded the data, you will have to first download it and then name the Images +# and Ground Truth zipped files as images.zip and gt.zip. Then, point the path below to the +# location where your zipped files are present on the grid. +bentham_hwr_database=/export/corpora5/handwriting_ocr/hwr1/ICDAR-HTR-Competition-2015 +# bentham_text_database points to the database path on the JHU grid. +# It contains all of the written works of Bentham, and can be used to train +# an LM for the HWR task. We have provided a script which downloads the data +# and saves it to the location provided below. +bentham_text_corpus=/export/corpora5/handwriting_ocr/hwr1/ICDAR-HTR-Competition-2015/Bentham-Text + +. ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. + ## This relates to the queue. +. ./path.sh +. ./utils/parse_options.sh # e.g. this parses the above options + # if supplied. + + +./local/check_tools.sh + +if [ $stage -le 0 ]; then + echo "$0: Preparing data..." + local/prepare_data.sh --database-dir $bentham_hwr_database \ + --text-corpus-dir $bentham_text_corpus +fi + +if [ $stage -le 1 ]; then + image/get_image2num_frames.py data/train # This will be needed for the next command + # The next command creates a "allowed_lengths.txt" file in data/train + # which will be used by local/make_features.py to enforce the images to + # have allowed lengths. The allowed lengths will be spaced by 10% difference in length. + image/get_allowed_lengths.py --frame-subsampling-factor 4 10 data/train + echo "$(date) Extracting features, creating feats.scp file" + for dataset in train val test; do + local/extract_features.sh --nj $nj --cmd "$cmd" --feat-dim 40 data/$dataset + steps/compute_cmvn_stats.sh data/$dataset + done + utils/fix_data_dir.sh data/train +fi + +if [ $stage -le 2 ]; then + echo "$0: Preparing BPE..." + # getting non-silence phones. + cut -d' ' -f2- data/train/text | \ +python3 <( +cat << "END" +import os, sys, io; +infile = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8'); +output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8'); +phone_dict = dict(); +for line in infile: + line_vect = line.strip().split(); + for word in line_vect: + for phone in word: + phone_dict[phone] = phone; +for phone in phone_dict.keys(): + output.write(phone+ '\n'); +END + ) > data/local/phones.txt + + cut -d' ' -f2- data/train/text > data/local/train_data.txt + cat data/local/phones.txt data/local/train_data.txt | \ + utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/learn_bpe.py -s 700 > data/local/bpe.txt + for set in test train val; do + cut -d' ' -f1 data/$set/text > data/$set/ids + cut -d' ' -f2- data/$set/text | \ + utils/lang/bpe/prepend_words.py | utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt \ + | sed 's/@@//g' > data/$set/bpe_text + mv data/$set/text data/$set/text.old + paste -d' ' data/$set/ids data/$set/bpe_text > data/$set/text + done +fi + +if [ $stage -le 3 ]; then + echo "$0: Estimating a language model for decoding..." + local/train_lm.sh +fi + +if [ $stage -le 4 ]; then + echo "$0: Preparing dictionary and lang..." + local/prepare_dict.sh + # This recipe uses byte-pair encoding, the silences are part of the words' pronunciations. + # So we set --sil-prob to 0.0 + utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 8 --sil-prob 0.0 --position-dependent-phones false \ + data/local/dict "" data/lang/temp data/lang + silphonelist=`cat data/lang/phones/silence.csl` + nonsilphonelist=`cat data/lang/phones/nonsilence.csl` + local/gen_topo.py 8 4 4 $nonsilphonelist $silphonelist data/lang/phones.txt >data/lang/topo + utils/lang/bpe/add_final_optional_silence.sh --final-sil-prob 0.5 data/lang + + utils/format_lm.sh data/lang data/local/local_lm/data/arpa/6gram_big.arpa.gz \ + data/local/dict/lexicon.txt data/lang + utils/build_const_arpa_lm.sh data/local/local_lm/data/arpa/6gram_unpruned.arpa.gz \ + data/lang data/lang_rescore_6g +fi + +if [ $stage -le 5 ]; then + echo "$0: Calling the flat-start chain recipe..." + local/chain/run_e2e_cnn.sh +fi + +if [ $stage -le 6 ]; then + echo "$0: Aligning the training data using the e2e chain model..." + steps/nnet3/align.sh --nj 50 --cmd "$cmd" \ + --use-gpu false \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0 --acoustic-scale=1.0' \ + data/train data/lang exp/chain/e2e_cnn_1a exp/chain/e2e_ali_train +fi + +if [ $stage -le 7 ]; then + echo "$0: Building a tree and training a regular chain model using the e2e alignments..." + local/chain/run_cnn_e2eali.sh +fi diff --git a/egs/bentham/v1/steps b/egs/bentham/v1/steps new file mode 120000 index 00000000000..1b186770dd1 --- /dev/null +++ b/egs/bentham/v1/steps @@ -0,0 +1 @@ +../../wsj/s5/steps/ \ No newline at end of file diff --git a/egs/bentham/v1/utils b/egs/bentham/v1/utils new file mode 120000 index 00000000000..a3279dc8679 --- /dev/null +++ b/egs/bentham/v1/utils @@ -0,0 +1 @@ +../../wsj/s5/utils/ \ No newline at end of file diff --git a/egs/bn_music_speech/v1/local/make_annotations_bn.py b/egs/bn_music_speech/v1/local/make_annotations_bn.py index 53cebf52ea4..86bec7b16ae 100755 --- a/egs/bn_music_speech/v1/local/make_annotations_bn.py +++ b/egs/bn_music_speech/v1/local/make_annotations_bn.py @@ -9,6 +9,7 @@ # # This file is meant to be invoked by make_bn.sh. +from __future__ import print_function import sys, re, os def is_speech(line): @@ -37,7 +38,7 @@ def extract_speech(line): m = re.search('(?<=E_time=)\d+.\d+', line) end = float(m.group(0)) if start > end: - print "Skipping annotation where end time is before start time:", line + print("Skipping annotation where end time is before start time: {}".format(line)) return start, end def extract_other_type2(line): @@ -46,7 +47,7 @@ def extract_other_type2(line): m = re.search('(?<=E_time=)\d+.\d+', line) end = float(m.group(0)) if start > end: - print "Skipping annotation where end time is before start time:", line + print("Skipping annotation where end time is before start time: {}".format(line)) return start, end def extract_music(line): @@ -60,7 +61,7 @@ def extract_music(line): elif level == "O": is_on = False else: - print "Encountered bad token on line:", line + print("Encountered bad token on line: {}".format(line)) sys.exit() return time, is_on @@ -75,7 +76,7 @@ def extract_other_type1(line): elif level == "O": is_on = False else: - print "Encountered bad token on line:", line + print("Encountered bad token on line: {}".format(line)) sys.exit() return time, is_on @@ -92,11 +93,11 @@ def process_file(annos): for line in annos: if is_speech(line): speech_start, speech_end = extract_speech(line) - speech = speech + str(speech_start) + " " + str(speech_end) + "\n" + speech = "{}{} {}\n".format(speech, speech_start, speech_end) max_time = max(speech_end, max_time) elif is_other_type2(line): other_type2_start, other_type2_end = extract_other_type2(line) - other_type2 = other_type2 + str(other_type2_start) + " " + str(other_type2_end) + "\n" + other_type2 = "{}{} {}\n".format(other_type2, other_type2_start, other_type2_end) max_time = max(other_type2_end, max_time) elif is_music(line): time, is_on = extract_music(line) @@ -105,7 +106,7 @@ def process_file(annos): prev_music_time = time start_new_music_segment = False elif not is_on and not start_new_music_segment: - music = music + str(prev_music_time) + " " + str(time) + "\n" + music = "{}{} {}\n".format(music, prev_music_time, time) start_new_music_segment = True elif is_other_type1(line): time, is_on = extract_other_type1(line) @@ -114,13 +115,13 @@ def process_file(annos): prev_other_time = time start_new_other_segment = False elif not is_on and not start_new_other_segment: - other_type1 = other_type1 + str(prev_other_time) + " " + str(time) + "\n" + other_type1 = "{}{} {}\n".format(other_type1, prev_other_time, time) start_new_other_segment = True if not start_new_music_segment: - music = music + str(prev_music_time) + " " + str(max_time) + "\n" + music = "{}{} {}\n".format(music, prev_music_time, max_time) if not start_new_other_segment: - other_type1 = other_type1 + str(prev_other_time) + " " + str(max_time) + "\n" + other_type1 = "{}{} {}\n".format(other_type1, prev_other_time, max_time) other = other_type1 + other_type2 return speech, music, other diff --git a/egs/bn_music_speech/v1/local/make_bn.py b/egs/bn_music_speech/v1/local/make_bn.py index 98836d32534..7ec9aabcbdf 100755 --- a/egs/bn_music_speech/v1/local/make_bn.py +++ b/egs/bn_music_speech/v1/local/make_bn.py @@ -20,7 +20,7 @@ for file in files: utt = str(file).replace(".sph", "") if file.endswith(".sph") and utt in utts: - wav = wav + utt + " sox " + subdir + "/" + utt + ".sph" + " -c 1 -r 16000 -t wav - |\n" + wav = "{0}{1} sox {2}/{1}.sph -c 1 -r 16000 -t -wav - |\n".format(wav, utt, subdir) wav_fi = open(os.path.join(out_dir, "wav.scp"), 'w') wav_fi.write(wav) @@ -32,14 +32,14 @@ count = 1 for line in music_fi: left, right = line.rstrip().split(" ") - segments = segments + utt + "-music-" + str(count) + " " + utt + " " + left + " " + right + "\n" - utt2spk = utt2spk + utt + "-music-" + str(count) + " " + utt + "-music-" + str(count) + "\n" + segments = "{0}{1}-music-{2} {1} {3} {4}\n".format(segments, utt, count, left, right) + utt2spk = "{0}{1}-music-{2} {1}-music-{2}".format(utt2spk, utt,count) count += 1 count = 1 for line in speech_fi: left, right = line.rstrip().split(" ") - segments = segments + utt + "-speech-" + str(count) + " " + utt + " " + left + " " + right + "\n" - utt2spk = utt2spk + utt + "-speech-" + str(count) + " " + utt + "-speech-" + str(count) + "\n" + segments = "{0}{1}-speech-{2} {1} {3} {4}\n".format(segments, utt, count, left, right) + utt2spk = "{0}{1}-speech-{2} {1}-music-{2}".format(utt2spk, utt, count) count += 1 utt2spk_fi = open(os.path.join(out_dir, "utt2spk"), 'w') utt2spk_fi.write(utt2spk) diff --git a/egs/bn_music_speech/v1/local/make_musan.py b/egs/bn_music_speech/v1/local/make_musan.py index b3795fe2b7d..eb739b68180 100755 --- a/egs/bn_music_speech/v1/local/make_musan.py +++ b/egs/bn_music_speech/v1/local/make_musan.py @@ -43,9 +43,9 @@ def prepare_music(root_dir, use_vocals): utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n" num_good_files += 1 else: - print("Missing file", utt) + print("Missing file {}".format(utt)) num_bad_files += 1 - print("In music directory, processed", num_good_files, "files;", num_bad_files, "had missing wav data") + print("In music directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files)) return utt2spk_str, utt2wav_str def prepare_speech(root_dir): @@ -69,9 +69,9 @@ def prepare_speech(root_dir): utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n" num_good_files += 1 else: - print("Missing file", utt) + print("Missing file {}".format(utt)) num_bad_files += 1 - print("In speech directory, processed", num_good_files, "files;", num_bad_files, "had missing wav data") + print("In speech directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files)) return utt2spk_str, utt2wav_str def prepare_noise(root_dir): @@ -95,9 +95,9 @@ def prepare_noise(root_dir): utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n" num_good_files += 1 else: - print("Missing file", utt) + print("Missing file {}".format(utt)) num_bad_files += 1 - print("In noise directory, processed", num_good_files, "files;", num_bad_files, "had missing wav data") + print("In noise directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files)) return utt2spk_str, utt2wav_str def main(): diff --git a/egs/bn_music_speech/v1/local/print_scores.py b/egs/bn_music_speech/v1/local/print_scores.py index c2b587cdcad..e563afb63d7 100755 --- a/egs/bn_music_speech/v1/local/print_scores.py +++ b/egs/bn_music_speech/v1/local/print_scores.py @@ -11,6 +11,7 @@ # those strings to determine if it is a target or nontarget # utterance. We arbitrarily pick music to be the target class. +from __future__ import print_function import sys utt2score = open(sys.argv[1], 'r').readlines() for i in range(0, len(utt2score)): @@ -19,4 +20,4 @@ type = "target" else: type = "nontarget" - print score, type + print(score, type) diff --git a/egs/bn_music_speech/v1/local/refine_annotations_bn.py b/egs/bn_music_speech/v1/local/refine_annotations_bn.py index 52ac87c8640..31cb1803f57 100755 --- a/egs/bn_music_speech/v1/local/refine_annotations_bn.py +++ b/egs/bn_music_speech/v1/local/refine_annotations_bn.py @@ -10,6 +10,7 @@ # designated length are created. # # This file is meant to be invoked from make_bn.sh. +from __future__ import division import sys, os def seg_to_string(seg): @@ -23,7 +24,7 @@ def seg_to_string(seg): def process_segs(raw_segs): segs = [] for seg in raw_segs: - lower, upper = map(float, seg.rstrip().split(" ")) + lower, upper = [float(i) for i in seg.rstrip().split(" ")] segs.append((lower, upper)) return segs @@ -60,8 +61,8 @@ def resegment(music, speech, other, frame_length, min_seg): start_frame = 0 for i in range(1, len(frame2classes)): if curr_class != frame2classes[i]: - start = float(start_frame) / frame_length - end = float(i) / frame_length + start = float(start_frame)/frame_length + end = float(i)/frame_length if end - start > min_seg: if curr_class == "music": new_music.append((start, end)) diff --git a/egs/callhome_diarization/v1/diarization/cluster.sh b/egs/callhome_diarization/v1/diarization/cluster.sh index 4f46b3ba5ef..5e5c6e9dbe5 100755 --- a/egs/callhome_diarization/v1/diarization/cluster.sh +++ b/egs/callhome_diarization/v1/diarization/cluster.sh @@ -14,6 +14,9 @@ stage=0 nj=10 cleanup=true threshold=0.5 +max_spk_fraction=1.0 +first_pass_max_utterances=32767 +rttm_channel=0 read_costs=false reco2num_spk= # End configuration section. @@ -35,6 +38,17 @@ if [ $# != 2 ]; then echo " --threshold # Cluster stopping criterion. Clusters with scores greater" echo " # than this value will be merged until all clusters" echo " # exceed this value." + echo " --max-spk-fraction # Clusters with total fraction of utterances greater than" + echo " # this value will not be merged. This is active only when" + echo " # reco2num-spk is supplied and" + echo " # 1.0 / num-spk <= max-spk-fraction <= 1.0." + echo " --first-pass-max-utterances # If the number of utterances is larger than first-pass-max-utterances," + echo " # then clustering is done in two passes. In the first pass, input points" + echo " # are divided into contiguous subsets of size first-pass-max-utterances" + echo " # and each subset is clustered separately. In the second pass, the first" + echo " # pass clusters are merged into the final set of clusters." + echo " --rttm-channel # The value passed into the RTTM channel field. Only affects" + echo " # the format of the RTTM file." echo " --read-costs # If true, interpret input scores as costs, i.e. similarity" echo " # is indicated by smaller values. If enabled, clusters will" echo " # be merged until all cluster scores are less than the" @@ -75,8 +89,10 @@ if [ $stage -le 0 ]; then echo "$0: clustering scores" $cmd JOB=1:$nj $dir/log/agglomerative_cluster.JOB.log \ agglomerative-cluster --threshold=$threshold --read-costs=$read_costs \ - --reco2num-spk-rspecifier=$reco2num_spk scp:"$feats" \ - ark,t:$sdata/JOB/spk2utt ark,t:$dir/labels.JOB || exit 1; + --reco2num-spk-rspecifier=$reco2num_spk \ + --max-spk-fraction=$max_spk_fraction \ + --first-pass-max-utterances=$first_pass_max_utterances \ + scp:"$feats" ark,t:$sdata/JOB/spk2utt ark,t:$dir/labels.JOB || exit 1; fi if [ $stage -le 1 ]; then @@ -86,7 +102,7 @@ fi if [ $stage -le 2 ]; then echo "$0: computing RTTM" - diarization/make_rttm.py $srcdir/segments $dir/labels $dir/rttm || exit 1; + diarization/make_rttm.py --rttm-channel $rttm_channel $srcdir/segments $dir/labels $dir/rttm || exit 1; fi if $cleanup ; then diff --git a/egs/callhome_diarization/v1/diarization/extract_ivectors.sh b/egs/callhome_diarization/v1/diarization/extract_ivectors.sh index 370a37b873e..d7bb389bad5 100755 --- a/egs/callhome_diarization/v1/diarization/extract_ivectors.sh +++ b/egs/callhome_diarization/v1/diarization/extract_ivectors.sh @@ -29,6 +29,10 @@ min_post=0.025 # Minimum posterior to use (posteriors below this are pruned out) posterior_scale=1.0 # This scale helps to control for successve features being highly # correlated. E.g. try 0.1 or 0.3. apply_cmn=true # If true, apply sliding window cepstral mean normalization +apply_deltas=true # If true, copy the delta options from the i-vector extractor directory. + # If false, we won't add deltas in this step. For speaker diarization, + # we sometimes need to write features to disk that already have various + # post-processing applied so adding deltas is no longer needed in this stage. # End configuration section. echo "$0 $@" # Print the command line for logging @@ -57,6 +61,12 @@ if [ $# != 3 ]; then echo " --min-post # Pruning threshold for posteriors" echo " --apply-cmn # if true, apply sliding window cepstral mean" echo " # normalization to features" + echo " --apply-deltas # If true, copy the delta options from the i-vector" + echo " # extractor directory. If false, we won't add deltas" + echo " # in this step. For speaker diarization, we sometimes" + echo " # need to write features to disk that already have" + echo " # various post-processing applied so adding deltas is" + echo " # no longer needed in this stage." exit 1; fi @@ -82,7 +92,7 @@ if [ $stage -le 0 ]; then fi utils/data/get_uniform_subsegments.py \ --max-segment-duration=$window \ - --overlap-duration=$(echo "$window-$period" | bc) \ + --overlap-duration=$(perl -e "print $window-$period") \ --max-remaining-duration=$min_segment \ --constant-duration=True \ $segments > $dir/subsegments @@ -95,7 +105,11 @@ mkdir -p $dir/log sub_sdata=$sub_data/split$nj; utils/split_data.sh $sub_data $nj || exit 1; -delta_opts=`cat $srcdir/delta_opts 2>/dev/null` +if $apply_deltas; then + delta_opts=`cat $srcdir/delta_opts 2>/dev/null` +else + delta_opts="--delta-order=0" +fi ## Set up features. if $apply_cmn; then diff --git a/egs/callhome_diarization/v1/diarization/make_rttm.py b/egs/callhome_diarization/v1/diarization/make_rttm.py index 1705411069f..cc1145ab9ab 100755 --- a/egs/callhome_diarization/v1/diarization/make_rttm.py +++ b/egs/callhome_diarization/v1/diarization/make_rttm.py @@ -51,6 +51,9 @@ def get_args(): help="Input labels file") parser.add_argument("rttm_file", type=str, help="Output RTTM file") + parser.add_argument("--rttm-channel", type=int, default=0, + help="The value passed into the RTTM channel field. \ + Only affects the format of the RTTM file.") args = parser.parse_args() return args @@ -80,7 +83,7 @@ def main(): # Cut up overlapping segments so they are contiguous contiguous_segs = [] - for reco in reco2segs: + for reco in sorted(reco2segs): segs = reco2segs[reco].strip().split() new_segs = "" for i in range(1, len(segs)-1): @@ -120,8 +123,8 @@ def main(): reco = segs[0] for i in range(1, len(segs)): start, end, label = segs[i].strip().split(',') - print("SPEAKER {0} 0 {1:7.3f} {2:7.3f} {3} ".format( - reco, float(start), float(end)-float(start), label), file=rttm_writer) + print("SPEAKER {0} {1} {2:7.3f} {3:7.3f} {4} ".format( + reco, args.rttm_channel, float(start), float(end)-float(start), label), file=rttm_writer) if __name__ == '__main__': main() diff --git a/egs/callhome_diarization/v1/diarization/nnet3/xvector/extract_xvectors.sh b/egs/callhome_diarization/v1/diarization/nnet3/xvector/extract_xvectors.sh index d7591a6a3a8..8d579138c73 100755 --- a/egs/callhome_diarization/v1/diarization/nnet3/xvector/extract_xvectors.sh +++ b/egs/callhome_diarization/v1/diarization/nnet3/xvector/extract_xvectors.sh @@ -102,7 +102,7 @@ if [ $stage -le 0 ]; then fi utils/data/get_uniform_subsegments.py \ --max-segment-duration=$window \ - --overlap-duration=$(echo "$window-$period" | bc) \ + --overlap-duration=$(perl -e "print ($window-$period);") \ --max-remaining-duration=$min_segment \ --constant-duration=True \ $segments > $dir/subsegments diff --git a/egs/callhome_diarization/v1/local/make_musan.py b/egs/callhome_diarization/v1/local/make_musan.py index b3f6652ba40..7c50adf7c83 100755 --- a/egs/callhome_diarization/v1/local/make_musan.py +++ b/egs/callhome_diarization/v1/local/make_musan.py @@ -43,9 +43,9 @@ def prepare_music(root_dir, use_vocals): utt2wav_str = utt2wav_str + utt + " sox -t wav " + utt2wav[utt] + " -r 8k -t wav - |\n" num_good_files += 1 else: - print("Missing file", utt) + print("Missing file: {}".format(utt)) num_bad_files += 1 - print("In music directory, processed", num_good_files, "files;", num_bad_files, "had missing wav data") + print("In music directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files)) return utt2spk_str, utt2wav_str def prepare_speech(root_dir): @@ -69,9 +69,9 @@ def prepare_speech(root_dir): utt2wav_str = utt2wav_str + utt + " sox -t wav " + utt2wav[utt] + " -r 8k -t wav - |\n" num_good_files += 1 else: - print("Missing file", utt) + print("Missing file: {}".format(utt)) num_bad_files += 1 - print("In speech directory, processed", num_good_files, "files;", num_bad_files, "had missing wav data") + print("In speech directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files)) return utt2spk_str, utt2wav_str def prepare_noise(root_dir): @@ -95,9 +95,9 @@ def prepare_noise(root_dir): utt2wav_str = utt2wav_str + utt + " sox -t wav " + utt2wav[utt] + " -r 8k -t wav - |\n" num_good_files += 1 else: - print("Missing file", utt) + print("Missing file: {}".format(utt)) num_bad_files += 1 - print("In noise directory, processed", num_good_files, "files;", num_bad_files, "had missing wav data") + print("In noise directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files)) return utt2spk_str, utt2wav_str def main(): diff --git a/egs/callhome_diarization/v1/local/make_swbd2_phase1.pl b/egs/callhome_diarization/v1/local/make_swbd2_phase1.pl new file mode 100755 index 00000000000..71b26b55de5 --- /dev/null +++ b/egs/callhome_diarization/v1/local/make_swbd2_phase1.pl @@ -0,0 +1,106 @@ +#!/usr/bin/perl +use warnings; #sed replacement for -w perl parameter +# +# Copyright 2017 David Snyder +# Apache 2.0 + +if (@ARGV != 2) { + print STDERR "Usage: $0 \n"; + print STDERR "e.g. $0 /export/corpora3/LDC/LDC98S75 data/swbd2_phase1_train\n"; + exit(1); +} +($db_base, $out_dir) = @ARGV; + +if (system("mkdir -p $out_dir")) { + die "Error making directory $out_dir"; +} + +open(CS, "<$db_base/doc/callstat.tbl") || die "Could not open $db_base/doc/callstat.tbl"; +open(GNDR, ">$out_dir/spk2gender") || die "Could not open the output file $out_dir/spk2gender"; +open(SPKR, ">$out_dir/utt2spk") || die "Could not open the output file $out_dir/utt2spk"; +open(WAV, ">$out_dir/wav.scp") || die "Could not open the output file $out_dir/wav.scp"; + +@badAudio = ("3", "4"); + +$tmp_dir = "$out_dir/tmp"; +if (system("mkdir -p $tmp_dir") != 0) { + die "Error making directory $tmp_dir"; +} + +if (system("find $db_base -name '*.sph' > $tmp_dir/sph.list") != 0) { + die "Error getting list of sph files"; +} + +open(WAVLIST, "<$tmp_dir/sph.list") or die "cannot open wav list"; + +%wavs = (); +while() { + chomp; + $sph = $_; + @t = split("/",$sph); + @t1 = split("[./]",$t[$#t]); + $uttId = $t1[0]; + $wavs{$uttId} = $sph; +} + +while () { + $line = $_ ; + @A = split(",", $line); + @A1 = split("[./]",$A[0]); + $wav = $A1[0]; + if (/$wav/i ~~ @badAudio) { + # do nothing + print "Bad Audio = $wav"; + } else { + $spkr1= "sw_" . $A[2]; + $spkr2= "sw_" . $A[3]; + $gender1 = $A[5]; + $gender2 = $A[6]; + if ($gender1 eq "M") { + $gender1 = "m"; + } elsif ($gender1 eq "F") { + $gender1 = "f"; + } else { + die "Unknown Gender in $line"; + } + if ($gender2 eq "M") { + $gender2 = "m"; + } elsif ($gender2 eq "F") { + $gender2 = "f"; + } else { + die "Unknown Gender in $line"; + } + if (-e "$wavs{$wav}") { + $uttId = $spkr1 ."_" . $wav ."_1"; + if (!$spk2gender{$spkr1}) { + $spk2gender{$spkr1} = $gender1; + print GNDR "$spkr1"," $gender1\n"; + } + print WAV "$uttId"," sph2pipe -f wav -p -c 1 $wavs{$wav} |\n"; + print SPKR "$uttId"," $spkr1","\n"; + + $uttId = $spkr2 . "_" . $wav ."_2"; + if (!$spk2gender{$spkr2}) { + $spk2gender{$spkr2} = $gender2; + print GNDR "$spkr2"," $gender2\n"; + } + print WAV "$uttId"," sph2pipe -f wav -p -c 2 $wavs{$wav} |\n"; + print SPKR "$uttId"," $spkr2","\n"; + } else { + print STDERR "Missing $wavs{$wav} for $wav\n"; + } + } +} + +close(WAV) || die; +close(SPKR) || die; +close(GNDR) || die; +if (system("utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_dir"; +} +if (system("utils/fix_data_dir.sh $out_dir") != 0) { + die "Error fixing data dir $out_dir"; +} +if (system("utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { + die "Error validating directory $out_dir"; +} diff --git a/egs/callhome_diarization/v1/run.sh b/egs/callhome_diarization/v1/run.sh index acc48bd24f9..f4652c0c0ef 100755 --- a/egs/callhome_diarization/v1/run.sh +++ b/egs/callhome_diarization/v1/run.sh @@ -188,7 +188,7 @@ if [ $stage -le 6 ]; then der=$(grep -oP 'DIARIZATION\ ERROR\ =\ \K[0-9]+([.][0-9]+)?' \ exp/tuning/${dataset}_t${threshold}) - if [ $(echo $der'<'$best_der | bc -l) -eq 1 ]; then + if [ $(perl -e "print ($der < $best_der ? 1 : 0);") -eq 1 ]; then best_der=$der best_threshold=$threshold fi diff --git a/egs/callhome_diarization/v2/run.sh b/egs/callhome_diarization/v2/run.sh index 4f730d4753c..b79717e2348 100755 --- a/egs/callhome_diarization/v2/run.sh +++ b/egs/callhome_diarization/v2/run.sh @@ -115,7 +115,7 @@ if [ $stage -le 2 ]; then # Make a reverberated version of the SWBD+SRE list. Note that we don't add any # additive noise here. - python steps/data/reverberate_data_dir.py \ + steps/data/reverberate_data_dir.py \ "${rvb_opts[@]}" \ --speech-rvb-probability 1 \ --pointsource-noise-addition-probability 0 \ @@ -140,11 +140,11 @@ if [ $stage -le 2 ]; then done # Augment with musan_noise - python steps/data/augment_data_dir.py --utt-suffix "noise" --fg-interval 1 --fg-snrs "15:10:5:0" --fg-noise-dir "data/musan_noise" data/train data/train_noise + steps/data/augment_data_dir.py --utt-suffix "noise" --fg-interval 1 --fg-snrs "15:10:5:0" --fg-noise-dir "data/musan_noise" data/train data/train_noise # Augment with musan_music - python steps/data/augment_data_dir.py --utt-suffix "music" --bg-snrs "15:10:8:5" --num-bg-noises "1" --bg-noise-dir "data/musan_music" data/train data/train_music + steps/data/augment_data_dir.py --utt-suffix "music" --bg-snrs "15:10:8:5" --num-bg-noises "1" --bg-noise-dir "data/musan_music" data/train data/train_music # Augment with musan_speech - python steps/data/augment_data_dir.py --utt-suffix "babble" --bg-snrs "20:17:15:13" --num-bg-noises "3:4:5:6:7" --bg-noise-dir "data/musan_speech" data/train data/train_babble + steps/data/augment_data_dir.py --utt-suffix "babble" --bg-snrs "20:17:15:13" --num-bg-noises "3:4:5:6:7" --bg-noise-dir "data/musan_speech" data/train data/train_babble # Combine reverb, noise, music, and babble into one directory. utils/combine_data.sh data/train_aug data/train_reverb data/train_noise data/train_music data/train_babble @@ -297,7 +297,7 @@ if [ $stage -le 10 ]; then der=$(grep -oP 'DIARIZATION\ ERROR\ =\ \K[0-9]+([.][0-9]+)?' \ $nnet_dir/tuning/${dataset}_t${threshold}) - if [ $(echo $der'<'$best_der | bc -l) -eq 1 ]; then + if [ $(perl -e "print ($der < $best_der ? 1 : 0);") -eq 1 ]; then best_der=$der best_threshold=$threshold fi diff --git a/egs/callhome_egyptian/s5/local/callhome_prepare_dict.sh b/egs/callhome_egyptian/s5/local/callhome_prepare_dict.sh index 62bca974e53..d9faa97f266 100755 --- a/egs/callhome_egyptian/s5/local/callhome_prepare_dict.sh +++ b/egs/callhome_egyptian/s5/local/callhome_prepare_dict.sh @@ -54,9 +54,8 @@ cat $dir/silence_phones.txt| awk '{printf("%s ", $1);} END{printf "\n";}' > \ $dir/extra_questions.txt || exit 1; # Add prons for laughter, noise, oov -for w in `grep -v sil $dir/silence_phones.txt`; do -sed -i "/\[$w\]/d" $tmpdir/lexicon.3 -done +w=$(grep -v sil $dir/silence_phones.txt | tr '\n' '|') +perl -i -ne "print unless /\[(${w%?})\]/" $tmpdir/lexicon.3 for w in `grep -v sil $dir/silence_phones.txt`; do echo "[$w] $w" diff --git a/egs/callhome_egyptian/s5/local/convert_symtable_to_utf.py b/egs/callhome_egyptian/s5/local/convert_symtable_to_utf.py index f5b69a1ff86..7192ff7a1cc 100644 --- a/egs/callhome_egyptian/s5/local/convert_symtable_to_utf.py +++ b/egs/callhome_egyptian/s5/local/convert_symtable_to_utf.py @@ -1,3 +1,4 @@ +from __future__ import print_function #!/usr/bin/env py # Converts a romanized ECA word list (symbol table) to @@ -7,9 +8,9 @@ import codecs if len(sys.argv) < 3: - print "USAGE: local/convert_symtable_to_utf.py [SYMTABLE] [ECA-LEXICON]" - print "E.g., local/convert_symtable_to_utf.py data/lang/words.txt \ - /export/corpora/LDC/LDC99L22" + print("USAGE: local/convert_symtable_to_utf.py [SYMTABLE] [ECA-LEXICON]") + print("E.g., local/convert_symtable_to_utf.py data/lang/words.txt \ + /export/corpora/LDC/LDC99L22") sys.exit(1) # Note that the ECA lexicon's default encoding is ISO-8859-6, not UTF8 diff --git a/egs/callhome_egyptian/s5/local/ctm.sh b/egs/callhome_egyptian/s5/local/ctm.sh index 14056b7a44b..64a7cf0d4f6 100755 --- a/egs/callhome_egyptian/s5/local/ctm.sh +++ b/egs/callhome_egyptian/s5/local/ctm.sh @@ -18,9 +18,9 @@ fi steps/get_ctm.sh $data_dir $lang_dir $decode_dir # Make sure that channel markers match -#sed -i "s:\s.*_fsp-([AB]): \1:g" data/dev/stm -#ls exp/tri5a/decode_dev/score_*/dev.ctm | xargs -I {} sed -i -r 's:fsp\s1\s:fsp A :g' {} -#ls exp/tri5a/decode_dev/score_*/dev.ctm | xargs -I {} sed -i -r 's:fsp\s2\s:fsp B :g' {} +#perl -i -pe "s:\s.*_fsp-([AB]): \1:g" data/dev/stm +#ls exp/tri5a/decode_dev/score_*/dev.ctm | xargs -I {} perl -i -pe 's:fsp\s1\s:fsp A :g' {} +#ls exp/tri5a/decode_dev/score_*/dev.ctm | xargs -I {} perl -i -pe 's:fsp\s2\s:fsp B :g' {} # Get the environment variables . /export/babel/data/software/env.sh diff --git a/egs/callhome_egyptian/s5/local/splits/get_conversation.py b/egs/callhome_egyptian/s5/local/splits/get_conversation.py index c999d3e597e..80f66174e2b 100755 --- a/egs/callhome_egyptian/s5/local/splits/get_conversation.py +++ b/egs/callhome_egyptian/s5/local/splits/get_conversation.py @@ -1,5 +1,6 @@ #!/usr/bin/env python +from __future__ import print_function import os import re @@ -37,14 +38,14 @@ evaltest[pathComponents[12]] = numberOfConversations testConv = testConv + numberOfConversations -print "==============Train===============" -print train -print "Total Conversations in train = " + str(trainConv) -print "==============Dev===============" -print devtest -print "Total Conversations in dev = " + str(devConv) -print "==============Test===============" -print evaltest -print "Total Conversations in test = " + str(testConv) -print "=================================" -print "Total Conversations in Corpus = " + str(trainConv + devConv + testConv) +print("==============Train===============") +print(train) +print("Total Conversations in train = {}".format(trainConv)) +print("==============Dev===============") +print(devtest) +print("Total Conversations in dev = {}".format(devConv)) +print("==============Test===============") +print(evaltest) +print("Total Conversations in test = {}".format(testConv)) +print("=================================") +print("Total Conversations in Corpus = {}".format(trainConv + devConv + testConv)) diff --git a/egs/chime1/s5/local/chime1_prepare_data.sh b/egs/chime1/s5/local/chime1_prepare_data.sh index e60c46ff8da..c5963b5d4ab 100755 --- a/egs/chime1/s5/local/chime1_prepare_data.sh +++ b/egs/chime1/s5/local/chime1_prepare_data.sh @@ -53,7 +53,7 @@ for x in "devel" "test"; do for sid in `seq 34`; do sid2=`printf "s%02d" $sid` ls -1 $wav_dir/*/s${sid}_*.wav \ - | perl -ape "s/(.*)\/(.*)\/s.*_(.*).wav/${sid2}_\3_\2$\t\1\/\2\/s${sid}_\3.wav/;" \ + | perl -ape "s/(.*)\/(.*)\/s.*_(.*).wav/${sid2}_\3_\2\t\1\/\2\/s${sid}_\3.wav/;" \ | sort >> $scp done fi @@ -68,7 +68,7 @@ for x in $set_list; do # Create utt2spk files # No speaker ID - perl -ape "s/(.*)\t.*/\1$\t\1/;" < "$scp" > "$data/$x/utt2spk" + perl -ape "s/(.*)\t.*/\1\t\1/;" < "$scp" > "$data/$x/utt2spk" # Use speaker ID # perl -ape "s/(s..)(.*)\\t.*/\1\2\t\1/;" < "$scp" > "$data/$x/utt2spk" diff --git a/egs/chime4/s5_1ch/RESULTS b/egs/chime4/s5_1ch/RESULTS index c0146b772b7..3e5f752a803 100644 --- a/egs/chime4/s5_1ch/RESULTS +++ b/egs/chime4/s5_1ch/RESULTS @@ -17,7 +17,22 @@ et05_simu WER: 33.30% (Average), 26.65% (BUS), 38.40% (CAFE), 34.68% (PEDESTRIAN et05_real WER: 37.54% (Average), 51.92% (BUS), 39.67% (CAFE), 34.04% (PEDESTRIAN), 24.54% (STREET) ------------------- -Advanced baseline: +GMM noisy multi-condition without enhancement using 6 channel data +exp/tri3b_tr05_multi_noisy/best_wer_isolated_1ch_track.result +------------------- +best overall dt05 WER 22.32% (language model weight = 10) +------------------- +dt05_simu WER: 23.24% (Average), 19.28% (BUS), 28.41% (CAFE), 19.16% (PEDESTRIAN), 26.12% (STREET) +------------------- +dt05_real WER: 21.40% (Average), 25.86% (BUS), 21.81% (CAFE), 16.80% (PEDESTRIAN), 21.12% (STREET) +------------------- +et05_simu WER: 32.03% (Average), 25.42% (BUS), 36.25% (CAFE), 33.34% (PEDESTRIAN), 33.10% (STREET) +------------------- +et05_real WER: 36.14% (Average), 49.28% (BUS), 38.79% (CAFE), 32.44% (PEDESTRIAN), 24.06% (STREET) +------------------- + +GMM noisy multi-condition without enhancement using 6 channel data plus enhanced data +exp/tri3b_tr05_multi_noisy/best_wer_isolated_1ch_track.result ------------------- best overall dt05 WER 22.28% (language model weight = 10) ------------------- @@ -30,6 +45,34 @@ et05_simu WER: 32.18% (Average), 25.33% (BUS), 37.37% (CAFE), 33.36% (PEDESTRIAN et05_real WER: 35.54% (Average), 49.07% (BUS), 38.94% (CAFE), 31.60% (PEDESTRIAN), 22.56% (STREET) ------------------- +GMM noisy multi-condition with BLSTM masking using 6 channel data +exp/tri3b_tr05_multi_noisy/best_wer_single_BLSTMmask.result +------------------- +best overall dt05 WER 28.82% (language model weight = 14) +------------------- +dt05_simu WER: 28.54% (Average), 25.46% (BUS), 33.47% (CAFE), 25.19% (PEDESTRIAN), 30.06% (STREET) +------------------- +dt05_real WER: 29.10% (Average), 33.46% (BUS), 31.80% (CAFE), 25.71% (PEDESTRIAN), 25.42% (STREET) +------------------- +et05_simu WER: 36.10% (Average), 30.97% (BUS), 40.42% (CAFE), 35.82% (PEDESTRIAN), 37.19% (STREET) +------------------- +et05_real WER: 41.84% (Average), 52.57% (BUS), 46.41% (CAFE), 39.87% (PEDESTRIAN), 28.52% (STREET) +------------------- + +GMM noisy multi-condition with BLSTM masking using 6 channel data plus enhanced data +exp/tri3b_tr05_multi_noisy/best_wer_single_BLSTMmask.result +------------------- +best overall dt05 WER 22.72% (language model weight = 13) +------------------- +dt05_simu WER: 23.37% (Average), 20.71% (BUS), 28.26% (CAFE), 19.85% (PEDESTRIAN), 24.66% (STREET) +------------------- +dt05_real WER: 22.07% (Average), 25.92% (BUS), 24.32% (CAFE), 18.47% (PEDESTRIAN), 19.58% (STREET) +------------------- +et05_simu WER: 30.41% (Average), 24.08% (BUS), 35.86% (CAFE), 30.80% (PEDESTRIAN), 30.89% (STREET) +------------------- +et05_real WER: 34.02% (Average), 44.68% (BUS), 37.19% (CAFE), 31.73% (PEDESTRIAN), 22.49% (STREET) +------------------- + DNN sMBR exp/tri4a_dnn_tr05_multi_noisy_smbr_i1lats/best_wer_isolated_1ch_track.result ------------------- @@ -45,7 +88,7 @@ et05_simu WER: 24.13% (Average), 19.65% (BUS), 27.57% (CAFE), 23.14% (PEDESTRIAN et05_real WER: 27.68% (Average), 40.40% (BUS), 28.95% (CAFE), 24.25% (PEDESTRIAN), 17.13% (STREET) ------------------- -Advanced baseline: +DNN sMBR using all 6 channel data ------------------- best overall dt05 WER 12.84% (language model weight = 12) (Number of iterations = 3) @@ -73,7 +116,7 @@ et05_simu WER: 22.32% (Average), 17.82% (BUS), 25.48% (CAFE), 21.70% (PEDESTRIAN et05_real WER: 24.92% (Average), 37.52% (BUS), 26.45% (CAFE), 21.28% (PEDESTRIAN), 14.44% (STREET) ------------------- -Advanced baseline: +5-gram rescoring using all 6 channel data ------------------- best overall dt05 WER 11.07% (language model weight = 12) ------------------- @@ -100,7 +143,7 @@ et05_simu WER: 20.84% (Average), 16.49% (BUS), 23.91% (CAFE), 20.25% (PEDESTRIAN et05_real WER: 23.70% (Average), 35.93% (BUS), 24.60% (CAFE), 19.94% (PEDESTRIAN), 14.36% (STREET) ------------------- -Advanced baseline: +RNNLM using all 6 channel data ------------------- best overall dt05 WER 9.99% (language model weight = 14) ------------------- @@ -113,30 +156,86 @@ et05_simu WER: 17.31% (Average), 12.81% (BUS), 20.32% (CAFE), 17.03% (PEDESTRIAN et05_real WER: 18.10% (Average), 26.58% (BUS), 19.97% (CAFE), 14.44% (PEDESTRIAN), 11.43% (STREET) ------------------- -TDNN -exp/chain/tdnn1d_sp/best_wer_beamformit_5mics.result +TDNN using all 6 channel data +exp/chain/tdnniso_sp/best_wer_beamformit_5mics.result +------------------- +best overall dt05 WER 9.56% (language model weight = 10) +------------------- +dt05_simu WER: 10.23% (Average), 8.86% (BUS), 13.13% (CAFE), 7.94% (PEDESTRIAN), 11.00% (STREET) +------------------- +dt05_real WER: 8.89% (Average), 11.90% (BUS), 8.54% (CAFE), 6.09% (PEDESTRIAN), 9.03% (STREET) +------------------- +et05_simu WER: 16.48% (Average), 12.87% (BUS), 18.60% (CAFE), 15.52% (PEDESTRIAN), 18.94% (STREET) +------------------- +et05_real WER: 16.34% (Average), 24.32% (BUS), 16.51% (CAFE), 13.43% (PEDESTRIAN), 11.11% (STREET) +------------------- + +TDNN+RNNLM using all 6 channel data +exp/chain/tdnniso_sp_smbr_lmrescore/best_wer_beamformit_5mics_rnnlm_5k_h300_w0.5_n100.result +------------------- +best overall dt05 WER 7.21% (language model weight = 11) +------------------- +dt05_simu WER: 7.78% (Average), 6.52% (BUS), 10.27% (CAFE), 5.69% (PEDESTRIAN), 8.66% (STREET) +------------------- +dt05_real WER: 6.64% (Average), 9.06% (BUS), 6.62% (CAFE), 4.26% (PEDESTRIAN), 6.61% (STREET) +------------------- +et05_simu WER: 13.54% (Average), 10.22% (BUS), 15.07% (CAFE), 12.94% (PEDESTRIAN), 15.93% (STREET) +------------------- +et05_real WER: 12.92% (Average), 20.79% (BUS), 12.35% (CAFE), 9.62% (PEDESTRIAN), 8.91% (STREET) +------------------- + +TDNN with BLSTM masking using all 6 channel data +exp/chain/tdnn1a_sp/best_wer_single_BLSTMmask.result +------------------- +best overall dt05 WER 18.00% (language model weight = 13) +------------------- +dt05_simu WER: 18.81% (Average), 15.34% (BUS), 23.58% (CAFE), 15.27% (PEDESTRIAN), 21.06% (STREET) +------------------- +dt05_real WER: 17.18% (Average), 21.12% (BUS), 19.45% (CAFE), 11.61% (PEDESTRIAN), 16.53% (STREET) +------------------- +et05_simu WER: 25.85% (Average), 20.06% (BUS), 30.13% (CAFE), 26.88% (PEDESTRIAN), 26.32% (STREET) +------------------- +et05_real WER: 27.68% (Average), 37.88% (BUS), 29.51% (CAFE), 24.74% (PEDESTRIAN), 18.60% (STREET) +------------------- + +TDNN+RNNLM with BLSTM masking using all 6 channel data +exp/chain/tdnn1a_sp/best_wer_single_BLSTMmask.result +------------------- +best overall dt05 WER 14.38% (language model weight = 14) +------------------- +dt05_simu WER: 15.62% (Average), 12.36% (BUS), 20.46% (CAFE), 12.11% (PEDESTRIAN), 17.55% (STREET) +------------------- +dt05_real WER: 13.15% (Average), 16.43% (BUS), 15.21% (CAFE), 8.59% (PEDESTRIAN), 12.37% (STREET) +------------------- +et05_simu WER: 21.61% (Average), 16.01% (BUS), 25.87% (CAFE), 22.15% (PEDESTRIAN), 22.39% (STREET) +------------------- +et05_real WER: 22.47% (Average), 32.34% (BUS), 24.08% (CAFE), 18.91% (PEDESTRIAN), 14.57% (STREET) +------------------- + +TDNN with BLSTM masking using all 6 channel data plus enhanced data +exp/chain/tdnn1a_sp/best_wer_single_BLSTMmask.result ------------------- -best overall dt05 WER 10.37% (language model weight = 9) +best overall dt05 WER 11.73% (language model weight = 12) ------------------- -dt05_simu WER: 10.79% (Average), 9.62% (BUS), 13.70% (CAFE), 8.23% (PEDESTRIAN), 11.61% (STREET) +dt05_simu WER: 13.06% (Average), 10.78% (BUS), 17.20% (CAFE), 10.15% (PEDESTRIAN), 14.10% (STREET) ------------------- -dt05_real WER: 9.95% (Average), 14.38% (BUS), 8.81% (CAFE), 6.43% (PEDESTRIAN), 10.19% (STREET) +dt05_real WER: 10.40% (Average), 13.44% (BUS), 10.72% (CAFE), 7.29% (PEDESTRIAN), 10.16% (STREET) ------------------- -et05_simu WER: 17.18% (Average), 13.75% (BUS), 19.48% (CAFE), 15.82% (PEDESTRIAN), 19.67% (STREET) +et05_simu WER: 19.48% (Average), 14.48% (BUS), 23.10% (CAFE), 19.84% (PEDESTRIAN), 20.49% (STREET) ------------------- -et05_real WER: 18.36% (Average), 30.77% (BUS), 16.17% (CAFE), 14.29% (PEDESTRIAN), 12.20% (STREET) +et05_real WER: 19.08% (Average), 27.43% (BUS), 19.76% (CAFE), 16.93% (PEDESTRIAN), 12.22% (STREET) ------------------- -TDNN+RNNLM -exp/chain/tdnn1d_sp_smbr_lmrescore/best_wer_beamformit_5mics_rnnlm_5k_h300_w0.5_n100.result +TDNN+RNNLM with BLSTM masking using all 6 channel data plus enhanced data +exp/chain/tdnn1a_sp/best_wer_single_BLSTMmask.result ------------------- -best overall dt05 WER 7.98% (language model weight = 10) +best overall dt05 WER 8.95% (language model weight = 13) ------------------- -dt05_simu WER: 8.40% (Average), 7.37% (BUS), 10.91% (CAFE), 6.36% (PEDESTRIAN), 8.97% (STREET) +dt05_simu WER: 10.28% (Average), 8.51% (BUS), 13.88% (CAFE), 7.58% (PEDESTRIAN), 11.17% (STREET) ------------------- -dt05_real WER: 7.56% (Average), 11.58% (BUS), 6.58% (CAFE), 4.41% (PEDESTRIAN), 7.65% (STREET) +dt05_real WER: 7.62% (Average), 10.25% (BUS), 7.86% (CAFE), 5.31% (PEDESTRIAN), 7.05% (STREET) ------------------- -et05_simu WER: 13.91% (Average), 10.87% (BUS), 15.09% (CAFE), 12.78% (PEDESTRIAN), 16.88% (STREET) +et05_simu WER: 16.18% (Average), 12.03% (BUS), 18.71% (CAFE), 16.62% (PEDESTRIAN), 17.35% (STREET) ------------------- -et05_real WER: 14.99% (Average), 26.88% (BUS), 13.32% (CAFE), 10.07% (PEDESTRIAN), 9.71% (STREET) +et05_real WER: 15.08% (Average), 22.96% (BUS), 15.45% (CAFE), 12.74% (PEDESTRIAN), 9.17% (STREET) ------------------- diff --git a/egs/chime4/s5_1ch/local/CHiME3_simulate_data_patched_parallel.m b/egs/chime4/s5_1ch/local/CHiME3_simulate_data_patched_parallel.m new file mode 100755 index 00000000000..49c1ed48018 --- /dev/null +++ b/egs/chime4/s5_1ch/local/CHiME3_simulate_data_patched_parallel.m @@ -0,0 +1,362 @@ +function CHiME3_simulate_data_patched_parallel(official,nj,chime4_dir,chime3_dir) + +% CHIME3_SIMULATE_DATA Creates simulated data for the 3rd CHiME Challenge +% +% CHiME3_simulate_data +% CHiME3_simulate_data(official) +% +% Input: +% official: boolean flag indicating whether to recreate the official +% Challenge data (default) or to create new (non-official) data +% +% If you use this software in a publication, please cite: +% +% Jon Barker, Ricard Marxer, Emmanuel Vincent, and Shinji Watanabe, The +% third 'CHiME' Speech Separation and Recognition Challenge: Dataset, +% task and baselines, submitted to IEEE 2015 Automatic Speech Recognition +% and Understanding Workshop (ASRU), 2015. +% +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +% Copyright 2015 University of Sheffield (Jon Barker, Ricard Marxer) +% Inria (Emmanuel Vincent) +% Mitsubishi Electric Research Labs (Shinji Watanabe) +% This software is distributed under the terms of the GNU Public License +% version 3 (http://www.gnu.org/licenses/gpl.txt) +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +utils_folder = sprintf('%s/tools/utils', chime4_dir); +enhancement_folder = sprintf('%s/tools/enhancement/', chime3_dir); +addpath(utils_folder,'-end'); +addpath(enhancement_folder); +sim_folder = sprintf('%s/tools/simulation', chime4_dir); +addpath(sim_folder); +upath = sprintf('%s/data/audio/16kHz/isolated/', chime4_dir); +cpath = sprintf('%s/data/audio/16kHz/embedded/', chime4_dir); +bpath = sprintf('%s/data/audio/16kHz/backgrounds/', chime4_dir); +apath = sprintf('%s/data/annotations/', chime4_dir); +upath_ext = 'local/nn-gev/data/audio/16kHz/isolated_ext/'; +upath_simu = 'local/nn-gev/data/audio/16kHz/isolated/'; +nchan=6; + +% Define hyper-parameters +pow_thresh=-20; % threshold in dB below which a microphone is considered to fail +wlen_sub=256; % STFT window length in samples +blen_sub=4000; % average block length in samples for speech subtraction (250 ms) +ntap_sub=12; % filter length in frames for speech subtraction (88 ms) +wlen_add=1024; % STFT window length in samples for speaker localization +del=-3; % minimum delay (0 for a causal filter) + +%% Create simulated training dataset from original WSJ0 data %% +if exist('equal_filter.mat','file'), + load('equal_filter.mat'); +else + % Compute average power spectrum of booth data + nfram=0; + bth_spec=zeros(wlen_sub/2+1,1); + sets={'tr05' 'dt05'}; + for set_ind=1:length(sets), + set=sets{set_ind}; + mat=json2mat([apath set '_bth.json']); + for utt_ind=1:length(mat), + oname=[mat{utt_ind}.speaker '_' mat{utt_ind}.wsj_name '_BTH']; + fprintf('%s\n',[upath set '_bth/' oname '.CH0.wav']); + o=audioread([upath set '_bth/' oname '.CH0.wav']); + O=stft_multi(o.',wlen_sub); + nfram=nfram+size(O,2); + bth_spec=bth_spec+sum(abs(O).^2,2); + end + end + bth_spec=bth_spec/nfram; + + % Compute average power spectrum of original WSJ0 data + nfram=0; + org_spec=zeros(wlen_sub/2+1,1); + olist=dir([upath 'tr05_org/*.wav']); + for f=1:length(olist), + oname=olist(f).name; + o=audioread([upath 'tr05_org/' oname]); + O=stft_multi(o.',wlen_sub); + nfram=nfram+size(O,2); + org_spec=org_spec+sum(abs(O).^2,2); + end + org_spec=org_spec/nfram; + + % Derive equalization filter + equal_filter=sqrt(bth_spec./org_spec); + save('equal_filter.mat','equal_filter'); +end +% Read official annotations +if official, + mat=json2mat([apath 'tr05_simu.json']); +% Create new (non-official) annotations +else + mat=json2mat([apath 'tr05_org.json']); + ir_mat=json2mat([apath 'tr05_real.json']); + for utt_ind=1:length(mat), + oname=[mat{utt_ind}.speaker '_' mat{utt_ind}.wsj_name '_ORG']; + osize=audioread([upath 'tr05_org/' oname '.wav'],'size'); + dur=osize(1)/16000; + envirs={'BUS' 'CAF' 'PED' 'STR'}; + envir=envirs{randperm(4,1)}; % draw a random environment + mat{utt_ind}.environment=envir; + blist=dir([bpath '*' envir '.CH1.wav']); + dur_diff=inf(1,length(ir_mat)); + for ir_ind=1:length(ir_mat), + if strcmp(ir_mat{ir_ind}.environment,envir), + ir_dur=ir_mat{ir_ind}.end-ir_mat{ir_ind}.start; + dur_diff(ir_ind)=abs(ir_dur-dur); + end + end + ir_ind=find(isinf(dur_diff)); + ir_ind=ir_ind(1); + nfail=true; + while nfail, + bname=blist(randperm(length(blist),1)).name(1:end-8); % draw a random background recording + mat{utt_ind}.noise_wavfile=bname; + bsize=audioread([bpath bname '.CH1.wav'],'size'); + bdur=bsize(1)/16000; + mat{utt_ind}.noise_start=ceil(rand(1)*(bdur-dur)*16000)/16000; % draw a random time + mat{utt_ind}.noise_end=mat{utt_ind}.noise_start+dur; + nname=mat{utt_ind}.noise_wavfile; + nbeg=round(mat{utt_ind}.noise_start*16000)+1; + nend=round(mat{utt_ind}.noise_end*16000); + n=zeros(nend-nbeg+1,nchan); + for c=1:nchan, + n(:,c)=audioread([bpath nname '.CH' int2str(c) '.wav'],[nbeg nend]); + end + npow=sum(n.^2,1); + npow=10*log10(npow/max(npow)); + nfail=any(npow<=pow_thresh); % check for microphone failure + end + xfail=true; + while xfail, + dur_diff(ir_ind)=inf; + [~,ir_ind]=min(dur_diff); % pick impulse response from the same environment with the closest duration + if dur_diff(ir_ind)==inf, + keyboard; + end + mat{utt_ind}.ir_wavfile=ir_mat{ir_ind}.wavfile; + mat{utt_ind}.ir_start=ir_mat{ir_ind}.start; + mat{utt_ind}.ir_end=ir_mat{ir_ind}.end; + iname=mat{utt_ind}.ir_wavfile; + ibeg=round(mat{utt_ind}.ir_start*16000)+1; + iend=round(mat{utt_ind}.ir_end*16000); + x=zeros(iend-ibeg+1,nchan); + for c=1:nchan, + x(:,c)=audioread([cpath iname '.CH' int2str(c) '.wav'],[ibeg iend]); + end + xpow=sum(x.^2,1); + xpow=10*log10(xpow/max(xpow)); + xfail=any(xpow<=pow_thresh); % check for microphone failure + end + mat{utt_ind}=orderfields(mat{utt_ind}); + end + mat2json(mat,[apath 'tr05_simu_new.json']); +end + +p = parpool('local', nj); +% Loop over utterances +parfor utt_ind=1:length(mat), + if official, + udir=[upath_simu 'tr05_' lower(mat{utt_ind}.environment) '_simu/']; + udir_ext=[upath_ext 'tr05_' lower(mat{utt_ind}.environment) '_simu/']; + else + udir=[upath 'tr05_' lower(mat{utt_ind}.environment) '_simu_new/']; + end + if ~exist(udir,'dir'), + system(['mkdir -p ' udir]); + end + if ~exist(udir_ext,'dir'), + system(['mkdir -p ' udir_ext]); + end + oname=[mat{utt_ind}.speaker '_' mat{utt_ind}.wsj_name '_ORG']; + iname=mat{utt_ind}.ir_wavfile; + nname=mat{utt_ind}.noise_wavfile; + uname=[mat{utt_ind}.speaker '_' mat{utt_ind}.wsj_name '_' mat{utt_ind}.environment]; + ibeg=round(mat{utt_ind}.ir_start*16000)+1; + iend=round(mat{utt_ind}.ir_end*16000); + nbeg=round(mat{utt_ind}.noise_start*16000)+1; + nend=round(mat{utt_ind}.noise_end*16000); + + % Load WAV files + fprintf('%s\n',[upath 'tr05_org/' oname '.wav']); + o=audioread([upath 'tr05_org/' oname '.wav']); + [r,fs]=audioread([cpath iname '.CH0.wav'],[ibeg iend]); + fprintf('%s\n',[cpath iname '.CH0.wav'],[ibeg iend]); + x=zeros(iend-ibeg+1,nchan); + n=zeros(nend-nbeg+1,nchan); + for c=1:nchan, + fprintf('%s Place1\n',[cpath iname '.CH' int2str(c) '.wav']); + x(:,c)=audioread([cpath iname '.CH' int2str(c) '.wav'],[ibeg iend]); + n(:,c)=audioread([bpath nname '.CH' int2str(c) '.wav'],[nbeg nend]); + fprintf('%s Place2\n',[bpath nname '.CH' int2str(c) '.wav']); + end + + % Compute the STFT (short window) + O=stft_multi(o.',wlen_sub); + R=stft_multi(r.',wlen_sub); + X=stft_multi(x.',wlen_sub); + + % Estimate 88 ms impulse responses on 250 ms time blocks + A=estimate_ir(R,X,blen_sub,ntap_sub,del); + + % Derive SNR + Y=apply_ir(A,R,del); + y=istft_multi(Y,iend-ibeg+1).'; + SNR=sum(sum(y.^2))/sum(sum((x-y).^2)); + + % Equalize microphone + [~,nfram]=size(O); + O=O.*repmat(equal_filter,[1 nfram]); + o=istft_multi(O,nend-nbeg+1).'; + + % Compute the STFT (long window) + O=stft_multi(o.',wlen_add); + X=stft_multi(x.',wlen_add); + [nbin,nfram] = size(O); + + % Localize and track the speaker + [~,TDOAx]=localize(X); + + % Interpolate the spatial position over the duration of clean speech + TDOA=zeros(nchan,nfram); + for c=1:nchan, + TDOA(c,:)=interp1(0:size(X,2)-1,TDOAx(c,:),(0:nfram-1)/(nfram-1)*(size(X,2)-1)); + end + + % Filter clean speech + Ysimu=zeros(nbin,nfram,nchan); + for f=1:nbin, + for t=1:nfram, + Df=sqrt(1/nchan)*exp(-2*1i*pi*(f-1)/wlen_add*fs*TDOA(:,t)); + Ysimu(f,t,:)=permute(Df*O(f,t),[2 3 1]); + end + end + ysimu=istft_multi(Ysimu,nend-nbeg+1).'; + + % Normalize level and add + ysimu=sqrt(SNR/sum(sum(ysimu.^2))*sum(sum(n.^2)))*ysimu; + xsimu=ysimu+n; + + % Write WAV file + for c=1:nchan, + audiowrite([udir uname '.CH' int2str(c) '.wav'],xsimu(:,c),fs); + audiowrite([udir_ext uname '.CH' int2str(c) '.Noise.wav'],n(:, c),fs); + audiowrite([udir_ext uname '.CH' int2str(c) '.Clean.wav'],ysimu(:, c), fs); + end +end + +%% Create simulated development and test datasets from booth recordings %% +sets={'dt05' 'et05'}; +for set_ind=1:length(sets), + set=sets{set_ind}; + + % Read official annotations + if official, + mat=json2mat([apath set '_simu.json']); + + % Create new (non-official) annotations + else + mat=json2mat([apath set '_real.json']); + clean_mat=json2mat([apath set '_bth.json']); + for utt_ind=1:length(mat), + for clean_ind=1:length(clean_mat), % match noisy utterance with same clean utterance (may be from a different speaker) + if strcmp(clean_mat{clean_ind}.wsj_name,mat{utt_ind}.wsj_name), + break; + end + end + noise_mat=mat{utt_ind}; + mat{utt_ind}=clean_mat{clean_ind}; + mat{utt_ind}.environment=noise_mat.environment; + mat{utt_ind}.noise_wavfile=noise_mat.wavfile; + dur=mat{utt_ind}.end-mat{utt_ind}.start; + noise_dur=noise_mat.end-noise_mat.start; + pbeg=round((dur-noise_dur)/2*16000)/16000; + pend=round((dur-noise_dur)*16000)/16000-pbeg; + mat{utt_ind}.noise_start=noise_mat.start-pbeg; + mat{utt_ind}.noise_end=noise_mat.end+pend; + mat{utt_ind}=orderfields(mat{utt_ind}); + end + mat2json(mat,[apath set '_simu_new.json']); + end + + % Loop over utterances + parfor utt_ind=1:length(mat), + if official, + udir=[upath_simu set '_' lower(mat{utt_ind}.environment) '_simu/']; + udir_ext=[upath_ext set '_' lower(mat{utt_ind}.environment) '_simu/']; + else + udir=[upath set '_' lower(mat{utt_ind}.environment) '_simu_new/']; + end + if ~exist(udir,'dir'), + system(['mkdir -p ' udir]); + end + if ~exist(udir_ext,'dir'), + system(['mkdir -p ' udir_ext]); + end + oname=[mat{utt_ind}.speaker '_' mat{utt_ind}.wsj_name '_BTH']; + nname=mat{utt_ind}.noise_wavfile; + uname=[mat{utt_ind}.speaker '_' mat{utt_ind}.wsj_name '_' mat{utt_ind}.environment]; + tbeg=round(mat{utt_ind}.noise_start*16000)+1; + tend=round(mat{utt_ind}.noise_end*16000); + + % Load WAV files + o=audioread([upath set '_bth/' oname '.CH0.wav']); + [r,fs]=audioread([cpath nname '.CH0.wav'],[tbeg tend]); + nsampl=length(r); + x=zeros(nsampl,nchan); + for c=1:nchan, + x(:,c)=audioread([cpath nname '.CH' int2str(c) '.wav'],[tbeg tend]); + end + + % Compute the STFT (short window) + R=stft_multi(r.',wlen_sub); + X=stft_multi(x.',wlen_sub); + + % Estimate 88 ms impulse responses on 250 ms time blocks + A=estimate_ir(R,X,blen_sub,ntap_sub,del); + + % Filter and subtract close-mic speech + Y=apply_ir(A,R,del); + y=istft_multi(Y,nsampl).'; + level=sum(sum(y.^2)); + n=x-y; + + % Compute the STFT (long window) + O=stft_multi(o.',wlen_add); + X=stft_multi(x.',wlen_add); + [nbin,nfram] = size(O); + + % Localize and track the speaker + [~,TDOAx]=localize(X); + + % Interpolate the spatial position over the duration of clean speech + TDOA=zeros(nchan,nfram); + for c=1:nchan, + TDOA(c,:)=interp1(0:size(X,2)-1,TDOAx(c,:),(0:nfram-1)/(nfram-1)*(size(X,2)-1)); + end + + % Filter clean speech + Ysimu=zeros(nbin,nfram,nchan); + for f=1:nbin, + for t=1:nfram, + Df=sqrt(1/nchan)*exp(-2*1i*pi*(f-1)/wlen_add*fs*TDOA(:,t)); + Ysimu(f,t,:)=permute(Df*O(f,t),[2 3 1]); + end + end + ysimu=istft_multi(Ysimu,nsampl).'; + + % Normalize level and add + ysimu=sqrt(level/sum(sum(ysimu.^2)))*ysimu; + xsimu=ysimu+n; + + % Write WAV file + for c=1:nchan, + audiowrite([udir uname '.CH' int2str(c) '.wav'],xsimu(:,c),fs); + audiowrite([udir_ext uname '.CH' int2str(c) '.Noise.wav'],n(:, c),fs); + audiowrite([udir_ext uname '.CH' int2str(c) '.Clean.wav'],ysimu(:, c), fs); + end + end +end +delete(p); +end diff --git a/egs/chime4/s5_1ch/local/chain/run_tdnn_lstm_recog.sh b/egs/chime4/s5_1ch/local/chain/run_tdnn_lstm_recog.sh deleted file mode 100755 index 9348cd6fa5a..00000000000 --- a/egs/chime4/s5_1ch/local/chain/run_tdnn_lstm_recog.sh +++ /dev/null @@ -1,223 +0,0 @@ -#!/bin/bash - -set -e -o pipefail - -stage=0 -nj=30 -train=noisy -enhan=$1 -mdir=$2 -train_set=tr05_multi_${train} -test_sets="dt05_real_$enhan dt05_simu_$enhan et05_real_$enhan et05_simu_$enhan" -gmm=tri3b_tr05_multi_${train} # this is the source gmm-dir that we'll use for alignments; it - # should have alignments for the specified training data. -nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. - -# Options which are not passed through to run_ivector_common.sh -affix=1a #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. -common_egs_dir= -reporting_email= - -# training chunk-options -chunk_width=140,100,160 -# we don't need extra left/right context for TDNN systems. -chunk_left_context=0 -chunk_right_context=0 - -#decode options -test_online_decoding=false # if true, it will run the last decoding stage. - -# End configuration section. -echo "$0 $@" # Print the command line for logging - - -. ./cmd.sh -. ./path.sh -. ./utils/parse_options.sh - - -if ! cuda-compiled; then - cat < \n\n" `basename $0` - echo "First argument specifies a unique name for different enhancement method" - echo "Second argument specifies acoustic and language model directory" - exit 1; -fi - -# check whether run_init is executed -if [ ! -d data/lang ]; then - echo "error, execute local/run_init.sh, first" - exit 1; -fi - -# check whether run_init is executed -if [ ! -d exp/tri3b_tr05_multi_${train} ]; then - echo "error, execute local/run_init.sh, first" - exit 1; -fi - -# check ivector extractor -if [ ! -d $mdir/exp/nnet3${nnet3_affix}/extractor ]; then - echo "error, set $mdir correctly" - exit 1; -elif [ ! -d exp/nnet3${nnet3_affix}/extractor ]; then - echo "copy $mdir/exp/nnet3${nnet3_affix}/extractor" - mkdir -p exp/nnet3${nnet3_affix} - cp -r $mdir/exp/nnet3${nnet3_affix}/extractor exp/nnet3${nnet3_affix}/ -fi - -# check tdnn-lstm graph -if [ ! -d $mdir/exp/chain${nnet3_affix}/tree_a_sp/graph_tgpr_5k ]; then - echo "error, set $mdir correctly" - exit 1; -elif [ ! -d exp/chain${nnet3_affix}/tree_a_sp/graph_tgpr_5k ]; then - echo "copy $mdir/exp/chain${nnet3_affix}/tree_a_sp/graph_tgpr_5k" - mkdir -p exp/chain${nnet3_affix}/tree_a_sp - cp -r $mdir/exp/chain${nnet3_affix}/tree_a_sp/graph_tgpr_5k exp/chain${nnet3_affix}/tree_a_sp/ -fi - -# check dir -if [ ! -d $mdir/exp/chain${nnet3_affix}/tdnn_lstm${affix}_sp ]; then - echo "error, set $mdir correctly" - exit 1; -elif [ ! -d exp/chain${nnet3_affix}/tdnn_lstm${affix}_sp ]; then - echo "copy $mdir/exp/chain${nnet3_affix}/tdnn_lstm${affix}_sp" - cp -r $mdir/exp/chain${nnet3_affix}/tdnn_lstm${affix}_sp exp/chain${nnet3_affix}/ - rm -rf exp/chain${nnet3_affix}/tdnn_lstm${affix}_sp/decode_* - rm -rf exp/chain${nnet3_affix}/tdnn_lstm${affix}_sp/best_* -fi - -dir=exp/chain${nnet3_affix}/tdnn_lstm${affix}_sp - -# note: you don't necessarily have to change the treedir name -# each time you do a new experiment-- only if you change the -# configuration in a way that affects the tree. -tree_dir=$mdir/exp/chain${nnet3_affix}/tree_a_sp - -# make ivector for dev and eval -if [ $stage -le 2 ]; then - for datadir in ${test_sets}; do - utils/copy_data_dir.sh data/$datadir data/${datadir}_hires - done - - # extracting hires features - for datadir in ${test_sets}; do - steps/make_mfcc.sh --nj $nj --mfcc-config conf/mfcc_hires.conf \ - --cmd "$train_cmd" data/${datadir}_hires - steps/compute_cmvn_stats.sh data/${datadir}_hires - utils/fix_data_dir.sh data/${datadir}_hires - done - - # extract iVectors for the test data, but in this case we don't need the speed - # perturbation (sp). - for data in ${test_sets}; do - nspk=$(wc -l /dev/null || true - - for data in $test_sets; do - ( - data_affix=$(echo $data | sed s/test_//) - nspk=$(wc -l /dev/null || true - - for data in $test_sets; do - ( - data_affix=$(echo $data | sed s/test_//) - nspk=$(wc -l /dev/null || true - - for data in $test_sets; do - ( - data_affix=$(echo $data | sed s/test_//) - nspk=$(wc -l exp/chain/tdnn_lstm${affix}_sp/best_wer_$enhan.result - head -n 15 exp/chain/tdnn_lstm${affix}_sp/best_wer_$enhan.result - - echo "score looped decoding results" - local/chime4_calc_wers_looped.sh exp/chain/tdnn_lstm${affix}_sp $enhan exp/chain/tree_a_sp/graph_tgpr_5k \ - > exp/chain/tdnn_lstm${affix}_sp/best_wer_looped_$enhan.result - head -n 15 exp/chain/tdnn_lstm${affix}_sp/best_wer_looped_$enhan.result -fi - -exit 0; diff --git a/egs/chime4/s5_1ch/local/chain/run_tdnn_recog.sh b/egs/chime4/s5_1ch/local/chain/run_tdnn_recog.sh deleted file mode 100755 index 38a9cc391e7..00000000000 --- a/egs/chime4/s5_1ch/local/chain/run_tdnn_recog.sh +++ /dev/null @@ -1,200 +0,0 @@ -#!/bin/bash - -set -e -o pipefail - -stage=0 -nj=30 -train=noisy -enhan=$1 -mdir=$2 -train_set=tr05_multi_${train} -test_sets="dt05_real_$enhan dt05_simu_$enhan et05_real_$enhan et05_simu_$enhan" -gmm=tri3b_tr05_multi_${train} # this is the source gmm-dir that we'll use for alignments; it - # should have alignments for the specified training data. -nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. - -# Options which are not passed through to run_ivector_common.sh -affix=1a #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. -common_egs_dir= -reporting_email= - -# training chunk-options -chunk_width=140,100,160 -# we don't need extra left/right context for TDNN systems. -chunk_left_context=0 -chunk_right_context=0 - -#decode options -test_online_decoding=false # if true, it will run the last decoding stage. - -# End configuration section. -echo "$0 $@" # Print the command line for logging - - -. ./cmd.sh -. ./path.sh -. ./utils/parse_options.sh - - -if ! cuda-compiled; then - cat < \n\n" `basename $0` - echo "First argument specifies a unique name for different enhancement method" - echo "Second argument specifies acoustic and language model directory" - exit 1; -fi - -# check whether run_init is executed -if [ ! -d data/lang ]; then - echo "error, execute local/run_init.sh, first" - exit 1; -fi - -# check whether run_init is executed -if [ ! -d exp/tri3b_tr05_multi_${train} ]; then - echo "error, execute local/run_init.sh, first" - exit 1; -fi - -# check ivector extractor -if [ ! -d $mdir/exp/nnet3${nnet3_affix}/extractor ]; then - echo "error, set $mdir correctly" - exit 1; -elif [ ! -d exp/nnet3${nnet3_affix}/extractor ]; then - echo "copy $mdir/exp/nnet3${nnet3_affix}/extractor" - mkdir -p exp/nnet3${nnet3_affix} - cp -r $mdir/exp/nnet3${nnet3_affix}/extractor exp/nnet3${nnet3_affix}/ -fi - -# check tdnn graph -if [ ! -d $mdir/exp/chain${nnet3_affix}/tree_a_sp/graph_tgpr_5k ]; then - echo "error, set $mdir correctly" - exit 1; -elif [ ! -d exp/chain${nnet3_affix}/tree_a_sp/graph_tgpr_5k ]; then - echo "copy $mdir/exp/chain${nnet3_affix}/tree_a_sp/graph_tgpr_5k" - mkdir -p exp/chain${nnet3_affix}/tree_a_sp - cp -r $mdir/exp/chain${nnet3_affix}/tree_a_sp/graph_tgpr_5k exp/chain${nnet3_affix}/tree_a_sp/ -fi - -# check dir -if [ ! -d $mdir/exp/chain${nnet3_affix}/tdnn${affix}_sp ]; then - echo "error, set $mdir correctly" - exit 1; -elif [ ! -d exp/chain${nnet3_affix}/tdnn${affix}_sp ]; then - echo "copy $mdir/exp/chain${nnet3_affix}/tdnn${affix}_sp" - cp -r $mdir/exp/chain${nnet3_affix}/tdnn${affix}_sp exp/chain${nnet3_affix}/ - rm -rf exp/chain${nnet3_affix}/tdnn${affix}_sp/decode_* - rm -rf exp/chain${nnet3_affix}/tdnn${affix}_sp/best_* -fi - -dir=exp/chain${nnet3_affix}/tdnn${affix}_sp - -# note: you don't necessarily have to change the treedir name -# each time you do a new experiment-- only if you change the -# configuration in a way that affects the tree. -tree_dir=$mdir/exp/chain${nnet3_affix}/tree_a_sp - -# make ivector for dev and eval -if [ $stage -le 2 ]; then - for datadir in ${test_sets}; do - utils/copy_data_dir.sh data/$datadir data/${datadir}_hires - done - - # extracting hires features - for datadir in ${test_sets}; do - steps/make_mfcc.sh --nj $nj --mfcc-config conf/mfcc_hires.conf \ - --cmd "$train_cmd" data/${datadir}_hires - steps/compute_cmvn_stats.sh data/${datadir}_hires - utils/fix_data_dir.sh data/${datadir}_hires - done - - # extract iVectors for the test data, but in this case we don't need the speed - # perturbation (sp). - for data in ${test_sets}; do - nspk=$(wc -l /dev/null || true - - for data in $test_sets; do - ( - data_affix=$(echo $data | sed s/test_//) - nspk=$(wc -l /dev/null || true - - for data in $test_sets; do - ( - data_affix=$(echo $data | sed s/test_//) - nspk=$(wc -l exp/chain/tdnn${affix}_sp/best_wer_$enhan.result - head -n 15 exp/chain/tdnn${affix}_sp/best_wer_$enhan.result -fi - - -exit 0; diff --git a/egs/chime4/s5_1ch/local/chain/tuning/run_tdnn_1a.sh b/egs/chime4/s5_1ch/local/chain/tuning/run_tdnn_1a.sh index aa7d07b636a..3f8b7c60090 100755 --- a/egs/chime4/s5_1ch/local/chain/tuning/run_tdnn_1a.sh +++ b/egs/chime4/s5_1ch/local/chain/tuning/run_tdnn_1a.sh @@ -1,20 +1,20 @@ #!/bin/bash -# This was modified from wsj/local/chain/tunning/run_tdnn_1d.sh to be +# This was modified from wsj/local/chain/tunning/run_tdnn_1e.sh to be # used in Chime4. #This is the result using all 6 channels: -# exp/chain/tdnn1a_sp/best_wer_beamformit_5mics.result +# exp/chain/tdnn1a_sp/best_wer_blstm_gev.result # ------------------- -# best overall dt05 WER 6.04% (language model weight = 9) +# best overall dt05 WER 4.34% (language model weight = 7) # ------------------- -# dt05_simu WER: 6.25% (Average), 5.71% (BUS), 6.92% (CAFE), 5.37% (PEDESTRIAN), 7.02% (STREET) +# dt05_simu WER: 4.46% (Average), 4.12% (BUS), 5.29% (CAFE), 4.00% (PEDESTRIAN), 4.42% (STREET) # ------------------- -# dt05_real WER: 5.83% (Average), 7.48% (BUS), 5.28% (CAFE), 4.43% (PEDESTRIAN), 6.13% (STREET) +# dt05_real WER: 4.21% (Average), 4.94% (BUS), 4.07% (CAFE), 3.81% (PEDESTRIAN), 4.04% (STREET) # ------------------- -# et05_simu WER: 10.30% (Average), 7.34% (BUS), 10.37% (CAFE), 10.05% (PEDESTRIAN), 13.43% (STREET) +# et05_simu WER: 5.50% (Average), 4.78% (BUS), 5.86% (CAFE), 5.51% (PEDESTRIAN), 5.83% (STREET) # ------------------- -# et05_real WER: 9.67% (Average), 12.71% (BUS), 8.33% (CAFE), 8.20% (PEDESTRIAN), 9.45% (STREET) +# et05_real WER: 5.78% (Average), 6.82% (BUS), 5.10% (CAFE), 5.70% (PEDESTRIAN), 5.51% (STREET) # ------------------- # Final train prob -0.080 # Final valid prob -0.075 @@ -32,9 +32,7 @@ set -e -o pipefail stage=1 nj=30 train=noisy -enhan=$1 train_set=tr05_multi_${train} -test_sets="dt05_real_$enhan dt05_simu_$enhan et05_real_$enhan et05_simu_$enhan" gmm=tri3b_tr05_multi_${train} # this is the source gmm-dir that we'll use for alignments; it # should have alignments for the specified training data. num_threads_ubm=32 @@ -57,11 +55,11 @@ chunk_right_context=0 # training options srand=0 -remove_egs=false +remove_egs=true #decode options test_online_decoding=false # if true, it will run the last decoding stage. - +decode_only=false # if true, it wouldn't train a model again and will only do decoding # End configuration section. echo "$0 $@" # Print the command line for logging @@ -70,6 +68,8 @@ echo "$0 $@" # Print the command line for logging . ./path.sh . ./utils/parse_options.sh +enhan=$1 +test_sets="dt05_real_$enhan dt05_simu_$enhan et05_real_$enhan et05_simu_$enhan" if ! cuda-compiled; then cat < $dir/configs/network.xconfig @@ -187,18 +232,18 @@ if [ $stage -le 15 ]; then fixed-affine-layer name=lda input=Append(-2,-1,0,1,2,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat # the first splicing is moved before the lda layer, so no splicing here - relu-batchnorm-layer name=tdnn1 dim=750 - relu-batchnorm-layer name=tdnn2 dim=750 input=Append(-1,0,1) - relu-batchnorm-layer name=tdnn3 dim=750 - relu-batchnorm-layer name=tdnn4 dim=750 input=Append(-1,0,1) - relu-batchnorm-layer name=tdnn5 dim=750 - relu-batchnorm-layer name=tdnn6 dim=750 input=Append(-3,0,3) - relu-batchnorm-layer name=tdnn7 dim=750 input=Append(-3,0,3) - relu-batchnorm-layer name=tdnn8 dim=750 input=Append(-6,-3,0) + relu-batchnorm-layer name=tdnn1 $opts dim=850 + relu-batchnorm-layer name=tdnn2 $opts dim=850 input=Append(-1,0,1) + relu-batchnorm-layer name=tdnn3 $opts dim=850 + relu-batchnorm-layer name=tdnn4 $opts dim=850 input=Append(-1,0,1) + relu-batchnorm-layer name=tdnn5 $opts dim=850 + relu-batchnorm-layer name=tdnn6 $opts dim=850 input=Append(-3,0,3) + relu-batchnorm-layer name=tdnn7 $opts dim=850 input=Append(-3,0,3) + relu-batchnorm-layer name=tdnn8 $opts dim=850 input=Append(-6,-3,0) ## adding the layers for chain branch - relu-batchnorm-layer name=prefinal-chain dim=750 target-rms=0.5 - output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 + relu-batchnorm-layer name=prefinal-chain $opts dim=850 target-rms=0.5 + output-layer name=output $output_opts include-log-softmax=false dim=$num_targets max-change=1.5 # adding the layers for xent branch # This block prints the configs for a separate output that will be @@ -209,8 +254,8 @@ if [ $stage -le 15 ]; then # final-layer learns at a rate independent of the regularization # constant; and the 0.5 was tuned so as to make the relative progress # similar in the xent and regular final layers. - relu-batchnorm-layer name=prefinal-xent input=tdnn8 dim=750 target-rms=0.5 - output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 + relu-batchnorm-layer name=prefinal-xent $opts input=tdnn8 dim=850 target-rms=0.5 + output-layer name=output-xent $output_opts dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 EOF steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ fi @@ -221,7 +266,12 @@ if [ $stage -le 16 ]; then utils/create_split_dir.pl \ /export/b0{3,4,5,6}/$USER/kaldi-data/egs/chime4-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage fi - + + cat $train_data_dir/utt2uniq | awk -F' ' '{print $1}' > $train_data_dir/utt2uniq.tmp1 + cat $train_data_dir/utt2uniq | awk -F' ' '{print $2}' | sed -e 's/\....//g' | sed -e 's/\_CH.//g' | sed -e 's/\_enhan//g' > $train_data_dir/utt2uniq.tmp2 + paste -d" " $train_data_dir/utt2uniq.tmp1 $train_data_dir/utt2uniq.tmp2 > $train_data_dir/utt2uniq + rm -rf $train_data_dir/utt2uniq.tmp{1,2} + steps/nnet3/chain/train.py --stage=$train_stage \ --cmd="$decode_cmd" \ --feat.online-ivector-dir=$train_ivector_dir \ @@ -233,16 +283,17 @@ if [ $stage -le 16 ]; then --chain.lm-opts="--num-extra-lm-states=2000" \ --trainer.srand=$srand \ --trainer.max-param-change=2.0 \ - --trainer.num-epochs=6 \ + --trainer.num-epochs=12 \ --trainer.frames-per-iter=3000000 \ --trainer.optimization.num-jobs-initial=2 \ - --trainer.optimization.num-jobs-final=5 \ - --trainer.optimization.initial-effective-lrate=0.003 \ - --trainer.optimization.final-effective-lrate=0.0003 \ + --trainer.optimization.num-jobs-final=12 \ + --trainer.optimization.initial-effective-lrate=0.005 \ + --trainer.optimization.final-effective-lrate=0.0005 \ --trainer.optimization.shrink-value=1.0 \ - --trainer.optimization.proportional-shrink=60.0 \ --trainer.num-chunk-per-minibatch=128,64 \ --trainer.optimization.momentum=0.0 \ + --trainer.optimization.backstitch-training-scale=0.3 \ + --trainer.optimization.backstitch-training-interval=1 \ --egs.chunk-width=$chunk_width \ --egs.chunk-left-context=0 \ --egs.chunk-right-context=0 \ @@ -280,8 +331,11 @@ if [ $stage -le 18 ]; then for data in $test_sets; do ( + utils/data/modify_speaker_info.sh --seconds-per-spk-max 200 \ + data/${data}_hires data/${data}_chunked + data_affix=$(echo $data | sed s/test_//) - nspk=$(wc -l $dir/configs/network.xconfig diff --git a/egs/chime4/s5_1ch/local/chime4_calc_wers_looped.sh b/egs/chime4/s5_1ch/local/chime4_calc_wers_looped.sh index 9fe4a20f43a..84bb2cb8dbd 100755 --- a/egs/chime4/s5_1ch/local/chime4_calc_wers_looped.sh +++ b/egs/chime4/s5_1ch/local/chime4_calc_wers_looped.sh @@ -82,4 +82,4 @@ for e_d in $tasks; do | utils/int2sym.pl -f 2- $graph_dir/words.txt \ | sed s:\::g done -done \ No newline at end of file +done diff --git a/egs/chime4/s5_1ch/local/compute_pesq.sh b/egs/chime4/s5_1ch/local/compute_pesq.sh new file mode 100755 index 00000000000..1d290a4893f --- /dev/null +++ b/egs/chime4/s5_1ch/local/compute_pesq.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Copyright 2017 Johns Hopkins University (Author: Aswin Shanmugam Subramanian) +# Apache 2.0 + +# This script creates the average PESQ score of files in an enhanced directory with corresponding +# files in a reference directory. +# Expects the PESQ third party executable in "local/PESQ" +# PESQ source was dowloaded and compiled using "local/download_se_eval_tool.sh" +# Eg. local/compute_pesq.sh blstm_gev enhan/blstm_gev local/nn-gev/data/audio/16kHz/isolated_ext $PWD + +set -e +set -u +set -o pipefail + +if [ $# != 4 ]; then + echo "Wrong #arguments ($#, expected 4)" + echo "Usage: local/compute_pesq.sh " + exit 1; +fi + +enhancement_method=$1 +enhancement_directory=$2 +chime_rir_directory=$3 +modeldir=$4 + +expdir=$modeldir/exp/compute_pesq_${enhancement_method} +mkdir -p $expdir +pushd $expdir +ls $enhancement_directory/et05_*_simu/*.wav > $expdir/et05_files +ls $enhancement_directory/dt05_*_simu/*.wav > $expdir/dt05_files + +for set in "dt05" "et05" +do +declare -i n_files=0 +t_mos=0 +avg_mos=0 + while read filename; do + n_files=$n_files+1 + target_filename=`echo $filename | rev | cut -d"/" -f1 | rev` + speaker=`echo $target_filename | cut -d"_" -f1` + utt_id=`echo $target_filename | cut -d"_" -f2` + noise_cap=`echo $target_filename | cut -d"_" -f3 | cut -d"." -f1` + noise=`echo "$noise_cap" | awk '{ print tolower($1) }'` + temp=`$modeldir/local/PESQ +16000 ../../$chime_rir_directory/"$set"_"$noise"_simu/"$speaker"_"$utt_id"_"$noise_cap".CH5.Clean.wav $filename` + pesq_score=`echo $temp | rev | cut -d " " -f1 | rev` + t_mos=$(awk "BEGIN {print $t_mos+$pesq_score; exit}") + done <$expdir/"$set"_files +avg_mos=$(awk "BEGIN {print $t_mos/$n_files; exit}") +echo $avg_mos>"$expdir"/pesq_"$set" +done +popd diff --git a/egs/chime4/s5_1ch/local/compute_stoi_estoi_sdr.sh b/egs/chime4/s5_1ch/local/compute_stoi_estoi_sdr.sh new file mode 100755 index 00000000000..b7627560b67 --- /dev/null +++ b/egs/chime4/s5_1ch/local/compute_stoi_estoi_sdr.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# Copyright 2017 Johns Hopkins University (Author: Aswin Shanmugam Subramanian) +# Apache 2.0 + +# This script creates the list of enhanced files and reference files and calls the +# matlab script "stoi_estoi_sdr.m" to get STOI, eSTOI and SDR scores +# Eg. local/compute_stoi_estoi_sdr.sh --njobs 10 blstm_gev enhan/blstm_gev local/nn-gev/data/audio/16kHz/isolated_ext + +. ./cmd.sh +. ./path.sh +set -e +set -u +set -o pipefail + +njobs=10 +cmd=run.pl + +. utils/parse_options.sh || exit 1; + +if [ $# != 3 ]; then + echo "Wrong #arguments ($#, expected 3)" + echo "Usage: local/compute_stoi_estoi_sdr.sh [options] " + echo "options" + echo " --njobs # number of parallel jobs" + echo " --cmd # Command to run in parallel with" + exit 1; +fi + +enhancement_method=$1 +enhancement_directory=$2 +chime_rir_directory=$3 + +expdir=exp/compute_stoi_estoi_sdr_${enhancement_method} +mkdir -p $expdir +ls $chime_rir_directory/dt05_*/*CH5.Clean.wav > $expdir/original_list +ls $enhancement_directory/dt05_*simu/*.wav > $expdir/enhanced_list +$cmd $expdir/compute_stoi_estoi_sdr_dt05.log matlab -nodisplay -nosplash -r "addpath('local'); stoi_estoi_sdr($njobs,'$enhancement_method','$expdir','dt05');exit" +ls $chime_rir_directory/et05_*/*CH5.Clean.wav > $expdir/original_list +ls $enhancement_directory/et05_*simu/*.wav > $expdir/enhanced_list +$cmd $expdir/compute_stoi_estoi_sdr_et05.log matlab -nodisplay -nosplash -r "addpath('local'); stoi_estoi_sdr($njobs,'$enhancement_method','$expdir','et05');exit" diff --git a/egs/chime4/s5_1ch/local/download_se_eval_tool.sh b/egs/chime4/s5_1ch/local/download_se_eval_tool.sh new file mode 100755 index 00000000000..ddd86a03d8a --- /dev/null +++ b/egs/chime4/s5_1ch/local/download_se_eval_tool.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# Copyright 2017 Johns Hopkins University (Author: Aswin Shanmugam Subramanian) +# Apache 2.0 + +wget http://bass-db.gforge.inria.fr/bss_eval/bss_eval_sources.m -O local/bss_eval_sources.m +wget https://github.com/JacobD10/SoundZone_Tools/raw/master/stoi.m -O local/stoi.m +wget https://github.com/JacobD10/SoundZone_Tools/raw/master/estoi.m -O local/estoi.m +wget 'https://www.itu.int/rec/dologin_pub.asp?lang=e&id=T-REC-P.862-200102-I!!SOFT-ZST-E&type=items' -O PESQ.zip +unzip PESQ.zip -d local/PESQ_sources +cd local/PESQ_sources/P862/Software/source +gcc *.c -lm -o PESQ +cd ../../../../../ +mv local/PESQ_sources/P862/Software/source/PESQ local/ diff --git a/egs/chime4/s5_1ch/local/fix_read_sim_from_different_directory.patch b/egs/chime4/s5_1ch/local/fix_read_sim_from_different_directory.patch new file mode 100644 index 00000000000..46121357c5e --- /dev/null +++ b/egs/chime4/s5_1ch/local/fix_read_sim_from_different_directory.patch @@ -0,0 +1,244 @@ +diff --git a/beamform.py b/beamform.py +index 02eeed8..070c76d 100644 +--- a/beamform.py ++++ b/beamform.py +@@ -6,9 +6,10 @@ from chainer import Variable + from chainer import cuda + from chainer import serializers + from tqdm import tqdm ++import sys + +-from chime_data import gen_flist_simu, \ +- gen_flist_real, get_audio_data, get_audio_data_with_context ++from chime_data import gen_flist_simu, gen_flist_2ch,\ ++ gen_flist_real, get_audio_data, get_audio_data_1ch, get_audio_data_with_context + from fgnt.beamforming import gev_wrapper_on_masks + from fgnt.signal_processing import audiowrite, stft, istft + from fgnt.utils import Timer +@@ -20,6 +21,8 @@ parser.add_argument('flist', + help='Name of the flist to process (e.g. tr05_simu)') + parser.add_argument('chime_dir', + help='Base directory of the CHiME challenge.') ++parser.add_argument('sim_dir', ++ help='Base directory of the CHiME challenge simulated data.') + parser.add_argument('output_dir', + help='The directory where the enhanced wav files will ' + 'be stored.') +@@ -29,6 +32,10 @@ parser.add_argument('model_type', + help='Type of model (BLSTM or FW)') + parser.add_argument('--gpu', '-g', default=-1, type=int, + help='GPU ID (negative value indicates CPU)') ++parser.add_argument('--single', '-s', default=0, type=int, ++ help='0 for multi-channel and channel number (1-6) for single channel') ++parser.add_argument('--track', '-t', default=6, type=int, ++ help='1, 2 or 6 depending on the data used') + args = parser.parse_args() + + # Prepare model +@@ -48,11 +55,35 @@ xp = np if args.gpu < 0 else cuda.cupy + stage = args.flist[:2] + scenario = args.flist.split('_')[-1] + ++if stage == 'tr' and (args.track == 1 or args.track == 2): ++ print("No train data for 1ch track and 2ch track"); ++ sys.exit(0); ++ + # CHiME data handling + if scenario == 'simu': +- flist = gen_flist_simu(args.chime_dir, stage) ++ if args.track == 6: ++ flist = gen_flist_simu(args.chime_dir, args.sim_dir, stage) ++ elif args.track == 2: ++ flist = gen_flist_2ch(args.chime_dir, stage, scenario) ++ elif args.track == 1: ++ flist = list() ++ for env in ['caf', 'bus', 'str', 'ped']: ++ flist_temp = os.listdir(os.path.join(args.chime_dir, 'audio', '16kHz', 'isolated_1ch_track', '{}05_{}_{}'.format(stage, env, scenario))) ++ flist_ext = [i for i in flist_temp if i.endswith('.wav')] ++ flist_with_dir = [os.path.join(args.chime_dir, 'audio', '16kHz', 'isolated_1ch_track', '{}05_{}_{}'.format(stage, env, scenario), i) for i in flist_ext] ++ flist = flist + flist_with_dir + elif scenario == 'real': +- flist = gen_flist_real(args.chime_dir, stage) ++ if args.track == 6: ++ flist = gen_flist_real(args.chime_dir, stage) ++ elif args.track == 2: ++ flist = gen_flist_2ch(args.chime_dir, stage, scenario) ++ elif args.track == 1: ++ flist = list() ++ for env in ['caf', 'bus', 'str', 'ped']: ++ flist_temp = os.listdir(os.path.join(args.chime_dir, 'audio', '16kHz', 'isolated_1ch_track', '{}05_{}_{}'.format(stage, env, scenario))) ++ flist_ext = [i for i in flist_temp if i.endswith('.wav')] ++ flist_with_dir = [os.path.join(args.chime_dir, 'audio', '16kHz', 'isolated_1ch_track', '{}05_{}_{}'.format(stage, env, scenario), i) for i in flist_ext] ++ flist = flist + flist_with_dir + else: + raise ValueError('Unknown flist {}'.format(args.flist)) + +@@ -67,12 +98,19 @@ t_beamform = 0 + # Beamform loop + for cur_line in tqdm(flist): + with Timer() as t: +- if scenario == 'simu': ++ if args.track == 6: ++ if scenario == 'simu': ++ audio_data = get_audio_data(cur_line) ++ context_samples = 0 ++ elif scenario == 'real': ++ audio_data, context_samples = get_audio_data_with_context( ++ cur_line[0], cur_line[1], cur_line[2]) ++ elif args.track == 2: + audio_data = get_audio_data(cur_line) + context_samples = 0 +- elif scenario == 'real': +- audio_data, context_samples = get_audio_data_with_context( +- cur_line[0], cur_line[1], cur_line[2]) ++ elif args.track == 1: ++ audio_data = get_audio_data_1ch(cur_line) ++ context_samples = 0 + t_io += t.msecs + Y = stft(audio_data, time_dim=1).transpose((1, 0, 2)) + Y_var = Variable(np.abs(Y).astype(np.float32), True) +@@ -85,28 +123,45 @@ for cur_line in tqdm(flist): + t_net += t.msecs + + with Timer() as t: +- N_mask = np.median(N_masks.data, axis=1) +- X_mask = np.median(X_masks.data, axis=1) +- Y_hat = gev_wrapper_on_masks(Y, N_mask, X_mask) ++ if args.single >= 1 or args.track == 1: ++ Y_hat = X_masks.data * Y ++ elif args.single == 0: ++ N_mask = np.median(N_masks.data, axis=1) ++ X_mask = np.median(X_masks.data, axis=1) ++ Y_hat = gev_wrapper_on_masks(Y, N_mask, X_mask) + t_beamform += t.msecs + +- if scenario == 'simu': +- wsj_name = cur_line.split('/')[-1].split('_')[1] +- spk = cur_line.split('/')[-1].split('_')[0] +- env = cur_line.split('/')[-1].split('_')[-1] +- elif scenario == 'real': +- wsj_name = cur_line[3] +- spk = cur_line[0].split('/')[-1].split('_')[0] +- env = cur_line[0].split('/')[-1].split('_')[-1] ++ if args.track == 1: ++ env = cur_line.split('/')[-1].split('_')[2].split('.')[0] ++ filename = os.path.join(args.output_dir, '{}05_{}_{}'.format(stage, env.lower(), scenario), os.path.basename(cur_line)) ++ else: ++ if scenario == 'simu' or args.track == 2: ++ wsj_name = cur_line.split('/')[-1].split('_')[1] ++ spk = cur_line.split('/')[-1].split('_')[0] ++ env = cur_line.split('/')[-1].split('_')[-1] ++ elif scenario == 'real': ++ wsj_name = cur_line[3] ++ spk = cur_line[0].split('/')[-1].split('_')[0] ++ env = cur_line[0].split('/')[-1].split('_')[-1] + +- filename = os.path.join( +- args.output_dir, +- '{}05_{}_{}'.format(stage, env.lower(), scenario), +- '{}_{}_{}.wav'.format(spk, wsj_name, env.upper()) +- ) +- with Timer() as t: +- audiowrite(istft(Y_hat)[context_samples:], filename, 16000, True, True) +- t_io += t.msecs ++ filename = os.path.join( ++ args.output_dir, ++ '{}05_{}_{}'.format(stage, env.lower(), scenario), ++ '{}_{}_{}.wav'.format(spk, wsj_name, env.upper()) ++ ) ++ if args.track == 1: ++ with Timer() as t: ++ audiowrite(istft(Y_hat[:,0,:])[int(context_samples):], filename, 16000, True, True) ++ t_io += t.msecs ++ elif args.single == 0: ++ with Timer() as t: ++ audiowrite(istft(Y_hat)[int(context_samples):], filename, 16000, True, True) ++ t_io += t.msecs ++ elif args.single >= 1: ++ ch = args.single ++ with Timer() as t: ++ audiowrite(istft(Y_hat[:,ch-1,:])[int(context_samples):], filename, 16000, True, True) ++ t_io += t.msecs + + print('Finished') + print('Timings: I/O: {:.2f}s | Net: {:.2f}s | Beamformer: {:.2f}s'.format( +diff --git a/beamform.sh b/beamform.sh +index 3c7de5a..aaae10d 100755 +--- a/beamform.sh ++++ b/beamform.sh +@@ -1,5 +1,5 @@ + #!/usr/bin/env bash + + for flist in tr05_simu tr05_real dt05_simu dt05_real et05_simu et05_real; do +- python beamform.py $flist "$@" +-done +\ No newline at end of file ++ $HOME/miniconda3/bin/python local/nn-gev/beamform.py $flist "$@" ++done +diff --git a/chime_data.py b/chime_data.py +index 0072e1b..641d9d3 100644 +--- a/chime_data.py ++++ b/chime_data.py +@@ -11,7 +11,7 @@ from fgnt.signal_processing import stft + from fgnt.utils import mkdir_p + + +-def gen_flist_simu(chime_data_dir, stage, ext=False): ++def gen_flist_simu(chime_data_dir, dest_dir, stage, ext=False): + with open(os.path.join( + chime_data_dir, 'annotations', + '{}05_{}.json'.format(stage, 'simu'))) as fid: +@@ -21,7 +21,7 @@ def gen_flist_simu(chime_data_dir, stage, ext=False): + else: + isolated_dir = 'isolated' + flist = [os.path.join( +- chime_data_dir, 'audio', '16kHz', isolated_dir, ++ dest_dir, 'audio', '16kHz', isolated_dir, + '{}05_{}_{}'.format(stage, a['environment'].lower(), 'simu'), + '{}_{}_{}'.format(a['speaker'], a['wsj_name'], a['environment'])) + for a in annotations] +@@ -39,11 +39,33 @@ def gen_flist_real(chime_data_dir, stage): + return flist_tuples + + ++def gen_flist_2ch(chime_data_dir, stage, scenario): ++ with open(os.path.join( ++ chime_data_dir, 'annotations', ++ '{}05_{}.json'.format(stage, scenario))) as fid: ++ annotations = json.load(fid) ++ flist = [os.path.join( ++ chime_data_dir, 'audio', '16kHz', 'isolated_2ch_track', ++ '{}05_{}_{}'.format(stage, a['environment'].lower(), scenario), ++ '{}_{}_{}'.format(a['speaker'], a['wsj_name'], a['environment'])) ++ for a in annotations] ++ return flist ++ ++ ++def get_audio_data_1ch(filename): ++ audio_data = list() ++ audio_data.append(audioread(filename)[None, :]) ++ audio_data = np.concatenate(audio_data, axis=0) ++ audio_data = audio_data.astype(np.float32) ++ return audio_data ++ ++ + def get_audio_data(file_template, postfix='', ch_range=range(1, 7)): + audio_data = list() + for ch in ch_range: +- audio_data.append(audioread( +- file_template + '.CH{}{}.wav'.format(ch, postfix))[None, :]) ++ if os.path.isfile(file_template + '.CH{}{}.wav'.format(ch, postfix)): ++ audio_data.append(audioread( ++ file_template + '.CH{}{}.wav'.format(ch, postfix))[None, :]) + audio_data = np.concatenate(audio_data, axis=0) + audio_data = audio_data.astype(np.float32) + return audio_data +@@ -65,7 +87,7 @@ def get_audio_data_with_context(embedded_template, t_start, t_end, + + def prepare_training_data(chime_data_dir, dest_dir): + for stage in ['tr', 'dt']: +- flist = gen_flist_simu(chime_data_dir, stage, ext=True) ++ flist = gen_flist_simu(chime_data_dir, dest_dir, stage, ext=True) + export_flist = list() + mkdir_p(os.path.join(dest_dir, stage)) + for f in tqdm.tqdm(flist, desc='Generating data for {}'.format(stage)): diff --git a/egs/chime4/s5_1ch/local/real_noisy_chime4_data_prep.sh b/egs/chime4/s5_1ch/local/real_noisy_chime4_data_prep.sh index edbbfd41e69..0173b022176 100755 --- a/egs/chime4/s5_1ch/local/real_noisy_chime4_data_prep.sh +++ b/egs/chime4/s5_1ch/local/real_noisy_chime4_data_prep.sh @@ -68,10 +68,14 @@ if $eval_flag; then cp $trans_dir/et05_real.dot_all et05_real.dot fi -# make a scp file from file list +# make a scp temporary file from file list for x in $list_set; do - cat $x.flist | awk -F'[/]' '{print $NF}'| sed -e 's/\.wav/_REAL/' > ${x}_wav.ids - paste -d" " ${x}_wav.ids $x.flist | sort -k 1 > ${x}_wav.scp + cat $x.flist | awk -F'[/]' '{print $NF}'| sed -e 's/\.wav/_REAL/' > ${x}_wav.id.temp + cat ${x}_wav.id.temp | awk -F'_' '{print $3}' | awk -F'.' '{print $2}' > $x.ch + cat ${x}_wav.id.temp | awk -F'_' '{print $1}' > $x.part1 + cat ${x}_wav.id.temp | sed -e 's/^..._//' > $x.part2 + paste -d"_" $x.part1 $x.ch $x.part2 > ${x}_wav.ids + paste -d" " ${x}_wav.ids $x.flist | sort -t_ -k1,1 -k3 > ${x}_wav.scp.temp done #make a transcription from dot @@ -98,13 +102,17 @@ fi # data-preparation stage independent of the specific lexicon used. noiseword=""; for x in $list_set;do + cat ${x}_wav.scp.temp | awk '{print $1}' > $x.txt.part1 + cat $x.trans1 | awk '{$1=""; print $0}' | sed 's/^[ \t]*//g' > $x.txt.part2 + paste -d" " $x.txt.part1 $x.txt.part2 > $x.trans1 cat $x.trans1 | $local/normalize_transcript.pl $noiseword \ | sort > $x.txt || exit 1; done # Make the utt2spk and spk2utt files. for x in $list_set; do - cat ${x}_wav.scp | awk -F'_' '{print $1}' > $x.spk + sort ${x}_wav.scp.temp > ${x}_wav.scp + cat ${x}_wav.scp | awk -F'_' '{print $1"_"$2}' > $x.spk cat ${x}_wav.scp | awk '{print $1}' > $x.utt paste -d" " $x.utt $x.spk > $x.utt2spk cat $x.utt2spk | $utils/utt2spk_to_spk2utt.pl > $x.spk2utt || exit 1; @@ -119,4 +127,8 @@ for x in $list_set; do cp ${x}.utt2spk ../../$x/utt2spk || exit 1; done +# clean up temp files +rm *.temp +rm *.part{1,2} + echo "Data preparation succeeded" diff --git a/egs/chime4/s5_1ch/local/rnnlm/run_lstm.sh b/egs/chime4/s5_1ch/local/rnnlm/run_lstm.sh new file mode 120000 index 00000000000..c53740399ce --- /dev/null +++ b/egs/chime4/s5_1ch/local/rnnlm/run_lstm.sh @@ -0,0 +1 @@ +tuning/run_lstm_1a.sh \ No newline at end of file diff --git a/egs/chime4/s5_1ch/local/rnnlm/run_lstm_back.sh b/egs/chime4/s5_1ch/local/rnnlm/run_lstm_back.sh new file mode 100755 index 00000000000..76e2b563e6b --- /dev/null +++ b/egs/chime4/s5_1ch/local/rnnlm/run_lstm_back.sh @@ -0,0 +1,93 @@ +#!/bin/bash + +# Copyright 2012 Johns Hopkins University (author: Daniel Povey) +# 2015 Guoguo Chen +# 2017 Hainan Xu +# 2017 Szu-Jui Chen + +# This script trains LMs on the reversed Chime4 data, which we +# call it backward model. + +# Begin configuration section. +affix=1a +dir=exp/rnnlm_lstm_${affix}_back +embedding_dim=2048 +lstm_rpd=512 +lstm_nrpd=512 +stage=-10 +train_stage=-10 + +# variables for lattice rescoring +ngram_order=4 # approximate the lattice-rescoring by limiting the max-ngram-order + # if it's set, it merges histories in the lattice if they share + # the same ngram history and this prevents the lattice from + # exploding exponentially + +. cmd.sh +. utils/parse_options.sh + +srcdir=data/local/local_lm +lexicon=data/local/dict/lexiconp.txt +text_dir=data/rnnlm/text_nosp_${affix}_back +mkdir -p $dir/config +set -e + +for f in $lexicon; do + [ ! -f $f ] && \ + echo "$0: expected file $f to exist; search for local/wsj_extend_dict.sh in run.sh" && exit 1 +done + +#prepare training and dev data +if [ $stage -le 0 ]; then + mkdir -p $text_dir + cat $srcdir/train.rnn | awk '{for(i=NF;i>0;i--) printf("%s ",$i); print""}'> $text_dir/chime4.txt.tmp + sed -e "s///g" $text_dir/chime4.txt.tmp > $text_dir/chime4.txt + rm $text_dir/chime4.txt.tmp + cat $srcdir/valid.rnn | awk '{for(i=NF;i>0;i--) printf("%s ",$i); print""}'> $text_dir/dev.txt +fi + +if [ $stage -le 1 ]; then + cp data/lang_chain/words.txt $dir/config/words.txt + n=`cat $dir/config/words.txt | wc -l` + echo " $n" >> $dir/config/words.txt + # words that are not present in words.txt but are in the training or dev data, will be + # mapped to during training. + echo "" >$dir/config/oov.txt + + cat > $dir/config/data_weights.txt <$dir/config/unigram_probs.txt + + # choose features + rnnlm/choose_features.py --unigram-probs=$dir/config/unigram_probs.txt \ + --use-constant-feature=true \ + --special-words=',,,' \ + $dir/config/words.txt > $dir/config/features.txt + + cat >$dir/config/xconfig <//g" $text_dir/chime4.txt.tmp > $text_dir/chime4.txt + cp $srcdir/valid.rnn $text_dir/dev.txt +fi + +if [ $stage -le 1 ]; then + cp data/lang_chain/words.txt $dir/config/words.txt + n=`cat $dir/config/words.txt | wc -l` + echo " $n" >> $dir/config/words.txt + # words that are not present in words.txt but are in the training or dev data, will be + # mapped to during training. + echo "" >$dir/config/oov.txt + + cat > $dir/config/data_weights.txt <$dir/config/unigram_probs.txt + + # choose features + rnnlm/choose_features.py --unigram-probs=$dir/config/unigram_probs.txt \ + --use-constant-feature=true \ + --special-words=',,,' \ + $dir/config/words.txt > $dir/config/features.txt + + cat >$dir/config/xconfig < $tgtdir/best_wer_${enhan}_${decode_dir_suffix}.result + head -n 15 $tgtdir/best_wer_${enhan}_${decode_dir_suffix}.result +fi + +nbest=100 +rnnweight=0.8 +if [ $stage -le 6 ] && $run_nbest_rescore; then + echo "$0: Perform nbest-rescoring on $ac_model_dir" + for decode_set in dt05_real dt05_simu et05_real et05_simu; do + decode_dir=$tgtdir/decode_tgpr_5k_${decode_set}_${enhan}_${LM} + ( + # Lattice rescoring + rnnlm/lmrescore_nbest.sh \ + --cmd "$train_cmd --mem 2G" --N $nbest \ + $rnnweight data/lang_test_$LM $dir \ + data/${decode_set}_${enhan}_chunked ${decode_dir} \ + $tgtdir/decode_tgpr_5k_${decode_set}_${enhan}_${decode_dir_suffix}_w${rnnweight}_n${nbest} + + if $use_backward_model; then + rnnlm/lmrescore_nbest_back.sh \ + --cmd "$train_cmd --mem 2G" --N $nbest \ + $rnnweight data/lang_test_$LM ${dir}_back \ + data/${decode_set}_${enhan}_chunked \ + $tgtdir/decode_tgpr_5k_${decode_set}_${enhan}_${decode_dir_suffix}_w${rnnweight}_n${nbest} \ + $tgtdir/decode_tgpr_5k_${decode_set}_${enhan}_${decode_dir_suffix}_w${rnnweight}_n${nbest}_bi + fi + ) & + done + wait + # calc wers for nbest-rescoring results + if $use_backward_model; then + local/chime4_calc_wers.sh $tgtdir ${enhan}_${decode_dir_suffix}_w${rnnweight}_n${nbest}_bi \ + $tgtdir/graph_tgpr_5k \ + > $tgtdir/best_wer_${enhan}_${decode_dir_suffix}_w${rnnweight}_n${nbest}_bi.result + head -n 15 $tgtdir/best_wer_${enhan}_${decode_dir_suffix}_w${rnnweight}_n${nbest}_bi.result + else + local/chime4_calc_wers.sh $tgtdir ${enhan}_${decode_dir_suffix}_w${rnnweight}_n${nbest} \ + $tgtdir/graph_tgpr_5k \ + > $tgtdir/best_wer_${enhan}_${decode_dir_suffix}_w${rnnweight}_n${nbest}.result + head -n 15 $tgtdir/best_wer_${enhan}_${decode_dir_suffix}_w${rnnweight}_n${nbest}.result + fi +fi + +exit 0 diff --git a/egs/chime4/s5_1ch/local/run_blstm_gev.sh b/egs/chime4/s5_1ch/local/run_blstm_gev.sh new file mode 100755 index 00000000000..2ee92b70fbd --- /dev/null +++ b/egs/chime4/s5_1ch/local/run_blstm_gev.sh @@ -0,0 +1,81 @@ +#!/bin/bash +# Copyright 2017 Johns Hopkins University (Author: Aswin Shanmugam Subramanian) +# Apache 2.0 + +. ./cmd.sh +. ./path.sh + +# Config: +nj=10 +cmd=run.pl +track=6 +. utils/parse_options.sh || exit 1; + +if [ $# != 4 ]; then + echo "Wrong #arguments ($#, expected 4)" + echo "Usage: local/run_blstm_gev.sh [options] " + echo "main options (for others, see top of script file)" + echo " --nj # number of parallel jobs" + echo " --cmd # Command to run in parallel with" + echo " --track # Chime data to use (1, 2 or 6)" + exit 1; +fi + +sdir=$1 +chime3_dir=$2 +odir=$3 +enhancement_type=$4 + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +miniconda_dir=$HOME/miniconda3/ +if [ ! -d $miniconda_dir ]; then + echo "$miniconda_dir does not exist. Please run '../../../tools/extras/install_miniconda.sh' and '../../../tools/extras/install_chainer.sh';" +fi + +# check if chainer is installed +result=`$HOME/miniconda3/bin/python -c "\ +try: + import chainer + print('1') +except ImportError: + print('0')"` + +if [ "$result" == "1" ]; then + echo "Chainer is installed" +else + echo "Chainer is not installed. Please run ../../../tools/extras/install_chainer.sh" +fi + +if [ ! -d local/nn-gev ]; then + cd local/ + git clone https://github.com/fgnt/nn-gev.git + cd nn-gev/ + git checkout 3a039a4b707419fab05deb9679b41360ea92d779 . + git apply ../fix_read_sim_from_different_directory.patch + cd ../../ +else + cd local/nn-gev/ + git checkout 3a039a4b707419fab05deb9679b41360ea92d779 . + git apply ../fix_read_sim_from_different_directory.patch + cd ../../ +fi + +mkdir -p $odir +set +e +n_isolated_dirs=`ls local/nn-gev/data/audio/16kHz/isolated/ 2>/dev/null | wc -l` +n_isolated_ext_dirs=`ls local/nn-gev/data/audio/16kHz/isolated_ext/ 2>/dev/null | wc -l` +set -e +if [[ "$n_isolated_dirs" -ne 12 || "$n_isolated_ext_dirs" -ne 12 ]];then + echo "generating simulation data and storing in local/nn-gev/data" + $cmd $odir/simulation.log matlab -nodisplay -nosplash -r "addpath('local'); CHiME3_simulate_data_patched_parallel(1,$nj,'$sdir','$chime3_dir');exit" +else + echo "Didn't run Matlab simulation. Using existing data in local/nn-gev/data/audio/" +fi + +echo "Training a BLSTM-based mask network and enhancing signals with mask-based GEV beamformer" +$cuda_cmd $odir/beamform.log local/run_nn-gev.sh $sdir $odir $enhancement_type $track diff --git a/egs/chime4/s5_1ch/local/run_dnn.sh b/egs/chime4/s5_1ch/local/run_dnn.sh deleted file mode 100755 index 2207574e71c..00000000000 --- a/egs/chime4/s5_1ch/local/run_dnn.sh +++ /dev/null @@ -1,237 +0,0 @@ -#!/bin/bash - -# Copyright 2016 University of Sheffield (Jon Barker, Ricard Marxer) -# Inria (Emmanuel Vincent) -# Mitsubishi Electric Research Labs (Shinji Watanabe) -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -# This script is made from the kaldi recipe of the 2nd CHiME Challenge Track 2 -# made by Chao Weng - -. ./path.sh -. ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. - ## This relates to the queue. - -# Config: -nj=30 -stage=0 # resume training with --stage N -train=noisy -eval_flag=true # make it true when the evaluation data are released - -. utils/parse_options.sh || exit 1; - -# This is a shell script, but it's recommended that you run the commands one by -# one by copying and pasting into the shell. - -if [ $# -ne 1 ]; then - printf "\nUSAGE: %s \n\n" `basename $0` - echo "First argument specifies a unique name for different enhancement method" - exit 1; -fi - -# set enhanced data -enhan=$1 - -# Set bash to 'debug' mode, it will exit on : -# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', -set -e -set -u -set -o pipefail - -# check whether run_init is executed -if [ ! -d data/lang ]; then - echo "error, execute local/run_init.sh, first" - exit 1; -fi - -# check whether run_init is executed -if [ ! -d exp/tri3b_tr05_multi_${train} ]; then - echo "error, execute local/run_init.sh, first" - exit 1; -fi - -# get alignments -if [ $stage -le 0 ]; then - steps/align_fmllr.sh --nj $nj --cmd "$train_cmd" \ - data/tr05_multi_${train} data/lang exp/tri3b_tr05_multi_${train} exp/tri3b_tr05_multi_${train}_ali - steps/align_fmllr.sh --nj 4 --cmd "$train_cmd" \ - data/dt05_multi_$enhan data/lang exp/tri3b_tr05_multi_${train} exp/tri3b_tr05_multi_${train}_ali_dt05 -fi - -# make fmllr feature for training multi = simu + real -gmmdir=exp/tri3b_tr05_multi_${train}_ali -data_fmllr=data-fmllr-tri3b -mkdir -p $data_fmllr -fmllrdir=fmllr-tri3b/${train} -if [ $stage -le 1 ]; then - for x in tr05_real_${train} tr05_simu_${train}; do - steps/nnet/make_fmllr_feats.sh --nj 4 --cmd "$train_cmd" \ - --transform-dir $gmmdir \ - $data_fmllr/$x data/$x $gmmdir exp/make_fmllr_tri3b/$x $fmllrdir - done -fi - -# make fmllr feature for dev and eval -gmmdir=exp/tri3b_tr05_multi_${train} -fmllrdir=fmllr-tri3b/$enhan -if [ $stage -le 2 ]; then - if $eval_flag; then - tasks="dt05_real_$enhan dt05_simu_$enhan et05_real_$enhan et05_simu_$enhan" - else - tasks="dt05_real_$enhan dt05_simu_$enhan" - fi - for x in $tasks; do - steps/nnet/make_fmllr_feats.sh --nj 4 --cmd "$train_cmd" \ - --transform-dir $gmmdir/decode_tgpr_5k_$x \ - $data_fmllr/$x data/$x $gmmdir exp/make_fmllr_tri3b/$x $fmllrdir - done -fi - -# make mixed training set from real and simulation enhanced data -# multi = simu + real -if [ $stage -le 3 ]; then - for data_dir in $data_fmllr/tr05_real_${train} $data_fmllr/tr05_simu_${train} $data_fmllr/dt05_real_$enhan $data_fmllr/dt05_simu_$enhan; do - utils/data/get_utt2dur.sh $data_dir - done - - utils/combine_data.sh $data_fmllr/tr05_multi_${train} $data_fmllr/tr05_simu_${train} $data_fmllr/tr05_real_${train} - utils/combine_data.sh $data_fmllr/dt05_multi_$enhan $data_fmllr/dt05_simu_$enhan $data_fmllr/dt05_real_$enhan - if $eval_flag; then - for data_dir in $data_fmllr/et05_real_$enhan $data_fmllr/et05_simu_$enhan; do - utils/data/get_utt2dur.sh $data_dir - done - utils/combine_data.sh $data_fmllr/et05_multi_$enhan $data_fmllr/et05_simu_$enhan $data_fmllr/et05_real_$enhan - fi -fi - -# pre-train dnn -dir=exp/tri4a_dnn_pretrain_tr05_multi_${train} -if [ $stage -le 4 ]; then - $cuda_cmd $dir/_pretrain_dbn.log \ - steps/nnet/pretrain_dbn.sh --nn-depth 7 --rbm-iter 3 $data_fmllr/tr05_multi_${train} $dir -fi - -# train dnn -dir=exp/tri4a_dnn_tr05_multi_${train} -ali=exp/tri3b_tr05_multi_${train}_ali -ali_dev=exp/tri3b_tr05_multi_${train}_ali_dt05 -feature_transform=exp/tri4a_dnn_pretrain_tr05_multi_${train}/final.feature_transform -dbn=exp/tri4a_dnn_pretrain_tr05_multi_${train}/7.dbn -if [ $stage -le 5 ]; then - $cuda_cmd $dir/_train_nnet.log \ - steps/nnet/train.sh --feature-transform $feature_transform --dbn $dbn --hid-layers 0 --learn-rate 0.008 \ - $data_fmllr/tr05_multi_${train} $data_fmllr/dt05_multi_$enhan data/lang $ali $ali_dev $dir -fi - -# decode enhanced speech -if [ $stage -le 6 ]; then - utils/mkgraph.sh data/lang_test_tgpr_5k $dir $dir/graph_tgpr_5k - steps/nnet/decode.sh --nj 4 --num-threads 3 --cmd "$decode_cmd" --acwt 0.10 --config conf/decode_dnn.config \ - $dir/graph_tgpr_5k $data_fmllr/dt05_real_$enhan $dir/decode_tgpr_5k_dt05_real_$enhan & - steps/nnet/decode.sh --nj 4 --num-threads 3 --cmd "$decode_cmd" --acwt 0.10 --config conf/decode_dnn.config \ - $dir/graph_tgpr_5k $data_fmllr/dt05_simu_$enhan $dir/decode_tgpr_5k_dt05_simu_$enhan & - if $eval_flag; then - steps/nnet/decode.sh --nj 4 --num-threads 3 --cmd "$decode_cmd" --acwt 0.10 --config conf/decode_dnn.config \ - $dir/graph_tgpr_5k $data_fmllr/et05_real_$enhan $dir/decode_tgpr_5k_et05_real_$enhan & - steps/nnet/decode.sh --nj 4 --num-threads 3 --cmd "$decode_cmd" --acwt 0.10 --config conf/decode_dnn.config \ - $dir/graph_tgpr_5k $data_fmllr/et05_simu_$enhan $dir/decode_tgpr_5k_et05_simu_$enhan & - fi - wait; -fi - -# Sequence training using sMBR criterion, we do Stochastic-GD -# with per-utterance updates. We use usually good acwt 0.1 -# Lattices are re-generated after 1st epoch, to get faster convergence. -dir=exp/tri4a_dnn_tr05_multi_${train}_smbr -srcdir=exp/tri4a_dnn_tr05_multi_${train} -acwt=0.1 - -# First we generate lattices and alignments: -# awk -v FS="/" '{ NF_nosuffix=$NF; sub(".gz","",NF_nosuffix); print NF_nosuffix gunzip -c "$0" |"; }' in -# steps/nnet/make_denlats.sh -if [ $stage -le 7 ]; then - steps/nnet/align.sh --nj $nj --cmd "$train_cmd" \ - $data_fmllr/tr05_multi_${train} data/lang $srcdir ${srcdir}_ali - steps/nnet/make_denlats.sh --nj $nj --cmd "$decode_cmd" --config conf/decode_dnn.config --acwt $acwt \ - $data_fmllr/tr05_multi_${train} data/lang $srcdir ${srcdir}_denlats -fi - -# Re-train the DNN by 1 iteration of sMBR -if [ $stage -le 8 ]; then - steps/nnet/train_mpe.sh --cmd "$cuda_cmd" --num-iters 1 --acwt $acwt --do-smbr true \ - $data_fmllr/tr05_multi_${train} data/lang $srcdir ${srcdir}_ali ${srcdir}_denlats $dir -fi - -# Decode (reuse HCLG graph) -if [ $stage -le 9 ]; then - for ITER in 1; do - steps/nnet/decode.sh --nj 4 --num-threads 3 --cmd "$decode_cmd" --config conf/decode_dnn.config \ - --nnet $dir/${ITER}.nnet --acwt $acwt \ - exp/tri4a_dnn_tr05_multi_${train}/graph_tgpr_5k $data_fmllr/dt05_real_${enhan} $dir/decode_tgpr_5k_dt05_real_${enhan}_it${ITER} & - steps/nnet/decode.sh --nj 4 --num-threads 3 --cmd "$decode_cmd" --config conf/decode_dnn.config \ - --nnet $dir/${ITER}.nnet --acwt $acwt \ - exp/tri4a_dnn_tr05_multi_${train}/graph_tgpr_5k $data_fmllr/dt05_simu_${enhan} $dir/decode_tgpr_5k_dt05_simu_${enhan}_it${ITER} & - if $eval_flag; then - steps/nnet/decode.sh --nj 4 --num-threads 3 --cmd "$decode_cmd" --config conf/decode_dnn.config \ - --nnet $dir/${ITER}.nnet --acwt $acwt \ - exp/tri4a_dnn_tr05_multi_${train}/graph_tgpr_5k $data_fmllr/et05_real_${enhan} $dir/decode_tgpr_5k_et05_real_${enhan}_it${ITER} & - steps/nnet/decode.sh --nj 4 --num-threads 3 --cmd "$decode_cmd" --config conf/decode_dnn.config \ - --nnet $dir/${ITER}.nnet --acwt $acwt \ - exp/tri4a_dnn_tr05_multi_${train}/graph_tgpr_5k $data_fmllr/et05_simu_${enhan} $dir/decode_tgpr_5k_et05_simu_${enhan}_it${ITER} & - fi - done -fi - -# Re-generate lattices, run 4 more sMBR iterations -dir=exp/tri4a_dnn_tr05_multi_${train}_smbr_i1lats -srcdir=exp/tri4a_dnn_tr05_multi_${train}_smbr -acwt=0.1 - -# Generate lattices and alignments: -if [ $stage -le 10 ]; then - steps/nnet/align.sh --nj $nj --cmd "$train_cmd" \ - $data_fmllr/tr05_multi_${train} data/lang $srcdir ${srcdir}_ali - steps/nnet/make_denlats.sh --nj $nj --cmd "$decode_cmd" --config conf/decode_dnn.config --acwt $acwt \ - $data_fmllr/tr05_multi_${train} data/lang $srcdir ${srcdir}_denlats -fi - -# Re-train the DNN by 4 iterations of sMBR -if [ $stage -le 11 ]; then - steps/nnet/train_mpe.sh --cmd "$cuda_cmd" --num-iters 4 --acwt $acwt --do-smbr true \ - $data_fmllr/tr05_multi_${train} data/lang $srcdir ${srcdir}_ali ${srcdir}_denlats $dir || exit 1 -fi - -# Decode (reuse HCLG graph) -if [ $stage -le 12 ]; then - for ITER in 1 2 3 4; do - steps/nnet/decode.sh --nj 4 --num-threads 3 --cmd "$decode_cmd" --config conf/decode_dnn.config \ - --nnet $dir/${ITER}.nnet --acwt $acwt \ - exp/tri4a_dnn_tr05_multi_${train}/graph_tgpr_5k $data_fmllr/dt05_real_${enhan} $dir/decode_tgpr_5k_dt05_real_${enhan}_it${ITER} & - steps/nnet/decode.sh --nj 4 --num-threads 3 --cmd "$decode_cmd" --config conf/decode_dnn.config \ - --nnet $dir/${ITER}.nnet --acwt $acwt \ - exp/tri4a_dnn_tr05_multi_${train}/graph_tgpr_5k $data_fmllr/dt05_simu_${enhan} $dir/decode_tgpr_5k_dt05_simu_${enhan}_it${ITER} & - if $eval_flag; then - steps/nnet/decode.sh --nj 4 --num-threads 3 --cmd "$decode_cmd" --config conf/decode_dnn.config \ - --nnet $dir/${ITER}.nnet --acwt $acwt \ - exp/tri4a_dnn_tr05_multi_${train}/graph_tgpr_5k $data_fmllr/et05_real_${enhan} $dir/decode_tgpr_5k_et05_real_${enhan}_it${ITER} & - steps/nnet/decode.sh --nj 4 --num-threads 3 --cmd "$decode_cmd" --config conf/decode_dnn.config \ - --nnet $dir/${ITER}.nnet --acwt $acwt \ - exp/tri4a_dnn_tr05_multi_${train}/graph_tgpr_5k $data_fmllr/et05_simu_${enhan} $dir/decode_tgpr_5k_et05_simu_${enhan}_it${ITER} & - fi - done - wait -fi - -# scoring -if [ $stage -le 13 ]; then - # decoded results of enhanced speech using DNN AMs trained with enhanced data - local/chime4_calc_wers.sh exp/tri4a_dnn_tr05_multi_${train} $enhan exp/tri4a_dnn_tr05_multi_${train}/graph_tgpr_5k \ - > exp/tri4a_dnn_tr05_multi_${train}/best_wer_$enhan.result - head -n 15 exp/tri4a_dnn_tr05_multi_${train}/best_wer_$enhan.result - # decoded results of enhanced speech using sequence-training DNN - ./local/chime4_calc_wers_smbr.sh exp/tri4a_dnn_tr05_multi_${train}_smbr_i1lats ${enhan} exp/tri4a_dnn_tr05_multi_${train}/graph_tgpr_5k \ - > exp/tri4a_dnn_tr05_multi_${train}_smbr_i1lats/best_wer_${enhan}.result - head -n 15 exp/tri4a_dnn_tr05_multi_${train}_smbr_i1lats/best_wer_${enhan}.result -fi - -echo "`basename $0` Done." diff --git a/egs/chime4/s5_1ch/local/run_dnn_recog.sh b/egs/chime4/s5_1ch/local/run_dnn_recog.sh deleted file mode 100755 index 5e6ade02387..00000000000 --- a/egs/chime4/s5_1ch/local/run_dnn_recog.sh +++ /dev/null @@ -1,143 +0,0 @@ -#!/bin/bash - -# Copyright 2016 University of Sheffield (Jon Barker, Ricard Marxer) -# Inria (Emmanuel Vincent) -# Mitsubishi Electric Research Labs (Shinji Watanabe) -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -# This script is made from the kaldi recipe of the 2nd CHiME Challenge Track 2 -# made by Chao Weng - -. ./path.sh -. ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. - ## This relates to the queue. - -# Config: -nj=30 -stage=0 # resume training with --stage=N -train=noisy -eval_flag=true # make it true when the evaluation data are released - -. utils/parse_options.sh || exit 1; - -# This is a shell script, but it's recommended that you run the commands one by -# one by copying and pasting into the shell. - -if [ $# -ne 2 ]; then - printf "\nUSAGE: %s \n\n" `basename $0` - echo "First argument specifies a unique name for different enhancement method" - echo "Second argument specifies acoustic and language model directory" - exit 1; -fi - -# set enhanced data -enhan=$1 -# set model directory -mdir=$2 - -# Set bash to 'debug' mode, it will exit on : -# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', -set -e -set -u -set -o pipefail - -# check data/loca/data -if [ ! -d $mdir/data/local/data ]; then - echo "error, set $mdir correctly" - exit 1; -elif [ ! -d data/local/data ]; then - echo "copy $mdir/data/local/data" - mkdir -p data/local - cp -r $mdir/data/local/data data/local/ -fi - -# check gmm model -if [ ! -d $mdir/exp/tri3b_tr05_multi_${train} ]; then - echo "error, set $mdir correctly" - exit 1; -elif [ ! -d exp/tri3b_tr05_multi_${train} ]; then - echo "copy $mdir/exp/tri3b_tr05_multi_${train}" - mkdir -p exp - cp -r $mdir/exp/tri3b_tr05_multi_${train} exp/ -fi - -# check dnn graph -if [ ! -d $mdir/exp/tri4a_dnn_tr05_multi_${train}/graph_tgpr_5k ]; then - echo "error, set $mdir correctly" - exit 1; -elif [ ! -d exp/tri4a_dnn_tr05_multi_${train}/graph_tgpr_5k ]; then - echo "copy $mdir/exp/tri4a_dnn_tr05_multi_${train}/graph_tgpr_5k" - mkdir -p exp/tri4a_dnn_tr05_multi_${train} - cp -r $mdir/exp/tri4a_dnn_tr05_multi_${train}/graph_tgpr_5k exp/tri4a_dnn_tr05_multi_${train}/ -fi - -# check dnn smbr model -if [ ! -d $mdir/exp/tri4a_dnn_tr05_multi_${train}_smbr_i1lats ]; then - echo "error, set $mdir correctly" - exit 1; -elif [ ! -d exp/tri4a_dnn_tr05_multi_${train}_smbr_i1lats ]; then - echo "copy $mdir/exp/tri4a_dnn_tr05_multi_${train}_smbr_i1lats" - mkdir -p exp - cp -r $mdir/exp/tri4a_dnn_tr05_multi_${train}_smbr_i1lats exp/ -fi - -# make fmllr feature for dev and eval -gmmdir=exp/tri3b_tr05_multi_${train} -data_fmllr=data-fmllr-tri3b -mkdir -p $data_fmllr -fmllrdir=fmllr-tri3b/$enhan -if [ $stage -le 4 ]; then - if $eval_flag; then - tasks="dt05_real_$enhan dt05_simu_$enhan et05_real_$enhan et05_simu_$enhan" - else - tasks="dt05_real_$enhan dt05_simu_$enhan" - fi - for x in $tasks; do - steps/nnet/make_fmllr_feats.sh --nj 4 --cmd "$train_cmd" \ - --transform-dir $gmmdir/decode_tgpr_5k_$x \ - $data_fmllr/$x data/$x $gmmdir exp/make_fmllr_tri3b/$x $fmllrdir - done -fi - -# make mixed training set from real and simulation enhanced data -# multi = simu + real -if [ $stage -le 5 ]; then - utils/combine_data.sh $data_fmllr/dt05_multi_$enhan $data_fmllr/dt05_simu_$enhan $data_fmllr/dt05_real_$enhan - if $eval_flag; then - utils/combine_data.sh $data_fmllr/et05_multi_$enhan $data_fmllr/et05_simu_$enhan $data_fmllr/et05_real_$enhan - fi -fi - -# Re-generate lattices, run 4 more sMBR iterations -dir=exp/tri4a_dnn_tr05_multi_${train}_smbr_i1lats -acwt=0.1 - -# Decode (reuse HCLG graph) -if [ $stage -le 6 ]; then - for ITER in 1 2 3 4; do - steps/nnet/decode.sh --nj 4 --num-threads 3 --cmd "$decode_cmd" --config conf/decode_dnn.config \ - --nnet $dir/${ITER}.nnet --acwt $acwt \ - exp/tri4a_dnn_tr05_multi_${train}/graph_tgpr_5k $data_fmllr/dt05_real_${enhan} $dir/decode_tgpr_5k_dt05_real_${enhan}_it${ITER} & - steps/nnet/decode.sh --nj 4 --num-threads 3 --cmd "$decode_cmd" --config conf/decode_dnn.config \ - --nnet $dir/${ITER}.nnet --acwt $acwt \ - exp/tri4a_dnn_tr05_multi_${train}/graph_tgpr_5k $data_fmllr/dt05_simu_${enhan} $dir/decode_tgpr_5k_dt05_simu_${enhan}_it${ITER} & - if $eval_flag; then - steps/nnet/decode.sh --nj 4 --num-threads 3 --cmd "$decode_cmd" --config conf/decode_dnn.config \ - --nnet $dir/${ITER}.nnet --acwt $acwt \ - exp/tri4a_dnn_tr05_multi_${train}/graph_tgpr_5k $data_fmllr/et05_real_${enhan} $dir/decode_tgpr_5k_et05_real_${enhan}_it${ITER} & - steps/nnet/decode.sh --nj 4 --num-threads 3 --cmd "$decode_cmd" --config conf/decode_dnn.config \ - --nnet $dir/${ITER}.nnet --acwt $acwt \ - exp/tri4a_dnn_tr05_multi_${train}/graph_tgpr_5k $data_fmllr/et05_simu_${enhan} $dir/decode_tgpr_5k_et05_simu_${enhan}_it${ITER} & - fi - wait - done -fi - -# scoring -if [ $stage -le 7 ]; then - # decoded results of enhanced speech using sequence-training DNN - ./local/chime4_calc_wers_smbr.sh $dir ${enhan} exp/tri4a_dnn_tr05_multi_${train}/graph_tgpr_5k > $dir/best_wer_${enhan}.result - head -n 15 $dir/best_wer_${enhan}.result -fi - -echo "`basename $0` Done." diff --git a/egs/chime4/s5_1ch/local/run_gmm.sh b/egs/chime4/s5_1ch/local/run_gmm.sh index 2a3c8680f23..5178433dfc2 100755 --- a/egs/chime4/s5_1ch/local/run_gmm.sh +++ b/egs/chime4/s5_1ch/local/run_gmm.sh @@ -17,6 +17,8 @@ nj=30 stage=0 # resume training with --stage=N train=noisy # noisy data multi-condition training eval_flag=true # make it true when the evaluation data are released +add_enhanced_data=true # make it true when you want to add enhanced data into training set +decode_only=false # if true, it wouldn't train a model again and will only do decoding . utils/parse_options.sh || exit 1; @@ -49,6 +51,33 @@ if [ ! -d data/lang ]; then exit 1; fi +if $decode_only; then + # check data/loca/data + mdir=`pwd` + if [ ! -d $mdir/data/local/data ]; then + echo "error, set $mdir correctly" + exit 1; + elif [ ! -d data/local/data ]; then + echo "copy $mdir/data/local/data" + mkdir -p data/local + cp -r $mdir/data/local/data data/local/ + fi + # check gmm model + if [ ! -d $mdir/exp/tri3b_tr05_multi_${train} ]; then + echo "error, set $mdir correctly" + exit 1; + elif [ ! -d exp/tri3b_tr05_multi_${train} ]; then + echo "copy $mdir/exp/tri3b_tr05_multi_${train}" + mkdir -p exp + cp -r $mdir/exp/tri3b_tr05_multi_${train} exp/ + fi + # process for enhanced data + if [ ! -d data/dt05_real_$enhan ] || [ ! -d data/et05_real_$enhan ]; then + local/real_enhan_chime4_data_prep.sh $enhan $enhan_data + local/simu_enhan_chime4_data_prep.sh $enhan $enhan_data + fi + stage=6 +fi ####################### #### training ######### if [ $stage -le 1 ]; then @@ -63,27 +92,51 @@ if [ $stage -le 1 ]; then local/simu_enhan_chime4_data_prep.sh $enhan $enhan_data fi fi +# Copy enhanced data for 1ch and 2ch experiments +if [ $stage -le 2 ] && [[ "$PWD" != *s5_6ch* ]]; then + beamformed=0 + # First remove empty files generated from previous stage + for d in tr05_{real,simu}_$enhan; do + [ -d data/$d ] && rm -rf data/$d && \ + echo "remove empty directory $d" + done + if [[ "$enhan" == *beamformit_2mics* ]] && [ -d ../s5_6ch/data/tr05_real_beamformit_5mics ]; then + echo "copy tr05_{real,simu}_beamformit_5mics from ../s5_6ch/data/" + cp -r ../s5_6ch/data/tr05_real_beamformit_5mics data/tr05_real_beamformit_2mics + cp -r ../s5_6ch/data/tr05_simu_beamformit_5mics data/tr05_simu_beamformit_2mics + beamformed=1 + elif [ -d ../s5_6ch/data/tr05_real_$enhan ]; then + echo "copy enhanced training data ${d} from ../s5_6ch/data/" + cp -r ../s5_6ch/data/tr05_real_$enhan data/ + cp -r ../s5_6ch/data/tr05_simu_$enhan data/ + beamformed=1 + elif [[ "$enhan" == *isolated_1ch_track* ]]; then + beamformed=1 + fi + if [ $beamformed == 0 ]; then + echo "no such directory tr05_{real,simu}_{beamformit_5mics,blstm_gev,single_BLSTMmask}" + echo "They are generated by run_beamform_6ch_track.sh in ../s5_6ch/run.sh, please execute it first" && \ + exit 1; + fi +fi # Now make MFCC features for clean, close, and noisy data # mfccdir should be some place with a largish disk where you # want to store MFCC features. mfccdir=mfcc -if [ $stage -le 2 ]; then - if $eval_flag; then - tasks="tr05_real_${train} dt05_real_${train} tr05_simu_${train} dt05_simu_${train} et05_real_${train} et05_simu_${train} tr05_real_$enhan tr05_simu_$enhan" +if [ $stage -le 3 ]; then + if $add_enhanced_data; then + if $eval_flag; then + tasks="tr05_real_${train} dt05_real_${train} tr05_simu_${train} dt05_simu_${train} et05_real_${train} et05_simu_${train} tr05_real_$enhan tr05_simu_$enhan" + else + tasks="tr05_real_${train} dt05_real_${train} tr05_simu_${train} dt05_simu_${train} tr05_real_$enhan tr05_simu_$enhan" + fi else - tasks="tr05_real_${train} dt05_real_${train} tr05_simu_${train} dt05_simu_${train} tr05_real_$enhan tr05_simu_$enhan" - fi - if [ "$enhan" == "beamformit_2mics" ]; then - for d in ../s5_6ch/data/tr05_{real,simu}_beamformit_5mics; do - [ ! -d $d ] && echo "no such directory $d" && \ - echo "It is generated by run_beamform_6ch_track.sh within ../s5_6ch/run.sh, execute it first" && \ - exit 1; - done - echo "copy enhanced training data from ../s5_6ch/data/" - rm -rf data/tr05_{real,simu}_beamformit_2mics - cp -r ../s5_6ch/data/tr05_real_beamformit_5mics data/tr05_real_beamformit_2mics - cp -r ../s5_6ch/data/tr05_simu_beamformit_5mics data/tr05_simu_beamformit_2mics + if $eval_flag; then + tasks="tr05_real_${train} dt05_real_${train} tr05_simu_${train} dt05_simu_${train} et05_real_${train} et05_simu_${train}" + else + tasks="tr05_real_${train} dt05_real_${train} tr05_simu_${train} dt05_simu_${train}" + fi fi for x in $tasks; do steps/make_mfcc.sh --nj 8 --cmd "$train_cmd" \ @@ -95,17 +148,20 @@ fi # make mixed training set from real and simulation training data # multi = simu + real # Note that we are combining enhanced training data with noisy training data -if [ $stage -le 3 ]; then - utils/combine_data.sh data/tr05_multi_${train} data/tr05_simu_${train} data/tr05_real_${train} data/tr05_simu_$enhan data/tr05_real_$enhan - #utils/combine_data.sh data/tr05_multi_${train} data/tr05_simu_${train} data/tr05_real_${train} +if [ $stage -le 4 ]; then + if $add_enhanced_data; then + utils/combine_data.sh data/tr05_multi_${train} data/tr05_simu_${train} data/tr05_real_${train} data/tr05_simu_$enhan data/tr05_real_$enhan + else + utils/combine_data.sh data/tr05_multi_${train} data/tr05_simu_${train} data/tr05_real_${train} + fi utils/combine_data.sh data/dt05_multi_${train} data/dt05_simu_${train} data/dt05_real_${train} if $eval_flag; then - utils/combine_data.sh data/et05_multi_${train} data/et05_simu_${train} data/et05_real_${train} + utils/combine_data.sh data/et05_multi_${train} data/et05_simu_${train} data/et05_real_${train} fi fi # training models for noisy data -if [ $stage -le 4 ]; then +if [ $stage -le 5 ]; then nspk=`wc -l data/tr05_multi_${train}/spk2utt | awk '{print $1}'` if [ $nj -gt $nspk ]; then nj2=$nspk diff --git a/egs/chime4/s5_1ch/local/run_gmm_recog.sh b/egs/chime4/s5_1ch/local/run_gmm_recog.sh deleted file mode 100755 index 5f7f47b39d7..00000000000 --- a/egs/chime4/s5_1ch/local/run_gmm_recog.sh +++ /dev/null @@ -1,127 +0,0 @@ -#!/bin/bash - -# Copyright 2016 University of Sheffield (Jon Barker, Ricard Marxer) -# Inria (Emmanuel Vincent) -# Mitsubishi Electric Research Labs (Shinji Watanabe) -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -# This script is made from the kaldi recipe of the 2nd CHiME Challenge Track 2 -# made by Chao Weng - -. ./path.sh -. ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. - ## This relates to the queue. - -# Config: -nj=30 -stage=0 # resume training with --stage=N -train=noisy -eval_flag=true # make it true when the evaluation data are released - -. utils/parse_options.sh || exit 1; - -# This is a shell script, but it's recommended that you run the commands one by -# one by copying and pasting into the shell. - -if [ $# -ne 3 ]; then - printf "\nUSAGE: %s \n\n" `basename $0` - echo "First argument specifies a unique name for different enhancement method" - echo "Second argument specifies the directory of enhanced wav files" - echo "Third argument specifies acoustic and language model directory" - exit 1; -fi - -# set enhanced data -enhan=$1 -enhan_data=$2 -# set model directory -mdir=$3 - -# Set bash to 'debug' mode, it will exit on : -# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', -set -e -set -u -set -o pipefail - -# check data/loca/data -if [ ! -d $mdir/data/local/data ]; then - echo "error, set $mdir correctly" - exit 1; -elif [ ! -d data/local/data ]; then - echo "copy $mdir/data/local/data" - mkdir -p data/local - cp -r $mdir/data/local/data data/local/ -fi - -# check gmm model -if [ ! -d $mdir/exp/tri3b_tr05_multi_${train} ]; then - echo "error, set $mdir correctly" - exit 1; -elif [ ! -d exp/tri3b_tr05_multi_${train} ]; then - echo "copy $mdir/exp/tri3b_tr05_multi_${train}" - mkdir -p exp - cp -r $mdir/exp/tri3b_tr05_multi_${train} exp/ -fi - -# process for enhanced data -if [ $stage -le 0 ]; then - if [ ! -d data/dt05_real_$enhan ] || [ ! -d data/et05_real_$enhan ]; then - local/real_enhan_chime4_data_prep.sh $enhan $enhan_data - local/simu_enhan_chime4_data_prep.sh $enhan $enhan_data - fi -fi - -# Now make MFCC features for enhanced data -# mfccdir should be some place with a largish disk where you -# want to store MFCC features. -mfccdir=mfcc/$enhan -if [ $stage -le 1 ]; then - if $eval_flag; then - tasks="dt05_real_$enhan dt05_simu_$enhan et05_real_$enhan et05_simu_$enhan" - else - tasks="dt05_real_$enhan dt05_simu_$enhan" - fi - for x in $tasks; do - if [ ! -e data/$x/feats.scp ]; then - steps/make_mfcc.sh --nj 8 --cmd "$train_cmd" \ - data/$x exp/make_mfcc/$x $mfccdir - steps/compute_cmvn_stats.sh data/$x exp/make_mfcc/$x $mfccdir - fi - done -fi - -# make mixed training set from real and simulation enhanced data -# multi = simu + real -if [ $stage -le 2 ]; then - if [ ! -d data/dt05_multi_$enhan ] || [ ! -d data/et05_multi_$enhan ]; then - utils/combine_data.sh data/dt05_multi_$enhan data/dt05_simu_$enhan data/dt05_real_$enhan - if $eval_flag; then - utils/combine_data.sh data/et05_multi_$enhan data/et05_simu_$enhan data/et05_real_$enhan - fi - fi -fi - -# decode enhanced speech using AMs trained with enhanced data -if [ $stage -le 3 ]; then - steps/decode_fmllr.sh --nj 4 --num-threads 3 --cmd "$decode_cmd" \ - exp/tri3b_tr05_multi_${train}/graph_tgpr_5k data/dt05_real_$enhan exp/tri3b_tr05_multi_${train}/decode_tgpr_5k_dt05_real_$enhan & - steps/decode_fmllr.sh --nj 4 --num-threads 3 --cmd "$decode_cmd" \ - exp/tri3b_tr05_multi_${train}/graph_tgpr_5k data/dt05_simu_$enhan exp/tri3b_tr05_multi_${train}/decode_tgpr_5k_dt05_simu_$enhan & - if $eval_flag; then - steps/decode_fmllr.sh --nj 4 --num-threads 3 --cmd "$decode_cmd" \ - exp/tri3b_tr05_multi_${train}/graph_tgpr_5k data/et05_real_$enhan exp/tri3b_tr05_multi_${train}/decode_tgpr_5k_et05_real_$enhan & - steps/decode_fmllr.sh --nj 4 --num-threads 3 --cmd "$decode_cmd" \ - exp/tri3b_tr05_multi_${train}/graph_tgpr_5k data/et05_simu_$enhan exp/tri3b_tr05_multi_${train}/decode_tgpr_5k_et05_simu_$enhan & - fi - wait; -fi - -# scoring -if [ $stage -le 4 ]; then - # decoded results of enhanced speech using AMs trained with enhanced data - local/chime4_calc_wers.sh exp/tri3b_tr05_multi_${train} $enhan exp/tri3b_tr05_multi_${train}/graph_tgpr_5k \ - > exp/tri3b_tr05_multi_${train}/best_wer_$enhan.result - head -n 15 exp/tri3b_tr05_multi_${train}/best_wer_$enhan.result -fi - -echo "`basename $0` Done." diff --git a/egs/chime4/s5_1ch/local/run_lmrescore_recog.sh b/egs/chime4/s5_1ch/local/run_lmrescore_recog.sh deleted file mode 100755 index 8b57585fda0..00000000000 --- a/egs/chime4/s5_1ch/local/run_lmrescore_recog.sh +++ /dev/null @@ -1,121 +0,0 @@ -#!/bin/bash - -# Copyright 2015 University of Sheffield (Jon Barker, Ricard Marxer) -# Inria (Emmanuel Vincent) -# Mitsubishi Electric Research Labs (Shinji Watanabe) -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -# Copyright 2015, Mitsubishi Electric Research Laboratories, MERL (Author: Takaaki Hori) - -nj=12 -stage=1 -order=5 -hidden=300 -rnnweight=0.5 -nbest=100 -train=noisy -eval_flag=true # make it true when the evaluation data are released - -. utils/parse_options.sh || exit 1; - -. ./path.sh -. ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. - ## This relates to the queue. - -# This is a shell script, but it's recommended that you run the commands one by -# one by copying and pasting into the shell. - -if [ $# -ne 2 ]; then - printf "\nUSAGE: %s \n\n" `basename $0` - echo "First argument specifies a unique name for different enhancement method" - echo "Second argument specifies acoustic and language model directory" - exit 1; -fi - -# set language models -lm_suffix=${order}gkn_5k -rnnlm_suffix=rnnlm_5k_h${hidden} - -# enhan data -enhan=$1 -# set model directory -mdir=$2 -srcdir=exp/tri4a_dnn_tr05_multi_${train}_smbr_i1lats - -# check language models -if [ ! -d $mdir/data/lang ]; then - echo "error, set $mdir correctly" - exit 1; -fi - -# preparation -dir=exp/tri4a_dnn_tr05_multi_${train}_smbr_lmrescore -mkdir -p $dir -# make a symbolic link to graph info -if [ ! -e $dir/graph_tgpr_5k ]; then - if [ ! -e exp/tri4a_dnn_tr05_multi_${train}/graph_tgpr_5k ]; then - echo "graph is missing, execute local/run_dnn.sh, correctly" - exit 1; - fi - pushd . ; cd $dir - ln -s ../tri4a_dnn_tr05_multi_${train}/graph_tgpr_5k . - popd -fi - -# rescore lattices by a high-order N-gram -if [ $stage -le 3 ]; then - # check the best iteration - if [ ! -f $srcdir/log/best_wer_$enhan ]; then - echo "$0: error $srcdir/log/best_wer_$enhan not found. execute local/run_dnn.sh, first" - exit 1; - fi - it=`cut -f 1 -d" " $srcdir/log/best_wer_$enhan | awk -F'[_]' '{print $1}'` - # rescore lattices - if $eval_flag; then - tasks="dt05_simu dt05_real et05_simu et05_real" - else - tasks="dt05_simu dt05_real" - fi - for t in $tasks; do - steps/lmrescore.sh --mode 3 \ - $mdir/data/lang_test_tgpr_5k \ - $mdir/data/lang_test_${lm_suffix} \ - data-fmllr-tri3b/${t}_$enhan \ - $srcdir/decode_tgpr_5k_${t}_${enhan}_it$it \ - $dir/decode_tgpr_5k_${t}_${enhan}_${lm_suffix} - done - # rescored results by high-order n-gram LM - mkdir -p $dir/log - local/chime4_calc_wers.sh $dir ${enhan}_${lm_suffix} $dir/graph_tgpr_5k \ - > $dir/best_wer_${enhan}_${lm_suffix}.result - head -n 15 $dir/best_wer_${enhan}_${lm_suffix}.result -fi - -# N-best rescoring using a RNNLM -if [ $stage -le 4 ]; then - # check the best lmw - if [ ! -f $dir/log/best_wer_${enhan}_${lm_suffix} ]; then - echo "error, rescoring with a high-order n-gram seems to be failed" - exit 1; - fi - lmw=`cut -f 1 -d" " $dir/log/best_wer_${enhan}_${lm_suffix} | awk -F'[_]' '{print $NF}'` - # rescore n-best list for all sets - if $eval_flag; then - tasks="dt05_simu dt05_real et05_simu et05_real" - else - tasks="dt05_simu dt05_real" - fi - for t in $tasks; do - steps/rnnlmrescore.sh --inv-acwt $lmw --N $nbest --use-phi true \ - $rnnweight \ - $mdir/data/lang_test_${lm_suffix} \ - $mdir/data/lang_test_${rnnlm_suffix} \ - data-fmllr-tri3b/${t}_$enhan \ - $dir/decode_tgpr_5k_${t}_${enhan}_${lm_suffix} \ - $dir/decode_tgpr_5k_${t}_${enhan}_${rnnlm_suffix}_w${rnnweight}_n${nbest} - done - # calc wers for RNNLM results - local/chime4_calc_wers.sh $dir ${enhan}_${rnnlm_suffix}_w${rnnweight}_n${nbest} $dir/graph_tgpr_5k \ - > $dir/best_wer_${enhan}_${rnnlm_suffix}_w${rnnweight}_n${nbest}.result - head -n 15 $dir/best_wer_${enhan}_${rnnlm_suffix}_w${rnnweight}_n${nbest}.result -fi diff --git a/egs/chime4/s5_1ch/local/run_lmrescore_tdnn.sh b/egs/chime4/s5_1ch/local/run_lmrescore_tdnn.sh index 67572f0dd4c..58af793615e 100755 --- a/egs/chime4/s5_1ch/local/run_lmrescore_tdnn.sh +++ b/egs/chime4/s5_1ch/local/run_lmrescore_tdnn.sh @@ -98,7 +98,7 @@ if [ $stage -le 3 ]; then steps/lmrescore.sh --mode 3 \ data/lang_test_tgpr_5k \ data/lang_test_${lm_suffix} \ - data/${t}_${enhan}_hires \ + data/${t}_${enhan}_chunked \ $srcdir/decode_tgpr_5k_${t}_${enhan} \ $dir/decode_tgpr_5k_${t}_${enhan}_${lm_suffix} done @@ -128,7 +128,7 @@ if [ $stage -le 4 ]; then $rnnweight \ data/lang_test_${lm_suffix} \ data/lang_test_${rnnlm_suffix} \ - data/${t}_${enhan}_hires \ + data/${t}_${enhan}_chunked \ $dir/decode_tgpr_5k_${t}_${enhan}_${lm_suffix} \ $dir/decode_tgpr_5k_${t}_${enhan}_${rnnlm_suffix}_w${rnnweight}_n${nbest} done diff --git a/egs/chime4/s5_1ch/local/run_lmrescore_tdnn_lstm.sh b/egs/chime4/s5_1ch/local/run_lmrescore_tdnn_lstm.sh index 7173dcea78b..0bea4dd7102 100755 --- a/egs/chime4/s5_1ch/local/run_lmrescore_tdnn_lstm.sh +++ b/egs/chime4/s5_1ch/local/run_lmrescore_tdnn_lstm.sh @@ -165,4 +165,4 @@ if [ $stage -le 4 ]; then local/chime4_calc_wers_looped.sh $dir ${enhan}_${rnnlm_suffix}_w${rnnweight}_n${nbest} $dir/graph_tgpr_5k \ > $dir/best_wer_looped_${enhan}_${rnnlm_suffix}_w${rnnweight}_n${nbest}.result head -n 15 $dir/best_wer_looped_${enhan}_${rnnlm_suffix}_w${rnnweight}_n${nbest}.result -fi \ No newline at end of file +fi diff --git a/egs/chime4/s5_1ch/local/run_lmrescore_tdnn_lstm_recog.sh b/egs/chime4/s5_1ch/local/run_lmrescore_tdnn_lstm_recog.sh deleted file mode 100755 index c4b4e238011..00000000000 --- a/egs/chime4/s5_1ch/local/run_lmrescore_tdnn_lstm_recog.sh +++ /dev/null @@ -1,153 +0,0 @@ -#!/bin/bash - -# Copyright 2015 University of Sheffield (Jon Barker, Ricard Marxer) -# Inria (Emmanuel Vincent) -# Mitsubishi Electric Research Labs (Shinji Watanabe) -# 2017 JHU CLSP (Szu-Jui Chen) -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -# Copyright 2015, Mitsubishi Electric Research Laboratories, MERL (Author: Takaaki Hori) - -nj=12 -stage=1 -order=5 -hidden=300 -rnnweight=0.5 -nbest=100 -train=noisy -eval_flag=true # make it true when the evaluation data are released - -. utils/parse_options.sh || exit 1; - -. ./path.sh -. ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. - ## This relates to the queue. - -# This is a shell script, but it's recommended that you run the commands one by -# one by copying and pasting into the shell. - -if [ $# -ne 2 ]; then - printf "\nUSAGE: %s \n\n" `basename $0` - echo "First argument specifies a unique name for different enhancement method" - echo "Second argument specifies acoustic and language model directory" - exit 1; -fi - -# set language models -# You might need to change affix to the affix of your best tdnn model. -affix=1a -lm_suffix=${order}gkn_5k -rnnlm_suffix=rnnlm_5k_h${hidden} - -# enhan data -enhan=$1 -# set model directory -mdir=$2 -srcdir=exp/chain/tdnn_lstm${affix}_sp - -# check language models -if [ ! -d $mdir/data/lang ]; then - echo "error, set $mdir correctly" - exit 1; -fi - -# preparation -dir=exp/chain/tdnn_lstm${affix}_sp_smbr_lmrescore -mkdir -p $dir -# make a symbolic link to graph info -if [ ! -e $dir/graph_tgpr_5k ]; then - if [ ! -e exp/chain/tree_a_sp/graph_tgpr_5k ]; then - echo "graph is missing, execute local/run_tdnn.sh, correctly" - exit 1; - fi - pushd . ; cd $dir - ln -s ../tree_a_sp/graph_tgpr_5k . - popd -fi - -# rescore lattices by a high-order N-gram -if [ $stage -le 3 ]; then - # check the best iteration - if [ ! -f $srcdir/log/best_wer_$enhan ]; then - echo "$0: error $srcdir/log/best_wer_$enhan not found. execute local/run_tdnn_lstm.sh, first" - exit 1; - fi - it=`cut -f 1 -d" " $srcdir/log/best_wer_$enhan | awk -F'[_]' '{print $1}'` - # rescore lattices - if $eval_flag; then - tasks="dt05_simu dt05_real et05_simu et05_real" - else - tasks="dt05_simu dt05_real" - fi - for t in $tasks; do - steps/lmrescore.sh --mode 3 \ - $mdir/data/lang_test_tgpr_5k \ - $mdir/data/lang_test_${lm_suffix} \ - data/${t}_${enhan}_hires \ - $srcdir/decode_tgpr_5k_${t}_${enhan} \ - $dir/decode_tgpr_5k_${t}_${enhan}_${lm_suffix} - done - # rescored results by high-order n-gram LM - mkdir -p $dir/log - local/chime4_calc_wers.sh $dir ${enhan}_${lm_suffix} $dir/graph_tgpr_5k \ - > $dir/best_wer_${enhan}_${lm_suffix}.result - head -n 15 $dir/best_wer_${enhan}_${lm_suffix}.result - - # now rescore lattices after looped decoding - for t in $tasks; do - steps/lmrescore.sh --mode 3 \ - data/lang_test_tgpr_5k \ - data/lang_test_${lm_suffix} \ - data/${t}_${enhan}_hires \ - $srcdir/decode_looped_tgpr_5k_${t}_${enhan} \ - $dir/decode_looped_tgpr_5k_${t}_${enhan}_${lm_suffix} - done - # rescored results by high-order n-gram LM - local/chime4_calc_wers_looped.sh $dir ${enhan}_${lm_suffix} $dir/graph_tgpr_5k \ - > $dir/best_wer_looped_${enhan}_${lm_suffix}.result - head -n 15 $dir/best_wer_looped_${enhan}_${lm_suffix}.result -fi - -# N-best rescoring using a RNNLM -if [ $stage -le 4 ]; then - # check the best lmw - if [ ! -f $dir/log/best_wer_${enhan}_${lm_suffix} ]; then - echo "error, rescoring with a high-order n-gram seems to be failed" - exit 1; - fi - lmw=`cut -f 1 -d" " $dir/log/best_wer_${enhan}_${lm_suffix} | awk -F'[_]' '{print $NF}'` - # rescore n-best list for all sets - if $eval_flag; then - tasks="dt05_simu dt05_real et05_simu et05_real" - else - tasks="dt05_simu dt05_real" - fi - for t in $tasks; do - steps/rnnlmrescore.sh --inv-acwt $lmw --N $nbest --use-phi true \ - $rnnweight \ - $mdir/data/lang_test_${lm_suffix} \ - $mdir/data/lang_test_${rnnlm_suffix} \ - data/${t}_${enhan}_hires \ - $dir/decode_tgpr_5k_${t}_${enhan}_${lm_suffix} \ - $dir/decode_tgpr_5k_${t}_${enhan}_${rnnlm_suffix}_w${rnnweight}_n${nbest} - done - # calc wers for RNNLM results - local/chime4_calc_wers.sh $dir ${enhan}_${rnnlm_suffix}_w${rnnweight}_n${nbest} $dir/graph_tgpr_5k \ - > $dir/best_wer_${enhan}_${rnnlm_suffix}_w${rnnweight}_n${nbest}.result - head -n 15 $dir/best_wer_${enhan}_${rnnlm_suffix}_w${rnnweight}_n${nbest}.result - - # now rescore lattices after looped decoding - for t in $tasks; do - steps/rnnlmrescore.sh --inv-acwt $lmw --N $nbest --use-phi true \ - $rnnweight \ - data/lang_test_${lm_suffix} \ - data/lang_test_${rnnlm_suffix} \ - data/${t}_${enhan}_hires \ - $dir/decode_looped_tgpr_5k_${t}_${enhan}_${lm_suffix} \ - $dir/decode_looped_tgpr_5k_${t}_${enhan}_${rnnlm_suffix}_w${rnnweight}_n${nbest} - done - # calc wers for RNNLM results - local/chime4_calc_wers_looped.sh $dir ${enhan}_${rnnlm_suffix}_w${rnnweight}_n${nbest} $dir/graph_tgpr_5k \ - > $dir/best_wer_looped_${enhan}_${rnnlm_suffix}_w${rnnweight}_n${nbest}.result - head -n 15 $dir/best_wer_looped_${enhan}_${rnnlm_suffix}_w${rnnweight}_n${nbest}.result -fi diff --git a/egs/chime4/s5_1ch/local/run_lmrescore_tdnn_recog.sh b/egs/chime4/s5_1ch/local/run_lmrescore_tdnn_recog.sh deleted file mode 100755 index 4508ddeb9f4..00000000000 --- a/egs/chime4/s5_1ch/local/run_lmrescore_tdnn_recog.sh +++ /dev/null @@ -1,124 +0,0 @@ -#!/bin/bash - -# Copyright 2015 University of Sheffield (Jon Barker, Ricard Marxer) -# Inria (Emmanuel Vincent) -# Mitsubishi Electric Research Labs (Shinji Watanabe) -# 2017 JHU CLSP (Szu-Jui Chen) -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -# Copyright 2015, Mitsubishi Electric Research Laboratories, MERL (Author: Takaaki Hori) - -nj=12 -stage=1 -order=5 -hidden=300 -rnnweight=0.5 -nbest=100 -train=noisy -eval_flag=true # make it true when the evaluation data are released - -. utils/parse_options.sh || exit 1; - -. ./path.sh -. ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. - ## This relates to the queue. - -# This is a shell script, but it's recommended that you run the commands one by -# one by copying and pasting into the shell. - -if [ $# -ne 2 ]; then - printf "\nUSAGE: %s \n\n" `basename $0` - echo "First argument specifies a unique name for different enhancement method" - echo "Second argument specifies acoustic and language model directory" - exit 1; -fi - -# set language models -# You might need to change affix to the affix of your best tdnn model. -affix=1a -lm_suffix=${order}gkn_5k -rnnlm_suffix=rnnlm_5k_h${hidden} - -# enhan data -enhan=$1 -# set model directory -mdir=$2 -srcdir=exp/chain/tdnn${affix}_sp - -# check language models -if [ ! -d $mdir/data/lang ]; then - echo "error, set $mdir correctly" - exit 1; -fi - -# preparation -dir=exp/chain/tdnn${affix}_sp_smbr_lmrescore -mkdir -p $dir -# make a symbolic link to graph info -if [ ! -e $dir/graph_tgpr_5k ]; then - if [ ! -e exp/chain/tree_a_sp/graph_tgpr_5k ]; then - echo "graph is missing, execute local/run_tdnn.sh, correctly" - exit 1; - fi - pushd . ; cd $dir - ln -s ../tree_a_sp/graph_tgpr_5k . - popd -fi - -# rescore lattices by a high-order N-gram -if [ $stage -le 3 ]; then - # check the best iteration - if [ ! -f $srcdir/log/best_wer_$enhan ]; then - echo "$0: error $srcdir/log/best_wer_$enhan not found. execute local/run_tdnn.sh, first" - exit 1; - fi - it=`cut -f 1 -d" " $srcdir/log/best_wer_$enhan | awk -F'[_]' '{print $1}'` - # rescore lattices - if $eval_flag; then - tasks="dt05_simu dt05_real et05_simu et05_real" - else - tasks="dt05_simu dt05_real" - fi - for t in $tasks; do - steps/lmrescore.sh --mode 3 \ - $mdir/data/lang_test_tgpr_5k \ - $mdir/data/lang_test_${lm_suffix} \ - data/${t}_${enhan}_hires \ - $srcdir/decode_tgpr_5k_${t}_${enhan} \ - $dir/decode_tgpr_5k_${t}_${enhan}_${lm_suffix} - done - # rescored results by high-order n-gram LM - mkdir -p $dir/log - local/chime4_calc_wers.sh $dir ${enhan}_${lm_suffix} $dir/graph_tgpr_5k \ - > $dir/best_wer_${enhan}_${lm_suffix}.result - head -n 15 $dir/best_wer_${enhan}_${lm_suffix}.result -fi - -# N-best rescoring using a RNNLM -if [ $stage -le 4 ]; then - # check the best lmw - if [ ! -f $dir/log/best_wer_${enhan}_${lm_suffix} ]; then - echo "error, rescoring with a high-order n-gram seems to be failed" - exit 1; - fi - lmw=`cut -f 1 -d" " $dir/log/best_wer_${enhan}_${lm_suffix} | awk -F'[_]' '{print $NF}'` - # rescore n-best list for all sets - if $eval_flag; then - tasks="dt05_simu dt05_real et05_simu et05_real" - else - tasks="dt05_simu dt05_real" - fi - for t in $tasks; do - steps/rnnlmrescore.sh --inv-acwt $lmw --N $nbest --use-phi true \ - $rnnweight \ - $mdir/data/lang_test_${lm_suffix} \ - $mdir/data/lang_test_${rnnlm_suffix} \ - data/${t}_${enhan}_hires \ - $dir/decode_tgpr_5k_${t}_${enhan}_${lm_suffix} \ - $dir/decode_tgpr_5k_${t}_${enhan}_${rnnlm_suffix}_w${rnnweight}_n${nbest} - done - # calc wers for RNNLM results - local/chime4_calc_wers.sh $dir ${enhan}_${rnnlm_suffix}_w${rnnweight}_n${nbest} $dir/graph_tgpr_5k \ - > $dir/best_wer_${enhan}_${rnnlm_suffix}_w${rnnweight}_n${nbest}.result - head -n 15 $dir/best_wer_${enhan}_${rnnlm_suffix}_w${rnnweight}_n${nbest}.result -fi diff --git a/egs/chime4/s5_1ch/local/run_nn-gev.sh b/egs/chime4/s5_1ch/local/run_nn-gev.sh new file mode 100755 index 00000000000..a17dd3d3f15 --- /dev/null +++ b/egs/chime4/s5_1ch/local/run_nn-gev.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Copyright 2017 Johns Hopkins University (Author: Aswin Shanmugam Subramanian) +# Apache 2.0 + +. ./cmd.sh +. ./path.sh + +if [ $# != 4 ]; then + echo "Wrong #arguments ($#, expected 3)" + echo "Usage: local/run_nn-gev.sh " + exit 1; +fi + +sdir=$1 +odir=$2 +enhancement_type=$3 +track=$4 + +gpu_id=1 +case $(hostname -f) in + *.clsp.jhu.edu) gpu_id=`free-gpu` ;; # JHU, +esac + +if [ ! -f local/nn-gev/data/BLSTM_model/mlp.tr ]; then + echo "training a BLSTM mask network" + $HOME/miniconda3/bin/python local/nn-gev/train.py --chime_dir=$sdir/data --gpu $gpu_id local/nn-gev/data BLSTM +else + echo "Not training a BLSTM mask network. Using existing model in local/nn-gev/data/BLSTM_model/" +fi +echo "enhancing signals with mask-based GEV beamformer" +local/nn-gev/beamform.sh $sdir/data local/nn-gev/data $odir local/nn-gev/data/BLSTM_model/best.nnet BLSTM --gpu $gpu_id --single $enhancement_type --track $track diff --git a/egs/chime4/s5_1ch/local/simu_noisy_chime4_data_prep.sh b/egs/chime4/s5_1ch/local/simu_noisy_chime4_data_prep.sh index 03e355a82ec..124cde82b8a 100755 --- a/egs/chime4/s5_1ch/local/simu_noisy_chime4_data_prep.sh +++ b/egs/chime4/s5_1ch/local/simu_noisy_chime4_data_prep.sh @@ -69,8 +69,12 @@ fi # make a scp file from file list for x in $list_set; do - cat $x.flist | awk -F'[/]' '{print $NF}'| sed -e 's/\.wav/_SIMU/' > ${x}_wav.ids - paste -d" " ${x}_wav.ids $x.flist | sort -k 1 > ${x}_wav.scp + cat $x.flist | awk -F'[/]' '{print $NF}'| sed -e 's/\.wav/_SIMU/' > ${x}_wav.id.temp + cat ${x}_wav.id.temp | awk -F'_' '{print $3}' | awk -F'.' '{print $2}' > $x.ch + cat ${x}_wav.id.temp | awk -F'_' '{print $1}' > $x.part1 + cat ${x}_wav.id.temp | sed -e 's/^..._//' > $x.part2 + paste -d"_" $x.part1 $x.ch $x.part2 > ${x}_wav.ids + paste -d" " ${x}_wav.ids $x.flist | sort -t_ -k1,1 -k3 > ${x}_wav.scp.temp done # make a transcription from dot @@ -80,10 +84,10 @@ if [ ! -e dot_files.flist ]; then echo "Could not find $dir/dot_files.flist files, first run local/clean_wsj0_data_prep.sh"; exit 1; fi -cat tr05_simu_noisy_wav.scp | awk -F'[_]' '{print $2}' | tr '[A-Z]' '[a-z]' \ +cat tr05_simu_noisy_wav.scp.temp | awk -F'[_]' '{print $3}' | tr '[A-Z]' '[a-z]' \ | $local/find_noisy_transcripts.pl dot_files.flist | cut -f 2- -d" " > tr05_simu_noisy.txt -cat tr05_simu_noisy_wav.scp | cut -f 1 -d" " > tr05_simu_noisy.ids -paste -d" " tr05_simu_noisy.ids tr05_simu_noisy.txt | sort -k 1 > tr05_simu_noisy.trans1 +cat tr05_simu_noisy_wav.scp.temp | cut -f 1 -d" " > tr05_simu_noisy.ids +paste -d" " tr05_simu_noisy.ids tr05_simu_noisy.txt | sort -t_ -k1,1 -k3 > tr05_simu_noisy.trans1 # dt05 and et05 simulation data are generated from the CHiME4 booth recording # and we use CHiME4 dot files cat dt05_simu.dot | sed -e 's/(\(.*\))/\1/' | awk '{print $NF ".CH1_SIMU"}'> dt05_simu_noisy.ids @@ -104,13 +108,17 @@ fi # data-preparation stage independent of the specific lexicon used. noiseword=""; for x in $list_set;do + cat ${x}_wav.scp.temp | awk '{print $1}' > $x.txt.part1 + cat $x.trans1 | awk '{$1=""; print $0}' | sed 's/^[ \t]*//g' > $x.txt.part2 + paste -d" " $x.txt.part1 $x.txt.part2 > $x.trans1 cat $x.trans1 | $local/normalize_transcript.pl $noiseword \ | sort > $x.txt || exit 1; done # Make the utt2spk and spk2utt files. for x in $list_set; do - cat ${x}_wav.scp | awk -F'_' '{print $1}' > $x.spk + sort ${x}_wav.scp.temp > ${x}_wav.scp + cat ${x}_wav.scp | awk -F'_' '{print $1"_"$2}' > $x.spk cat ${x}_wav.scp | awk '{print $1}' > $x.utt paste -d" " $x.utt $x.spk > $x.utt2spk cat $x.utt2spk | $utils/utt2spk_to_spk2utt.pl > $x.spk2utt || exit 1; @@ -125,4 +133,8 @@ for x in $list_set; do cp ${x}.utt2spk ../../$x/utt2spk || exit 1; done +# clean up temp files +rm *.temp +rm *.part{1,2} + echo "Data preparation succeeded" diff --git a/egs/chime4/s5_1ch/local/stoi_estoi_sdr.m b/egs/chime4/s5_1ch/local/stoi_estoi_sdr.m new file mode 100644 index 00000000000..45047fe1884 --- /dev/null +++ b/egs/chime4/s5_1ch/local/stoi_estoi_sdr.m @@ -0,0 +1,62 @@ +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +% Copyright 2017 Johns Hopkins University (Author: Aswin Shanmugam Subramanian) +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +function stoi_estoi_sdr(nj,enhancement_method,destination_directory,set) + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +% "stoi_estoi_sdr" : this function computes the average STOI, eSTOI and SDR +% scores by calling downloaded third party matlab functions +% +% Input: +% nj: number of jobs +% enhancement_method: the name of the enhacement method +% destination_directory: the directory where the results have to be stored, +% the list of the enhaced and reference files are +% stored here before calling this function +% set: name of the set to be evaluated ('et05' or 'dt05') +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +original_file_list=strcat(destination_directory,'/original_list'); +enhanced_file_list=strcat(destination_directory,'/enhanced_list'); +files1=textread(original_file_list,'%s'); +files2=textread(enhanced_file_list,'%s'); +d_stoi=zeros(1,length(files2)); +d_estoi=zeros(1,length(files2)); +SDR=zeros(1,length(files2)); +p = parpool('local', nj); +parfor i=1:length(files2) + [x, fs] = audioread(files1{i}); + [y, fs] = audioread(files2{i}); + m=length(x); + n=length(y); + d=abs(m-n); + if m>n + y=[y; zeros(d,1)]; + end + if n>m + x=[x; zeros(d,1)]; + end + + d_stoi(i)=stoi(x,y,fs); + d_estoi(i)=estoi(x,y,fs); + [SDR(i),SIR,SAR,perm]=bss_eval_sources(y',x'); +end +SDR_avg=mean(SDR); +STOI_avg=mean(d_stoi); +ESTOI_avg=mean(d_estoi); +SDRFile=strcat(destination_directory,'/',enhancement_method,'_',set,'_SDR'); +stoiFile=strcat(destination_directory,'/',enhancement_method,'_',set,'_STOI'); +estoiFile=strcat(destination_directory,'/',enhancement_method,'_',set,'_eSTOI'); +fileID = fopen(SDRFile,'w'); +fprintf(fileID,'%f\n',SDR_avg); +fclose(fileID); +fileID = fopen(stoiFile,'w'); +fprintf(fileID,'%f\n',STOI_avg); +fclose(fileID); +fileID = fopen(estoiFile,'w'); +fprintf(fileID,'%f\n',ESTOI_avg); +fclose(fileID); +ResultMATFile=strcat(destination_directory,'/',enhancement_method,'_',set,'_stoi_estoi_sdr.mat'); +save(ResultMATFile,'SDR','d_stoi','d_estoi'); +end diff --git a/egs/chime4/s5_1ch/local/write_se_results.sh b/egs/chime4/s5_1ch/local/write_se_results.sh new file mode 100755 index 00000000000..7ada63f8ccc --- /dev/null +++ b/egs/chime4/s5_1ch/local/write_se_results.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# Copyright 2017 Johns Hopkins University (Author: Aswin Shanmugam Subramanian) +# Apache 2.0 + +. ./cmd.sh +. ./path.sh + +# Config: + +if [ $# != 1 ]; then + echo "Wrong #arguments ($#, expected 1)" + echo "Usage: local/write_se_results.sh " + exit 1; +fi + +enhancement=$1 + +echo -e "PESQ ($enhancement) \t dt05_simu=$(cat exp/compute_pesq_$enhancement/pesq_dt05) \t et05_simu=$(cat exp/compute_pesq_$enhancement/pesq_et05)" +echo -e "STOI ($enhancement) \t dt05_simu=$(cat exp/compute_stoi_estoi_sdr_$enhancement/${enhancement}_dt05_STOI) \t et05_simu=$(cat exp/compute_stoi_estoi_sdr_$enhancement/${enhancement}_et05_STOI)" +echo -e "eSTOI ($enhancement) \t dt05_simu=$(cat exp/compute_stoi_estoi_sdr_$enhancement/${enhancement}_dt05_eSTOI) \t et05_simu=$(cat exp/compute_stoi_estoi_sdr_$enhancement/${enhancement}_et05_eSTOI)" +echo -e "SDR ($enhancement) \t dt05_simu=$(cat exp/compute_stoi_estoi_sdr_$enhancement/${enhancement}_dt05_SDR) \t et05_simu=$(cat exp/compute_stoi_estoi_sdr_$enhancement/${enhancement}_et05_SDR)" +echo "" diff --git a/egs/chime4/s5_1ch/rnnlm b/egs/chime4/s5_1ch/rnnlm new file mode 120000 index 00000000000..e136939ba72 --- /dev/null +++ b/egs/chime4/s5_1ch/rnnlm @@ -0,0 +1 @@ +../../../scripts/rnnlm/ \ No newline at end of file diff --git a/egs/chime4/s5_1ch/run.sh b/egs/chime4/s5_1ch/run.sh index beb8c80207f..5b980dec827 100755 --- a/egs/chime4/s5_1ch/run.sh +++ b/egs/chime4/s5_1ch/run.sh @@ -6,26 +6,29 @@ # Inria (Emmanuel Vincent) # Mitsubishi Electric Research Labs (Shinji Watanabe) # 2017 JHU CLSP (Szu-Jui Chen) +# 2017 JHU CLSP (Aswin Shanmugam Subramanian) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) . ./path.sh . ./cmd.sh + #####Baseline settings##### # Usage: -# 1. For using original baseline, execute './run.sh --baseline chime4_official'. -# We don't provide the function to train original baseline models anymore. Instead, we provided the -# trained original baseline models in tools/ASR_models for directly using. +# Execute './run.sh' to get the models. +# We provide BLSTM masking based enhancement --enhancement single_blstmmask # -# 2. For using advanced baseline, first execute './run.sh --baseline advanced --flatstart true' to -# get the models. If you want to use DNN instead of TDNN, add option "--tdnn false". -# Then execute './run.sh --baseline advanced' for your experiments. +# We stopped to support the old CHiME-3/4 baseline. If you want to reproduce the old results +# Please use the old version of Kaldi, e.g., git checkout 9e8ff73648917836d0870c8f6fdd2ff4bdde384f # Config: stage=0 # resume training with --stage N - -baseline=advanced -flatstart=false -tdnn=true +enhancement=single_blstmmask #### or your method +# if the following options are true, they wouldn't train a model again and will only do decoding +gmm_decode_only=false +tdnn_decode_only=false +# make it true when you want to add enhanced data into training set. But please note that when changing enhancement method, +# you may need to retrain from run_gmm.sh and avoid using decode-only options above +add_enhanced_data=true . utils/parse_options.sh || exit 1; @@ -40,107 +43,82 @@ set -o pipefail # If you use scripts distributed in the CHiME4 package, chime4_data=`pwd`/../.. # Otherwise, please specify it, e.g., -chime4_data=/db/laputa1/data/processed/public/CHiME4 - +# chime4_data=/db/laputa1/data/processed/public/CHiME4 +# chime3_data=/data2/archive/speech-db/original/public/CHiME3 case $(hostname -f) in - *.clsp.jhu.edu) chime4_data=/export/corpora4/CHiME4/CHiME3 ;; # JHU, + *.clsp.jhu.edu) + chime4_data=/export/corpora4/CHiME4/CHiME3 # JHU, + chime3_data=/export/corpora5/CHiME3 + ;; esac if [ ! -d $chime4_data ]; then - echo "$chime4_data does not exist. Please specify chime4 data root correctly" && exit 1 + echo "$chime4_data does not exist. Please specify chime4 data root correctly" && exit 1; fi -# Set a model directory for the CHiME4 data. -case $baseline in - chime4_official) - if $flatstart; then - echo "We don't support this anymore for 'chime4_official' baseline" - echo " ... Automatically set it to false" - fi - modeldir=$chime4_data/tools/ASR_models - flatstart=false - ;; - advanced) - modeldir=`pwd` - ;; - *) - echo "Usage: './run.sh --baseline chime4_official' or './run.sh --baseline advanced'" - echo " ... If you haven't run flatstart for advanced baseline, please execute" - echo " ... './run.sh --baseline advanced --flatstart true' first"; - exit 1; -esac - -if [ "$flatstart" = false ]; then - for d in $modeldir $modeldir/data/{lang,lang_test_tgpr_5k,lang_test_5gkn_5k,lang_test_rnnlm_5k_h300,local} \ - $modeldir/exp/{tri3b_tr05_multi_noisy,tri4a_dnn_tr05_multi_noisy,tri4a_dnn_tr05_multi_noisy_smbr_i1lats}; do - [ ! -d $d ] && echo "$0: no such directory $d. specify models correctly" && \ - echo " or execute './run.sh --baseline advanced --flatstart true' first" && exit 1; - done +if [ ! -d $chime3_data ]; then + echo "$chime3_data does not exist. Please specify chime4 data root correctly" && exit 1; fi -#####check data and model paths finished####### - #####main program start################ # You can execute run_init.sh only "once" # This creates 3-gram LM, FSTs, and basic task files -if [ $stage -le 0 ] && $flatstart; then +if [ $stage -le 0 ]; then local/run_init.sh $chime4_data fi -# In this script, we use non-enhanced 6th microphone signals. -enhancement_method=isolated_1ch_track -enhancement_data=$chime4_data/data/audio/16kHz/$enhancement_method -#if [ $stage -le 1 ]; then -# put your single channel enhancement -#fi +if [[ "$enhancement" == *isolated_1ch_track* ]]; then + enhancement_data=$chime4_data/data/audio/16kHz/isolated_1ch_track +else + enhancement_data=`pwd`/enhan/$enhancement +fi -# GMM based ASR experiment without "retraining" -# Please set a directory of your speech enhancement method. -# run_gmm_recog.sh can be done every time when you change a speech enhancement technique. -# The directory structure and audio files must follow the attached baseline enhancement directory +if [ $stage -le 1 ]; then + local/run_blstm_gev.sh --cmd "$train_cmd" --nj 20 --track 1 $chime4_data $chime3_data $enhancement_data 0 +fi + +# Compute PESQ, STOI, eSTOI, and SDR scores if [ $stage -le 2 ]; then - if $flatstart; then - local/run_gmm.sh $enhancement_method $enhancement_data $chime4_data - else - local/run_gmm_recog.sh $enhancement_method $enhancement_data $modeldir + if [ ! -f local/bss_eval_sources.m ] || [ ! -f local/stoi.m ] || [ ! -f local/estoi.m ] || [ ! -f local/PESQ ]; then + # download and install speech enhancement evaluation tools + local/download_se_eval_tool.sh + fi + chime4_rir_data=local/nn-gev/data/audio/16kHz/isolated_ext + if [ ! -d $chime4_rir_data ]; then + echo "$chime4_rir_dir does not exist. Please run 'blstm_gev' enhancement method first;" && exit 1; fi + local/compute_pesq.sh $enhancement $enhancement_data $chime4_rir_data $PWD + local/compute_stoi_estoi_sdr.sh $enhancement $enhancement_data $chime4_rir_data + local/compute_pesq.sh NOISY_1ch $chime4_data/data/audio/16kHz/isolated_1ch_track/ $chime4_rir_data $PWD + local/compute_stoi_estoi_sdr.sh NOISY_1ch $chime4_data/data/audio/16kHz/isolated_1ch_track/ $chime4_rir_data + local/write_se_results.sh $enhancement + local/write_se_results.sh NOISY_1ch fi -# DNN based ASR experiment -# Since it takes time to evaluate DNN, we make the GMM and DNN scripts separately. -# You may execute it after you would have promising results using GMM-based ASR experiments +# GMM based ASR experiment +# Please set a directory of your speech enhancement method. +# The directory structure and audio files must follow the attached baseline enhancement directory if [ $stage -le 3 ]; then - if $tdnn; then - if $flatstart; then - local/chain/run_tdnn.sh $enhancement_method - else - local/chain/run_tdnn_recog.sh $enhancement_method $modeldir - fi - else - if $flatstart; then - local/run_dnn.sh $enhancement_method - else - local/run_dnn_recog.sh $enhancement_method $modeldir - fi - fi + local/run_gmm.sh --add-enhanced-data $add_enhanced_data \ + --decode-only $gmm_decode_only $enhancement $enhancement_data $chime4_data +fi + +# TDNN based ASR experiment +# Since it takes time to evaluate TDNN, we make the GMM and TDNN scripts separately. +# You may execute it after you would have promising results using GMM-based ASR experiments +if [ $stage -le 4 ]; then + local/chain/run_tdnn.sh --decode-only $tdnn_decode_only $enhancement fi # LM-rescoring experiment with 5-gram and RNN LMs # It takes a few days to train a RNNLM. -if [ $stage -le 4 ]; then - if $flatstart; then - if $tdnn; then - local/run_lmrescore_tdnn.sh $chime4_data $enhancement_method - else - local/run_lmrescore.sh $chime4_data $enhancement_method - fi - else - if $tdnn; then - local/run_lmrescore_tdnn_recog.sh $enhancement_method $modeldir - else - local/run_lmrescore_recog.sh $enhancement_method $modeldir - fi - fi +if [ $stage -le 5 ]; then + local/run_lmrescore_tdnn.sh $chime4_data $enhancement +fi + +# LM-rescoring experiment with LSTM LMs +if [ $stage -le 6 ]; then + local/rnnlm/run_lstm.sh $enhancement fi echo "Done." diff --git a/egs/chime4/s5_2ch/RESULTS b/egs/chime4/s5_2ch/RESULTS index f506b54c5db..156b94ebfa9 100644 --- a/egs/chime4/s5_2ch/RESULTS +++ b/egs/chime4/s5_2ch/RESULTS @@ -19,7 +19,8 @@ et05_simu WER: 27.57% (Average), 20.17% (BUS), 31.81% (CAFE), 29.96% (PEDESTRIAN et05_real WER: 29.03% (Average), 39.37% (BUS), 28.43% (CAFE), 27.56% (PEDESTRIAN), 20.77% (STREET) ------------------- -Advanced baseline: +GMM noisy multi-condition with beamformit using 6 channel data +exp/tri3b_tr05_multi_noisy/best_wer_beamformit_2mics.result ------------------- best overall dt05 WER 17.26% (language model weight = 10) ------------------- @@ -32,6 +33,19 @@ et05_simu WER: 26.85% (Average), 20.08% (BUS), 30.84% (CAFE), 29.03% (PEDESTRIAN et05_real WER: 27.91% (Average), 37.05% (BUS), 29.25% (CAFE), 25.37% (PEDESTRIAN), 19.97% (STREET) ------------------- +GMM noisy multi-condition with BLSTM masking using 6 channel data plus enhanced data +exp/tri3b_tr05_multi_noisy/best_wer_blstm_gev.result +------------------- +best overall dt05 WER 14.57% (language model weight = 10) +------------------- +dt05_simu WER: 15.62% (Average), 12.89% (BUS), 20.49% (CAFE), 14.22% (PEDESTRIAN), 14.90% (STREET) +------------------- +dt05_real WER: 13.52% (Average), 15.52% (BUS), 14.34% (CAFE), 11.57% (PEDESTRIAN), 12.67% (STREET) +------------------- +et05_simu WER: 19.05% (Average), 14.51% (BUS), 21.87% (CAFE), 20.41% (PEDESTRIAN), 19.39% (STREET) +------------------- +et05_real WER: 20.94% (Average), 26.66% (BUS), 21.52% (CAFE), 19.15% (PEDESTRIAN), 16.45% (STREET) +------------------- DNN sMBR exp/tri4a_dnn_tr05_multi_noisy_smbr_i1lats/best_wer_beamformit_2mics.result @@ -48,7 +62,7 @@ et05_simu WER: 19.04% (Average), 14.76% (BUS), 21.72% (CAFE), 19.22% (PEDESTRIAN et05_real WER: 20.44% (Average), 30.02% (BUS), 19.95% (CAFE), 17.79% (PEDESTRIAN), 14.01% (STREET) ------------------- -Advanced baseline: +DNN sMBR using all 6 channel data ------------------- best overall dt05 WER 10.13% (language model weight = 12) (Number of iterations = 3) @@ -77,7 +91,7 @@ et05_simu WER: 16.88% (Average), 12.08% (BUS), 19.70% (CAFE), 16.77% (PEDESTRIAN et05_real WER: 18.07% (Average), 26.77% (BUS), 17.93% (CAFE), 14.76% (PEDESTRIAN), 12.83% (STREET) ------------------- -Advanced baseline: +5-gram rescoring using all 6 channel data ------------------- best overall dt05 WER 8.53% (language model weight = 13) ------------------- @@ -105,7 +119,7 @@ et05_simu WER: 15.33% (Average), 10.66% (BUS), 18.21% (CAFE), 15.61% (PEDESTRIAN et05_real WER: 16.58% (Average), 25.37% (BUS), 15.97% (CAFE), 13.53% (PEDESTRIAN), 11.45% (STREET) ------------------- -Advanced baseline: +RNNLM using all 6 channel data ------------------- best overall dt05 WER 7.46% (language model weight = 14) ------------------- @@ -118,7 +132,7 @@ et05_simu WER: 12.57% (Average), 8.85% (BUS), 14.85% (CAFE), 12.44% (PEDESTRIAN) et05_real WER: 13.33% (Average), 18.94% (BUS), 13.04% (CAFE), 11.85% (PEDESTRIAN), 9.49% (STREET) ------------------- -TDNN +TDNN using all 6 channel data exp/chain/tdnn1d_sp/best_wer_beamformit_5mics.result ------------------- best overall dt05 WER 7.89% (language model weight = 10) @@ -132,8 +146,8 @@ et05_simu WER: 13.15% (Average), 9.77% (BUS), 14.16% (CAFE), 13.43% (PEDESTRIAN) et05_real WER: 13.39% (Average), 19.63% (BUS), 11.64% (CAFE), 11.49% (PEDESTRIAN), 10.80% (STREET) ------------------- -TDNN+RNNLM -exp/chain/tdnn1d_sp_smbr_lmrescore/best_wer_beamformit_5mics_rnnlm_5k_h300_w0.5_n100.result +TDNN+RNNLM using all 6 channel data +exp/chain/tdnn1d_sp_smbr_lmrescore/best_wer_beamformit_2mics_rnnlm_5k_h300_w0.5_n100.result ------------------- best overall dt05 WER 5.82% (language model weight = 11) ------------------- @@ -145,3 +159,73 @@ et05_simu WER: 9.90% (Average), 7.00% (BUS), 11.15% (CAFE), 10.05% (PEDESTRIAN), ------------------- et05_real WER: 10.53% (Average), 16.90% (BUS), 8.65% (CAFE), 8.52% (PEDESTRIAN), 8.05% (STREET) ------------------- + +TDNN using 6 channel data plus enhanced data +exp/chain/tdnn1a_sp/best_wer_beamformit_5mics.result +------------------- +best overall dt05 WER 7.57% (language model weight = 10) +------------------- +dt05_simu WER: 8.18% (Average), 7.12% (BUS), 10.16% (CAFE), 6.33% (PEDESTRIAN), 9.12% (STREET) +------------------- +dt05_real WER: 6.96% (Average), 9.38% (BUS), 6.46% (CAFE), 4.91% (PEDESTRIAN), 7.09% (STREET) +------------------- +et05_simu WER: 13.14% (Average), 9.92% (BUS), 14.55% (CAFE), 13.26% (PEDESTRIAN), 14.83% (STREET) +------------------- +et05_real WER: 12.81% (Average), 19.27% (BUS), 10.66% (CAFE), 11.29% (PEDESTRIAN), 10.03% (STREET) +------------------- + +TDNN+RNNLM using 6 channel data plus enhanced data +exp/chain/tdnn1a_sp_smbr_lmrescore/best_wer_beamformit_2mics_rnnlm_5k_h300_w0.5_n100.result +------------------- +best overall dt05 WER 5.52% (language model weight = 10) +------------------- +dt05_simu WER: 6.02% (Average), 5.28% (BUS), 7.37% (CAFE), 4.60% (PEDESTRIAN), 6.81% (STREET) +------------------- +dt05_real WER: 5.03% (Average), 7.23% (BUS), 4.26% (CAFE), 3.26% (PEDESTRIAN), 5.35% (STREET) +------------------- +et05_simu WER: 10.35% (Average), 7.84% (BUS), 11.04% (CAFE), 10.55% (PEDESTRIAN), 11.95% (STREET) +------------------- +et05_real WER: 10.20% (Average), 16.21% (BUS), 8.18% (CAFE), 8.43% (PEDESTRIAN), 7.98% (STREET) +------------------- + +TDNN with BLSTM masking using 6 channel data plus enhanced data +exp/chain/tdnn1a_sp/best_wer_blstm_gev.result +------------------- +best overall dt05 WER 6.35% (language model weight = 9) +------------------- +dt05_simu WER: 7.03% (Average), 5.72% (BUS), 9.32% (CAFE), 6.28% (PEDESTRIAN), 6.78% (STREET) +------------------- +dt05_real WER: 5.66% (Average), 6.89% (BUS), 5.99% (CAFE), 4.44% (PEDESTRIAN), 5.34% (STREET) +------------------- +et05_simu WER: 8.80% (Average), 6.80% (BUS), 10.20% (CAFE), 8.37% (PEDESTRIAN), 9.84% (STREET) +------------------- +et05_real WER: 9.46% (Average), 13.42% (BUS), 8.31% (CAFE), 8.76% (PEDESTRIAN), 7.34% (STREET) +------------------- + +TDNN+RNNLM with BLSTM masking using 6 channel data plus enhanced data +exp/chain/tdnn1a_sp_smbr_lmrescore/best_wer_blstm_gev_rnnlm_5k_h300_w0.5_n100.result +------------------- +best overall dt05 WER 4.41% (language model weight = 11) +------------------- +dt05_simu WER: 5.03% (Average), 4.13% (BUS), 6.83% (CAFE), 4.45% (PEDESTRIAN), 4.72% (STREET) +------------------- +dt05_real WER: 3.79% (Average), 4.68% (BUS), 3.94% (CAFE), 2.95% (PEDESTRIAN), 3.61% (STREET) +------------------- +et05_simu WER: 6.07% (Average), 4.52% (BUS), 6.93% (CAFE), 6.05% (PEDESTRIAN), 6.78% (STREET) +------------------- +et05_real WER: 6.93% (Average), 10.23% (BUS), 6.13% (CAFE), 6.41% (PEDESTRIAN), 4.97% (STREET) +------------------- + +TDNN+RNNLM with BLSTM masking using 6 channel data plus enhanced data +exp/chain/tdnn1a_sp_smbr_lmrescore/best_wer_blstm_gev_rnnlm_lstm_1a_w0.5_n100.result +------------------- +best overall dt05 WER 3.39% (language model weight = 10) +------------------- +dt05_simu WER: 3.94% (Average), 2.99% (BUS), 5.65% (CAFE), 3.44% (PEDESTRIAN), 3.67% (STREET) +------------------- +dt05_real WER: 2.85% (Average), 3.58% (BUS), 2.89% (CAFE), 2.07% (PEDESTRIAN), 2.85% (STREET) +------------------- +et05_simu WER: 5.03% (Average), 3.66% (BUS), 5.57% (CAFE), 4.87% (PEDESTRIAN), 6.03% (STREET) +------------------- +et05_real WER: 5.40% (Average), 7.81% (BUS), 4.71% (CAFE), 4.73% (PEDESTRIAN), 4.37% (STREET) +------------------- diff --git a/egs/chime4/s5_2ch/rnnlm b/egs/chime4/s5_2ch/rnnlm new file mode 120000 index 00000000000..e136939ba72 --- /dev/null +++ b/egs/chime4/s5_2ch/rnnlm @@ -0,0 +1 @@ +../../../scripts/rnnlm/ \ No newline at end of file diff --git a/egs/chime4/s5_2ch/run.sh b/egs/chime4/s5_2ch/run.sh index e1a3fecbce5..7ae5048c6fa 100755 --- a/egs/chime4/s5_2ch/run.sh +++ b/egs/chime4/s5_2ch/run.sh @@ -6,26 +6,30 @@ # Inria (Emmanuel Vincent) # Mitsubishi Electric Research Labs (Shinji Watanabe) # 2017 JHU CLSP (Szu-Jui Chen) +# 2017 JHU CLSP (Aswin Shanmugam Subramanian) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) . ./path.sh . ./cmd.sh + #####Baseline settings##### # Usage: -# 1. For using original baseline, execute './run.sh --baseline chime4_official'. -# We don't provide the function to train original baseline models anymore. Instead, we provided the -# trained original baseline models in tools/ASR_models for directly using. +# Execute './run.sh' to get the models. +# We provide three kinds of beamform methods. Add option --enhancement blstm_gev, or --enhancement beamformit_2mics +# to use them. i.g. './run.sh --enhancement blstm_gev' # -# 2. For using advanced baseline, first execute './run.sh --baseline advanced --flatstart true' to -# get the models. If you want to use DNN instead of TDNN, add option "--tdnn false". -# Then execute './run.sh --baseline advanced' for your experiments. +# We stopped to support the old CHiME-3/4 baseline. If you want to reproduce the old results +# Please use the old version of Kaldi, e.g., git checkout 9e8ff73648917836d0870c8f6fdd2ff4bdde384f # Config: stage=0 # resume training with --stage N - -baseline=advanced -flatstart=false -tdnn=true +enhancement=blstm_gev #### or your method +# if the following options are true, they wouldn't train a model again and will only do decoding +gmm_decode_only=false +tdnn_decode_only=false +# make it true when you want to add enhanced data into training set. But please note that when changing enhancement method, +# you may need to retrain from run_gmm.sh and avoid using decode-only options above +add_enhanced_data=true . utils/parse_options.sh || exit 1; @@ -40,109 +44,89 @@ set -o pipefail # If you use scripts distributed in the CHiME4 package, chime4_data=`pwd`/../.. # Otherwise, please specify it, e.g., -chime4_data=/db/laputa1/data/processed/public/CHiME4 +# chime4_data=/db/laputa1/data/processed/public/CHiME4 +# chime3_data=/data2/archive/speech-db/original/public/CHiME3 case $(hostname -f) in - *.clsp.jhu.edu) chime4_data=/export/corpora4/CHiME4/CHiME3 ;; # JHU, + *.clsp.jhu.edu) + chime4_data=/export/corpora4/CHiME4/CHiME3 # JHU, + chime3_data=/export/corpora5/CHiME3 + ;; esac if [ ! -d $chime4_data ]; then - echo "$chime4_data does not exist. Please specify chime4 data root correctly" && exit 1 + echo "$chime4_data does not exist. Please specify chime4 data root correctly" && exit 1; fi -# Set a model directory for the CHiME4 data. -case $baseline in - chime4_official) - if $flatstart; then - echo "We don't support this anymore for 'chime4_official' baseline" - echo " ... Automatically set it to false" - fi - modeldir=$chime4_data/tools/ASR_models - flatstart=false - ;; - advanced) - modeldir=`pwd` - ;; - *) - echo "Usage: './run.sh --baseline chime4_official' or './run.sh --baseline advanced'" - echo " ... If you haven't run flatstart to train the model of advanced baseline," - echo " ... please execute './run.sh --baseline advanced --flatstart true' first"; - exit 1; -esac - -if [ "$flatstart" = false ]; then - for d in $modeldir $modeldir/data/{lang,lang_test_tgpr_5k,lang_test_5gkn_5k,lang_test_rnnlm_5k_h300,local} \ - $modeldir/exp/{tri3b_tr05_multi_noisy,tri4a_dnn_tr05_multi_noisy,tri4a_dnn_tr05_multi_noisy_smbr_i1lats}; do - [ ! -d $d ] && echo "$0: no such directory $d. specify models correctly" && \ - echo " or execute './run.sh --baseline advanced --flatstart true' first" && exit 1; - done +if [ ! -d $chime3_data ]; then + echo "$chime3_data does not exist. Please specify chime4 data root correctly" && exit 1; fi -#####check data and model paths finished####### - #####main program start################ # You can execute run_init.sh only "once" # This creates 3-gram LM, FSTs, and basic task files -if [ $stage -le 0 ] && $flatstart; then +if [ $stage -le 0 ]; then local/run_init.sh $chime4_data fi -# Using Beamformit -# See Hori et al, "The MERL/SRI system for the 3rd CHiME challenge using beamforming, -# robust feature extraction, and advanced speech recognition," in Proc. ASRU'15 -# note that beamformed wav files are generated in the following directory -enhancement_method=beamformit_2mics -enhancement_data=`pwd`/enhan/$enhancement_method +# Using Beamformit or mask-based beamformer +# note that beamformed WAV files are generated in the following directory +enhancement_data=`pwd`/enhan/$enhancement if [ $stage -le 1 ]; then - local/run_beamform_2ch_track.sh --cmd "$train_cmd" --nj 20 $chime4_data/data/audio/16kHz/isolated_2ch_track $enhancement_data + case $enhancement in + beamformit_2mics) + local/run_beamform_2ch_track.sh --cmd "$train_cmd" --nj 20 $chime4_data/data/audio/16kHz/isolated_2ch_track $enhancement_data + ;; + blstm_gev) + local/run_blstm_gev.sh --cmd "$train_cmd" --nj 20 --track 2 $chime4_data $chime3_data $enhancement_data 0 + ;; + *) + echo "Usage: --enhancement blstm_gev, or --enhancement beamformit_2mics" + exit 1; + esac fi -# GMM based ASR experiment without "retraining" -# Please set a directory of your speech enhancement method. -# run_gmm_recog.sh can be done every time when you change a speech enhancement technique. -# The directory structure and audio files must follow the attached baseline enhancement directory +# Compute PESQ, STOI, eSTOI, and SDR scores if [ $stage -le 2 ]; then - if $flatstart; then - local/run_gmm.sh $enhancement_method $enhancement_data $chime4_data - else - local/run_gmm_recog.sh $enhancement_method $enhancement_data $modeldir + if [ ! -f local/bss_eval_sources.m ] || [ ! -f local/stoi.m ] || [ ! -f local/estoi.m ] || [ ! -f local/PESQ ]; then + # download and install speech enhancement evaluation tools + local/download_se_eval_tool.sh + fi + chime4_rir_data=local/nn-gev/data/audio/16kHz/isolated_ext + if [ ! -d $chime4_rir_data ]; then + echo "$chime4_rir_dir does not exist. Please run 'blstm_gev' enhancement method first;" && exit 1; fi + local/compute_pesq.sh $enhancement $enhancement_data $chime4_rir_data $PWD + local/compute_stoi_estoi_sdr.sh $enhancement $enhancement_data $chime4_rir_data + local/compute_pesq.sh NOISY_1ch $chime4_data/data/audio/16kHz/isolated_1ch_track/ $chime4_rir_data $PWD + local/compute_stoi_estoi_sdr.sh NOISY_1ch $chime4_data/data/audio/16kHz/isolated_1ch_track/ $chime4_rir_data + local/write_se_results.sh $enhancement + local/write_se_results.sh NOISY_1ch fi -# DNN based ASR experiment -# Since it takes time to evaluate DNN, we make the GMM and DNN scripts separately. -# You may execute it after you would have promising results using GMM-based ASR experiments +# GMM based ASR experiment +# Please set a directory of your speech enhancement method. +# The directory structure and audio files must follow the attached baseline enhancement directory if [ $stage -le 3 ]; then - if $tdnn; then - if $flatstart; then - local/chain/run_tdnn.sh $enhancement_method - else - local/chain/run_tdnn_recog.sh $enhancement_method $modeldir - fi - else - if $flatstart; then - local/run_dnn.sh $enhancement_method - else - local/run_dnn_recog.sh $enhancement_method $modeldir - fi - fi + local/run_gmm.sh --add-enhanced-data $add_enhanced_data \ + --decode-only $gmm_decode_only $enhancement $enhancement_data $chime4_data +fi + +# TDNN based ASR experiment +# Since it takes time to evaluate TDNN, we make the GMM and TDNN scripts separately. +# You may execute it after you would have promising results using GMM-based ASR experiments +if [ $stage -le 4 ]; then + local/chain/run_tdnn.sh --decode-only $tdnn_decode_only $enhancement fi # LM-rescoring experiment with 5-gram and RNN LMs # It takes a few days to train a RNNLM. -if [ $stage -le 4 ]; then - if $flatstart; then - if $tdnn; then - local/run_lmrescore_tdnn.sh $chime4_data $enhancement_method - else - local/run_lmrescore.sh $chime4_data $enhancement_method - fi - else - if $tdnn; then - local/run_lmrescore_tdnn_recog.sh $enhancement_method $modeldir - else - local/run_lmrescore_recog.sh $enhancement_method $modeldir - fi - fi +if [ $stage -le 5 ]; then + local/run_lmrescore_tdnn.sh $chime4_data $enhancement +fi + +# LM-rescoring experiment with LSTM LMs +if [ $stage -le 6 ]; then + local/rnnlm/run_lstm.sh $enhancement fi echo "Done." diff --git a/egs/chime4/s5_6ch/RESULTS b/egs/chime4/s5_6ch/RESULTS index 7d602d49247..266216adc16 100644 --- a/egs/chime4/s5_6ch/RESULTS +++ b/egs/chime4/s5_6ch/RESULTS @@ -19,20 +19,21 @@ et05_simu WER: 21.30% (Average), 15.73% (BUS), 22.94% (CAFE), 22.51% (PEDESTRIAN et05_real WER: 21.83% (Average), 30.17% (BUS), 20.66% (CAFE), 19.82% (PEDESTRIAN), 16.68% (STREET) ------------------- -Advanced baseline: +GMM noisy multi-condition with blstm_gev +exp/tri3b_tr05_multi_noisy/best_wer_blstm_gev.result ------------------- -best overall dt05 WER 13.60% (language model weight = 12) +best overall dt05 WER 11.17% (language model weight = 12) ------------------- -dt05_simu WER: 14.23% (Average), 12.24% (BUS), 17.20% (CAFE), 12.05% (PEDESTRIAN), 15.44% (STREET) +dt05_simu WER: 11.44% (Average), 9.78% (BUS), 14.37% (CAFE), 10.10% (PEDESTRIAN), 11.50% (STREET) ------------------- -dt05_real WER: 12.96% (Average), 15.42% (BUS), 12.94% (CAFE), 10.18% (PEDESTRIAN), 13.30% (STREET) +dt05_real WER: 10.91% (Average), 11.21% (BUS), 11.24% (CAFE), 10.34% (PEDESTRIAN), 10.84% (STREET) ------------------- -et05_simu WER: 20.46% (Average), 14.77% (BUS), 21.78% (CAFE), 22.49% (PEDESTRIAN), 22.81% (STREET) +et05_simu WER: 13.54% (Average), 11.65% (BUS), 14.90% (CAFE), 13.73% (PEDESTRIAN), 13.86% (STREET) ------------------- -et05_real WER: 21.14% (Average), 28.40% (BUS), 21.29% (CAFE), 18.68% (PEDESTRIAN), 16.19% (STREET) +et05_real WER: 14.62% (Average), 16.43% (BUS), 15.43% (CAFE), 12.99% (PEDESTRIAN), 13.63% (STREET) ------------------- -DNN sMBR +DNN sMBR with beamformit exp/tri4a_dnn_tr05_multi_noisy_smbr_i1lats/best_wer_beamformit_5mics.result ------------------- best overall dt05 WER 8.60% (language model weight = 11) @@ -47,98 +48,120 @@ et05_simu WER: 14.23% (Average), 10.72% (BUS), 15.52% (CAFE), 13.90% (PEDESTRIAN et05_real WER: 15.00% (Average), 21.74% (BUS), 13.58% (CAFE), 12.84% (PEDESTRIAN), 11.86% (STREET) ------------------- -Advanced baseline: +DNN sMBR with blstm_gev +exp/tri4a_dnn_tr05_multi_noisy_smbr_i1lats/best_wer_blstm_gev.result ------------------- -best overall dt05 WER 7.72% (language model weight = 12) - (Number of iterations = 3) +best overall dt05 WER 7.38% (language model weight = 11) + (Number of iterations = 4) ------------------- -dt05_simu WER: 7.98% (Average), 6.96% (BUS), 9.75% (CAFE), 6.56% (PEDESTRIAN), 8.66% (STREET) +dt05_simu WER: 7.49% (Average), 5.93% (BUS), 9.69% (CAFE), 6.73% (PEDESTRIAN), 7.61% (STREET) ------------------- -dt05_real WER: 7.45% (Average), 9.15% (BUS), 8.10% (CAFE), 5.40% (PEDESTRIAN), 7.17% (STREET) +dt05_real WER: 7.28% (Average), 7.83% (BUS), 7.80% (CAFE), 6.37% (PEDESTRIAN), 7.11% (STREET) ------------------- -et05_simu WER: 12.30% (Average), 9.45% (BUS), 13.26% (CAFE), 11.77% (PEDESTRIAN), 14.74% (STREET) +et05_simu WER: 9.54% (Average), 8.18% (BUS), 10.87% (CAFE), 9.81% (PEDESTRIAN), 9.32% (STREET) ------------------- -et05_real WER: 12.64% (Average), 16.34% (BUS), 12.36% (CAFE), 10.93% (PEDESTRIAN), 10.93% (STREET) +et05_real WER: 9.77% (Average), 11.42% (BUS), 10.22% (CAFE), 9.23% (PEDESTRIAN), 8.22% (STREET) ------------------- -5-gram rescoring -exp/tri4a_dnn_tr05_multi_noisy_smbr_lmrescore/best_wer_beamformit_5mics_5gkn_5k.result +RNNLM with beamformit +exp/tri4a_dnn_tr05_multi_noisy_smbr_lmrescore/best_wer_beamformit_5mics_rnnlm_5k_h300_w0.5_n100.result ------------------- -best overall dt05 WER 7.30% (language model weight = 11) +best overall dt05 WER 6.27% (language model weight = 12) ------------------- -dt05_simu WER: 7.75% (Average), 7.14% (BUS), 9.13% (CAFE), 6.33% (PEDESTRIAN), 8.41% (STREET) +dt05_simu WER: 6.77% (Average), 6.02% (BUS), 8.10% (CAFE), 5.49% (PEDESTRIAN), 7.48% (STREET) ------------------- -dt05_real WER: 6.85% (Average), 8.53% (BUS), 6.90% (CAFE), 4.72% (PEDESTRIAN), 7.24% (STREET) +dt05_real WER: 5.76% (Average), 7.39% (BUS), 5.77% (CAFE), 3.72% (PEDESTRIAN), 6.18% (STREET) ------------------- -et05_simu WER: 12.31% (Average), 8.82% (BUS), 13.04% (CAFE), 11.84% (PEDESTRIAN), 15.54% (STREET) +et05_simu WER: 10.90% (Average), 7.68% (BUS), 11.54% (CAFE), 10.31% (PEDESTRIAN), 14.06% (STREET) ------------------- -et05_real WER: 13.23% (Average), 19.07% (BUS), 11.80% (CAFE), 11.51% (PEDESTRIAN), 10.53% (STREET) +et05_real WER: 11.51% (Average), 16.86% (BUS), 10.18% (CAFE), 9.83% (PEDESTRIAN), 9.19% (STREET) ------------------- -Advanced baseline: +######## Advanced baseline +######## All 6 channel training, enhanced data training, Lattice-free MMI TDNN, BLSTM-mask-based GEV beamformer + +TDNN with beamformit +exp/chain/tdnn1d_sp/best_wer_beamformit_5mics.result ------------------- -best overall dt05 WER 6.25% (language model weight = 13) +best overall dt05 WER 6.04% (language model weight = 9) ------------------- -dt05_simu WER: 6.58% (Average), 5.86% (BUS), 7.89% (CAFE), 5.19% (PEDESTRIAN), 7.39% (STREET) +dt05_simu WER: 6.25% (Average), 5.71% (BUS), 6.92% (CAFE), 5.37% (PEDESTRIAN), 7.02% (STREET) ------------------- -dt05_real WER: 5.92% (Average), 7.46% (BUS), 6.19% (CAFE), 4.25% (PEDESTRIAN), 5.77% (STREET) +dt05_real WER: 5.83% (Average), 7.48% (BUS), 5.28% (CAFE), 4.43% (PEDESTRIAN), 6.13% (STREET) ------------------- -et05_simu WER: 10.50% (Average), 7.81% (BUS), 11.06% (CAFE), 10.44% (PEDESTRIAN), 12.70% (STREET) +et05_simu WER: 10.30% (Average), 7.34% (BUS), 10.37% (CAFE), 10.05% (PEDESTRIAN), 13.43% (STREET) ------------------- -et05_real WER: 10.68% (Average), 13.97% (BUS), 10.48% (CAFE), 9.08% (PEDESTRIAN), 9.19% (STREET) +et05_real WER: 9.67% (Average), 12.71% (BUS), 8.33% (CAFE), 8.20% (PEDESTRIAN), 9.45% (STREET) ------------------- -RNNLM -exp/tri4a_dnn_tr05_multi_noisy_smbr_lmrescore/best_wer_beamformit_5mics_rnnlm_5k_h300_w0.5_n100.result +TDNN+RNNLM with beamformit +exp/chain/tdnn1d_sp_smbr_lmrescore/best_wer_beamformit_5mics_rnnlm_5k_h300_w0.5_n100.result ------------------- -best overall dt05 WER 6.27% (language model weight = 12) +best overall dt05 WER 4.15% (language model weight = 9) ------------------- -dt05_simu WER: 6.77% (Average), 6.02% (BUS), 8.10% (CAFE), 5.49% (PEDESTRIAN), 7.48% (STREET) +dt05_simu WER: 4.33% (Average), 3.95% (BUS), 4.87% (CAFE), 3.53% (PEDESTRIAN), 4.97% (STREET) ------------------- -dt05_real WER: 5.76% (Average), 7.39% (BUS), 5.77% (CAFE), 3.72% (PEDESTRIAN), 6.18% (STREET) +dt05_real WER: 3.97% (Average), 5.38% (BUS), 3.19% (CAFE), 2.94% (PEDESTRIAN), 4.37% (STREET) ------------------- -et05_simu WER: 10.90% (Average), 7.68% (BUS), 11.54% (CAFE), 10.31% (PEDESTRIAN), 14.06% (STREET) +et05_simu WER: 7.39% (Average), 4.87% (BUS), 7.58% (CAFE), 7.15% (PEDESTRIAN), 9.96% (STREET) ------------------- -et05_real WER: 11.51% (Average), 16.86% (BUS), 10.18% (CAFE), 9.83% (PEDESTRIAN), 9.19% (STREET) +et05_real WER: 7.04% (Average), 9.89% (BUS), 5.49% (CAFE), 5.70% (PEDESTRIAN), 7.10% (STREET) ------------------- -Advanced baseline: +TDNN using 6 channel data plus enhanced data with beamformit +exp/chain/tdnn7a_sp/best_wer_beamformit_5mics.result ------------------- -best overall dt05 WER 5.44% (language model weight = 13) +best overall dt05 WER 5.80% (language model weight = 10) ------------------- -dt05_simu WER: 5.82% (Average), 4.90% (BUS), 6.96% (CAFE), 4.62% (PEDESTRIAN), 6.81% (STREET) +dt05_simu WER: 6.19% (Average), 5.96% (BUS), 6.78% (CAFE), 5.10% (PEDESTRIAN), 6.92% (STREET) ------------------- -dt05_real WER: 5.05% (Average), 6.43% (BUS), 5.03% (CAFE), 3.42% (PEDESTRIAN), 5.31% (STREET) +dt05_real WER: 5.41% (Average), 6.86% (BUS), 4.87% (CAFE), 4.00% (PEDESTRIAN), 5.91% (STREET) ------------------- -et05_simu WER: 9.24% (Average), 6.65% (BUS), 9.81% (CAFE), 9.23% (PEDESTRIAN), 11.28% (STREET) +et05_simu WER: 10.26% (Average), 7.68% (BUS), 10.40% (CAFE), 10.16% (PEDESTRIAN), 12.79% (STREET) ------------------- -et05_real WER: 9.50% (Average), 12.64% (BUS), 8.76% (CAFE), 7.96% (PEDESTRIAN), 8.63% (STREET) +et05_real WER: 9.63% (Average), 13.46% (BUS), 7.98% (CAFE), 8.13% (PEDESTRIAN), 8.97% (STREET) ------------------- -TDNN -exp/chain/tdnn1d_sp/best_wer_beamformit_5mics.result +TDNN+RNNLM using 6 channel data plus enhanced data with beamformit +exp/chain/tdnn7a_sp_smbr_lmrescore/best_wer_beamformit_5mics_rnnlm_5k_h300_w0.5_n100.result +compute dt05 WER for each location ------------------- -best overall dt05 WER 6.04% (language model weight = 9) +best overall dt05 WER 4.02% (language model weight = 11) ------------------- -dt05_simu WER: 6.25% (Average), 5.71% (BUS), 6.92% (CAFE), 5.37% (PEDESTRIAN), 7.02% (STREET) +dt05_simu WER: 4.31% (Average), 4.04% (BUS), 4.88% (CAFE), 3.38% (PEDESTRIAN), 4.94% (STREET) ------------------- -dt05_real WER: 5.83% (Average), 7.48% (BUS), 5.28% (CAFE), 4.43% (PEDESTRIAN), 6.13% (STREET) +dt05_real WER: 3.74% (Average), 4.62% (BUS), 3.17% (CAFE), 3.02% (PEDESTRIAN), 4.14% (STREET) ------------------- -et05_simu WER: 10.30% (Average), 7.34% (BUS), 10.37% (CAFE), 10.05% (PEDESTRIAN), 13.43% (STREET) +et05_simu WER: 7.49% (Average), 5.16% (BUS), 7.21% (CAFE), 7.45% (PEDESTRIAN), 10.14% (STREET) ------------------- -et05_real WER: 9.67% (Average), 12.71% (BUS), 8.33% (CAFE), 8.20% (PEDESTRIAN), 9.45% (STREET) +et05_real WER: 6.84% (Average), 9.74% (BUS), 5.38% (CAFE), 5.25% (PEDESTRIAN), 7.00% (STREET) ------------------- -TDNN+RNNLM -exp/chain/tdnn1d_sp_smbr_lmrescore/best_wer_beamformit_5mics_rnnlm_5k_h300_w0.5_n100.result +TDNN+RNNLM using 6 channel data plus enhanced data with blstm_gev +exp/chain/tdnn1a_sp_smbr_lmrescore/best_wer_blstm_gev_rnnlm_5k_h300_w0.5_n100.result ------------------- -best overall dt05 WER 4.15% (language model weight = 9) +best overall dt05 WER 3.01% (language model weight = 10) ------------------- -dt05_simu WER: 4.33% (Average), 3.95% (BUS), 4.87% (CAFE), 3.53% (PEDESTRIAN), 4.97% (STREET) +dt05_simu WER: 3.10% (Average), 2.60% (BUS), 4.07% (CAFE), 2.80% (PEDESTRIAN), 2.92% (STREET) ------------------- -dt05_real WER: 3.97% (Average), 5.38% (BUS), 3.19% (CAFE), 2.94% (PEDESTRIAN), 4.37% (STREET) +dt05_real WER: 2.93% (Average), 3.32% (BUS), 2.83% (CAFE), 2.63% (PEDESTRIAN), 2.93% (STREET) ------------------- -et05_simu WER: 7.39% (Average), 4.87% (BUS), 7.58% (CAFE), 7.15% (PEDESTRIAN), 9.96% (STREET) +et05_simu WER: 3.95% (Average), 3.29% (BUS), 4.71% (CAFE), 4.30% (PEDESTRIAN), 3.53% (STREET) ------------------- -et05_real WER: 7.04% (Average), 9.89% (BUS), 5.49% (CAFE), 5.70% (PEDESTRIAN), 7.10% (STREET) -------------------- \ No newline at end of file +et05_real WER: 4.04% (Average), 4.94% (BUS), 3.66% (CAFE), 3.66% (PEDESTRIAN), 3.90% (STREET) +------------------- + +TDNN+LSTMLM using 6 channel data plus enhanced data with blstm_gev +exp/chain/tdnn1a_sp_smbr_lmrescore/best_wer_blstm_gev_rnnlm_lstm_1a_w0.5_n100.result +------------------- +best overall dt05 WER 2.00% (language model weight = 11) +------------------- +dt05_simu WER: 2.10% (Average), 2.06% (BUS), 2.58% (CAFE), 1.73% (PEDESTRIAN), 2.02% (STREET) +------------------- +dt05_real WER: 1.90% (Average), 2.05% (BUS), 1.78% (CAFE), 1.68% (PEDESTRIAN), 2.09% (STREET) +------------------- +et05_simu WER: 2.66% (Average), 2.33% (BUS), 2.73% (CAFE), 2.93% (PEDESTRIAN), 2.63% (STREET) +------------------- +et05_real WER: 2.74% (Average), 3.05% (BUS), 2.45% (CAFE), 2.65% (PEDESTRIAN), 2.82% (STREET) +------------------- + diff --git a/egs/chime4/s5_6ch/rnnlm b/egs/chime4/s5_6ch/rnnlm new file mode 120000 index 00000000000..e136939ba72 --- /dev/null +++ b/egs/chime4/s5_6ch/rnnlm @@ -0,0 +1 @@ +../../../scripts/rnnlm/ \ No newline at end of file diff --git a/egs/chime4/s5_6ch/run.sh b/egs/chime4/s5_6ch/run.sh index 090808c026b..1979a040bd8 100755 --- a/egs/chime4/s5_6ch/run.sh +++ b/egs/chime4/s5_6ch/run.sh @@ -1,33 +1,33 @@ -#!/bin/bash - # Kaldi ASR baseline for the CHiME-4 Challenge (6ch track: 6 channel track) # # Copyright 2016 University of Sheffield (Jon Barker, Ricard Marxer) # Inria (Emmanuel Vincent) # Mitsubishi Electric Research Labs (Shinji Watanabe) # 2017 JHU CLSP (Szu-Jui Chen) +# 2017 JHU CLSP (Aswin Shanmugam Subramanian) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) . ./path.sh . ./cmd.sh + #####Baseline settings##### # Usage: -# 1. For using original baseline, execute './run.sh --baseline chime4_official'. -# We don't provide the function to train original baseline models anymore. Instead, we provided the -# trained original baseline models in tools/ASR_models for directly using. +# Execute './run.sh' to get the models. +# We provide three kinds of beamform methods. Add option --enhancement blstm_gev, or --enhancement beamformit_5mics +# or --enhancement single_blstmmask to use them. i.g. './run.sh --enhancement blstm_gev' # -# 2. For using advanced baseline, first execute './run.sh --baseline advanced --flatstart true' to -# get the models. If you want to use TDNN instead of DNN, add option "--tdnn true". If you want to -# use TDNN-LSTM instead of DNN, add option "--tdnn-lstm true". -# Then execute './run.sh --baseline advanced' for your experiments. +# We stopped to support the old CHiME-3/4 baseline. If you want to reproduce the old results +# Please use the old version of Kaldi, e.g., git checkout 9e8ff73648917836d0870c8f6fdd2ff4bdde384f # Config: stage=0 # resume training with --stage N - -baseline=advanced -flatstart=false -tdnn=true -tdnn_lstm=false +enhancement=blstm_gev #### or your method +# if the following options are true, they wouldn't train a model again and will only do decoding +gmm_decode_only=false +tdnn_decode_only=false +# make it true when you want to add enhanced data into training set. But please note that when changing enhancement method, +# you may need to retrain from run_gmm.sh and avoid using decode-only options above +add_enhanced_data=true . utils/parse_options.sh || exit 1; @@ -42,119 +42,92 @@ set -o pipefail # If you use scripts distributed in the CHiME4 package, chime4_data=`pwd`/../.. # Otherwise, please specify it, e.g., -chime4_data=/db/laputa1/data/processed/public/CHiME4 +# chime4_data=/db/laputa1/data/processed/public/CHiME4 +# chime3_data=/data2/archive/speech-db/original/public/CHiME3 case $(hostname -f) in - *.clsp.jhu.edu) chime4_data=/export/corpora4/CHiME4/CHiME3 ;; # JHU, + *.clsp.jhu.edu) + chime4_data=/export/corpora4/CHiME4/CHiME3 # JHU, + chime3_data=/export/corpora5/CHiME3 + ;; esac if [ ! -d $chime4_data ]; then echo "$chime4_data does not exist. Please specify chime4 data root correctly" && exit 1; fi -# Set a model directory for the CHiME4 data. -case $baseline in - chime4_official) - if $flatstart; then - echo "We don't support this anymore for 'chime4_official' baseline" - echo " ... Automatically set it to false" - fi - modeldir=$chime4_data/tools/ASR_models - flatstart=false - ;; - advanced) - modeldir=`pwd` - ;; - *) - echo "Usage: './run.sh --baseline chime4_official' or './run.sh --baseline advanced'" - echo " ... If you haven't run flatstart for advanced baseline, please execute" - echo " ... './run.sh --baseline advanced --flatstart true' first"; - exit 1; -esac - -if [ "$flatstart" = false ]; then - for d in $modeldir $modeldir/data/{lang,lang_test_tgpr_5k,lang_test_5gkn_5k,lang_test_rnnlm_5k_h300,local} \ - $modeldir/exp/{tri3b_tr05_multi_noisy,tri4a_dnn_tr05_multi_noisy,tri4a_dnn_tr05_multi_noisy_smbr_i1lats}; do - [ ! -d $d ] && echo "$0: no such directory $d. specify models correctly" && \ - echo " or execute './run.sh --baseline advanced --flatstart true' first" && exit 1; - done +if [ ! -d $chime3_data ]; then + echo "$chime3_data does not exist. Please specify chime4 data root correctly" && exit 1; fi -#####check data and model paths finished####### - #####main program start################ # You can execute run_init.sh only "once" # This creates 3-gram LM, FSTs, and basic task files -if [ $stage -le 0 ] && $flatstart; then +if [ $stage -le 0 ]; then local/run_init.sh $chime4_data fi -# Using Beamformit -# See Hori et al, "The MERL/SRI system for the 3rd CHiME challenge using beamforming, -# robust feature extraction, and advanced speech recognition," in Proc. ASRU'15 -# note that beamformed wav files are generated in the following directory -enhancement_method=beamformit_5mics -enhancement_data=`pwd`/enhan/$enhancement_method +# Using Beamformit or mask-based beamformer +# note that beamformed WAV files are generated in the following directory +enhancement_data=`pwd`/enhan/$enhancement if [ $stage -le 1 ]; then - local/run_beamform_6ch_track.sh --cmd "$train_cmd" --nj 20 $chime4_data/data/audio/16kHz/isolated_6ch_track $enhancement_data + case $enhancement in + beamformit_5mics) + local/run_beamform_6ch_track.sh --cmd "$train_cmd" --nj 20 $chime4_data/data/audio/16kHz/isolated_6ch_track $enhancement_data + ;; + blstm_gev) + local/run_blstm_gev.sh --cmd "$train_cmd" --nj 20 $chime4_data $chime3_data $enhancement_data 0 + ;; + single_blstmmask) + local/run_blstm_gev.sh --cmd "$train_cmd" --nj 20 $chime4_data $chime3_data $enhancement_data 5 + ;; + *) + echo "Usage: --enhancement blstm_gev, or --enhancement beamformit_5mics , or --enhancement single_blstmmask" + exit 1; + esac fi -# GMM based ASR experiment without "retraining" -# Please set a directory of your speech enhancement method. -# run_gmm_recog.sh can be done every time when you change a speech enhancement technique. -# The directory structure and audio files must follow the attached baseline enhancement directory +# Compute PESQ, STOI, eSTOI, and SDR scores if [ $stage -le 2 ]; then - if $flatstart; then - local/run_gmm.sh $enhancement_method $enhancement_data $chime4_data - else - local/run_gmm_recog.sh $enhancement_method $enhancement_data $modeldir + if [ ! -f local/bss_eval_sources.m ] || [ ! -f local/stoi.m ] || [ ! -f local/estoi.m ] || [ ! -f local/PESQ ]; then + # download and install speech enhancement evaluation tools + local/download_se_eval_tool.sh fi + chime4_rir_data=local/nn-gev/data/audio/16kHz/isolated_ext + if [ ! -d $chime4_rir_data ]; then + echo "$chime4_rir_data does not exist. Please run 'blstm_gev' enhancement method first;" && exit 1; + fi + local/compute_pesq.sh $enhancement $enhancement_data $chime4_rir_data $PWD + local/compute_stoi_estoi_sdr.sh $enhancement $enhancement_data $chime4_rir_data + local/compute_pesq.sh NOISY_1ch $chime4_data/data/audio/16kHz/isolated_1ch_track/ $chime4_rir_data $PWD + local/compute_stoi_estoi_sdr.sh NOISY_1ch $chime4_data/data/audio/16kHz/isolated_1ch_track/ $chime4_rir_data + local/write_se_results.sh $enhancement + local/write_se_results.sh NOISY_1ch fi -# DNN based ASR experiment -# Since it takes time to evaluate DNN, we make the GMM and DNN scripts separately. -# You may execute it after you would have promising results using GMM-based ASR experiments +# GMM based ASR experiment +# Please set a directory of your speech enhancement method. +# The directory structure and audio files must follow the attached baseline enhancement directory if [ $stage -le 3 ]; then - if $tdnn; then - if $flatstart; then - local/chain/run_tdnn.sh $enhancement_method - else - local/chain/run_tdnn_recog.sh $enhancement_method $modeldir - fi - elif $tdnn_lstm; then - if $flatstart; then - local/chain/run_tdnn_lstm.sh $enhancement_method - else - local/chain/run_tdnn_lstm_recog.sh $enhancement_method $modeldir - fi - else - if $flatstart; then - local/run_dnn.sh $enhancement_method - else - local/run_dnn_recog.sh $enhancement_method $modeldir - fi - fi + local/run_gmm.sh --add-enhanced-data $add_enhanced_data \ + --decode-only $gmm_decode_only $enhancement $enhancement_data $chime4_data +fi + +# TDNN based ASR experiment +# Since it takes time to evaluate TDNN, we make the GMM and TDNN scripts separately. +# You may execute it after you would have promising results using GMM-based ASR experiments +if [ $stage -le 4 ]; then + local/chain/run_tdnn.sh --decode-only $tdnn_decode_only $enhancement fi -flatstart=false + # LM-rescoring experiment with 5-gram and RNN LMs # It takes a few days to train a RNNLM. -if [ $stage -le 4 ]; then - if $flatstart; then - if $tdnn; then - local/run_lmrescore_tdnn.sh $chime4_data $enhancement_method - elif $tdnn_lstm; then - local/run_lmrescore_tdnn_lstm.sh $chime4_data $enhancement_method - else - local/run_lmrescore.sh $chime4_data $enhancement_method - fi - else - if $tdnn; then - local/run_lmrescore_tdnn_recog.sh $enhancement_method $modeldir - elif $tdnn_lstm; then - local/run_lmrescore_tdnn_lstm_recog.sh $enhancement_method $modeldir - else - local/run_lmrescore_recog.sh $enhancement_method $modeldir - fi - fi +if [ $stage -le 5 ]; then + local/run_lmrescore_tdnn.sh $chime4_data $enhancement +fi + +# LM-rescoring experiment with LSTM LMs +if [ $stage -le 6 ]; then + local/rnnlm/run_lstm.sh $enhancement fi echo "Done." diff --git a/egs/chime5/s5/local/chain/tuning/run_tdnn_1a.sh b/egs/chime5/s5/local/chain/tuning/run_tdnn_1a.sh index 45a7fd84bd6..5418ecf2b4f 100755 --- a/egs/chime5/s5/local/chain/tuning/run_tdnn_1a.sh +++ b/egs/chime5/s5/local/chain/tuning/run_tdnn_1a.sh @@ -133,7 +133,7 @@ if [ $stage -le 13 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.05" output_opts="l2-regularize=0.01 bottleneck-dim=320" diff --git a/egs/chime5/s5/local/json2text.py b/egs/chime5/s5/local/json2text.py index 4df0160efb6..a0142ad916e 100755 --- a/egs/chime5/s5/local/json2text.py +++ b/egs/chime5/s5/local/json2text.py @@ -25,8 +25,8 @@ def hms_to_seconds(hms): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('json', type=str, help='JSON transcription file') - parser.add_argument('--mictype', type=str, + parser.add_argument('json', help='JSON transcription file') + parser.add_argument('--mictype', choices=['ref', 'worn', 'u01', 'u02', 'u03', 'u04', 'u05', 'u06'], help='Type of microphones') args = parser.parse_args() diff --git a/egs/chime5/s5/local/score_for_submit.sh b/egs/chime5/s5/local/score_for_submit.sh index 5502c5994e5..23121d68b93 100755 --- a/egs/chime5/s5/local/score_for_submit.sh +++ b/egs/chime5/s5/local/score_for_submit.sh @@ -43,7 +43,7 @@ for session in S02 S09; do # get nerror nerr=`grep "\#csid" $score_result | grep $room | grep $session | awk '{sum+=$4+$5+$6} END {print sum}'` # get nwords from references (NF-2 means to exclude utterance id and " ref ") - nwrd=`grep " ref " $score_result | grep $room | grep $session | sed -e "s/\*//g" | awk '{sum+=NF-2} END {print sum}'` + nwrd=`grep "\#csid" $score_result | grep $room | grep $session | awk '{sum+=$3+$4+$6} END {print sum}'` # compute wer with scale=2 wer=`echo "scale=2; 100 * $nerr / $nwrd" | bc` @@ -59,7 +59,7 @@ echo -n "overall: " # get nerror nerr=`grep "\#csid" $score_result | awk '{sum+=$4+$5+$6} END {print sum}'` # get nwords from references (NF-2 means to exclude utterance id and " ref ") -nwrd=`grep " ref " $score_result | sed -e "s/\*//g" | awk '{sum+=NF-2} END {print sum}'` +nwrd=`grep "\#csid" $score_result | awk '{sum+=$3+$4+$6} END {print sum}'` # compute wer with scale=2 wer=`echo "scale=2; 100 * $nerr / $nwrd" | bc` echo -n "#words $nwrd, " @@ -81,7 +81,7 @@ for session in S01 S21; do # get nerror nerr=`grep "\#csid" $score_result | grep $room | grep $session | awk '{sum+=$4+$5+$6} END {print sum}'` # get nwords from references (NF-2 means to exclude utterance id and " ref ") - nwrd=`grep " ref " $score_result | grep $room | grep $session | sed -e "s/\*//g" | awk '{sum+=NF-2} END {print sum}'` + nwrd=`grep "\#csid" $score_result | grep $room | grep $session | awk '{sum+=$3+$4+$6} END {print sum}'` # compute wer with scale=2 wer=`echo "scale=2; 100 * $nerr / $nwrd" | bc` @@ -98,7 +98,7 @@ if $do_eval; then # get nerror nerr=`grep "\#csid" $score_result | awk '{sum+=$4+$5+$6} END {print sum}'` # get nwords from references (NF-2 means to exclude utterance id and " ref ") - nwrd=`grep " ref " $score_result | sed -e "s/\*//g" | awk '{sum+=NF-2} END {print sum}'` + nwrd=`grep "\#csid" $score_result | awk '{sum+=$3+$4+$6} END {print sum}'` # compute wer with scale=2 wer=`echo "scale=2; 100 * $nerr / $nwrd" | bc` echo -n "overall: " diff --git a/egs/cifar/v1/image/copy_data_dir.sh b/egs/cifar/v1/image/copy_data_dir.sh new file mode 100755 index 00000000000..c923f5cc07a --- /dev/null +++ b/egs/cifar/v1/image/copy_data_dir.sh @@ -0,0 +1,118 @@ +#!/bin/bash + +# Copyright 2013 Johns Hopkins University (author: Daniel Povey) +# Apache 2.0 + +# This script operates on a directory, such as in data/train/, +# that contains some subset of the following files: +# feats.scp +# images.scp +# vad.scp +# spk2utt +# utt2spk +# text +# +# It copies to another directory, possibly adding a specified prefix or a suffix +# to the utterance and/or speaker names. Note, the recording-ids stay the same. +# + + +# begin configuration section +spk_prefix= +utt_prefix= +spk_suffix= +utt_suffix= +validate_opts= # should rarely be needed. +# end configuration section + +. utils/parse_options.sh + +if [ $# != 2 ]; then + echo "Usage: " + echo " $0 [options] " + echo "e.g.:" + echo " $0 --spk-prefix=1- --utt-prefix=1- data/train data/train_1" + echo "Options" + echo " --spk-prefix= # Prefix for speaker ids, default empty" + echo " --utt-prefix= # Prefix for utterance ids, default empty" + echo " --spk-suffix= # Suffix for speaker ids, default empty" + echo " --utt-suffix= # Suffix for utterance ids, default empty" + exit 1; +fi + + +export LC_ALL=C + +srcdir=$1 +destdir=$2 + +if [ ! -f $srcdir/utt2spk ]; then + echo "copy_data_dir.sh: no such file $srcdir/utt2spk" + exit 1; +fi + +if [ "$destdir" == "$srcdir" ]; then + echo "$0: this script requires and to be different." + exit 1 +fi + +set -e; + +mkdir -p $destdir + +cat $srcdir/utt2spk | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s %s%s%s\n", $1, p, $1, s);}' > $destdir/utt_map +cat $srcdir/spk2utt | awk -v p=$spk_prefix -v s=$spk_suffix '{printf("%s %s%s%s\n", $1, p, $1, s);}' > $destdir/spk_map + +if [ ! -f $srcdir/utt2uniq ]; then + if [[ ! -z $utt_prefix || ! -z $utt_suffix ]]; then + cat $srcdir/utt2spk | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $1);}' > $destdir/utt2uniq + fi +else + cat $srcdir/utt2uniq | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $2);}' > $destdir/utt2uniq +fi + +cat $srcdir/utt2spk | utils/apply_map.pl -f 1 $destdir/utt_map | \ + utils/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk + +utils/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt + +if [ -f $srcdir/feats.scp ]; then + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp +fi + +if [ -f $srcdir/vad.scp ]; then + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp +fi + +if [ -f $srcdir/images.scp ]; then + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/images.scp >$destdir/images.scp +fi + +if [ -f $srcdir/text ]; then + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text +fi +if [ -f $srcdir/utt2dur ]; then + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur +fi +if [ -f $srcdir/cmvn.scp ]; then + utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp +fi + +rm $destdir/spk_map $destdir/utt_map + +echo "$0: copied data from $srcdir to $destdir" + +for f in feats.scp cmvn.scp vad.scp utt2uniq utt2dur utt2num_frames text images.scp; do + if [ -f $destdir/$f ] && [ ! -f $srcdir/$f ]; then + echo "$0: file $f exists in dest $destdir but not in src $srcdir. Moving it to" + echo " ... $destdir/.backup/$f" + mkdir -p $destdir/.backup + mv $destdir/$f $destdir/.backup/ + fi +done + + +[ ! -f $srcdir/feats.scp ] && validate_opts="$validate_opts --no-feats" +[ ! -f $srcdir/text ] && validate_opts="$validate_opts --no-text" + +utils/validate_data_dir.sh $validate_opts $destdir diff --git a/egs/cifar/v1/image/get_allowed_lengths.py b/egs/cifar/v1/image/get_allowed_lengths.py index 02321fdd2df..33996c8eef1 100755 --- a/egs/cifar/v1/image/get_allowed_lengths.py +++ b/egs/cifar/v1/image/get_allowed_lengths.py @@ -10,6 +10,7 @@ file is later used by make_features.py to pad each image sufficiently so that they all have an allowed length. This is intended for end2end chain training. """ +from __future__ import division import argparse import os @@ -117,14 +118,14 @@ def find_allowed_durations(start_len, end_len, args): (length // args.frame_subsampling_factor)) allowed_lengths.append(length) fp.write("{}\n".format(int(length))) - length *= args.factor + length = max(length * args.factor, length + args.frame_subsampling_factor) return allowed_lengths def main(): args = get_args() - args.factor = 1.0 + args.factor / 100.0 + args.factor = 1.0 + args.factor/100.0 image2length = read_kaldi_mapfile(os.path.join(args.srcdir, 'image2num_frames')) @@ -133,7 +134,7 @@ def main(): "Coverage rate: {}%".format(start_dur, end_dur, 100.0 - args.coverage_factor * 2)) logger.info("There will be {} unique allowed lengths " - "for the images.".format(int(math.log(end_dur / start_dur) / + "for the images.".format(int((math.log(float(end_dur)/start_dur))/ math.log(args.factor)))) allowed_durations = find_allowed_durations(start_dur, end_dur, args) diff --git a/egs/cifar/v1/image/matrix_to_image.py b/egs/cifar/v1/image/matrix_to_image.py index 52dcead7479..908b1f8b3ed 100755 --- a/egs/cifar/v1/image/matrix_to_image.py +++ b/egs/cifar/v1/image/matrix_to_image.py @@ -26,6 +26,7 @@ copy-feats --binary=false $(grep $imgid data/train/feats.scp | cut -d' ' -f2) - | \ image/matrix_to_image.py --color=1 > $imgid.bmp """ +from __future__ import division import argparse import sys @@ -59,7 +60,7 @@ num_cols = len(line) # initialize if len(line) != num_cols: raise Exception("All rows should be of the same length") - line = map(float, line) # string to float + line = [float(i) for i in line] # string to float if max(line) > 1: raise Excetion("Element value in the matrix should be normalized and no larger than 1") line = [int(x * 255) for x in line] # float to integer ranging from 0 to 255 @@ -70,7 +71,7 @@ if num_cols % 3 != 0: raise Exception("Number of columns should be a multiple of 3 in the color mode") width = num_rows - height = num_cols / 3 + height = num_cols/3 # reform the image matrix image_array = [[0 for i in range(width * 3)] for j in range(height)] for i in range(height): diff --git a/egs/madcat_ar/v1/local/make_features.py b/egs/cifar/v1/image/ocr/make_features.py similarity index 51% rename from egs/madcat_ar/v1/local/make_features.py rename to egs/cifar/v1/image/ocr/make_features.py index a21276d32c2..aa909f596c9 100755 --- a/egs/madcat_ar/v1/local/make_features.py +++ b/egs/cifar/v1/image/ocr/make_features.py @@ -2,27 +2,33 @@ # Copyright 2017 Chun Chieh Chang # 2017 Ashish Arora +# 2017 Yiwen Shao # 2018 Hossein Hadian +# 2018 Desh Raj """ This script converts images to Kaldi-format feature matrices. The input to this script is the path to a data directory, e.g. "data/train". This script reads the images listed in images.scp and writes them to standard output (by default) as Kaldi-formatted matrices (in text form). It also scales the images so they have the same height (via --feat-dim). It can optionally pad - the images (on left/right sides) with white pixels. + the images (on left/right sides) with white pixels. It by default performs + augmentation, (directly scaling down and scaling up). It will double the + data but we can turn augmentation off (via --no-augment). If an 'image2num_frames' file is found in the data dir, it will be used to enforce the images to have the specified length in that file by padding white pixels (the --padding option will be ignored in this case). This relates to end2end chain training. - eg. local/make_features.py data/train --feat-dim 40 """ - +import random import argparse import os import sys import numpy as np from scipy import misc +import math +from signal import signal, SIGPIPE, SIG_DFL +signal(SIGPIPE, SIG_DFL) parser = argparse.ArgumentParser(description="""Converts images (in 'dir'/images.scp) to features and writes them to standard output in text format.""") @@ -38,8 +44,15 @@ parser.add_argument('--padding', type=int, default=5, help='Number of white pixels to pad on the left' 'and right side of the image.') - - +parser.add_argument('--num-channels', type=int, default=1, + help='Number of color channels') +parser.add_argument('--vertical-shift', type=int, default=0, + help='total number of padding pixel per column') +parser.add_argument('--fliplr', type=lambda x: (str(x).lower()=='true'), default=False, + help="Flip the image left-right for right to left languages") +parser.add_argument('--augment_type', type=str, default='no_aug', + choices=['no_aug', 'random_scale','random_shift'], + help='Subset of data to process.') args = parser.parse_args() @@ -59,18 +72,6 @@ def write_kaldi_matrix(file_handle, matrix, key): file_handle.write("\n") file_handle.write(" ]\n") - -def get_scaled_image(im): - scale_size = args.feat_dim - sx = im.shape[1] # width - sy = im.shape[0] # height - scale = (1.0 * scale_size) / sy - nx = int(scale_size) - ny = int(scale * sx) - im = misc.imresize(im, (nx, ny)) - return im - - def horizontal_pad(im, allowed_lengths = None): if allowed_lengths is None: left_padding = right_padding = args.padding @@ -88,21 +89,73 @@ def horizontal_pad(im, allowed_lengths = None): left_padding = int(padding // 2) right_padding = padding - left_padding dim_y = im.shape[0] # height - im_pad = np.concatenate((255 * np.ones((dim_y, left_padding), - dtype=int), im), axis=1) - im_pad1 = np.concatenate((im_pad, 255 * np.ones((dim_y, right_padding), - dtype=int)), axis=1) + if args.num_channels in [1,4]: + im_pad = np.concatenate((255 * np.ones((dim_y, left_padding), + dtype=int), im), axis=1) + im_pad1 = np.concatenate((im_pad, 255 * np.ones((dim_y, right_padding), + dtype=int)), axis=1) + else: + im_pad = np.concatenate((255 * np.ones((dim_y, left_padding, args.num_channels), + dtype=int), im), axis=1) + im_pad1 = np.concatenate((im_pad, 255 * np.ones((dim_y, right_padding, args.num_channels), + dtype=int)), axis=1) return im_pad1 +def get_scaled_image_aug(im, mode='normal'): + scale_size = args.feat_dim + sx = im.shape[1] + sy = im.shape[0] + scale = (1.0 * scale_size) / sy + nx = int(scale_size) + ny = int(scale * sx) + scale_size = random.randint(10, 30) + scale = (1.0 * scale_size) / sy + down_nx = int(scale_size) + down_ny = int(scale * sx) + if mode == 'normal': + im = misc.imresize(im, (nx, ny)) + return im + else: + im_scaled_down = misc.imresize(im, (down_nx, down_ny)) + im_scaled_up = misc.imresize(im_scaled_down, (nx, ny)) + return im_scaled_up + return im -### main ### +def vertical_shift(im, mode='normal'): + if args.vertical_shift == 0: + return im + total = args.vertical_shift + if mode == 'notmid': + val = random.randint(0, 1) + if val == 0: + mode = 'top' + else: + mode = 'bottom' + if mode == 'normal': + top = int(total / 2) + bottom = total - top + elif mode == 'top': # more padding on top + top = random.randint(total / 2, total) + bottom = total - top + elif mode == 'bottom': # more padding on bottom + top = random.randint(0, total / 2) + bottom = total - top + width = im.shape[1] + im_pad = np.concatenate( + (255 * np.ones((top, width), dtype=int) - + np.random.normal(2, 1, (top, width)).astype(int), im), axis=0) + im_pad = np.concatenate( + (im_pad, 255 * np.ones((bottom, width), dtype=int) - + np.random.normal(2, 1, (bottom, width)).astype(int)), axis=0) + return im_pad +### main ### +random.seed(1) data_list_path = args.images_scp_path - if args.out_ark == '-': out_fh = sys.stdout else: - out_fh = open(args.out_ark,'wb') + out_fh = open(args.out_ark,'w') allowed_lengths = None allowed_len_handle = args.allowed_len_file_path @@ -123,13 +176,31 @@ def horizontal_pad(im, allowed_lengths = None): line_vect = line.split(' ') image_id = line_vect[0] image_path = line_vect[1] - im = misc.imread(image_path) - im_scaled = get_scaled_image(im) - im_horizontal_padded = horizontal_pad(im_scaled, allowed_lengths) - if im_horizontal_padded is None: + if args.num_channels == 4: + im = misc.imread(image_path, mode='L') + else: + im = misc.imread(image_path) + if args.fliplr: + im = np.fliplr(im) + if args.augment_type == 'no_aug' or 'random_shift': + im = get_scaled_image_aug(im, 'normal') + elif args.augment_type == 'random_scale': + im = get_scaled_image_aug(im, 'scaled') + im = horizontal_pad(im, allowed_lengths) + if im is None: num_fail += 1 continue - data = np.transpose(im_horizontal_padded, (1, 0)) + if args.augment_type == 'no_aug' or 'random_scale': + im = vertical_shift(im, 'normal') + elif args.augment_type == 'random_shift': + im = vertical_shift(im, 'notmid') + if args.num_channels in [1,4]: + data = np.transpose(im, (1, 0)) + elif args.num_channels == 3: + H = im.shape[0] + W = im.shape[1] + C = im.shape[2] + data = np.reshape(np.transpose(im, (1, 0, 2)), (W, H * C)) data = np.divide(data, 255.0) num_ok += 1 write_kaldi_matrix(out_fh, data, image_id) diff --git a/egs/cifar/v1/image/select_image_in_egs.py b/egs/cifar/v1/image/select_image_in_egs.py index 88d7d568e66..dbf48e6403d 100755 --- a/egs/cifar/v1/image/select_image_in_egs.py +++ b/egs/cifar/v1/image/select_image_in_egs.py @@ -9,6 +9,7 @@ # --vertical-shift=0.3 --srand=27 --num-channels=3 ark:exp/cifar10_egs/egs.1.ark ark,t:- | \ # image/select_image_in_egs.py $id | image/matrix_to_image.py --color 3 > $id.bmp +from __future__ import print_function import argparse import sys diff --git a/egs/cifar/v1/local/process_data.py b/egs/cifar/v1/local/process_data.py index 51173dafc6f..38a599297d2 100755 --- a/egs/cifar/v1/local/process_data.py +++ b/egs/cifar/v1/local/process_data.py @@ -6,6 +6,7 @@ """ This script prepares the training and test data for CIFAR-10 or CIFAR-100. """ +from __future__ import division import argparse import os @@ -14,13 +15,13 @@ parser = argparse.ArgumentParser(description="""Converts train/test data of CIFAR-10 or CIFAR-100 to Kaldi feature format""") -parser.add_argument('database', type=str, +parser.add_argument('database', default='data/dl/cifar-10-batches-bin', help='path to downloaded cifar data (binary version)') -parser.add_argument('dir', type=str, help='output dir') -parser.add_argument('--cifar-version', type=str, default='CIFAR-10', choices=['CIFAR-10', 'CIFAR-100']) -parser.add_argument('--dataset', type=str, default='train', choices=['train', 'test']) -parser.add_argument('--out-ark', type=str, default='-', help='where to write output feature data') +parser.add_argument('dir', help='output dir') +parser.add_argument('--cifar-version', default='CIFAR-10', choices=['CIFAR-10', 'CIFAR-100']) +parser.add_argument('--dataset', default='train', choices=['train', 'test']) +parser.add_argument('--out-ark', default='-', help='where to write output feature data') args = parser.parse_args() @@ -37,7 +38,7 @@ def load_cifar10_data_batch(datafile): for i in range(num_images_in_batch): label = ord(fh.read(1)) bin_img = fh.read(C * H * W) - img = [[[ord(byte) / 255.0 for byte in bin_img[channel*H*W+row*W:channel*H*W+(row+1)*W]] + img = [[[ord(byte)/255.0 for byte in bin_img[channel*H*W+row*W:channel*H*W+(row+1)*W]] for row in range(H)] for channel in range(C)] labels += [label] data += [img] @@ -52,7 +53,7 @@ def load_cifar100_data_batch(datafile, num_images_in_batch): coarse_label = ord(fh.read(1)) fine_label = ord(fh.read(1)) bin_img = fh.read(C * H * W) - img = [[[ord(byte) / 255.0 for byte in bin_img[channel*H*W+row*W:channel*H*W+(row+1)*W]] + img = [[[ord(byte)/255.0 for byte in bin_img[channel*H*W+row*W:channel*H*W+(row+1)*W]] for row in range(H)] for channel in range(C)] fine_labels += [fine_label] coarse_labels += [coarse_label] @@ -80,7 +81,7 @@ def write_kaldi_matrix(file_handle, matrix, key): if num_cols != len(matrix[row_index]): raise Exception("All the rows of a matrix are expected to " "have the same length") - file_handle.write(" ".join(map(lambda x: str(x), matrix[row_index]))) + file_handle.write(" ".join([str(x) for x in matrix[row_index]])) if row_index != num_rows - 1: file_handle.write("\n") file_handle.write(" ]\n") diff --git a/egs/commonvoice/s5/local/chain/tuning/run_tdnn_1a.sh b/egs/commonvoice/s5/local/chain/tuning/run_tdnn_1a.sh index 635e3de1076..d4acd0fed4b 100755 --- a/egs/commonvoice/s5/local/chain/tuning/run_tdnn_1a.sh +++ b/egs/commonvoice/s5/local/chain/tuning/run_tdnn_1a.sh @@ -141,7 +141,7 @@ if [ $stage -le 13 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/commonvoice/s5/local/prepare_dict.sh b/egs/commonvoice/s5/local/prepare_dict.sh index d6d1aba41fb..cdfffe42080 100755 --- a/egs/commonvoice/s5/local/prepare_dict.sh +++ b/egs/commonvoice/s5/local/prepare_dict.sh @@ -52,7 +52,7 @@ if [[ "$(uname)" == "Darwin" ]]; then alias readlink=greadlink fi -sequitur=$KALDI_ROOT/tools/sequitur +sequitur=$KALDI_ROOT/tools/sequitur-g2p export PATH=$PATH:$sequitur/bin export PYTHONPATH=$PYTHONPATH:`utils/make_absolute.sh $sequitur/lib/python*/site-packages` diff --git a/egs/csj/s5/local/chain/tuning/run_tdnn_1a.sh b/egs/csj/s5/local/chain/tuning/run_tdnn_1a.sh index a463db77066..75ceb80e3e0 100755 --- a/egs/csj/s5/local/chain/tuning/run_tdnn_1a.sh +++ b/egs/csj/s5/local/chain/tuning/run_tdnn_1a.sh @@ -133,7 +133,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/csj/s5/local/csj_data_prep.sh b/egs/csj/s5/local/csj_data_prep.sh index 55738bf0e37..69e2865e316 100755 --- a/egs/csj/s5/local/csj_data_prep.sh +++ b/egs/csj/s5/local/csj_data_prep.sh @@ -45,7 +45,9 @@ if [ ! -d $CSJ ]; then fi # CSJ dictionary file check -[ ! -f $dir/lexicon.txt ] && cp $CSJ/lexicon/lexicon.txt $dir || exit 1; +if [ ! -f $dir/lexicon.txt ]; then + cp $CSJ/lexicon/lexicon.txt $dir || exit 1; +fi ### Config of using wav data that relates with acoustic model training ### if [ $mode -eq 3 ] diff --git a/egs/csj/s5/local/csj_make_trans/csj_autorun.sh b/egs/csj/s5/local/csj_make_trans/csj_autorun.sh index f288e4fb4d3..5cd78ee94ae 100755 --- a/egs/csj/s5/local/csj_make_trans/csj_autorun.sh +++ b/egs/csj/s5/local/csj_make_trans/csj_autorun.sh @@ -61,7 +61,7 @@ if [ ! -e $outd/.done_make_trans ];then mkdir -p $outd/$vol/$id case "$csjv" in - "usb" ) TPATH="$resource/${SDB}$vol" ; WPATH="$resource/$WAV" ;; + "usb" ) TPATH="$resource/${SDB}$vol" ; WPATH="$resource/${WAV}$vol" ;; "dvd" ) TPATH="$resource/$vol/$id" ; WPATH="$resource/$vol/$id" ;; "merl" ) TPATH="$resource/$vol/$SDB" ; WPATH="$resource/$vol/$WAV" ;; esac diff --git a/egs/csj/s5/local/nnet/run_dnn_tandem_uc.sh b/egs/csj/s5/local/nnet/run_dnn_tandem_uc.sh index 4677ff473cb..297aed1f486 100755 --- a/egs/csj/s5/local/nnet/run_dnn_tandem_uc.sh +++ b/egs/csj/s5/local/nnet/run_dnn_tandem_uc.sh @@ -280,4 +280,4 @@ exit 0 %WER 14.88 [ 2557 / 17189, 556 ins, 359 del, 1642 sub ] exp/tandem2uc-tri4/decode_eval3_csj/wer_20_0.5 %WER 17.03 [ 2927 / 17189, 592 ins, 417 del, 1918 sub ] exp/tandem2uc-tri4/decode_eval3_csj.si/wer_20_1.0 %WER 13.44 [ 2311 / 17189, 430 ins, 340 del, 1541 sub ] exp/tandem2uc-tri4_mmi_b0.1/decode_eval3_csj/wer_20_1.0 -EOF \ No newline at end of file +EOF diff --git a/egs/csj/s5/local/run_sgmm2.sh b/egs/csj/s5/local/run_sgmm2.sh index 619c6c5d1ef..c66b43c4f7f 100755 --- a/egs/csj/s5/local/run_sgmm2.sh +++ b/egs/csj/s5/local/run_sgmm2.sh @@ -18,7 +18,7 @@ fi if [ ! -f exp/ubm5/final.ubm ]; then steps/train_ubm.sh --cmd "$train_cmd" 1400 data/train_nodup data/lang \ exp/tri4_ali_nodup exp/ubm5 || exit 1; -fi +fi # steps/train_sgmm2.sh --cmd "$train_cmd" \ steps/train_sgmm2_group.sh --cmd "$train_cmd" \ diff --git a/egs/dihard_2018/README.txt b/egs/dihard_2018/README.txt new file mode 100644 index 00000000000..a7a00c8bf4e --- /dev/null +++ b/egs/dihard_2018/README.txt @@ -0,0 +1,14 @@ + + This is a Kaldi recipe for The First DIHARD Speech Diarization Challenge. + DIHARD is a new annual challenge focusing on "hard" diarization; that is, + speech diarization for challenging corpora where there is an expectation that + the current state-of-the-art will fare poorly, including, but not limited + to: clinical interviews, extended child language acquisition recordings, + YouTube videos and "speech in the wild" (e.g., recordings in restaurants) + See https://coml.lscp.ens.fr/dihard/index.html for details. + + The subdirectories "v1" and so on are different speaker diarization + recipes. The recipe in v1 demonstrates a standard approach using a + full-covariance GMM-UBM, i-vectors, PLDA scoring and agglomerative + hierarchical clustering. The example in v2 demonstrates DNN speaker + embeddings, PLDA scoring and agglomerative hierarchical clustering. diff --git a/egs/dihard_2018/v1/README.txt b/egs/dihard_2018/v1/README.txt new file mode 100644 index 00000000000..98bf3641b03 --- /dev/null +++ b/egs/dihard_2018/v1/README.txt @@ -0,0 +1,13 @@ + This recipe is the speaker diarization recipe for The First DIHARD Speech + Diarization Challenge (DIHARD 2018). There are two tracks in the DIHARD 2018 + competition , one uses oracle SAD (track1) and the other required that SAD + was performed from scratch (track2). This script is for track1. + + The recipe is closely based on the following paper: + http://www.danielpovey.com/files/2018_interspeech_dihard.pdf but doesn't + contain the VB refinement. The whole system mainly contains full-covariance + GMM-UBM, i-vector extractor (T-matrix), PLDA scoring and agglomerative + hierarchical clustering. The VoxCeleb datasets are used for training i-vectors + and PLDA. The development set of the DIHARD 2018 competition is used as + validation set to tune parameters. The system is tested on the DIHARD 2018 + evaluation set. diff --git a/egs/dihard_2018/v1/cmd.sh b/egs/dihard_2018/v1/cmd.sh new file mode 100755 index 00000000000..c35cd18f287 --- /dev/null +++ b/egs/dihard_2018/v1/cmd.sh @@ -0,0 +1,15 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +export train_cmd="queue.pl" + + diff --git a/egs/dihard_2018/v1/conf/mfcc.conf b/egs/dihard_2018/v1/conf/mfcc.conf new file mode 100644 index 00000000000..649cffb9de8 --- /dev/null +++ b/egs/dihard_2018/v1/conf/mfcc.conf @@ -0,0 +1,7 @@ +--sample-frequency=16000 +--frame-length=25 # the default is 25 +--low-freq=20 # the default. +--high-freq=7600 # the default is zero meaning use the Nyquist (8k in this case). +--num-mel-bins=30 +--num-ceps=24 +--snip-edges=false diff --git a/egs/dihard_2018/v1/conf/vad.conf b/egs/dihard_2018/v1/conf/vad.conf new file mode 100644 index 00000000000..a0ca2449b10 --- /dev/null +++ b/egs/dihard_2018/v1/conf/vad.conf @@ -0,0 +1,2 @@ +--vad-energy-threshold=5.5 +--vad-energy-mean-scale=0.5 diff --git a/egs/dihard_2018/v1/diarization b/egs/dihard_2018/v1/diarization new file mode 120000 index 00000000000..bad937c1444 --- /dev/null +++ b/egs/dihard_2018/v1/diarization @@ -0,0 +1 @@ +../../callhome_diarization/v1/diarization \ No newline at end of file diff --git a/egs/dihard_2018/v1/local/make_dihard_2018_dev.py b/egs/dihard_2018/v1/local/make_dihard_2018_dev.py new file mode 100755 index 00000000000..fa652da8b4c --- /dev/null +++ b/egs/dihard_2018/v1/local/make_dihard_2018_dev.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 + +# This script is called by local/make_dihard_2018_dev.sh, and it creates the +# necessary files for DIHARD 2018 development directory. + +import sys, os + +def prepare_dihard_2018_dev(src_dir, data_dir): + wavscp_fi = open(data_dir + "/wav.scp" , 'w') + utt2spk_fi = open(data_dir + "/utt2spk" , 'w') + segments_fi = open(data_dir + "/segments" , 'w') + rttm_fi = open(data_dir + "/rttm" , 'w') + reco2num_spk_fi = open(data_dir + "/reco2num_spk" , 'w') + + for subdir, dirs, files in os.walk(src_dir): + for file in files: + filename = os.path.join(subdir, file) + if filename.endswith(".lab"): + utt = os.path.basename(filename).split(".")[0] + lines = open(filename, 'r').readlines() + segment_id = 0 + for line in lines: + start, end, speech = line.split() + segment_id_str = "{}_{}".format(utt, str(segment_id).zfill(4)) + segments_str = "{} {} {} {}\n".format(segment_id_str, utt, start, end) + utt2spk_str = "{} {}\n".format(segment_id_str, utt) + segments_fi.write(segments_str) + utt2spk_fi.write(utt2spk_str) + segment_id += 1 + wav_str = "{} sox -t flac {}/data/flac/{}.flac -t wav -r 16k "\ + "-b 16 - channels 1 |\n".format(utt, src_dir, utt) + wavscp_fi.write(wav_str) + with open("{}/data/rttm/{}.rttm".format(src_dir, utt), 'r') as fh: + rttm_str = fh.read() + rttm_fi.write(rttm_str) + with open("{}/data/rttm/{}.rttm".format(src_dir, utt), 'r') as fh: + rttm_list = fh.readlines() + spk_list = [(x.split())[7] for x in rttm_list] + num_spk = len(set(spk_list)) + reco2num_spk_fi.write("{} {}\n".format(utt, num_spk)) + wavscp_fi.close() + utt2spk_fi.close() + segments_fi.close() + rttm_fi.close() + reco2num_spk_fi.close() + return 0 + +def main(): + src_dir = sys.argv[1] + data_dir = sys.argv[2] + if not os.path.exists(data_dir): + os.makedirs(data_dir) + prepare_dihard_2018_dev(src_dir, data_dir) + return 0 + +if __name__=="__main__": + main() diff --git a/egs/dihard_2018/v1/local/make_dihard_2018_dev.sh b/egs/dihard_2018/v1/local/make_dihard_2018_dev.sh new file mode 100755 index 00000000000..cc48e2e792a --- /dev/null +++ b/egs/dihard_2018/v1/local/make_dihard_2018_dev.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# Copyright 2018 Zili Huang +# Apache 2.0. +# +# This script, called by ../run.sh, creates the DIHARD 2018 development data directory. + +if [ $# != 2 ]; then + echo "Usage: $0 " + echo " e.g.: $0 /export/corpora/LDC/LDC2018E31 data/dihard_2018_dev" +fi + +path_to_dihard_2018_dev=$1 +data_dir=$2 + +echo "Preparing ${data_dir}..." +local/make_dihard_2018_dev.py ${path_to_dihard_2018_dev} ${data_dir} + +sort -k 2,2 -s ${data_dir}/rttm > ${data_dir}/rttm_tmp +mv ${data_dir}/rttm_tmp ${data_dir}/rttm +sort -k 1,1 -s ${data_dir}/reco2num_spk > ${data_dir}/reco2num_spk_tmp +mv ${data_dir}/reco2num_spk_tmp ${data_dir}/reco2num_spk +utils/fix_data_dir.sh ${data_dir} diff --git a/egs/dihard_2018/v1/local/make_dihard_2018_eval.py b/egs/dihard_2018/v1/local/make_dihard_2018_eval.py new file mode 100755 index 00000000000..2a8acbee58d --- /dev/null +++ b/egs/dihard_2018/v1/local/make_dihard_2018_eval.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 + +# This script is called by local/make_dihard_2018_eval.sh, and it creates the +# necessary files for DIHARD 2018 evaluation directory. + +import sys, os + +def prepare_dihard_2018_eval(src_dir, data_dir): + wavscp_fi = open(data_dir + "/wav.scp" , 'w') + utt2spk_fi = open(data_dir + "/utt2spk" , 'w') + segments_fi = open(data_dir + "/segments" , 'w') + rttm_fi = open(data_dir + "/rttm" , 'w') + reco2num_spk_fi = open(data_dir + "/reco2num_spk" , 'w') + + for subdir, dirs, files in os.walk(src_dir): + for file in files: + filename = os.path.join(subdir, file) + if filename.endswith(".lab"): + utt = os.path.basename(filename).split(".")[0] + lines = open(filename, 'r').readlines() + segment_id = 0 + for line in lines: + start, end, speech = line.split() + segment_id_str = "{}_{}".format(utt, str(segment_id).zfill(4)) + segments_str = "{} {} {} {}\n".format(segment_id_str, utt, start, end) + utt2spk_str = "{} {}\n".format(segment_id_str, utt) + segments_fi.write(segments_str) + utt2spk_fi.write(utt2spk_str) + segment_id += 1 + wav_str = "{} sox -t flac {}/data/flac/{}.flac -t wav -r 16k "\ + "-b 16 - channels 1 |\n".format(utt, src_dir, utt) + wavscp_fi.write(wav_str) + with open("{}/data/rttm/{}.rttm".format(src_dir, utt), 'r') as fh: + rttm_str = fh.read() + rttm_fi.write(rttm_str) + with open("{}/data/rttm/{}.rttm".format(src_dir, utt), 'r') as fh: + rttm_list = fh.readlines() + spk_list = [(x.split())[7] for x in rttm_list] + num_spk = len(set(spk_list)) + reco2num_spk_fi.write("{} {}\n".format(utt, num_spk)) + wavscp_fi.close() + utt2spk_fi.close() + segments_fi.close() + rttm_fi.close() + reco2num_spk_fi.close() + return 0 + +def main(): + src_dir = sys.argv[1] + data_dir = sys.argv[2] + if not os.path.exists(data_dir): + os.makedirs(data_dir) + prepare_dihard_2018_eval(src_dir, data_dir) + return 0 + +if __name__=="__main__": + main() diff --git a/egs/dihard_2018/v1/local/make_dihard_2018_eval.sh b/egs/dihard_2018/v1/local/make_dihard_2018_eval.sh new file mode 100755 index 00000000000..0a461c635ec --- /dev/null +++ b/egs/dihard_2018/v1/local/make_dihard_2018_eval.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# Copyright 2018 Zili Huang +# Apache 2.0. +# +# This script, called by ../run.sh, creates the DIHARD 2018 evaluation directory. + +if [ $# != 2 ]; then + echo "Usage: $0 " + echo " e.g.: $0 /export/corpora/LDC/LDC2018E32v1.1 data/dihard_2018_eval" +fi + +path_to_dihard_2018_eval=$1 +data_dir=$2 + +echo "Preparing ${data_dir}..." +local/make_dihard_2018_eval.py ${path_to_dihard_2018_eval} ${data_dir} + +sort -k 2,2 -s ${data_dir}/rttm > ${data_dir}/rttm_tmp +mv ${data_dir}/rttm_tmp ${data_dir}/rttm +sort -k 1,1 -s ${data_dir}/reco2num_spk > ${data_dir}/reco2num_spk_tmp +mv ${data_dir}/reco2num_spk_tmp ${data_dir}/reco2num_spk +utils/fix_data_dir.sh ${data_dir} diff --git a/egs/dihard_2018/v1/local/make_voxceleb1.pl b/egs/dihard_2018/v1/local/make_voxceleb1.pl new file mode 100755 index 00000000000..2268c20ab52 --- /dev/null +++ b/egs/dihard_2018/v1/local/make_voxceleb1.pl @@ -0,0 +1,130 @@ +#!/usr/bin/perl +# +# Copyright 2018 Ewald Enzinger +# 2018 David Snyder +# +# Usage: make_voxceleb1.pl /export/voxceleb1 data/ + +if (@ARGV != 2) { + print STDERR "Usage: $0 \n"; + print STDERR "e.g. $0 /export/voxceleb1 data/\n"; + exit(1); +} + +($data_base, $out_dir) = @ARGV; +my $out_test_dir = "$out_dir/voxceleb1_test"; +my $out_train_dir = "$out_dir/voxceleb1_train"; + +if (system("mkdir -p $out_test_dir") != 0) { + die "Error making directory $out_test_dir"; +} + +if (system("mkdir -p $out_train_dir") != 0) { + die "Error making directory $out_train_dir"; +} + +opendir my $dh, "$data_base/voxceleb1_wav" or die "Cannot open directory: $!"; +my @spkr_dirs = grep {-d "$data_base/voxceleb1_wav/$_" && ! /^\.{1,2}$/} readdir($dh); +closedir $dh; + +if (! -e "$data_base/voxceleb1_test.txt") { + system("wget -O $data_base/voxceleb1_test.txt http://www.openslr.org/resources/49/voxceleb1_test.txt"); +} + +if (! -e "$data_base/vox1_meta.csv") { + system("wget -O $data_base/vox1_meta.csv http://www.openslr.org/resources/49/vox1_meta.csv"); +} + +open(TRIAL_IN, "<", "$data_base/voxceleb1_test.txt") or die "Could not open the verification trials file $data_base/voxceleb1_test.txt"; +open(META_IN, "<", "$data_base/vox1_meta.csv") or die "Could not open the meta data file $data_base/vox1_meta.csv"; +open(SPKR_TEST, ">", "$out_test_dir/utt2spk") or die "Could not open the output file $out_test_dir/utt2spk"; +open(WAV_TEST, ">", "$out_test_dir/wav.scp") or die "Could not open the output file $out_test_dir/wav.scp"; +open(SPKR_TRAIN, ">", "$out_train_dir/utt2spk") or die "Could not open the output file $out_train_dir/utt2spk"; +open(WAV_TRAIN, ">", "$out_train_dir/wav.scp") or die "Could not open the output file $out_train_dir/wav.scp"; +open(TRIAL_OUT, ">", "$out_test_dir/trials") or die "Could not open the output file $out_test_dir/trials"; + +my %id2spkr = (); +while () { + chomp; + my ($vox_id, $spkr_id, $gender, $nation, $set) = split; + $id2spkr{$vox_id} = $spkr_id; +} + +my $test_spkrs = (); +while () { + chomp; + my ($tar_or_non, $path1, $path2) = split; + + # Create entry for left-hand side of trial + my ($spkr_id, $filename) = split('/', $path1); + my $rec_id = substr($filename, 0, 11); + my $segment = substr($filename, 12, 7); + my $utt_id1 = "$spkr_id-$rec_id-$segment"; + $test_spkrs{$spkr_id} = (); + + # Create entry for right-hand side of trial + my ($spkr_id, $filename) = split('/', $path2); + my $rec_id = substr($filename, 0, 11); + my $segment = substr($filename, 12, 7); + my $utt_id2 = "$spkr_id-$rec_id-$segment"; + $test_spkrs{$spkr_id} = (); + + my $target = "nontarget"; + if ($tar_or_non eq "1") { + $target = "target"; + } + print TRIAL_OUT "$utt_id1 $utt_id2 $target\n"; +} + +foreach (@spkr_dirs) { + my $spkr_id = $_; + my $new_spkr_id = $spkr_id; + # If we're using a newer version of VoxCeleb1, we need to "deanonymize" + # the speaker labels. + if (exists $id2spkr{$spkr_id}) { + $new_spkr_id = $id2spkr{$spkr_id}; + } + opendir my $dh, "$data_base/voxceleb1_wav/$spkr_id/" or die "Cannot open directory: $!"; + my @files = map{s/\.[^.]+$//;$_}grep {/\.wav$/} readdir($dh); + closedir $dh; + foreach (@files) { + my $filename = $_; + my $rec_id = substr($filename, 0, 11); + my $segment = substr($filename, 12, 7); + my $wav = "$data_base/voxceleb1_wav/$spkr_id/$filename.wav"; + my $utt_id = "$new_spkr_id-$rec_id-$segment"; + if (exists $test_spkrs{$new_spkr_id}) { + print WAV_TEST "$utt_id", " $wav", "\n"; + print SPKR_TEST "$utt_id", " $new_spkr_id", "\n"; + } else { + print WAV_TRAIN "$utt_id", " $wav", "\n"; + print SPKR_TRAIN "$utt_id", " $new_spkr_id", "\n"; + } + } +} + +close(SPKR_TEST) or die; +close(WAV_TEST) or die; +close(SPKR_TRAIN) or die; +close(WAV_TRAIN) or die; +close(TRIAL_OUT) or die; +close(TRIAL_IN) or die; +close(META_IN) or die; + +if (system( + "utils/utt2spk_to_spk2utt.pl $out_test_dir/utt2spk >$out_test_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_test_dir"; +} +system("env LC_COLLATE=C utils/fix_data_dir.sh $out_test_dir"); +if (system("env LC_COLLATE=C utils/validate_data_dir.sh --no-text --no-feats $out_test_dir") != 0) { + die "Error validating directory $out_test_dir"; +} + +if (system( + "utils/utt2spk_to_spk2utt.pl $out_train_dir/utt2spk >$out_train_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_train_dir"; +} +system("env LC_COLLATE=C utils/fix_data_dir.sh $out_train_dir"); +if (system("env LC_COLLATE=C utils/validate_data_dir.sh --no-text --no-feats $out_train_dir") != 0) { + die "Error validating directory $out_train_dir"; +} diff --git a/egs/dihard_2018/v1/local/make_voxceleb2.pl b/egs/dihard_2018/v1/local/make_voxceleb2.pl new file mode 100755 index 00000000000..34c1591eba3 --- /dev/null +++ b/egs/dihard_2018/v1/local/make_voxceleb2.pl @@ -0,0 +1,70 @@ +#!/usr/bin/perl +# +# Copyright 2018 Ewald Enzinger +# +# Usage: make_voxceleb2.pl /export/voxceleb2 dev data/dev +# +# Note: This script requires ffmpeg to be installed and its location included in $PATH. + +if (@ARGV != 3) { + print STDERR "Usage: $0 \n"; + print STDERR "e.g. $0 /export/voxceleb2 dev data/dev\n"; + exit(1); +} + +# Check that ffmpeg is installed. +if (`which ffmpeg` eq "") { + die "Error: this script requires that ffmpeg is installed."; +} + +($data_base, $dataset, $out_dir) = @ARGV; + +if ("$dataset" ne "dev" && "$dataset" ne "test") { + die "dataset parameter must be 'dev' or 'test'!"; +} + +opendir my $dh, "$data_base/$dataset/aac" or die "Cannot open directory: $!"; +my @spkr_dirs = grep {-d "$data_base/$dataset/aac/$_" && ! /^\.{1,2}$/} readdir($dh); +closedir $dh; + +if (system("mkdir -p $out_dir") != 0) { + die "Error making directory $out_dir"; +} + +open(SPKR, ">", "$out_dir/utt2spk") or die "Could not open the output file $out_dir/utt2spk"; +open(WAV, ">", "$out_dir/wav.scp") or die "Could not open the output file $out_dir/wav.scp"; + +foreach (@spkr_dirs) { + my $spkr_id = $_; + + opendir my $dh, "$data_base/$dataset/aac/$spkr_id/" or die "Cannot open directory: $!"; + my @rec_dirs = grep {-d "$data_base/$dataset/aac/$spkr_id/$_" && ! /^\.{1,2}$/} readdir($dh); + closedir $dh; + + foreach (@rec_dirs) { + my $rec_id = $_; + + opendir my $dh, "$data_base/$dataset/aac/$spkr_id/$rec_id/" or die "Cannot open directory: $!"; + my @files = map{s/\.[^.]+$//;$_}grep {/\.m4a$/} readdir($dh); + closedir $dh; + + foreach (@files) { + my $name = $_; + my $wav = "ffmpeg -v 8 -i $data_base/$dataset/aac/$spkr_id/$rec_id/$name.m4a -f wav -acodec pcm_s16le -|"; + my $utt_id = "$spkr_id-$rec_id-$name"; + print WAV "$utt_id", " $wav", "\n"; + print SPKR "$utt_id", " $spkr_id", "\n"; + } + } +} +close(SPKR) or die; +close(WAV) or die; + +if (system( + "utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_dir"; +} +system("env LC_COLLATE=C utils/fix_data_dir.sh $out_dir"); +if (system("env LC_COLLATE=C utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { + die "Error validating directory $out_dir"; +} diff --git a/egs/dihard_2018/v1/local/prepare_feats.sh b/egs/dihard_2018/v1/local/prepare_feats.sh new file mode 100755 index 00000000000..9fa70a2d91e --- /dev/null +++ b/egs/dihard_2018/v1/local/prepare_feats.sh @@ -0,0 +1,91 @@ +#!/bin/bash +# +# Apache 2.0. + +# This script adds deltas, applies sliding window CMVN and writes the features to disk. +# +# Although this kind of script isn't necessary in speaker recognition recipes, +# it can be helpful in the diarization recipes. The script +# diarization/extract_ivectors.sh extracts i-vectors from very +# short (e.g., 1-2 seconds) segments. Therefore, in order to apply the sliding +# window CMVN in a meaningful way, it must be performed prior to performing +# the subsegmentation. + +nj=40 +cmd="run.pl" +stage=0 +norm_vars=false +center=true +compress=true +cmn_window=300 +delta_window=3 +delta_order=2 + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; +if [ $# != 3 ]; then + echo "Usage: $0 " + echo "e.g.: $0 data/train data/train_no_sil exp/make_ivector_features" + echo "Options: " + echo " --nj # number of parallel jobs" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --norm-vars # If true, normalize variances in the sliding window cmvn" + exit 1; +fi + +data_in=$1 +data_out=$2 +dir=$3 + +name=`basename $data_in` + +for f in $data_in/feats.scp ; do + [ ! -f $f ] && echo "$0: No such file $f" && exit 1; +done + +# Set various variables. +mkdir -p $dir/log +mkdir -p $data_out +featdir=$(utils/make_absolute.sh $dir) + +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $featdir/storage ]; then + utils/create_split_dir.pl \ + /export/b{14,15,16,17}/$USER/kaldi-data/egs/dihard_2018/v1/ivector-$(date +'%m_%d_%H_%M')/ivector_cmvn_feats/storage $featdir/storage +fi + +for n in $(seq $nj); do + # the next command does nothing unless $featdir/storage/ exists, see + # utils/create_data_link.pl for more info. + utils/create_data_link.pl $featdir/ivector_cmvn_feats_${name}.${n}.ark +done + +cp $data_in/utt2spk $data_out/utt2spk +cp $data_in/spk2utt $data_out/spk2utt +cp $data_in/wav.scp $data_out/wav.scp + +write_num_frames_opt="--write-num-frames=ark,t:$featdir/log/utt2num_frames.JOB" + +sdata_in=$data_in/split$nj; +utils/split_data.sh $data_in $nj || exit 1; + +delta_opts="--delta-window=$delta_window --delta-order=$delta_order" + +$cmd JOB=1:$nj $dir/log/create_ivector_cmvn_feats_${name}.JOB.log \ + add-deltas $delta_opts scp:${sdata_in}/JOB/feats.scp ark:- \| \ + apply-cmvn-sliding --norm-vars=false --center=true --cmn-window=$cmn_window \ + ark:- ark:- \| \ + copy-feats --compress=$compress $write_num_frames_opt ark:- \ + ark,scp:$featdir/ivector_cmvn_feats_${name}.JOB.ark,$featdir/ivector_cmvn_feats_${name}.JOB.scp || exit 1; + +for n in $(seq $nj); do + cat $featdir/ivector_cmvn_feats_${name}.$n.scp || exit 1; +done > ${data_out}/feats.scp || exit 1 + +for n in $(seq $nj); do + cat $featdir/log/utt2num_frames.$n || exit 1; +done > $data_out/utt2num_frames || exit 1 +rm $featdir/log/utt2num_frames.* + +echo "$0: Succeeded creating ivector features for $name" diff --git a/egs/dihard_2018/v1/path.sh b/egs/dihard_2018/v1/path.sh new file mode 100755 index 00000000000..851c14e27c3 --- /dev/null +++ b/egs/dihard_2018/v1/path.sh @@ -0,0 +1,5 @@ +export KALDI_ROOT=`pwd`/../../.. +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/tools/sph2pipe_v2.5:$KALDI_ROOT/tools/sctk/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C diff --git a/egs/dihard_2018/v1/run.sh b/egs/dihard_2018/v1/run.sh new file mode 100755 index 00000000000..44af9f48c3f --- /dev/null +++ b/egs/dihard_2018/v1/run.sh @@ -0,0 +1,235 @@ +#!/bin/bash +# Copyright 2017 Johns Hopkins University (Author: Daniel Garcia-Romero) +# 2017 Johns Hopkins University (Author: Daniel Povey) +# 2017-2018 David Snyder +# 2018 Ewald Enzinger +# 2018 Zili Huang +# Apache 2.0. +# +# See ../README.txt for more info on data required. +# Results (diarization error rate) are inline in comments below. + +. ./cmd.sh +. ./path.sh +set -e +mfccdir=`pwd`/mfcc +vaddir=`pwd`/mfcc + +voxceleb1_root=/export/corpora/VoxCeleb1 +voxceleb2_root=/export/corpora/VoxCeleb2 +dihard_2018_dev=/export/corpora/LDC/LDC2018E31 +dihard_2018_eval=/export/corpora/LDC/LDC2018E32v1.1 +num_components=2048 +ivector_dim=400 +ivec_dir=exp/extractor_c${num_components}_i${ivector_dim} + +stage=0 + +if [ $stage -le 0 ]; then + local/make_voxceleb2.pl $voxceleb2_root dev data/voxceleb2_train + local/make_voxceleb2.pl $voxceleb2_root test data/voxceleb2_test + # This script creates data/voxceleb1_test and data/voxceleb1_train. + # Our evaluation set is the test portion of VoxCeleb1. + local/make_voxceleb1.pl $voxceleb1_root data + # We'll train on all of VoxCeleb2, plus the training portion of VoxCeleb1. + # This should give 7,351 speakers and 1,277,503 utterances. + utils/combine_data.sh data/train data/voxceleb2_train data/voxceleb2_test data/voxceleb1_train + + # Prepare the development and evaluation set for DIHARD 2018. + local/make_dihard_2018_dev.sh $dihard_2018_dev data/dihard_2018_dev + local/make_dihard_2018_eval.sh $dihard_2018_eval data/dihard_2018_eval +fi + +if [ $stage -le 1 ]; then + # Make MFCCs for each dataset + for name in train dihard_2018_dev dihard_2018_eval; do + steps/make_mfcc.sh --write-utt2num-frames true \ + --mfcc-config conf/mfcc.conf --nj 40 --cmd "$train_cmd --max-jobs-run 20" \ + data/${name} exp/make_mfcc $mfccdir + utils/fix_data_dir.sh data/${name} + done + + # Compute the energy-based VAD for train + sid/compute_vad_decision.sh --nj 40 --cmd "$train_cmd" \ + data/train exp/make_vad $vaddir + utils/fix_data_dir.sh data/train + + # This writes features to disk after adding deltas and applying the sliding window CMN. + # Although this is somewhat wasteful in terms of disk space, for diarization + # it ends up being preferable to performing the CMN in memory. If the CMN + # were performed in memory it would need to be performed after the subsegmentation, + # which leads to poorer results. + for name in train dihard_2018_dev dihard_2018_eval; do + local/prepare_feats.sh --nj 40 --cmd "$train_cmd" \ + data/$name data/${name}_cmn exp/${name}_cmn + if [ -f data/$name/vad.scp ]; then + cp data/$name/vad.scp data/${name}_cmn/ + fi + if [ -f data/$name/segments ]; then + cp data/$name/segments data/${name}_cmn/ + fi + utils/fix_data_dir.sh data/${name}_cmn + done + + echo "0.01" > data/train_cmn/frame_shift + # Create segments to extract i-vectors from for PLDA training data. + # The segments are created using an energy-based speech activity + # detection (SAD) system, but this is not necessary. You can replace + # this with segments computed from your favorite SAD. + diarization/vad_to_segments.sh --nj 40 --cmd "$train_cmd" \ + data/train_cmn data/train_cmn_segmented +fi + +if [ $stage -le 2 ]; then + # Train the UBM on VoxCeleb 1 and 2. + sid/train_diag_ubm.sh --cmd "$train_cmd --mem 4G" \ + --nj 40 --num-threads 8 \ + data/train $num_components \ + exp/diag_ubm + + sid/train_full_ubm.sh --cmd "$train_cmd --mem 25G" \ + --nj 40 --remove-low-count-gaussians false \ + data/train \ + exp/diag_ubm exp/full_ubm +fi + +if [ $stage -le 3 ]; then + # In this stage, we train the i-vector extractor on a subset of VoxCeleb 1 + # and 2. + # + # Note that there are well over 1 million utterances in our training set, + # and it takes an extremely long time to train the extractor on all of this. + # Also, most of those utterances are very short. Short utterances are + # harmful for training the i-vector extractor. Therefore, to reduce the + # training time and improve performance, we will only train on the 100k + # longest utterances. + utils/subset_data_dir.sh \ + --utt-list <(sort -n -k 2 data/train/utt2num_frames | tail -n 100000) \ + data/train data/train_100k + + # Train the i-vector extractor. + sid/train_ivector_extractor.sh --cmd "$train_cmd --mem 16G" \ + --ivector-dim $ivector_dim --num-iters 5 \ + exp/full_ubm/final.ubm data/train_100k \ + $ivec_dir +fi + +if [ $stage -le 4 ]; then + # Extract i-vectors for DIHARD 2018 development and evaluation set. + # We set apply-cmn false and apply-deltas false because we already add + # deltas and apply cmn in stage 1. + diarization/extract_ivectors.sh --cmd "$train_cmd --mem 20G" \ + --nj 40 --window 1.5 --period 0.75 --apply-cmn false --apply-deltas false \ + --min-segment 0.5 $ivec_dir \ + data/dihard_2018_dev_cmn $ivec_dir/ivectors_dihard_2018_dev + + diarization/extract_ivectors.sh --cmd "$train_cmd --mem 20G" \ + --nj 40 --window 1.5 --period 0.75 --apply-cmn false --apply-deltas false \ + --min-segment 0.5 $ivec_dir \ + data/dihard_2018_eval_cmn $ivec_dir/ivectors_dihard_2018_eval + + # Reduce the amount of training data for the PLDA training. + utils/subset_data_dir.sh data/train_cmn_segmented 128000 data/train_cmn_segmented_128k + # Extract i-vectors for the VoxCeleb, which is our PLDA training + # data. A long period is used here so that we don't compute too + # many i-vectors for each recording. + diarization/extract_ivectors.sh --cmd "$train_cmd --mem 25G" \ + --nj 40 --window 3.0 --period 10.0 --min-segment 1.5 --apply-cmn false --apply-deltas false \ + --hard-min true $ivec_dir \ + data/train_cmn_segmented_128k $ivec_dir/ivectors_train_segmented_128k +fi + +if [ $stage -le 5 ]; then + # Train a PLDA model on VoxCeleb, using DIHARD 2018 development set to whiten. + "$train_cmd" $ivec_dir/ivectors_dihard_2018_dev/log/plda.log \ + ivector-compute-plda ark:$ivec_dir/ivectors_train_segmented_128k/spk2utt \ + "ark:ivector-subtract-global-mean \ + scp:$ivec_dir/ivectors_train_segmented_128k/ivector.scp ark:- \ + | transform-vec $ivec_dir/ivectors_dihard_2018_dev/transform.mat ark:- ark:- \ + | ivector-normalize-length ark:- ark:- |" \ + $ivec_dir/ivectors_dihard_2018_dev/plda || exit 1; +fi + +# Perform PLDA scoring +if [ $stage -le 6 ]; then + # Perform PLDA scoring on all pairs of segments for each recording. + diarization/score_plda.sh --cmd "$train_cmd --mem 4G" \ + --nj 20 $ivec_dir/ivectors_dihard_2018_dev $ivec_dir/ivectors_dihard_2018_dev \ + $ivec_dir/ivectors_dihard_2018_dev/plda_scores + + diarization/score_plda.sh --cmd "$train_cmd --mem 4G" \ + --nj 20 $ivec_dir/ivectors_dihard_2018_dev $ivec_dir/ivectors_dihard_2018_eval \ + $ivec_dir/ivectors_dihard_2018_eval/plda_scores +fi + +# Cluster the PLDA scores using a stopping threshold. +if [ $stage -le 7 ]; then + # First, we find the threshold that minimizes the DER on DIHARD 2018 development set. + mkdir -p $ivec_dir/tuning + echo "Tuning clustering threshold for DIHARD 2018 development set" + best_der=100 + best_threshold=0 + + # The threshold is in terms of the log likelihood ratio provided by the + # PLDA scores. In a perfectly calibrated system, the threshold is 0. + # In the following loop, we evaluate DER performance on DIHARD 2018 development + # set using some reasonable thresholds for a well-calibrated system. + for threshold in -0.5 -0.4 -0.3 -0.2 -0.1 -0.05 0 0.05 0.1 0.2 0.3 0.4 0.5; do + diarization/cluster.sh --cmd "$train_cmd --mem 4G" --nj 20 \ + --threshold $threshold --rttm-channel 1 $ivec_dir/ivectors_dihard_2018_dev/plda_scores \ + $ivec_dir/ivectors_dihard_2018_dev/plda_scores_t$threshold + + md-eval.pl -r data/dihard_2018_dev/rttm \ + -s $ivec_dir/ivectors_dihard_2018_dev/plda_scores_t$threshold/rttm \ + 2> $ivec_dir/tuning/dihard_2018_dev_t${threshold}.log \ + > $ivec_dir/tuning/dihard_2018_dev_t${threshold} + + der=$(grep -oP 'DIARIZATION\ ERROR\ =\ \K[0-9]+([.][0-9]+)?' \ + $ivec_dir/tuning/dihard_2018_dev_t${threshold}) + if [ $(perl -e "print ($der < $best_der ? 1 : 0);") -eq 1 ]; then + best_der=$der + best_threshold=$threshold + fi + done + echo "$best_threshold" > $ivec_dir/tuning/dihard_2018_dev_best + + diarization/cluster.sh --cmd "$train_cmd --mem 4G" --nj 20 \ + --threshold $(cat $ivec_dir/tuning/dihard_2018_dev_best) --rttm-channel 1 \ + $ivec_dir/ivectors_dihard_2018_dev/plda_scores $ivec_dir/ivectors_dihard_2018_dev/plda_scores + + # Cluster DIHARD 2018 evaluation set using the best threshold found for the DIHARD + # 2018 development set. The DIHARD 2018 development set is used as the validation + # set to tune the parameters. + diarization/cluster.sh --cmd "$train_cmd --mem 4G" --nj 20 \ + --threshold $(cat $ivec_dir/tuning/dihard_2018_dev_best) --rttm-channel 1 \ + $ivec_dir/ivectors_dihard_2018_eval/plda_scores $ivec_dir/ivectors_dihard_2018_eval/plda_scores + + mkdir -p $ivec_dir/results + # Compute the DER on the DIHARD 2018 evaluation set. We use the official metrics of + # the DIHARD challenge. The DER is calculated with no unscored collars and including + # overlapping speech. + md-eval.pl -r data/dihard_2018_eval/rttm \ + -s $ivec_dir/ivectors_dihard_2018_eval/plda_scores/rttm 2> $ivec_dir/results/threshold.log \ + > $ivec_dir/results/DER_threshold.txt + der=$(grep -oP 'DIARIZATION\ ERROR\ =\ \K[0-9]+([.][0-9]+)?' \ + $ivec_dir/results/DER_threshold.txt) + # Using supervised calibration, DER: 28.51% + echo "Using supervised calibration, DER: $der%" +fi + +# Cluster the PLDA scores using the oracle number of speakers +if [ $stage -le 8 ]; then + # In this section, we show how to do the clustering if the number of speakers + # (and therefore, the number of clusters) per recording is known in advance. + diarization/cluster.sh --cmd "$train_cmd --mem 4G" --nj 20 \ + --reco2num-spk data/dihard_2018_eval/reco2num_spk --rttm-channel 1 \ + $ivec_dir/ivectors_dihard_2018_eval/plda_scores $ivec_dir/ivectors_dihard_2018_eval/plda_scores_num_spk + + md-eval.pl -r data/dihard_2018_eval/rttm \ + -s $ivec_dir/ivectors_dihard_2018_eval/plda_scores_num_spk/rttm 2> $ivec_dir/results/num_spk.log \ + > $ivec_dir/results/DER_num_spk.txt + der=$(grep -oP 'DIARIZATION\ ERROR\ =\ \K[0-9]+([.][0-9]+)?' \ + $ivec_dir/results/DER_num_spk.txt) + # Using the oracle number of speakers, DER: 24.42% + echo "Using the oracle number of speakers, DER: $der%" +fi diff --git a/egs/dihard_2018/v1/sid b/egs/dihard_2018/v1/sid new file mode 120000 index 00000000000..893a12f30c9 --- /dev/null +++ b/egs/dihard_2018/v1/sid @@ -0,0 +1 @@ +../../sre08/v1/sid \ No newline at end of file diff --git a/egs/dihard_2018/v1/steps b/egs/dihard_2018/v1/steps new file mode 120000 index 00000000000..6e99bf5b5ad --- /dev/null +++ b/egs/dihard_2018/v1/steps @@ -0,0 +1 @@ +../../wsj/s5/steps \ No newline at end of file diff --git a/egs/dihard_2018/v1/utils b/egs/dihard_2018/v1/utils new file mode 120000 index 00000000000..b240885218f --- /dev/null +++ b/egs/dihard_2018/v1/utils @@ -0,0 +1 @@ +../../wsj/s5/utils \ No newline at end of file diff --git a/egs/dihard_2018/v2/README.txt b/egs/dihard_2018/v2/README.txt new file mode 100644 index 00000000000..5487a911184 --- /dev/null +++ b/egs/dihard_2018/v2/README.txt @@ -0,0 +1,17 @@ + This recipe is the speaker diarization recipe for The First DIHARD Speech + Diarization Challenge (DIHARD 2018). There are two tracks in the DIHARD 2018 + competition , one uses oracle SAD (track1) and the other required that SAD + was performed from scratch (track2). This script is for track1. + + The recipe is closely based on the following paper: + http://www.danielpovey.com/files/2018_interspeech_dihard.pdf but doesn't + contain the VB refinement. The whole system mainly contains training and + extract x-vectors, PLDA scoring and agglomerative hierarchical clustering. + The VoxCeleb datasets are used for training x-vectors and PLDA. The + development set of the DIHARD 2018 competition is used as validation set to + tune parameters. The system is tested on the DIHARD 2018 evaluation set. + + We also use the following datasets for augmentation. + + MUSAN http://www.openslr.org/17 + RIR_NOISES http://www.openslr.org/28 diff --git a/egs/dihard_2018/v2/cmd.sh b/egs/dihard_2018/v2/cmd.sh new file mode 100755 index 00000000000..c35cd18f287 --- /dev/null +++ b/egs/dihard_2018/v2/cmd.sh @@ -0,0 +1,15 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +export train_cmd="queue.pl" + + diff --git a/egs/dihard_2018/v2/conf/mfcc.conf b/egs/dihard_2018/v2/conf/mfcc.conf new file mode 100755 index 00000000000..9e125706aae --- /dev/null +++ b/egs/dihard_2018/v2/conf/mfcc.conf @@ -0,0 +1,7 @@ +--sample-frequency=16000 +--frame-length=25 # the default is 25 +--low-freq=20 # the default. +--high-freq=7600 # the default is zero meaning use the Nyquist (8k in this case). +--num-mel-bins=30 +--num-ceps=30 +--snip-edges=false diff --git a/egs/dihard_2018/v2/conf/vad.conf b/egs/dihard_2018/v2/conf/vad.conf new file mode 100755 index 00000000000..c9f5e8b3072 --- /dev/null +++ b/egs/dihard_2018/v2/conf/vad.conf @@ -0,0 +1,4 @@ +--vad-energy-threshold=5.5 +--vad-energy-mean-scale=0.5 +--vad-proportion-threshold=0.12 +--vad-frames-context=2 diff --git a/egs/dihard_2018/v2/diarization b/egs/dihard_2018/v2/diarization new file mode 120000 index 00000000000..bad937c1444 --- /dev/null +++ b/egs/dihard_2018/v2/diarization @@ -0,0 +1 @@ +../../callhome_diarization/v1/diarization \ No newline at end of file diff --git a/egs/dihard_2018/v2/local/make_dihard_2018_dev.py b/egs/dihard_2018/v2/local/make_dihard_2018_dev.py new file mode 120000 index 00000000000..3c69bc08240 --- /dev/null +++ b/egs/dihard_2018/v2/local/make_dihard_2018_dev.py @@ -0,0 +1 @@ +../../v1/local/make_dihard_2018_dev.py \ No newline at end of file diff --git a/egs/dihard_2018/v2/local/make_dihard_2018_dev.sh b/egs/dihard_2018/v2/local/make_dihard_2018_dev.sh new file mode 120000 index 00000000000..6fe340e9df2 --- /dev/null +++ b/egs/dihard_2018/v2/local/make_dihard_2018_dev.sh @@ -0,0 +1 @@ +../../v1/local/make_dihard_2018_dev.sh \ No newline at end of file diff --git a/egs/dihard_2018/v2/local/make_dihard_2018_eval.py b/egs/dihard_2018/v2/local/make_dihard_2018_eval.py new file mode 120000 index 00000000000..d107a5446ca --- /dev/null +++ b/egs/dihard_2018/v2/local/make_dihard_2018_eval.py @@ -0,0 +1 @@ +../../v1/local/make_dihard_2018_eval.py \ No newline at end of file diff --git a/egs/dihard_2018/v2/local/make_dihard_2018_eval.sh b/egs/dihard_2018/v2/local/make_dihard_2018_eval.sh new file mode 120000 index 00000000000..0c01aee4fa7 --- /dev/null +++ b/egs/dihard_2018/v2/local/make_dihard_2018_eval.sh @@ -0,0 +1 @@ +../../v1/local/make_dihard_2018_eval.sh \ No newline at end of file diff --git a/egs/dihard_2018/v2/local/make_musan.py b/egs/dihard_2018/v2/local/make_musan.py new file mode 100755 index 00000000000..c4b5c9359b4 --- /dev/null +++ b/egs/dihard_2018/v2/local/make_musan.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +# Copyright 2015 David Snyder +# 2018 Ewald Enzinger +# Apache 2.0. +# +# Modified version of egs/sre16/v1/local/make_musan.py (commit e3fb7c4a0da4167f8c94b80f4d3cc5ab4d0e22e8). +# This version uses the raw MUSAN audio files (16 kHz) and does not use sox to resample at 8 kHz. +# +# This file is meant to be invoked by make_musan.sh. + +import os, sys + +def process_music_annotations(path): + utt2spk = {} + utt2vocals = {} + lines = open(path, 'r').readlines() + for line in lines: + utt, genres, vocals, musician = line.rstrip().split()[:4] + # For this application, the musican ID isn't important + utt2spk[utt] = utt + utt2vocals[utt] = vocals == "Y" + return utt2spk, utt2vocals + +def prepare_music(root_dir, use_vocals): + utt2vocals = {} + utt2spk = {} + utt2wav = {} + num_good_files = 0 + num_bad_files = 0 + music_dir = os.path.join(root_dir, "music") + for root, dirs, files in os.walk(music_dir): + for file in files: + file_path = os.path.join(root, file) + if file.endswith(".wav"): + utt = str(file).replace(".wav", "") + utt2wav[utt] = file_path + elif str(file) == "ANNOTATIONS": + utt2spk_part, utt2vocals_part = process_music_annotations(file_path) + utt2spk.update(utt2spk_part) + utt2vocals.update(utt2vocals_part) + utt2spk_str = "" + utt2wav_str = "" + for utt in utt2vocals: + if utt in utt2wav: + if use_vocals or not utt2vocals[utt]: + utt2spk_str = utt2spk_str + utt + " " + utt2spk[utt] + "\n" + utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n" + num_good_files += 1 + else: + print("Missing file {}".format(utt)) + num_bad_files += 1 + print(("In music directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files)) + return utt2spk_str, utt2wav_str + +def prepare_speech(root_dir): + utt2spk = {} + utt2wav = {} + num_good_files = 0 + num_bad_files = 0 + speech_dir = os.path.join(root_dir, "speech") + for root, dirs, files in os.walk(speech_dir): + for file in files: + file_path = os.path.join(root, file) + if file.endswith(".wav"): + utt = str(file).replace(".wav", "") + utt2wav[utt] = file_path + utt2spk[utt] = utt + utt2spk_str = "" + utt2wav_str = "" + for utt in utt2spk: + if utt in utt2wav: + utt2spk_str = utt2spk_str + utt + " " + utt2spk[utt] + "\n" + utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n" + num_good_files += 1 + else: + print("Missing file {}".format(utt)) + num_bad_files += 1 + print(("In speech directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files)) + return utt2spk_str, utt2wav_str + +def prepare_noise(root_dir): + utt2spk = {} + utt2wav = {} + num_good_files = 0 + num_bad_files = 0 + noise_dir = os.path.join(root_dir, "noise") + for root, dirs, files in os.walk(noise_dir): + for file in files: + file_path = os.path.join(root, file) + if file.endswith(".wav"): + utt = str(file).replace(".wav", "") + utt2wav[utt] = file_path + utt2spk[utt] = utt + utt2spk_str = "" + utt2wav_str = "" + for utt in utt2spk: + if utt in utt2wav: + utt2spk_str = utt2spk_str + utt + " " + utt2spk[utt] + "\n" + utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n" + num_good_files += 1 + else: + print("Missing file {}".format(utt)) + num_bad_files += 1 + print(("In noise directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files)) + return utt2spk_str, utt2wav_str + +def main(): + in_dir = sys.argv[1] + out_dir = sys.argv[2] + use_vocals = sys.argv[3] == "Y" + utt2spk_music, utt2wav_music = prepare_music(in_dir, use_vocals) + utt2spk_speech, utt2wav_speech = prepare_speech(in_dir) + utt2spk_noise, utt2wav_noise = prepare_noise(in_dir) + utt2spk = utt2spk_speech + utt2spk_music + utt2spk_noise + utt2wav = utt2wav_speech + utt2wav_music + utt2wav_noise + wav_fi = open(os.path.join(out_dir, "wav.scp"), 'w') + wav_fi.write(utt2wav) + utt2spk_fi = open(os.path.join(out_dir, "utt2spk"), 'w') + utt2spk_fi.write(utt2spk) + + +if __name__=="__main__": + main() diff --git a/egs/dihard_2018/v2/local/make_musan.sh b/egs/dihard_2018/v2/local/make_musan.sh new file mode 100755 index 00000000000..1565ef0d85c --- /dev/null +++ b/egs/dihard_2018/v2/local/make_musan.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# Copyright 2015 David Snyder +# Apache 2.0. +# +# Copy of egs/sre16/v1/local/make_musan.sh (commit e3fb7c4a0da4167f8c94b80f4d3cc5ab4d0e22e8). +# +# This script, called by ../run.sh, creates the MUSAN +# data directory. The required dataset is freely available at +# http://www.openslr.org/17/ + +set -e +in_dir=$1 +data_dir=$2 +use_vocals='Y' + +mkdir -p local/musan.tmp + +echo "Preparing ${data_dir}/musan..." +mkdir -p ${data_dir}/musan +local/make_musan.py ${in_dir} ${data_dir}/musan ${use_vocals} + +utils/fix_data_dir.sh ${data_dir}/musan + +grep "music" ${data_dir}/musan/utt2spk > local/musan.tmp/utt2spk_music +grep "speech" ${data_dir}/musan/utt2spk > local/musan.tmp/utt2spk_speech +grep "noise" ${data_dir}/musan/utt2spk > local/musan.tmp/utt2spk_noise +utils/subset_data_dir.sh --utt-list local/musan.tmp/utt2spk_music \ + ${data_dir}/musan ${data_dir}/musan_music +utils/subset_data_dir.sh --utt-list local/musan.tmp/utt2spk_speech \ + ${data_dir}/musan ${data_dir}/musan_speech +utils/subset_data_dir.sh --utt-list local/musan.tmp/utt2spk_noise \ + ${data_dir}/musan ${data_dir}/musan_noise + +utils/fix_data_dir.sh ${data_dir}/musan_music +utils/fix_data_dir.sh ${data_dir}/musan_speech +utils/fix_data_dir.sh ${data_dir}/musan_noise + +rm -rf local/musan.tmp + diff --git a/egs/dihard_2018/v2/local/make_voxceleb1.pl b/egs/dihard_2018/v2/local/make_voxceleb1.pl new file mode 120000 index 00000000000..c54d69af919 --- /dev/null +++ b/egs/dihard_2018/v2/local/make_voxceleb1.pl @@ -0,0 +1 @@ +../../v1/local/make_voxceleb1.pl \ No newline at end of file diff --git a/egs/dihard_2018/v2/local/make_voxceleb2.pl b/egs/dihard_2018/v2/local/make_voxceleb2.pl new file mode 120000 index 00000000000..701225dfa57 --- /dev/null +++ b/egs/dihard_2018/v2/local/make_voxceleb2.pl @@ -0,0 +1 @@ +../../v1/local/make_voxceleb2.pl \ No newline at end of file diff --git a/egs/dihard_2018/v2/local/nnet3/xvector/prepare_feats.sh b/egs/dihard_2018/v2/local/nnet3/xvector/prepare_feats.sh new file mode 100755 index 00000000000..4ad2c42d8b9 --- /dev/null +++ b/egs/dihard_2018/v2/local/nnet3/xvector/prepare_feats.sh @@ -0,0 +1,86 @@ +#!/bin/bash +# +# Apache 2.0. + +# This script applies sliding window CMVN and writes the features to disk. +# +# Although this kind of script isn't necessary in speaker recognition recipes, +# it can be helpful in the diarization recipes. The script +# diarization/nnet3/xvector/extract_xvectors.sh extracts x-vectors from very +# short (e.g., 1-2 seconds) segments. Therefore, in order to apply the sliding +# window CMVN in a meaningful way, it must be performed prior to performing +# the subsegmentation. + +nj=40 +cmd="run.pl" +stage=0 +norm_vars=false +center=true +compress=true +cmn_window=300 + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; +if [ $# != 3 ]; then + echo "Usage: $0 " + echo "e.g.: $0 data/train data/train_no_sil exp/make_xvector_features" + echo "Options: " + echo " --nj # number of parallel jobs" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --norm-vars # If true, normalize variances in the sliding window cmvn" + exit 1; +fi + +data_in=$1 +data_out=$2 +dir=$3 + +name=`basename $data_in` + +for f in $data_in/feats.scp ; do + [ ! -f $f ] && echo "$0: No such file $f" && exit 1; +done + +# Set various variables. +mkdir -p $dir/log +mkdir -p $data_out +featdir=$(utils/make_absolute.sh $dir) + +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $featdir/storage ]; then + utils/create_split_dir.pl \ + /export/b{14,15,16,17}/$USER/kaldi-data/egs/dihard_2018/v2/xvector-$(date +'%m_%d_%H_%M')/xvector_cmvn_feats/storage $featdir/storage +fi + +for n in $(seq $nj); do + # the next command does nothing unless $featdir/storage/ exists, see + # utils/create_data_link.pl for more info. + utils/create_data_link.pl $featdir/xvector_cmvn_feats_${name}.${n}.ark +done + +cp $data_in/utt2spk $data_out/utt2spk +cp $data_in/spk2utt $data_out/spk2utt +cp $data_in/wav.scp $data_out/wav.scp + +write_num_frames_opt="--write-num-frames=ark,t:$featdir/log/utt2num_frames.JOB" + +sdata_in=$data_in/split$nj; +utils/split_data.sh $data_in $nj || exit 1; + +$cmd JOB=1:$nj $dir/log/create_xvector_cmvn_feats_${name}.JOB.log \ + apply-cmvn-sliding --norm-vars=false --center=true --cmn-window=$cmn_window \ + scp:${sdata_in}/JOB/feats.scp ark:- \| \ + copy-feats --compress=$compress $write_num_frames_opt ark:- \ + ark,scp:$featdir/xvector_cmvn_feats_${name}.JOB.ark,$featdir/xvector_cmvn_feats_${name}.JOB.scp || exit 1; + +for n in $(seq $nj); do + cat $featdir/xvector_cmvn_feats_${name}.$n.scp || exit 1; +done > ${data_out}/feats.scp || exit 1 + +for n in $(seq $nj); do + cat $featdir/log/utt2num_frames.$n || exit 1; +done > $data_out/utt2num_frames || exit 1 +rm $featdir/log/utt2num_frames.* + +echo "$0: Succeeded creating xvector features for $name" diff --git a/egs/dihard_2018/v2/local/nnet3/xvector/prepare_feats_for_egs.sh b/egs/dihard_2018/v2/local/nnet3/xvector/prepare_feats_for_egs.sh new file mode 100755 index 00000000000..1d8ac6153e7 --- /dev/null +++ b/egs/dihard_2018/v2/local/nnet3/xvector/prepare_feats_for_egs.sh @@ -0,0 +1,84 @@ +#!/bin/bash +# +# Copied from egs/sre16/v1/local/nnet3/xvector/prepare_feats_for_egs.sh (commit 3ea534070fd2cccd2e4ee21772132230033022ce). +# +# Apache 2.0. + +# This script applies sliding window cmvn and removes silence frames. This +# is performed on the raw features prior to generating examples for training +# the xvector system. + +nj=40 +cmd="run.pl" +stage=0 +norm_vars=false +center=true +compress=true +cmn_window=300 + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; +if [ $# != 3 ]; then + echo "Usage: $0 " + echo "e.g.: $0 data/train data/train_no_sil exp/make_xvector_features" + echo "Options: " + echo " --nj # number of parallel jobs" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --norm-vars # If true, normalize variances in the sliding window cmvn" + exit 1; +fi + +data_in=$1 +data_out=$2 +dir=$3 + +name=`basename $data_in` + +for f in $data_in/feats.scp $data_in/vad.scp ; do + [ ! -f $f ] && echo "$0: No such file $f" && exit 1; +done + +# Set various variables. +mkdir -p $dir/log +mkdir -p $data_out +featdir=$(utils/make_absolute.sh $dir) + +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $featdir/storage ]; then + utils/create_split_dir.pl \ + /export/b{14,15,16,17}/$USER/kaldi-data/egs/dihard_2018/v2/xvector-$(date +'%m_%d_%H_%M')/xvector_feats/storage $featdir/storage +fi + +for n in $(seq $nj); do + # the next command does nothing unless $featdir/storage/ exists, see + # utils/create_data_link.pl for more info. + utils/create_data_link.pl $featdir/xvector_feats_${name}.${n}.ark +done + +cp $data_in/utt2spk $data_out/utt2spk +cp $data_in/spk2utt $data_out/spk2utt +cp $data_in/wav.scp $data_out/wav.scp + +write_num_frames_opt="--write-num-frames=ark,t:$featdir/log/utt2num_frames.JOB" + +sdata_in=$data_in/split$nj; +utils/split_data.sh $data_in $nj || exit 1; + +$cmd JOB=1:$nj $dir/log/create_xvector_feats_${name}.JOB.log \ + apply-cmvn-sliding --norm-vars=false --center=true --cmn-window=$cmn_window \ + scp:${sdata_in}/JOB/feats.scp ark:- \| \ + select-voiced-frames ark:- scp,s,cs:${sdata_in}/JOB/vad.scp ark:- \| \ + copy-feats --compress=$compress $write_num_frames_opt ark:- \ + ark,scp:$featdir/xvector_feats_${name}.JOB.ark,$featdir/xvector_feats_${name}.JOB.scp || exit 1; + +for n in $(seq $nj); do + cat $featdir/xvector_feats_${name}.$n.scp || exit 1; +done > ${data_out}/feats.scp || exit 1 + +for n in $(seq $nj); do + cat $featdir/log/utt2num_frames.$n || exit 1; +done > $data_out/utt2num_frames || exit 1 +rm $featdir/log/utt2num_frames.* + +echo "$0: Succeeded creating xvector features for $name" diff --git a/egs/dihard_2018/v2/local/nnet3/xvector/run_xvector.sh b/egs/dihard_2018/v2/local/nnet3/xvector/run_xvector.sh new file mode 120000 index 00000000000..585b63fd2dd --- /dev/null +++ b/egs/dihard_2018/v2/local/nnet3/xvector/run_xvector.sh @@ -0,0 +1 @@ +tuning/run_xvector_1a.sh \ No newline at end of file diff --git a/egs/dihard_2018/v2/local/nnet3/xvector/tuning/run_xvector_1a.sh b/egs/dihard_2018/v2/local/nnet3/xvector/tuning/run_xvector_1a.sh new file mode 100755 index 00000000000..4ee472b1c71 --- /dev/null +++ b/egs/dihard_2018/v2/local/nnet3/xvector/tuning/run_xvector_1a.sh @@ -0,0 +1,155 @@ +#!/bin/bash +# Copyright 2017 David Snyder +# 2017 Johns Hopkins University (Author: Daniel Garcia-Romero) +# 2017 Johns Hopkins University (Author: Daniel Povey) +# +# Copied from egs/sre16/v1/local/nnet3/xvector/tuning/run_xvector_1a.sh (commit e082c17d4a8f8a791428ae4d9f7ceb776aef3f0b). +# +# Apache 2.0. + +# This script trains a DNN similar to the recipe described in +# http://www.danielpovey.com/files/2018_icassp_xvectors.pdf + +. ./cmd.sh +set -e + +stage=1 +train_stage=0 +use_gpu=true +remove_egs=false + +data=data/train +nnet_dir=exp/xvector_nnet_1a/ +egs_dir=exp/xvector_nnet_1a/egs + +. ./path.sh +. ./cmd.sh +. ./utils/parse_options.sh + +num_pdfs=$(awk '{print $2}' $data/utt2spk | sort | uniq -c | wc -l) + +# Now we create the nnet examples using sid/nnet3/xvector/get_egs.sh. +# The argument --num-repeats is related to the number of times a speaker +# repeats per archive. If it seems like you're getting too many archives +# (e.g., more than 200) try increasing the --frames-per-iter option. The +# arguments --min-frames-per-chunk and --max-frames-per-chunk specify the +# minimum and maximum length (in terms of number of frames) of the features +# in the examples. +# +# To make sense of the egs script, it may be necessary to put an "exit 1" +# command immediately after stage 3. Then, inspect +# exp//egs/temp/ranges.* . The ranges files specify the examples that +# will be created, and which archives they will be stored in. Each line of +# ranges.* has the following form: +# +# For example: +# 100304-f-sre2006-kacg-A 1 2 4079 881 23 + +# If you're satisfied with the number of archives (e.g., 50-150 archives is +# reasonable) and with the number of examples per speaker (e.g., 1000-5000 +# is reasonable) then you can let the script continue to the later stages. +# Otherwise, try increasing or decreasing the --num-repeats option. You might +# need to fiddle with --frames-per-iter. Increasing this value decreases the +# the number of archives and increases the number of examples per archive. +# Decreasing this value increases the number of archives, while decreasing the +# number of examples per archive. +if [ $stage -le 6 ]; then + echo "$0: Getting neural network training egs"; + # dump egs. + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $egs_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/dihard_2018/v2/xvector-$(date +'%m_%d_%H_%M')/$egs_dir/storage $egs_dir/storage + fi + sid/nnet3/xvector/get_egs.sh --cmd "$train_cmd" \ + --nj 8 \ + --stage 0 \ + --frames-per-iter 1000000000 \ + --frames-per-iter-diagnostic 100000 \ + --min-frames-per-chunk 200 \ + --max-frames-per-chunk 400 \ + --num-diagnostic-archives 3 \ + --num-repeats 50 \ + "$data" $egs_dir +fi + +if [ $stage -le 7 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(wc -w $egs_dir/pdf2num | awk '{print $1}') + feat_dim=$(cat $egs_dir/info/feat_dim) + + # This chunk-size corresponds to the maximum number of frames the + # stats layer is able to pool over. In this script, it corresponds + # to 100 seconds. If the input recording is greater than 100 seconds, + # we will compute multiple xvectors from the same recording and average + # to produce the final xvector. + max_chunk_size=10000 + + # The smallest number of frames we're comfortable computing an xvector from. + # Note that the hard minimum is given by the left and right context of the + # frame-level layers. + min_chunk_size=25 + mkdir -p $nnet_dir/configs + cat < $nnet_dir/configs/network.xconfig + # please note that it is important to have input layer with the name=input + + # The frame-level layers + input dim=${feat_dim} name=input + relu-batchnorm-layer name=tdnn1 input=Append(-2,-1,0,1,2) dim=512 + relu-batchnorm-layer name=tdnn2 input=Append(-2,0,2) dim=512 + relu-batchnorm-layer name=tdnn3 input=Append(-3,0,3) dim=512 + relu-batchnorm-layer name=tdnn4 dim=512 + relu-batchnorm-layer name=tdnn5 dim=1500 + + # The stats pooling layer. Layers after this are segment-level. + # In the config below, the first and last argument (0, and ${max_chunk_size}) + # means that we pool over an input segment starting at frame 0 + # and ending at frame ${max_chunk_size} or earlier. The other arguments (1:1) + # mean that no subsampling is performed. + stats-layer name=stats config=mean+stddev(0:1:1:${max_chunk_size}) + + # This is where we usually extract the embedding (aka xvector) from. + relu-batchnorm-layer name=tdnn6 dim=512 input=stats + + # This is where another layer the embedding could be extracted + # from, but usually the previous one works better. + relu-batchnorm-layer name=tdnn7 dim=512 + output-layer name=output include-log-softmax=true dim=${num_targets} +EOF + + steps/nnet3/xconfig_to_configs.py \ + --xconfig-file $nnet_dir/configs/network.xconfig \ + --config-dir $nnet_dir/configs/ + cp $nnet_dir/configs/final.config $nnet_dir/nnet.config + + # These three files will be used by sid/nnet3/xvector/extract_xvectors.sh + echo "output-node name=output input=tdnn6.affine" > $nnet_dir/extract.config + echo "$max_chunk_size" > $nnet_dir/max_chunk_size + echo "$min_chunk_size" > $nnet_dir/min_chunk_size +fi + +dropout_schedule='0,0@0.20,0.1@0.50,0' +srand=123 +if [ $stage -le 8 ]; then + steps/nnet3/train_raw_dnn.py --stage=$train_stage \ + --cmd="$train_cmd" \ + --trainer.optimization.proportional-shrink 10 \ + --trainer.optimization.momentum=0.5 \ + --trainer.optimization.num-jobs-initial=3 \ + --trainer.optimization.num-jobs-final=8 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.minibatch-size=64 \ + --trainer.srand=$srand \ + --trainer.max-param-change=2 \ + --trainer.num-epochs=3 \ + --trainer.dropout-schedule="$dropout_schedule" \ + --trainer.shuffle-buffer-size=1000 \ + --egs.frames-per-eg=1 \ + --egs.dir="$egs_dir" \ + --cleanup.remove-egs $remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --dir=$nnet_dir || exit 1; +fi + +exit 0; diff --git a/egs/dihard_2018/v2/path.sh b/egs/dihard_2018/v2/path.sh new file mode 100755 index 00000000000..851c14e27c3 --- /dev/null +++ b/egs/dihard_2018/v2/path.sh @@ -0,0 +1,5 @@ +export KALDI_ROOT=`pwd`/../../.. +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/tools/sph2pipe_v2.5:$KALDI_ROOT/tools/sctk/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C diff --git a/egs/dihard_2018/v2/run.sh b/egs/dihard_2018/v2/run.sh new file mode 100755 index 00000000000..0da1f330ea7 --- /dev/null +++ b/egs/dihard_2018/v2/run.sh @@ -0,0 +1,309 @@ +#!/bin/bash +# Copyright 2017 Johns Hopkins University (Author: Daniel Garcia-Romero) +# 2017 Johns Hopkins University (Author: Daniel Povey) +# 2017-2018 David Snyder +# 2018 Ewald Enzinger +# 2018 Zili Huang +# Apache 2.0. +# +# See ../README.txt for more info on data required. +# Results (diarization error rate) are inline in comments below. + +. ./cmd.sh +. ./path.sh +set -e +mfccdir=`pwd`/mfcc +vaddir=`pwd`/mfcc + +voxceleb1_root=/export/corpora/VoxCeleb1 +voxceleb2_root=/export/corpora/VoxCeleb2 +nnet_dir=exp/xvector_nnet_1a +musan_root=/export/corpora/JHU/musan +dihard_2018_dev=/export/corpora/LDC/LDC2018E31 +dihard_2018_eval=/export/corpora/LDC/LDC2018E32v1.1 + +stage=0 + +if [ $stage -le 0 ]; then + local/make_voxceleb2.pl $voxceleb2_root dev data/voxceleb2_train + local/make_voxceleb2.pl $voxceleb2_root test data/voxceleb2_test + # This script creates data/voxceleb1_test and data/voxceleb1_train. + # Our evaluation set is the test portion of VoxCeleb1. + local/make_voxceleb1.pl $voxceleb1_root data + # We'll train on all of VoxCeleb2, plus the training portion of VoxCeleb1. + # This should give 7,351 speakers and 1,277,503 utterances. + utils/combine_data.sh data/train data/voxceleb2_train data/voxceleb2_test data/voxceleb1_train + + # Prepare the development and evaluation set for DIHARD 2018. + local/make_dihard_2018_dev.sh $dihard_2018_dev data/dihard_2018_dev + local/make_dihard_2018_eval.sh $dihard_2018_eval data/dihard_2018_eval +fi + +if [ $stage -le 1 ]; then + # Make MFCCs for each dataset. + for name in train dihard_2018_dev dihard_2018_eval; do + steps/make_mfcc.sh --write-utt2num-frames true --mfcc-config conf/mfcc.conf --nj 40 --cmd "$train_cmd --max-jobs-run 20" \ + data/${name} exp/make_mfcc $mfccdir + utils/fix_data_dir.sh data/${name} + done + + # Compute the energy-based VAD for training set. + sid/compute_vad_decision.sh --nj 40 --cmd "$train_cmd" \ + data/train exp/make_vad $vaddir + utils/fix_data_dir.sh data/train + + # This writes features to disk after applying the sliding window CMN. + # Although this is somewhat wasteful in terms of disk space, for diarization + # it ends up being preferable to performing the CMN in memory. If the CMN + # were performed in memory (e.g., we used --apply-cmn true in + # diarization/nnet3/xvector/extract_xvectors.sh) it would need to be + # performed after the subsegmentation, which leads to poorer results. + for name in train dihard_2018_dev dihard_2018_eval; do + local/nnet3/xvector/prepare_feats.sh --nj 40 --cmd "$train_cmd" \ + data/$name data/${name}_cmn exp/${name}_cmn + if [ -f data/$name/vad.scp ]; then + cp data/$name/vad.scp data/${name}_cmn/ + fi + if [ -f data/$name/segments ]; then + cp data/$name/segments data/${name}_cmn/ + fi + utils/fix_data_dir.sh data/${name}_cmn + done + + echo "0.01" > data/train_cmn/frame_shift + # Create segments to extract x-vectors from for PLDA training data. + # The segments are created using an energy-based speech activity + # detection (SAD) system, but this is not necessary. You can replace + # this with segments computed from your favorite SAD. + diarization/vad_to_segments.sh --nj 40 --cmd "$train_cmd" \ + data/train_cmn data/train_cmn_segmented +fi + +# In this section, we augment the training data with reverberation, +# noise, music, and babble, and combine it with the clean data. +if [ $stage -le 2 ]; then + frame_shift=0.01 + awk -v frame_shift=$frame_shift '{print $1, $2*frame_shift;}' data/train/utt2num_frames > data/train/reco2dur + + if [ ! -d "RIRS_NOISES" ]; then + # Download the package that includes the real RIRs, simulated RIRs, isotropic noises and point-source noises + wget --no-check-certificate http://www.openslr.org/resources/28/rirs_noises.zip + unzip rirs_noises.zip + fi + + # Make a version with reverberated speech + rvb_opts=() + rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/smallroom/rir_list") + rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/mediumroom/rir_list") + + # Make a reverberated version of the training data. Note that we don't add any + # additive noise here. + steps/data/reverberate_data_dir.py \ + "${rvb_opts[@]}" \ + --speech-rvb-probability 1 \ + --pointsource-noise-addition-probability 0 \ + --isotropic-noise-addition-probability 0 \ + --num-replications 1 \ + --source-sampling-rate 16000 \ + data/train data/train_reverb + cp data/train/vad.scp data/train_reverb/ + utils/copy_data_dir.sh --utt-suffix "-reverb" data/train_reverb data/train_reverb.new + rm -rf data/train_reverb + mv data/train_reverb.new data/train_reverb + + # Prepare the MUSAN corpus, which consists of music, speech, and noise + # suitable for augmentation. + local/make_musan.sh $musan_root data + + # Get the duration of the MUSAN recordings. This will be used by the + # script augment_data_dir.py. + for name in speech noise music; do + utils/data/get_utt2dur.sh data/musan_${name} + mv data/musan_${name}/utt2dur data/musan_${name}/reco2dur + done + + # Augment with musan_noise + steps/data/augment_data_dir.py --utt-suffix "noise" --fg-interval 1 --fg-snrs "15:10:5:0" --fg-noise-dir "data/musan_noise" data/train data/train_noise + # Augment with musan_music + steps/data/augment_data_dir.py --utt-suffix "music" --bg-snrs "15:10:8:5" --num-bg-noises "1" --bg-noise-dir "data/musan_music" data/train data/train_music + # Augment with musan_speech + steps/data/augment_data_dir.py --utt-suffix "babble" --bg-snrs "20:17:15:13" --num-bg-noises "3:4:5:6:7" --bg-noise-dir "data/musan_speech" data/train data/train_babble + + # Combine reverb, noise, music, and babble into one directory. + utils/combine_data.sh data/train_aug data/train_reverb data/train_noise data/train_music data/train_babble +fi + +if [ $stage -le 3 ]; then + # Take a random subset of the augmentations + utils/subset_data_dir.sh data/train_aug 1000000 data/train_aug_1m + utils/fix_data_dir.sh data/train_aug_1m + + # Make MFCCs for the augmented data. Note that we do not compute a new + # vad.scp file here. Instead, we use the vad.scp from the clean version of + # the list. + steps/make_mfcc.sh --mfcc-config conf/mfcc.conf --nj 40 --cmd "$train_cmd --max-jobs-run 20" \ + data/train_aug_1m exp/make_mfcc $mfccdir + + # Combine the clean and augmented training data. This is now roughly + # double the size of the original clean list. + utils/combine_data.sh data/train_combined data/train_aug_1m data/train +fi + +# Now we prepare the features to generate examples for xvector training. +if [ $stage -le 4 ]; then + # This script applies CMVN and removes nonspeech frames. Note that this is somewhat + # wasteful, as it roughly doubles the amount of training data on disk. After + # creating training examples, this can be removed. + local/nnet3/xvector/prepare_feats_for_egs.sh --nj 40 --cmd "$train_cmd" \ + data/train_combined data/train_combined_no_sil exp/train_combined_no_sil + utils/fix_data_dir.sh data/train_combined_no_sil +fi + +if [ $stage -le 5 ]; then + # Now, we need to remove features that are too short after removing silence + # frames. We want at least 4s (400 frames) per utterance. + min_len=400 + mv data/train_combined_no_sil/utt2num_frames data/train_combined_no_sil/utt2num_frames.bak + awk -v min_len=${min_len} '$2 > min_len {print $1, $2}' data/train_combined_no_sil/utt2num_frames.bak > data/train_combined_no_sil/utt2num_frames + utils/filter_scp.pl data/train_combined_no_sil/utt2num_frames data/train_combined_no_sil/utt2spk > data/train_combined_no_sil/utt2spk.new + mv data/train_combined_no_sil/utt2spk.new data/train_combined_no_sil/utt2spk + utils/fix_data_dir.sh data/train_combined_no_sil + + # We also want several utterances per speaker. Now we'll throw out speakers + # with fewer than 8 utterances. + min_num_utts=8 + awk '{print $1, NF-1}' data/train_combined_no_sil/spk2utt > data/train_combined_no_sil/spk2num + awk -v min_num_utts=${min_num_utts} '$2 >= min_num_utts {print $1, $2}' data/train_combined_no_sil/spk2num | utils/filter_scp.pl - data/train_combined_no_sil/spk2utt > data/train_combined_no_sil/spk2utt.new + mv data/train_combined_no_sil/spk2utt.new data/train_combined_no_sil/spk2utt + utils/spk2utt_to_utt2spk.pl data/train_combined_no_sil/spk2utt > data/train_combined_no_sil/utt2spk + + utils/filter_scp.pl data/train_combined_no_sil/utt2spk data/train_combined_no_sil/utt2num_frames > data/train_combined_no_sil/utt2num_frames.new + mv data/train_combined_no_sil/utt2num_frames.new data/train_combined_no_sil/utt2num_frames + + # Now we're ready to create training examples. + utils/fix_data_dir.sh data/train_combined_no_sil +fi + +# Stages 6 through 8 are handled in run_xvector.sh, a TDNN embedding extractor is trained. +local/nnet3/xvector/run_xvector.sh --stage $stage --train-stage -1 \ + --data data/train_combined_no_sil --nnet-dir $nnet_dir \ + --egs-dir $nnet_dir/egs + +if [ $stage -le 9 ]; then + # Extract x-vectors for DIHARD 2018 development and evaluation set. + diarization/nnet3/xvector/extract_xvectors.sh --cmd "$train_cmd --mem 5G" \ + --nj 40 --window 1.5 --period 0.75 --apply-cmn false \ + --min-segment 0.5 $nnet_dir \ + data/dihard_2018_dev_cmn $nnet_dir/xvectors_dihard_2018_dev + + diarization/nnet3/xvector/extract_xvectors.sh --cmd "$train_cmd --mem 5G" \ + --nj 40 --window 1.5 --period 0.75 --apply-cmn false \ + --min-segment 0.5 $nnet_dir \ + data/dihard_2018_eval_cmn $nnet_dir/xvectors_dihard_2018_eval + + # Reduce the amount of training data for the PLDA training. + utils/subset_data_dir.sh data/train_cmn_segmented 128000 data/train_cmn_segmented_128k + # Extract x-vectors for the VoxCeleb, which is our PLDA training + # data. A long period is used here so that we don't compute too + # many x-vectors for each recording. + diarization/nnet3/xvector/extract_xvectors.sh --cmd "$train_cmd --mem 10G" \ + --nj 40 --window 3.0 --period 10.0 --min-segment 1.5 --apply-cmn false \ + --hard-min true $nnet_dir \ + data/train_cmn_segmented_128k $nnet_dir/xvectors_train_segmented_128k +fi + +# Train PLDA models +if [ $stage -le 10 ]; then + # Train a PLDA model on VoxCeleb, using DIHARD 2018 development set to whiten. + "$train_cmd" $nnet_dir/xvectors_dihard_2018_dev/log/plda.log \ + ivector-compute-plda ark:$nnet_dir/xvectors_train_segmented_128k/spk2utt \ + "ark:ivector-subtract-global-mean \ + scp:$nnet_dir/xvectors_train_segmented_128k/xvector.scp ark:- \ + | transform-vec $nnet_dir/xvectors_dihard_2018_dev/transform.mat ark:- ark:- \ + | ivector-normalize-length ark:- ark:- |" \ + $nnet_dir/xvectors_dihard_2018_dev/plda || exit 1; +fi + +# Perform PLDA scoring +if [ $stage -le 11 ]; then + # Perform PLDA scoring on all pairs of segments for each recording. + diarization/nnet3/xvector/score_plda.sh --cmd "$train_cmd --mem 4G" \ + --nj 20 $nnet_dir/xvectors_dihard_2018_dev $nnet_dir/xvectors_dihard_2018_dev \ + $nnet_dir/xvectors_dihard_2018_dev/plda_scores + + diarization/nnet3/xvector/score_plda.sh --cmd "$train_cmd --mem 4G" \ + --nj 20 $nnet_dir/xvectors_dihard_2018_dev $nnet_dir/xvectors_dihard_2018_eval \ + $nnet_dir/xvectors_dihard_2018_eval/plda_scores +fi + +# Cluster the PLDA scores using a stopping threshold. +if [ $stage -le 12 ]; then + # First, we find the threshold that minimizes the DER on DIHARD 2018 development set. + mkdir -p $nnet_dir/tuning + echo "Tuning clustering threshold for DIHARD 2018 development set" + best_der=100 + best_threshold=0 + + # The threshold is in terms of the log likelihood ratio provided by the + # PLDA scores. In a perfectly calibrated system, the threshold is 0. + # In the following loop, we evaluate DER performance on DIHARD 2018 development + # set using some reasonable thresholds for a well-calibrated system. + for threshold in -0.5 -0.4 -0.3 -0.2 -0.1 -0.05 0 0.05 0.1 0.2 0.3 0.4 0.5; do + diarization/cluster.sh --cmd "$train_cmd --mem 4G" --nj 20 \ + --threshold $threshold --rttm-channel 1 $nnet_dir/xvectors_dihard_2018_dev/plda_scores \ + $nnet_dir/xvectors_dihard_2018_dev/plda_scores_t$threshold + + md-eval.pl -r data/dihard_2018_dev/rttm \ + -s $nnet_dir/xvectors_dihard_2018_dev/plda_scores_t$threshold/rttm \ + 2> $nnet_dir/tuning/dihard_2018_dev_t${threshold}.log \ + > $nnet_dir/tuning/dihard_2018_dev_t${threshold} + + der=$(grep -oP 'DIARIZATION\ ERROR\ =\ \K[0-9]+([.][0-9]+)?' \ + $nnet_dir/tuning/dihard_2018_dev_t${threshold}) + if [ $(perl -e "print ($der < $best_der ? 1 : 0);") -eq 1 ]; then + best_der=$der + best_threshold=$threshold + fi + done + echo "$best_threshold" > $nnet_dir/tuning/dihard_2018_dev_best + + diarization/cluster.sh --cmd "$train_cmd --mem 4G" --nj 20 \ + --threshold $(cat $nnet_dir/tuning/dihard_2018_dev_best) --rttm-channel 1 \ + $nnet_dir/xvectors_dihard_2018_dev/plda_scores $nnet_dir/xvectors_dihard_2018_dev/plda_scores + + # Cluster DIHARD 2018 evaluation set using the best threshold found for the DIHARD + # 2018 development set. The DIHARD 2018 development set is used as the validation + # set to tune the parameters. + diarization/cluster.sh --cmd "$train_cmd --mem 4G" --nj 20 \ + --threshold $(cat $nnet_dir/tuning/dihard_2018_dev_best) --rttm-channel 1 \ + $nnet_dir/xvectors_dihard_2018_eval/plda_scores $nnet_dir/xvectors_dihard_2018_eval/plda_scores + + mkdir -p $nnet_dir/results + # Compute the DER on the DIHARD 2018 evaluation set. We use the official metrics of + # the DIHARD challenge. The DER is calculated with no unscored collars and including + # overlapping speech. + md-eval.pl -r data/dihard_2018_eval/rttm \ + -s $nnet_dir/xvectors_dihard_2018_eval/plda_scores/rttm 2> $nnet_dir/results/threshold.log \ + > $nnet_dir/results/DER_threshold.txt + der=$(grep -oP 'DIARIZATION\ ERROR\ =\ \K[0-9]+([.][0-9]+)?' \ + $nnet_dir/results/DER_threshold.txt) + # Using supervised calibration, DER: 26.47% + echo "Using supervised calibration, DER: $der%" +fi + +# Cluster the PLDA scores using the oracle number of speakers +if [ $stage -le 13 ]; then + # In this section, we show how to do the clustering if the number of speakers + # (and therefore, the number of clusters) per recording is known in advance. + diarization/cluster.sh --cmd "$train_cmd --mem 4G" --nj 20 \ + --reco2num-spk data/dihard_2018_eval/reco2num_spk --rttm-channel 1 \ + $nnet_dir/xvectors_dihard_2018_eval/plda_scores $nnet_dir/xvectors_dihard_2018_eval/plda_scores_num_spk + + md-eval.pl -r data/dihard_2018_eval/rttm \ + -s $nnet_dir/xvectors_dihard_2018_eval/plda_scores_num_spk/rttm 2> $nnet_dir/results/num_spk.log \ + > $nnet_dir/results/DER_num_spk.txt + der=$(grep -oP 'DIARIZATION\ ERROR\ =\ \K[0-9]+([.][0-9]+)?' \ + $nnet_dir/results/DER_num_spk.txt) + # Using the oracle number of speakers, DER: 23.90% + echo "Using the oracle number of speakers, DER: $der%" +fi diff --git a/egs/dihard_2018/v2/sid b/egs/dihard_2018/v2/sid new file mode 120000 index 00000000000..893a12f30c9 --- /dev/null +++ b/egs/dihard_2018/v2/sid @@ -0,0 +1 @@ +../../sre08/v1/sid \ No newline at end of file diff --git a/egs/dihard_2018/v2/steps b/egs/dihard_2018/v2/steps new file mode 120000 index 00000000000..6e99bf5b5ad --- /dev/null +++ b/egs/dihard_2018/v2/steps @@ -0,0 +1 @@ +../../wsj/s5/steps \ No newline at end of file diff --git a/egs/dihard_2018/v2/utils b/egs/dihard_2018/v2/utils new file mode 120000 index 00000000000..b240885218f --- /dev/null +++ b/egs/dihard_2018/v2/utils @@ -0,0 +1 @@ +../../wsj/s5/utils \ No newline at end of file diff --git a/egs/fame/s5/run.sh b/egs/fame/s5/run.sh index 26a8485ff7d..de6fe46b7c4 100755 --- a/egs/fame/s5/run.sh +++ b/egs/fame/s5/run.sh @@ -106,8 +106,8 @@ fi if [ $stage -le 7 ]; then echo "Starting SGMM training." steps/align_fmllr.sh --nj $train_nj --cmd "$train_cmd" data/train data/lang exp/tri3 exp/tri3_ali || exit 1; - steps/train_ubm.sh --cmd "$train_cmd" $numGaussUBM data/train data/lang exp/tri3_ali exp/ubm || exit 1; - steps/train_sgmm2.sh --cmd "$train_cmd" $numLeavesSGMM $numGaussSGMM data/train data/lang exp/tri3_ali exp/ubm/final.ubm exp/sgmm2 || exit 1; + steps/train_ubm.sh --cmd "$train_cmd" $numGaussUBM data/train data/lang exp/tri3_ali exp/ubm || exit 1; + steps/train_sgmm2.sh --cmd "$train_cmd" $numLeavesSGMM $numGaussSGMM data/train data/lang exp/tri3_ali exp/ubm/final.ubm exp/sgmm2 || exit 1; echo "SGMM training done." echo "Decoding the development and test sets using SGMM models" diff --git a/egs/fame/v1/local/prepare_for_eer.py b/egs/fame/v1/local/prepare_for_eer.py index 59d2985e7c2..f1dbcfa9ab6 100755 --- a/egs/fame/v1/local/prepare_for_eer.py +++ b/egs/fame/v1/local/prepare_for_eer.py @@ -1,3 +1,4 @@ +from __future__ import print_function # Copyright 2015 David Snyder # Apache 2.0. # @@ -12,4 +13,4 @@ spkrutt2target[spkr+utt]=target for line in scores: spkr, utt, score = line.strip().split() - print score, spkrutt2target[spkr+utt] + print(score, spkrutt2target[spkr+utt]) diff --git a/egs/farsdat/s5/run.sh b/egs/farsdat/s5/run.sh index 81f353c301c..4c3d3c5882b 100755 --- a/egs/farsdat/s5/run.sh +++ b/egs/farsdat/s5/run.sh @@ -8,7 +8,7 @@ # farsdat, description of the database: # http://www.assta.org/sst/SST-94-Vol-ll/cache/SST-94-VOL2-Chapter15-p20.pdf -. ./cmd.sh +. ./cmd.sh [ -f path.sh ] && . ./path.sh set -e @@ -54,7 +54,7 @@ echo =========================================================================== # Now make MFCC features. mfccdir=mfcc -for x in train dev test; do +for x in train dev test; do steps/make_mfcc.sh --cmd "$train_cmd" --nj $feats_nj data/$x exp/make_mfcc/$x $mfccdir steps/compute_cmvn_stats.sh data/$x exp/make_mfcc/$x $mfccdir done diff --git a/egs/fisher_callhome_spanish/s5/conf/mfcc_hires.conf b/egs/fisher_callhome_spanish/s5/conf/mfcc_hires.conf new file mode 100644 index 00000000000..d870ab04c38 --- /dev/null +++ b/egs/fisher_callhome_spanish/s5/conf/mfcc_hires.conf @@ -0,0 +1,10 @@ +# config for high-resolution MFCC features, intended for neural network training. +# Note: we keep all cepstra, so it has the same info as filterbank features, +# but MFCC is more easily compressible (because less correlated) which is why +# we prefer this method. +--use-energy=false # use average of log energy, not energy. +--sample-frequency=8000 # Switchboard is sampled at 8kHz +--num-mel-bins=40 # similar to Google's setup. +--num-ceps=40 # there is no dimensionality reduction. +--low-freq=40 # low cutoff frequency for mel bins +--high-freq=-200 # high cutoff frequently, relative to Nyquist of 4000 (=3800) diff --git a/egs/fisher_callhome_spanish/s5/conf/online_cmvn.conf b/egs/fisher_callhome_spanish/s5/conf/online_cmvn.conf new file mode 100644 index 00000000000..7748a4a4dd3 --- /dev/null +++ b/egs/fisher_callhome_spanish/s5/conf/online_cmvn.conf @@ -0,0 +1 @@ +# configuration file for apply-cmvn-online, used in the script ../local/run_online_decoding.sh diff --git a/egs/fisher_callhome_spanish/s5/local/callhome_get_lattices.py b/egs/fisher_callhome_spanish/s5/local/callhome_get_lattices.py index 9112d868c25..4c96e01ce7e 100755 --- a/egs/fisher_callhome_spanish/s5/local/callhome_get_lattices.py +++ b/egs/fisher_callhome_spanish/s5/local/callhome_get_lattices.py @@ -5,6 +5,7 @@ # The list of files in the conversations for which 1 best output has to be extracted # words.txt +from __future__ import print_function import os import sys import subprocess @@ -76,7 +77,7 @@ def findLattice(timeDetail): # Concatenate lattices mergedTranslation = latticeConcatenate(mergedTranslation, tmp) - print mergedTranslation + print(mergedTranslation) if mergedTranslation != "": # Sanjeev's Recipe : Remove epsilons and topo sort @@ -95,16 +96,16 @@ def findLattice(timeDetail): # file so it can be checked later proc = subprocess.Popen("/export/a04/gkumar/moses/mosesdecoder/checkplf < " + finalPLFFile + " 2>&1 | awk 'FNR == 2 {print}'", stdout=subprocess.PIPE, shell=True) line = proc.stdout.readline() - print line + " " + str(lineNo) + print("{} {}".format(line, lineNo)) if line.strip() != "PLF format appears to be correct.": os.system("cp " + finalFST + " " + invalidplfdir + "/" + timeInfo[0]) invalidPLF.write(invalidplfdir + "/" + timeInfo[0] + "\n") - rmLines.write(str(lineNo) + "\n") + rmLines.write("{}\n".format(lineNo)) else: provFile.write(PLFline) else: blankPLF.write(timeInfo[0] + "\n") - rmLines.write(str(lineNo) + "\n") + rmLines.write("{}\n".format(lineNo)) # Now convert to PLF lineNo += 1 diff --git a/egs/fisher_callhome_spanish/s5/local/chain/run_tdnn_1g.sh b/egs/fisher_callhome_spanish/s5/local/chain/run_tdnn_1g.sh new file mode 100755 index 00000000000..7f407552c2e --- /dev/null +++ b/egs/fisher_callhome_spanish/s5/local/chain/run_tdnn_1g.sh @@ -0,0 +1,288 @@ +#!/bin/bash + +# 1g is like 1f but upgrading to a "resnet-style TDNN-F model", i.e. +# with bypass resnet connections, and re-tuned. +# compute-wer --text --mode=present ark:exp/chain/multipsplice_tdnn/decode_fsp_train_test/scoring_kaldi/test_filt.txt ark,p:- +# %WER 22.21 [ 8847 / 39831, 1965 ins, 2127 del, 4755 sub ] +# %SER 56.98 [ 3577 / 6278 ] +# Scored 6278 sentences, 0 not present in hyp. + +# steps/info/chain_dir_info.pl exp/chain/multipsplice_tdnn +# exp/chain/multipsplice_tdnn: num-iters=296 nj=1..2 num-params=8.2M dim=40+100->2489 combine=-0.170->-0.165 (over 8) xent:train/valid[196,295,final]=(-2.30,-1.93,-1.83/-2.24,-1.96,-1.86) logprob:train/valid[196,295,final]=(-0.208,-0.169,-0.164/-0.189,-0.161,-0.158) + +set -e -o pipefail + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +nj=30 +train_set=train +test_sets="test dev" +gmm=tri5a # this is the source gmm-dir that we'll use for alignments; it + # should have alignments for the specified training data. +num_threads_ubm=32 +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. + +# Options which are not passed through to run_ivector_common.sh +affix=1g #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +common_egs_dir= +reporting_email= + +# LSTM/chain options +train_stage=-10 +xent_regularize=0.1 +dropout_schedule='0,0@0.20,0.3@0.50,0' + +# training chunk-options +chunk_width=140,100,160 +# we don't need extra left/right context for TDNN systems. +chunk_left_context=0 +chunk_right_context=0 + +# training options +srand=0 +remove_egs=true + +#decode options +test_online_decoding=false # if true, it will run the last decoding stage. + +# End configuration section. +echo "$0 $@" # Print the command line for logging + + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 17 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj 100 --cmd "$train_cmd" ${lores_train_data_dir} \ + data/lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 18 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 3500 ${lores_train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 19 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + tdnn_opts="l2-regularize=0.01 dropout-proportion=0.0 dropout-per-dim-continuous=true" + tdnnf_opts="l2-regularize=0.01 dropout-proportion=0.0 bypass-scale=0.66" + linear_opts="l2-regularize=0.01 orthonormal-constraint=-1.0" + prefinal_opts="l2-regularize=0.01" + output_opts="l2-regularize=0.005" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-dropout-layer name=tdnn1 $tdnn_opts dim=1024 + tdnnf-layer name=tdnnf2 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=1 + tdnnf-layer name=tdnnf3 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=1 + tdnnf-layer name=tdnnf4 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=1 + tdnnf-layer name=tdnnf5 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=0 + tdnnf-layer name=tdnnf6 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf7 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf8 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf9 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf10 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf11 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf12 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf13 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + linear-component name=prefinal-l dim=192 $linear_opts + + + prefinal-layer name=prefinal-chain input=prefinal-l $prefinal_opts big-dim=1024 small-dim=192 + output-layer name=output include-log-softmax=false dim=$num_targets $output_opts + + prefinal-layer name=prefinal-xent input=prefinal-l $prefinal_opts big-dim=1024 small-dim=192 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 20 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/wsj-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir $train_ivector_dir \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.0 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.srand $srand \ + --trainer.max-param-change 2.0 \ + --trainer.num-epochs 4 \ + --trainer.frames-per-iter 5000000 \ + --trainer.optimization.num-jobs-initial 1 \ + --trainer.optimization.num-jobs-final=2 \ + --trainer.optimization.initial-effective-lrate 0.0005 \ + --trainer.optimization.final-effective-lrate 0.00005 \ + --trainer.num-chunk-per-minibatch 128,64 \ + --trainer.optimization.momentum 0.0 \ + --egs.chunk-width $chunk_width \ + --egs.chunk-left-context 0 \ + --egs.chunk-right-context 0 \ + --egs.dir "$common_egs_dir" \ + --egs.opts "--frames-overlap-per-eg 0" \ + --cleanup.remove-egs $remove_egs \ + --use-gpu true \ + --feat-dir $train_data_dir \ + --tree-dir $tree_dir \ + --lat-dir exp/tri5a_lats_nodup_sp \ + --dir $dir || exit 1; +fi + +if [ $stage -le 21 ]; then + # The reason we are using data/lang_test here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + #LM was trained only on Fisher Spanish train subset. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/lang_test \ + $tree_dir $tree_dir/graph_fsp_train || exit 1; + +fi + +rnnlmdir=exp/rnnlm_lstm_tdnn_1b +if [ $stage -le 22 ]; then + local/rnnlm/train_rnnlm.sh --dir $rnnlmdir || exit 1; +fi + +if [ $stage -le 23 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + rm $dir/.error 2>/dev/null || true + + for data in $test_sets; do + ( + nspk=$(wc -l &1 | awk 'FNR == 2 {print}'", stdout=subprocess.PIPE, shell=True) line = proc.stdout.readline() - print line + " " + str(lineNo) + print("{} {}".format(line, lineNo)) if line.strip() != "PLF format appears to be correct.": os.system("cp " + finalFST + " " + invalidplfdir + "/" + timeInfo[0]) invalidPLF.write(invalidplfdir + "/" + timeInfo[0] + "\n") - rmLines.write(str(lineNo) + "\n") + rmLines.write("{}\n".format(lineNo)) else: provFile.write(PLFline) else: blankPLF.write(timeInfo[0] + "\n") - rmLines.write(str(lineNo) + "\n") + rmLines.write("{}\n".format(lineNo)) # Now convert to PLF lineNo += 1 diff --git a/egs/fisher_callhome_spanish/s5/local/merge_lexicons.py b/egs/fisher_callhome_spanish/s5/local/merge_lexicons.py index 5c09f09bc35..b42eb52d20a 100755 --- a/egs/fisher_callhome_spanish/s5/local/merge_lexicons.py +++ b/egs/fisher_callhome_spanish/s5/local/merge_lexicons.py @@ -1,10 +1,11 @@ -#!/usr/bin/env python # Copyright 2014 Gaurav Kumar. Apache 2.0 +# 2018 Saikiran Valluri, GoVivace inc., Avaaya +#!/usr/bin/env python # -*- coding: utf-8 -*- # # Merges unique words from Spanish Fisher, Gigaword and the LDC spanish lexicon - -import sys +from __future__ import print_function +import sys, re import json import codecs import operator @@ -16,6 +17,7 @@ uw_gigaword = tmpdir + "/es_wordlist.json" uw_LDC = ldc_lexicon + "/callhome_spanish_lexicon_970908/preferences" +filtered_letters = re.compile(u'[¡¥ª°º¿àçèëìîôö0123456789]') merged_lexicon = [] # All three lexicons are in different formats # First add the data from lexicon_fisher (A) into the dictionary @@ -24,8 +26,7 @@ merged_lexicon.append(line.strip()) fisher.close() -print "After adding the fisher data, the lexicon contains " \ - + str(len(merged_lexicon)) + " entries." +print("After adding the fisher data, the lexicon contains {} entries".format(len(merged_lexicon))) # Now add data from the LDC lexicon ldc = codecs.open(uw_LDC, encoding='iso-8859-1') @@ -34,12 +35,11 @@ if entries[0].lower() not in merged_lexicon: merged_lexicon.append(entries[0].lower()) -print "After adding the LDC data, the lexicon contains " \ - + str(len(merged_lexicon)) + " entries." +print("After adding the LDC data, the lexicon contains {} entries".format(len(merged_lexicon))) # Finally add the gigaword data gigaword = json.load(open(uw_gigaword)) -gigaword = reversed(sorted(gigaword.iteritems(), key=operator.itemgetter(1))) +gigaword = reversed(sorted(gigaword.items(), key=operator.itemgetter(1))) for item in gigaword: # We need a maximum of wordlimit words in the lexicon @@ -49,16 +49,16 @@ if item[0].lower() not in merged_lexicon: merged_lexicon.append(item[0].lower()) -print "After adding the Gigaword data, the lexicon contains " \ - + str(len(merged_lexicon)) + " entries." +print("After adding the Gigaword data, the lexicon contains {} entries".format(len(merged_lexicon))) # Now write the uniquewords to a file lf = codecs.open(tmpdir + '/uniquewords64k', encoding='utf-8', mode='w+') ltuples = sorted(merged_lexicon) for item in ltuples: - lf.write(item + "\n") + if not item==u'ñ' and not re.search(filtered_letters, item): + lf.write(item + "\n") lf.close() -print "Finshed writing unique words" +print("Finshed writing unique words") diff --git a/egs/fisher_callhome_spanish/s5/local/nnet3/run_ivector_common.sh b/egs/fisher_callhome_spanish/s5/local/nnet3/run_ivector_common.sh new file mode 100755 index 00000000000..cc9de4d26c5 --- /dev/null +++ b/egs/fisher_callhome_spanish/s5/local/nnet3/run_ivector_common.sh @@ -0,0 +1,187 @@ +#!/bin/bash + +set -e -o pipefail + +# This script is called from scripts like local/nnet3/run_tdnn.sh and +# local/chain/run_tdnn.sh (and may eventually be called by more scripts). It +# contains the common feature preparation and iVector-related parts of the +# script. See those scripts for examples of usage. + + +stage=7 +nj=30 +train_set=train # you might set this to e.g. train. +test_sets="test dev" +gmm=tri5a # This specifies a GMM-dir from the features of the type you're training the system on; + # it should contain alignments for 'train_set'. + +num_threads_ubm=32 +nnet3_affix= # affix for exp/nnet3 directory to put iVector stuff in (e.g. + # in the tedlium recip it's _cleaned). + +. ./cmd.sh +. ./path.sh +. utils/parse_options.sh + + +gmm_dir=exp/${gmm} +ali_dir=exp/${gmm}_ali_${train_set}_sp + +for f in data/${train_set}/feats.scp ${gmm_dir}/final.mdl; do + if [ ! -f $f ]; then + echo "$0: expected file $f to exist" + exit 1 + fi +done + + + +if [ $stage -le 7 ] && [ -f data/${train_set}_sp_hires/feats.scp ]; then + echo "$0: data/${train_set}_sp_hires/feats.scp already exists." + echo " ... Please either remove it, or rerun this script with stage > 7." + exit 1 +fi + + +if [ $stage -le 8 ]; then + echo "$0: preparing directory for speed-perturbed data" + utils/data/perturb_data_dir_speed_3way.sh data/${train_set} data/${train_set}_sp +fi + +if [ $stage -le 9 ]; then + echo "$0: creating high-resolution MFCC features" + + # this shows how you can split across multiple file-systems. we'll split the + # MFCC dir across multiple locations. You might want to be careful here, if you + # have multiple copies of Kaldi checked out and run the same recipe, not to let + # them overwrite each other. + mfccdir=data/${train_set}_sp_hires/data + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl /export/b0{5,6,7,8}/$USER/kaldi-data/mfcc/wsj-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage + fi + + for datadir in ${train_set}_sp ${test_sets}; do + utils/copy_data_dir.sh data/$datadir data/${datadir}_hires + done + + # do volume-perturbation on the training data prior to extracting hires + # features; this helps make trained nnets more invariant to test data volume. + utils/data/perturb_data_dir_volume.sh data/${train_set}_sp_hires + + for datadir in ${train_set}_sp ${test_sets}; do + steps/make_mfcc.sh --nj $nj --mfcc-config conf/mfcc_hires.conf \ + --cmd "$train_cmd" data/${datadir}_hires + steps/compute_cmvn_stats.sh data/${datadir}_hires + utils/fix_data_dir.sh data/${datadir}_hires + done +fi + +if [ $stage -le 10 ]; then + echo "$0: computing a subset of data to train the diagonal UBM." + + mkdir -p exp/nnet3${nnet3_affix}/diag_ubm + temp_data_root=exp/nnet3${nnet3_affix}/diag_ubm + + # train a diagonal UBM using a subset of about a quarter of the data + num_utts_total=$(wc -l $text_dir/ami.txt + cat $dev | cut -d ' ' -f2- > $text_dir/dev.txt +fi + +if [ $stage -le 1 ]; then + cp $wordlist $dir/config/ + n=`cat $dir/config/words.txt | wc -l` + echo " $n" >> $dir/config/words.txt + + # words that are not present in words.txt but are in the training or dev data, will be + # mapped to during training. + echo "" >$dir/config/oov.txt + + cat > $dir/config/data_weights.txt <$dir/config/unigram_probs.txt + + # choose features + rnnlm/choose_features.py --unigram-probs=$dir/config/unigram_probs.txt \ + --use-constant-feature=true \ + --top-word-features 10000 \ + --min-frequency 1.0e-03 \ + --special-words=',,,,[noise],[laughter]' \ + $dir/config/words.txt > $dir/config/features.txt + +lstm_opts="l2-regularize=$comp_l2" +tdnn_opts="l2-regularize=$comp_l2" +output_opts="l2-regularize=$output_l2" + + cat >$dir/config/xconfig <&1 | awk 'FNR == 2 {print}'", stdout=subprocess.PIPE, shell=True) line = proc.stdout.readline() - print line + " " + str(lineNo) + print("{} {}".format(line, lineNo)) if line.strip() != "PLF format appears to be correct.": os.system("cp " + finalFST + " " + invalidplfdir + "/" + timeInfo[0]) invalidPLF.write(invalidplfdir + "/" + timeInfo[0] + "\n") - rmLines.write(str(lineNo) + "\n") + rmLines.write("{}\n".format(lineNo)) else: provFile.write(PLFline) else: blankPLF.write(timeInfo[0] + "\n") - rmLines.write(str(lineNo) + "\n") + rmLines.write("{}\n".format(lineNo)) # Now convert to PLF lineNo += 1 diff --git a/egs/fisher_callhome_spanish/s5/path.sh b/egs/fisher_callhome_spanish/s5/path.sh index 1a6fb5f891b..17ffb0369f8 100755 --- a/egs/fisher_callhome_spanish/s5/path.sh +++ b/egs/fisher_callhome_spanish/s5/path.sh @@ -3,3 +3,4 @@ export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH [ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 . $KALDI_ROOT/tools/config/common_path.sh export LC_ALL=C +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/dpovey/libs diff --git a/egs/fisher_callhome_spanish/s5/rnnlm b/egs/fisher_callhome_spanish/s5/rnnlm new file mode 120000 index 00000000000..fb754622d5e --- /dev/null +++ b/egs/fisher_callhome_spanish/s5/rnnlm @@ -0,0 +1 @@ +../../wsj/s5/rnnlm \ No newline at end of file diff --git a/egs/fisher_callhome_spanish/s5/run.sh b/egs/fisher_callhome_spanish/s5/run.sh index 57902a98fed..6e2752a7b68 100755 --- a/egs/fisher_callhome_spanish/s5/run.sh +++ b/egs/fisher_callhome_spanish/s5/run.sh @@ -1,20 +1,22 @@ #!/bin/bash # +# Copyright 2018 Nagendra Goel, Saikiran Valluri Apache 2.0 # Copyright 2014 Gaurav Kumar. Apache 2.0 # Recipe for Fisher/Callhome-Spanish -# Made to integrate KALDI with JOSHUA for end-to-end ASR and SMT stage=0 +train_stage=-20 +train_sgmm2=false # call the next line with the directory where the Spanish Fisher data is # (the values below are just an example). -sfisher_speech=/veu4/jadrian/data/LDC/LDC2010S01 -sfisher_transcripts=/veu4/jadrian/data/LDC/LDC2010T04 -spanish_lexicon=/veu4/jadrian/data/LDC/LDC96L16 +sfisher_speech=/export/corpora/LDC/LDC2010S01 +sfisher_transcripts=/export/corpora/LDC/LDC2010T04 +spanish_lexicon=/export/corpora/LDC/LDC96L16 split=local/splits/split_fisher -callhome_speech=/veu4/jadrian/data/LDC/LDC96S35 -callhome_transcripts=/veu4/jadrian/data/LDC/LDC96T17 +callhome_speech=/export/corpora/LDC/LDC96S35 +callhome_transcripts=/export/corpora/LDC/LDC96T17 split_callhome=local/splits/split_callhome mfccdir=`pwd`/mfcc @@ -25,7 +27,7 @@ if [ -f path.sh ]; then . ./path.sh; fi set -e -if [ $stage -lt 1 ]; then +if [ $stage -le 1 ]; then local/fsp_data_prep.sh $sfisher_speech $sfisher_transcripts local/callhome_data_prep.sh $callhome_speech $callhome_transcripts @@ -95,7 +97,7 @@ if [ $stage -lt 1 ]; then local/callhome_create_splits.sh $split_callhome fi -if [ $stage -lt 2 ]; then +if [ $stage -le 2 ]; then # Now compute CMVN stats for the train, dev and test subsets steps/compute_cmvn_stats.sh data/dev exp/make_mfcc/dev $mfccdir steps/compute_cmvn_stats.sh data/test exp/make_mfcc/test $mfccdir @@ -124,90 +126,95 @@ if [ $stage -lt 2 ]; then utils/subset_data_dir.sh --speakers data/train 90000 data/train_100k fi +if [ $stage -le 3 ]; then + steps/train_mono.sh --nj 10 --cmd "$train_cmd" \ + data/train_10k_nodup data/lang exp/mono0a -steps/train_mono.sh --nj 10 --cmd "$train_cmd" \ - data/train_10k_nodup data/lang exp/mono0a + steps/align_si.sh --nj 30 --cmd "$train_cmd" \ + data/train_30k data/lang exp/mono0a exp/mono0a_ali || exit 1; -steps/align_si.sh --nj 30 --cmd "$train_cmd" \ - data/train_30k data/lang exp/mono0a exp/mono0a_ali || exit 1; - -steps/train_deltas.sh --cmd "$train_cmd" \ + steps/train_deltas.sh --cmd "$train_cmd" \ 2500 20000 data/train_30k data/lang exp/mono0a_ali exp/tri1 || exit 1; -(utils/mkgraph.sh data/lang_test exp/tri1 exp/tri1/graph - steps/decode.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ - exp/tri1/graph data/dev exp/tri1/decode_dev)& + (utils/mkgraph.sh data/lang_test exp/tri1 exp/tri1/graph + steps/decode.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ + exp/tri1/graph data/dev exp/tri1/decode_dev)& -steps/align_si.sh --nj 30 --cmd "$train_cmd" \ - data/train_30k data/lang exp/tri1 exp/tri1_ali || exit 1; + steps/align_si.sh --nj 30 --cmd "$train_cmd" \ + data/train_30k data/lang exp/tri1 exp/tri1_ali || exit 1; -steps/train_deltas.sh --cmd "$train_cmd" \ + steps/train_deltas.sh --cmd "$train_cmd" \ 2500 20000 data/train_30k data/lang exp/tri1_ali exp/tri2 || exit 1; -( - utils/mkgraph.sh data/lang_test exp/tri2 exp/tri2/graph || exit 1; - steps/decode.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ - exp/tri2/graph data/dev exp/tri2/decode_dev || exit 1; -)& - + ( + utils/mkgraph.sh data/lang_test exp/tri2 exp/tri2/graph || exit 1; + steps/decode.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ + exp/tri2/graph data/dev exp/tri2/decode_dev || exit 1; + )& +fi -steps/align_si.sh --nj 30 --cmd "$train_cmd" \ - data/train_100k data/lang exp/tri2 exp/tri2_ali || exit 1; +if [ $stage -le 4 ]; then + steps/align_si.sh --nj 30 --cmd "$train_cmd" \ + data/train_100k data/lang exp/tri2 exp/tri2_ali || exit 1; # Train tri3a, which is LDA+MLLT, on 100k data. -steps/train_lda_mllt.sh --cmd "$train_cmd" \ + steps/train_lda_mllt.sh --cmd "$train_cmd" \ --splice-opts "--left-context=3 --right-context=3" \ 3000 40000 data/train_100k data/lang exp/tri2_ali exp/tri3a || exit 1; -( - utils/mkgraph.sh data/lang_test exp/tri3a exp/tri3a/graph || exit 1; - steps/decode.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ - exp/tri3a/graph data/dev exp/tri3a/decode_dev || exit 1; -)& - + ( + utils/mkgraph.sh data/lang_test exp/tri3a exp/tri3a/graph || exit 1; + steps/decode.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ + exp/tri3a/graph data/dev exp/tri3a/decode_dev || exit 1; + )& +fi +if [ $stage -le 5 ]; then # Next we'll use fMLLR and train with SAT (i.e. on # fMLLR features) -steps/align_fmllr.sh --nj 30 --cmd "$train_cmd" \ - data/train_100k data/lang exp/tri3a exp/tri3a_ali || exit 1; + steps/align_fmllr.sh --nj 30 --cmd "$train_cmd" \ + data/train_100k data/lang exp/tri3a exp/tri3a_ali || exit 1; -steps/train_sat.sh --cmd "$train_cmd" \ - 4000 60000 data/train_100k data/lang exp/tri3a_ali exp/tri4a || exit 1; + steps/train_sat.sh --cmd "$train_cmd" \ + 4000 60000 data/train_100k data/lang exp/tri3a_ali exp/tri4a || exit 1; -( - utils/mkgraph.sh data/lang_test exp/tri4a exp/tri4a/graph - steps/decode_fmllr.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ - exp/tri4a/graph data/dev exp/tri4a/decode_dev + ( + utils/mkgraph.sh data/lang_test exp/tri4a exp/tri4a/graph + steps/decode_fmllr.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ + exp/tri4a/graph data/dev exp/tri4a/decode_dev )& -steps/align_fmllr.sh --nj 30 --cmd "$train_cmd" \ - data/train data/lang exp/tri4a exp/tri4a_ali || exit 1; + steps/align_fmllr.sh --nj 30 --cmd "$train_cmd" \ + data/train data/lang exp/tri4a exp/tri4a_ali || exit 1; # Reduce the number of gaussians -steps/train_sat.sh --cmd "$train_cmd" \ - 5000 120000 data/train data/lang exp/tri4a_ali exp/tri5a || exit 1; + steps/train_sat.sh --cmd "$train_cmd" \ + 5000 120000 data/train data/lang exp/tri4a_ali exp/tri5a || exit 1; -( - utils/mkgraph.sh data/lang_test exp/tri5a exp/tri5a/graph - steps/decode_fmllr.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ - exp/tri5a/graph data/dev exp/tri5a/decode_dev - steps/decode_fmllr.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ - exp/tri5a/graph data/test exp/tri5a/decode_test + ( + utils/mkgraph.sh data/lang_test exp/tri5a exp/tri5a/graph + steps/decode_fmllr.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ + exp/tri5a/graph data/dev exp/tri5a/decode_dev + steps/decode_fmllr.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ + exp/tri5a/graph data/test exp/tri5a/decode_test # Decode CALLHOME - steps/decode_fmllr.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ - exp/tri5a/graph data/callhome_test exp/tri5a/decode_callhome_test - steps/decode_fmllr.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ - exp/tri5a/graph data/callhome_dev exp/tri5a/decode_callhome_dev - steps/decode_fmllr.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ - exp/tri5a/graph data/callhome_train exp/tri5a/decode_callhome_train -) & - + steps/decode_fmllr.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ + exp/tri5a/graph data/callhome_test exp/tri5a/decode_callhome_test + steps/decode_fmllr.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ + exp/tri5a/graph data/callhome_dev exp/tri5a/decode_callhome_dev + steps/decode_fmllr.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ + exp/tri5a/graph data/callhome_train exp/tri5a/decode_callhome_train + ) & + + + steps/align_fmllr.sh \ + --boost-silence 0.5 --nj 32 --cmd "$train_cmd" \ + data/train data/lang exp/tri5a exp/tri5a_ali +fi -steps/align_fmllr.sh \ - --boost-silence 0.5 --nj 32 --cmd "$train_cmd" \ - data/train data/lang exp/tri5a exp/tri5a_ali +if $train_sgmm2; then steps/train_ubm.sh \ --cmd "$train_cmd" 750 \ @@ -258,22 +265,7 @@ for iter in 1 2 3 4; do done ) & -dnn_cpu_parallel_opts=(--minibatch-size 128 --max-change 10 --num-jobs-nnet 8 --num-threads 16 \ - --parallel-opts "--num-threads 16") -dnn_gpu_parallel_opts=(--minibatch-size 512 --max-change 40 --num-jobs-nnet 4 --num-threads 1 \ - --parallel-opts "--gpu 1") - -steps/nnet2/train_pnorm_ensemble.sh \ - --mix-up 5000 --initial-learning-rate 0.008 --final-learning-rate 0.0008\ - --num-hidden-layers 4 --pnorm-input-dim 2000 --pnorm-output-dim 200\ - --cmd "$train_cmd" \ - "${dnn_gpu_parallel_opts[@]}" \ - --ensemble-size 4 --initial-beta 0.1 --final-beta 5 \ - data/train data/lang exp/tri5a_ali exp/tri6a_dnn +fi -( - steps/nnet2/decode.sh --nj 13 --cmd "$decode_cmd" --num-threads 4 \ - --scoring-opts "--min-lmwt 8 --max-lmwt 16" --transform-dir exp/tri5a/decode_dev exp/tri5a/graph data/dev exp/tri6a_dnn/decode_dev -) & -wait +local/chain/run_tdnn_1g.sh --stage $stage --train-stage $train_stage || exit 1; exit 0; diff --git a/egs/fisher_english/s5/local/chain/run_tdnn.sh b/egs/fisher_english/s5/local/chain/run_tdnn.sh index 14174e617c4..1fd0f1fdf3a 100755 --- a/egs/fisher_english/s5/local/chain/run_tdnn.sh +++ b/egs/fisher_english/s5/local/chain/run_tdnn.sh @@ -112,7 +112,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_100k_semisupervised_1a.sh b/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_100k_semisupervised_1a.sh index e95de232304..b76efc4f1de 100644 --- a/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_100k_semisupervised_1a.sh +++ b/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_100k_semisupervised_1a.sh @@ -231,7 +231,7 @@ if [ $stage -le 11 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $sup_tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_1a.sh b/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_1a.sh index e76df666e8a..b1c133942ef 100755 --- a/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_1a.sh +++ b/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_1a.sh @@ -142,7 +142,7 @@ if [ $stage -le 13 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_50k_semisupervised_1a.sh b/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_50k_semisupervised_1a.sh index 2d5b2f8480e..53aac8c08ea 100755 --- a/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_50k_semisupervised_1a.sh +++ b/egs/fisher_english/s5/local/semisup/chain/tuning/run_tdnn_50k_semisupervised_1a.sh @@ -250,7 +250,7 @@ if [ $stage -le 11 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $sup_tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/fisher_swbd/s5/local/chain/run_blstm_6j.sh b/egs/fisher_swbd/s5/local/chain/run_blstm_6j.sh index cbf0ef6cb6c..c12f604f26b 100755 --- a/egs/fisher_swbd/s5/local/chain/run_blstm_6j.sh +++ b/egs/fisher_swbd/s5/local/chain/run_blstm_6j.sh @@ -133,7 +133,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/fisher_swbd/s5/local/chain/run_tdnn_7c.sh b/egs/fisher_swbd/s5/local/chain/run_tdnn_7c.sh index 12b3187a5fa..efcd1eced4a 100644 --- a/egs/fisher_swbd/s5/local/chain/run_tdnn_7c.sh +++ b/egs/fisher_swbd/s5/local/chain/run_tdnn_7c.sh @@ -129,7 +129,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/fisher_swbd/s5/local/chain/run_tdnn_7d.sh b/egs/fisher_swbd/s5/local/chain/run_tdnn_7d.sh index 7d640c3262a..e4a555abfdd 100644 --- a/egs/fisher_swbd/s5/local/chain/run_tdnn_7d.sh +++ b/egs/fisher_swbd/s5/local/chain/run_tdnn_7d.sh @@ -134,7 +134,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.002" linear_opts="orthonormal-constraint=1.0" output_opts="l2-regularize=0.0005 bottleneck-dim=256" diff --git a/egs/fisher_swbd/s5/local/chain/run_tdnn_lstm_1a.sh b/egs/fisher_swbd/s5/local/chain/run_tdnn_lstm_1a.sh index 07e88b59ddc..5650cedca28 100755 --- a/egs/fisher_swbd/s5/local/chain/run_tdnn_lstm_1a.sh +++ b/egs/fisher_swbd/s5/local/chain/run_tdnn_lstm_1a.sh @@ -142,7 +142,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" mkdir -p $dir/configs diff --git a/egs/fisher_swbd/s5/local/chain/run_tdnn_lstm_1b.sh b/egs/fisher_swbd/s5/local/chain/run_tdnn_lstm_1b.sh index c9d50d1f7bd..f3cc869e6de 100755 --- a/egs/fisher_swbd/s5/local/chain/run_tdnn_lstm_1b.sh +++ b/egs/fisher_swbd/s5/local/chain/run_tdnn_lstm_1b.sh @@ -151,7 +151,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20 dropout-proportion=0.0" mkdir -p $dir/configs diff --git a/egs/fisher_swbd/s5/local/chain/run_tdnn_opgru_1a.sh b/egs/fisher_swbd/s5/local/chain/run_tdnn_opgru_1a.sh index 1cce08abeee..059a81e15fc 100755 --- a/egs/fisher_swbd/s5/local/chain/run_tdnn_opgru_1a.sh +++ b/egs/fisher_swbd/s5/local/chain/run_tdnn_opgru_1a.sh @@ -148,7 +148,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) gru_opts="dropout-per-frame=true dropout-proportion=0.0 " mkdir -p $dir/configs diff --git a/egs/fisher_swbd/s5/local/chain/run_tdnn_opgru_1b.sh b/egs/fisher_swbd/s5/local/chain/run_tdnn_opgru_1b.sh index 2334c6a1bc1..d86b699d6f6 100755 --- a/egs/fisher_swbd/s5/local/chain/run_tdnn_opgru_1b.sh +++ b/egs/fisher_swbd/s5/local/chain/run_tdnn_opgru_1b.sh @@ -149,7 +149,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) gru_opts="dropout-per-frame=true dropout-proportion=0.0 " mkdir -p $dir/configs diff --git a/egs/fisher_swbd/s5/local/format_acronyms_ctm_eval2000.py b/egs/fisher_swbd/s5/local/format_acronyms_ctm_eval2000.py index 3c447c5976a..75cc4458d85 100755 --- a/egs/fisher_swbd/s5/local/format_acronyms_ctm_eval2000.py +++ b/egs/fisher_swbd/s5/local/format_acronyms_ctm_eval2000.py @@ -10,6 +10,7 @@ # en_4156 B 414.58 0.16 l # en_4156 B 414.74 0.17 a +from __future__ import division import argparse,re __author__ = 'Minhua Wu' @@ -27,7 +28,7 @@ if items[4].find(".") != -1: letters = items[4].split("._") acronym_period = round(float(items[3]), 2) - letter_slot = round(acronym_period / len(letters), 2) + letter_slot = round(acronym_period/len(letters), 2) time_start = round(float(items[2]), 2) for l in letters[:-1]: time = " %.2f %.2f " % (time_start, letter_slot) diff --git a/egs/fisher_swbd/s5/local/format_acronyms_ctm_rt03.py b/egs/fisher_swbd/s5/local/format_acronyms_ctm_rt03.py index 59814beb4ea..c3f9af09c99 100755 --- a/egs/fisher_swbd/s5/local/format_acronyms_ctm_rt03.py +++ b/egs/fisher_swbd/s5/local/format_acronyms_ctm_rt03.py @@ -10,6 +10,7 @@ # en_4156 B 414.58 0.16 l # en_4156 B 414.74 0.17 a +from __future__ import division import argparse,re __author__ = 'Minhua Wu' @@ -27,7 +28,7 @@ if items[4].find(".") != -1: letters = items[4].split("._") acronym_period = round(float(items[3]), 2) - letter_slot = round(acronym_period / len(letters), 2) + letter_slot = round(acronym_period/ len(letters), 2) time_start = round(float(items[2]), 2) for l in letters[:-1]: time = " %.2f %.2f " % (time_start, letter_slot) diff --git a/egs/formosa/README.txt b/egs/formosa/README.txt new file mode 100644 index 00000000000..3b9d78dad92 --- /dev/null +++ b/egs/formosa/README.txt @@ -0,0 +1,22 @@ +### Welcome to the demo recipe of the Formosa Speech in the Wild (FSW) Project ### + +The language habits of Taiwanese people are different from other Mandarin speakers (both accents and cultures) [1]. Especially Tainwaese use tranditional Chinese characters, i.e., 繁體中文). To address this issue, a Taiwanese speech corpus collection project "Formosa Speech in the Wild (FSW)" was initiated in 2017 to improve the development of Taiwanese-specific speech recognition techniques. + +FSW corpus will be a large-scale database of real-Life/multi-gene Taiwanese Spontaneous speech collected and transcribed from various sources (radio, TV, open courses, etc.). To demostrate that this database is a reasonable data resource for Taiwanese spontaneous speech recognition research, a baseline recipe is provied here for everybody, especially students, to develop their own systems easily and quickly. + +This recipe is based on the "NER-Trs-Vol1" corpus (about 150 hours broadcast radio speech selected from FSW). For more details, please visit: +* Formosa Speech in the Wild (FSW) project (https://sites.google.com/speech.ntut.edu.tw/fsw) + +If you want to apply the NER-Trs-Vol1 corpus, please contact Yuan-Fu Liao (廖元甫) via "yfliao@mail.ntut.edu.tw". This corpus is only for non-commercial research/education use and will be distributed via our GitLab server in https://speech.nchc.org.tw. + +Any bug, errors, comments or suggestions are very welcomed. + +Yuan-Fu Liao (廖元甫) +Associate Professor +Department of electronic Engineering, +National Taipei University of Technology +http://www.ntut.edu.tw/~yfliao +yfliao@mail.ntut.edu.tw + +............ +[1] The languages of Taiwan consist of several varieties of languages under families of the Austronesian languages and the Sino-Tibetan languages. Taiwanese Mandarin, Hokkien, Hakka and Formosan languages are used by 83.5%, 81.9%, 6.6% and 1.4% of the population respectively (2010). Given the prevalent use of Taiwanese Hokkien, the Mandarin spoken in Taiwan has been to a great extent influenced by it. diff --git a/egs/formosa/s5/RESULTS b/egs/formosa/s5/RESULTS new file mode 100644 index 00000000000..b047e5cefe4 --- /dev/null +++ b/egs/formosa/s5/RESULTS @@ -0,0 +1,43 @@ +# +# Reference results +# +# Experimental settings: +# +# training set: show CS, BG, DA, QG, SR, SY and WK, in total 18977 utt., 1,088,948 words +# test set: show JZ, GJ, KX and YX, in total 2112 utt., 135,972 words +# eval set: show JX, TD and WJ, in total 2222 utt., 104,648 words +# +# lexicon: 274,036 words +# phones (IPA): 196 (tonal) +# + +# WER: test + +%WER 61.32 [ 83373 / 135972, 5458 ins, 19156 del, 58759 sub ] exp/mono/decode_test/wer_11_0.0 +%WER 41.00 [ 55742 / 135972, 6725 ins, 12763 del, 36254 sub ] exp/tri1/decode_test/wer_15_0.0 +%WER 40.41 [ 54948 / 135972, 7366 ins, 11505 del, 36077 sub ] exp/tri2/decode_test/wer_14_0.0 +%WER 38.67 [ 52574 / 135972, 6855 ins, 11250 del, 34469 sub ] exp/tri3a/decode_test/wer_15_0.0 +%WER 35.70 [ 48546 / 135972, 7197 ins, 9717 del, 31632 sub ] exp/tri4a/decode_test/wer_17_0.0 +%WER 32.11 [ 43661 / 135972, 6112 ins, 10185 del, 27364 sub ] exp/tri5a/decode_test/wer_17_0.5 +%WER 31.36 [ 42639 / 135972, 6846 ins, 8860 del, 26933 sub ] exp/tri5a_cleaned/decode_test/wer_17_0.5 +%WER 24.43 [ 33218 / 135972, 5524 ins, 7583 del, 20111 sub ] exp/nnet3/tdnn_sp/decode_test/wer_12_0.0 +%WER 23.95 [ 32568 / 135972, 4457 ins, 10271 del, 17840 sub ] exp/chain/tdnn_1a_sp/decode_test/wer_10_0.0 +%WER 23.54 [ 32006 / 135972, 4717 ins, 8644 del, 18645 sub ] exp/chain/tdnn_1b_sp/decode_test/wer_10_0.0 +%WER 20.64 [ 28067 / 135972, 4434 ins, 7946 del, 15687 sub ] exp/chain/tdnn_1c_sp/decode_test/wer_11_0.0 +%WER 20.98 [ 28527 / 135972, 4706 ins, 7816 del, 16005 sub ] exp/chain/tdnn_1d_sp/decode_test/wer_10_0.0 + +# CER: test + +%WER 54.09 [ 116688 / 215718, 4747 ins, 24510 del, 87431 sub ] exp/mono/decode_test/cer_10_0.0 +%WER 32.61 [ 70336 / 215718, 5866 ins, 16282 del, 48188 sub ] exp/tri1/decode_test/cer_13_0.0 +%WER 32.10 [ 69238 / 215718, 6186 ins, 15772 del, 47280 sub ] exp/tri2/decode_test/cer_13_0.0 +%WER 30.40 [ 65583 / 215718, 6729 ins, 13115 del, 45739 sub ] exp/tri3a/decode_test/cer_12_0.0 +%WER 27.53 [ 59389 / 215718, 6311 ins, 13008 del, 40070 sub ] exp/tri4a/decode_test/cer_15_0.0 +%WER 24.21 [ 52232 / 215718, 6425 ins, 11543 del, 34264 sub ] exp/tri5a/decode_test/cer_15_0.0 +%WER 23.41 [ 50492 / 215718, 6645 ins, 10997 del, 32850 sub ] exp/tri5a_cleaned/decode_test/cer_17_0.0 +%WER 17.07 [ 36829 / 215718, 4734 ins, 9938 del, 22157 sub ] exp/nnet3/tdnn_sp/decode_test/cer_12_0.0 +%WER 16.83 [ 36305 / 215718, 4772 ins, 10810 del, 20723 sub ] exp/chain/tdnn_1a_sp/decode_test/cer_9_0.0 +%WER 16.44 [ 35459 / 215718, 4216 ins, 11278 del, 19965 sub ] exp/chain/tdnn_1b_sp/decode_test/cer_10_0.0 +%WER 13.72 [ 29605 / 215718, 4678 ins, 8066 del, 16861 sub ] exp/chain/tdnn_1c_sp/decode_test/cer_10_0.0 +%WER 14.08 [ 30364 / 215718, 5182 ins, 7588 del, 17594 sub ] exp/chain/tdnn_1d_sp/decode_test/cer_9_0.0 + diff --git a/egs/formosa/s5/cmd.sh b/egs/formosa/s5/cmd.sh new file mode 100755 index 00000000000..66ae9090820 --- /dev/null +++ b/egs/formosa/s5/cmd.sh @@ -0,0 +1,27 @@ +# "queue.pl" uses qsub. The options to it are +# options to qsub. If you have GridEngine installed, +# change this to a queue you have access to. +# Otherwise, use "run.pl", which will run jobs locally +# (make sure your --num-jobs options are no more than +# the number of cpus on your machine. + +# Run locally: +#export train_cmd=run.pl +#export decode_cmd=run.pl + +# JHU cluster (or most clusters using GridEngine, with a suitable +# conf/queue.conf). +export train_cmd="queue.pl" +export decode_cmd="queue.pl --mem 4G" + +host=$(hostname -f) +if [ ${host#*.} == "fit.vutbr.cz" ]; then + queue_conf=$HOME/queue_conf/default.conf # see example /homes/kazi/iveselyk/queue_conf/default.conf, + export train_cmd="queue.pl --config $queue_conf --mem 2G --matylda 0.2" + export decode_cmd="queue.pl --config $queue_conf --mem 3G --matylda 0.1" + export cuda_cmd="queue.pl --config $queue_conf --gpu 1 --mem 10G --tmp 40G" +elif [ ${host#*.} == "cm.cluster" ]; then + # MARCC bluecrab cluster: + export train_cmd="slurm.pl --time 4:00:00 " + export decode_cmd="slurm.pl --mem 4G --time 4:00:00 " +fi diff --git a/egs/formosa/s5/conf/decode.config b/egs/formosa/s5/conf/decode.config new file mode 100644 index 00000000000..d91f86183af --- /dev/null +++ b/egs/formosa/s5/conf/decode.config @@ -0,0 +1,5 @@ +beam=11.0 # beam for decoding. Was 13.0 in the scripts. +first_beam=8.0 # beam for 1st-pass decoding in SAT. + + + diff --git a/egs/formosa/s5/conf/mfcc.conf b/egs/formosa/s5/conf/mfcc.conf new file mode 100644 index 00000000000..a1aa3d6c158 --- /dev/null +++ b/egs/formosa/s5/conf/mfcc.conf @@ -0,0 +1,2 @@ +--use-energy=false # only non-default option. +--sample-frequency=16000 diff --git a/egs/formosa/s5/conf/mfcc_hires.conf b/egs/formosa/s5/conf/mfcc_hires.conf new file mode 100644 index 00000000000..ca067e77b37 --- /dev/null +++ b/egs/formosa/s5/conf/mfcc_hires.conf @@ -0,0 +1,10 @@ +# config for high-resolution MFCC features, intended for neural network training. +# Note: we keep all cepstra, so it has the same info as filterbank features, +# but MFCC is more easily compressible (because less correlated) which is why +# we prefer this method. +--use-energy=false # use average of log energy, not energy. +--sample-frequency=16000 # Switchboard is sampled at 8kHz +--num-mel-bins=40 # similar to Google's setup. +--num-ceps=40 # there is no dimensionality reduction. +--low-freq=40 # low cutoff frequency for mel bins +--high-freq=-200 # high cutoff frequently, relative to Nyquist of 8000 (=3800) diff --git a/egs/formosa/s5/conf/online_cmvn.conf b/egs/formosa/s5/conf/online_cmvn.conf new file mode 100644 index 00000000000..591367e7ae9 --- /dev/null +++ b/egs/formosa/s5/conf/online_cmvn.conf @@ -0,0 +1 @@ +# configuration file for apply-cmvn-online, used when invoking online2-wav-nnet3-latgen-faster. diff --git a/egs/formosa/s5/conf/pitch.conf b/egs/formosa/s5/conf/pitch.conf new file mode 100644 index 00000000000..e959a19d5b8 --- /dev/null +++ b/egs/formosa/s5/conf/pitch.conf @@ -0,0 +1 @@ +--sample-frequency=16000 diff --git a/egs/formosa/s5/local/chain/run_tdnn.sh b/egs/formosa/s5/local/chain/run_tdnn.sh new file mode 120000 index 00000000000..e1adaa9346d --- /dev/null +++ b/egs/formosa/s5/local/chain/run_tdnn.sh @@ -0,0 +1 @@ +tuning/run_tdnn_1d.sh \ No newline at end of file diff --git a/egs/formosa/s5/local/chain/tuning/run_tdnn_1a.sh b/egs/formosa/s5/local/chain/tuning/run_tdnn_1a.sh new file mode 100755 index 00000000000..66c5ad3335f --- /dev/null +++ b/egs/formosa/s5/local/chain/tuning/run_tdnn_1a.sh @@ -0,0 +1,181 @@ +#!/bin/bash + +# This script is based on run_tdnn_7h.sh in swbd chain recipe. + +set -e + +# configs for 'chain' +affix=1a +stage=0 +train_stage=-10 +get_egs_stage=-10 +dir=exp/chain/tdnn # Note: _sp will get added to this +decode_iter= + +# training options +num_epochs=4 +initial_effective_lrate=0.001 +final_effective_lrate=0.0001 +max_param_change=2.0 +final_layer_normalize_target=0.5 +num_jobs_initial=2 +num_jobs_final=12 +minibatch_size=128 +frames_per_eg=150,110,90 +remove_egs=false +common_egs_dir= +xent_regularize=0.1 + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 9 ]; then + # Build a tree using our new topology. This is the critically different + # step compared with other recipes. + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 5000 data/$train_set $lang $ali_dir $treedir +fi + +if [ $stage -le 10 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=43 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=625 + relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1) dim=625 + relu-batchnorm-layer name=tdnn3 input=Append(-1,0,1) dim=625 + relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=625 + relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=625 + relu-batchnorm-layer name=tdnn6 input=Append(-3,0,3) dim=625 + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain input=tdnn6 dim=625 target-rms=0.5 + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn6 dim=625 target-rms=0.5 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 + +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 11 ]; then + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir exp/nnet3/ivectors_${train_set} \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width $frames_per_eg \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.num-jobs-initial $num_jobs_initial \ + --trainer.optimization.num-jobs-final $num_jobs_final \ + --trainer.optimization.initial-effective-lrate $initial_effective_lrate \ + --trainer.optimization.final-effective-lrate $final_effective_lrate \ + --trainer.max-param-change $max_param_change \ + --cleanup.remove-egs $remove_egs \ + --feat-dir data/${train_set}_hires \ + --tree-dir $treedir \ + --lat-dir exp/tri5a_sp_lats \ + --use-gpu wait \ + --dir $dir || exit 1; +fi + +if [ $stage -le 12 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 data/lang_test $dir $dir/graph +fi + +graph_dir=$dir/graph +if [ $stage -le 13 ]; then + for test_set in test eval; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj 10 --cmd "$decode_cmd" \ + --online-ivector-dir exp/nnet3/ivectors_$test_set \ + $graph_dir data/${test_set}_hires $dir/decode_${test_set} || exit 1; + done + wait; +fi + +exit 0; diff --git a/egs/formosa/s5/local/chain/tuning/run_tdnn_1b.sh b/egs/formosa/s5/local/chain/tuning/run_tdnn_1b.sh new file mode 100755 index 00000000000..1981bb0530d --- /dev/null +++ b/egs/formosa/s5/local/chain/tuning/run_tdnn_1b.sh @@ -0,0 +1,188 @@ +#!/bin/bash + +# This script shows improvement arising from data cleaning. + +# CER: +# %WER 16.83 [ 36305 / 215718, 4772 ins, 10810 del, 20723 sub ] exp/chain/tdnn_1a_sp/decode_test/cer_9_0.0 +# %WER 16.44 [ 35459 / 215718, 4216 ins, 11278 del, 19965 sub ] exp/chain/tdnn_1b_sp/decode_test/cer_10_0.0 + +# steps/info/chain_dir_info.pl exp/chain/tdnn_1b_sp +# exp/chain/tdnn_1b_sp: num-iters=133 nj=2..12 num-params=12.5M dim=43+100->4528 combine=-0.073->-0.073 (over 2) xent:train/valid[87,132,final]=(-1.05,-0.964,-0.963/-1.10,-1.06,-1.05) logprob:train/valid[87,132,final]=(-0.079,-0.065,-0.065/-0.094,-0.092,-0.092) + +set -e + +# configs for 'chain' +affix=1b +nnet3_affix=_1b +stage=0 +train_stage=-10 +get_egs_stage=-10 +dir=exp/chain/tdnn # Note: _sp will get added to this +decode_iter= + +# training options +num_epochs=4 +initial_effective_lrate=0.001 +final_effective_lrate=0.0001 +max_param_change=2.0 +final_layer_normalize_target=0.5 +num_jobs_initial=2 +num_jobs_final=12 +minibatch_size=128 +frames_per_eg=150,110,90 +remove_egs=false +common_egs_dir= +xent_regularize=0.1 + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 9 ]; then + # Build a tree using our new topology. This is the critically different + # step compared with other recipes. + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 5000 data/$train_set $lang $ali_dir $treedir +fi + +if [ $stage -le 10 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=43 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=625 + relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1) dim=625 + relu-batchnorm-layer name=tdnn3 input=Append(-1,0,1) dim=625 + relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=625 + relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=625 + relu-batchnorm-layer name=tdnn6 input=Append(-3,0,3) dim=625 + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain input=tdnn6 dim=625 target-rms=0.5 + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn6 dim=625 target-rms=0.5 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 + +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 11 ]; then + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir exp/nnet3${nnet3_affix}/ivectors_${train_set} \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width $frames_per_eg \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.num-jobs-initial $num_jobs_initial \ + --trainer.optimization.num-jobs-final $num_jobs_final \ + --trainer.optimization.initial-effective-lrate $initial_effective_lrate \ + --trainer.optimization.final-effective-lrate $final_effective_lrate \ + --trainer.max-param-change $max_param_change \ + --cleanup.remove-egs $remove_egs \ + --feat-dir data/${train_set}_hires \ + --tree-dir $treedir \ + --lat-dir exp/tri5a_sp_lats \ + --use-gpu wait \ + --dir $dir || exit 1; +fi + +if [ $stage -le 12 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 data/lang_test $dir $dir/graph +fi + +graph_dir=$dir/graph +if [ $stage -le 13 ]; then + for test_set in test eval; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj 10 --cmd "$decode_cmd" \ + --online-ivector-dir exp/nnet3${nnet3_affix}/ivectors_$test_set \ + $graph_dir data/${test_set}_hires $dir/decode_${test_set} || exit 1; + done + wait; +fi +exit 0; diff --git a/egs/formosa/s5/local/chain/tuning/run_tdnn_1c.sh b/egs/formosa/s5/local/chain/tuning/run_tdnn_1c.sh new file mode 100755 index 00000000000..6fa10344cfc --- /dev/null +++ b/egs/formosa/s5/local/chain/tuning/run_tdnn_1c.sh @@ -0,0 +1,191 @@ +#!/bin/bash + +# CER: +# %WER 16.44 [ 35459 / 215718, 4216 ins, 11278 del, 19965 sub ] exp/chain/tdnn_1b_sp/decode_test/cer_10_0.0 +# %WER 13.72 [ 29605 / 215718, 4678 ins, 8066 del, 16861 sub ] exp/chain/tdnn_1c_sp/decode_test/cer_10_0.0 + +# steps/info/chain_dir_info.pl exp/chain/tdnn_1c_sp +# exp/chain/tdnn_1c_sp: num-iters=147 nj=3..16 num-params=17.9M dim=43+100->4528 combine=-0.041->-0.041 (over 2) xent:train/valid[97,146,final]=(-0.845,-0.625,-0.618/-0.901,-0.710,-0.703) logprob:train/valid[97,146,final]=(-0.064,-0.040,-0.039/-0.072,-0.058,-0.057) + +set -e + +# configs for 'chain' +affix=1c +nnet3_affix=_1b +stage=0 +train_stage=-10 +get_egs_stage=-10 +dir=exp/chain/tdnn # Note: _sp will get added to this +decode_iter= + +# training options +num_epochs=6 +initial_effective_lrate=0.00025 +final_effective_lrate=0.000025 +max_param_change=2.0 +final_layer_normalize_target=0.5 +num_jobs_initial=3 +num_jobs_final=16 +minibatch_size=64 +frames_per_eg=150,110,90 +remove_egs=false +common_egs_dir= +xent_regularize=0.1 +dropout_schedule='0,0@0.20,0.5@0.50,0' + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 9 ]; then + # Build a tree using our new topology. This is the critically different + # step compared with other recipes. + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 5000 data/$train_set $lang $ali_dir $treedir +fi + +if [ $stage -le 10 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + affine_opts="l2-regularize=0.01 dropout-proportion=0.0 dropout-per-dim=true dropout-per-dim-continuous=true" + tdnnf_opts="l2-regularize=0.01 dropout-proportion=0.0 bypass-scale=0.66" + linear_opts="l2-regularize=0.01 orthonormal-constraint=-1.0" + prefinal_opts="l2-regularize=0.01" + output_opts="l2-regularize=0.002" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=43 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-dropout-layer name=tdnn1 $affine_opts dim=1536 + tdnnf-layer name=tdnnf2 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=1 + tdnnf-layer name=tdnnf3 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=1 + tdnnf-layer name=tdnnf4 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=1 + tdnnf-layer name=tdnnf5 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=0 + tdnnf-layer name=tdnnf6 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf7 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf8 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf9 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf10 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf11 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf12 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf13 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf14 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf15 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + linear-component name=prefinal-l dim=256 $linear_opts + prefinal-layer name=prefinal-chain input=prefinal-l $prefinal_opts big-dim=1536 small-dim=256 + output-layer name=output include-log-softmax=false dim=$num_targets $output_opts + prefinal-layer name=prefinal-xent input=prefinal-l $prefinal_opts big-dim=1536 small-dim=256 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 11 ]; then + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir exp/nnet3$nnet3_affix/ivectors_${train_set} \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.0 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--frames-overlap-per-eg 0 --constrained false" \ + --egs.chunk-width $frames_per_eg \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.num-jobs-initial $num_jobs_initial \ + --trainer.optimization.num-jobs-final $num_jobs_final \ + --trainer.optimization.initial-effective-lrate $initial_effective_lrate \ + --trainer.optimization.final-effective-lrate $final_effective_lrate \ + --trainer.max-param-change $max_param_change \ + --cleanup.remove-egs $remove_egs \ + --feat-dir data/${train_set}_hires \ + --tree-dir $treedir \ + --lat-dir exp/tri5a_sp_lats \ + --use-gpu wait \ + --dir $dir || exit 1; +fi + +if [ $stage -le 12 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 data/lang_test $dir $dir/graph +fi + +graph_dir=$dir/graph +if [ $stage -le 13 ]; then + for test_set in test eval; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj 10 --cmd "$decode_cmd" \ + --online-ivector-dir exp/nnet3${nnet3_affix:+_$nnet3_affix}/ivectors_$test_set \ + $graph_dir data/${test_set}_hires $dir/decode_${test_set} || exit 1; + done + wait; +fi + +exit 0; diff --git a/egs/formosa/s5/local/chain/tuning/run_tdnn_1d.sh b/egs/formosa/s5/local/chain/tuning/run_tdnn_1d.sh new file mode 100755 index 00000000000..1f4b7e12850 --- /dev/null +++ b/egs/formosa/s5/local/chain/tuning/run_tdnn_1d.sh @@ -0,0 +1,190 @@ +#!/bin/bash + +# CER: +# 1a: %WER 16.83 [ 36305 / 215718, 4772 ins, 10810 del, 20723 sub ] exp/chain/tdnn_1a_sp/decode_test/cer_9_0.0 +# 1d: %WER 14.08 [ 30364 / 215718, 5182 ins, 7588 del, 17594 sub ] exp/chain/tdnn_1d_sp/decode_test/cer_9_0.0 + +# steps/info/chain_dir_info.pl exp/chain/tdnn_1d_sp +# exp/chain/tdnn_1d_sp: num-iters=157 nj=3..16 num-params=18.6M dim=43+100->5792 combine=-0.050->-0.050 (over 1) xent:train/valid[103,156,final]=(-0.977,-0.735,-0.725/-0.953,-0.772,-0.768) logprob:train/valid[103,156,final]=(-0.077,-0.052,-0.052/-0.079,-0.065,-0.066) + +set -e + +# configs for 'chain' +affix=1d +stage=0 +train_stage=-10 +get_egs_stage=-10 +dir=exp/chain/tdnn # Note: _sp will get added to this +decode_iter= + +# training options +num_epochs=6 +initial_effective_lrate=0.00025 +final_effective_lrate=0.000025 +max_param_change=2.0 +final_layer_normalize_target=0.5 +num_jobs_initial=3 +num_jobs_final=16 +minibatch_size=64 +frames_per_eg=150,110,90 +remove_egs=false +common_egs_dir= +xent_regularize=0.1 +dropout_schedule='0,0@0.20,0.5@0.50,0' + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 9 ]; then + # Build a tree using our new topology. This is the critically different + # step compared with other recipes. + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 7000 data/$train_set $lang $ali_dir $treedir +fi + +if [ $stage -le 10 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + affine_opts="l2-regularize=0.01 dropout-proportion=0.0 dropout-per-dim=true dropout-per-dim-continuous=true" + tdnnf_opts="l2-regularize=0.01 dropout-proportion=0.0 bypass-scale=0.66" + linear_opts="l2-regularize=0.01 orthonormal-constraint=-1.0" + prefinal_opts="l2-regularize=0.01" + output_opts="l2-regularize=0.002" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=43 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-dropout-layer name=tdnn1 $affine_opts dim=1536 + tdnnf-layer name=tdnnf2 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=1 + tdnnf-layer name=tdnnf3 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=1 + tdnnf-layer name=tdnnf4 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=1 + tdnnf-layer name=tdnnf5 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=0 + tdnnf-layer name=tdnnf6 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf7 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf8 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf9 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf10 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf11 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf12 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf13 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf14 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf15 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + linear-component name=prefinal-l dim=256 $linear_opts + prefinal-layer name=prefinal-chain input=prefinal-l $prefinal_opts big-dim=1536 small-dim=256 + output-layer name=output include-log-softmax=false dim=$num_targets $output_opts + prefinal-layer name=prefinal-xent input=prefinal-l $prefinal_opts big-dim=1536 small-dim=256 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 11 ]; then + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir exp/nnet3$nnet3_affix/ivectors_${train_set} \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.0 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--frames-overlap-per-eg 0 --constrained false" \ + --egs.chunk-width $frames_per_eg \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.num-jobs-initial $num_jobs_initial \ + --trainer.optimization.num-jobs-final $num_jobs_final \ + --trainer.optimization.initial-effective-lrate $initial_effective_lrate \ + --trainer.optimization.final-effective-lrate $final_effective_lrate \ + --trainer.max-param-change $max_param_change \ + --cleanup.remove-egs $remove_egs \ + --feat-dir data/${train_set}_hires \ + --tree-dir $treedir \ + --lat-dir exp/tri5a_sp_lats \ + --use-gpu wait \ + --dir $dir || exit 1; +fi + +if [ $stage -le 12 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 data/lang_test $dir $dir/graph +fi + +graph_dir=$dir/graph +if [ $stage -le 13 ]; then + for test_set in test eval; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj 10 --cmd "$decode_cmd" \ + --online-ivector-dir exp/nnet3${nnet3_affix:+_$nnet3_affix}/ivectors_$test_set \ + $graph_dir data/${test_set}_hires $dir/decode_${test_set} || exit 1; + done + wait; +fi + +exit 0; diff --git a/egs/formosa/s5/local/nnet3/run_ivector_common.sh b/egs/formosa/s5/local/nnet3/run_ivector_common.sh new file mode 100755 index 00000000000..723589ddd2e --- /dev/null +++ b/egs/formosa/s5/local/nnet3/run_ivector_common.sh @@ -0,0 +1,145 @@ +#!/bin/bash + +set -euo pipefail + +# This script is modified based on mini_librispeech/s5/local/nnet3/run_ivector_common.sh + +# This script is called from local/nnet3/run_tdnn.sh and +# local/chain/run_tdnn.sh (and may eventually be called by more +# scripts). It contains the common feature preparation and +# iVector-related parts of the script. See those scripts for examples +# of usage. + +stage=0 +train_set=train +test_sets="test eval" +gmm=tri5a + +nnet3_affix= + +. ./cmd.sh +. ./path.sh +. utils/parse_options.sh + +gmm_dir=exp/${gmm} +ali_dir=exp/${gmm}_sp_ali + +for f in data/${train_set}/feats.scp ${gmm_dir}/final.mdl; do + if [ ! -f $f ]; then + echo "$0: expected file $f to exist" + exit 1 + fi +done + +if [ $stage -le 1 ]; then + # Although the nnet will be trained by high resolution data, we still have to + # perturb the normal data to get the alignment _sp stands for speed-perturbed + echo "$0: preparing directory for low-resolution speed-perturbed data (for alignment)" + utils/data/perturb_data_dir_speed_3way.sh data/${train_set} data/${train_set}_sp + echo "$0: making MFCC features for low-resolution speed-perturbed data" + steps/make_mfcc_pitch.sh --cmd "$train_cmd" --nj 70 data/${train_set}_sp \ + exp/make_mfcc/${train_set}_sp mfcc_perturbed || exit 1; + steps/compute_cmvn_stats.sh data/${train_set}_sp \ + exp/make_mfcc/${train_set}_sp mfcc_perturbed || exit 1; + utils/fix_data_dir.sh data/${train_set}_sp +fi + +if [ $stage -le 2 ]; then + echo "$0: aligning with the perturbed low-resolution data" + steps/align_fmllr.sh --nj 30 --cmd "$train_cmd" \ + data/${train_set}_sp data/lang $gmm_dir $ali_dir || exit 1 +fi + +if [ $stage -le 3 ]; then + # Create high-resolution MFCC features (with 40 cepstra instead of 13). + # this shows how you can split across multiple file-systems. + echo "$0: creating high-resolution MFCC features" + mfccdir=mfcc_perturbed_hires + + for datadir in ${train_set}_sp ${test_sets}; do + utils/copy_data_dir.sh data/$datadir data/${datadir}_hires + done + + # do volume-perturbation on the training data prior to extracting hires + # features; this helps make trained nnets more invariant to test data volume. + utils/data/perturb_data_dir_volume.sh data/${train_set}_sp_hires || exit 1; + + for datadir in ${train_set}_sp ${test_sets}; do + steps/make_mfcc_pitch.sh --nj 10 --mfcc-config conf/mfcc_hires.conf \ + --cmd "$train_cmd" data/${datadir}_hires exp/make_hires/$datadir $mfccdir || exit 1; + steps/compute_cmvn_stats.sh data/${datadir}_hires exp/make_hires/$datadir $mfccdir || exit 1; + utils/fix_data_dir.sh data/${datadir}_hires || exit 1; + # create MFCC data dir without pitch to extract iVector + utils/data/limit_feature_dim.sh 0:39 data/${datadir}_hires data/${datadir}_hires_nopitch || exit 1; + steps/compute_cmvn_stats.sh data/${datadir}_hires_nopitch exp/make_hires/$datadir $mfccdir || exit 1; + done +fi + +if [ $stage -le 4 ]; then + echo "$0: computing a subset of data to train the diagonal UBM." + # We'll use about a quarter of the data. + mkdir -p exp/nnet3${nnet3_affix}/diag_ubm + temp_data_root=exp/nnet3${nnet3_affix}/diag_ubm + + num_utts_total=$(wc -l $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=43 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-2,-1,0,1,2,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=850 + relu-batchnorm-layer name=tdnn2 dim=850 input=Append(-1,0,2) + relu-batchnorm-layer name=tdnn3 dim=850 input=Append(-3,0,3) + relu-batchnorm-layer name=tdnn4 dim=850 input=Append(-7,0,2) + relu-batchnorm-layer name=tdnn5 dim=850 input=Append(-3,0,3) + relu-batchnorm-layer name=tdnn6 dim=850 + output-layer name=output input=tdnn6 dim=$num_targets max-change=1.5 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 8 ]; then + steps/nnet3/train_dnn.py --stage=$train_stage \ + --cmd="$decode_cmd" \ + --feat.online-ivector-dir exp/nnet3/ivectors_${train_set} \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.num-jobs-initial $num_jobs_initial \ + --trainer.optimization.num-jobs-final $num_jobs_final \ + --trainer.optimization.initial-effective-lrate $initial_effective_lrate \ + --trainer.optimization.final-effective-lrate $final_effective_lrate \ + --egs.dir "$common_egs_dir" \ + --cleanup.remove-egs $remove_egs \ + --cleanup.preserve-model-interval 500 \ + --use-gpu wait \ + --feat-dir=data/${train_set}_hires \ + --ali-dir $ali_dir \ + --lang data/lang \ + --reporting.email="$reporting_email" \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 9 ]; then + # this version of the decoding treats each utterance separately + # without carrying forward speaker information. + + for decode_set in test eval; do + num_jobs=`cat data/${decode_set}_hires/utt2spk|cut -d' ' -f2|sort -u|wc -l` + decode_dir=${dir}/decode_$decode_set + steps/nnet3/decode.sh --nj $num_jobs --cmd "$decode_cmd" \ + --online-ivector-dir exp/nnet3/ivectors_${decode_set} \ + $graph_dir data/${decode_set}_hires $decode_dir || exit 1; + done + wait; +fi + +exit 0; diff --git a/egs/formosa/s5/local/prepare_data.sh b/egs/formosa/s5/local/prepare_data.sh new file mode 100755 index 00000000000..68f342e1549 --- /dev/null +++ b/egs/formosa/s5/local/prepare_data.sh @@ -0,0 +1,60 @@ +#!/bin/bash +# Copyright 2015-2016 Sarah Flora Juan +# Copyright 2016 Johns Hopkins University (Author: Yenda Trmal) +# Copyright 2018 Yuan-Fu Liao, National Taipei University of Technology +# AsusTek Computer Inc. (Author: Alex Hung) + +# Apache 2.0 + +set -e -o pipefail + +train_dir=NER-Trs-Vol1/Train +eval_dir=NER-Trs-Vol1-Eval +eval_key_dir=NER-Trs-Vol1-Eval-Key + +. ./path.sh +. parse_options.sh + +for x in $train_dir $eval_dir; do + if [ ! -d "$x" ] ; then + echo >&2 "The directory $x does not exist" + fi +done + +if [ -z "$(command -v dos2unix 2>/dev/null)" ]; then + echo "dos2unix not found on PATH. Please install it manually." + exit 1; +fi + +# have to remvoe previous files to avoid filtering speakers according to cmvn.scp and feats.scp +rm -rf data/all data/train data/test data/eval data/local/train +mkdir -p data/all data/train data/test data/eval data/local/train + + +# make utt2spk, wav.scp and text +find $train_dir -name *.wav -exec sh -c 'x={}; y=$(basename -s .wav $x); printf "%s %s\n" $y $y' \; | dos2unix > data/all/utt2spk +find $train_dir -name *.wav -exec sh -c 'x={}; y=$(basename -s .wav $x); printf "%s %s\n" $y $x' \; | dos2unix > data/all/wav.scp +find $train_dir -name *.txt -exec sh -c 'x={}; y=$(basename -s .txt $x); printf "%s " $y; cat $x' \; | dos2unix > data/all/text + +# fix_data_dir.sh fixes common mistakes (unsorted entries in wav.scp, +# duplicate entries and so on). Also, it regenerates the spk2utt from +# utt2spk +utils/fix_data_dir.sh data/all + +echo "Preparing train and test data" +# test set: JZ, GJ, KX, YX +grep -E "(JZ|GJ|KX|YX)_" data/all/utt2spk | awk '{print $1}' > data/all/cv.spk +utils/subset_data_dir_tr_cv.sh --cv-spk-list data/all/cv.spk data/all data/train data/test + +# for LM training +echo "cp data/train/text data/local/train/text for language model training" +cat data/train/text | awk '{$1=""}1;' | awk '{$1=$1}1;' > data/local/train/text + +# preparing EVAL set. +find $eval_dir -name *.wav -exec sh -c 'x={}; y=$(basename -s .wav $x); printf "%s %s\n" $y $y' \; | dos2unix > data/eval/utt2spk +find $eval_dir -name *.wav -exec sh -c 'x={}; y=$(basename -s .wav $x); printf "%s %s\n" $y $x' \; | dos2unix > data/eval/wav.scp +find $eval_key_dir -name *.txt -exec sh -c 'x={}; y=$(basename -s .txt $x); printf "%s " $y; cat $x' \; | dos2unix > data/eval/text +utils/fix_data_dir.sh data/eval + +echo "Data preparation completed." +exit 0; diff --git a/egs/formosa/s5/local/prepare_dict.sh b/egs/formosa/s5/local/prepare_dict.sh new file mode 100755 index 00000000000..4e580f5f6e8 --- /dev/null +++ b/egs/formosa/s5/local/prepare_dict.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright 2015-2016 Sarah Flora Juan +# Copyright 2016 Johns Hopkins University (Author: Yenda Trmal) +# Copyright 2018 Yuan-Fu Liao, National Taipei University of Technology +# Apache 2.0 + +source_dir=NER-Trs-Vol1/Language +dict_dir=data/local/dict +rm -rf $dict_dir +mkdir -p $dict_dir + +# +# +# +rm -f $dict_dir/lexicon.txt +touch $dict_dir/lexicon.txt +cat $source_dir/lexicon.txt > $dict_dir/lexicon.txt +echo " SIL" >> $dict_dir/lexicon.txt + +# +# define silence phone +# +rm -f $dict_dir/silence_phones.txt +touch $dict_dir/silence_phones.txt + +echo "SIL" > $dict_dir/silence_phones.txt + +# +# find nonsilence phones +# +rm -f $dict_dir/nonsilence_phones.txt +touch $dict_dir/nonsilence_phones.txt + +cat $source_dir/lexicon.txt | grep -v -F -f $dict_dir/silence_phones.txt | \ + perl -ane 'print join("\n", @F[1..$#F]) . "\n"; ' | \ + sort -u > $dict_dir/nonsilence_phones.txt + +# +# add optional silence phones +# + +rm -f $dict_dir/optional_silence.txt +touch $dict_dir/optional_silence.txt +echo "SIL" > $dict_dir/optional_silence.txt + +# +# extra questions +# +rm -f $dict_dir/extra_questions.txt +touch $dict_dir/extra_questions.txt +cat $dict_dir/silence_phones.txt | awk '{printf("%s ", $1);} END{printf "\n";}' > $dict_dir/extra_questions.txt || exit 1; +cat $dict_dir/nonsilence_phones.txt | awk '{printf("%s ", $1);} END{printf "\n";}' >> $dict_dir/extra_questions.txt || exit 1; + +echo "Dictionary preparation succeeded" +exit 0; diff --git a/egs/formosa/s5/local/prepare_lm.sh b/egs/formosa/s5/local/prepare_lm.sh new file mode 100755 index 00000000000..59fe1529658 --- /dev/null +++ b/egs/formosa/s5/local/prepare_lm.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# Copyright 2015-2016 Sarah Flora Juan +# Copyright 2016 Johns Hopkins University (Author: Yenda Trmal) +# Apache 2.0 + +set -e -o pipefail + +# To create G.fst from ARPA language model +. ./path.sh || die "path.sh expected"; + +local/train_lms_srilm.sh --train-text data/train/text data/ data/srilm + +#nl -nrz -w10 corpus/LM/iban-bp-2012.txt | utils/shuffle_list.pl > data/local/external_text +local/train_lms_srilm.sh --train-text data/local/external_text data/ data/srilm_external + +# let's do ngram interpolation of the previous two LMs +# the lm.gz is always symlink to the model with the best perplexity, so we use that + +mkdir -p data/srilm_interp +for w in 0.9 0.8 0.7 0.6 0.5; do + ngram -lm data/srilm/lm.gz -mix-lm data/srilm_external/lm.gz \ + -lambda $w -write-lm data/srilm_interp/lm.${w}.gz + echo -n "data/srilm_interp/lm.${w}.gz " + ngram -lm data/srilm_interp/lm.${w}.gz -ppl data/srilm/dev.txt | paste -s - +done | sort -k15,15g > data/srilm_interp/perplexities.txt + +# for basic decoding, let's use only a trigram LM +[ -d data/lang_test/ ] && rm -rf data/lang_test +cp -R data/lang data/lang_test +lm=$(cat data/srilm/perplexities.txt | grep 3gram | head -n1 | awk '{print $1}') +local/arpa2G.sh $lm data/lang_test data/lang_test + +# for decoding using bigger LM let's find which interpolated gave the most improvement +[ -d data/lang_big ] && rm -rf data/lang_big +cp -R data/lang data/lang_big +lm=$(cat data/srilm_interp/perplexities.txt | head -n1 | awk '{print $1}') +local/arpa2G.sh $lm data/lang_big data/lang_big + +# for really big lm, we should only decode using small LM +# and resocre using the big lm +utils/build_const_arpa_lm.sh $lm data/lang_big data/lang_big +exit 0; diff --git a/egs/formosa/s5/local/run_cleanup_segmentation.sh b/egs/formosa/s5/local/run_cleanup_segmentation.sh new file mode 100755 index 00000000000..b72cd89b4d1 --- /dev/null +++ b/egs/formosa/s5/local/run_cleanup_segmentation.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +# Copyright 2016 Vimal Manohar +# 2016 Johns Hopkins University (author: Daniel Povey) +# 2017 Nagendra Kumar Goel +# 2019 AsusTek Computer Inc. (author: Alex Hung) +# Apache 2.0 + +# This script demonstrates how to re-segment training data selecting only the +# "good" audio that matches the transcripts. +# The basic idea is to decode with an existing in-domain acoustic model, and a +# biased language model built from the reference, and then work out the +# segmentation from a ctm like file. + +# For nnet3 and chain results after cleanup, see the scripts in +# local/nnet3/run_tdnn.sh and local/chain/run_tdnn.sh + +# GMM Results for speaker-independent (SI) and speaker adaptive training (SAT) systems on dev and test sets +# [will add these later]. + +set -e +set -o pipefail +set -u + +stage=0 +cleanup_stage=0 +data=data/train +cleanup_affix=cleaned +srcdir=exp/tri5a +langdir=data/lang_test +nj=20 +decode_nj=20 +decode_num_threads=1 + +. ./cmd.sh +if [ -f ./path.sh ]; then . ./path.sh; fi +. utils/parse_options.sh + +cleaned_data=${data}_${cleanup_affix} + +dir=${srcdir}_${cleanup_affix}_work +cleaned_dir=${srcdir}_${cleanup_affix} + +if [ $stage -le 1 ]; then + # This does the actual data cleanup. + steps/cleanup/clean_and_segment_data.sh --stage $cleanup_stage \ + --nj $nj --cmd "$train_cmd" \ + $data $langdir $srcdir $dir $cleaned_data +fi + +if [ $stage -le 2 ]; then + steps/align_fmllr.sh --nj $nj --cmd "$train_cmd" \ + $cleaned_data $langdir $srcdir ${srcdir}_ali_${cleanup_affix} +fi + +if [ $stage -le 3 ]; then + steps/train_sat.sh --cmd "$train_cmd" \ + 3500 100000 $cleaned_data $langdir ${srcdir}_ali_${cleanup_affix} ${cleaned_dir} +fi + +utils/data/get_utt2dur.sh data/train_cleaned +ori_avg_dur=$(awk 'BEGIN{total=0}{total += $2}END{printf("%.2f", total/NR)}' ${data}/utt2dur) +new_avg_dur=$(awk 'BEGIN{total=0}{total += $2}END{printf("%.2f", total/NR)}' ${cleaned_data}/utt2dur) +echo "average duration was reduced from ${ori_avg_dur}s to ${new_avg_dur}s." +# average duration was reduced from 21.68s to 10.97s. +exit 0; diff --git a/egs/formosa/s5/local/score.sh b/egs/formosa/s5/local/score.sh new file mode 100755 index 00000000000..a9786169973 --- /dev/null +++ b/egs/formosa/s5/local/score.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +set -e -o pipefail +set -x +steps/score_kaldi.sh "$@" +steps/scoring/score_kaldi_cer.sh --stage 2 "$@" + +echo "$0: Done" diff --git a/egs/formosa/s5/local/train_lms.sh b/egs/formosa/s5/local/train_lms.sh new file mode 100755 index 00000000000..efc5b92c573 --- /dev/null +++ b/egs/formosa/s5/local/train_lms.sh @@ -0,0 +1,63 @@ +#!/bin/bash + + +# To be run from one directory above this script. +. ./path.sh + +text=data/local/train/text +lexicon=data/local/dict/lexicon.txt + +for f in "$text" "$lexicon"; do + [ ! -f $x ] && echo "$0: No such file $f" && exit 1; +done + +# This script takes no arguments. It assumes you have already run +# aishell_data_prep.sh. +# It takes as input the files +# data/local/train/text +# data/local/dict/lexicon.txt +dir=data/local/lm +mkdir -p $dir + +kaldi_lm=`which train_lm.sh` +if [ -z $kaldi_lm ]; then + echo "$0: train_lm.sh is not found. That might mean it's not installed" + echo "$0: or it is not added to PATH" + echo "$0: Use the script tools/extra/install_kaldi_lm.sh to install it" + exit 1 +fi + +cleantext=$dir/text.no_oov + +cat $text | awk -v lex=$lexicon 'BEGIN{while((getline0){ seen[$1]=1; } } + {for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf(" ");} } printf("\n");}' \ + > $cleantext || exit 1; + +cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort | uniq -c | \ + sort -nr > $dir/word.counts || exit 1; + +# Get counts from acoustic training transcripts, and add one-count +# for each word in the lexicon (but not silence, we don't want it +# in the LM-- we'll add it optionally later). +cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | \ + cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | \ + sort | uniq -c | sort -nr > $dir/unigram.counts || exit 1; + +# note: we probably won't really make use of as there aren't any OOVs +cat $dir/unigram.counts | awk '{print $2}' | get_word_map.pl "" "" "" > $dir/word_map \ + || exit 1; + +# note: ignore 1st field of train.txt, it's the utterance-id. +cat $cleantext | awk -v wmap=$dir/word_map 'BEGIN{while((getline0)map[$1]=$2;} + { for(n=2;n<=NF;n++) { printf map[$n]; if(n$dir/train.gz \ + || exit 1; + +train_lm.sh --arpa --lmtype 3gram-mincount $dir || exit 1; + +# LM is small enough that we don't need to prune it (only about 0.7M N-grams). +# Perplexity over 128254.000000 words is 90.446690 + +# note: output is +# data/local/lm/3gram-mincount/lm_unpruned.gz + +exit 0; diff --git a/egs/formosa/s5/local/wer_hyp_filter b/egs/formosa/s5/local/wer_hyp_filter new file mode 100755 index 00000000000..519d92ee80d --- /dev/null +++ b/egs/formosa/s5/local/wer_hyp_filter @@ -0,0 +1,19 @@ +#!/usr/bin/env perl + +@filters=(''); + +foreach $w (@filters) { + $bad{$w} = 1; +} + +while() { + @A = split(" ", $_); + $id = shift @A; + print "$id "; + foreach $a (@A) { + if (!defined $bad{$a}) { + print "$a "; + } + } + print "\n"; +} diff --git a/egs/formosa/s5/local/wer_output_filter b/egs/formosa/s5/local/wer_output_filter new file mode 100755 index 00000000000..06a99a43e34 --- /dev/null +++ b/egs/formosa/s5/local/wer_output_filter @@ -0,0 +1,25 @@ +#!/usr/bin/env perl +# Copyright 2012-2014 Johns Hopkins University (Author: Yenda Trmal) +# Apache 2.0 +use utf8; + +use open qw(:encoding(utf8)); +binmode STDIN, ":utf8"; +binmode STDOUT, ":utf8"; +binmode STDERR, ":utf8"; + +while (<>) { + @F = split " "; + print $F[0] . " "; + foreach $s (@F[1..$#F]) { + if (($s =~ /\[.*\]/) || ($s =~ /\<.*\>/) || ($s =~ "")) { + print ""; + } else { + print "$s" + } + print " "; + } + print "\n"; +} + + diff --git a/egs/formosa/s5/local/wer_ref_filter b/egs/formosa/s5/local/wer_ref_filter new file mode 100755 index 00000000000..519d92ee80d --- /dev/null +++ b/egs/formosa/s5/local/wer_ref_filter @@ -0,0 +1,19 @@ +#!/usr/bin/env perl + +@filters=(''); + +foreach $w (@filters) { + $bad{$w} = 1; +} + +while() { + @A = split(" ", $_); + $id = shift @A; + print "$id "; + foreach $a (@A) { + if (!defined $bad{$a}) { + print "$a "; + } + } + print "\n"; +} diff --git a/egs/formosa/s5/path.sh b/egs/formosa/s5/path.sh new file mode 100755 index 00000000000..2d17b17a84a --- /dev/null +++ b/egs/formosa/s5/path.sh @@ -0,0 +1,6 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C diff --git a/egs/formosa/s5/run.sh b/egs/formosa/s5/run.sh new file mode 100755 index 00000000000..a4d0f2dcd1d --- /dev/null +++ b/egs/formosa/s5/run.sh @@ -0,0 +1,217 @@ +#!/bin/bash +# +# Copyright 2018, Yuan-Fu Liao, National Taipei University of Technology, yfliao@mail.ntut.edu.tw +# +# Before you run this recipe, please apply, download and put or make a link of the corpus under this folder (folder name: "NER-Trs-Vol1"). +# For more detail, please check: +# 1. Formosa Speech in the Wild (FSW) project (https://sites.google.com/speech.ntut.edu.tw/fsw/home/corpus) +# 2. Formosa Speech Recognition Challenge (FSW) 2018 (https://sites.google.com/speech.ntut.edu.tw/fsw/home/challenge) +stage=-2 +num_jobs=20 + +train_dir=NER-Trs-Vol1/Train +eval_dir=NER-Trs-Vol1-Eval +eval_key_dir=NER-Trs-Vol1-Eval-Key + +# shell options +set -eo pipefail + +. ./cmd.sh +. ./utils/parse_options.sh + +# configure number of jobs running in parallel, you should adjust these numbers according to your machines +# data preparation +if [ $stage -le -2 ]; then + # Lexicon Preparation, + echo "$0: Lexicon Preparation" + local/prepare_dict.sh || exit 1; + + # Data Preparation + echo "$0: Data Preparation" + local/prepare_data.sh --train-dir $train_dir --eval-dir $eval_dir --eval-key-dir $eval_key_dir || exit 1; + + # Phone Sets, questions, L compilation + echo "$0: Phone Sets, questions, L compilation Preparation" + rm -rf data/lang + utils/prepare_lang.sh --position-dependent-phones false data/local/dict \ + "" data/local/lang data/lang || exit 1; + + # LM training + echo "$0: LM training" + rm -rf data/local/lm/3gram-mincount + local/train_lms.sh || exit 1; + + # G compilation, check LG composition + echo "$0: G compilation, check LG composition" + utils/format_lm.sh data/lang data/local/lm/3gram-mincount/lm_unpruned.gz \ + data/local/dict/lexicon.txt data/lang_test || exit 1; + +fi + +# Now make MFCC plus pitch features. +# mfccdir should be some place with a largish disk where you +# want to store MFCC features. +mfccdir=mfcc + +# mfcc +if [ $stage -le -1 ]; then + echo "$0: making mfccs" + for x in train test eval; do + steps/make_mfcc_pitch.sh --cmd "$train_cmd" --nj $num_jobs data/$x exp/make_mfcc/$x $mfccdir || exit 1; + steps/compute_cmvn_stats.sh data/$x exp/make_mfcc/$x $mfccdir || exit 1; + utils/fix_data_dir.sh data/$x || exit 1; + done +fi + +# mono +if [ $stage -le 0 ]; then + echo "$0: train mono model" + # Make some small data subsets for early system-build stages. + echo "$0: make training subsets" + utils/subset_data_dir.sh --shortest data/train 3000 data/train_mono + + # train mono + steps/train_mono.sh --boost-silence 1.25 --cmd "$train_cmd" --nj $num_jobs \ + data/train_mono data/lang exp/mono || exit 1; + + # Get alignments from monophone system. + steps/align_si.sh --boost-silence 1.25 --cmd "$train_cmd" --nj $num_jobs \ + data/train data/lang exp/mono exp/mono_ali || exit 1; + + # Monophone decoding + ( + utils/mkgraph.sh data/lang_test exp/mono exp/mono/graph || exit 1; + steps/decode.sh --cmd "$decode_cmd" --config conf/decode.config --nj $num_jobs \ + exp/mono/graph data/test exp/mono/decode_test + )& +fi + +# tri1 +if [ $stage -le 1 ]; then + echo "$0: train tri1 model" + # train tri1 [first triphone pass] + steps/train_deltas.sh --boost-silence 1.25 --cmd "$train_cmd" \ + 2500 20000 data/train data/lang exp/mono_ali exp/tri1 || exit 1; + + # align tri1 + steps/align_si.sh --cmd "$train_cmd" --nj $num_jobs \ + data/train data/lang exp/tri1 exp/tri1_ali || exit 1; + + # decode tri1 + ( + utils/mkgraph.sh data/lang_test exp/tri1 exp/tri1/graph || exit 1; + steps/decode.sh --cmd "$decode_cmd" --config conf/decode.config --nj $num_jobs \ + exp/tri1/graph data/test exp/tri1/decode_test + )& +fi + +# tri2 +if [ $stage -le 2 ]; then + echo "$0: train tri2 model" + # train tri2 [delta+delta-deltas] + steps/train_deltas.sh --cmd "$train_cmd" \ + 2500 20000 data/train data/lang exp/tri1_ali exp/tri2 || exit 1; + + # align tri2b + steps/align_si.sh --cmd "$train_cmd" --nj $num_jobs \ + data/train data/lang exp/tri2 exp/tri2_ali || exit 1; + + # decode tri2 + ( + utils/mkgraph.sh data/lang_test exp/tri2 exp/tri2/graph + steps/decode.sh --cmd "$decode_cmd" --config conf/decode.config --nj $num_jobs \ + exp/tri2/graph data/test exp/tri2/decode_test + )& +fi + +# tri3a +if [ $stage -le 3 ]; then + echo "$-: train tri3 model" + # Train tri3a, which is LDA+MLLT, + steps/train_lda_mllt.sh --cmd "$train_cmd" \ + 2500 20000 data/train data/lang exp/tri2_ali exp/tri3a || exit 1; + + # decode tri3a + ( + utils/mkgraph.sh data/lang_test exp/tri3a exp/tri3a/graph || exit 1; + steps/decode.sh --cmd "$decode_cmd" --nj $num_jobs --config conf/decode.config \ + exp/tri3a/graph data/test exp/tri3a/decode_test + )& +fi + +# tri4 +if [ $stage -le 4 ]; then + echo "$0: train tri4 model" + # From now, we start building a more serious system (with SAT), and we'll + # do the alignment with fMLLR. + steps/align_fmllr.sh --cmd "$train_cmd" --nj $num_jobs \ + data/train data/lang exp/tri3a exp/tri3a_ali || exit 1; + + steps/train_sat.sh --cmd "$train_cmd" \ + 2500 20000 data/train data/lang exp/tri3a_ali exp/tri4a || exit 1; + + # align tri4a + steps/align_fmllr.sh --cmd "$train_cmd" --nj $num_jobs \ + data/train data/lang exp/tri4a exp/tri4a_ali + + # decode tri4a + ( + utils/mkgraph.sh data/lang_test exp/tri4a exp/tri4a/graph + steps/decode_fmllr.sh --cmd "$decode_cmd" --nj $num_jobs --config conf/decode.config \ + exp/tri4a/graph data/test exp/tri4a/decode_test + )& +fi + +# tri5 +if [ $stage -le 5 ]; then + echo "$0: train tri5 model" + # Building a larger SAT system. + steps/train_sat.sh --cmd "$train_cmd" \ + 3500 100000 data/train data/lang exp/tri4a_ali exp/tri5a || exit 1; + + # align tri5a + steps/align_fmllr.sh --cmd "$train_cmd" --nj $num_jobs \ + data/train data/lang exp/tri5a exp/tri5a_ali || exit 1; + + # decode tri5 + ( + utils/mkgraph.sh data/lang_test exp/tri5a exp/tri5a/graph || exit 1; + steps/decode_fmllr.sh --cmd "$decode_cmd" --nj $num_jobs --config conf/decode.config \ + exp/tri5a/graph data/test exp/tri5a/decode_test || exit 1; + )& +fi + +# nnet3 tdnn models +# commented out by default, since the chain model is usually faster and better +#if [ $stage -le 6 ]; then + # echo "$0: train nnet3 model" + # local/nnet3/run_tdnn.sh +#fi + +# chain model +if [ $stage -le 7 ]; then + # The iVector-extraction and feature-dumping parts coulb be skipped by setting "--train_stage 7" + echo "$0: train chain model" + local/chain/run_tdnn.sh +fi + +# getting results (see RESULTS file) +if [ $stage -le 8 ]; then + echo "$0: extract the results" + for test_set in test eval; do + echo "WER: $test_set" + for x in exp/*/decode_${test_set}*; do [ -d $x ] && grep WER $x/wer_* | utils/best_wer.sh; done 2>/dev/null + for x in exp/*/*/decode_${test_set}*; do [ -d $x ] && grep WER $x/wer_* | utils/best_wer.sh; done 2>/dev/null + echo + + echo "CER: $test_set" + for x in exp/*/decode_${test_set}*; do [ -d $x ] && grep WER $x/cer_* | utils/best_wer.sh; done 2>/dev/null + for x in exp/*/*/decode_${test_set}*; do [ -d $x ] && grep WER $x/cer_* | utils/best_wer.sh; done 2>/dev/null + echo + done +fi + +# finish +echo "$0: all done" + +exit 0; diff --git a/egs/formosa/s5/steps b/egs/formosa/s5/steps new file mode 120000 index 00000000000..6e99bf5b5ad --- /dev/null +++ b/egs/formosa/s5/steps @@ -0,0 +1 @@ +../../wsj/s5/steps \ No newline at end of file diff --git a/egs/formosa/s5/utils b/egs/formosa/s5/utils new file mode 120000 index 00000000000..b240885218f --- /dev/null +++ b/egs/formosa/s5/utils @@ -0,0 +1 @@ +../../wsj/s5/utils \ No newline at end of file diff --git a/egs/gale_arabic/s5/local/gale_format_data.sh b/egs/gale_arabic/s5/local/gale_format_data.sh index 85a946a58d9..053323dc194 100755 --- a/egs/gale_arabic/s5/local/gale_format_data.sh +++ b/egs/gale_arabic/s5/local/gale_format_data.sh @@ -57,4 +57,4 @@ fsttablecompose data/lang/L_disambig.fst data/lang_test/G.fst | \ echo gale_format_data succeeded. -exit 0 \ No newline at end of file +exit 0 diff --git a/egs/gale_arabic/s5/local/gale_prep_dict.sh b/egs/gale_arabic/s5/local/gale_prep_dict.sh index 74ef789eda7..f6fd83378d0 100755 --- a/egs/gale_arabic/s5/local/gale_prep_dict.sh +++ b/egs/gale_arabic/s5/local/gale_prep_dict.sh @@ -25,9 +25,8 @@ echo SIL > $dir/optional_silence.txt cat $dir/lexicon.txt | cut -d ' ' -f2- | tr -s ' ' '\n' |\ sort -u > $dir/nonsilence_phones.txt || exit 1; +perl -i -pe 'print " SIL\n" if $.==1' $dir/lexicon.txt - sed -i '1i SIL' $dir/lexicon.txt - echo Dictionary preparation succeeded exit 0 diff --git a/egs/gale_arabic/s5/local/gale_train_lms.sh b/egs/gale_arabic/s5/local/gale_train_lms.sh index 1b5d4665a19..8f8e715390f 100755 --- a/egs/gale_arabic/s5/local/gale_train_lms.sh +++ b/egs/gale_arabic/s5/local/gale_train_lms.sh @@ -113,4 +113,4 @@ fi echo train lm succeeded -exit 0 \ No newline at end of file +exit 0 diff --git a/egs/gale_arabic/s5/local/run_sgmm.sh b/egs/gale_arabic/s5/local/run_sgmm.sh index f9ba9b193a8..a5d32d18038 100755 --- a/egs/gale_arabic/s5/local/run_sgmm.sh +++ b/egs/gale_arabic/s5/local/run_sgmm.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/bin/bash . ./path.sh @@ -10,17 +10,17 @@ nDecodeJobs=40 galeData=GALE mfccdir=mfcc - -if [[ ! -d exp/tri3b_ali ]]; then + +if [[ ! -d exp/tri3b_ali ]]; then echo "exp/tri3b_ali lattices are required for alignmnet" - exit 1 + exit 1 fi ## SGMM (subspace gaussian mixture model), excluding the "speaker-dependent weights" steps/train_ubm.sh --cmd "$train_cmd" 700 \ data/train data/lang exp/tri3b_ali exp/ubm5a || exit 1; - + steps/train_sgmm2.sh --cmd "$train_cmd" 5000 20000 data/train data/lang exp/tri3b_ali \ exp/ubm5a/final.ubm exp/sgmm_5a || exit 1; @@ -38,16 +38,16 @@ steps/align_sgmm2.sh --nj $nJobs --cmd "$train_cmd" --transform-dir exp/tri3b_al steps/make_denlats_sgmm2.sh --nj $nJobs --sub-split 30 --beam 9.0 --lattice-beam 6 \ --cmd "$decode_cmd" --transform-dir \ exp/tri3b_ali data/train data/lang exp/sgmm_5a_ali exp/sgmm_5a_denlats || exit 1; - + steps/train_mmi_sgmm2.sh --cmd "$train_cmd" --num-iters 8 --transform-dir exp/tri3b_ali --boost 0.1 \ data/train data/lang exp/sgmm_5a exp/sgmm_5a_denlats exp/sgmm_5a_mmi_b0.1 - + #decode SGMM MMI utils/mkgraph.sh data/lang_test exp/sgmm_5a_mmi_b0.1 exp/sgmm_5a_mmi_b0.1/graph steps/decode_sgmm2.sh --nj $nDecodeJobs --cmd "$decode_cmd" \ --config conf/decode.config --transform-dir exp/tri3b/decode \ exp/sgmm_5a_mmi_b0.1/graph data/test exp/sgmm_5a_mmi_b0.1/decode - + for n in 1 2 3 4; do steps/decode_sgmm2_rescore.sh --cmd "$decode_cmd" --iter $n \ --transform-dir exp/tri3b/decode data/lang_test data/test \ diff --git a/egs/gale_arabic/s5b/RESULTS b/egs/gale_arabic/s5b/RESULTS index 2260a106654..a485240ff6b 100644 --- a/egs/gale_arabic/s5b/RESULTS +++ b/egs/gale_arabic/s5b/RESULTS @@ -2,13 +2,7 @@ # This file is generated using local/split_wer.sh $galeData //galeData is a local folder to keep intermediate gale data # look at the end of run.sh in the same folder ## -##### RESULTS generated by amali at 2017-01-01-08-05-59 - Report Results WER: -%WER 9.50 [ 2124 / 22363, 160 ins, 275 del, 1689 sub ] exp/chain_cleaned/tdnn_lstm1a_sp_bi/decode/wer_report_9 -%WER 10.72 [ 2398 / 22363, 163 ins, 313 del, 1922 sub ] exp/chain_cleaned/tdnn1b_sp_bi/decode/wer_report_9 -%WER 12.04 [ 2693 / 22363, 226 ins, 271 del, 2196 sub ] exp/nnet3_cleaned/lstm_ld5_sp/decode/wer_report_9 -%WER 12.29 [ 2749 / 22363, 273 ins, 266 del, 2210 sub ] exp/nnet3_cleaned/tdnn_sp/decode/wer_report_10 %WER 17.82 [ 3986 / 22363, 315 ins, 618 del, 3053 sub ] exp/sgmm_5a_mmi_b0.1/decode/wer_report_12 %WER 18.15 [ 4059 / 22363, 335 ins, 589 del, 3135 sub ] exp/sgmm_5a_mmi_b0.1/decode4/wer_report_11 %WER 18.42 [ 4119 / 22363, 346 ins, 590 del, 3183 sub ] exp/sgmm_5a_mmi_b0.1/decode3/wer_report_11 @@ -27,10 +21,6 @@ Report Results WER: %WER 25.66 [ 5738 / 22363, 478 ins, 838 del, 4422 sub ] exp/tri2a/decode/wer_report_14 %WER 26.38 [ 5900 / 22363, 435 ins, 929 del, 4536 sub ] exp/tri1/decode/wer_report_15 Conversational Results WER: -%WER 21.59 [ 10213 / 47305, 944 ins, 3092 del, 6177 sub ] exp/chain_cleaned/tdnn_lstm1a_sp_bi/decode/wer_conversational_9 -%WER 24.77 [ 11716 / 47305, 1098 ins, 3579 del, 7039 sub ] exp/chain_cleaned/tdnn1b_sp_bi/decode/wer_conversational_9 -%WER 26.78 [ 12670 / 47305, 1741 ins, 2434 del, 8495 sub ] exp/nnet3_cleaned/lstm_ld5_sp/decode/wer_conversational_9 -%WER 27.55 [ 13032 / 47305, 1800 ins, 2666 del, 8566 sub ] exp/nnet3_cleaned/tdnn_sp/decode/wer_conversational_11 %WER 34.10 [ 16133 / 47305, 1903 ins, 3245 del, 10985 sub ] exp/sgmm_5a_mmi_b0.1/decode/wer_conversational_11 %WER 34.81 [ 16466 / 47305, 2077 ins, 3037 del, 11352 sub ] exp/sgmm_5a_mmi_b0.1/decode4/wer_conversational_10 %WER 35.19 [ 16648 / 47305, 1933 ins, 3264 del, 11451 sub ] exp/sgmm_5a_mmi_b0.1/decode3/wer_conversational_11 @@ -49,10 +39,6 @@ Conversational Results WER: %WER 45.92 [ 21724 / 47305, 1995 ins, 5213 del, 14516 sub ] exp/tri2a/decode/wer_conversational_14 %WER 46.86 [ 22166 / 47305, 2212 ins, 4819 del, 15135 sub ] exp/tri1/decode/wer_conversational_13 Combined Results for Reports and Conversational WER: -%WER 17.64 [ 12286 / 69668, 1310 ins, 2807 del, 8169 sub ] exp/chain_cleaned/tdnn_lstm1a_sp_bi/decode/wer_8 -%WER 20.26 [ 14114 / 69668, 1261 ins, 3892 del, 8961 sub ] exp/chain_cleaned/tdnn1b_sp_bi/decode/wer_9 -%WER 22.05 [ 15363 / 69668, 1967 ins, 2705 del, 10691 sub ] exp/nnet3_cleaned/lstm_ld5_sp/decode/wer_9 -%WER 22.66 [ 15786 / 69668, 2047 ins, 2955 del, 10784 sub ] exp/nnet3_cleaned/tdnn_sp/decode/wer_11 %WER 28.89 [ 20127 / 69668, 2244 ins, 3829 del, 14054 sub ] exp/sgmm_5a_mmi_b0.1/decode/wer_11 %WER 29.48 [ 20541 / 69668, 2243 ins, 3860 del, 14438 sub ] exp/sgmm_5a_mmi_b0.1/decode4/wer_11 %WER 29.81 [ 20767 / 69668, 2279 ins, 3854 del, 14634 sub ] exp/sgmm_5a_mmi_b0.1/decode3/wer_11 @@ -65,8 +51,30 @@ Combined Results for Reports and Conversational WER: %WER 32.36 [ 22542 / 69668, 2156 ins, 4184 del, 16202 sub ] exp/tri2b_mmi/decode_it4/wer_11 %WER 32.50 [ 22640 / 69668, 2393 ins, 3956 del, 16291 sub ] exp/tri2b_mmi/decode_it3/wer_11 %WER 32.79 [ 22847 / 69668, 2407 ins, 4760 del, 15680 sub ] exp/tri2b_mpe/decode_it3/wer_13 +# WER with train_sat_basis +%WER 33.35 [ 23233 / 69668, 2385 ins, 5274 del, 15574 sub ] exp/tri3b/decode/wer_16_0.5 +# WER with train_sat %WER 33.61 [ 23413 / 69668, 2817 ins, 4577 del, 16019 sub ] exp/tri3b/decode/wer_17 %WER 35.73 [ 24894 / 69668, 2630 ins, 4944 del, 17320 sub ] exp/tri3b/decode.si/wer_15 %WER 36.17 [ 25196 / 69668, 2429 ins, 5393 del, 17374 sub ] exp/tri2b/decode/wer_16 %WER 39.42 [ 27462 / 69668, 2473 ins, 6051 del, 18938 sub ] exp/tri2a/decode/wer_14 %WER 40.35 [ 28113 / 69668, 2713 ins, 5635 del, 19765 sub ] exp/tri1/decode/wer_13 + + +# Effect of GMM seed model (tri2b instead of tri3b). Using tri3b give a slightly better result +# as compared to using tri2b as seed. +%WER 16.66 [ 11610 / 69668, 1233 ins, 2747 del, 7630 sub ] exp/chain/tdnn_1a_3b_sp/decode_test/wer_10_0.0 +%WER 16.71 [ 11642 / 69668, 1145 ins, 2908 del, 7589 sub ] exp/chain/tdnn_1a_2b_sp/decode_test/wer_9_0.0 + +# Effect of Tree-size (3500, 4500, 7000, 11000) +%WER 16.66 [ 11610 / 69668, 1233 ins, 2747 del, 7630 sub ] exp/chain/tdnn_1a_3500_sp/decode_test/wer_10_0.0 +%WER 16.59 [ 11557 / 69668, 1234 ins, 2646 del, 7677 sub ] exp/chain/tdnn_1a_4500_sp/decode_test/wer_10_0.0 +%WER 16.47 [ 11474 / 69668, 1421 ins, 2207 del, 7846 sub ] exp/chain/tdnn_1a_7000_sp/decode_test/wer_9_0.0 +%WER 16.62 [ 11580 / 69668, 1164 ins, 2789 del, 7627 sub ] exp/chain/tdnn_1a_11000_sp/decode_test/wer_10_0.0 + +# Effect of l2-regularization on the output with tree-size=7000. l2 on the output (0.005,0.002) +%WER 16.54 [ 11522 / 69668, 1123 ins, 2739 del, 7660 sub ] exp/chain/tdnn_1a_7000_005_sp/decode_test/wer_9_0.5 +%WER 16.47 [ 11474 / 69668, 1421 ins, 2207 del, 7846 sub ] exp/chain/tdnn_1a_7000_002_sp/decode_test/wer_9_0.0 + +#current best 'chain' models (see local/chain/tuning/run_tdnn_1a.sh) +%WER 16.47 [ 11474 / 69668, 1421 ins, 2207 del, 7846 sub ] exp/chain/tdnn_1a_sp/decode_test/wer_9_0.0 diff --git a/egs/gale_arabic/s5b/cmd.sh b/egs/gale_arabic/s5b/cmd.sh index 71dd849a93b..ea341c98d4a 100755 --- a/egs/gale_arabic/s5b/cmd.sh +++ b/egs/gale_arabic/s5b/cmd.sh @@ -10,6 +10,6 @@ # conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, # or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. -export train_cmd="queue.pl --mem 2G" -export decode_cmd="queue.pl --mem 4G" -export mkgraph_cmd="queue.pl --mem 8G" +export train_cmd="retry.pl queue.pl --mem 2G" +export decode_cmd="retry.pl queue.pl --mem 4G" +export mkgraph_cmd="retry.pl queue.pl --mem 8G" diff --git a/egs/gale_arabic/s5b/local/chain/compare_wer.sh b/egs/gale_arabic/s5b/local/chain/compare_wer.sh new file mode 100755 index 00000000000..1a40523355a --- /dev/null +++ b/egs/gale_arabic/s5b/local/chain/compare_wer.sh @@ -0,0 +1,72 @@ +#!/bin/bash + +# this script is used for comparing decoding results between systems. +# e.g. local/chain/compare_wer.sh exp/chain/cnn{1a,1b} + +# ./local/chain/compare_wer.sh exp/chain/cnn1a +# System cnn1a +# WER 0.61 +# CER 0.15 +# Final train prob -0.0377 +# Final valid prob -0.0380 +# Final train prob (xent) -0.0830 +# Final valid prob (xent) -0.0838 + +if [ $# == 0 ]; then + echo "Usage: $0: [ ... ]" + echo "e.g.: $0 exp/chain/cnn{1a,1b}" + exit 1 +fi + +echo "# $0 $*" +used_epochs=false + +echo -n "# System " +for x in $*; do printf "% 10s" " $(basename $x)"; done +echo + +echo -n "# WER " +for x in $*; do + wer=$(cat $x/decode_test/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# CER " +for x in $*; do + cer=$(cat $x/decode_test/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + +if $used_epochs; then + exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. +fi + +echo -n "# Final train prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final train prob (xent) " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob (xent) " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo diff --git a/egs/gale_arabic/s5b/local/chain/run_chain_common.sh b/egs/gale_arabic/s5b/local/chain/run_chain_common.sh new file mode 100755 index 00000000000..da37e148441 --- /dev/null +++ b/egs/gale_arabic/s5b/local/chain/run_chain_common.sh @@ -0,0 +1,82 @@ +#!/bin/bash + +# this script has common stages shared across librispeech chain recipes. +# It generates a new topology in a new lang directory, gets the alignments as +# lattices, and builds a tree for the new topology +set -e + +stage=11 + +# input directory names. These options are actually compulsory, and they have +# been named for convenience +gmm_dir= +ali_dir= +lores_train_data_dir= + +num_leaves=6000 + +# output directory names. They are also compulsory. +lang= +lat_dir= +tree_dir= +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +[ -z $lang ] && echo "Set --lang, this specifies the new lang directory which will have the new topology" && exit 1; +[ -z $lat_dir ] && echo "Set --lat-dir, this specifies the experiment directory to store lattice" && exit 1; +[ -z $tree_dir ] && echo "Set --tree-dir, this specifies the directory to store new tree " && exit 1; + +for f in $gmm_dir/final.mdl $ali_dir/ali.1.gz $lores_train_data_dir/feats.scp; do + [ ! -f $f ] && echo "$0: expected file $f to exist" && exit 1 +done + +if [ $stage -le 11 ]; then + echo "$0: creating lang directory with one state per phone." + # Create a version of the lang/ directory that has one state per phone in the + # topo file. [note, it really has two states.. the first one is only repeated + # once, the second one has zero or more repeats.] + if [ -d $lang ]; then + if [ $lang/L.fst -nt data/lang/L.fst ]; then + echo "$0: $lang already exists, not overwriting it; continuing" + else + echo "$0: $lang already exists and seems to be older than data/lang..." + echo " ... not sure what to do. Exiting." + exit 1; + fi + else + cp -r data/lang $lang + silphonelist=$(cat $lang/phones/silence.csl) || exit 1; + nonsilphonelist=$(cat $lang/phones/nonsilence.csl) || exit 1; + # Use our special topology... note that later on may have to tune this + # topology. + steps/nnet3/chain/gen_topo.py $nonsilphonelist $silphonelist >$lang/topo + fi +fi + +if [ $stage -le 12 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + nj=$(cat ${ali_dir}/num_jobs) || exit 1; + steps/align_fmllr_lats.sh --nj $nj --cmd "$train_cmd" ${lores_train_data_dir} \ + $lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 13 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" $num_leaves ${lores_train_data_dir} $lang $ali_dir $tree_dir +fi + +exit 0; diff --git a/egs/gale_arabic/s5b/local/chain/tuning/run_tdnn_1a.sh b/egs/gale_arabic/s5b/local/chain/tuning/run_tdnn_1a.sh index 7afafb31ff6..bf2e45c9914 100755 --- a/egs/gale_arabic/s5b/local/chain/tuning/run_tdnn_1a.sh +++ b/egs/gale_arabic/s5b/local/chain/tuning/run_tdnn_1a.sh @@ -1,31 +1,51 @@ #!/bin/bash -#started from tedlium recipe with few edits +# ./local/chain/compare_wer.sh exp/chain/tdnn_1a_sp +# System tdnn_1a_sp +# WER 16.47 +# CER 6.68 +# Final train prob -0.0652 +# Final valid prob -0.0831 +# Final train prob (xent) -0.8965 +# Final valid prob (xent) -0.9964 +# steps/info/chain_dir_info.pl exp/chain/tdnn_1a_sp/ +# exp/chain/tdnn_1a_sp/: num-iters=441 nj=3..16 num-params=18.6M dim=40+100->5816 combine=-0.063->-0.062 (over 6) xent:train/valid[293,440,final]=(-1.22,-0.912,-0.896/-1.29,-1.01,-0.996) logprob:train/valid[293,440,final]=(-0.097,-0.066,-0.065/-0.108,-0.084,-0.083) -set -e -o pipefail -# First the options that are passed through to run_ivector_common.sh -# (some of which are also used in this script directly). +set -e -o pipefail stage=0 nj=30 -decode_nj=30 -min_seg_len=1.55 -xent_regularize=0.1 train_set=train -gmm=tri2b # the gmm for the target data +test_set=test +gmm=tri3b # this is the source gmm-dir that we'll use for alignments; it + # should have alignments for the specified training data. num_threads_ubm=32 -nnet3_affix=_cleaned # cleanup affix for nnet3 and chain dirs, e.g. _cleaned - -# The rest are configs specific to this script. Most of the parameters -# are just hardcoded at this level, in the commands below. -train_stage=-10 #default -10 -tree_affix= # affix for tree directory, e.g. "a" or "b", in case we change the configuration. -tdnn_affix=1b #affix for TDNN directory, e.g. "a" or "b", in case we change the configuration. -common_egs_dir= # you can set this to use previously dumped egs. +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. + +# Options which are not passed through to run_ivector_common.sh +affix=_1a #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +common_egs_dir= +reporting_email= + +# LSTM/chain options +train_stage=-10 +xent_regularize=0.1 +dropout_schedule='0,0@0.20,0.5@0.50,0' + +# training chunk-options +chunk_width=150,110,100 +get_egs_stage=-10 + +# training options +srand=0 +remove_egs=true +run_ivector_common=true +run_chain_common=true # End configuration section. echo "$0 $@" # Print the command line for logging + . ./cmd.sh . ./path.sh . ./utils/parse_options.sh @@ -39,169 +59,162 @@ where "nvcc" is installed. EOF fi -local/nnet3/run_ivector_common.sh --stage $stage \ - --nj $nj \ - --min-seg-len $min_seg_len \ - --train-set $train_set \ - --gmm $gmm \ - --num-threads-ubm $num_threads_ubm \ - --nnet3-affix "$nnet3_affix" - - -gmm_dir=exp/$gmm -ali_dir=exp/${gmm}_ali_${train_set}_sp_comb -tree_dir=exp/chain${nnet3_affix}/tree_bi${tree_affix} -lat_dir=exp/chain${nnet3_affix}/${gmm}_${train_set}_sp_comb_lats -dir=exp/chain${nnet3_affix}/tdnn${tdnn_affix}_sp_bi -train_data_dir=data/${train_set}_sp_hires_comb -lores_train_data_dir=data/${train_set}_sp_comb -train_ivector_dir=exp/nnet3${nnet3_affix}/ivectors_${train_set}_sp_hires_comb - +if $run_ivector_common; then + local/nnet3/run_ivector_common.sh \ + --stage $stage --nj $nj \ + --train-set $train_set --gmm $gmm \ + --num-threads-ubm $num_threads_ubm \ + --nnet3-affix "$nnet3_affix" +fi -for f in $gmm_dir/final.mdl $train_data_dir/feats.scp $train_ivector_dir/ivector_online.scp \ - $lores_train_data_dir/feats.scp $ali_dir/ali.1.gz $gmm_dir/final.mdl; do +gmm_dir=exp/${gmm} +ali_dir=exp/${gmm}_ali_${train_set}_sp +lat_dir=exp/chain${nnet3_affix}/${gmm}_${train_set}_sp_lats +dir=exp/chain${nnet3_affix}/tdnn${affix}_sp +train_data_dir=data/${train_set}_sp_hires +train_ivector_dir=exp/nnet3${nnet3_affix}/ivectors_${train_set}_sp_hires +lores_train_data_dir=data/${train_set}_sp + +# note: you don't necessarily have to change the treedir name +# each time you do a new experiment-- only if you change the +# configuration in a way that affects the tree. +tree_dir=exp/chain${nnet3_affix}/tree_a_sp +# the 'lang' directory is created by this script. +# If you create such a directory with a non-standard topology +# you should probably name it differently. +lang=data/lang_chain + +for f in $train_data_dir/feats.scp $train_ivector_dir/ivector_online.scp \ + $lores_train_data_dir/feats.scp $gmm_dir/final.mdl \ + $ali_dir/ali.1.gz $gmm_dir/final.mdl; do [ ! -f $f ] && echo "$0: expected file $f to exist" && exit 1 done -if [ $stage -le 14 ]; then - echo "$0: creating lang directory with one state per phone." - # Create a version of the lang/ directory that has one state per phone in the - # topo file. [note, it really has two states.. the first one is only repeated - # once, the second one has zero or more repeats.] - if [ -d data/lang_chain ]; then - if [ data/lang_chain/L.fst -nt data/lang/L.fst ]; then - echo "$0: data/lang_chain already exists, not overwriting it; continuing" - else - echo "$0: data/lang_chain already exists and seems to be older than data/lang..." - echo " ... not sure what to do. Exiting." - exit 1; - fi - else - cp -r data/lang data/lang_chain - silphonelist=$(cat data/lang_chain/phones/silence.csl) || exit 1; - nonsilphonelist=$(cat data/lang_chain/phones/nonsilence.csl) || exit 1; - # Use our special topology... note that later on may have to tune this - # topology. - steps/nnet3/chain/gen_topo.py $nonsilphonelist $silphonelist >data/lang_chain/topo - fi +# Please take this as a reference on how to specify all the options of +# local/chain/run_chain_common.sh +if $run_chain_common; then + local/chain/run_chain_common.sh --stage $stage \ + --gmm-dir $gmm_dir \ + --ali-dir $ali_dir \ + --lores-train-data-dir ${lores_train_data_dir} \ + --lang $lang \ + --lat-dir $lat_dir \ + --num-leaves 7000 \ + --tree-dir $tree_dir || exit 1; fi if [ $stage -le 15 ]; then - # Get the alignments as lattices (gives the chain training more freedom). - # use the same num-jobs as the alignments - steps/align_fmllr_lats.sh --nj 100 --cmd "$train_cmd" ${lores_train_data_dir} \ - data/lang $gmm_dir $lat_dir - rm $lat_dir/fsts.*.gz # save space -fi - -if [ $stage -le 16 ]; then - # Build a tree using our new topology. We know we have alignments for the - # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use - # those. - if [ -f $tree_dir/final.mdl ]; then - echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." - exit 1; - fi - steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ - --context-opts "--context-width=2 --central-position=1" \ - --cmd "$train_cmd" 4000 ${lores_train_data_dir} data/lang_chain $ali_dir $tree_dir -fi - -if [ $stage -le 17 ]; then mkdir -p $dir - echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + affine_opts="l2-regularize=0.01 dropout-proportion=0.0 dropout-per-dim=true dropout-per-dim-continuous=true" + tdnnf_opts="l2-regularize=0.01 dropout-proportion=0.0 bypass-scale=0.66" + linear_opts="l2-regularize=0.01 orthonormal-constraint=-1.0" + prefinal_opts="l2-regularize=0.01" + output_opts="l2-regularize=0.002" mkdir -p $dir/configs + cat < $dir/configs/network.xconfig input dim=100 name=ivector input dim=40 name=input - # please note that it is important to have input layer with the name=input # as the layer immediately preceding the fixed-affine-layer to enable # the use of short notation for the descriptor fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat - # the first splicing is moved before the lda layer, so no splicing here - relu-renorm-layer name=tdnn1 dim=450 - relu-renorm-layer name=tdnn2 input=Append(-1,0,1) dim=450 - relu-renorm-layer name=tdnn3 input=Append(-1,0,1,2) dim=450 - relu-renorm-layer name=tdnn4 input=Append(-3,0,3) dim=450 - relu-renorm-layer name=tdnn5 input=Append(-3,0,3) dim=450 - relu-renorm-layer name=tdnn6 input=Append(-6,-3,0) dim=450 - - ## adding the layers for chain branch - relu-renorm-layer name=prefinal-chain input=tdnn6 dim=450 target-rms=0.5 - output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 - - # adding the layers for xent branch - # This block prints the configs for a separate output that will be - # trained with a cross-entropy objective in the 'chain' models... this - # has the effect of regularizing the hidden parts of the model. we use - # 0.5 / args.xent_regularize as the learning rate factor- the factor of - # 0.5 / args.xent_regularize is suitable as it means the xent - # final-layer learns at a rate independent of the regularization - # constant; and the 0.5 was tuned so as to make the relative progress - # similar in the xent and regular final layers. - relu-renorm-layer name=prefinal-xent input=tdnn6 dim=450 target-rms=0.5 - output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 - + relu-batchnorm-dropout-layer name=tdnn1 $affine_opts dim=1536 + tdnnf-layer name=tdnnf2 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=1 + tdnnf-layer name=tdnnf3 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=1 + tdnnf-layer name=tdnnf4 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=1 + tdnnf-layer name=tdnnf5 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=0 + tdnnf-layer name=tdnnf6 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf7 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf8 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf9 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf10 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf11 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf12 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf13 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf14 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf15 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + linear-component name=prefinal-l dim=256 $linear_opts + prefinal-layer name=prefinal-chain input=prefinal-l $prefinal_opts big-dim=1536 small-dim=256 + output-layer name=output include-log-softmax=false dim=$num_targets $output_opts + prefinal-layer name=prefinal-xent input=prefinal-l $prefinal_opts big-dim=1536 small-dim=256 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor $output_opts EOF steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ - fi -if [ $stage -le 18 ]; then + +if [ $stage -le 16 ]; then if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then utils/create_split_dir.pl \ - /export/b0{5,6,7,8}/$USER/kaldi-data/egs/gale_arabic-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/wsj-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage fi - steps/nnet3/chain/train.py --stage $train_stage \ + steps/nnet3/chain/train.py --stage $train_stage \ --cmd "$decode_cmd" \ --feat.online-ivector-dir $train_ivector_dir \ --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ --chain.xent-regularize $xent_regularize \ --chain.leaky-hmm-coefficient 0.1 \ - --chain.l2-regularize 0.00005 \ + --chain.l2-regularize 0.0 \ --chain.apply-deriv-weights false \ --chain.lm-opts="--num-extra-lm-states=2000" \ - --egs.dir "$common_egs_dir" \ - --egs.opts "--frames-overlap-per-eg 0" \ - --egs.chunk-width 150 \ - --trainer.num-chunk-per-minibatch 128 \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs 6 \ --trainer.frames-per-iter 1500000 \ - --trainer.num-epochs 4 \ - --trainer.optimization.num-jobs-initial 2 \ - --trainer.optimization.num-jobs-final 2 \ - --trainer.optimization.initial-effective-lrate 0.001 \ - --trainer.optimization.final-effective-lrate 0.0001 \ - --trainer.max-param-change 2.0 \ - --cleanup.remove-egs true \ - --feat-dir $train_data_dir \ + --trainer.optimization.num-jobs-initial 3 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.00025 \ + --trainer.optimization.final-effective-lrate 0.000025 \ + --trainer.num-chunk-per-minibatch=64,32 \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$common_egs_dir" \ + --egs.opts "--frames-overlap-per-eg 0 --constrained false" \ + --egs.stage $get_egs_stage \ + --reporting.email="$reporting_email" \ + --cleanup.remove-egs=$remove_egs \ + --feat-dir=$train_data_dir \ --tree-dir $tree_dir \ - --lat-dir $lat_dir \ - --dir $dir -fi - + --lat-dir=$lat_dir \ + --dir $dir || exit 1; +fi -if [ $stage -le 19 ]; then - # Note: it might appear that this data/lang_chain directory is mismatched, and it is as - # far as the 'topo' is concerned, but this script doesn't read the 'topo' from - # the lang directory. - utils/mkgraph.sh --left-biphone --self-loop-scale 1.0 data/lang_test $dir $dir/graph +if [ $stage -le 17 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/lang/check_phones_compatible.sh \ + data/lang_test/phones.txt $lang/phones.txt + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/lang_test \ + $tree_dir $tree_dir/graph || exit 1; fi -if [ $stage -le 20 ]; then +if [ $stage -le 18 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) rm $dir/.error 2>/dev/null || true - steps/nnet3/decode.sh --num-threads 4 --nj $decode_nj --cmd "$decode_cmd" \ - --acwt 1.0 --post-decode-acwt 10.0 \ - --online-ivector-dir exp/nnet3${nnet3_affix}/ivectors_test_hires \ - --scoring-opts "--min-lmwt 5 " \ - $dir/graph data/test_hires $dir/decode || exit 1; + + steps/nnet3/decode.sh \ + --acwt 1.0 --post-decode-acwt 10.0 \ + --extra-left-context 0 --extra-right-context 0 \ + --extra-left-context-initial 0 \ + --extra-right-context-final 0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$decode_cmd" --num-threads 4 \ + --online-ivector-dir exp/nnet3${nnet3_affix}/ivectors_${test_set}_hires \ + $tree_dir/graph data/${test_set}_hires ${dir}/decode_${test_set} || exit 1 fi -exit 0 diff --git a/egs/gale_arabic/s5b/local/chain/tuning/run_tdnn_lstm_1a.sh b/egs/gale_arabic/s5b/local/chain/tuning/run_tdnn_lstm_1a.sh index 604f32a1de4..deebafc95e4 100755 --- a/egs/gale_arabic/s5b/local/chain/tuning/run_tdnn_lstm_1a.sh +++ b/egs/gale_arabic/s5b/local/chain/tuning/run_tdnn_lstm_1a.sh @@ -120,7 +120,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/gale_arabic/s5b/local/gale_data_prep_audio.sh b/egs/gale_arabic/s5b/local/gale_data_prep_audio.sh deleted file mode 100755 index 0125272d06c..00000000000 --- a/egs/gale_arabic/s5b/local/gale_data_prep_audio.sh +++ /dev/null @@ -1,32 +0,0 @@ -#!/bin/bash - -# Copyright 2014 QCRI (author: Ahmed Ali) -# Apache 2.0 - - -galeData=$(utils/make_absolute.sh "${@: -1}" ); # last argumnet; the local folder -audio_dvds=${@:1:${#}-1} # all the audio dvds for GALE corpus; ; check audio=( in ../run.sh - -mkdir -p $galeData - -# check that sox is installed -which sox &>/dev/null -if [[ $? != 0 ]]; then - echo "sox is not installed"; exit 1 -fi - -for dvd in $audio_dvds; do - dvd_full_path=$(utils/make_absolute.sh $dvd) - if [[ ! -e $dvd_full_path ]]; then - echo missing $dvd_full_path; exit 1; - fi - find $dvd_full_path \( -name "*.wav" -o -name "*.flac" \) | while read file; do - id=$(basename $file | awk '{gsub(".wav","");gsub(".flac","");print}') - echo "$id sox $file -r 16000 -t wav - |" - done -done | sort -u > $galeData/wav.scp - -echo data prep audio succeded - -exit 0 - diff --git a/egs/gale_arabic/s5b/local/gale_data_prep_split.sh b/egs/gale_arabic/s5b/local/gale_data_prep_split.sh deleted file mode 100755 index b18a4e5b105..00000000000 --- a/egs/gale_arabic/s5b/local/gale_data_prep_split.sh +++ /dev/null @@ -1,39 +0,0 @@ -#!/bin/bash - -# Copyright 2014 QCRI (author: Ahmed Ali) -# Apache 2.0 - -if [ $# -ne 1 ]; then - echo "Arguments should be the "; exit 1 -fi - - -#data will data/local - -galeData=$(utils/make_absolute.sh $1) -mkdir -p data/local -dir=$(utils/make_absolute.sh data/local) - - -grep -f local/test_list $galeData/all | grep -v -f local/bad_segments > $galeData/all.test -grep -v -f local/test_list $galeData/all | grep -v -f local/bad_segments > $galeData/all.train - -for x in test train; do - outdir=$dir/$x - file=$galeData/all.$x - mkdir -p $outdir - awk '{print $2 " " $2}' $file | sort -u > $outdir/utt2spk - cp -pr $outdir/utt2spk $outdir/spk2utt - awk '{print $2 " " $1 " " $3 " " $4}' $file | sort -u > $outdir/segments - awk '{printf $2 " "; for (i=5; i<=NF; i++) {printf $i " "} printf "\n"}' $file | sort -u > $outdir/text -done - - -grep -f local/test_list $galeData/wav.scp > $dir/test/wav.scp - -cat $galeData/wav.scp | awk -v seg=$dir/train/segments 'BEGIN{while((getline0) {seen[$2]=1;}} - {if (seen[$1]) { print $0}}' > $dir/train/wav.scp - -echo data prep split succeeded - -exit 0 diff --git a/egs/gale_arabic/s5b/local/gale_data_prep_txt.sh b/egs/gale_arabic/s5b/local/gale_data_prep_txt.sh deleted file mode 100755 index 04529d88ac0..00000000000 --- a/egs/gale_arabic/s5b/local/gale_data_prep_txt.sh +++ /dev/null @@ -1,60 +0,0 @@ -#!/bin/bash - -# Copyright 2014 QCRI (author: Ahmed Ali) -# Apache 2.0 - -galeData=$(utils/make_absolute.sh "${@: -1}" ); # last argumnet; the local folder -txt_dvds=${@:1:${#}-1} # all the txt cds correspoding to the audio corpus; check text=( in ../run.sh - - -top_pwd=`pwd` -txtdir=$galeData/txt -mkdir -p $txtdir; cd $txtdir - -for cdx in $txt_dvds; do - echo "Preparing $cdx" - if [[ $cdx == *.tgz ]] ; then - tar -xvf $cdx - elif [ -d "$cdx" ]; then - ln -s $cdx `basename $cdx` - else - echo "I don't really know what I shall do with $cdx " >&2 - fi -done - -find -L . -type f -name "*.tdf" | while read file; do -sed '1,3d' $file # delete the first 3 lines -done > all.tmp$$ - -perl -e ' - ($inFile,$idFile,$txtFile)= split /\s+/, $ARGV[0]; - open(IN, "$inFile"); - open(ID, ">$idFile"); - open(TXT, ">$txtFile"); - while () { - @arr= split /\t/,$_; - $start=sprintf ("%0.3f",$arr[2]);$rStart=$start;$start=~s/\.//; $start=~s/^0+$/0/; $start=~s/^0+([^0])/$1/; # remove zeros at the beginning - $end=sprintf ("%0.3f",$arr[3]);$rEnd=$end;$end=~s/^0+([^0])/$1/;$end=~s/\.//; - if ( ($arr[11] !~ m/report/) && ($arr[11] !~ m/conversational/) ){$arr[11]="UNK";} - $id="$arr[11] $arr[0] $arr[0]_${start}_${end} $rStart $rEnd\n"; - next if ($rStart == $rEnd); - $id =~ s/.sph//g; - print ID $id; - print TXT "$arr[7]\n"; - }' "all.tmp$$ allid.tmp$$ contentall.tmp$$" - - -perl ${top_pwd}/local/normalize_transcript_BW.pl contentall.tmp$$ contentall.buck.tmp$$ - -paste allid.tmp$$ contentall.buck.tmp$$ | sed 's: $::' | awk '{if (NF>5) {print $0}}' > all_1.tmp$$ - -awk '{$1="";print $0}' all_1.tmp$$ | sed 's:^ ::' > $galeData/all -awk '{if ($1 == "report") {$1="";print $0}}' all_1.tmp$$ | sed 's:^ ::' > $galeData/report -awk '{if ($1 == "conversational") {$1="";print $0}}' all_1.tmp$$ | sed 's:^ ::' > $galeData/conversational - -cd ..; -rm -fr $txtdir -cd $top_pwd -echo data prep text succeeded - -exit 0 diff --git a/egs/gale_arabic/s5b/local/gale_format_data.sh b/egs/gale_arabic/s5b/local/gale_format_data.sh deleted file mode 100755 index b69c34e68b9..00000000000 --- a/egs/gale_arabic/s5b/local/gale_format_data.sh +++ /dev/null @@ -1,60 +0,0 @@ -#!/bin/bash - -# Copyright 2014 QCRI (author: Ahmed Ali) -# Apache 2.0 - -if [ -f path.sh ]; then - . ./path.sh; else - echo "$0: missing path.sh"; exit 1; -fi - -for dir in test train; do - cp -pr data/local/$dir data/$dir -done - - -mkdir -p data/lang_test - -arpa_lm=data/local/lm/3gram-mincount/lm_unpruned.gz -[ ! -f $arpa_lm ] && echo No such file $arpa_lm && exit 1; - -rm -r data/lang_test -cp -r data/lang data/lang_test - -gunzip -c "$arpa_lm" | \ - arpa2fst --disambig-symbol=#0 \ - --read-symbol-table=data/lang_test/words.txt - data/lang_test/G.fst - - -echo "$0: Checking how stochastic G is (the first of these numbers should be small):" -fstisstochastic data/lang_test/G.fst - -## Check lexicon. -## just have a look and make sure it seems sane. -echo "$0: First few lines of lexicon FST:" -fstprint --isymbols=data/lang/phones.txt --osymbols=data/lang/words.txt data/lang/L.fst | head - -echo "$0: Performing further checks" - -# Checking that G.fst is determinizable. -fstdeterminize data/lang_test/G.fst /dev/null || echo Error determinizing G. - -# Checking that L_disambig.fst is determinizable. -fstdeterminize data/lang_test/L_disambig.fst /dev/null || echo Error determinizing L. - -# Checking that disambiguated lexicon times G is determinizable -# Note: we do this with fstdeterminizestar not fstdeterminize, as -# fstdeterminize was taking forever (presumbaly relates to a bug -# in this version of OpenFst that makes determinization slow for -# some case). -fsttablecompose data/lang_test/L_disambig.fst data/lang_test/G.fst | \ - fstdeterminizestar >/dev/null || echo Error - -# Checking that LG is stochastic: -fsttablecompose data/lang/L_disambig.fst data/lang_test/G.fst | \ - fstisstochastic || echo LG is not stochastic - - -echo gale_format_data succeeded. - -exit 0 diff --git a/egs/gale_arabic/s5b/local/gale_prep_grapheme_dict.sh b/egs/gale_arabic/s5b/local/gale_prep_grapheme_dict.sh deleted file mode 100755 index 5f101f8245b..00000000000 --- a/egs/gale_arabic/s5b/local/gale_prep_grapheme_dict.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/bin/bash - -# Copyright 2017 QCRI (author: Ahmed Ali) -# Apache 2.0 - - -# run this from ../ -dir=$(utils/make_absolute.sh data/local/dict) -mkdir -p $dir - - -# (1) Get all avaialble dictionaries, since this is a grapheme model, so we mainly need the most frequent word lists -wget http://alt.qcri.org//resources/speech/dictionary/ar-ar_grapheme_lexicon_2016-02-09.bz2 || exit 1; -wget http://alt.qcri.org//resources/speech/dictionary/ar-ar_lexicon_2014-03-17.txt.bz2 || exit 1; -bzcat ar-ar_grapheme_lexicon_2016-02-09.bz2 | sed '1,3d' | awk '{print $1}' > tmp$$ -bzcat ar-ar_lexicon_2014-03-17.txt.bz2 | sed '1,3d' | awk '{print $1}' >> tmp$$ -# (2) Now we add all the words appeared in the training data -cat data/local/train/text | cut -d ' ' -f 2- | tr -s " " "\n" | sort -u >> tmp$$ -grep -v [0-9] tmp$$ | sed -e 's:[FNKaui\~o\`]::g' -e 's:{:}:g' | sort -u > tmp1.$$ # remove vowels and rare alef wasla -cat tmp1.$$ | sed 's:\(\):\1 :g' | sed -e 's: : :g' -e 's: : :g' -e 's:\s*: :g' -e 's:\*:V:g' > tmp2.$$ -paste -d ' ' tmp1.$$ tmp2.$$ > $dir/lexicon.txt - -#(2) Dictionary preparation: - -# silence phones, one per line. -echo SIL > $dir/silence_phones.txt -echo SIL > $dir/optional_silence.txt - -# nonsilence phones; on each line is a list of phones that correspond -# really to the same base phone. -cat tmp2.$$ | tr -s ' ' '\n' | grep -v ^$ | sort -u > $dir/nonsilence_phones.txt || exit 1; - -sed -i '1i SIL' $dir/lexicon.txt # insert word with phone sil at the begining of the dictionary - -rm -fr ar-ar_lexicon_2014-03-17.txt.bz2 ar-ar_grapheme_lexicon_2016-02-09.bz2 tmp$$ tmp1.$$ tmp2.$$ -echo Dictionary preparation succeeded - -# The script is still missing dates and numbers - -exit 0 - diff --git a/egs/gale_arabic/s5b/local/gale_train_lms.sh b/egs/gale_arabic/s5b/local/gale_train_lms.sh deleted file mode 100755 index 3988ec3818f..00000000000 --- a/egs/gale_arabic/s5b/local/gale_train_lms.sh +++ /dev/null @@ -1,81 +0,0 @@ -#!/bin/bash - - -# To be run from one directory above this script. - - -lexicon=data/local/dict/lexicon.txt -[ ! -f $lexicon ] && echo "$0: No such file $lexicon" && exit 1; - - -# This script takes no arguments. It assumes you have already run -# previus steps successfully -# It takes as input the files -#data/local/train.*/text -#data/local/dict/lexicon.txt - - -export LC_ALL=C # You'll get errors about things being not sorted, if you -# have a different locale. -export PATH=$PATH:./../../../tools/kaldi_lm -( # First make sure the kaldi_lm toolkit is installed. - cd $KALDI_ROOT/tools || exit 1; - if [ -d kaldi_lm ]; then - echo Not installing the kaldi_lm toolkit since it is already there. - else - echo Downloading and installing the kaldi_lm tools - if [ ! -f kaldi_lm.tar.gz ]; then - wget http://www.danielpovey.com/files/kaldi/kaldi_lm.tar.gz || exit 1; - fi - tar -xvzf kaldi_lm.tar.gz || exit 1; - cd kaldi_lm - make || exit 1; - echo Done making the kaldi_lm tools - fi -) || exit 1; - - -dir=data/local/lm - mkdir -p $dir - text=data/local/train/text - [ ! -f $text ] && echo "$0: No such file $text" && exit 1; - - cleantext=$dir/text.no_oov - - cat $text | awk -v lex=$lexicon 'BEGIN{while((getline0){ seen[$1]=1; } } - {for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf(" ",$n);} } printf("\n");}' \ - > $cleantext || exit 1; - - - cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort | uniq -c | \ - sort -nr > $dir/word.counts || exit 1; - - -# Get counts from acoustic training transcripts, and add one-count -# for each word in the lexicon (but not silence, we don't want it -# in the LM-- we'll add it optionally later). - cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | \ - cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | \ - sort | uniq -c | sort -nr > $dir/unigram.counts || exit 1; - -# note: we probably won't really make use of as there aren't any OOVs - cat $dir/unigram.counts | awk '{print $2}' | get_word_map.pl "" "" "" > $dir/word_map \ - || exit 1; - -# note: ignore 1st field of train.txt, it's the utterance-id. - cat $cleantext | awk -v wmap=$dir/word_map 'BEGIN{while((getline0)map[$1]=$2;} - { for(n=2;n<=NF;n++) { printf map[$n]; if(n$dir/train.gz \ - || exit 1; - - train_lm.sh --arpa --lmtype 3gram-mincount $dir || exit 1; - -# LM is small enough that we don't need to prune it (only about 0.7M N-grams). -# Perplexity over 128254.000000 words is 90.446690 - -# note: output is -# data/local/lm/3gram-mincount/lm_unpruned.gz - - -echo train lm succeeded - -exit 0 diff --git a/egs/gale_arabic/s5b/local/nnet3/run_ivector_common.sh b/egs/gale_arabic/s5b/local/nnet3/run_ivector_common.sh index f14c8441869..f071842dc0b 100755 --- a/egs/gale_arabic/s5b/local/nnet3/run_ivector_common.sh +++ b/egs/gale_arabic/s5b/local/nnet3/run_ivector_common.sh @@ -2,31 +2,29 @@ set -e -o pipefail -# This script is called from local/nnet3/run_tdnn.sh and local/chain/run_tdnn.sh (and may eventually -# be called by more scripts). It contains the common feature preparation and iVector-related parts -# of the script. See those scripts for examples of usage. +# This script is called from scripts like local/nnet3/run_tdnn.sh and +# local/chain/run_tdnn.sh (and may eventually be called by more scripts). It +# contains the common feature preparation and iVector-related parts of the +# script. See those scripts for examples of usage. stage=0 nj=100 -min_seg_len=1.55 # min length in seconds... we do this because chain training - # will discard segments shorter than 1.5 seconds. Must remain in sync - # with the same option given to prepare_lores_feats_and_alignments.sh train_set=train # you might set this to e.g. train. -gmm=tri2b # This specifies a GMM-dir from the features of the type you're training the system on; +test_sets="test" +gmm=tri3b # This specifies a GMM-dir from the features of the type you're training the system on; # it should contain alignments for 'train_set'. num_threads_ubm=32 -nnet3_affix=_cleaned # affix for exp/nnet3 directory to put iVector stuff in, so it - # becomes exp/nnet3_cleaned or whatever. +nnet3_affix= # affix for exp/nnet3 directory to put iVector stuff . ./cmd.sh . ./path.sh -. ./utils/parse_options.sh +. utils/parse_options.sh gmm_dir=exp/${gmm} -ali_dir=exp/${gmm}_ali_${train_set}_sp_comb +ali_dir=exp/${gmm}_ali_${train_set}_sp for f in data/${train_set}/feats.scp ${gmm_dir}/final.mdl; do if [ ! -f $f ]; then @@ -61,7 +59,7 @@ if [ $stage -le 2 ]; then utils/create_split_dir.pl /export/b0{5,6,7,8}/$USER/kaldi-data/mfcc/gale_arabic-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage fi - for datadir in ${train_set}_sp test; do + for datadir in ${train_set}_sp ${test_sets}; do utils/copy_data_dir.sh data/$datadir data/${datadir}_hires done @@ -69,7 +67,7 @@ if [ $stage -le 2 ]; then # features; this helps make trained nnets more invariant to test data volume. utils/data/perturb_data_dir_volume.sh data/${train_set}_sp_hires - for datadir in ${train_set}_sp test; do + for datadir in ${train_set}_sp ${test_sets}; do steps/make_mfcc.sh --nj $nj --mfcc-config conf/mfcc_hires.conf \ --cmd "$train_cmd" data/${datadir}_hires steps/compute_cmvn_stats.sh data/${datadir}_hires @@ -78,75 +76,33 @@ if [ $stage -le 2 ]; then fi if [ $stage -le 3 ]; then - echo "$0: combining short segments of speed-perturbed high-resolution MFCC training data" - # we have to combine short segments or we won't be able to train chain models - # on those segments. - utils/data/combine_short_segments.sh \ - data/${train_set}_sp_hires $min_seg_len data/${train_set}_sp_hires_comb - - # just copy over the CMVN to avoid having to recompute it. - cp data/${train_set}_sp_hires/cmvn.scp data/${train_set}_sp_hires_comb/ - utils/fix_data_dir.sh data/${train_set}_sp_hires_comb/ -fi - -if [ $stage -le 4 ]; then - echo "$0: selecting segments of hires training data that were also present in the" - echo " ... original training data." - - # note, these data-dirs are temporary; we put them in a sub-directory - # of the place where we'll make the alignments. - temp_data_root=exp/nnet3${nnet3_affix}/tri5 - mkdir -p $temp_data_root - - utils/data/subset_data_dir.sh --utt-list data/${train_set}/feats.scp \ - data/${train_set}_sp_hires $temp_data_root/${train_set}_hires - - # note: essentially all the original segments should be in the hires data. - n1=$(wc -l /dev/null +if [[ $? != 0 ]]; then + echo "$0: sox is not installed"; exit 1 +fi + +for dvd in $dir1 $dir2 $dir3; do + dvd_full_path=$(utils/make_absolute.sh $dvd) + if [[ ! -e $dvd_full_path ]]; then + echo "$0: missing $dvd_full_path"; exit 1; + fi + find $dvd_full_path \( -name "*.wav" -o -name "*.flac" \) | while read file; do + id=$(basename $file | awk '{gsub(".wav","");gsub(".flac","");print}') + echo "$id sox $file -r 16000 -t wav - |" + done +done | sort -u > $gale_data/wav.scp +echo "$0:data prep audio succeded" + +gale_data=$(utils/make_absolute.sh "GALE" ); +top_pwd=`pwd` +txtdir=$gale_data/txt +mkdir -p $txtdir; cd $txtdir + +for cdx in $text1 $text2 $text3; do + echo "$0:Preparing $cdx" + if [[ $cdx == *.tgz ]] ; then + tar -xvf $cdx + elif [ -d "$cdx" ]; then + ln -s $cdx `basename $cdx` + else + echo "$0:I don't really know what I shall do with $cdx " >&2 + fi +done + +find -L . -type f -name "*.tdf" | while read file; do +sed '1,3d' $file # delete the first 3 lines +done > all.tmp$$ + +perl -e ' + ($inFile,$idFile,$txtFile)= split /\s+/, $ARGV[0]; + open(IN, "$inFile"); + open(ID, ">$idFile"); + open(TXT, ">$txtFile"); + while () { + @arr= split /\t/,$_; + $start=sprintf ("%0.3f",$arr[2]);$rStart=$start;$start=~s/\.//; $start=~s/^0+$/0/; $start=~s/^0+([^0])/$1/; # remove zeros at the beginning + $end=sprintf ("%0.3f",$arr[3]);$rEnd=$end;$end=~s/^0+([^0])/$1/;$end=~s/\.//; + if ( ($arr[11] !~ m/report/) && ($arr[11] !~ m/conversational/) ){$arr[11]="UNK";} + $id="$arr[11] $arr[0] $arr[0]_${start}_${end} $rStart $rEnd\n"; + next if ($rStart == $rEnd); + $id =~ s/.sph//g; + print ID $id; + print TXT "$arr[7]\n"; + }' "all.tmp$$ allid.tmp$$ contentall.tmp$$" + +perl ${top_pwd}/local/normalize_transcript_BW.pl contentall.tmp$$ contentall.buck.tmp$$ +paste allid.tmp$$ contentall.buck.tmp$$ | sed 's: $::' | awk '{if (NF>5) {print $0}}' > all_1.tmp$$ + + +awk '{$1="";print $0}' all_1.tmp$$ | sed 's:^ ::' > $gale_data/all +awk '{if ($1 == "report") {$1="";print $0}}' all_1.tmp$$ | sed 's:^ ::' > $gale_data/report +awk '{if ($1 == "conversational") {$1="";print $0}}' all_1.tmp$$ | sed 's:^ ::' > $gale_data/conversational + +cd ..; +rm -fr $txtdir +cd $top_pwd +echo "$0:dat a prep text succeeded" + +mkdir -p data +dir=$(utils/make_absolute.sh data/) +grep -f local/test_list $gale_data/all | grep -v -f local/bad_segments > $gale_data/all.test +grep -v -f local/test_list $gale_data/all | grep -v -f local/bad_segments > $gale_data/all.train + +for x in test train; do + outdir=data/$x + file=$gale_data/all.$x + mkdir -p $outdir + awk '{print $2 " " $2}' $file | sort -u > $outdir/utt2spk + cp -pr $outdir/utt2spk $outdir/spk2utt + awk '{print $2 " " $1 " " $3 " " $4}' $file | sort -u > $outdir/segments + awk '{printf $2 " "; for (i=5; i<=NF; i++) {printf $i " "} printf "\n"}' $file | sort -u > $outdir/text +done + +grep -f local/test_list $gale_data/wav.scp > $dir/test/wav.scp + +cat $gale_data/wav.scp | awk -v seg=$dir/train/segments 'BEGIN{while((getline0) {seen[$2]=1;}} + {if (seen[$1]) { print $0}}' > $dir/train/wav.scp + +echo "$0:data prep split succeeded" +exit 0 diff --git a/egs/gale_arabic/s5b/local/prepare_dict.sh b/egs/gale_arabic/s5b/local/prepare_dict.sh new file mode 100755 index 00000000000..47b5869fdf1 --- /dev/null +++ b/egs/gale_arabic/s5b/local/prepare_dict.sh @@ -0,0 +1,48 @@ +#!/usr/bin/env bash + +# Copyright 2017 QCRI (author: Ahmed Ali) +# Apache 2.0 +# This script prepares the dictionary. + +set -e +dir=data/local/dict +lexicon_url1="http://alt.qcri.org//resources/speech/dictionary/ar-ar_grapheme_lexicon_2016-02-09.bz2"; +lexicon_url2="http://alt.qcri.org//resources/speech/dictionary/ar-ar_lexicon_2014-03-17.txt.bz2"; +stage=0 +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; +mkdir -p $dir data/local/lexicon_data + +if [ $stage -le 0 ]; then + echo "$0: Downloading text for lexicon... $(date)." + wget -P data/local/lexicon_data $lexicon_url1 + wget -P data/local/lexicon_data $lexicon_url2 + bzcat data/local/lexicon_data/ar-ar_grapheme_lexicon_2016-02-09.bz2 | sed '1,3d' | awk '{print $1}' > data/local/lexicon_data/grapheme_lexicon + bzcat data/local/lexicon_data/ar-ar_lexicon_2014-03-17.txt.bz2 | sed '1,3d' | awk '{print $1}' >> data/local/lexicon_data/grapheme_lexicon + cat data/train/text | cut -d ' ' -f 2- | tr -s " " "\n" | sort -u >> data/local/lexicon_data/grapheme_lexicon +fi + + +if [ $stage -le 0 ]; then + echo "$0: processing lexicon text and creating lexicon... $(date)." + # remove vowels and rare alef wasla + grep -v [0-9] data/local/lexicon_data/grapheme_lexicon | sed -e 's:[FNKaui\~o\`]::g' -e 's:{:}:g' | sort -u > data/local/lexicon_data/processed_lexicon + local/prepare_lexicon.py +fi + +cut -d' ' -f2- $dir/lexicon.txt | sed 's/SIL//g' | tr ' ' '\n' | sort -u | sed '/^$/d' >$dir/nonsilence_phones.txt || exit 1; + +sed -i '1i UNK' $dir/lexicon.txt + +echo UNK >> $dir/nonsilence_phones.txt + +echo ' SIL' >> $dir/lexicon.txt + +echo SIL > $dir/silence_phones.txt + +echo SIL >$dir/optional_silence.txt + +echo -n "" >$dir/extra_questions.txt + +echo "$0: Dictionary preparation succeeded" diff --git a/egs/gale_arabic/s5b/local/prepare_lexicon.py b/egs/gale_arabic/s5b/local/prepare_lexicon.py new file mode 100755 index 00000000000..215541585eb --- /dev/null +++ b/egs/gale_arabic/s5b/local/prepare_lexicon.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora +# Apache 2.0 + +# This script prepares lexicon. + +import argparse +import os + +parser = argparse.ArgumentParser(description="""Creates the list of characters and words in lexicon""") +args = parser.parse_args() + +### main ### +lex = {} +text_path = os.path.join('data','local', 'lexicon_data', 'processed_lexicon') +with open(text_path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + characters = list(line) + characters = " ".join(['V' if char == '*' else char for char in characters]) + lex[line] = characters + +with open(os.path.join('data','local','dict', 'lexicon.txt'), 'w', encoding='utf-8') as fp: + for key in sorted(lex): + fp.write(key + " " + lex[key] + "\n") diff --git a/egs/gale_arabic/s5b/local/prepare_lm.sh b/egs/gale_arabic/s5b/local/prepare_lm.sh new file mode 100755 index 00000000000..6fdf35f471a --- /dev/null +++ b/egs/gale_arabic/s5b/local/prepare_lm.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +# Copyright 2012 Vassil Panayotov +# 2017 Ewald Enzinger +# Apache 2.0 + +. ./path.sh || exit 1 + +echo "=== Building a language model ..." + +dir=data/local/lm/ +text=data/train/text +lexicon=data/local/dict/lexicon.txt +# Language model order +order=3 + +. utils/parse_options.sh + +# Prepare a LM training corpus from the transcripts +mkdir -p $dir + +for f in "$text" "$lexicon"; do + [ ! -f $f ] && echo "$0: No such file $f" && exit 1; +done + +loc=`which ngram-count`; +if [ -z $loc ]; then + if uname -a | grep 64 >/dev/null; then # some kind of 64 bit... + sdir=$KALDI_ROOT/tools/srilm/bin/i686-m64 + else + sdir=$KALDI_ROOT/tools/srilm/bin/i686 + fi + if [ -f $sdir/ngram-count ]; then + echo Using SRILM tools from $sdir + export PATH=$PATH:$sdir + else + echo You appear to not have SRILM tools installed, either on your path, + echo or installed in $sdir. See tools/install_srilm.sh for installation + echo instructions. + exit 1 + fi +fi + +cat data/train/text | cut -d " " -f 2- > $dir/text.txt +cut -d' ' -f1 $lexicon > $dir/wordlist + +ngram-count -text $dir/text.txt -order $order -limit-vocab -vocab $dir/wordlist \ + -unk -map-unk "" -kndiscount -interpolate -lm $dir/lm.gz + +#ngram -lm $dir/lm.gz -ppl $dir/dev.txt +echo "*** Finished building the LM model!" diff --git a/egs/gale_arabic/s5b/local/score.sh b/egs/gale_arabic/s5b/local/score.sh index 83366f7c7fc..1d84815fc69 100755 --- a/egs/gale_arabic/s5b/local/score.sh +++ b/egs/gale_arabic/s5b/local/score.sh @@ -1,60 +1,6 @@ -#!/bin/bash -# Copyright 2012 Johns Hopkins University (Author: Daniel Povey) -# Apache 2.0 - -[ -f ./path.sh ] && . ./path.sh - -# begin configuration section. -cmd=run.pl -stage=0 -decode_mbr=true -word_ins_penalty=0.0 -min_lmwt=7 -max_lmwt=17 -iter= #some of the scripts from steps/ seem to use it -#end configuration section. - -echo "$0 $#" - -[ -f ./path.sh ] && . ./path.sh -. parse_options.sh || exit 1; - -if [ $# -ne 3 ]; then - echo "Usage: local/score.sh [--cmd (run.pl|queue.pl...)] " - echo " Options:" - echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." - echo " --stage (0|1|2) # start scoring script from part-way through." - echo " --decode_mbr (true/false) # maximum bayes risk decoding (confusion network)." - echo " --min_lmwt # minumum LM-weight for lattice rescoring " - echo " --max_lmwt # maximum LM-weight for lattice rescoring " - exit 1; -fi -data=$1 -lang_or_graph=$2 -dir=$3 - -symtab=$lang_or_graph/words.txt - -for f in $symtab $dir/lat.1.gz $data/text; do - [ ! -f $f ] && echo "score.sh: no such file $f" && exit 1; -done - -mkdir -p $dir/scoring/log - -cat $data/text | sed 's:::g' | sed 's:::g' > $dir/scoring/test_filt.txt - -$cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/best_path.LMWT.log \ - lattice-scale --inv-acoustic-scale=LMWT "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ - lattice-add-penalty --word-ins-penalty=$word_ins_penalty ark:- ark:- \| \ - lattice-best-path --word-symbol-table=$symtab \ - ark:- ark,t:$dir/scoring/LMWT.tra || exit 1; +#!/bin/bash -# Note: the double level of quoting for the sed command -$cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/score.LMWT.log \ - cat $dir/scoring/LMWT.tra \| \ - utils/int2sym.pl -f 2- $symtab \| sed 's:\::g' \| \ - compute-wer --text --mode=present \ - ark:$dir/scoring/test_filt.txt ark,p:- ">&" $dir/wer_LMWT || exit 1; -exit 0; +steps/scoring/score_kaldi_wer.sh "$@" +steps/scoring/score_kaldi_cer.sh --stage 2 "$@" diff --git a/egs/gale_arabic/s5b/local/wer_output_filter b/egs/gale_arabic/s5b/local/wer_output_filter new file mode 100755 index 00000000000..cf48b434144 --- /dev/null +++ b/egs/gale_arabic/s5b/local/wer_output_filter @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Hossein Hadian + +# Apache 2.0 +# This script converts a BPE-encoded text to normal text. It is used in scoring + +import sys, io +import string + +infile = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') +output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') + +for line in infile: + words = line.strip().split() + words = [word for word in words if '' not in word] + uttid = words[0] + transcript = ' '.join(words[1:]) + output.write(uttid + ' ' + transcript + '\n') diff --git a/egs/gale_arabic/s5b/run.sh b/egs/gale_arabic/s5b/run.sh index c45f5119949..3f12d22495e 100755 --- a/egs/gale_arabic/s5b/run.sh +++ b/egs/gale_arabic/s5b/run.sh @@ -3,177 +3,121 @@ # Copyright 2014 QCRI (author: Ahmed Ali) # Apache 2.0 -. ./path.sh -. ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. - ## This relates to the queue. num_jobs=120 num_decode_jobs=40 +decode_gmm=true +stage=0 +overwrite=false -#NB: You can add whatever number of copora you like. The supported extensions -#NB: (formats) are wav and flac. Flac will be converted using sox and in contrast -#NB: with the old approach, the conversion will be on-the-fly and one-time-only -#NB: during the parametrization. - -#NB: Text corpora scpecification. We support either tgz files, which are unpacked -#NB: or just plain (already unpacked) directories. The list of transcript is then -#NB: obtained using find command - -#Make sure you edit this section to reflect whers you keep the LDC data on your cluster - -#This is CLSP configuration. We add the 2014 GALE data. We got around 2 % -#improvement just by including it. The gain might be large if someone would tweak -# the number of leaves and states and so on. - -#audio=( -# /export/corpora/LDC/LDC2013S02/ -# /export/corpora/LDC/LDC2013S07/ -# /export/corpora/LDC/LDC2014S07/ -#) -#text=( -# /export/corpora/LDC/LDC2013T17 -# /export/corpora/LDC/LDC2013T04 -# /export/corpora/LDC/LDC2014T17 -#) - -audio=( - /data/sls/scratch/amali/data/GALE/LDC2013S02 - /data/sls/scratch/amali/data/GALE/LDC2013S07 - /data/sls/scratch/amali/data/GALE/LDC2014S07 -) -text=( - /data/sls/scratch/amali/data/GALE/LDC2013T17.tgz - /data/sls/scratch/amali/data/GALE/LDC2013T04.tgz - /data/sls/scratch/amali/data/GALE/LDC2014T17.tgz -) +dir1=/export/corpora/LDC/LDC2013S02/ +dir2=/export/corpora/LDC/LDC2013S07/ +dir3=/export/corpora/LDC/LDC2014S07/ +text1=/export/corpora/LDC/LDC2013T17/ +text2=/export/corpora/LDC/LDC2013T04/ +text3=/export/corpora/LDC/LDC2014T17/ galeData=GALE -#prepare the data -#split train dev test -#prepare lexicon and LM - -# You can run the script from here automatically, but it is recommended to run the data preparation, -# and features extraction manually and and only once. -# By copying and pasting into your shell. - -#copy the audio files to local folder wav and convet flac files to wav -local/gale_data_prep_audio.sh "${audio[@]}" $galeData || exit 1; +. ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. + ## This relates to the queue. +. ./path.sh +. ./utils/parse_options.sh # e.g. this parses the above options + # if supplied. -#get the transcription and remove empty prompts and all noise markers -local/gale_data_prep_txt.sh "${text[@]}" $galeData || exit 1; +if [ $stage -le 0 ]; then -# split the data to reports and conversational and for each class will have rain/dev and test -local/gale_data_prep_split.sh $galeData || exit 1; + if [ -f data/train/text ] && ! $overwrite; then + echo "$0: Not processing, probably script have run from wrong stage" + echo "Exiting with status 1 to avoid data corruption" + exit 1; + fi -# get all Arabic grapheme dictionaries and add silence and UNK -local/gale_prep_grapheme_dict.sh || exit 1; + echo "$0: Preparing data..." + local/prepare_data.sh --dir1 $dir1 --dir2 $dir2 --dir3 $dir3 \ + --text1 $text1 --text2 $text2 --text3 $text3 + echo "$0: Preparing lexicon and LM..." + local/prepare_dict.sh -#prepare the langauge resources -utils/prepare_lang.sh data/local/dict "" data/local/lang data/lang || exit 1; + utils/prepare_lang.sh data/local/dict "" data/local/lang data/lang -# LM training -local/gale_train_lms.sh || exit 1; + local/prepare_lm.sh -local/gale_format_data.sh || exit 1; -# G compilation, check LG composition + utils/format_lm.sh data/lang data/local/lm/lm.gz \ + data/local/dict/lexicon.txt data/lang_test +fi -# Now make MFCC features. -# mfccdir should be some place with a largish disk where you -# want to store MFCC features. mfccdir=mfcc - -for x in train test ; do - steps/make_mfcc.sh --cmd "$train_cmd" --nj $num_jobs \ - data/$x exp/make_mfcc/$x $mfccdir - utils/fix_data_dir.sh data/$x # some files fail to get mfcc for many reasons - steps/compute_cmvn_stats.sh data/$x exp/make_mfcc/$x $mfccdir -done - - -# Here we start the AM - -# Let's create a subset with 10k segments to make quick flat-start training: -utils/subset_data_dir.sh data/train 10000 data/train.10K || exit 1; - -# Train monophone models on a subset of the data, 10K segment -# Note: the --boost-silence option should probably be omitted by default -steps/train_mono.sh --nj 40 --cmd "$train_cmd" \ - data/train.10K data/lang exp/mono || exit 1; - - -# Get alignments from monophone system. -steps/align_si.sh --nj $num_jobs --cmd "$train_cmd" \ - data/train data/lang exp/mono exp/mono_ali || exit 1; - -# train tri1 [first triphone pass] -steps/train_deltas.sh --cmd "$train_cmd" \ - 2500 30000 data/train data/lang exp/mono_ali exp/tri1 || exit 1; - -# First triphone decoding -utils/mkgraph.sh data/lang_test exp/tri1 exp/tri1/graph -steps/decode.sh --nj $num_decode_jobs --cmd "$decode_cmd" \ - exp/tri1/graph data/test exp/tri1/decode - -steps/align_si.sh --nj $num_jobs --cmd "$train_cmd" \ - data/train data/lang exp/tri1 exp/tri1_ali || exit 1; - -# Train tri2a, which is deltas+delta+deltas -steps/train_deltas.sh --cmd "$train_cmd" \ - 3000 40000 data/train data/lang exp/tri1_ali exp/tri2a || exit 1; - -# tri2a decoding -utils/mkgraph.sh data/lang_test exp/tri2a exp/tri2a/graph -steps/decode.sh --nj $num_decode_jobs --cmd "$decode_cmd" \ - exp/tri2a/graph data/test exp/tri2a/decode - -# train and decode tri2b [LDA+MLLT] -steps/train_lda_mllt.sh --cmd "$train_cmd" 4000 50000 \ - data/train data/lang exp/tri1_ali exp/tri2b || exit 1; - -utils/mkgraph.sh data/lang_test exp/tri2b exp/tri2b/graph -steps/decode.sh --nj $num_decode_jobs --cmd "$decode_cmd" \ - exp/tri2b/graph data/test exp/tri2b/decode - -# Align all data with LDA+MLLT system (tri2b) -steps/align_si.sh --nj $num_jobs --cmd "$train_cmd" \ - --use-graphs true data/train data/lang exp/tri2b exp/tri2b_ali || exit 1; - - -# From 2b system, train 3b which is LDA + MLLT + SAT. -steps/train_sat.sh --cmd "$train_cmd" \ - 5000 100000 data/train data/lang exp/tri2b_ali exp/tri3b || exit 1; - -utils/mkgraph.sh data/lang_test exp/tri3b exp/tri3b/graph -steps/decode_fmllr.sh --nj $num_decode_jobs --cmd \ - "$decode_cmd" exp/tri3b/graph data/test exp/tri3b/decode - -# From 3b system, align all data. -steps/align_fmllr.sh --nj $num_jobs --cmd "$train_cmd" \ - data/train data/lang exp/tri3b exp/tri3b_ali || exit 1; - - -# nnet3 cross-entropy -local/nnet3/run_tdnn.sh #tdnn recipe: -local/nnet3/run_lstm.sh --stage 12 #lstm recipe (we skip ivector training) - -# chain lattice-free -local/chain/run_tdnn.sh #tdnn recipe: -local/chain/run_tdnn_lstm.sh #tdnn-lstm recipe: - -time=$(date +"%Y-%m-%d-%H-%M-%S") - -#get detailed WER; reports, conversational and combined -local/split_wer.sh $galeData > RESULTS.details.$USER.$time # to make sure you keep the results timed and owned - -echo training succedded +if [ $stage -le 1 ]; then + echo "$0: Preparing the test and train feature files..." + for x in train test ; do + steps/make_mfcc.sh --cmd "$train_cmd" --nj $num_jobs \ + data/$x exp/make_mfcc/$x $mfccdir + utils/fix_data_dir.sh data/$x # some files fail to get mfcc for many reasons + steps/compute_cmvn_stats.sh data/$x exp/make_mfcc/$x $mfccdir + done +fi + +if [ $stage -le 2 ]; then + echo "$0: creating sub-set and training monophone system" + utils/subset_data_dir.sh data/train 10000 data/train.10K || exit 1; + + steps/train_mono.sh --nj 40 --cmd "$train_cmd" \ + data/train.10K data/lang exp/mono || exit 1; +fi + +if [ $stage -le 3 ]; then + echo "$0: Aligning data using monophone system" + steps/align_si.sh --nj $num_jobs --cmd "$train_cmd" \ + data/train data/lang exp/mono exp/mono_ali || exit 1; + + echo "$0: training triphone system with delta features" + steps/train_deltas.sh --cmd "$train_cmd" \ + 2500 30000 data/train data/lang exp/mono_ali exp/tri1 || exit 1; +fi + +if [ $stage -le 4 ] && $decode_gmm; then + utils/mkgraph.sh data/lang_test exp/tri1 exp/tri1/graph + steps/decode.sh --nj $num_decode_jobs --cmd "$decode_cmd" \ + exp/tri1/graph data/test exp/tri1/decode +fi + +if [ $stage -le 5 ]; then + echo "$0: Aligning data and retraining and realigning with lda_mllt" + steps/align_si.sh --nj $num_jobs --cmd "$train_cmd" \ + data/train data/lang exp/tri1 exp/tri1_ali || exit 1; + + steps/train_lda_mllt.sh --cmd "$train_cmd" 4000 50000 \ + data/train data/lang exp/tri1_ali exp/tri2b || exit 1; +fi + +if [ $stage -le 6 ] && $decode_gmm; then + utils/mkgraph.sh data/lang_test exp/tri2b exp/tri2b/graph + steps/decode.sh --nj $num_decode_jobs --cmd "$decode_cmd" \ + exp/tri2b/graph data/test exp/tri2b/decode +fi + +if [ $stage -le 7 ]; then + echo "$0: Aligning data and retraining and realigning with sat_basis" + steps/align_si.sh --nj $num_jobs --cmd "$train_cmd" \ + data/train data/lang exp/tri2b exp/tri2b_ali || exit 1; + + steps/train_sat_basis.sh --cmd "$train_cmd" \ + 5000 100000 data/train data/lang exp/tri2b_ali exp/tri3b || exit 1; + + steps/align_fmllr.sh --nj $num_jobs --cmd "$train_cmd" \ + data/train data/lang exp/tri3b exp/tri3b_ali || exit 1; +fi + +if [ $stage -le 8 ] && $decode_gmm; then + utils/mkgraph.sh data/lang_test exp/tri3b exp/tri3b/graph + steps/decode_fmllr.sh --nj $num_decode_jobs --cmd \ + "$decode_cmd" exp/tri3b/graph data/test exp/tri3b/decode +fi + +if [ $stage -le 9 ]; then + echo "$0: Training a regular chain model using the e2e alignments..." + local/chain/run_tdnn.sh +fi + +echo "$0: training succedded" exit 0 - -#TODO: -#LM (4-gram and RNN) rescoring -#combine lattices -#dialect detection - - - - - diff --git a/egs/gale_mandarin/s5/local/gale_segment.py b/egs/gale_mandarin/s5/local/gale_segment.py index 975ddb9c143..d652eb837f3 100755 --- a/egs/gale_mandarin/s5/local/gale_segment.py +++ b/egs/gale_mandarin/s5/local/gale_segment.py @@ -1,6 +1,7 @@ #!/usr/bin/env python #coding:utf-8 #!/usr/bin/env python +from __future__ import print_function import sys from mmseg import seg_txt for line in sys.stdin: @@ -12,4 +13,4 @@ continue for j in seg_txt(blks[i]): out_line += " " + j - print out_line + print(out_line) diff --git a/egs/gp/s1/local/gp_convert_audio.sh b/egs/gp/s1/local/gp_convert_audio.sh index a7c2d7285c4..b3db909c9b6 100755 --- a/egs/gp/s1/local/gp_convert_audio.sh +++ b/egs/gp/s1/local/gp_convert_audio.sh @@ -108,4 +108,4 @@ done < "$INLIST" echo "sox: error converting following $nsoxerr file(s):" >&2 [ -f "$soxerr" ] && cat "$soxerr" >&2 -exit 0; \ No newline at end of file +exit 0; diff --git a/egs/gp/s1/utils/mkgraph.sh b/egs/gp/s1/utils/mkgraph.sh index 2e45296593b..3aba742832d 100755 --- a/egs/gp/s1/utils/mkgraph.sh +++ b/egs/gp/s1/utils/mkgraph.sh @@ -131,4 +131,4 @@ cp $lang/silphones.csl $dir/ # to make const fst: # fstconvert --fst_type=const $dir/HCLG.fst $dir/HCLG_c.fst -echo "Finished making decoding graphs in $dir" \ No newline at end of file +echo "Finished making decoding graphs in $dir" diff --git a/egs/heroico/s5/RESULTS b/egs/heroico/s5/RESULTS index 9717e95e6e2..7942c03b1d9 100644 --- a/egs/heroico/s5/RESULTS +++ b/egs/heroico/s5/RESULTS @@ -1,22 +1,48 @@ # for dir in $(echo exp/tri*/decode* | grep -v 'si/'); do grep WER $dir/wer* | utils/best_wer.sh; done -%WER 67.01 [ 5126 / 7650, 837 ins, 575 del, 3714 sub ] exp/tri1/decode_devtest/wer_14_1.0 -%WER 62.39 [ 4678 / 7498, 768 ins, 397 del, 3513 sub ] exp/tri1/decode_native/wer_13_1.0 -%WER 67.05 [ 6179 / 9215, 895 ins, 606 del, 4678 sub ] exp/tri1/decode_nonnative/wer_13_1.0 -%WER 64.97 [ 10859 / 16713, 1678 ins, 999 del, 8182 sub ] exp/tri1/decode_test/wer_13_1.0 -%WER 65.90 [ 5041 / 7650, 1016 ins, 416 del, 3609 sub ] exp/tri2b/decode_devtest/wer_12_1.0 -%WER 61.26 [ 4593 / 7498, 908 ins, 300 del, 3385 sub ] exp/tri2b/decode_native/wer_14_1.0 -%WER 67.51 [ 6221 / 9215, 1085 ins, 524 del, 4612 sub ] exp/tri2b/decode_nonnative/wer_14_1.0 -%WER 64.87 [ 10842 / 16713, 2004 ins, 838 del, 8000 sub ] exp/tri2b/decode_test/wer_14_1.0 -%WER 66.09 [ 5056 / 7650, 1078 ins, 402 del, 3576 sub ] exp/tri3b/decode_devtest/wer_16_1.0 -%WER 74.88 [ 5728 / 7650, 1210 ins, 426 del, 4092 sub ] exp/tri3b/decode_devtest.si/wer_15_1.0 -%WER 61.19 [ 4588 / 7498, 1038 ins, 255 del, 3295 sub ] exp/tri3b/decode_native/wer_14_1.0 -%WER 70.99 [ 5323 / 7498, 1185 ins, 301 del, 3837 sub ] exp/tri3b/decode_native.si/wer_16_1.0 -%WER 66.35 [ 6114 / 9215, 1186 ins, 421 del, 4507 sub ] exp/tri3b/decode_nonnative/wer_17_1.0 -%WER 76.36 [ 7037 / 9215, 1420 ins, 467 del, 5150 sub ] exp/tri3b/decode_nonnative.si/wer_16_1.0 -%WER 64.06 [ 10706 / 16713, 2245 ins, 657 del, 7804 sub ] exp/tri3b/decode_test/wer_15_1.0 -%WER 73.97 [ 12362 / 16713, 2608 ins, 766 del, 8988 sub ] exp/tri3b/decode_test.si/wer_16_1.0 -%WER 53.07 [ 4060 / 7650, 744 ins, 376 del, 2940 sub ] exp/chain/tdnn1e_sp/decode_devtest/wer_7_1.0 -%WER 54.47 [ 4084 / 7498, 536 ins, 475 del, 3073 sub ] exp/chain/tdnn1e_sp/decode_native/wer_7_1.0 -%WER 63.01 [ 5806 / 9215, 685 ins, 784 del, 4337 sub ] exp/chain/tdnn1e_sp/decode_nonnative/wer_7_1.0 -%WER 59.25 [ 9903 / 16713, 1226 ins, 1259 del, 7418 sub ] exp/chain/tdnn1e_sp/decode_test/wer_7_1.0 +# old results before adding Movie subtitles text corpus in LM training: +# %WER 67.01 [ 5126 / 7650, 837 ins, 575 del, 3714 sub ] exp/tri1/decode_devtest/wer_14_1.0 +# %WER 62.39 [ 4678 / 7498, 768 ins, 397 del, 3513 sub ] exp/tri1/decode_native/wer_13_1.0 +# %WER 67.05 [ 6179 / 9215, 895 ins, 606 del, 4678 sub ] exp/tri1/decode_nonnative/wer_13_1.0 +# %WER 64.97 [ 10859 / 16713, 1678 ins, 999 del, 8182 sub ] exp/tri1/decode_test/wer_13_1.0 +# %WER 65.90 [ 5041 / 7650, 1016 ins, 416 del, 3609 sub ] exp/tri2b/decode_devtest/wer_12_1.0 +# %WER 61.26 [ 4593 / 7498, 908 ins, 300 del, 3385 sub ] exp/tri2b/decode_native/wer_14_1.0 +# %WER 67.51 [ 6221 / 9215, 1085 ins, 524 del, 4612 sub ] exp/tri2b/decode_nonnative/wer_14_1.0 +# %WER 64.87 [ 10842 / 16713, 2004 ins, 838 del, 8000 sub ] exp/tri2b/decode_test/wer_14_1.0 +# %WER 66.09 [ 5056 / 7650, 1078 ins, 402 del, 3576 sub ] exp/tri3b/decode_devtest/wer_16_1.0 +# %WER 74.88 [ 5728 / 7650, 1210 ins, 426 del, 4092 sub ] exp/tri3b/decode_devtest.si/wer_15_1.0 +# %WER 61.19 [ 4588 / 7498, 1038 ins, 255 del, 3295 sub ] exp/tri3b/decode_native/wer_14_1.0 +# %WER 70.99 [ 5323 / 7498, 1185 ins, 301 del, 3837 sub ] exp/tri3b/decode_native.si/wer_16_1.0 +# %WER 66.35 [ 6114 / 9215, 1186 ins, 421 del, 4507 sub ] exp/tri3b/decode_nonnative/wer_17_1.0 +# %WER 76.36 [ 7037 / 9215, 1420 ins, 467 del, 5150 sub ] exp/tri3b/decode_nonnative.si/wer_16_1.0 +# %WER 64.06 [ 10706 / 16713, 2245 ins, 657 del, 7804 sub ] exp/tri3b/decode_test/wer_15_1.0 +# %WER 73.97 [ 12362 / 16713, 2608 ins, 766 del, 8988 sub ] exp/tri3b/decode_test.si/wer_16_1.0 +# %WER 53.07 [ 4060 / 7650, 744 ins, 376 del, 2940 sub ] exp/chain/tdnn1e_sp/decode_devtest/wer_7_1.0 +# %WER 54.47 [ 4084 / 7498, 536 ins, 475 del, 3073 sub ] exp/chain/tdnn1e_sp/decode_native/wer_7_1.0 +# %WER 63.01 [ 5806 / 9215, 685 ins, 784 del, 4337 sub ] exp/chain/tdnn1e_sp/decode_nonnative/wer_7_1.0 +# %WER 59.25 [ 9903 / 16713, 1226 ins, 1259 del, 7418 sub ] exp/chain/tdnn1e_sp/decode_test/wer_7_1.0 + +# new results: +%WER 18.27 [ 1398 / 7650, 213 ins, 253 del, 932 sub ] exp/tri1/decode_devtest/wer_15_0.5 +%WER 9.95 [ 746 / 7498, 74 ins, 108 del, 564 sub ] exp/tri1/decode_native/wer_13_0.5 +%WER 16.63 [ 1532 / 9215, 197 ins, 183 del, 1152 sub ] exp/tri1/decode_nonnative/wer_17_0.0 +%WER 13.68 [ 2287 / 16713, 207 ins, 360 del, 1720 sub ] exp/tri1/decode_test/wer_17_0.5 +%WER 17.19 [ 1315 / 7650, 227 ins, 231 del, 857 sub ] exp/tri2b/decode_devtest/wer_17_0.5 +%WER 9.23 [ 692 / 7498, 60 ins, 103 del, 529 sub ] exp/tri2b/decode_native/wer_16_0.5 +%WER 17.16 [ 1581 / 9215, 184 ins, 216 del, 1181 sub ] exp/tri2b/decode_nonnative/wer_17_0.5 +%WER 13.64 [ 2279 / 16713, 241 ins, 326 del, 1712 sub ] exp/tri2b/decode_test/wer_17_0.5 +%WER 15.36 [ 1175 / 7650, 212 ins, 210 del, 753 sub ] exp/tri3b/decode_devtest/wer_17_0.5 +%WER 20.27 [ 1551 / 7650, 269 ins, 257 del, 1025 sub ] exp/tri3b/decode_devtest.si/wer_14_1.0 +%WER 6.40 [ 480 / 7498, 50 ins, 58 del, 372 sub ] exp/tri3b/decode_native/wer_16_0.0 +%WER 10.91 [ 818 / 7498, 100 ins, 112 del, 606 sub ] exp/tri3b/decode_native.si/wer_16_1.0 +%WER 14.30 [ 1318 / 9215, 206 ins, 134 del, 978 sub ] exp/tri3b/decode_nonnative/wer_17_0.0 +%WER 21.62 [ 1992 / 9215, 286 ins, 224 del, 1482 sub ] exp/tri3b/decode_nonnative.si/wer_16_1.0 +%WER 10.78 [ 1802 / 16713, 247 ins, 195 del, 1360 sub ] exp/tri3b/decode_test/wer_17_0.0 +%WER 16.81 [ 2809 / 16713, 374 ins, 338 del, 2097 sub ] exp/tri3b/decode_test.si/wer_16_1.0 + +# chain model results: +# for dir in $(echo exp/chain/tdnn1b_sp/decode* | grep -v 'si/'); do grep WER $dir/wer* | utils/best_wer.sh; done +%WER 12.99 [ 994 / 7650, 192 ins, 163 del, 639 sub ] exp/chain/tdnn1b_sp/decode_devtest/wer_10_1.0 +%WER 12.47 [ 1149 / 9215, 119 ins, 174 del, 856 sub ] exp/chain/tdnn1b_sp/decode_nonnative/wer_12_0.0 +%WER 9.64 [ 1611 / 16713, 169 ins, 240 del, 1202 sub ] exp/chain/tdnn1b_sp/decode_test/wer_12_0.0 +%WER 6.13 [ 460 / 7498, 52 ins, 55 del, 353 sub ] exp/chain/tdnn1b_sp/decode_native/wer_10_0.0 diff --git a/egs/heroico/s5/cmd.sh b/egs/heroico/s5/cmd.sh index a427f3c16a5..533aad25db1 100755 --- a/egs/heroico/s5/cmd.sh +++ b/egs/heroico/s5/cmd.sh @@ -10,6 +10,7 @@ # conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, # or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. +export cmd="retry.pl queue.pl" export train_cmd="retry.pl queue.pl" export decode_cmd="retry.pl queue.pl --mem 2G" diff --git a/egs/heroico/s5/local/chain/tuning/run_cnn_tdnn_1a.sh b/egs/heroico/s5/local/chain/tuning/run_cnn_tdnn_1a.sh new file mode 100755 index 00000000000..361879b4142 --- /dev/null +++ b/egs/heroico/s5/local/chain/tuning/run_cnn_tdnn_1a.sh @@ -0,0 +1,318 @@ +#!/bin/bash + +# run_cnn_tdnn_1a.sh is modified from run_tdnn_1b.sh but taking +# the xconfig from mini-librispeech's run_cnn_tdnn_1a54.sh; only +# reducing the bottleneck-dim from 96 to 64, which is the value +# the run_tdnn1b.sh script here has. Results are better. +# local/chain/compare_wer.sh exp/chain/tdnn1a_sp exp/chain/tdnn1b_sp exp/chain/cnn_tdnn1a_sp +# System tdnn1a_sp tdnn1b_sp cnn_tdnn1a_sp +# %WER devtest 53.07 52.54 51.10 +# %WER test 59.25 53.70 52.07 +# %WER native 54.47 48.76 47.88 +# %WER nonnative 63.01 57.66 55.51 +# Final train prob -0.0253 -0.0547 -0.0502 +# Final valid prob -0.0687 -0.0694 -0.0661 +# Final train prob (xent) -0.7715 -0.9502 -0.8513 +# Final valid prob (xent) -1.0719 -1.0849 -0.9915 +# Num-params 6567648 3321312 3345088 + +# Set -e here so that we catch if any executable fails immediately +set -euo pipefail + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +decode_nj=10 +train_set=train +test_sets="native nonnative devtest test" +gmm=tri3b +nnet3_affix= + +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +affix=1a # affix for the TDNN directory name +tree_affix= +train_stage=-10 +get_egs_stage=-10 +decode_iter= + +num_leaves=3500 + +# training options +# training chunk-options +chunk_width=140,100,160 +# we don't need extra left/right context for TDNN systems. +dropout_schedule='0,0@0.20,0.3@0.50,0' +common_egs_dir= +xent_regularize=0.1 + +# training options +srand=0 +remove_egs=true +reporting_email= + +#decode options +test_online_decoding=false # if true, it will run the last decoding stage. + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 11 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj 75 --cmd "$train_cmd" ${lores_train_data_dir} \ + data/lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 12 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh \ + --cmd "$train_cmd" \ + --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + $num_leaves \ + ${lores_train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 13 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + + cnn_opts="l2-regularize=0.03" + ivector_layer_opts="l2-regularize=0.03" + ivector_affine_opts="l2-regularize=0.03" + tdnn_opts="l2-regularize=0.03 dropout-proportion=0.0 dropout-per-dim-continuous=true" + tdnnf_first_opts="l2-regularize=0.03 dropout-proportion=0.0 bypass-scale=0.0" + tdnnf_opts="l2-regularize=0.03 dropout-proportion=0.0 bypass-scale=0.66" + linear_opts="l2-regularize=0.03 orthonormal-constraint=-1.0" + prefinal_opts="l2-regularize=0.03" + output_opts="l2-regularize=0.015" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # this takes the MFCCs and generates filterbank coefficients. The MFCCs + # are more compressible so we prefer to dump the MFCCs to disk rather + # than filterbanks. + idct-layer name=idct input=input dim=40 cepstral-lifter=22 affine-transform-file=$dir/configs/idct.mat + + linear-component name=ivector-linear $ivector_affine_opts dim=200 input=ReplaceIndex(ivector, t, 0) + batchnorm-component name=ivector-batchnorm target-rms=0.025 + + batchnorm-component name=idct-batchnorm input=idct + combine-feature-maps-layer name=combine_inputs input=Append(idct-batchnorm, ivector-batchnorm) num-filters1=1 num-filters2=5 height=40 + + conv-relu-batchnorm-layer name=cnn1 $cnn_opts height-in=40 height-out=40 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=48 learning-rate-factor=0.333 max-change=0.25 + conv-relu-batchnorm-layer name=cnn2 $cnn_opts height-in=40 height-out=40 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=48 + conv-relu-batchnorm-layer name=cnn3 $cnn_opts height-in=40 height-out=20 height-subsample-out=2 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=64 + conv-relu-batchnorm-layer name=cnn4 $cnn_opts height-in=20 height-out=20 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=64 + conv-relu-batchnorm-layer name=cnn5 $cnn_opts height-in=20 height-out=10 height-subsample-out=2 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=64 + conv-relu-batchnorm-layer name=cnn6 $cnn_opts height-in=10 height-out=5 height-subsample-out=2 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=128 + + # the first TDNN-F layer has no bypass (since dims don't match), and a larger bottleneck so the + # information bottleneck doesn't become a problem. (we use time-stride=0 so no splicing, to + # limit the num-parameters). + tdnnf-layer name=tdnnf7 $tdnnf_first_opts dim=768 bottleneck-dim=192 time-stride=0 + tdnnf-layer name=tdnnf8 $tdnnf_opts dim=768 bottleneck-dim=64 time-stride=3 + tdnnf-layer name=tdnnf9 $tdnnf_opts dim=768 bottleneck-dim=64 time-stride=3 + tdnnf-layer name=tdnnf10 $tdnnf_opts dim=768 bottleneck-dim=64 time-stride=3 + tdnnf-layer name=tdnnf11 $tdnnf_opts dim=768 bottleneck-dim=64 time-stride=3 + tdnnf-layer name=tdnnf12 $tdnnf_opts dim=768 bottleneck-dim=64 time-stride=3 + tdnnf-layer name=tdnnf13 $tdnnf_opts dim=768 bottleneck-dim=64 time-stride=3 + tdnnf-layer name=tdnnf14 $tdnnf_opts dim=768 bottleneck-dim=64 time-stride=3 + tdnnf-layer name=tdnnf15 $tdnnf_opts dim=768 bottleneck-dim=64 time-stride=3 + linear-component name=prefinal-l dim=192 $linear_opts + + ## adding the layers for chain branch + prefinal-layer name=prefinal-chain input=prefinal-l $prefinal_opts small-dim=192 big-dim=768 + output-layer name=output include-log-softmax=false dim=$num_targets $output_opts + + # adding the layers for xent branch + prefinal-layer name=prefinal-xent input=prefinal-l $prefinal_opts small-dim=192 big-dim=768 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor $output_opts +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 14 ]; then + steps/nnet3/chain/train.py \ + --stage=$train_stage \ + --cmd="$decode_cmd" \ + --feat.online-ivector-dir=$train_ivector_dir \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.0 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=8 \ + --trainer.frames-per-iter=3000000 \ + --trainer.optimization.num-jobs-initial=2 \ + --trainer.optimization.num-jobs-final=5 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.num-chunk-per-minibatch=128,64 \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 15 ]; then + # Note: it's not important to give mkgraph.sh the lang directory with the + # matched topology (since it gets the topology file from the model). + utils/mkgraph.sh \ + --self-loop-scale 1.0 \ + data/lang_test \ + $tree_dir \ + $tree_dir/graph || exit 1; +fi + +if [ $stage -le 16 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + rm $dir/.error 2>/dev/null || true + + for data in $test_sets; do + ( + nspk=$(wc -l /dev/null || true + + for data in $test_sets; do + ( + nspk=$(wc -l 1392 combine=-0.040->-0.033 (over 7) xent:train/valid[69,104,final]=(-1.12,-0.880,-0.771/-1.33,-1.21,-1.07) logprob:train/valid[69,104,final]=(-0.050,-0.031,-0.025/-0.079,-0.080,-0.069) +# exp/chain/tdnn1a_sp: num-iters=105 nj=1..1 num-params=6.6M dim=40+100->1384 combine=-0.032->-0.026 (over 7) xent:train/valid[69,104,final]=(-1.14,-0.892,-0.811/-1.19,-1.07,-0.990) logprob:train/valid[69,104,final]=(-0.045,-0.029,-0.023/-0.083,-0.080,-0.072) # Set -e here so that we catch if any executable fails immediately set -euo pipefail @@ -149,7 +150,7 @@ if [ $stage -le 13 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.01" output_opts="l2-regularize=0.0025" diff --git a/egs/heroico/s5/local/chain/tuning/run_tdnn_1b.sh b/egs/heroico/s5/local/chain/tuning/run_tdnn_1b.sh index 33ce1556d29..cfb4dc1f697 100755 --- a/egs/heroico/s5/local/chain/tuning/run_tdnn_1b.sh +++ b/egs/heroico/s5/local/chain/tuning/run_tdnn_1b.sh @@ -3,21 +3,20 @@ # 1b is as 1a but a re-tuned model with quite a few changes, including moving to # a resnet-style factored TDNN-F model. # -# local/chain/compare_wer.sh exp/chain/tdnn1a_sp exp/chain/tdnn1b_sp +# ./local/chain/compare_wer.sh exp/chain/tdnn1a_sp exp/chain/tdnn1b_sp # System tdnn1a_sp tdnn1b_sp -# %WER devtest 53.07 52.54 -# %WER test 59.25 53.70 -# %WER native 54.47 48.76 -# %WER nonnative 63.01 57.66 -# Final train prob -0.0253 -0.0547 -# Final valid prob -0.0687 -0.0694 -# Final train prob (xent) -0.7715 -0.9502 -# Final valid prob (xent) -1.0719 -1.0849 -# Num-params 6567648 3321312 - +# %WER devtest 13.10 12.99 +# %WER test 15.53 9.64 +# %WER native 10.14 6.13 +# %WER nonnative 19.78 12.47 +# Final train prob -0.0233 -0.0442 +# Final valid prob -0.0720 -0.0726 +# Final train prob (xent) -0.8107 -0.9759 +# Final valid prob (xent) -0.9898 -0.9964 +# Num-params 6559440 3318224 # steps/info/chain_dir_info.pl exp/chain/tdnn1b_sp -# exp/chain/tdnn1b_sp: num-iters=34 nj=2..5 num-params=3.3M dim=40+100->1392 combine=-0.059->-0.059 (over 1) xent:train/valid[21,33,final]=(-1.28,-0.986,-0.950/-1.38,-1.10,-1.08) logprob:train/valid[21,33,final]=(-0.085,-0.063,-0.055/-0.090,-0.074,-0.069) +# exp/chain/tdnn1b_sp: num-iters=34 nj=2..5 num-params=3.3M dim=40+100->1384 combine=-0.044->-0.044 (over 1) xent:train/valid[21,33,final]=(-1.30,-0.993,-0.976/-1.28,-1.01,-0.996) logprob:train/valid[21,33,final]=(-0.071,-0.050,-0.044/-0.093,-0.076,-0.073) # Set -e here so that we catch if any executable fails immediately set -euo pipefail @@ -152,7 +151,7 @@ if [ $stage -le 13 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) affine_opts="l2-regularize=0.03 dropout-proportion=0.0 dropout-per-dim-continuous=true" tdnnf_opts="l2-regularize=0.03 dropout-proportion=0.0 bypass-scale=0.66" linear_opts="l2-regularize=0.03 orthonormal-constraint=-1.0" diff --git a/egs/heroico/s5/local/heroico_answers_make_lists.pl b/egs/heroico/s5/local/heroico_answers_make_lists.pl index fb3c0ecb8d1..c1a3735b4f1 100755 --- a/egs/heroico/s5/local/heroico_answers_make_lists.pl +++ b/egs/heroico/s5/local/heroico_answers_make_lists.pl @@ -30,7 +30,7 @@ my $t = "$tmpdir/answers/text"; # initialize hash for prompts -my %p = (); +my %prompts = (); # store prompts in hash LINEA: while ( my $line = <> ) { @@ -40,9 +40,27 @@ my @dirs = split /\//, $directories; # get the speaker number my $s = $dirs[-1]; + # pad the speaker number with zeroes + my $spk = ""; + if ( $s < 10 ) { + $spk = '000' . $s; + } elsif ( $s < 100 ) { + $spk = '00' . $s; + } elsif ( $s < 1000 ) { + $spk = '0' . $s; + } + # pad the filename with zeroes + my $fn = ""; + if ( $file < 10 ) { + $fn = '000' . $file; + } elsif ( $file < 100 ) { + $fn = '00' . $file; + } elsif ( $file < 1000 ) { + $fn = '0' . $file; + } # the utterance name - my $i = $s . '_' . 'a' . '_' . $file; - $p{$i} = $sent; + my $utt = $spk . '_' . $fn; + $prompts{$utt} = $sent; } open my $W, '<', $w or croak "problem with $w $!"; @@ -58,18 +76,36 @@ my @dirs = split /\//, $directories; my $r = basename $line, ".wav"; my $s = $dirs[-1]; - my $rid = $s . '_' . 'a' . '_' . $r; - if ( exists $p{$rid} ) { - print $T "$rid $p{$rid}\n"; - } elsif ( defined $rid ) { - warn "warning: problem\t$rid"; + my $spk = ""; + # pad with zeroes + if ( $s < 10 ) { + $spk = '000' . $s; + } elsif ( $s < 100 ) { + $spk = '00' . $s; + } elsif ( $s < 1000 ) { + $spk = '0' . $s; + } + # pad the file name with zeroes + my $rec = ""; + if ( $r < 10 ) { + $rec = '000' . $r; + } elsif ( $r < 100 ) { + $rec = '00' . $r; + } elsif ( $r < 1000 ) { + $rec = '0' . $r; + } + my $rec_id = $spk . '_' . $rec; + if ( exists $prompts{$rec_id} ) { + print $T "$rec_id $prompts{$rec_id}\n"; + } elsif ( defined $rec_id ) { + warn "warning: problem\t$rec_id"; next LINE; } else { croak "$line"; } - print $O "$rid sox -r 22050 -e signed -b 16 $line -r 16000 -t wav - |\n"; - print $U "$rid ${s}_a\n"; + print $O "$rec_id sox -r 22050 -e signed -b 16 $line -r 16000 -t wav - |\n"; + print $U "$rec_id $spk\n"; } close $T; close $O; diff --git a/egs/heroico/s5/local/heroico_recordings_make_lists.pl b/egs/heroico/s5/local/heroico_recordings_make_lists.pl index 1d157665799..b9a3ab5a565 100755 --- a/egs/heroico/s5/local/heroico_recordings_make_lists.pl +++ b/egs/heroico/s5/local/heroico_recordings_make_lists.pl @@ -19,75 +19,102 @@ system "mkdir -p $tmpdir/recordings/devtest"; # input wav file list -my $w = "$tmpdir/wav_list.txt"; +my $input_wav_list = "$tmpdir/wav_list.txt"; # output temporary wav.scp files -my $o_train = "$tmpdir/recordings/train/wav.scp"; -my $o_test = "$tmpdir/recordings/devtest/wav.scp"; +my $train_wav_scp = "$tmpdir/recordings/train/wav.scp"; +my $test_wav_scp = "$tmpdir/recordings/devtest/wav.scp"; # output temporary utt2spk files -my $u_train = "$tmpdir/recordings/train/utt2spk"; -my $u_test = "$tmpdir/recordings/devtest/utt2spk"; +my $train_uttspk = "$tmpdir/recordings/train/utt2spk"; +my $test_uttspk = "$tmpdir/recordings/devtest/utt2spk"; # output temporary text files -my $t_train = "$tmpdir/recordings/train/text"; -my $t_test = "$tmpdir/recordings/devtest/text"; +my $train_text = "$tmpdir/recordings/train/text"; +my $test_text = "$tmpdir/recordings/devtest/text"; # initialize hash for prompts -my %p = (); +my %prompts = (); # store prompts in hash LINEA: while ( my $line = <> ) { chomp $line; - my ($s,$sent) = split /\t/, $line, 2; - $p{$s} = $sent; + my ($prompt_id,$prompt) = split /\t/, $line, 2; + # pad the prompt id with zeroes + my $pid = ""; + if ( $prompt_id < 10 ) { + $pid = '0000' . $prompt_id; + } elsif ( $prompt_id < 100 ) { + $pid = '000' . $prompt_id; + } elsif ( $prompt_id < 1000 ) { + $pid = '00' . $prompt_id; + } + $prompts{$pid} = $prompt; } -open my $W, '<', $w or croak "problem with $w $!"; -open my $OT, '+>', $o_train or croak "problem with $o_train $!"; -open my $OE, '+>', $o_test or croak "problem with $o_test $!"; -open my $UT, '+>', $u_train or croak "problem with $u_train $!"; -open my $UE, '+>', $u_test or croak "problem with $u_test $!"; -open my $TT, '+>', $t_train or croak "problem with $t_train $!"; -open my $TE, '+>', $t_test or croak "problem with $t_test $!"; +open my $WVL, '<', $input_wav_list or croak "problem with $input_wav_list $!"; +open my $TRNWSCP, '+>', $train_wav_scp or croak "problem with $train_wav_scp $!"; +open my $TSTWSCP, '+>', $test_wav_scp or croak "problem with $test_wav_scp $!"; +open my $TRNUTTSPK, '+>', $train_uttspk or croak "problem with $train_uttspk $!"; +open my $TSTUTTSPK, '+>', $test_uttspk or croak "problem with $test_uttspk $!"; +open my $TRNTXT, '+>', $train_text or croak "problem with $train_text $!"; +open my $TSTTXT, '+>', $test_text or croak "problem with $test_text $!"; - LINE: while ( my $line = <$W> ) { + LINE: while ( my $line = <$WVL> ) { chomp $line; next LINE if ($line =~ /Answers/ ); next LINE unless ( $line =~ /Recordings/ ); my ($volume,$directories,$file) = File::Spec->splitpath( $line ); my @dirs = split /\//, $directories; - my $r = basename $line, ".wav"; - my $s = $dirs[-1]; - my $rid = $s . '_r' . '_' . $r; - if ( ( $r >= 355 ) and ( $r < 561 ) ) { - if ( exists $p{$r} ) { - print $TE "$rid $p{$r}\n"; - } elsif ( defined $rid ) { - warn "problem\t$rid"; + my $utt_id = basename $line, ".wav"; + # pad the utterance id with zeroes + my $utt = ""; + if ( $utt_id < 10 ) { + $utt = '0000' . $utt_id; +} elsif ( $utt_id < 100 ) { + $utt = '000' . $utt_id; +} elsif ( $utt_id < 1000 ) { + $utt = '00' . $utt_id; +} + my $spk_id = $dirs[-1]; + # pad the speaker id with zeroes + my $spk = ""; + if ( $spk_id < 10 ) { + $spk = '000' . $spk_id; + } elsif ( $spk_id < 100 ) { + $spk = '00' . $spk_id; + } elsif ( $spk_id < 1000 ) { + $spk = '0' . $spk_id; + } + my $spk_utt_id = $spk . '_' . $utt; + if ( ( $utt_id >= 355 ) and ( $utt_id < 561 ) ) { +if ( exists $prompts{$utt} ) { + print $TSTTXT "$spk_utt_id $prompts{$utt}\n"; + } elsif ( defined $spk_utt_id ) { + warn "problem\t$spk_utt_id"; next LINE; } else { croak "$line"; } - print $OE "$rid sox -r 22050 -e signed -b 16 $line -r 16000 -t wav - |\n"; - print $UE "$rid ${s}_r\n"; - } elsif ( ( $r < 355 ) or ( $r > 560 ) ) { - if ( exists $p{$r} ) { - print $TT "$rid $p{$r}\n"; - } elsif ( defined $rid ) { - warn "problem\t$rid"; + print $TSTWSCP "$spk_utt_id sox -r 22050 -e signed -b 16 $line -r 16000 -t wav - |\n"; + print $TSTUTTSPK "$spk_utt_id $spk\n"; + } elsif ( ( $utt_id < 355 ) or ( $utt_id > 560 ) ) { + if ( exists $prompts{$utt} ) { + print $TRNTXT "$spk_utt_id $prompts{$utt}\n"; + } elsif ( defined $spk_utt_id ) { + warn "problem\t$spk_utt_id"; next LINE; } else { croak "$line"; } - print $OT "$rid sox -r 22050 -e signed -b 16 $line -r 16000 -t wav - |\n"; - print $UT "$rid ${s}_r\n"; - } + print $TRNWSCP "$spk_utt_id sox -r 22050 -e signed -b 16 $line -r 16000 -t wav - |\n"; + print $TRNUTTSPK "$spk_utt_id $spk\n"; + } } -close $TT; -close $OT; -close $UT; -close $TE; -close $OE; -close $UE; -close $W; +close $TRNTXT; +close $TRNWSCP; +close $TRNUTTSPK; +close $TSTTXT; +close $TSTWSCP; +close $TSTUTTSPK; +close $WVL; diff --git a/egs/heroico/s5/local/nnet3/run_ivector_common.sh b/egs/heroico/s5/local/nnet3/run_ivector_common.sh index 153f0073667..e882ce0c918 100755 --- a/egs/heroico/s5/local/nnet3/run_ivector_common.sh +++ b/egs/heroico/s5/local/nnet3/run_ivector_common.sh @@ -9,6 +9,9 @@ set -euo pipefail # of usage. stage=0 +nj=56 +num_threads_ubm=2 + train_set=train test_sets="native nonnative devtest test" gmm=tri3b @@ -37,25 +40,17 @@ if [ $stage -le 1 ]; then utils/data/perturb_data_dir_speed_3way.sh \ data/${train_set} \ data/${train_set}_sp - echo "$0: making MFCC features for low-resolution speed-perturbed data" - steps/make_mfcc.sh \ - --cmd "$train_cmd" \ - --nj 10 \ - data/${train_set}_sp || exit 1; - steps/compute_cmvn_stats.sh \ - data/${train_set}_sp || exit 1; - utils/fix_data_dir.sh \ - data/${train_set}_sp + + echo "$0: making mfcc features for low-resolution speed-perturbed data" + steps/make_mfcc.sh --cmd "$train_cmd" --nj 10 data/${train_set}_sp || exit 1; + steps/compute_cmvn_stats.sh data/${train_set}_sp || exit 1; + utils/fix_data_dir.sh data/${train_set}_sp fi if [ $stage -le 2 ]; then echo "$0: aligning with the perturbed low-resolution data" steps/align_fmllr.sh \ - --nj 20 \ - --cmd "$train_cmd" \ - data/${train_set}_sp \ - data/lang \ - $gmm_dir \ + --nj 20 --cmd "$train_cmd" data/${train_set}_sp data/lang $gmm_dir \ $ali_dir || exit 1 fi diff --git a/egs/heroico/s5/local/prepare_data.sh b/egs/heroico/s5/local/prepare_data.sh index db2b990c07b..b78d9f1d1cb 100755 --- a/egs/heroico/s5/local/prepare_data.sh +++ b/egs/heroico/s5/local/prepare_data.sh @@ -4,17 +4,17 @@ # Apache 2.0. . ./cmd.sh - . ./path.sh stage=0 +datadir=$1 . ./utils/parse_options.sh set -e set -o pipefail -# the location of the LDC corpus -datadir=$1 +tmpdir=data/local/tmp + # acoustic models are trained on the heroico corpus # testing is done on the usma corpus # heroico consists of 2 parts: answers and recordings (recited) @@ -25,8 +25,6 @@ recordings_transcripts=$datadir/data/transcripts/heroico-recordings.txt # usma is all recited usma_transcripts=$datadir/data/transcripts/usma-prompts.txt -tmpdir=data/local/tmp - # make acoustic model training lists if [ $stage -le 0 ]; then mkdir -p $tmpdir/heroico $tmpdir/usma @@ -37,12 +35,12 @@ if [ $stage -le 0 ]; then # the transcripts are converted to UTF8 export LC_ALL=en_US.UTF-8 cat $answers_transcripts | iconv -f ISO-8859-1 -t UTF-8 | \ - sed -e 's/\r//' | local/heroico_answers_make_lists.pl + tr -d '\r' | local/heroico_answers_make_lists.pl utils/fix_data_dir.sh $tmpdir/heroico/answers cat $recordings_transcripts | iconv -f ISO-8859-1 -t UTF-8 | \ - sed -e 's/\r//' | local/heroico_recordings_make_lists.pl + tr -d '\r' | local/heroico_recordings_make_lists.pl utils/fix_data_dir.sh $tmpdir/heroico/recordings/train utils/fix_data_dir.sh $tmpdir/heroico/recordings/devtest @@ -52,11 +50,11 @@ if [ $stage -le 0 ]; then for x in wav.scp utt2spk text; do cat $tmpdir/heroico/answers/$x $tmpdir/heroico/recordings/train/$x | \ - sed -e 's/\r//' | sort -k1,1 -u >$tmpdir/heroico/lists/train/$x + tr -d '\r' | sort -k1,1 -u >$tmpdir/heroico/lists/train/$x done for x in wav.scp utt2spk text; do - cat $tmpdir/heroico/recordings/devtest/$x | sed -e 's/\r//' | \ + cat $tmpdir/heroico/recordings/devtest/$x | tr -d '\r' | \ sort -k1,1 -u >$tmpdir/heroico/lists/devtest/$x done @@ -67,10 +65,10 @@ fi if [ $stage -le 1 ]; then # make separate lists for usma (US military academy) native and nonnative cat $usma_transcripts | iconv -f ISO-8859-1 -t UTF-8 | \ - sed -e 's/\r//' | local/usma_native_make_lists.pl + tr -d '\r' | dos2unix | local/usma_native_make_lists.pl cat $usma_transcripts | iconv -f ISO-8859-1 -t UTF-8 | \ - sed -e 's/\r//' | local/usma_nonnative_make_lists.pl + tr -d '\r' | local/usma_nonnative_make_lists.pl for n in native nonnative; do mkdir -p $tmpdir/usma/$n/lists @@ -86,14 +84,14 @@ if [ $stage -le 1 ]; then # get training lists for x in wav.scp utt2spk text; do cat $tmpdir/heroico/answers/${x} $tmpdir/heroico/recordings/train/${x} | \ - sed -e 's/\r//' >$tmpdir/lists/train/$x + tr -d '\r' >$tmpdir/lists/train/$x sort $tmpdir/lists/train/$x >data/train/$x done # get devtest lists for x in wav.scp utt2spk text; do cat $tmpdir/heroico/lists/devtest/$x | \ - sed -e 's/\r//' >$tmpdir/lists/devtest/$x + tr -d '\r' >$tmpdir/lists/devtest/$x sort $tmpdir/lists/devtest/$x >data/devtest/$x done diff --git a/egs/heroico/s5/local/prepare_dict.sh b/egs/heroico/s5/local/prepare_dict.sh index a6d182a6852..9f498bc963a 100755 --- a/egs/heroico/s5/local/prepare_dict.sh +++ b/egs/heroico/s5/local/prepare_dict.sh @@ -13,12 +13,12 @@ fi export LC_ALL=C -cut -f2- data/local/tmp/dict/santiago.txt | \ +cut -f2- ./santiago.txt | \ tr -s '[:space:]' '[\n*]' | \ grep -v SPN | sort -u >data/local/dict/nonsilence_phones.txt # sed "1d" deletes the last line. -expand -t 1 data/local/tmp/dict/santiago.txt | sort -u | +expand -t 1 ./santiago.txt | sort -u | sed "1d" >data/local/dict/lexicon.txt echo " SPN" >> data/local/dict/lexicon.txt diff --git a/egs/heroico/s5/local/subs_download.sh b/egs/heroico/s5/local/subs_download.sh new file mode 100755 index 00000000000..98dcb42d4e0 --- /dev/null +++ b/egs/heroico/s5/local/subs_download.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# Copyright 2017 John Morgan +# Apache 2.0. + +tmpdir=data/local/tmp +download_dir=$(pwd) +mkdir -p $download_dir +subs_src=$1 + +# download the subs corpus +if [ ! -f $download_dir/subs.zip ]; then + wget -O $download_dir/subs.zip $subs_src + ( + cd $download_dir + unzip subs.zip + ) + else + echo "$0: subs file already downloaded." +fi diff --git a/egs/heroico/s5/local/subs_prepare_data.pl b/egs/heroico/s5/local/subs_prepare_data.pl index 3cd906d4699..e39db79f610 100755 --- a/egs/heroico/s5/local/subs_prepare_data.pl +++ b/egs/heroico/s5/local/subs_prepare_data.pl @@ -1,4 +1,4 @@ -#!/usr/bin/perl -w +#!/usr/bin/env perl # Copyright 2017 John Morgan # Apache 2.0. @@ -12,69 +12,64 @@ use Encode; # set lower and upper bounds -my $lb = 8; -# only segments with at least $lb words will be written -my $ub = 16; -# only segments with fewer than $ub words will be written +my $low_bound = 8; +# only segments with at least $low_bound words will be written +my $up_bound = 16; +# only segments with fewer than $up_bound words will be written # input and output files -my $c = "data/local/tmp/subs/OpenSubtitles2016.en-es.es"; -my $symtab = "data/lang/words.txt"; -my $rl = "data/local/tmp/subs/lm/es.txt"; -my $oo = "data/local/tmp/subs/lm/oovs.txt"; + +my $corpus = "OpenSubtitles.en-es.es"; +my $symbol_table = "data/lang/words.txt"; +my $filtered = "data/local/tmp/subs/lm/es.txt"; +my $oovs = "data/local/tmp/subs/lm/oovs.txt"; my $iv = "data/local/tmp/subs/lm/in_vocabulary.txt"; -open my $C, '<', $c or croak "problems with $c $!"; +open my $C, '<', $corpus or croak "problems with $corpus $!"; system "mkdir -p data/local/tmp/subs/lm"; -open my $RL, '+>:utf8', $rl or croak "problems with $rl $!"; - -LINE: while ( my $line = <$C> ) { - $line = decode_utf8 $line; - chomp $line; - - my @tokens = split /\s+/, $line; - - next LINE if ( ($#tokens < $lb) or ($#tokens > $ub )); - - #remove control characters - #$line =~ s/(\p{Other})/ /g; - #$line =~ s/(\p{Control})/ /g; - #$line =~ s/(\p{Format})/ /g; - #$line =~ s/(\p{Private_Use})/ /g; - #$line =~ s/(\p{Surrogate})/ /g; - - # punctuation - $line =~ s/(\p{Punctuation}+|\p{Dash_Punctuation}+|\p{Close_Punctuation}+|\p{Open_Punctuation}+|\p{Initial_Punctuation}+|\p{Final_Punctuation}+|\p{Connector_Punctuation}+|\p{Other_Punctuation}+|[ ]+)/ /msxg; -#convert tabs to white space - $line =~ s/\t/ /g; - #hard to soft space - $line =~ s/ / /g; -#squeeze white space - $line =~ s/\s+/ /g; -#initial and final white space - $line =~ s/^\p{Separator}+//; - $line =~ s/\p{Separator}+$//; -#down case - $line = lc $line; - - - print $RL "$line\n"; - +if ( -e $filtered ) { + warn "$filtered already exists."; +} else { + open my $FLT, '+>:utf8', $filtered or croak "problems with $filtered $!"; + LINE: while ( my $line = <$C> ) { + $line = decode_utf8 $line; + chomp $line; + + my @tokens = split /\s+/, $line; + + next LINE if ( ($#tokens < $low_bound) or ($#tokens > $up_bound )); + + # remove punctuation + $line =~ s/(\p{Punctuation}+|\p{Dash_Punctuation}+|\p{Close_Punctuation}+|\p{Open_Punctuation}+|\p{Initial_Punctuation}+|\p{Final_Punctuation}+|\p{Connector_Punctuation}+|\p{Other_Punctuation}+|[ ]+)/ /msxg; + #convert tabs to white space + $line =~ s/\t/ /g; + #hard to soft space + $line =~ s/ / /g; + #squeeze white space + $line =~ s/\s+/ /g; + #initial and final white space + $line =~ s/^\p{Separator}+//; + $line =~ s/\p{Separator}+$//; + #down case + $line = lc $line; + + print $FLT "$line\n"; + } + close $FLT; } - close $C; -close $RL; + # find out of vocabulary words -# $symtab points to a file containing a map of symbols to integers +# $symbol_table points to a file containing a map of symbols to integers # hash for word to integer map my %sym2int = (); -open my $F, '<', $symtab or croak "problem with $symtab $!"; +open my $F, '<', $symbol_table or croak "problem with $symbol_table $!"; # store words to int map in hash while( my $line = <$F>) { @@ -84,33 +79,33 @@ } close $F; -open my $I, '<', $rl or croak "problem with $rl $!"; -open my $OO, '+>', $oo or croak "problems with $oo $!"; +open my $I, '<', $filtered or croak "problem with $filtered $!"; +open my $OOVS, '+>', $oovs or croak "problems with $oovs $!"; while ( my $line = <$I>) { chomp $line; my @A = split /\s/, $line; foreach my $a (@A) { if (!defined ($sym2int{$a})) { - print $OO "$a\n"; + print $OOVS "$a\n"; } } } -close $OO; +close $OOVS; close $I; # remove segments with OOVs # store OOVS in hash my %oov = (); -open my $V, '<', $oo or croak "problems with $oo $!"; +open my $V, '<', $oovs or croak "problems with $oovs $!"; while ( my $line = <$V> ) { chomp $line; $oov{$line} = 1; } close $V; -open my $L, '<', $rl or croak "problems with $rl $!"; +open my $L, '<', $filtered or croak "problems with $filtered $!"; open my $IV, '+>', $iv or croak "problems with $iv $!"; SEGMENT: while ( my $segment = <$L> ) { diff --git a/egs/heroico/s5/run.sh b/egs/heroico/s5/run.sh index 711bece3c66..4cc5617e985 100755 --- a/egs/heroico/s5/run.sh +++ b/egs/heroico/s5/run.sh @@ -1,83 +1,80 @@ #!/bin/bash . ./cmd.sh - . ./path.sh + stage=0 +# the location of the LDC corpus; this location works for the CLSP grid. +datadir=/export/corpora5/LDC/LDC2006S37 + +# The corpus and lexicon are on openslr.org +#speech_url="http://www.openslr.org/resources/39/LDC2006S37.tar.gz" +lexicon_url="http://www.openslr.org/resources/34/santiago.tar.gz" + +# Location of the Movie subtitles text corpus +subtitles_url="http://opus.lingfil.uu.se/download.php?f=OpenSubtitles2018/en-es.txt.zip" + . utils/parse_options.sh set -e set -o pipefail set -u -# the location of the LDC corpus; this location works for the CLSP grid. -datadir=/export/corpora5/LDC/LDC2006S37 -#datadir=/mnt/corpora/LDC2006S37 - -# location of subtitles text data -# note: this is not used so I'm commenting it out; dan. -#subsdata="http://opus.lingfil.uu.se/download.php?f=OpenSubtitles2016/en-es.txt.zip" -lexicon="http://www.openslr.org/resources/34/santiago.tar.gz" # don't change tmpdir, the location is used explicitly in scripts in local/. tmpdir=data/local/tmp if [ $stage -le 0 ]; then - # prepare the lists for acoustic model training and testing - mkdir -p $tmpdir/heroico - mkdir -p $tmpdir/usma - - [ ! -d "$datadir" ] && \ - echo "$0 Data directory (LDC corpus release) does not exist" && \ + if [ ! -d $datadir ]; then + echo "$0: please download and un-tar http://www.openslr.org/resources/39/LDC2006S37.tar.gz" + echo " and set $datadir to the directory where it is located." exit 1 - local/prepare_data.sh $datadir + fi + if [ ! -s santiago.txt ]; then + echo "$0: downloading the lexicon" + wget -c http://www.openslr.org/resources/34/santiago.tar.gz + tar -xvzf santiago.tar.gz + fi + # Get data for lm training + local/subs_download.sh $subtitles_url fi if [ $stage -le 1 ]; then - # prepare a dictionary - mkdir -p data/local/dict - mkdir -p data/local/tmp/dict - - # download the dictionary from openslr - if [ ! -f data/local/tmp/dict/santiago.tar.gz ]; then - wget -O data/local/tmp/dict/santiago.tar.gz $lexicon - fi - - ( - cd $tmpdir/dict - tar -xzf santiago.tar.gz - ) + echo "Making lists for building models." + local/prepare_data.sh $datadir +fi +if [ $stage -le 2 ]; then + mkdir -p data/local/dict $tmpdir/dict local/prepare_dict.sh +fi - # prepare the lang directory +if [ $stage -le 3 ]; then utils/prepare_lang.sh \ data/local/dict "" \ data/local/lang data/lang fi -if [ $stage -le 2 ]; then - # use am training text to train lm - mkdir -p $tmpdir/heroico/lm +if [ $stage -le 4 ]; then + mkdir -p $tmpdir/subs/lm + local/subs_prepare_data.pl +fi + +if [ $stage -le 5 ]; then echo "point 1" - # get the text from data/train/text - cut -d " " -f 2- data/train/text > $tmpdir/heroico/lm/train.txt - echo "point 2" - # build lm - local/prepare_lm.sh $tmpdir/heroico/lm/train.txt + local/prepare_lm.sh $tmpdir/subs/lm/in_vocabulary.txt +fi - echo "point 3" +if [ $stage -le 6 ]; then + echo "point 2" utils/format_lm.sh \ data/lang data/local/lm/trigram.arpa.gz data/local/dict/lexicon.txt \ data/lang_test - - # delete temporary work - rm -rf data/local/tmp fi -if [ $stage -le 3 ]; then - # extract acoustic features +if [ $stage -le 7 ]; then + echo "$0: extracting acoustic features." mkdir -p exp for fld in native nonnative test devtest train; do @@ -92,7 +89,7 @@ if [ $stage -le 3 ]; then done fi -if [ $stage -le 4 ]; then +if [ $stage -le 8 ]; then echo "$0 monophone training" steps/train_mono.sh --nj 8 --cmd "$train_cmd" data/train data/lang exp/mono || exit 1; @@ -108,8 +105,7 @@ if [ $stage -le 4 ]; then ) & fi -if [ $stage -le 5 ]; then - +if [ $stage -le 9 ]; then # align with monophones steps/align_si.sh --nj 8 --cmd "$train_cmd" \ data/train data/lang exp/mono exp/mono_ali @@ -131,10 +127,8 @@ if [ $stage -le 5 ]; then fi -if [ $stage -le 6 ]; then +if [ $stage -le 10 ]; then echo "$0: Starting delta system alignment" - - # align with triphones steps/align_si.sh \ --nj 8 --cmd "$train_cmd" data/train data/lang exp/tri1 exp/tri1_ali @@ -156,10 +150,9 @@ if [ $stage -le 6 ]; then ) & fi -if [ $stage -le 7 ]; then +if [ $stage -le 11 ]; then echo "$0: Starting LDA+MLLT system alignment" - # align with lda and mllt adapted triphones steps/align_si.sh \ --use-graphs true --nj 8 --cmd "$train_cmd" \ data/train data/lang exp/tri2b exp/tri2b_ali @@ -169,7 +162,6 @@ if [ $stage -le 7 ]; then --cmd "$train_cmd" \ 3100 50000 data/train data/lang exp/tri2b_ali exp/tri3b - # align with tri3b models echo "$0 Starting exp/tri3b_ali" steps/align_fmllr.sh \ --nj 8 --cmd "$train_cmd" \ @@ -182,16 +174,16 @@ if [ $stage -le 7 ]; then utils/mkgraph.sh \ data/lang_test exp/tri3b exp/tri3b/graph || exit 1; - # decode test sets with tri3b models for x in native nonnative devtest test; do + echo "$0: decoding $x with tri3b models." steps/decode_fmllr.sh \ --nj 8 --cmd "$decode_cmd" exp/tri3b/graph data/$x exp/tri3b/decode_${x} done ) & fi -if [ $stage -le 9 ]; then - # train and test chain models +if [ $stage -le 12 ]; then + echo "$0: train and test chain models." local/chain/run_tdnn.sh fi diff --git a/egs/hkust/s5/RESULTS b/egs/hkust/s5/RESULTS index c419c9f6ddd..aac01fcb5af 100644 --- a/egs/hkust/s5/RESULTS +++ b/egs/hkust/s5/RESULTS @@ -1,3 +1,5 @@ +## Caution: these WERs are actually CERs. + # for x in exp/*/decode; do [ -d $x ] && grep WER $x/cer_* | utils/best_wer.sh; done %WER 80.67 [ 45198 / 56027, 1607 ins, 10733 del, 32858 sub ] exp/mono0a/decode/cer_9_0.0 %WER 58.79 [ 32939 / 56027, 2662 ins, 6124 del, 24153 sub ] exp/tri1/decode/cer_13_0.0 @@ -41,3 +43,6 @@ exp/nnet2_convnet/decode/cer_10:%WER 41.19 [ 23129 / 56154, 2599 ins, 3782 del, # nnet3 mfcc results (using speed perturbed data) exp/nnet3/tdnn_sp/decode_dev/cer_10:%WER 33.79 [ 18977 / 56154, 2027 ins, 3485 del, 13465 sub ] exp/nnet3/lstm_sp_ld5/decode_dev/cer_9:%WER 33.51 [ 18815 / 56154, 1813 ins, 3249 del, 13753 sub ] + + +# For nnet3+chain results, which are significantly better, see scripts in local/chain/tuning/. diff --git a/egs/hkust/s5/local/chain/compare_wer.sh b/egs/hkust/s5/local/chain/compare_wer.sh index b3376871a69..27a6b783433 100755 --- a/egs/hkust/s5/local/chain/compare_wer.sh +++ b/egs/hkust/s5/local/chain/compare_wer.sh @@ -39,25 +39,25 @@ for x in $*; do done echo -# print decode WER results -echo -n "# WER(%) " +# print decode CER results +echo -n "# CER(%) " for x in $*; do set_names $x - wer=$([ -d $x ] && grep WER $x/decode/cer_* | utils/best_wer.sh | awk '{print $2}') + wer=$([ -d $x ] && grep CER $x/decode/cer_* | utils/best_wer.sh | awk '{print $2}') printf "% 10s" $wer done echo -# so how about online WER? +# so how about online CER? if $include_online; then - echo -n "# WER(%)[online] " + echo -n "# CER(%)[online] " for x in $*; do set_names $x wer=$(cat ${x}_online/decode/cer_* | utils/best_wer.sh | awk '{print $2}') printf "% 10s" $wer done echo - echo -n "# WER(%)[per-utt] " + echo -n "# CER(%)[per-utt] " for x in $*; do set_names $x wer_per_utt=$(cat ${x}_online/decode_per_utt/cer_* | utils/best_wer.sh | awk '{print $2}') diff --git a/egs/hkust/s5/local/chain/tuning/run_tdnn_2a.sh b/egs/hkust/s5/local/chain/tuning/run_tdnn_2a.sh old mode 100644 new mode 100755 index 0fc0de36a45..c62b776de2b --- a/egs/hkust/s5/local/chain/tuning/run_tdnn_2a.sh +++ b/egs/hkust/s5/local/chain/tuning/run_tdnn_2a.sh @@ -5,9 +5,9 @@ # Results # local/chain/compare_wer.sh --online exp/chain/tdnn_7h_chain_2b_sp # Model tdnn_7h_chain_2b_sp -# WER(%) 23.67 -# WER(%)[online] 23.69 -# WER(%)[per-utt] 24.67 +# CER(%) 23.67 +# CER(%)[online] 23.69 +# CER(%)[per-utt] 24.67 # Final train prob -0.0895 # Final valid prob -0.1251 # Final train prob (xent) -1.3628 @@ -109,7 +109,7 @@ if [ $stage -le 12 ]; then ivector_dim=$(feat-to-dim scp:exp/nnet3/ivectors_${train_set}/ivector_online.scp -) feat_dim=$(feat-to-dim scp:data/${train_set}_hires/feats.scp -) num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.004 dropout-proportion=0.0 dropout-per-dim=true dropout-per-dim-continuous=true" linear_opts="orthonormal-constraint=-1.0 l2-regularize=0.004" output_opts="l2-regularize=0.002" diff --git a/egs/hub4_english/s5/local/data_prep/process_1995_bn_annotation.py b/egs/hub4_english/s5/local/data_prep/process_1995_bn_annotation.py index be0c7ad8e0d..5675dc3fbd9 100755 --- a/egs/hub4_english/s5/local/data_prep/process_1995_bn_annotation.py +++ b/egs/hub4_english/s5/local/data_prep/process_1995_bn_annotation.py @@ -31,9 +31,9 @@ def get_args(): parser = argparse.ArgumentParser("Process 1995 CSR-IV HUB4 transcripts") - parser.add_argument("--noise-word", type=str, default="", + parser.add_argument("--noise-word", default="", help="Word to add in-place of noise words") - parser.add_argument("--spoken-noise-word", type=str, + parser.add_argument("--spoken-noise-word", default="", help="Word to add in-place of speaker noise words") parser.add_argument("in_file", type=argparse.FileType('r'), @@ -230,7 +230,7 @@ def run(args): start_time = story_end_time segments = process_story_content( args, reco_id, - ' '.join([unicode(x) for x in s.children]), + ' '.join([str(x) for x in s.children]), start_time=story_begin_time, end_time=story_end_time) write_segments(segments, args) elif (s.name is not None and s.name != "language" @@ -240,9 +240,9 @@ def run(args): "or or ; got {0}".format(s)) elif s.name == "language" or s.name == "sung": non_story_contents.append( - ' '.join([unicode(x) for x in s.children])) + ' '.join([str(x) for x in s.children])) else: - non_story_contents.append(unicode(s)) + non_story_contents.append(str(s)) except RuntimeError: raise except Exception: diff --git a/egs/hub4_english/s5/local/data_prep/process_1996_csr_hub4_lm_filelist.py b/egs/hub4_english/s5/local/data_prep/process_1996_csr_hub4_lm_filelist.py index 95aa7ddb831..fb5ba7a64ee 100755 --- a/egs/hub4_english/s5/local/data_prep/process_1996_csr_hub4_lm_filelist.py +++ b/egs/hub4_english/s5/local/data_prep/process_1996_csr_hub4_lm_filelist.py @@ -36,9 +36,9 @@ def get_args(): corpus (LDC98T31).""") parser.add_argument("--verbose", choices=[0,1,2,3], type=int, default=0, help="Set higher for more verbose logging.") - parser.add_argument("file_list", type=str, + parser.add_argument("file_list", help="""List of compressed source files""") - parser.add_argument("dir", type=str, + parser.add_argument("dir", help="Output directory to dump processed files to") args = parser.parse_args() @@ -83,7 +83,7 @@ def process_file_lines(lines, out_file_handle): for x in para.contents: try: if x.name is None: - normalized_text = normalize_text(unicode(x)) + normalized_text = normalize_text(str(x)) if len(normalized_text) == 0: continue out_file_handle.write("{0}\n".format( diff --git a/egs/hub4_english/s5/local/data_prep/process_na_news_text.py b/egs/hub4_english/s5/local/data_prep/process_na_news_text.py index 94b02a766a9..08203f7ada1 100755 --- a/egs/hub4_english/s5/local/data_prep/process_na_news_text.py +++ b/egs/hub4_english/s5/local/data_prep/process_na_news_text.py @@ -38,10 +38,10 @@ def get_args(): parser = argparse.ArgumentParser("Prepare NA News Text corpus (LDC95T21).") parser.add_argument("--verbose", type=int, choices=[0, 1, 2, 3], default=0, help="Use larger verbosity for more verbose logging.") - parser.add_argument("file_list", type=str, + parser.add_argument("file_list", help="List of compressed source files for NA News Text. " "e.g: /export/corpora/LDC/LDC95T21/na_news_1/latwp/1994") - parser.add_argument("out_file", type=str, + parser.add_argument("out_file", help="Output file to write to.") args = parser.parse_args() @@ -85,7 +85,7 @@ def process_file_lines(lines, out_file_handle): continue for para in art.find_all('p'): assert para.name == 'p' - text = ' '.join([unicode(x).strip() for x in para.contents]) + text = ' '.join([str(x).strip() for x in para.contents]) normalized_text = normalize_text(text) out_file_handle.write("{0}\n".format( normalized_text.encode('ascii'))) diff --git a/egs/hub4_english/s5/local/lm/merge_word_counts.py b/egs/hub4_english/s5/local/lm/merge_word_counts.py index 6338cbbf875..85e15d8dc07 100755 --- a/egs/hub4_english/s5/local/lm/merge_word_counts.py +++ b/egs/hub4_english/s5/local/lm/merge_word_counts.py @@ -7,6 +7,7 @@ A min-count argument is required to only write counts that are above the specified minimum count. """ +from __future__ import print_function import sys @@ -21,7 +22,7 @@ def main(): parts = line.strip().split() words[parts[1]] = words.get(parts[1], 0) + int(parts[0]) - for word, count in words.iteritems(): + for word, count in words.items(): if count >= int(sys.argv[1]): print ("{0} {1}".format(count, word)) diff --git a/egs/hub4_spanish/s5/local/chain/compare_wer.sh b/egs/hub4_spanish/s5/local/chain/compare_wer.sh new file mode 100755 index 00000000000..0194b86ac69 --- /dev/null +++ b/egs/hub4_spanish/s5/local/chain/compare_wer.sh @@ -0,0 +1,135 @@ +#!/bin/bash + +# this script is used for comparing decoding results between systems. +# e.g. local/chain/compare_wer.sh exp/chain/tdnn_{c,d}_sp +# For use with discriminatively trained systems you specify the epochs after a colon: +# for instance, +# local/chain/compare_wer.sh exp/chain/tdnn_c_sp exp/chain/tdnn_c_sp_smbr:{1,2,3} + + +if [ $# == 0 ]; then + echo "Usage: $0: [--looped] [--online] [ ... ]" + echo "e.g.: $0 exp/chain/tdnn_{b,c}_sp" + echo "or (with epoch numbers for discriminative training):" + echo "$0 exp/chain/tdnn_b_sp_disc:{1,2,3}" + exit 1 +fi + +echo "# $0 $*" + +include_looped=false +if [ "$1" == "--looped" ]; then + include_looped=true + shift +fi +include_online=false +if [ "$1" == "--online" ]; then + include_online=true + shift +fi + + +used_epochs=false + +# this function set_names is used to separate the epoch-related parts of the name +# [for discriminative training] and the regular parts of the name. +# If called with a colon-free directory name, like: +# set_names exp/chain/tdnn_lstm1e_sp_bi_smbr +# it will set dir=exp/chain/tdnn_lstm1e_sp_bi_smbr and epoch_infix="" +# If called with something like: +# set_names exp/chain/tdnn_d_sp_smbr:3 +# it will set dir=exp/chain/tdnn_d_sp_smbr and epoch_infix="_epoch3" + + +set_names() { + if [ $# != 1 ]; then + echo "compare_wer_general.sh: internal error" + exit 1 # exit the program + fi + dirname=$(echo $1 | cut -d: -f1) + epoch=$(echo $1 | cut -s -d: -f2) + if [ -z $epoch ]; then + epoch_infix="" + else + used_epochs=true + epoch_infix=_epoch${epoch} + fi +} + + + +echo -n "# System " +for x in $*; do printf "% 10s" " $(basename $x)"; done +echo + +strings=("#WER test ") + +for n in 0; do + echo -n "${strings[$n]}" + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + decode_names=(test) + + wer=$(cat $dirname/decode_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + if $include_looped; then + echo -n "# [looped:] " + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + wer=$(cat $dirname/decode_looped_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + fi + if $include_online; then + echo -n "# [online:] " + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + wer=$(cat ${dirname}_online/decode_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + fi +done + + +if $used_epochs; then + exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. +fi + + +echo -n "# Final train prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final train prob (xent)" +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob (xent)" +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Num-params " +for x in $*; do + printf "% 10s" $(grep num-parameters $x/log/progress.1.log | awk '{print $2}') +done +echo diff --git a/egs/hub4_spanish/s5/local/chain/run_cnn_tdnn.sh b/egs/hub4_spanish/s5/local/chain/run_cnn_tdnn.sh new file mode 120000 index 00000000000..ab83f3c43e8 --- /dev/null +++ b/egs/hub4_spanish/s5/local/chain/run_cnn_tdnn.sh @@ -0,0 +1 @@ +tuning/run_cnn_tdnn_1a.sh \ No newline at end of file diff --git a/egs/hub4_spanish/s5/local/chain/run_tdnn.sh b/egs/hub4_spanish/s5/local/chain/run_tdnn.sh index 211957092f9..61f8f499182 120000 --- a/egs/hub4_spanish/s5/local/chain/run_tdnn.sh +++ b/egs/hub4_spanish/s5/local/chain/run_tdnn.sh @@ -1 +1 @@ -./tuning/run_tdnn_1a.sh \ No newline at end of file +tuning/run_tdnn_1b.sh \ No newline at end of file diff --git a/egs/hub4_spanish/s5/local/chain/tuning/run_cnn_tdnn_1a.sh b/egs/hub4_spanish/s5/local/chain/tuning/run_cnn_tdnn_1a.sh new file mode 100755 index 00000000000..d1b657a2d74 --- /dev/null +++ b/egs/hub4_spanish/s5/local/chain/tuning/run_cnn_tdnn_1a.sh @@ -0,0 +1,287 @@ +#!/bin/bash + +## This is taken from mini_librispeech. + +# local/chain/compare_wer.sh --online exp/chain/tdnn1a_sp exp/chain/cnn_tdnn1a_sp +# System tdnn1a_sp cnn_tdnn1a_sp +#WER test 14.19 13.47 +# [online:] 14.26 13.57 +# Final train prob -0.0707 -0.0911 +# Final valid prob -0.1225 -0.1145 +# Final train prob (xent) -1.1117 -1.3038 +# Final valid prob (xent) -1.3199 -1.3374 +# Num-params 6945216 4471200 + +# steps/info/chain_dir_info.pl exp/chain/cnn_tdnn1a_sp +# exp/chain/cnn_tdnn1a_sp: num-iters=102 nj=2..5 num-params=4.5M dim=40+100->2272 combine=-0.101->-0.097 (over 5) xent:train/valid[67,101,final]=(-1.46,-1.31,-1.30/-1.47,-1.34,-1.34) logprob:train/valid[67,101,final]=(-0.112,-0.097,-0.091/-0.129,-0.121,-0.114) + +# Set -e here so that we catch if any executable fails immediately +set -euo pipefail + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +decode_nj=10 +train_set=train +test_sets=eval +gmm=tri5 +nnet3_affix= + +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +affix=1a # affix for the TDNN directory name +tree_affix= +train_stage=-10 +get_egs_stage=-10 +decode_iter= + +# training options +# training chunk-options +chunk_width=140,100,160 +dropout_schedule='0,0@0.20,0.3@0.50,0' +# we don't need extra left/right context for TDNN systems. +chunk_left_context=0 +chunk_right_context=0 +common_egs_dir= +xent_regularize=0.1 + +# training options +srand=0 +remove_egs=true +reporting_email= + +#decode options +test_online_decoding=true # if true, it will run the last decoding stage. + + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 11 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj 75 --cmd "$train_cmd" ${lores_train_data_dir} \ + data/lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 12 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 3500 ${lores_train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 13 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + + cnn_opts="l2-regularize=0.03" + ivector_affine_opts="l2-regularize=0.03" + tdnn_opts="l2-regularize=0.03 dropout-proportion=0.0 dropout-per-dim-continuous=true" + tdnnf_first_opts="l2-regularize=0.03 dropout-proportion=0.0 bypass-scale=0.0" + tdnnf_opts="l2-regularize=0.03 dropout-proportion=0.0 bypass-scale=0.66" + linear_opts="l2-regularize=0.03 orthonormal-constraint=-1.0" + prefinal_opts="l2-regularize=0.03" + output_opts="l2-regularize=0.015" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # this takes the MFCCs and generates filterbank coefficients. The MFCCs + # are more compressible so we prefer to dump the MFCCs to disk rather + # than filterbanks. + idct-layer name=idct input=input dim=40 cepstral-lifter=22 affine-transform-file=$dir/configs/idct.mat + + linear-component name=ivector-linear $ivector_affine_opts dim=200 input=ReplaceIndex(ivector, t, 0) + batchnorm-component name=ivector-batchnorm target-rms=0.025 + + batchnorm-component name=idct-batchnorm input=idct + combine-feature-maps-layer name=combine_inputs input=Append(idct-batchnorm, ivector-batchnorm) num-filters1=1 num-filters2=5 height=40 + + conv-relu-batchnorm-layer name=cnn1 $cnn_opts height-in=40 height-out=40 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=48 learning-rate-factor=0.333 max-change=0.25 + conv-relu-batchnorm-layer name=cnn2 $cnn_opts height-in=40 height-out=40 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=48 + conv-relu-batchnorm-layer name=cnn3 $cnn_opts height-in=40 height-out=20 height-subsample-out=2 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=64 + conv-relu-batchnorm-layer name=cnn4 $cnn_opts height-in=20 height-out=20 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=64 + conv-relu-batchnorm-layer name=cnn5 $cnn_opts height-in=20 height-out=10 height-subsample-out=2 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=64 + conv-relu-batchnorm-layer name=cnn6 $cnn_opts height-in=10 height-out=5 height-subsample-out=2 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=128 + + # the first TDNN-F layer has no bypass (since dims don't match), and a larger bottleneck so the + # information bottleneck doesn't become a problem. (we use time-stride=0 so no splicing, to + # limit the num-parameters). + tdnnf-layer name=tdnnf7 $tdnnf_first_opts dim=768 bottleneck-dim=192 time-stride=0 + tdnnf-layer name=tdnnf8 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + tdnnf-layer name=tdnnf9 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + tdnnf-layer name=tdnnf10 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + tdnnf-layer name=tdnnf11 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + tdnnf-layer name=tdnnf12 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + tdnnf-layer name=tdnnf13 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + tdnnf-layer name=tdnnf14 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + tdnnf-layer name=tdnnf15 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + linear-component name=prefinal-l dim=192 $linear_opts + + ## adding the layers for chain branch + prefinal-layer name=prefinal-chain input=prefinal-l $prefinal_opts small-dim=192 big-dim=768 + output-layer name=output include-log-softmax=false dim=$num_targets $output_opts + + # adding the layers for xent branch + prefinal-layer name=prefinal-xent input=prefinal-l $prefinal_opts small-dim=192 big-dim=768 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 14 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/hub4_spanish-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$decode_cmd" \ + --feat.online-ivector-dir=$train_ivector_dir \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.0 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=10 \ + --trainer.frames-per-iter=3000000 \ + --trainer.optimization.num-jobs-initial=2 \ + --trainer.optimization.num-jobs-final=5 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.num-chunk-per-minibatch=256,128,64 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.chunk-left-context=$chunk_left_context \ + --egs.chunk-right-context=$chunk_right_context \ + --egs.chunk-left-context-initial=0 \ + --egs.chunk-right-context-final=0 \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 15 ]; then + # Note: it's not important to give mkgraph.sh the lang directory with the + # matched topology (since it gets the topology file from the model). + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/langp_test \ + $tree_dir $dir/graph || exit 1; +fi + +if [ $stage -le 16 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + nspk=$(wc -l $dir/configs/network.xconfig @@ -179,7 +179,7 @@ fi if [ $stage -le 14 ]; then if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then utils/create_split_dir.pl \ - /export/b0{3,4,5,6}/$USER/kaldi-data/egs/mini_librispeech-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/hub4_spanish-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage fi steps/nnet3/chain/train.py --stage=$train_stage \ @@ -227,6 +227,16 @@ if [ $stage -le 15 ]; then $tree_dir $dir/graph || exit 1; fi +if [ $stage -le 16 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + nspk=$(wc -l 2272 combine=-0.105->-0.100 (over 6) xent:train/valid[67,101,final]=(-1.54,-1.34,-1.35/-1.56,-1.39,-1.39) logprob:train/valid[67,101,final]=(-0.116,-0.099,-0.094/-0.135,-0.123,-0.116) + +# Set -e here so that we catch if any executable fails immediately +set -euo pipefail + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +decode_nj=10 +train_set=train +test_sets=eval +gmm=tri5 +nnet3_affix= + +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +affix=1b # affix for the TDNN directory name +tree_affix= +train_stage=-10 +get_egs_stage=-10 +decode_iter= + +# training options +# training chunk-options +chunk_width=140,100,160 +dropout_schedule='0,0@0.20,0.3@0.50,0' +# we don't need extra left/right context for TDNN systems. +chunk_left_context=0 +chunk_right_context=0 +common_egs_dir= +xent_regularize=0.1 + +# training options +srand=0 +remove_egs=true +reporting_email= + +#decode options +test_online_decoding=true # if true, it will run the last decoding stage. + + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 11 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj 75 --cmd "$train_cmd" ${lores_train_data_dir} \ + data/lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 12 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 3500 ${lores_train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 13 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + + tdnn_opts="l2-regularize=0.03 dropout-proportion=0.0 dropout-per-dim-continuous=true" + tdnnf_opts="l2-regularize=0.03 dropout-proportion=0.0 bypass-scale=0.66" + linear_opts="l2-regularize=0.03 orthonormal-constraint=-1.0" + prefinal_opts="l2-regularize=0.03" + output_opts="l2-regularize=0.015" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-dropout-layer name=tdnn1 $tdnn_opts dim=768 + tdnnf-layer name=tdnnf2 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=1 + tdnnf-layer name=tdnnf3 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=1 + tdnnf-layer name=tdnnf4 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=1 + tdnnf-layer name=tdnnf5 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=0 + tdnnf-layer name=tdnnf6 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + tdnnf-layer name=tdnnf7 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + tdnnf-layer name=tdnnf8 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + tdnnf-layer name=tdnnf9 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + tdnnf-layer name=tdnnf10 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + tdnnf-layer name=tdnnf11 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + tdnnf-layer name=tdnnf12 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + tdnnf-layer name=tdnnf13 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + linear-component name=prefinal-l dim=192 $linear_opts + + ## adding the layers for chain branch + prefinal-layer name=prefinal-chain input=prefinal-l $prefinal_opts small-dim=192 big-dim=768 + output-layer name=output include-log-softmax=false dim=$num_targets $output_opts + + # adding the layers for xent branch + prefinal-layer name=prefinal-xent input=prefinal-l $prefinal_opts small-dim=192 big-dim=768 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 14 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/hub4_spanish-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$decode_cmd" \ + --feat.online-ivector-dir=$train_ivector_dir \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.0 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=10 \ + --trainer.frames-per-iter=3000000 \ + --trainer.optimization.num-jobs-initial=2 \ + --trainer.optimization.num-jobs-final=5 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.num-chunk-per-minibatch=256,128,64 \ + --egs.cmd="run.pl --max-jobs-run 12" \ + --egs.chunk-width=$chunk_width \ + --egs.chunk-left-context=$chunk_left_context \ + --egs.chunk-right-context=$chunk_right_context \ + --egs.chunk-left-context-initial=0 \ + --egs.chunk-right-context-final=0 \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 15 ]; then + # Note: it's not important to give mkgraph.sh the lang directory with the + # matched topology (since it gets the topology file from the model). + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/langp_test \ + $tree_dir $dir/graph || exit 1; +fi + +if [ $stage -le 16 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + nspk=$(wc -l is written as a word if(w[0].lower() == ""): - f.write("%s\t\n" % (unicode(w[0]))) + f.write("%s\t\n" % (str(w[0]))) else: - f.write("%s\t%s\n" % (unicode(w[0]), + f.write("%s\t%s\n" % (str(w[0]), encoded_transcription[idx])) if __name__ == "__main__": diff --git a/egs/hub4_spanish/s5/local/prepare_unicode_dict.py b/egs/hub4_spanish/s5/local/prepare_unicode_dict.py index 86fa4d60ba1..3b9dc1abd86 100755 --- a/egs/hub4_spanish/s5/local/prepare_unicode_dict.py +++ b/egs/hub4_spanish/s5/local/prepare_unicode_dict.py @@ -89,7 +89,7 @@ def extract_phonemes(lexicon): # Read all baseform units into dictionary with {a: [a, a_1, a_2], # b: [b_1, b_3], ...} phonemes_dict = {} - for word, pron in lexicon.iteritems(): + for word, pron in lexicon.items(): for p in pron.split(): try: base = p.split("_",1)[0] @@ -98,11 +98,11 @@ def extract_phonemes(lexicon): phonemes_dict[base] = [p] # Makes sure there are no repeats in the list - phonemes_dict = {k: set(v) for k, v in phonemes_dict.iteritems()} + phonemes_dict = {k: set(v) for k, v in phonemes_dict.items()} # Get all unique phonemes phonemes = [] - for v in phonemes_dict.itervalues(): + for v in phonemes_dict.values(): for p in v: phonemes.append(p) @@ -137,11 +137,11 @@ def write_extra_questions(nonsil_phonemes, nonsil_phonemes_dict, # Write all possible phone_tag combinations that occur in the lexicon for tag in tags: - for p in nonsil_phonemes_dict.iterkeys(): + for p in nonsil_phonemes_dict.keys(): tagged_phoneme = "_".join([p, tag]) if(tagged_phoneme in nonsil_phonemes_dict[p]): fp.write("%s " % tagged_phoneme) - for p in sil_phonemes_dict.iterkeys(): + for p in sil_phonemes_dict.keys(): tagged_phoneme = "_".join([p, tag]) if(tagged_phoneme in sil_phonemes_dict[p]): fp.write("%s " % tagged_phoneme) diff --git a/egs/iam/v1/RESULTS b/egs/iam/v1/RESULTS new file mode 100644 index 00000000000..b25cb3cd772 --- /dev/null +++ b/egs/iam/v1/RESULTS @@ -0,0 +1,42 @@ +Run_end2end.sh (WER using lang_test, lang_unk) +flat_start: + • %WER 14.41 [ 2671 / 18542, 262 ins, 561 del, 1848 sub ] exp/chain/e2e_cnn_1a/decode_test/wer_11_1.0 + • %WER 15.21 [ 2821 / 18542, 375 ins, 500 del, 1946 sub ] exp/chain/e2e_cnn_1a/decode_test/wer_11_1.0 + +cnn_e2eali_1a: + • %WER 11.94 [ 2214 / 18542, 267 ins, 380 del, 1567 sub ] exp/chain/cnn_e2eali_1a/decode_test/wer_9_1.0 + • %WER 13.30 [ 2467 / 18542, 441 ins, 330 del, 1696 sub ] exp/chain/cnn_e2eali_1a/decode_test/wer_9_0.5 + +cnn_e2eali_1b: + • %WER 11.20 [ 2076 / 18542, 260 ins, 335 del, 1481 sub ] exp/chain/cnn_e2eali_1b/decode_test/wer_9_1.0 + • %WER 12.46 [ 2311 / 18542, 371 ins, 326 del, 1614 sub ] exp/chain/cnn_e2eali_1b/decode_test/wer_9_1.0 + +cnn_e2eali_1c: + • %WER 9.90 [ 1836 / 18542, 257 ins, 227 del, 1352 sub ] exp/chain/cnn_e2eali_1c/decode_test/wer_10_1.0 + • %WER 12.10 [ 2243 / 18542, 411 ins, 269 del, 1563 sub ] exp/chain/cnn_e2eali_1c/decode_test/wer_12_0.5 + + +Run.sh (WER using lang_test, lang_unk) +cnn_1a: + • %WER 15.18 [ 2815 / 18542, 285 ins, 509 del, 2021 sub ] exp/chain/cnn_1a/decode_test/wer_11_0.0 + • %WER 16.88 [ 3130 / 18542, 444 ins, 611 del, 2075 sub ] exp/chain/cnn_1a/decode_test/wer_11_0.0 + +cnn_chainali_1a: + • %WER 14.09 [ 2612 / 18542, 245 ins, 505 del, 1862 sub ] exp/chain/cnn_chainali_1a/decode_test/wer_13_0.0 + • %WER 15.93 [ 2954 / 18542, 454 ins, 470 del, 2030 sub ] exp/chain/cnn_chainali_1a/decode_test/wer_10_0.0 + +cnn_chainali_1b: + • %WER 13.29 [ 2465 / 18542, 221 ins, 499 del, 1745 sub ] exp/chain/cnn_chainali_1b/decode_test/wer_12_0.5 + • %WER 15.09 [ 2798 / 18542, 418 ins, 468 del, 1912 sub ] exp/chain/cnn_chainali_1b/decode_test/wer_10_0.5 + +cnn_chainali_1c: + • %WER 11.59 [ 2149 / 18542, 276 ins, 362 del, 1511 sub ] exp/chain/cnn_chainali_1c/decode_test/wer_9_0.0 + • %WER 13.75 [ 2550 / 18542, 465 ins, 368 del, 1717 sub ] exp/chain/cnn_chainali_1c/decode_test/wer_8_0.0 + +cnn_chainali_1d: + • %WER 11.07 [ 2053 / 18542, 261 ins, 311 del, 1481 sub ] exp/chain/cnn_chainali_1c/decode_test/wer_9_0.0 + • %WER 12.95 [ 2402 / 18542, 436 ins, 313 del, 1653 sub ] exp/chain/cnn_chainali_1c/decode_test/wer_8_0.0 + +cnn_chainali_1e: + • %WER 10.03 [ 1859 / 18542, 226 ins, 291 del, 1342 sub ] exp/chain/cnn_chainali_1e/decode_test/wer_11_0.5 + %WER 12.15 [ 2253 / 18542, 406 ins, 282 del, 1565 sub ] exp/chain/cnn_chainali_1e/decode_test/wer_10_0.5 diff --git a/egs/iam/v1/local/augment_data.sh b/egs/iam/v1/local/augment_data.sh new file mode 100755 index 00000000000..31e4a8217ca --- /dev/null +++ b/egs/iam/v1/local/augment_data.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Copyright 2018 Hossein Hadian +# 2018 Ashish Arora + +# Apache 2.0 +# This script performs data augmentation. + +nj=4 +cmd=run.pl +feat_dim=40 +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +srcdir=$1 +outdir=$2 +datadir=$3 +aug_set=aug1 +mkdir -p $datadir/augmentations +echo "copying $srcdir to $datadir/augmentations/$aug_set, allowed length, creating feats.scp" + +for set in $aug_set; do + image/copy_data_dir.sh --spk-prefix $set- --utt-prefix $set- \ + $srcdir $datadir/augmentations/$set + cat $srcdir/allowed_lengths.txt > $datadir/augmentations/$set/allowed_lengths.txt + local/extract_features.sh --nj $nj --cmd "$cmd" --feat-dim $feat_dim \ + --fliplr false --augment true $datadir/augmentations/$set +done + +echo " combine original data and data from different augmentations" +utils/combine_data.sh --extra-files images.scp $outdir $srcdir $datadir/augmentations/$aug_set +cat $srcdir/allowed_lengths.txt > $outdir/allowed_lengths.txt diff --git a/egs/iam/v1/local/chain/compare_wer.sh b/egs/iam/v1/local/chain/compare_wer.sh index ad90710b13f..4a2cc29481c 100755 --- a/egs/iam/v1/local/chain/compare_wer.sh +++ b/egs/iam/v1/local/chain/compare_wer.sh @@ -34,6 +34,20 @@ for x in $*; do done echo +echo -n "# WER val " +for x in $*; do + wer=$(cat $x/decode_val/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# CER val " +for x in $*; do + cer=$(cat $x/decode_val/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + if $used_epochs; then exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. fi diff --git a/egs/iam/v1/local/chain/run_cnn.sh b/egs/iam/v1/local/chain/run_cnn.sh new file mode 120000 index 00000000000..df6f0a468c1 --- /dev/null +++ b/egs/iam/v1/local/chain/run_cnn.sh @@ -0,0 +1 @@ +tuning/run_cnn_1a.sh \ No newline at end of file diff --git a/egs/iam/v1/local/chain/run_cnn_chainali.sh b/egs/iam/v1/local/chain/run_cnn_chainali.sh new file mode 120000 index 00000000000..41b712609c2 --- /dev/null +++ b/egs/iam/v1/local/chain/run_cnn_chainali.sh @@ -0,0 +1 @@ +tuning/run_cnn_chainali_1d.sh \ No newline at end of file diff --git a/egs/iam/v1/local/chain/run_cnn_e2eali.sh b/egs/iam/v1/local/chain/run_cnn_e2eali.sh new file mode 120000 index 00000000000..ad51803ab0e --- /dev/null +++ b/egs/iam/v1/local/chain/run_cnn_e2eali.sh @@ -0,0 +1 @@ +tuning/run_cnn_e2eali_1c.sh \ No newline at end of file diff --git a/egs/iam/v1/local/chain/run_e2e_cnn.sh b/egs/iam/v1/local/chain/run_e2e_cnn.sh new file mode 120000 index 00000000000..d26ba0182ce --- /dev/null +++ b/egs/iam/v1/local/chain/run_e2e_cnn.sh @@ -0,0 +1 @@ +tuning/run_e2e_cnn_1a.sh \ No newline at end of file diff --git a/egs/iam/v1/local/chain/run_cnn_1a.sh b/egs/iam/v1/local/chain/tuning/run_cnn_1a.sh similarity index 80% rename from egs/iam/v1/local/chain/run_cnn_1a.sh rename to egs/iam/v1/local/chain/tuning/run_cnn_1a.sh index 41a76920e37..ef1273f3961 100755 --- a/egs/iam/v1/local/chain/run_cnn_1a.sh +++ b/egs/iam/v1/local/chain/tuning/run_cnn_1a.sh @@ -4,23 +4,23 @@ # 2017 Chun Chieh Chang # 2017 Ashish Arora -# steps/info/chain_dir_info.pl exp/chain/cnn_1a/ -# exp/chain/cnn_1a/: num-iters=21 nj=2..4 num-params=4.4M dim=40->364 combine=-0.021->-0.015 xent:train/valid[13,20,final]=(-1.05,-0.701,-0.591/-1.30,-1.08,-1.00) logprob:train/valid[13,20,final]=(-0.061,-0.034,-0.030/-0.107,-0.101,-0.098) - # local/chain/compare_wer.sh exp/chain/cnn_1a/ -# System cnn_1a -# WER 18.52 -# CER 10.07 -# Final train prob -0.0077 -# Final valid prob -0.0970 -# Final train prob (xent) -0.5484 -# Final valid prob (xent) -0.9643 -# Parameters 4.36M +# System cnn_1a(dict_50k) cnn_1a(dict_50k + unk model) +# WER 16.88 15.18 +# CER 8.52 7.58 +# WER val 16.17 13.53 +# CER val 7.15 5.89 +# Final train prob -0.0299 +# Final valid prob -0.0574 +# Final train prob (xent) -0.3912 +# Final valid prob (xent) -0.6439 +# Parameters 4.36M -set -e -o pipefail +# steps/info/chain_dir_info.pl exp/chain/cnn_1a/ +# exp/chain/cnn_1a/: num-iters=42 nj=2..4 num-params=4.4M dim=40->368 combine=-0.029->-0.029 (over 2) xent:train/valid[27,41,final]=(-0.522,-0.394,-0.391/-0.695,-0.644,-0.644) logprob:train/valid[27,41,final]=(-0.035,-0.030,-0.030/-0.056,-0.057,-0.057) +set -e -o pipefail stage=0 - nj=30 train_set=train gmm=tri3 # this is the source gmm-dir that we'll use for alignments; it @@ -34,28 +34,21 @@ reporting_email= # chain options train_stage=-10 xent_regularize=0.1 -frame_subsampling_factor=4 -alignment_subsampling_factor=1 # training chunk-options chunk_width=340,300,200,100 num_leaves=500 -# we don't need extra left/right context for TDNN systems. -chunk_left_context=0 -chunk_right_context=0 tdnn_dim=450 # training options -srand=0 -remove_egs=false -lang_test=lang_unk +lang_decode=lang_unk +decode_val=true +if $decode_val; then maybe_val=val; else maybe_val= ; fi # End configuration section. echo "$0 $@" # Print the command line for logging - . ./cmd.sh . ./path.sh . ./utils/parse_options.sh - if ! cuda-compiled; then cat <368 combine=-0.020->-0.020 (over 2) xent:train/valid[27,41,final]=(-0.534,-0.425,-0.424/-0.659,-0.612,-0.612) logprob:train/valid[27,41,final]=(-0.026,-0.022,-0.022/-0.017,-0.016,-0.016) +set -e -o pipefail + +stage=0 +nj=30 +train_set=train +decode_val=true +gmm=tri3 # this is the source gmm-dir that we'll use for alignments; it + # should have alignments for the specified training data. +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. +affix=_1a #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +ali=tri3_ali +chain_model_dir=exp/chain${nnet3_affix}/cnn_1a +common_egs_dir= +reporting_email= + +# chain options +train_stage=-10 +xent_regularize=0.1 +chunk_width=340,300,200,100 +num_leaves=500 +tdnn_dim=450 +lang_decode=lang_unk +if $decode_val; then maybe_val=val; else maybe_val= ; fi +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 2 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/nnet3/align_lats.sh --nj $nj --cmd "$cmd" \ + --acoustic-scale 1.0 \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0' \ + $train_data_dir data/lang $chain_model_dir $lat_dir + cp $gmm_lat_dir/splice_opts $lat_dir/splice_opts +fi + +if [ $stage -le 3 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor 4 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$cmd" $num_leaves $train_data_dir \ + $lang $ali_dir $tree_dir +fi + +if [ $stage -le 4 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + common1="height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="height-offsets=-1,0,1 num-filters-out=70" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=40 name=input + conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + relu-batchnorm-layer name=tdnn1 input=Append(-4,-2,0,2,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn4 input=Append(-4,0,4) dim=$tdnn_dim + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' mod?els... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn4 dim=$tdnn_dim target-rms=0.5 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 5 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/iam-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--num-extra-lm-states=500" \ + --chain.frame-subsampling-factor=4 \ + --chain.alignment-subsampling-factor=1 \ + --trainer.srand=0 \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=4 \ + --trainer.frames-per-iter=1000000 \ + --trainer.optimization.num-jobs-initial=2 \ + --trainer.optimization.num-jobs-final=4 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=64,32 \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0" \ + --cleanup.remove-egs=false \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 6 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/$lang_decode \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 7 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + for decode_set in test $maybe_val; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/$decode_set $dir/decode_$decode_set || exit 1; + done +fi + +echo "$0 Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/iam/v1/local/chain/run_cnn_chainali_1b.sh b/egs/iam/v1/local/chain/tuning/run_cnn_chainali_1b.sh similarity index 79% rename from egs/iam/v1/local/chain/run_cnn_chainali_1b.sh rename to egs/iam/v1/local/chain/tuning/run_cnn_chainali_1b.sh index c6876fbafcb..401ffa14e19 100755 --- a/egs/iam/v1/local/chain/run_cnn_chainali_1b.sh +++ b/egs/iam/v1/local/chain/tuning/run_cnn_chainali_1b.sh @@ -1,27 +1,26 @@ #!/bin/bash # chainali_1b is as chainali_1a except it has 3 more cnn layers and 1 less tdnn layer. - -# local/chain/compare_wer.sh exp/chain/cnn_1a/ exp/chain/cnn_chainali_1b/ -# System cnn_1a cnn_chainali_1b -# WER 18.52 14.38 -# CER 10.07 7.14 -# Final train prob -0.0077 -0.0113 -# Final valid prob -0.0970 -0.0400 -# Final train prob (xent) -0.5484 -0.6043 -# Final valid prob (xent) -0.9643 -0.9030 -# Parameters 4.36M 3.96M +# local/chain/compare_wer.sh exp/chain/cnn_chainali_1b +# System cnn_chainali_1b(dict_50k) cnn_chainali_1b(dict_50k + unk_model) +# WER 15.09 13.29 +# CER 7.13 6.08 +# WER val 14.80 11.98 +# CER val 6.16 4.87 +# Final train prob -0.0225 +# Final valid prob -0.0132 +# Final train prob (xent) -0.4466 +# Final valid prob (xent) -0.6048 +# Parameters 3.96M # steps/info/chain_dir_info.pl exp/chain/chainali_cnn_1b/ -# exp/chain/chainali_cnn_1b/: num-iters=21 nj=2..4 num-params=4.0M dim=40->364 combine=-0.009->-0.005 xent:train/valid[13,20,final]=(-1.47,-0.728,-0.623/-1.69,-1.02,-0.940) logprob:train/valid[13,20,final]=(-0.068,-0.030,-0.011/-0.086,-0.056,-0.038) - +# exp/chain/cnn_chainali_1b: num-iters=42 nj=2..4 num-params=4.0M dim=40->368 combine=-0.019->-0.019 (over 2) xent:train/valid[27,41,final]=(-0.545,-0.448,-0.447/-0.645,-0.605,-0.605) logprob:train/valid[27,41,final]=(-0.026,-0.023,-0.023/-0.014,-0.013,-0.013) set -e -o pipefail - stage=0 - nj=30 train_set=train +decode_val=true gmm=tri3 # this is the source gmm-dir that we'll use for alignments; it # should have alignments for the specified training data. nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. @@ -31,31 +30,20 @@ chain_model_dir=exp/chain${nnet3_affix}/cnn_1a common_egs_dir= reporting_email= -# chain options train_stage=-10 xent_regularize=0.1 -frame_subsampling_factor=4 -alignment_subsampling_factor=1 -# training chunk-options chunk_width=340,300,200,100 num_leaves=500 -# we don't need extra left/right context for TDNN systems. -chunk_left_context=0 -chunk_right_context=0 tdnn_dim=450 -# training options -srand=0 -remove_egs=false -lang_test=lang_unk +lang_decode=lang_unk +if $decode_val; then maybe_val=val; else maybe_val= ; fi # End configuration section. echo "$0 $@" # Print the command line for logging - . ./cmd.sh . ./path.sh . ./utils/parse_options.sh - if ! cuda-compiled; then cat < $dir/configs/network.xconfig input dim=40 name=input - conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 @@ -160,7 +145,6 @@ if [ $stage -le 4 ]; then ## adding the layers for chain branch relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 - # adding the layers for xent branch # This block prints the configs for a separate output that will be # trained with a cross-entropy objective in the 'chain' mod?els... this @@ -191,9 +175,9 @@ if [ $stage -le 5 ]; then --chain.l2-regularize=0.00005 \ --chain.apply-deriv-weights=false \ --chain.lm-opts="--num-extra-lm-states=500" \ - --chain.frame-subsampling-factor=$frame_subsampling_factor \ - --chain.alignment-subsampling-factor=$alignment_subsampling_factor \ - --trainer.srand=$srand \ + --chain.frame-subsampling-factor=4 \ + --chain.alignment-subsampling-factor=1 \ + --trainer.srand=0 \ --trainer.max-param-change=2.0 \ --trainer.num-epochs=4 \ --trainer.frames-per-iter=1000000 \ @@ -203,15 +187,10 @@ if [ $stage -le 5 ]; then --trainer.optimization.final-effective-lrate=0.0001 \ --trainer.optimization.shrink-value=1.0 \ --trainer.num-chunk-per-minibatch=64,32 \ - --trainer.optimization.momentum=0.0 \ --egs.chunk-width=$chunk_width \ - --egs.chunk-left-context=$chunk_left_context \ - --egs.chunk-right-context=$chunk_right_context \ - --egs.chunk-left-context-initial=0 \ - --egs.chunk-right-context-final=0 \ --egs.dir="$common_egs_dir" \ --egs.opts="--frames-overlap-per-eg 0" \ - --cleanup.remove-egs=$remove_egs \ + --cleanup.remove-egs=false \ --use-gpu=true \ --reporting.email="$reporting_email" \ --feat-dir=$train_data_dir \ @@ -227,20 +206,20 @@ if [ $stage -le 6 ]; then # topology file from the model). So you could give it a different # lang directory, one that contained a wordlist and LM of your choice, # as long as phones.txt was compatible. - utils/mkgraph.sh \ - --self-loop-scale 1.0 data/$lang_test \ + --self-loop-scale 1.0 data/$lang_decode \ $dir $dir/graph || exit 1; fi if [ $stage -le 7 ]; then frames_per_chunk=$(echo $chunk_width | cut -d, -f1) - steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ - --extra-left-context $chunk_left_context \ - --extra-right-context $chunk_right_context \ - --extra-left-context-initial 0 \ - --extra-right-context-final 0 \ - --frames-per-chunk $frames_per_chunk \ - --nj $nj --cmd "$cmd" \ - $dir/graph data/test $dir/decode_test || exit 1; + for decode_set in test $maybe_val; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/$decode_set $dir/decode_$decode_set || exit 1; + done fi + +echo "$0 Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/iam/v1/local/chain/run_cnn_chainali_1c.sh b/egs/iam/v1/local/chain/tuning/run_cnn_chainali_1c.sh similarity index 80% rename from egs/iam/v1/local/chain/run_cnn_chainali_1c.sh rename to egs/iam/v1/local/chain/tuning/run_cnn_chainali_1c.sh index 54c52d913de..17209b9204f 100755 --- a/egs/iam/v1/local/chain/run_cnn_chainali_1c.sh +++ b/egs/iam/v1/local/chain/tuning/run_cnn_chainali_1c.sh @@ -1,25 +1,25 @@ #!/bin/bash # chainali_1c is as chainali_1b except it uses l2-regularize -# local/chain/compare_wer.sh exp/chain/cnn_chainali_1b exp/chain/cnn_chainali_1c -# System cnn_chainali_1b cnn_chainali_1c -# WER 14.38 12.72 -# CER 7.14 5.99 -# Final train prob -0.0113 -0.0291 -# Final valid prob -0.0400 -0.0359 -# Final train prob (xent) -0.6043 -0.9781 -# Final valid prob (xent) -0.9030 -1.1544 -# Parameters 3.96M 3.96M +# local/chain/compare_wer.sh exp/chain/cnn_chainali_1c +# System cnn_chainali_1c (dict_50k) cnn_chainali_1c(dict_50k + unk_model) +# WER 12.95 11.07 +# CER 6.04 4.91 +# WER val 12.75 9.78 +# CER val 5.15 3.74 +# Final train prob -0.0217 +# Final valid prob -0.0060 +# Final train prob (xent) -0.8303 +# Final valid prob (xent) -0.8665 +# Parameters 3.96M # steps/info/chain_dir_info.pl exp/chain/cnn_chainali_1c -# exp/chain/cnn_chainali_1c: num-iters=21 nj=2..4 num-params=4.0M dim=40->369 combine=-0.007->-0.007 (over 1) xent:train/valid[13,20,final]=(-1.44,-1.05,-0.997/-1.53,-1.19,-1.15) logprob:train/valid[13,20,final]=(-0.056,-0.020,-0.012/-0.056,-0.025,-0.020) - +# exp/chain/cnn_chainali_1c/: num-iters=42 nj=2..4 num-params=4.0M dim=40->368 combine=-0.018->-0.018 (over 1) xent:train/valid[27,41,final]=(-1.22,-0.847,-0.830/-1.19,-0.880,-0.867) logprob:train/valid[27,41,final]=(-0.045,-0.025,-0.022/-0.026,-0.010,-0.006) set -e -o pipefail - stage=0 - nj=30 train_set=train +decode_val=true gmm=tri3 # this is the source gmm-dir that we'll use for alignments; it # should have alignments for the specified training data. nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. @@ -29,30 +29,20 @@ chain_model_dir=exp/chain${nnet3_affix}/cnn_1a common_egs_dir= reporting_email= -# chain options train_stage=-10 xent_regularize=0.1 -frame_subsampling_factor=4 -# training chunk-options chunk_width=340,300,200,100 num_leaves=500 -# we don't need extra left/right context for TDNN systems. -chunk_left_context=0 -chunk_right_context=0 tdnn_dim=450 -# training options -srand=0 -remove_egs=false -lang_test=lang_unk +lang_decode=lang_unk +if $decode_val; then maybe_val=val; else maybe_val= ; fi # End configuration section. echo "$0 $@" # Print the command line for logging - . ./cmd.sh . ./path.sh . ./utils/parse_options.sh - if ! cuda-compiled; then cat <376 combine=-0.002->-0.002 (over 1) xent:train/valid[13,20,final]=(-1.66,-1.01,-0.865/-1.72,-1.12,-1.01) logprob:train/valid[13,20,final]=(-0.058,-0.019,-0.004/-0.055,-0.027,-0.013) - +# exp/chain/cnn_chainali_1d/: num-iters=42 nj=2..4 num-params=4.0M dim=40->368 combine=-0.018->-0.018 (over 1) xent:train/valid[27,41,final]=(-1.22,-0.847,-0.830/-1.19,-0.880,-0.867) logprob:train/valid[27,41,final]=(-0.045,-0.025,-0.022/-0.026,-0.010,-0.006) set -e -o pipefail stage=0 - nj=30 train_set=train gmm=tri3 # this is the source gmm-dir that we'll use for alignments; it # should have alignments for the specified training data. nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. -affix=_1c_uc #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +affix=_1d #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. ali=tri3_ali -chain_model_dir=exp/chain${nnet3_affix}/cnn_1a_uc +chain_model_dir=exp/chain${nnet3_affix}/cnn_1a common_egs_dir= reporting_email= # chain options train_stage=-10 xent_regularize=0.1 -frame_subsampling_factor=4 # training chunk-options chunk_width=340,300,200,100 num_leaves=500 -# we don't need extra left/right context for TDNN systems. -chunk_left_context=0 -chunk_right_context=0 tdnn_dim=450 -# training options -srand=0 -remove_egs=false -lang_test=lang_unk +lang_decode=lang_unk +decode_val=true +if $decode_val; then maybe_val=val; else maybe_val= ; fi + # End configuration section. echo "$0 $@" # Print the command line for logging - . ./cmd.sh . ./path.sh . ./utils/parse_options.sh - if ! cuda-compiled; then cat < $dir/configs/network.xconfig input dim=40 name=input - conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 @@ -157,7 +147,6 @@ if [ $stage -le 4 ]; then relu-batchnorm-layer name=tdnn1 input=Append(-4,-2,0,2,4) dim=$tdnn_dim $tdnn_opts relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts - ## adding the layers for chain branch relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts @@ -192,11 +181,11 @@ if [ $stage -le 5 ]; then --chain.l2-regularize=0.00005 \ --chain.apply-deriv-weights=false \ --chain.lm-opts="--num-extra-lm-states=500" \ - --chain.frame-subsampling-factor=$frame_subsampling_factor \ + --chain.frame-subsampling-factor=4 \ --chain.alignment-subsampling-factor=1 \ --chain.left-tolerance 3 \ --chain.right-tolerance 3 \ - --trainer.srand=$srand \ + --trainer.srand=0 \ --trainer.max-param-change=2.0 \ --trainer.num-epochs=4 \ --trainer.frames-per-iter=1000000 \ @@ -206,15 +195,10 @@ if [ $stage -le 5 ]; then --trainer.optimization.final-effective-lrate=0.0001 \ --trainer.optimization.shrink-value=1.0 \ --trainer.num-chunk-per-minibatch=64,32 \ - --trainer.optimization.momentum=0.0 \ --egs.chunk-width=$chunk_width \ - --egs.chunk-left-context=$chunk_left_context \ - --egs.chunk-right-context=$chunk_right_context \ - --egs.chunk-left-context-initial=0 \ - --egs.chunk-right-context-final=0 \ --egs.dir="$common_egs_dir" \ --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ - --cleanup.remove-egs=$remove_egs \ + --cleanup.remove-egs=false \ --use-gpu=true \ --reporting.email="$reporting_email" \ --feat-dir=$train_data_dir \ @@ -230,20 +214,20 @@ if [ $stage -le 6 ]; then # topology file from the model). So you could give it a different # lang directory, one that contained a wordlist and LM of your choice, # as long as phones.txt was compatible. - utils/mkgraph.sh \ - --self-loop-scale 1.0 data/$lang_test \ + --self-loop-scale 1.0 data/$lang_decode \ $dir $dir/graph || exit 1; fi if [ $stage -le 7 ]; then frames_per_chunk=$(echo $chunk_width | cut -d, -f1) - steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ - --extra-left-context $chunk_left_context \ - --extra-right-context $chunk_right_context \ - --extra-left-context-initial 0 \ - --extra-right-context-final 0 \ - --frames-per-chunk $frames_per_chunk \ - --nj $nj --cmd "$cmd" \ - $dir/graph data/test $dir/decode_test || exit 1; + for decode_set in test $maybe_val; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/$decode_set $dir/decode_$decode_set || exit 1; + done fi + +echo "$0 Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/iam/v1/local/chain/tuning/run_cnn_e2eali_1a.sh b/egs/iam/v1/local/chain/tuning/run_cnn_e2eali_1a.sh new file mode 100755 index 00000000000..703d404159a --- /dev/null +++ b/egs/iam/v1/local/chain/tuning/run_cnn_e2eali_1a.sh @@ -0,0 +1,229 @@ +#!/bin/bash + +# local/chain/compare_wer.sh exp/chain/cnn_e2eali_1a +# System cnn_e2eali_1a_(dict_50k) cnn_e2eali_1a_(dict_50k + unk model) +# WER 13.30 11.94 +# CER 5.95 5.15 +# WER val 12.85 10.71 +# CER val 5.09 4.03 +# Final train prob -0.0562 +# Final valid prob -0.0634 +# Final train prob (xent) -0.8196 +# Final valid prob (xent) -0.8816 +# Parameters 3.96M + +# steps/info/chain_dir_info.pl exp/chain/cnn_e2eali_1a +# exp/chain/cnn_e2eali_1a: num-iters=42 nj=2..4 num-params=4.0M dim=40->368 combine=-0.058->-0.058 (over 1) xent:train/valid[27,41,final]=(-2.67,-0.841,-0.820/-2.71,-0.892,-0.882) logprob:train/valid[27,41,final]=(-0.240,-0.060,-0.056/-0.245,-0.068,-0.063) + +set -e -o pipefail + +stage=0 +nj=30 +train_set=train +decode_val=true +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. +affix=_1a #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +e2echain_model_dir=exp/chain/e2e_cnn_1a +common_egs_dir= +reporting_email= + +# chain options +train_stage=-10 +xent_regularize=0.1 +frame_subsampling_factor=4 +chunk_width=340,300,200,100 +num_leaves=500 +tdnn_dim=450 +remove_egs=true +lang_decode=lang_unk +if $decode_val; then maybe_val=val; else maybe_val= ; fi +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 2 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/nnet3/align_lats.sh --nj $nj --cmd "$cmd" \ + --acoustic-scale 1.0 \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0' \ + ${train_data_dir} data/lang $e2echain_model_dir $lat_dir + echo "" >$lat_dir/splice_opts +fi + +if [ $stage -le 3 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor 4 \ + --alignment-subsampling-factor 1 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$cmd" $num_leaves $train_data_dir \ + $lang $ali_dir $tree_dir +fi + +if [ $stage -le 4 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + cnn_opts="l2-regularize=0.075" + tdnn_opts="l2-regularize=0.075" + output_opts="l2-regularize=0.1" + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=70" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=40 name=input + + conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=10 height-out=10 time-offsets=-1,0,1 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=10 height-out=10 time-offsets=-1,0,1 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-4,-2,0,2,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' mod?els... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 5 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/iam-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--num-extra-lm-states=500" \ + --chain.frame-subsampling-factor=4 \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --trainer.srand=0 \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=4 \ + --trainer.frames-per-iter=1000000 \ + --trainer.optimization.num-jobs-initial=2 \ + --trainer.optimization.num-jobs-final=4 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=64,32 \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 6 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/$lang_decode \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 7 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + for decode_set in test $maybe_val; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/$decode_set $dir/decode_$decode_set || exit 1; + done +fi + +echo "$0 Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/iam/v1/local/chain/tuning/run_cnn_e2eali_1b.sh b/egs/iam/v1/local/chain/tuning/run_cnn_e2eali_1b.sh new file mode 100755 index 00000000000..905c4661477 --- /dev/null +++ b/egs/iam/v1/local/chain/tuning/run_cnn_e2eali_1b.sh @@ -0,0 +1,221 @@ +#!/bin/bash + +# e2eali_1b is the same as e2eali_1a but uses unconstrained egs +# local/chain/compare_wer.sh exp/chain/cnn_e2eali_1b +# System cnn_e2eali_1b (dict_50k) cnn_e2eali_1b (dict_50k + unk model) +# WER 12.46 11.20 +# CER 5.53 4.76 +# WER val 12.71 10.49 +# CER val 4.97 3.92 +# Final train prob -0.0381 +# Final valid prob -0.0443 +# Final train prob (xent) -0.7860 +# Final valid prob (xent) -0.8290 +# Parameters 3.96M + +# steps/info/chain_dir_info.pl exp/chain/cnn_e2eali_1b +# exp/chain/cnn_e2eali_1b: num-iters=42 nj=2..4 num-params=4.0M dim=40->368 combine=-0.039->-0.039 (over 2) xent:train/valid[27,41,final]=(-1.19,-0.805,-0.786/-1.19,-0.846,-0.829) logprob:train/valid[27,41,final]=(-0.060,-0.041,-0.038/-0.062,-0.048,-0.044) + +set -e -o pipefail +stage=0 +nj=30 +train_set=train +decode_val=true +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. +affix=_1b #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +e2echain_model_dir=exp/chain/e2e_cnn_1a +common_egs_dir= +reporting_email= + +# chain options +train_stage=-10 +xent_regularize=0.1 +chunk_width=340,300,200,100 +num_leaves=500 +tdnn_dim=450 +lang_decode=lang_unk +if $decode_val; then maybe_val=val; else maybe_val= ; fi +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 2 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/nnet3/align_lats.sh --nj $nj --cmd "$cmd" \ + --acoustic-scale 1.0 \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0' \ + $train_data_dir data/lang $e2echain_model_dir $lat_dir + echo "" >$lat_dir/splice_opts +fi + +if [ $stage -le 3 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor 4 \ + --alignment-subsampling-factor 1 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$cmd" $num_leaves $train_data_dir \ + $lang $ali_dir $tree_dir +fi + +if [ $stage -le 4 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + cnn_opts="l2-regularize=0.075" + tdnn_opts="l2-regularize=0.075" + output_opts="l2-regularize=0.1" + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=70" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=40 name=input + conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=10 height-out=10 time-offsets=-1,0,1 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=10 height-out=10 time-offsets=-1,0,1 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-4,-2,0,2,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' mod?els... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 5 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/iam-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--num-extra-lm-states=500" \ + --chain.frame-subsampling-factor=4 \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --trainer.srand=0 \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=4 \ + --trainer.frames-per-iter=1000000 \ + --trainer.optimization.num-jobs-initial=2 \ + --trainer.optimization.num-jobs-final=4 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=64,32 \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ + --cleanup.remove-egs=true \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 6 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/$lang_decode \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 7 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + for decode_set in test $maybe_val; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/$decode_set $dir/decode_$decode_set || exit 1; + done +fi + +echo "$0 Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/iam/v1/local/chain/tuning/run_cnn_e2eali_1c.sh b/egs/iam/v1/local/chain/tuning/run_cnn_e2eali_1c.sh new file mode 100755 index 00000000000..26b1aca0929 --- /dev/null +++ b/egs/iam/v1/local/chain/tuning/run_cnn_e2eali_1c.sh @@ -0,0 +1,224 @@ +#!/bin/bash + +# e2eali_1c is the same as e2eali_1b but has more CNN layers, different filter size +# smaller lm-opts, minibatch, frams-per-iter, less epochs and more initial/finaljobs. +# local/chain/compare_wer.sh exp/chain/cnn_e2eali_1c +# System cnn_e2eali_1c (dict_50k) cnn_e2eali_1c(dict_50k + unk_model) +# WER 12.10 9.90 +# CER 5.23 4.16 +# WER val 12.15 9.60 +# CER val 4.78 3.56 +# Final train prob -0.0470 +# Final valid prob -0.0657 +# Final train prob (xent) -0.4713 +# Final valid prob (xent) -0.5437 +# Parameters 4.32M + +# steps/info/chain_dir_info.pl exp/chain/cnn_e2eali_1c +# exp/chain/cnn_e2eali_1c: num-iters=30 nj=3..5 num-params=4.3M dim=40->368 combine=-0.051->-0.051 (over 1) xent:train/valid[19,29,final]=(-0.722,-0.500,-0.471/-0.748,-0.568,-0.544) logprob:train/valid[19,29,final]=(-0.090,-0.053,-0.047/-0.106,-0.071,-0.066) +set -e -o pipefail + +stage=0 +nj=30 +train_set=train +decode_val=true +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. +affix=_1c #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +e2echain_model_dir=exp/chain/e2e_cnn_1a +common_egs_dir= +reporting_email= + +# chain options +train_stage=-10 +xent_regularize=0.1 +chunk_width=340,300,200,100 +num_leaves=500 +tdnn_dim=550 +lang_decode=data/lang_unk +if $decode_val; then maybe_val=val; else maybe_val= ; fi +dropout_schedule='0,0@0.20,0.2@0.50,0' +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 2 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/nnet3/align_lats.sh --nj $nj --cmd "$cmd" \ + --acoustic-scale 1.0 \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0' \ + $train_data_dir data/lang $e2echain_model_dir $lat_dir + echo "" >$lat_dir/splice_opts +fi + +if [ $stage -le 3 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor 4 \ + --alignment-subsampling-factor 1 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$cmd" $num_leaves $train_data_dir \ + $lang $ali_dir $tree_dir +fi + +if [ $stage -le 4 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + cnn_opts="l2-regularize=0.03 dropout-proportion=0.0" + tdnn_opts="l2-regularize=0.03" + output_opts="l2-regularize=0.04" + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=70" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=40 name=input + conv-relu-batchnorm-dropout-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-dropout-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-dropout-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-dropout-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-dropout-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common3 height-subsample-out=2 + conv-relu-batchnorm-dropout-layer name=cnn6 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + relu-batchnorm-dropout-layer name=tdnn1 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + relu-batchnorm-dropout-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + relu-batchnorm-dropout-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' mod?els... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 5 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/iam-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=true \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=1000" \ + --chain.frame-subsampling-factor=4 \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --trainer.srand=0 \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=5 \ + --trainer.frames-per-iter=1500000 \ + --trainer.optimization.num-jobs-initial=3 \ + --trainer.optimization.num-jobs-final=5 \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=32,16 \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ + --cleanup.remove-egs=true \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 6 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + utils/mkgraph.sh \ + --self-loop-scale 1.0 $lang_decode \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 7 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + for decode_set in test $maybe_val; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/$decode_set $dir/decode_$decode_set || exit 1; + done +fi + +echo "$0 Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/iam/v1/local/chain/tuning/run_e2e_cnn_1a.sh b/egs/iam/v1/local/chain/tuning/run_e2e_cnn_1a.sh new file mode 100755 index 00000000000..462ad0522de --- /dev/null +++ b/egs/iam/v1/local/chain/tuning/run_e2e_cnn_1a.sh @@ -0,0 +1,154 @@ +#!/bin/bash +# Copyright 2017 Hossein Hadian + +# This script does end2end chain training (i.e. from scratch) +# local/chain/compare_wer.sh exp/chain/e2e_cnn_1a/ +# System e2e_cnn_1a (dict_50k) e2e_cnn_1a (dict_50k + unk_model) +# WER 15.21 14.41 +# CER 7.43 6.82 +# WER val 14.84 13.51 +# CER val 6.41 5.60 +# Final train prob -0.0206 +# Final valid prob -0.0393 +# Final train prob (xent) +# Final valid prob (xent) +# Parameters 9.52M + +# steps/info/chain_dir_info.pl exp/chain/e2e_cnn_1a +# exp/chain/e2e_cnn_1a: num-iters=42 nj=2..4 num-params=9.5M dim=40->12640 combine=-0.020->-0.020 (over 1) logprob:train/valid[27,41,final]=(-0.025,-0.021,-0.021/-0.044,-0.040,-0.039) + +set -e +stage=0 +train_stage=-10 +get_egs_stage=-10 +affix=1a +nj=30 + +# training options +tdnn_dim=450 +minibatch_size=150=100,64/300=50,32/600=25,16/1200=16,8 +common_egs_dir= +train_set=train +decode_val=true +lang_decode=data/lang_unk +if $decode_val; then maybe_val=val; else maybe_val= ; fi +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 1 ]; then + steps/nnet3/chain/e2e/prepare_e2e.sh --nj 30 --cmd "$cmd" \ + --shared-phones true \ + --type biphone \ + data/$train_set $lang $treedir + $cmd $treedir/log/make_phone_lm.log \ + cat data/$train_set/text \| \ + steps/nnet3/chain/e2e/text_to_phones.py data/lang \| \ + utils/sym2int.pl -f 2- data/lang/phones.txt \| \ + chain-est-phone-lm --num-extra-lm-states=500 \ + ark:- $treedir/phone_lm.fst +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $treedir/tree | grep num-pdfs | awk '{print $2}') + common1="height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="height-offsets=-2,-1,0,1,2 num-filters-out=70" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=40 name=input + conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + relu-batchnorm-layer name=tdnn1 input=Append(-4,-2,0,2,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn4 input=Append(-4,0,4) dim=$tdnn_dim + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $output_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs +fi + +if [ $stage -le 3 ]; then + # no need to store the egs in a shared storage because we always + # remove them. Anyway, it takes only 5 minutes to generate them. + steps/nnet3/chain/e2e/train_e2e.py --stage $train_stage \ + --cmd "$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--num_egs_diagnostic 100 --num_utts_subset 400" \ + --chain.frame-subsampling-factor 4 \ + --chain.alignment-subsampling-factor 4 \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter 1000000 \ + --trainer.num-epochs 4 \ + --trainer.optimization.momentum 0 \ + --trainer.optimization.num-jobs-initial 2 \ + --trainer.optimization.num-jobs-final 4 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.optimization.shrink-value 1.0 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs true \ + --feat-dir data/$train_set \ + --tree-dir $treedir \ + --dir $dir || exit 1; +fi + +if [ $stage -le 4 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + utils/mkgraph.sh \ + --self-loop-scale 1.0 $lang_decode \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 5 ]; then + for decode_set in test $maybe_val; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/$decode_set $dir/decode_$decode_set || exit 1; + done +fi + +echo "$0 Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/iam/v1/local/extract_features.sh b/egs/iam/v1/local/extract_features.sh new file mode 100755 index 00000000000..1741ad3f9b2 --- /dev/null +++ b/egs/iam/v1/local/extract_features.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +# Copyright 2017 Yiwen Shao +# 2018 Ashish Arora + +# Apache 2.0 +# This script runs the make features script in parallel. + +nj=4 +cmd=run.pl +feat_dim=40 +augment=false +fliplr=false +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +data=$1 +featdir=$data/data +scp=$data/images.scp +logdir=$data/log + +mkdir -p $logdir +mkdir -p $featdir + +# make $featdir an absolute pathname +featdir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $featdir ${PWD}` + +for n in $(seq $nj); do + split_scps="$split_scps $logdir/images.$n.scp" +done + +# split images.scp +utils/split_scp.pl $scp $split_scps || exit 1; + +$cmd JOB=1:$nj $logdir/extract_features.JOB.log \ + local/make_features.py $logdir/images.JOB.scp \ + --allowed_len_file_path $data/allowed_lengths.txt \ + --feat-dim $feat_dim --fliplr $fliplr --augment $augment \| \ + copy-feats --compress=true --compression-method=7 \ + ark:- ark,scp:$featdir/images.JOB.ark,$featdir/images.JOB.scp + +## aggregates the output scp's to get feats.scp +for n in $(seq $nj); do + cat $featdir/images.$n.scp || exit 1; +done > $data/feats.scp || exit 1 diff --git a/egs/iam/v1/local/gen_topo.py b/egs/iam/v1/local/gen_topo.py new file mode 100755 index 00000000000..6fae276d542 --- /dev/null +++ b/egs/iam/v1/local/gen_topo.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python + +# Copyright 2017 (author: Chun-Chieh Chang) + +# Generate a topology file. This allows control of the number of states in the +# non-silence HMMs, and in the silence HMMs. This is a modified version of +# 'utils/gen_topo.pl'. The difference is that this creates two topologies for +# the non-silence HMMs. The number of states for punctuations is different than +# the number of states for other characters. + +from __future__ import print_function +import argparse +import string + +parser = argparse.ArgumentParser(description="Usage: steps/nnet3/chain/gen_topo.py " + " " + "e.g.: steps/nnet3/chain/gen_topo.pl 4:5:6:7:8:9:10 1:2:3\n", + epilog="See egs/swbd/s5c/local/chain/train_tdnn_a.sh for example of usage."); +parser.add_argument("num_nonsil_states", type=int, help="number of states for nonsilence phones"); +parser.add_argument("num_sil_states", type=int, help="number of states for silence phones"); +parser.add_argument("num_punctuation_states", type=int, help="number of states for punctuation"); +parser.add_argument("nonsilence_phones", type=str, + help="List of non-silence phones as integers, separated by colons, e.g. 4:5:6:7:8:9"); +parser.add_argument("silence_phones", type=str, + help="List of silence phones as integers, separated by colons, e.g. 1:2:3"); +parser.add_argument("phone_list", type=str, help="file containing all phones and their corresponding number."); + +args = parser.parse_args() + +silence_phones = [ int(x) for x in args.silence_phones.split(":") ] +nonsilence_phones = [ int(x) for x in args.nonsilence_phones.split(":") ] +all_phones = silence_phones + nonsilence_phones + +punctuation_phones = [] +exclude = set("!(),.?;:'-\"") +with open(args.phone_list) as f: + for line in f: + line = line.strip() + phone = line.split('_')[0] + if len(phone) == 1 and phone in exclude: + punctuation_phones.append(int(line.split(' ')[1])) +# For nonsilence phones that are not punctuations +print("") +print("") +print("") +print(" ".join([str(x) for x in nonsilence_phones if x not in punctuation_phones])) +print("") +for x in range(0, args.num_nonsil_states): + xp1 = x + 1 + print(" " + str(x) + " " + str(x) + " " + str(x) + " 0.75 " + str(xp1) + " 0.25 ") +print(" " + str(args.num_nonsil_states) + " ") +print("") + +# For nonsilence phones that ar punctuations +print("") +print("") +print(" ".join([str(x) for x in nonsilence_phones if x in punctuation_phones])) +print("") +for x in range(0, args.num_punctuation_states): + xp1 = x + 1 + print(" " + str(x) + " " + str(x) + " " + str(x) + " 0.75 " + str(xp1) + " 0.25 ") +print(" " + str(args.num_punctuation_states) + " ") +print("") + +# For silence phones +print("") +print("") +print(" ".join([str(x) for x in silence_phones])) +print("") +if(args.num_sil_states > 1): + transp = 1.0 / (args.num_sil_states - 1) + + state_str = " 0 0 " + for x in range(0, (args.num_sil_states - 1)): + state_str = state_str + " " + str(x) + " " + str(transp) + " " + state_str = state_str + "" + print(state_str) + + for x in range(1, (args.num_sil_states - 1)): + state_str = " " + str(x) + " " + str(x) + " " + for y in range(1, args.num_sil_states): + state_str = state_str + " " + str(y) + " " + str(transp) + " " + state_str = state_str + "" + print(state_str) + second_last = args.num_sil_states - 1 + print(" " + str(second_last) + " " + str(second_last) + " " + str(second_last) + " 0.75 " + str(args.num_sil_states) + " 0.25 ") + print(" " + str(args.num_sil_states) + " ") +else: + print(" 0 0 0 0.75 1 0.25 ") + print(" " + str(args.num_sil_states) + " ") +print("") +print("") diff --git a/egs/iam/v1/local/make_features.py b/egs/iam/v1/local/make_features.py index 84e012daedb..3ce501732cf 100755 --- a/egs/iam/v1/local/make_features.py +++ b/egs/iam/v1/local/make_features.py @@ -2,6 +2,7 @@ # Copyright 2017 Chun Chieh Chang # 2017 Ashish Arora +# 2017 Yiwen Shao # 2018 Hossein Hadian """ This script converts images to Kaldi-format feature matrices. The input to @@ -14,20 +15,27 @@ to enforce the images to have the specified length in that file by padding white pixels (the --padding option will be ignored in this case). This relates to end2end chain training. - eg. local/make_features.py data/train --feat-dim 40 """ - +import random import argparse import os import sys +import scipy.io as sio import numpy as np from scipy import misc +from scipy.ndimage.interpolation import affine_transform +import math +from signal import signal, SIGPIPE, SIG_DFL +signal(SIGPIPE, SIG_DFL) parser = argparse.ArgumentParser(description="""Converts images (in 'dir'/images.scp) to features and writes them to standard output in text format.""") -parser.add_argument('dir', type=str, - help='Source data directory (containing images.scp)') +parser.add_argument('images_scp_path', type=str, + help='Path of images.scp file') +parser.add_argument('--allowed_len_file_path', type=str, default=None, + help='If supplied, each images will be padded to reach the ' + 'target length (this overrides --padding).') parser.add_argument('--out-ark', type=str, default='-', help='Where to write the output feature file') parser.add_argument('--feat-dim', type=int, default=40, @@ -35,8 +43,10 @@ parser.add_argument('--padding', type=int, default=5, help='Number of white pixels to pad on the left' 'and right side of the image.') - - +parser.add_argument('--fliplr', type=lambda x: (str(x).lower()=='true'), default=False, + help="Flip the image left-right for right to left languages") +parser.add_argument("--augment", type=lambda x: (str(x).lower()=='true'), default=False, + help="performs image augmentation") args = parser.parse_args() @@ -56,18 +66,12 @@ def write_kaldi_matrix(file_handle, matrix, key): file_handle.write("\n") file_handle.write(" ]\n") -def get_scaled_image(im, allowed_lengths = None): - scale_size = args.feat_dim - sx = im.shape[1] - sy = im.shape[0] - scale = (1.0 * scale_size) / sy - nx = int(scale_size) - ny = int(scale * sx) - im = misc.imresize(im, (nx, ny)) + +def horizontal_pad(im, allowed_lengths = None): if allowed_lengths is None: left_padding = right_padding = args.padding else: # Find an allowed length for the image - imlen = im.shape[1] + imlen = im.shape[1] # width allowed_len = 0 for l in allowed_lengths: if l > imlen: @@ -77,28 +81,153 @@ def get_scaled_image(im, allowed_lengths = None): # No allowed length was found for the image (the image is too long) return None padding = allowed_len - imlen - left_padding = padding // 2 + left_padding = int(padding // 2) right_padding = padding - left_padding - dim_y = im.shape[0] + dim_y = im.shape[0] # height im_pad = np.concatenate((255 * np.ones((dim_y, left_padding), dtype=int), im), axis=1) im_pad1 = np.concatenate((im_pad, 255 * np.ones((dim_y, right_padding), dtype=int)), axis=1) return im_pad1 -### main ### -data_list_path = os.path.join(args.dir, 'images.scp') +def get_scaled_image_aug(im, mode='normal'): + scale_size = args.feat_dim + sx = im.shape[1] + sy = im.shape[0] + scale = (1.0 * scale_size) / sy + nx = int(scale_size) + ny = int(scale * sx) + scale_size = random.randint(10, 30) + scale = (1.0 * scale_size) / sy + down_nx = int(scale_size) + down_ny = int(scale * sx) + if mode == 'normal': + im = misc.imresize(im, (nx, ny)) + return im + else: + im_scaled_down = misc.imresize(im, (down_nx, down_ny)) + im_scaled_up = misc.imresize(im_scaled_down, (nx, ny)) + return im_scaled_up + return im + +def contrast_normalization(im, low_pct, high_pct): + element_number = im.size + rows = im.shape[0] + cols = im.shape[1] + im_contrast = np.zeros(shape=im.shape) + low_index = int(low_pct * element_number) + high_index = int(high_pct * element_number) + sorted_im = np.sort(im, axis=None) + low_thred = sorted_im[low_index] + high_thred = sorted_im[high_index] + for i in range(rows): + for j in range(cols): + if im[i, j] > high_thred: + im_contrast[i, j] = 255 # lightest to white + elif im[i, j] < low_thred: + im_contrast[i, j] = 0 # darkest to black + else: + # linear normalization + im_contrast[i, j] = (im[i, j] - low_thred) * \ + 255 / (high_thred - low_thred) + return im_contrast + + +def geometric_moment(frame, p, q): + m = 0 + for i in range(frame.shape[1]): + for j in range(frame.shape[0]): + m += (i ** p) * (j ** q) * frame[i][i] + return m + + +def central_moment(frame, p, q): + u = 0 + x_bar = geometric_moment(frame, 1, 0) / \ + geometric_moment(frame, 0, 0) # m10/m00 + y_bar = geometric_moment(frame, 0, 1) / \ + geometric_moment(frame, 0, 0) # m01/m00 + for i in range(frame.shape[1]): + for j in range(frame.shape[0]): + u += ((i - x_bar)**p) * ((j - y_bar)**q) * frame[i][j] + return u + + +def height_normalization(frame, w, h): + frame_normalized = np.zeros(shape=(h, w)) + alpha = 4 + x_bar = geometric_moment(frame, 1, 0) / \ + geometric_moment(frame, 0, 0) # m10/m00 + y_bar = geometric_moment(frame, 0, 1) / \ + geometric_moment(frame, 0, 0) # m01/m00 + sigma_x = (alpha * ((central_moment(frame, 2, 0) / + geometric_moment(frame, 0, 0)) ** .5)) # alpha * sqrt(u20/m00) + sigma_y = (alpha * ((central_moment(frame, 0, 2) / + geometric_moment(frame, 0, 0)) ** .5)) # alpha * sqrt(u02/m00) + for x in range(w): + for y in range(h): + i = int((x / w - 0.5) * sigma_x + x_bar) + j = int((y / h - 0.5) * sigma_y + y_bar) + frame_normalized[x][y] = frame[i][j] + return frame_normalized + +def find_slant_project(im): + rows = im.shape[0] + cols = im.shape[1] + std_max = 0 + alpha_max = 0 + col_disp = np.zeros(90, int) + proj = np.zeros(shape=(90, cols + 2 * rows), dtype=int) + for r in range(rows): + for alpha in range(-45, 45, 1): + col_disp[alpha] = int(r * math.tan(alpha / 180.0 * math.pi)) + for c in range(cols): + if im[r, c] < 100: + for alpha in range(-45, 45, 1): + proj[alpha + 45, c + col_disp[alpha] + rows] += 1 + for alpha in range(-45, 45, 1): + proj_histogram, bin_array = np.histogram(proj[alpha + 45, :], bins=10) + proj_std = np.std(proj_histogram) + if proj_std > std_max: + std_max = proj_std + alpha_max = alpha + proj_std = np.std(proj, axis=1) + return -alpha_max + + +def horizontal_shear(im, degree): + rad = degree / 180.0 * math.pi + padding_x = int(abs(np.tan(rad)) * im.shape[0]) + padding_y = im.shape[0] + if rad > 0: + im_pad = np.concatenate( + (255 * np.ones((padding_y, padding_x), dtype=int), im), axis=1) + elif rad < 0: + im_pad = np.concatenate( + (im, 255 * np.ones((padding_y, padding_x), dtype=int)), axis=1) + else: + im_pad = im + shear_matrix = np.array([[1, 0], + [np.tan(rad), 1]]) + sheared_im = affine_transform(im_pad, shear_matrix, cval=255.0) + return sheared_im + + +### main ### +random.seed(1) +data_list_path = args.images_scp_path if args.out_ark == '-': out_fh = sys.stdout else: - out_fh = open(args.out_ark,'wb') + out_fh = open(args.out_ark,'w') allowed_lengths = None -if os.path.isfile(os.path.join(args.dir, 'allowed_lengths.txt')): +allowed_len_handle = args.allowed_len_file_path +if os.path.isfile(allowed_len_handle): print("Found 'allowed_lengths.txt' file...", file=sys.stderr) allowed_lengths = [] - with open(os.path.join(args.dir,'allowed_lengths.txt')) as f: + with open(allowed_len_handle) as f: for line in f: allowed_lengths.append(int(line.strip())) print("Read {} allowed lengths and will apply them to the " @@ -106,6 +235,7 @@ def get_scaled_image(im, allowed_lengths = None): num_fail = 0 num_ok = 0 +aug_setting = ['normal', 'scaled'] with open(data_list_path) as f: for line in f: line = line.strip() @@ -113,15 +243,24 @@ def get_scaled_image(im, allowed_lengths = None): image_id = line_vect[0] image_path = line_vect[1] im = misc.imread(image_path) - im_scaled = get_scaled_image(im, allowed_lengths) - - if im_scaled is None: + if args.fliplr: + im = np.fliplr(im) + if args.augment: + im_aug = get_scaled_image_aug(im, aug_setting[0]) + im_contrast = contrast_normalization(im_aug, 0.05, 0.2) + slant_degree = find_slant_project(im_contrast) + im_sheared = horizontal_shear(im_contrast, slant_degree) + im_aug = im_sheared + else: + im_aug = get_scaled_image_aug(im, aug_setting[0]) + im_horizontal_padded = horizontal_pad(im_aug, allowed_lengths) + if im_horizontal_padded is None: num_fail += 1 continue - data = np.transpose(im_scaled, (1, 0)) + data = np.transpose(im_horizontal_padded, (1, 0)) data = np.divide(data, 255.0) num_ok += 1 write_kaldi_matrix(out_fh, data, image_id) -print('Generated features for {} images. Failed for {} (iamge too ' +print('Generated features for {} images. Failed for {} (image too ' 'long).'.format(num_ok, num_fail), file=sys.stderr) diff --git a/egs/iam/v1/local/prepare_data.sh b/egs/iam/v1/local/prepare_data.sh index 73d711c73f0..dc07f07e318 100755 --- a/egs/iam/v1/local/prepare_data.sh +++ b/egs/iam/v1/local/prepare_data.sh @@ -18,6 +18,7 @@ stage=0 download_dir=data/download +process_aachen_split=false wellington_dir= username= password= # username and password for downloading the IAM database @@ -53,6 +54,8 @@ ascii_url=http://www.fki.inf.unibe.ch/DBs/iamDB/data/ascii/ascii.tgz brown_corpus_url=http://www.sls.hawaii.edu/bley-vroman/brown.txt lob_corpus_url=http://ota.ox.ac.uk/text/0167.zip wellington_corpus_loc=/export/corpora5/Wellington/WWC/ +aachen_split_url=http://www.openslr.org/resources/56/splits.zip +aachen_splits=data/local/aachensplits mkdir -p $download_dir data/local # download and extact images and transcription @@ -144,6 +147,18 @@ else echo "$0: Wellington Corpus not included because wellington_dir not provided" fi +if [ -d $aachen_splits ]; then + echo "$0: Not downloading the Aachen splits as it is already there." +else + if [ ! -f $aachen_splits/splits.zip ]; then + echo "$0: Downloading Aachen splits ..." + mkdir -p $aachen_splits + wget -P $aachen_splits/ $aachen_split_url || exit 1; + fi + unzip $aachen_splits/splits.zip -d $aachen_splits || exit 1; + echo "$0: Done downloading and extracting Aachen splits" +fi + mkdir -p data/{train,test,val} file_name=largeWriterIndependentTextLineRecognitionTask @@ -160,11 +175,17 @@ cat $train_old > $train_new cat $test_old > $test_new cat $val1_old $val2_old > $val_new -if [ $stage -le 0 ]; then - local/process_data.py data/local data/train --dataset train || exit 1 - local/process_data.py data/local data/test --dataset test || exit 1 - local/process_data.py data/local data/val --dataset validation || exit 1 - - utils/utt2spk_to_spk2utt.pl data/train/utt2spk > data/train/spk2utt - utils/utt2spk_to_spk2utt.pl data/test/utt2spk > data/test/spk2utt +if $process_aachen_split; then + local/process_aachen_splits.py data/local $aachen_splits/splits data/train --dataset train || exit 1 + local/process_aachen_splits.py data/local $aachen_splits/splits data/test --dataset test || exit 1 + local/process_aachen_splits.py data/local $aachen_splits/splits data/val --dataset validation || exit 1 +else + local/process_data.py data/local data/train --dataset train || exit 1 + local/process_data.py data/local data/test --dataset test || exit 1 + local/process_data.py data/local data/val --dataset validation || exit 1 fi + +image/fix_data_dir.sh data/train +image/fix_data_dir.sh data/test +image/fix_data_dir.sh data/val + diff --git a/egs/iam/v1/local/prepare_dict.sh b/egs/iam/v1/local/prepare_dict.sh index f691d577fba..7451f6b85f7 100755 --- a/egs/iam/v1/local/prepare_dict.sh +++ b/egs/iam/v1/local/prepare_dict.sh @@ -38,7 +38,7 @@ while(<>){ }' | sort -u > $dir/lexicon.txt -sed -i "s/#//" $dir/nonsilence_phones.txt +perl -i -pe "s/#//" $dir/nonsilence_phones.txt echo ' SIL' >> $dir/lexicon.txt echo ' SIL' >> $dir/lexicon.txt diff --git a/egs/iam/v1/local/process_aachen_splits.py b/egs/iam/v1/local/process_aachen_splits.py new file mode 100755 index 00000000000..cb6a6d4f0d8 --- /dev/null +++ b/egs/iam/v1/local/process_aachen_splits.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Chun Chieh Chang +# 2017 Ashish Arora + +""" This script reads the extracted IAM database files and creates + the following files (for the data subset selected via --dataset): + text, utt2spk, images.scp. + + Eg. local/process_aachen_splits.py data/local data/train data --dataset train + Eg. text file: 000_a01-000u-00 A MOVE to stop Mr. Gaitskell from + utt2spk file: 000_a01-000u-00 000 + images.scp file: 000_a01-000u-00 data/local/lines/a01/a01-000u/a01-000u-00.png +""" + +import argparse +import os +import sys +import xml.dom.minidom as minidom + +parser = argparse.ArgumentParser(description="""Creates text, utt2spk + and images.scp files.""") +parser.add_argument('database_path', type=str, + help='Path to the downloaded (and extracted) IAM data') +parser.add_argument('split_path', type=str, + help='location of the train/test/val set') +parser.add_argument('out_dir', type=str, + help='location to write output files.') +parser.add_argument('--dataset', type=str, default='train', + choices=['train', 'test','validation'], + help='Subset of data to process.') +args = parser.parse_args() + +text_file = os.path.join(args.out_dir + '/', 'text') +text_fh = open(text_file, 'w') + +utt2spk_file = os.path.join(args.out_dir + '/', 'utt2spk') +utt2spk_fh = open(utt2spk_file, 'w') + +image_file = os.path.join(args.out_dir + '/', 'images.scp') +image_fh = open(image_file, 'w') + +dataset_path = os.path.join(args.split_path, + args.dataset + '.uttlist') + +text_file_path = os.path.join(args.database_path, + 'ascii','lines.txt') +text_dict = {} +def process_text_file_for_word_model(): + with open (text_file_path, 'rt') as in_file: + for line in in_file: + if line[0]=='#': + continue + line = line.strip() + utt_id = line.split(' ')[0] + text_vect = line.split(' ')[8:] + text = "".join(text_vect) + text = text.replace("|", " ") + text_dict[utt_id] = text + + +### main ### + +print("Processing '{}' data...".format(args.dataset)) +process_text_file_for_word_model() + +with open(dataset_path) as f: + for line in f: + line = line.strip() + line_vect = line.split('-') + xml_file = line_vect[0] + '-' + line_vect[1] + xml_path = os.path.join(args.database_path, 'xml', xml_file + '.xml') + doc = minidom.parse(xml_path) + form_elements = doc.getElementsByTagName('form')[0] + writer_id = form_elements.getAttribute('writer-id') + outerfolder = form_elements.getAttribute('id')[0:3] + innerfolder = form_elements.getAttribute('id') + lines_path = os.path.join(args.database_path, 'lines', + outerfolder, innerfolder) + for file in os.listdir(lines_path): + if file.endswith(".png"): + image_file_path = os.path.join(lines_path, file) + base_name = os.path.splitext(os.path.basename(image_file_path))[0] + text = text_dict[base_name] + utt_id = writer_id + '_' + base_name + text_fh.write(utt_id + ' ' + text + '\n') + utt2spk_fh.write(utt_id + ' ' + writer_id + '\n') + image_fh.write(utt_id + ' ' + image_file_path + '\n') diff --git a/egs/iam/v1/local/train_lm.sh b/egs/iam/v1/local/train_lm.sh index a15fbea2af3..3e8c838efdb 100755 --- a/egs/iam/v1/local/train_lm.sh +++ b/egs/iam/v1/local/train_lm.sh @@ -58,9 +58,12 @@ if [ $stage -le 0 ]; then rm ${dir}/data/text/* 2>/dev/null || true # Using LOB and brown corpus. - cat data/local/lobcorpus/0167/download/LOB_COCOA/lob.txt | \ - local/remove_test_utterances_from_lob.py data/test/text data/val/text \ - > ${dir}/data/text/lob.txt + if [ ! -f data/local/lob-train-only.txt ]; then + cat data/local/lobcorpus/0167/download/LOB_COCOA/lob.txt | \ + local/remove_test_utterances_from_lob.py data/test/text data/val/text \ + > data/local/lob-train-only.txt + fi + cat data/local/lob-train-only.txt > ${dir}/data/text/lob.txt cat data/local/browncorpus/brown.txt > ${dir}/data/text/brown.txt if [ -d "data/local/wellingtoncorpus" ]; then cat data/local/wellingtoncorpus/Wellington_annotation_removed.txt > ${dir}/data/text/wellington.txt diff --git a/egs/iam/v1/local/unk_arc_post_to_transcription.py b/egs/iam/v1/local/unk_arc_post_to_transcription.py index c5ad1235427..f8b69820601 100755 --- a/egs/iam/v1/local/unk_arc_post_to_transcription.py +++ b/egs/iam/v1/local/unk_arc_post_to_transcription.py @@ -1,88 +1,107 @@ #!/usr/bin/env python3 -# Copyright 2017 Ashish Arora +#Copyright 2017 Ashish Arora +""" This module will be used by scripts for open vocabulary setup. + If the hypothesis transcription contains , then it will replace the + with the word predicted by model by concatenating phones decoded + from the unk-model. It is currently supported only for triphone setup. + Args: + phones: File name of a file that contains the phones.txt, (symbol-table for phones). + phone and phoneID, Eg. a 217, phoneID of 'a' is 217. + words: File name of a file that contains the words.txt, (symbol-table for words). + word and wordID. Eg. ACCOUNTANCY 234, wordID of 'ACCOUNTANCY' is 234. + unk: ID of . Eg. 231. + one-best-arc-post: A file in arc-post format, which is a list of timing info and posterior + of arcs along the one-best path from the lattice. + E.g. 506_m01-049-00 8 12 1 7722 282 272 288 231 + [] + [ ...] + output-text: File containing hypothesis transcription with recognized by the + unk-model. + E.g. A move to stop mr. gaitskell. + + Eg. local/unk_arc_post_to_transcription.py lang/phones.txt lang/words.txt + data/lang/oov.int +""" import argparse +import os import sys - parser = argparse.ArgumentParser(description="""uses phones to convert unk to word""") -parser.add_argument('phones', type=str, help='phones and phonesID') -parser.add_argument('words', type=str, help='word and wordID') -parser.add_argument('unk', type=str, default='-', help='location of unk file') -parser.add_argument('--input-ark', type=str, default='-', help='where to read the input data') -parser.add_argument('--out-ark', type=str, default='-', help='where to write the output data') +parser.add_argument('phones', type=str, help='File name of a file that contains the' + 'symbol-table for phones. Each line must be: ') +parser.add_argument('words', type=str, help='File name of a file that contains the' + 'symbol-table for words. Each line must be: ') +parser.add_argument('unk', type=str, default='-', help='File name of a file that' + 'contains the ID of . The content must be: , e.g. 231') +parser.add_argument('--one-best-arc-post', type=str, default='-', help='A file in arc-post' + 'format, which is a list of timing info and posterior of arcs' + 'along the one-best path from the lattice') +parser.add_argument('--output-text', type=str, default='-', help='File containing' + 'hypothesis transcription with recognized by the unk-model') args = parser.parse_args() - ### main ### -phone_fh = open(args.phones, 'r', encoding='latin-1') -word_fh = open(args.words, 'r', encoding='latin-1') -unk_fh = open(args.unk, 'r', encoding='latin-1') -if args.input_ark == '-': - input_fh = sys.stdin +phone_handle = open(args.phones, 'r', encoding='latin-1') # Create file handles +word_handle = open(args.words, 'r', encoding='latin-1') +unk_handle = open(args.unk,'r', encoding='latin-1') +if args.one_best_arc_post == '-': + arc_post_handle = sys.stdin else: - input_fh = open(args.input_ark, 'r', encoding='latin-1') -if args.out_ark == '-': - out_fh = sys.stdout + arc_post_handle = open(args.one_best_arc_post, 'r', encoding='latin-1') +if args.output_text == '-': + output_text_handle = sys.stdout else: - out_fh = open(args.out_ark, 'w', encoding='latin-1') + output_text_handle = open(args.output_text, 'w', encoding='latin-1') -phone_dict = dict() # Stores phoneID and phone mapping -phone_data_vect = phone_fh.read().strip().split("\n") -for key_val in phone_data_vect: +id2phone = dict() # Stores the mapping from phone_id (int) to phone (char) +phones_data = phone_handle.read().strip().split("\n") + +for key_val in phones_data: key_val = key_val.split(" ") - phone_dict[key_val[1]] = key_val[0] + id2phone[key_val[1]] = key_val[0] + word_dict = dict() -word_data_vect = word_fh.read().strip().split("\n") +word_data_vect = word_handle.read().strip().split("\n") + for key_val in word_data_vect: key_val = key_val.split(" ") word_dict[key_val[1]] = key_val[0] -unk_val = unk_fh.read().strip().split(" ")[0] +unk_val = unk_handle.read().strip().split(" ")[0] -utt_word_dict = dict() -utt_phone_dict = dict() # Stores utteranceID and phoneID -unk_word_dict = dict() -count=0 -for line in input_fh: +utt_word_dict = dict() # Dict of list, stores mapping from utteranceID(int) to words(str) +for line in arc_post_handle: line_vect = line.strip().split("\t") - if len(line_vect) < 6: - print("Bad line: '{}' Expecting 6 fields. Skipping...".format(line), + if len(line_vect) < 6: # Check for 1best-arc-post output + print("Error: Bad line: '{}' Expecting 6 fields. Skipping...".format(line), file=sys.stderr) continue - uttID = line_vect[0] + utt_id = line_vect[0] word = line_vect[4] phones = line_vect[5] - if uttID in utt_word_dict.keys(): - utt_word_dict[uttID][count] = word - utt_phone_dict[uttID][count] = phones - else: - count = 0 - utt_word_dict[uttID] = dict() - utt_phone_dict[uttID] = dict() - utt_word_dict[uttID][count] = word - utt_phone_dict[uttID][count] = phones - if word == unk_val: # Get character sequence for unk - phone_key_vect = phones.split(" ") - phone_val_vect = list() - for pkey in phone_key_vect: - phone_val_vect.append(phone_dict[pkey]) + if utt_id not in list(utt_word_dict.keys()): + utt_word_dict[utt_id] = list() + + if word == unk_val: # Get the 1best phone sequence given by the unk-model + phone_id_seq = phones.split(" ") + phone_seq = list() + for pkey in phone_id_seq: + phone_seq.append(id2phone[pkey]) # Convert the phone-id sequence to a phone sequence. phone_2_word = list() - for phone_val in phone_val_vect: - phone_2_word.append(phone_val.split('_')[0]) - phone_2_word = ''.join(phone_2_word) - utt_word_dict[uttID][count] = phone_2_word + for phone_val in phone_seq: + phone_2_word.append(phone_val.split('_')[0]) # Removing the world-position markers(e.g. _B) + phone_2_word = ''.join(phone_2_word) # Concatnate phone sequence + utt_word_dict[utt_id].append(phone_2_word) # Store word from unk-model else: - if word == '0': + if word == '0': # Store space/silence word_val = ' ' else: word_val = word_dict[word] - utt_word_dict[uttID][count] = word_val - count += 1 + utt_word_dict[utt_id].append(word_val) # Store word from 1best-arc-post -transcription = "" -for key in sorted(utt_word_dict.keys()): - transcription = key - for index in sorted(utt_word_dict[key].keys()): - value = utt_word_dict[key][index] - transcription = transcription + " " + value - out_fh.write(transcription + '\n') +transcription = "" # Output transcription +for utt_key in sorted(utt_word_dict.keys()): + transcription = utt_key + for word in utt_word_dict[utt_key]: + transcription = transcription + " " + word + output_text_handle.write(transcription + '\n') diff --git a/egs/iam/v1/run.sh b/egs/iam/v1/run.sh index b943870f530..85811b6cb3d 100755 --- a/egs/iam/v1/run.sh +++ b/egs/iam/v1/run.sh @@ -20,6 +20,9 @@ iam_database=/export/corpora5/handwriting_ocr/IAM # This corpus is of written NZ English that can be purchased here: # "https://www.victoria.ac.nz/lals/resources/corpora-default" wellington_database=/export/corpora5/Wellington/WWC/ +train_set=train_aug +process_aachen_split=false +overwrite=false . ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. ## This relates to the queue. @@ -30,39 +33,63 @@ wellington_database=/export/corpora5/Wellington/WWC/ ./local/check_tools.sh if [ $stage -le 0 ]; then + if [ -f data/train/text ] && ! $overwrite; then + echo "$0: Not processing, probably script have run from wrong stage" + echo "Exiting with status 1 to avoid data corruption" + exit 1; + fi + echo "$0: Preparing data..." local/prepare_data.sh --download-dir "$iam_database" \ --wellington-dir "$wellington_database" \ - --username "$username" --password "$password" + --username "$username" --password "$password" \ + --process_aachen_split $process_aachen_split fi -mkdir -p data/{train,test}/data +mkdir -p data/{train,test,val}/data if [ $stage -le 1 ]; then - echo "$0: Preparing the test and train feature files..." - for dataset in train test; do - local/make_features.py data/$dataset --feat-dim 40 | \ - copy-feats --compress=true --compression-method=7 \ - ark:- ark,scp:data/$dataset/data/images.ark,data/$dataset/feats.scp - steps/compute_cmvn_stats.sh data/$dataset + echo "$0: $(date) stage 1: getting allowed image widths for e2e training..." + image/get_image2num_frames.py --feat-dim 40 data/train # This will be needed for the next command + # The next command creates a "allowed_lengths.txt" file in data/train + # which will be used by local/make_features.py to enforce the images to + # have allowed lengths. The allowed lengths will be spaced by 10% difference in length. + image/get_allowed_lengths.py --frame-subsampling-factor 4 10 data/train + echo "$0: $(date) Extracting features, creating feats.scp file" + local/extract_features.sh --nj $nj --cmd "$cmd" --feat-dim 40 data/train + steps/compute_cmvn_stats.sh data/train || exit 1; + for set in val test; do + local/extract_features.sh --nj $nj --cmd "$cmd" --augment true \ + --feat-dim 40 data/${set} + steps/compute_cmvn_stats.sh data/${set} || exit 1; done + utils/fix_data_dir.sh data/train fi if [ $stage -le 2 ]; then + for set in train; do + echo "$0: $(date) stage 2: Performing augmentation, it will double training data" + local/augment_data.sh --nj $nj --cmd "$cmd" --feat-dim 40 data/${set} data/${set}_aug data + steps/compute_cmvn_stats.sh data/${set}_aug || exit 1; + done +fi + +if [ $stage -le 3 ]; then echo "$0: Estimating a language model for decoding..." # We do this stage before dict preparation because prepare_dict.sh # generates the lexicon from pocolm's wordlist local/train_lm.sh --vocab-size 50k fi -if [ $stage -le 3 ]; then +if [ $stage -le 4 ]; then echo "$0: Preparing dictionary and lang..." - # This is for training. Use a large vocab size, e.g. 500k to include all the # training words: local/prepare_dict.sh --vocab-size 500k --dir data/local/dict # this is for training utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 8 --sil-prob 0.95 \ data/local/dict "" data/lang/temp data/lang - + silphonelist=`cat data/lang/phones/silence.csl` + nonsilphonelist=`cat data/lang/phones/nonsilence.csl` + local/gen_topo.py 8 4 4 $nonsilphonelist $silphonelist data/lang/phones.txt >data/lang/topo # This is for decoding. We use a 50k lexicon to be consistent with the papers # reporting WERs on IAM: local/prepare_dict.sh --vocab-size 50k --dir data/local/dict_50k # this is for decoding @@ -77,11 +104,14 @@ if [ $stage -le 3 ]; then utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 8 \ --unk-fst exp/unk_lang_model/unk_fst.txt \ data/local/dict_50k "" data/lang_unk/temp data/lang_unk + silphonelist=`cat data/lang/phones/silence.csl` + nonsilphonelist=`cat data/lang/phones/nonsilence.csl` + local/gen_topo.py 8 4 4 $nonsilphonelist $silphonelist data/lang_unk/phones.txt >data/lang_unk/topo cp data/lang_test/G.fst data/lang_unk/G.fst fi if [ $stage -le 4 ]; then - steps/train_mono.sh --nj $nj --cmd $cmd --totgauss 10000 data/train \ + steps/train_mono.sh --nj $nj --cmd $cmd --totgauss 10000 data/$train_set \ data/lang exp/mono fi @@ -93,10 +123,10 @@ if [ $stage -le 5 ] && $decode_gmm; then fi if [ $stage -le 6 ]; then - steps/align_si.sh --nj $nj --cmd $cmd data/train data/lang \ + steps/align_si.sh --nj $nj --cmd $cmd data/$train_set data/lang \ exp/mono exp/mono_ali - steps/train_deltas.sh --cmd $cmd 500 20000 data/train data/lang \ + steps/train_deltas.sh --cmd $cmd 500 20000 data/$train_set data/lang \ exp/mono_ali exp/tri fi @@ -108,12 +138,12 @@ if [ $stage -le 7 ] && $decode_gmm; then fi if [ $stage -le 8 ]; then - steps/align_si.sh --nj $nj --cmd $cmd data/train data/lang \ + steps/align_si.sh --nj $nj --cmd $cmd data/$train_set data/lang \ exp/tri exp/tri_ali steps/train_lda_mllt.sh --cmd $cmd \ --splice-opts "--left-context=3 --right-context=3" 500 20000 \ - data/train data/lang exp/tri_ali exp/tri2 + data/$train_set data/lang exp/tri_ali exp/tri2 fi if [ $stage -le 9 ] && $decode_gmm; then @@ -125,10 +155,10 @@ fi if [ $stage -le 10 ]; then steps/align_fmllr.sh --nj $nj --cmd $cmd --use-graphs true \ - data/train data/lang exp/tri2 exp/tri2_ali + data/$train_set data/lang exp/tri2 exp/tri2_ali steps/train_sat.sh --cmd $cmd 500 20000 \ - data/train data/lang exp/tri2_ali exp/tri3 + data/$train_set data/lang exp/tri2_ali exp/tri3 fi if [ $stage -le 11 ] && $decode_gmm; then @@ -140,13 +170,13 @@ fi if [ $stage -le 12 ]; then steps/align_fmllr.sh --nj $nj --cmd $cmd --use-graphs true \ - data/train data/lang exp/tri3 exp/tri3_ali + data/$train_set data/lang exp/tri3 exp/tri3_ali fi if [ $stage -le 13 ]; then - local/chain/run_cnn_1a.sh --lang-test lang_unk + local/chain/run_cnn.sh --lang-test lang_unk --train_set $train_set fi if [ $stage -le 14 ]; then - local/chain/run_cnn_chainali_1c.sh --chain-model-dir exp/chain/cnn_1a --stage 2 + local/chain/run_cnn_chainali.sh --chain-model-dir exp/chain/cnn_1a --stage 2 --train_set $train_set fi diff --git a/egs/iam/v1/run_end2end.sh b/egs/iam/v1/run_end2end.sh index 6df93e739f4..0a8b014715f 100755 --- a/egs/iam/v1/run_end2end.sh +++ b/egs/iam/v1/run_end2end.sh @@ -6,6 +6,8 @@ stage=0 nj=20 username= password= +process_aachen_split=false +overwrite=false # iam_database points to the database path on the JHU grid. If you have not # already downloaded the database you can set it to a local directory # like "data/download" and follow the instructions @@ -16,61 +18,78 @@ iam_database=/export/corpora5/handwriting_ocr/IAM # This corpus is of written NZ English that can be purchased here: # "https://www.victoria.ac.nz/lals/resources/corpora-default" wellington_database=/export/corpora5/Wellington/WWC/ - +train_set=train_aug . ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. ## This relates to the queue. . ./path.sh . ./utils/parse_options.sh # e.g. this parses the above options # if supplied. - - ./local/check_tools.sh - if [ $stage -le 0 ]; then + + if [ -f data/train/text ] && ! $overwrite; then + echo "$0: Not processing, probably script have run from wrong stage" + echo "Exiting with status 1 to avoid data corruption" + exit 1; + fi + echo "$0: Preparing data..." local/prepare_data.sh --download-dir "$iam_database" \ --wellington-dir "$wellington_database" \ - --username "$username" --password "$password" + --username "$username" --password "$password" \ + --process_aachen_split $process_aachen_split fi -mkdir -p data/{train,test}/data +mkdir -p data/{train,test,val}/data if [ $stage -le 1 ]; then - image/get_image2num_frames.py data/train # This will be needed for the next command + echo "$0: $(date) stage 1: getting allowed image widths for e2e training..." + image/get_image2num_frames.py --feat-dim 40 data/train # This will be needed for the next command # The next command creates a "allowed_lengths.txt" file in data/train # which will be used by local/make_features.py to enforce the images to # have allowed lengths. The allowed lengths will be spaced by 10% difference in length. image/get_allowed_lengths.py --frame-subsampling-factor 4 10 data/train - echo "$0: Preparing the test and train feature files..." - for dataset in train test; do - local/make_features.py data/$dataset --feat-dim 40 | \ - copy-feats --compress=true --compression-method=7 \ - ark:- ark,scp:data/$dataset/data/images.ark,data/$dataset/feats.scp - steps/compute_cmvn_stats.sh data/$dataset + echo "$0: $(date) Extracting features, creating feats.scp file" + local/extract_features.sh --nj $nj --cmd "$cmd" --feat-dim 40 data/train + steps/compute_cmvn_stats.sh data/train || exit 1; + for set in val test; do + local/extract_features.sh --nj $nj --cmd "$cmd" --augment true \ + --feat-dim 40 data/${set} + steps/compute_cmvn_stats.sh data/${set} || exit 1; done utils/fix_data_dir.sh data/train fi if [ $stage -le 2 ]; then + for set in train; do + echo "$0: $(date) stage 2: Performing augmentation, it will double training data" + local/augment_data.sh --nj $nj --cmd "$cmd" --feat-dim 40 data/${set} data/${set}_aug data + steps/compute_cmvn_stats.sh data/${set}_aug || exit 1; + done +fi + +if [ $stage -le 3 ]; then echo "$0: Estimating a language model for decoding..." # We do this stage before dict preparation because prepare_dict.sh # generates the lexicon from pocolm's wordlist local/train_lm.sh --vocab-size 50k fi -if [ $stage -le 3 ]; then +if [ $stage -le 4 ]; then echo "$0: Preparing dictionary and lang..." - # This is for training. Use a large vocab size, e.g. 500k to include all the # training words: local/prepare_dict.sh --vocab-size 500k --dir data/local/dict - utils/prepare_lang.sh --sil-prob 0.95 \ + utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 8 --sil-prob 0.95 \ data/local/dict "" data/lang/temp data/lang + silphonelist=`cat data/lang/phones/silence.csl` + nonsilphonelist=`cat data/lang/phones/nonsilence.csl` + local/gen_topo.py 8 4 4 $nonsilphonelist $silphonelist data/lang/phones.txt >data/lang/topo # This is for decoding. We use a 50k lexicon to be consistent with the papers # reporting WERs on IAM. local/prepare_dict.sh --vocab-size 50k --dir data/local/dict_50k - utils/prepare_lang.sh --sil-prob 0.95 data/local/dict_50k \ - "" data/lang_test/temp data/lang_test + utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 8 --sil-prob 0.95 \ + data/local/dict_50k "" data/lang_test/temp data/lang_test utils/format_lm.sh data/lang_test data/local/local_lm/data/arpa/3gram_big.arpa.gz \ data/local/dict_50k/lexicon.txt data/lang_test @@ -79,23 +98,27 @@ if [ $stage -le 3 ]; then data/local/dict_50k exp/unk_lang_model utils/prepare_lang.sh --unk-fst exp/unk_lang_model/unk_fst.txt \ data/local/dict_50k "" data/lang_unk/temp data/lang_unk + + silphonelist=`cat data/lang/phones/silence.csl` + nonsilphonelist=`cat data/lang/phones/nonsilence.csl` + local/gen_topo.py 8 4 4 $nonsilphonelist $silphonelist data/lang_unk/phones.txt >data/lang_unk/topo cp data/lang_test/G.fst data/lang_unk/G.fst fi -if [ $stage -le 4 ]; then +if [ $stage -le 5 ]; then echo "$0: Calling the flat-start chain recipe..." - local/chain/run_flatstart_cnn1a.sh + local/chain/run_e2e_cnn.sh --train_set $train_set fi -if [ $stage -le 5 ]; then +if [ $stage -le 6 ]; then echo "$0: Aligning the training data using the e2e chain model..." steps/nnet3/align.sh --nj 50 --cmd "$cmd" \ --use-gpu false \ --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0 --acoustic-scale=1.0' \ - data/train data/lang exp/chain/e2e_cnn_1a exp/chain/e2e_ali_train + data/$train_set data/lang exp/chain/e2e_cnn_1a exp/chain/e2e_ali_train fi -if [ $stage -le 6 ]; then +if [ $stage -le 7 ]; then echo "$0: Building a tree and training a regular chain model using the e2e alignments..." - local/chain/run_cnn_e2eali_1a.sh + local/chain/run_cnn_e2eali.sh --train_set $train_set fi diff --git a/egs/iam/v2/cmd.sh b/egs/iam/v2/cmd.sh new file mode 100755 index 00000000000..3c8eb9f93a5 --- /dev/null +++ b/egs/iam/v2/cmd.sh @@ -0,0 +1,13 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +export cmd="queue.pl" diff --git a/egs/iam/v2/image b/egs/iam/v2/image new file mode 120000 index 00000000000..1668ee99922 --- /dev/null +++ b/egs/iam/v2/image @@ -0,0 +1 @@ +../../cifar/v1/image/ \ No newline at end of file diff --git a/egs/iam/v2/local/augment_data.sh b/egs/iam/v2/local/augment_data.sh new file mode 100755 index 00000000000..31e4a8217ca --- /dev/null +++ b/egs/iam/v2/local/augment_data.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Copyright 2018 Hossein Hadian +# 2018 Ashish Arora + +# Apache 2.0 +# This script performs data augmentation. + +nj=4 +cmd=run.pl +feat_dim=40 +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +srcdir=$1 +outdir=$2 +datadir=$3 +aug_set=aug1 +mkdir -p $datadir/augmentations +echo "copying $srcdir to $datadir/augmentations/$aug_set, allowed length, creating feats.scp" + +for set in $aug_set; do + image/copy_data_dir.sh --spk-prefix $set- --utt-prefix $set- \ + $srcdir $datadir/augmentations/$set + cat $srcdir/allowed_lengths.txt > $datadir/augmentations/$set/allowed_lengths.txt + local/extract_features.sh --nj $nj --cmd "$cmd" --feat-dim $feat_dim \ + --fliplr false --augment true $datadir/augmentations/$set +done + +echo " combine original data and data from different augmentations" +utils/combine_data.sh --extra-files images.scp $outdir $srcdir $datadir/augmentations/$aug_set +cat $srcdir/allowed_lengths.txt > $outdir/allowed_lengths.txt diff --git a/egs/iam/v2/local/chain/compare_wer.sh b/egs/iam/v2/local/chain/compare_wer.sh new file mode 100755 index 00000000000..2ce14e13694 --- /dev/null +++ b/egs/iam/v2/local/chain/compare_wer.sh @@ -0,0 +1,120 @@ +#!/bin/bash + +# this script is used for comparing decoding results between systems. +# e.g. local/chain/compare_wer.sh exp/chain/cnn{1a,1b} + +# Copyright 2017 Chun Chieh Chang +# 2017 Ashish Arora + +if [ $# == 0 ]; then + echo "Usage: $0: [ ... ]" + echo "e.g.: $0 exp/chain/cnn{1a,1b}" + exit 1 +fi +. ./path.sh + +echo "# $0 $*" +used_epochs=false + +echo -n "# System " +for x in $*; do printf "% 10s" " $(basename $x)"; done +echo + +echo -n "# WER " +for x in $*; do + wer=$(cat $x/decode_test/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# WER (rescored) " +for x in $*; do + wer="--" + [ -d $x/decode_test_rescored ] && wer=$(cat $x/decode_test_rescored/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# CER " +for x in $*; do + cer=$(cat $x/decode_test/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + +echo -n "# CER (rescored) " +for x in $*; do + cer="--" + [ -d $x/decode_test_rescored ] && cer=$(cat $x/decode_test_rescored/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + +echo -n "# WER val " +for x in $*; do + wer=$(cat $x/decode_val/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# WER (rescored) val " +for x in $*; do + wer="--" + [ -d $x/decode_val_rescored ] && wer=$(cat $x/decode_val_rescored/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# CER val " +for x in $*; do + cer=$(cat $x/decode_val/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + +echo -n "# CER (rescored) val " +for x in $*; do + cer="--" + [ -d $x/decode_val_rescored ] && cer=$(cat $x/decode_val_rescored/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + +if $used_epochs; then + exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. +fi + +echo -n "# Final train prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final train prob (xent) " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob (xent) " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Parameters " +for x in $*; do + params=$(nnet3-info $x/final.mdl 2>/dev/null | grep num-parameters | cut -d' ' -f2 | awk '{printf "%0.2fM\n",$1/1000000}') + printf "% 10s" $params +done +echo diff --git a/egs/iam/v2/local/chain/run_cnn_e2eali.sh b/egs/iam/v2/local/chain/run_cnn_e2eali.sh new file mode 120000 index 00000000000..da731bcb0b1 --- /dev/null +++ b/egs/iam/v2/local/chain/run_cnn_e2eali.sh @@ -0,0 +1 @@ +tuning/run_cnn_e2eali_1d.sh \ No newline at end of file diff --git a/egs/iam/v2/local/chain/run_e2e_cnn.sh b/egs/iam/v2/local/chain/run_e2e_cnn.sh new file mode 120000 index 00000000000..7dca9c30e23 --- /dev/null +++ b/egs/iam/v2/local/chain/run_e2e_cnn.sh @@ -0,0 +1 @@ +tuning/run_e2e_cnn_1b.sh \ No newline at end of file diff --git a/egs/iam/v1/local/chain/run_cnn_e2eali_1a.sh b/egs/iam/v2/local/chain/tuning/run_cnn_e2eali_1a.sh similarity index 91% rename from egs/iam/v1/local/chain/run_cnn_e2eali_1a.sh rename to egs/iam/v2/local/chain/tuning/run_cnn_e2eali_1a.sh index ba28f681708..9a01688ba35 100755 --- a/egs/iam/v1/local/chain/run_cnn_e2eali_1a.sh +++ b/egs/iam/v2/local/chain/tuning/run_cnn_e2eali_1a.sh @@ -22,6 +22,7 @@ stage=0 nj=30 train_set=train +decode_val=true nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. affix=_1a #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. e2echain_model_dir=exp/chain/e2e_cnn_1a @@ -42,7 +43,9 @@ tdnn_dim=450 # training options srand=0 remove_egs=true -lang_test=lang_unk +lang_decode=data/lang +lang_rescore=data/lang_rescore_6g +if $decode_val; then maybe_val=val; else maybe_val= ; fi # End configuration section. echo "$0 $@" # Print the command line for logging @@ -132,7 +135,7 @@ if [ $stage -le 4 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) cnn_opts="l2-regularize=0.075" tdnn_opts="l2-regularize=0.075" output_opts="l2-regularize=0.1" @@ -228,18 +231,26 @@ if [ $stage -le 6 ]; then # as long as phones.txt was compatible. utils/mkgraph.sh \ - --self-loop-scale 1.0 data/$lang_test \ + --self-loop-scale 1.0 $lang_decode \ $dir $dir/graph || exit 1; fi if [ $stage -le 7 ]; then frames_per_chunk=$(echo $chunk_width | cut -d, -f1) - steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ - --extra-left-context $chunk_left_context \ - --extra-right-context $chunk_right_context \ - --extra-left-context-initial 0 \ - --extra-right-context-final 0 \ - --frames-per-chunk $frames_per_chunk \ - --nj $nj --cmd "$cmd" \ - $dir/graph data/test $dir/decode_test || exit 1; + for decode_set in test $maybe_val; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --extra-left-context $chunk_left_context \ + --extra-right-context $chunk_right_context \ + --extra-left-context-initial 0 \ + --extra-right-context-final 0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/$decode_set $dir/decode_$decode_set || exit 1; + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/$decode_set $dir/decode_${decode_set}{,_rescored} || exit 1 + done fi + +echo "Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/iam/v1/local/chain/run_cnn_e2eali_1b.sh b/egs/iam/v2/local/chain/tuning/run_cnn_e2eali_1b.sh similarity index 86% rename from egs/iam/v1/local/chain/run_cnn_e2eali_1b.sh rename to egs/iam/v2/local/chain/tuning/run_cnn_e2eali_1b.sh index 6d8cca876bf..28aa246f334 100755 --- a/egs/iam/v1/local/chain/run_cnn_e2eali_1b.sh +++ b/egs/iam/v2/local/chain/tuning/run_cnn_e2eali_1b.sh @@ -2,15 +2,17 @@ # e2eali_1b is the same as e2eali_1a but uses unconstrained egs -# local/chain/compare_wer.sh /home/hhadian/kaldi-rnnlm/egs/iam/v1/exp/chain/cnn_e2eali_1a exp/chain/cnn_e2eali_1b +# local/chain/compare_wer.sh exp/chain/cnn_e2eali_1a exp/chain/cnn_e2eali_1b # System cnn_e2eali_1a cnn_e2eali_1b -# WER 12.79 12.23 -# CER 5.73 5.48 -# Final train prob -0.0556 -0.0367 -# Final valid prob -0.0795 -0.0592 -# Final train prob (xent) -0.9178 -0.8382 -# Final valid prob (xent) -1.0604 -0.9853 -# Parameters 3.95M 3.95M +# WER 10.40 10.33 +# WER (rescored) 10.02 10.10 +# CER 4.97 5.00 +# CER (rescored) 4.83 4.88 +# Final train prob -0.0612 -0.0428 +# Final valid prob -0.0857 -0.0666 +# Final train prob (xent) -0.8990 -0.9210 +# Final valid prob (xent) -1.0024 -1.0264 +# Parameters 3.98M 3.98M # steps/info/chain_dir_info.pl exp/chain/cnn_e2eali_1b # exp/chain/cnn_e2eali_1b: num-iters=21 nj=2..4 num-params=4.0M dim=40->360 combine=-0.038->-0.038 (over 1) xent:train/valid[13,20,final]=(-1.34,-0.967,-0.838/-1.40,-1.07,-0.985) logprob:train/valid[13,20,final]=(-0.075,-0.054,-0.037/-0.083,-0.072,-0.059) @@ -21,6 +23,7 @@ stage=0 nj=30 train_set=train +decode_val=true nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. affix=_1b #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. e2echain_model_dir=exp/chain/e2e_cnn_1a @@ -41,7 +44,10 @@ tdnn_dim=450 # training options srand=0 remove_egs=true -lang_test=lang_unk +lang_decode=data/lang +lang_rescore=data/lang_rescore_6g +if $decode_val; then maybe_val=val; else maybe_val= ; fi + # End configuration section. echo "$0 $@" # Print the command line for logging @@ -131,7 +137,7 @@ if [ $stage -le 4 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) cnn_opts="l2-regularize=0.075" tdnn_opts="l2-regularize=0.075" output_opts="l2-regularize=0.1" @@ -227,18 +233,26 @@ if [ $stage -le 6 ]; then # as long as phones.txt was compatible. utils/mkgraph.sh \ - --self-loop-scale 1.0 data/$lang_test \ + --self-loop-scale 1.0 $lang_decode \ $dir $dir/graph || exit 1; fi if [ $stage -le 7 ]; then frames_per_chunk=$(echo $chunk_width | cut -d, -f1) - steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ - --extra-left-context $chunk_left_context \ - --extra-right-context $chunk_right_context \ - --extra-left-context-initial 0 \ - --extra-right-context-final 0 \ - --frames-per-chunk $frames_per_chunk \ - --nj $nj --cmd "$cmd" \ - $dir/graph data/test $dir/decode_test || exit 1; + for decode_set in test $maybe_val; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --extra-left-context $chunk_left_context \ + --extra-right-context $chunk_right_context \ + --extra-left-context-initial 0 \ + --extra-right-context-final 0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/$decode_set $dir/decode_$decode_set || exit 1; + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/$decode_set $dir/decode_${decode_set}{,_rescored} || exit 1 + done fi + +echo "Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/iam/v2/local/chain/tuning/run_cnn_e2eali_1c.sh b/egs/iam/v2/local/chain/tuning/run_cnn_e2eali_1c.sh new file mode 100755 index 00000000000..f158317950a --- /dev/null +++ b/egs/iam/v2/local/chain/tuning/run_cnn_e2eali_1c.sh @@ -0,0 +1,259 @@ +#!/bin/bash + +# e2eali_1c is the same as e2eali_1b but has fewer CNN layers, smaller +# l2-regularize, more epochs and uses dropout. + + +# local/chain/compare_wer.sh exp/chain/cnn_e2eali_1b exp/chain/cnn_e2eali_1c +# System cnn_e2eali_1b cnn_e2eali_1c +# WER 10.33 10.05 +# WER (rescored) 10.10 9.75 +# CER 5.00 4.76 +# CER (rescored) 4.88 4.68 +# Final train prob -0.0428 -0.0317 +# Final valid prob -0.0666 -0.0630 +# Final train prob (xent) -0.9210 -0.5413 +# Final valid prob (xent) -1.0264 -0.7096 +# Parameters 3.98M 5.12M + +# steps/info/chain_dir_info.pl exp/chain/cnn_e2eali_1c +# exp/chain/cnn_e2eali_1c: num-iters=21 nj=2..4 num-params=5.1M dim=40->392 combine=-0.034->-0.034 (over 1) xent:train/valid[13,20,final]=(-0.953,-0.800,-0.541/-1.03,-0.933,-0.710) logprob:train/valid[13,20,final]=(-0.069,-0.048,-0.032/-0.091,-0.078,-0.063) + +set -e -o pipefail + +stage=0 + +nj=30 +train_set=train +decode_val=true +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. +affix=_1c #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +e2echain_model_dir=exp/chain/e2e_cnn_1a +common_egs_dir= +reporting_email= + +# chain options +train_stage=-10 +xent_regularize=0.1 +frame_subsampling_factor=4 +# training chunk-options +chunk_width=340,300,200,100 +num_leaves=500 +# we don't need extra left/right context for TDNN systems. +chunk_left_context=0 +chunk_right_context=0 +tdnn_dim=550 +# training options +srand=0 +remove_egs=true +lang_decode=data/lang +lang_rescore=data/lang_rescore_6g +if $decode_val; then maybe_val=val; else maybe_val= ; fi +dropout_schedule='0,0@0.20,0.2@0.50,0' +# End configuration section. +echo "$0 $@" # Print the command line for logging + + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 2 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/nnet3/align_lats.sh --nj $nj --cmd "$cmd" \ + --acoustic-scale 1.0 \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0' \ + ${train_data_dir} data/lang $e2echain_model_dir $lat_dir + echo "" >$lat_dir/splice_opts +fi + +if [ $stage -le 3 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor $frame_subsampling_factor \ + --alignment-subsampling-factor 1 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$cmd" $num_leaves ${train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 4 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + cnn_opts="l2-regularize=0.03 dropout-proportion=0.0" + tdnn_opts="l2-regularize=0.03" + output_opts="l2-regularize=0.04" + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=70" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=40 name=input + + conv-relu-batchnorm-dropout-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-dropout-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-dropout-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-dropout-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-dropout-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + relu-batchnorm-dropout-layer name=tdnn1 input=Append(-4,-2,0,2,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + relu-batchnorm-dropout-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + relu-batchnorm-dropout-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' mod?els... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 5 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/iam-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--num-extra-lm-states=500" \ + --chain.frame-subsampling-factor=$frame_subsampling_factor \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=8 \ + --trainer.frames-per-iter=2000000 \ + --trainer.optimization.num-jobs-initial=2 \ + --trainer.optimization.num-jobs-final=4 \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=64,32 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.chunk-left-context=$chunk_left_context \ + --egs.chunk-right-context=$chunk_right_context \ + --egs.chunk-left-context-initial=0 \ + --egs.chunk-right-context-final=0 \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 6 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $lang_decode \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 7 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + for decode_set in test $maybe_val; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --extra-left-context $chunk_left_context \ + --extra-right-context $chunk_right_context \ + --extra-left-context-initial 0 \ + --extra-right-context-final 0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/$decode_set $dir/decode_$decode_set || exit 1; + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/$decode_set $dir/decode_${decode_set}{,_rescored} || exit 1 + done +fi + +echo "Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/iam/v2/local/chain/tuning/run_cnn_e2eali_1d.sh b/egs/iam/v2/local/chain/tuning/run_cnn_e2eali_1d.sh new file mode 100755 index 00000000000..1c44057454a --- /dev/null +++ b/egs/iam/v2/local/chain/tuning/run_cnn_e2eali_1d.sh @@ -0,0 +1,259 @@ +#!/bin/bash + +# e2eali_1d is the same as e2eali_1c but has more CNN layers, different filter size +# smaller lm-opts, minibatch, frams-per-iter, less epochs and more initial/finaljobs. + +# local/chain/compare_wer.sh exp/chain/e2e_cnn_1b/ exp/chain/cnn_e2eali_1d +# System e2e_cnn_1b cnn_e2eali_1d +# WER 13.91 8.80 +# WER (rescored) 13.64 8.52 +# CER 7.08 4.06 +# CER (rescored) 6.82 3.98 +# Final train prob 0.0148 -0.0524 +# Final valid prob 0.0105 -0.0713 +# Final train prob (xent) -0.4695 +# Final valid prob (xent) -0.5310 +# Parameters 9.52M 4.36M + +# steps/info/chain_dir_info.pl exp/chain/cnn_e2eali_1d +# exp/chain/cnn_e2eali_1d: num-iters=30 nj=3..5 num-params=4.4M dim=40->400 combine=-0.055->-0.055 (over 1) xent:train/valid[19,29,final]=(-0.683,-0.489,-0.469/-0.703,-0.544,-0.531) logprob:train/valid[19,29,final]=(-0.090,-0.057,-0.052/-0.107,-0.076,-0.071) +set -e -o pipefail + +stage=0 + +nj=30 +train_set=train +decode_val=true +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. +affix=_1d #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +e2echain_model_dir=exp/chain/e2e_cnn_1b +common_egs_dir= +reporting_email= + +# chain options +train_stage=-10 +xent_regularize=0.1 +frame_subsampling_factor=4 +# training chunk-options +chunk_width=340,300,200,100 +num_leaves=500 +# we don't need extra left/right context for TDNN systems. +chunk_left_context=0 +chunk_right_context=0 +tdnn_dim=550 +# training options +srand=0 +remove_egs=true +lang_decode=data/lang +lang_rescore=data/lang_rescore_6g +if $decode_val; then maybe_val=val; else maybe_val= ; fi +dropout_schedule='0,0@0.20,0.2@0.50,0' +# End configuration section. +echo "$0 $@" # Print the command line for logging + + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 2 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/nnet3/align_lats.sh --nj $nj --cmd "$cmd" \ + --acoustic-scale 1.0 \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0' \ + ${train_data_dir} data/lang $e2echain_model_dir $lat_dir + echo "" >$lat_dir/splice_opts +fi + +if [ $stage -le 3 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor $frame_subsampling_factor \ + --alignment-subsampling-factor 1 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$cmd" $num_leaves ${train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 4 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + cnn_opts="l2-regularize=0.03 dropout-proportion=0.0" + tdnn_opts="l2-regularize=0.03" + output_opts="l2-regularize=0.04" + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=70" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=40 name=input + + conv-relu-batchnorm-dropout-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-dropout-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-dropout-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-dropout-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-dropout-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common3 height-subsample-out=2 + conv-relu-batchnorm-dropout-layer name=cnn6 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + relu-batchnorm-dropout-layer name=tdnn1 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + relu-batchnorm-dropout-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + relu-batchnorm-dropout-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' mod?els... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 5 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/iam-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=true \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=1000" \ + --chain.frame-subsampling-factor=$frame_subsampling_factor \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=5 \ + --trainer.frames-per-iter=1500000 \ + --trainer.optimization.num-jobs-initial=3 \ + --trainer.optimization.num-jobs-final=5 \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=32,16 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.chunk-left-context=$chunk_left_context \ + --egs.chunk-right-context=$chunk_right_context \ + --egs.chunk-left-context-initial=0 \ + --egs.chunk-right-context-final=0 \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 6 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $lang_decode \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 7 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + for decode_set in test $maybe_val; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --extra-left-context $chunk_left_context \ + --extra-right-context $chunk_right_context \ + --extra-left-context-initial 0 \ + --extra-right-context-final 0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/$decode_set $dir/decode_$decode_set || exit 1; + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/$decode_set $dir/decode_${decode_set}{,_rescored} || exit 1 + done +fi + + +echo "Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/iam/v1/local/chain/run_flatstart_cnn1a.sh b/egs/iam/v2/local/chain/tuning/run_e2e_cnn_1a.sh similarity index 84% rename from egs/iam/v1/local/chain/run_flatstart_cnn1a.sh rename to egs/iam/v2/local/chain/tuning/run_e2e_cnn_1a.sh index 56c897137f4..cb2bfa0a82d 100755 --- a/egs/iam/v1/local/chain/run_flatstart_cnn1a.sh +++ b/egs/iam/v2/local/chain/tuning/run_e2e_cnn_1a.sh @@ -2,19 +2,21 @@ # Copyright 2017 Hossein Hadian # This script does end2end chain training (i.e. from scratch) - -# local/chain/compare_wer.sh exp/chain/cnn_1a exp/chain/cnn_chainali_1c exp/chain/e2e_cnn_1a -# System cnn_1a cnn_chainali_1c e2e_cnn_1a -# WER 18.52 12.72 13.87 -# CER 10.07 5.99 6.54 -# Final train prob -0.0077 -0.0291 -0.0371 -# Final valid prob -0.0970 -0.0359 -0.0636 -# Final train prob (xent) -0.5484 -0.9781 -# Final valid prob (xent) -0.9643 -1.1544 -# Parameters 4.36M 3.96M 9.13M +# ./local/chain/compare_wer.sh exp/chain/e2e_cnn_1a/ +# System e2e_cnn_1a +# WER 11.24 +# WER (rescored) 10.80 +# CER 5.32 +# CER (rescored) 5.24 +# Final train prob 0.0568 +# Final valid prob 0.0381 +# Final train prob (xent) +# Final valid prob (xent) +# Parameters 9.13M # steps/info/chain_dir_info.pl exp/chain/e2e_cnn_1a -# exp/chain/e2e_cnn_1a: num-iters=21 nj=2..4 num-params=9.1M dim=40->12640 combine=-0.033->-0.033 (over 1) logprob:train/valid[13,20,final]=(-0.058,-0.042,-0.035/-0.070,-0.064,-0.059) +# exp/chain/e2e_cnn_1a: num-iters=42 nj=2..4 num-params=9.1M dim=40->12640 combine=0.049->0.049 (over 1) logprob:train/valid[27,41,final]=(0.035,0.055,0.057/0.016,0.037,0.038) + set -e @@ -23,6 +25,7 @@ stage=0 train_stage=-10 get_egs_stage=-10 affix=1a +nj=30 # training options tdnn_dim=450 @@ -35,7 +38,9 @@ l2_regularize=0.00005 frames_per_iter=1000000 cmvn_opts="--norm-means=true --norm-vars=true" train_set=train -lang_test=lang_unk +decode_val=true +lang_decode=data/lang +lang_rescore=data/lang_rescore_6g # End configuration section. echo "$0 $@" # Print the command line for logging @@ -95,7 +100,6 @@ if [ $stage -le 2 ]; then mkdir -p $dir/configs cat < $dir/configs/network.xconfig input dim=40 name=input - conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 @@ -106,7 +110,6 @@ if [ $stage -le 2 ]; then relu-batchnorm-layer name=tdnn1 input=Append(-4,-2,0,2,4) dim=$tdnn_dim $tdnn_opts relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts - ## adding the layers for chain branch relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $output_opts output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts @@ -155,15 +158,19 @@ if [ $stage -le 4 ]; then # as long as phones.txt was compatible. utils/mkgraph.sh \ - --self-loop-scale 1.0 data/$lang_test \ + --self-loop-scale 1.0 $lang_decode \ $dir $dir/graph || exit 1; fi if [ $stage -le 5 ]; then - frames_per_chunk=$(echo $chunk_width | cut -d, -f1) - steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ - --nj 30 --cmd "$cmd" \ - $dir/graph data/test $dir/decode_test || exit 1; + for decode_set in test $maybe_val; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/$decode_set $dir/decode_$decode_set || exit 1; + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/$decode_set $dir/decode_${decode_set}{,_rescored} || exit 1 + done fi echo "Done. Date: $(date). Results:" diff --git a/egs/iam/v2/local/chain/tuning/run_e2e_cnn_1b.sh b/egs/iam/v2/local/chain/tuning/run_e2e_cnn_1b.sh new file mode 100755 index 00000000000..d5f79602695 --- /dev/null +++ b/egs/iam/v2/local/chain/tuning/run_e2e_cnn_1b.sh @@ -0,0 +1,163 @@ +#!/bin/bash +# Copyright 2017 Hossein Hadian + +# This script does end2end chain training (i.e. from scratch) +# ./local/chain/compare_wer.sh exp/chain/e2e_cnn_1b/ +# System e2e_cnn_1b +# WER 13.59 +# WER (rescored) 13.27 +# CER 6.92 +# CER (rescored) 6.71 +# Final train prob 0.0345 +# Final valid prob 0.0269 +# Final train prob (xent) +# Final valid prob (xent) +# Parameters 9.52M + +# steps/info/chain_dir_info.pl exp/chain/e2e_cnn_1b +# exp/chain/e2e_cnn_1b: num-iters=42 nj=2..4 num-params=9.5M dim=40->12640 combine=0.041->0.041 (over 2) logprob:train/valid[27,41,final]=(0.032,0.035,0.035/0.025,0.026,0.027) +set -e + +# configs for 'chain' +stage=0 +train_stage=-10 +get_egs_stage=-10 +affix=1b +nj=30 + +# training options +tdnn_dim=450 +minibatch_size=150=100,64/300=50,32/600=25,16/1200=16,8 +common_egs_dir= +train_set=train +decode_val=true +lang_decode=data/lang +lang_rescore=data/lang_rescore_6g +if $decode_val; then maybe_val=val; else maybe_val= ; fi +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 1 ]; then + steps/nnet3/chain/e2e/prepare_e2e.sh --nj 30 --cmd "$cmd" \ + --shared-phones true \ + --type biphone \ + data/$train_set $lang $treedir + $cmd $treedir/log/make_phone_lm.log \ + cat data/$train_set/text \| \ + steps/nnet3/chain/e2e/text_to_phones.py data/lang \| \ + utils/sym2int.pl -f 2- data/lang/phones.txt \| \ + chain-est-phone-lm --num-extra-lm-states=500 \ + ark:- $treedir/phone_lm.fst +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $treedir/tree | grep num-pdfs | awk '{print $2}') + common1="height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="height-offsets=-2,-1,0,1,2 num-filters-out=70" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=40 name=input + + conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + relu-batchnorm-layer name=tdnn1 input=Append(-4,-2,0,2,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn4 input=Append(-4,0,4) dim=$tdnn_dim + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $output_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs +fi + +if [ $stage -le 3 ]; then + # no need to store the egs in a shared storage because we always + # remove them. Anyway, it takes only 5 minutes to generate them. + + steps/nnet3/chain/e2e/train_e2e.py --stage $train_stage \ + --cmd "$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--num_egs_diagnostic 100 --num_utts_subset 400" \ + --chain.frame-subsampling-factor 4 \ + --chain.alignment-subsampling-factor 4 \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter 1000000 \ + --trainer.num-epochs 4 \ + --trainer.optimization.momentum 0 \ + --trainer.optimization.num-jobs-initial 2 \ + --trainer.optimization.num-jobs-final 4 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.optimization.shrink-value 1.0 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs true \ + --feat-dir data/${train_set} \ + --tree-dir $treedir \ + --dir $dir || exit 1; +fi + +if [ $stage -le 4 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $lang_decode \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 5 ]; then + for decode_set in test $maybe_val; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/$decode_set $dir/decode_$decode_set || exit 1; + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/$decode_set $dir/decode_${decode_set}{,_rescored} || exit 1 + done +fi + +echo "Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/iam/v2/local/check_tools.sh b/egs/iam/v2/local/check_tools.sh new file mode 100755 index 00000000000..5b4d3107d3b --- /dev/null +++ b/egs/iam/v2/local/check_tools.sh @@ -0,0 +1,43 @@ +#!/bin/bash -u + +# Copyright 2015 (c) Johns Hopkins University (Jan Trmal ) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +[ -f ./path.sh ] && . ./path.sh +set +e + +command -v python3 >&/dev/null \ + || { echo >&2 "python3 not found on PATH. You will have to install Python3, preferably >= 3.6"; exit 1; } + +python3 -c "import numpy" +if [ $? -ne 0 ] ; then + echo >&2 "This recipe needs numpy installed." + exit 1 +fi + +python3 -c "import scipy" +if [ $? -ne 0 ] ; then + echo >&2 "This recipe needs scipy installed." + exit 1 +fi + +python3 -c "import scipy.misc; scipy.misc.__dict__['imread']" +if [ $? -ne 0 ] ; then + echo >&2 "This recipe needs scipy-image and Pillow installed." + exit 1 +fi + + +exit 0 diff --git a/egs/iam/v2/local/extract_features.sh b/egs/iam/v2/local/extract_features.sh new file mode 100755 index 00000000000..1741ad3f9b2 --- /dev/null +++ b/egs/iam/v2/local/extract_features.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +# Copyright 2017 Yiwen Shao +# 2018 Ashish Arora + +# Apache 2.0 +# This script runs the make features script in parallel. + +nj=4 +cmd=run.pl +feat_dim=40 +augment=false +fliplr=false +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +data=$1 +featdir=$data/data +scp=$data/images.scp +logdir=$data/log + +mkdir -p $logdir +mkdir -p $featdir + +# make $featdir an absolute pathname +featdir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $featdir ${PWD}` + +for n in $(seq $nj); do + split_scps="$split_scps $logdir/images.$n.scp" +done + +# split images.scp +utils/split_scp.pl $scp $split_scps || exit 1; + +$cmd JOB=1:$nj $logdir/extract_features.JOB.log \ + local/make_features.py $logdir/images.JOB.scp \ + --allowed_len_file_path $data/allowed_lengths.txt \ + --feat-dim $feat_dim --fliplr $fliplr --augment $augment \| \ + copy-feats --compress=true --compression-method=7 \ + ark:- ark,scp:$featdir/images.JOB.ark,$featdir/images.JOB.scp + +## aggregates the output scp's to get feats.scp +for n in $(seq $nj); do + cat $featdir/images.$n.scp || exit 1; +done > $data/feats.scp || exit 1 diff --git a/egs/iam/v2/local/gen_topo.py b/egs/iam/v2/local/gen_topo.py new file mode 100755 index 00000000000..8ffc59c5788 --- /dev/null +++ b/egs/iam/v2/local/gen_topo.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python + +# Copyright 2017 (author: Chun-Chieh Chang) + +# Generate a topology file. This allows control of the number of states in the +# non-silence HMMs, and in the silence HMMs. This is a modified version of +# 'utils/gen_topo.pl'. The difference is that this creates two topologies for +# the non-silence HMMs. The number of states for punctuations is different than +# the number of states for other characters. + +from __future__ import print_function +from __future__ import division +import argparse +import string + +parser = argparse.ArgumentParser(description="Usage: steps/nnet3/chain/gen_topo.py " + " " + "e.g.: steps/nnet3/chain/gen_topo.pl 4:5:6:7:8:9:10 1:2:3\n", + epilog="See egs/swbd/s5c/local/chain/train_tdnn_a.sh for example of usage."); +parser.add_argument("num_nonsil_states", type=int, help="number of states for nonsilence phones"); +parser.add_argument("num_sil_states", type=int, help="number of states for silence phones"); +parser.add_argument("num_punctuation_states", type=int, help="number of states for punctuation"); +parser.add_argument("nonsilence_phones", + help="List of non-silence phones as integers, separated by colons, e.g. 4:5:6:7:8:9"); +parser.add_argument("silence_phones", + help="List of silence phones as integers, separated by colons, e.g. 1:2:3"); +parser.add_argument("phone_list", help="file containing all phones and their corresponding number."); + +args = parser.parse_args() + +silence_phones = [ int(x) for x in args.silence_phones.split(":") ] +nonsilence_phones = [ int(x) for x in args.nonsilence_phones.split(":") ] +all_phones = silence_phones + nonsilence_phones + +punctuation_phones = [] +exclude = set("!(),.?;:'-\"") +with open(args.phone_list) as f: + for line in f: + line = line.strip() + phone = line.split(' ')[0] + if len(phone) == 1 and phone in exclude: + punctuation_phones.append(int(line.split(' ')[1])) +# For nonsilence phones that are not punctuations +print("") +print("") +print("") +print(" ".join([str(x) for x in nonsilence_phones if x not in punctuation_phones])) +print("") +for x in range(0, args.num_nonsil_states): + xp1 = x + 1 + print(" {0} {0} {0} 0.75 {1} 0.25 ".format(x, xp1)) +print(" {} ".format(args.num_nonsil_states)) +print("") + +# For nonsilence phones that ar punctuations +print("") +print("") +print(" ".join([str(x) for x in nonsilence_phones if x in punctuation_phones])) +print("") +for x in range(0, args.num_punctuation_states): + xp1 = x + 1 + print(" {0} {0} {0} 0.75 {1} 0.25 ".format(x, xp1)) +print(" {} ".format(args.num_punctuation_states)) +print("") + +# For silence phones +print("") +print("") +print(" ".join([str(x) for x in silence_phones])) +print("") +if(args.num_sil_states > 1): + transp = 1.0 / (args.num_sil_states - 1) + + state_str = " 0 0 " + for x in range(0, (args.num_sil_states - 1)): + state_str = "{} {} {} ".format(state_str, x, transp)) + state_str = state_str + "" + print(state_str) + + for x in range(1, (args.num_sil_states - 1)): + state_str = " {0} " + print(state_str) + second_last = args.num_sil_states - 1 + print(" {0} {0} {0} 0.75 {1} 0.25 ".format(second_last, args.num_sil_states)) + print(" {} ".format(args.num_sil_states)) +else: + print(" 0 0 0 0.75 1 0.25 ") + print(" {} ".format(args.num_sil_states)) +print("") +print("") diff --git a/egs/iam/v2/local/make_features.py b/egs/iam/v2/local/make_features.py new file mode 100755 index 00000000000..3ce501732cf --- /dev/null +++ b/egs/iam/v2/local/make_features.py @@ -0,0 +1,266 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Chun Chieh Chang +# 2017 Ashish Arora +# 2017 Yiwen Shao +# 2018 Hossein Hadian + +""" This script converts images to Kaldi-format feature matrices. The input to + this script is the path to a data directory, e.g. "data/train". This script + reads the images listed in images.scp and writes them to standard output + (by default) as Kaldi-formatted matrices (in text form). It also scales the + images so they have the same height (via --feat-dim). It can optionally pad + the images (on left/right sides) with white pixels. + If an 'image2num_frames' file is found in the data dir, it will be used + to enforce the images to have the specified length in that file by padding + white pixels (the --padding option will be ignored in this case). This relates + to end2end chain training. + eg. local/make_features.py data/train --feat-dim 40 +""" +import random +import argparse +import os +import sys +import scipy.io as sio +import numpy as np +from scipy import misc +from scipy.ndimage.interpolation import affine_transform +import math +from signal import signal, SIGPIPE, SIG_DFL +signal(SIGPIPE, SIG_DFL) + +parser = argparse.ArgumentParser(description="""Converts images (in 'dir'/images.scp) to features and + writes them to standard output in text format.""") +parser.add_argument('images_scp_path', type=str, + help='Path of images.scp file') +parser.add_argument('--allowed_len_file_path', type=str, default=None, + help='If supplied, each images will be padded to reach the ' + 'target length (this overrides --padding).') +parser.add_argument('--out-ark', type=str, default='-', + help='Where to write the output feature file') +parser.add_argument('--feat-dim', type=int, default=40, + help='Size to scale the height of all images') +parser.add_argument('--padding', type=int, default=5, + help='Number of white pixels to pad on the left' + 'and right side of the image.') +parser.add_argument('--fliplr', type=lambda x: (str(x).lower()=='true'), default=False, + help="Flip the image left-right for right to left languages") +parser.add_argument("--augment", type=lambda x: (str(x).lower()=='true'), default=False, + help="performs image augmentation") +args = parser.parse_args() + + +def write_kaldi_matrix(file_handle, matrix, key): + file_handle.write(key + " [ ") + num_rows = len(matrix) + if num_rows == 0: + raise Exception("Matrix is empty") + num_cols = len(matrix[0]) + + for row_index in range(len(matrix)): + if num_cols != len(matrix[row_index]): + raise Exception("All the rows of a matrix are expected to " + "have the same length") + file_handle.write(" ".join(map(lambda x: str(x), matrix[row_index]))) + if row_index != num_rows - 1: + file_handle.write("\n") + file_handle.write(" ]\n") + + +def horizontal_pad(im, allowed_lengths = None): + if allowed_lengths is None: + left_padding = right_padding = args.padding + else: # Find an allowed length for the image + imlen = im.shape[1] # width + allowed_len = 0 + for l in allowed_lengths: + if l > imlen: + allowed_len = l + break + if allowed_len == 0: + # No allowed length was found for the image (the image is too long) + return None + padding = allowed_len - imlen + left_padding = int(padding // 2) + right_padding = padding - left_padding + dim_y = im.shape[0] # height + im_pad = np.concatenate((255 * np.ones((dim_y, left_padding), + dtype=int), im), axis=1) + im_pad1 = np.concatenate((im_pad, 255 * np.ones((dim_y, right_padding), + dtype=int)), axis=1) + return im_pad1 + +def get_scaled_image_aug(im, mode='normal'): + scale_size = args.feat_dim + sx = im.shape[1] + sy = im.shape[0] + scale = (1.0 * scale_size) / sy + nx = int(scale_size) + ny = int(scale * sx) + scale_size = random.randint(10, 30) + scale = (1.0 * scale_size) / sy + down_nx = int(scale_size) + down_ny = int(scale * sx) + if mode == 'normal': + im = misc.imresize(im, (nx, ny)) + return im + else: + im_scaled_down = misc.imresize(im, (down_nx, down_ny)) + im_scaled_up = misc.imresize(im_scaled_down, (nx, ny)) + return im_scaled_up + return im + +def contrast_normalization(im, low_pct, high_pct): + element_number = im.size + rows = im.shape[0] + cols = im.shape[1] + im_contrast = np.zeros(shape=im.shape) + low_index = int(low_pct * element_number) + high_index = int(high_pct * element_number) + sorted_im = np.sort(im, axis=None) + low_thred = sorted_im[low_index] + high_thred = sorted_im[high_index] + for i in range(rows): + for j in range(cols): + if im[i, j] > high_thred: + im_contrast[i, j] = 255 # lightest to white + elif im[i, j] < low_thred: + im_contrast[i, j] = 0 # darkest to black + else: + # linear normalization + im_contrast[i, j] = (im[i, j] - low_thred) * \ + 255 / (high_thred - low_thred) + return im_contrast + + +def geometric_moment(frame, p, q): + m = 0 + for i in range(frame.shape[1]): + for j in range(frame.shape[0]): + m += (i ** p) * (j ** q) * frame[i][i] + return m + + +def central_moment(frame, p, q): + u = 0 + x_bar = geometric_moment(frame, 1, 0) / \ + geometric_moment(frame, 0, 0) # m10/m00 + y_bar = geometric_moment(frame, 0, 1) / \ + geometric_moment(frame, 0, 0) # m01/m00 + for i in range(frame.shape[1]): + for j in range(frame.shape[0]): + u += ((i - x_bar)**p) * ((j - y_bar)**q) * frame[i][j] + return u + + +def height_normalization(frame, w, h): + frame_normalized = np.zeros(shape=(h, w)) + alpha = 4 + x_bar = geometric_moment(frame, 1, 0) / \ + geometric_moment(frame, 0, 0) # m10/m00 + y_bar = geometric_moment(frame, 0, 1) / \ + geometric_moment(frame, 0, 0) # m01/m00 + sigma_x = (alpha * ((central_moment(frame, 2, 0) / + geometric_moment(frame, 0, 0)) ** .5)) # alpha * sqrt(u20/m00) + sigma_y = (alpha * ((central_moment(frame, 0, 2) / + geometric_moment(frame, 0, 0)) ** .5)) # alpha * sqrt(u02/m00) + for x in range(w): + for y in range(h): + i = int((x / w - 0.5) * sigma_x + x_bar) + j = int((y / h - 0.5) * sigma_y + y_bar) + frame_normalized[x][y] = frame[i][j] + return frame_normalized + + +def find_slant_project(im): + rows = im.shape[0] + cols = im.shape[1] + std_max = 0 + alpha_max = 0 + col_disp = np.zeros(90, int) + proj = np.zeros(shape=(90, cols + 2 * rows), dtype=int) + for r in range(rows): + for alpha in range(-45, 45, 1): + col_disp[alpha] = int(r * math.tan(alpha / 180.0 * math.pi)) + for c in range(cols): + if im[r, c] < 100: + for alpha in range(-45, 45, 1): + proj[alpha + 45, c + col_disp[alpha] + rows] += 1 + for alpha in range(-45, 45, 1): + proj_histogram, bin_array = np.histogram(proj[alpha + 45, :], bins=10) + proj_std = np.std(proj_histogram) + if proj_std > std_max: + std_max = proj_std + alpha_max = alpha + proj_std = np.std(proj, axis=1) + return -alpha_max + + +def horizontal_shear(im, degree): + rad = degree / 180.0 * math.pi + padding_x = int(abs(np.tan(rad)) * im.shape[0]) + padding_y = im.shape[0] + if rad > 0: + im_pad = np.concatenate( + (255 * np.ones((padding_y, padding_x), dtype=int), im), axis=1) + elif rad < 0: + im_pad = np.concatenate( + (im, 255 * np.ones((padding_y, padding_x), dtype=int)), axis=1) + else: + im_pad = im + shear_matrix = np.array([[1, 0], + [np.tan(rad), 1]]) + sheared_im = affine_transform(im_pad, shear_matrix, cval=255.0) + return sheared_im + + +### main ### +random.seed(1) +data_list_path = args.images_scp_path +if args.out_ark == '-': + out_fh = sys.stdout +else: + out_fh = open(args.out_ark,'w') + +allowed_lengths = None +allowed_len_handle = args.allowed_len_file_path +if os.path.isfile(allowed_len_handle): + print("Found 'allowed_lengths.txt' file...", file=sys.stderr) + allowed_lengths = [] + with open(allowed_len_handle) as f: + for line in f: + allowed_lengths.append(int(line.strip())) + print("Read {} allowed lengths and will apply them to the " + "features.".format(len(allowed_lengths)), file=sys.stderr) + +num_fail = 0 +num_ok = 0 +aug_setting = ['normal', 'scaled'] +with open(data_list_path) as f: + for line in f: + line = line.strip() + line_vect = line.split(' ') + image_id = line_vect[0] + image_path = line_vect[1] + im = misc.imread(image_path) + if args.fliplr: + im = np.fliplr(im) + if args.augment: + im_aug = get_scaled_image_aug(im, aug_setting[0]) + im_contrast = contrast_normalization(im_aug, 0.05, 0.2) + slant_degree = find_slant_project(im_contrast) + im_sheared = horizontal_shear(im_contrast, slant_degree) + im_aug = im_sheared + else: + im_aug = get_scaled_image_aug(im, aug_setting[0]) + im_horizontal_padded = horizontal_pad(im_aug, allowed_lengths) + if im_horizontal_padded is None: + num_fail += 1 + continue + data = np.transpose(im_horizontal_padded, (1, 0)) + data = np.divide(data, 255.0) + num_ok += 1 + write_kaldi_matrix(out_fh, data, image_id) + +print('Generated features for {} images. Failed for {} (image too ' + 'long).'.format(num_ok, num_fail), file=sys.stderr) diff --git a/egs/iam/v2/local/prepare_data.sh b/egs/iam/v2/local/prepare_data.sh new file mode 100755 index 00000000000..cf729d9a939 --- /dev/null +++ b/egs/iam/v2/local/prepare_data.sh @@ -0,0 +1,191 @@ +#!/bin/bash + +# Copyright 2017 Chun Chieh Chang +# 2017 Ashish Arora +# 2017 Hossein Hadian +# Apache 2.0 + +# This script downloads the IAM handwriting database and prepares the training +# and test data (i.e text, images.scp, utt2spk and spk2utt) by calling process_data.py. +# It also downloads the LOB and Brown text corpora. It downloads the database files +# only if they do not already exist in download directory. + +# Eg. local/prepare_data.sh +# Eg. text file: 000_a01-000u-00 A MOVE to stop Mr. Gaitskell from +# utt2spk file: 000_a01-000u-00 000 +# images.scp file: 000_a01-000u-00 data/local/lines/a01/a01-000u/a01-000u-00.png +# spk2utt file: 000 000_a01-000u-00 000_a01-000u-01 000_a01-000u-02 000_a01-000u-03 + +stage=0 +download_dir=data/download +process_aachen_split=false +wellington_dir= +username= +password= # username and password for downloading the IAM database + # if you have not already downloaded the database, please + # register at http://www.fki.inf.unibe.ch/databases/iam-handwriting-database + # and provide this script with your username and password. + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +if [[ ! -f $download_dir/lines.tgz && -z $username ]]; then + echo "$0: Warning: Couldn't find lines.tgz in $download_dir. Unless the extracted dataset files" + echo "exist in your data/local directory this script will fail because the required files" + echo "can't be downloaded automatically (it needs registration)." + echo "Please register at http://www.fki.inf.unibe.ch/databases/iam-handwriting-database" + echo "... and then call this script again with --username --password " + echo "" + exit 1 +fi + +lines=data/local/lines +xml=data/local/xml +ascii=data/local/ascii +bcorpus=data/local/browncorpus +lobcorpus=data/local/lobcorpus +wcorpus=data/local/wellingtoncorpus +data_split_info=data/local/largeWriterIndependentTextLineRecognitionTask +lines_url=http://www.fki.inf.unibe.ch/DBs/iamDB/data/lines/lines.tgz +xml_url=http://www.fki.inf.unibe.ch/DBs/iamDB/data/xml/xml.tgz +data_split_info_url=http://www.fki.inf.unibe.ch/DBs/iamDB/tasks/largeWriterIndependentTextLineRecognitionTask.zip +ascii_url=http://www.fki.inf.unibe.ch/DBs/iamDB/data/ascii/ascii.tgz +brown_corpus_url=http://www.sls.hawaii.edu/bley-vroman/brown.txt +lob_corpus_url=http://ota.ox.ac.uk/text/0167.zip +wellington_corpus_loc=/export/corpora5/Wellington/WWC/ +aachen_split_url=http://www.openslr.org/resources/56/splits.zip +aachen_splits=data/local/aachensplits +mkdir -p $download_dir data/local + +# download and extact images and transcription +if [ -d $lines ]; then + echo "$0: Not downloading lines images as it is already there." +else + if [ ! -f $download_dir/lines.tgz ]; then + echo "$0: Trying to download lines images..." + wget -P $download_dir --user "$username" --password "$password" $lines_url || exit 1; + fi + mkdir -p $lines + tar -xzf $download_dir/lines.tgz -C $lines || exit 1; + echo "$0: Done downloading and extracting lines images" +fi + +if [ -d $xml ]; then + echo "$0: Not downloading transcriptions as it is already there." +else + if [ ! -f $download_dir/xml.tgz ]; then + echo "$0: Trying to download transcriptions..." + wget -P $download_dir --user "$username" --password "$password" $xml_url || exit 1; + fi + mkdir -p $xml + tar -xzf $download_dir/xml.tgz -C $xml || exit 1; + echo "$0: Done downloading and extracting transcriptions." +fi + +if [ -d $data_split_info ]; then + echo "$0: Not downloading data split information as it is already there." +else + if [ ! -f $download_dir/largeWriterIndependentTextLineRecognitionTask.zip ]; then + echo "$0: Trying to download training and testing data split information..." + wget -P $download_dir --user "$username" --password "$password" $data_split_info_url || exit 1; + fi + mkdir -p $data_split_info + unzip $download_dir/largeWriterIndependentTextLineRecognitionTask.zip -d $data_split_info || exit 1; + echo "$0: Done downloading and extracting training and testing data split information" +fi + +if [ -d $ascii ]; then + echo "$0: Not downloading ascii.tgz as it is already there." +else + if [ ! -f $download_dir/ascii.tgz ]; then + echo "$0: trying to download ascii.tgz..." + wget -P $download_dir --user "$username" --password "$password" $ascii_url || exit 1; + fi + mkdir -p $ascii + tar -xzf $download_dir/ascii.tgz -C $ascii || exit 1; + echo "$0: Done downloading and extracting ascii.tgz" +fi + +if [ -d $lobcorpus ]; then + echo "$0: Not downloading the LOB text corpus as it is already there." +else + if [ ! -f $lobcorpus/0167.zip ]; then + echo "$0: Downloading the LOB text corpus ..." + mkdir -p $lobcorpus + wget -P $lobcorpus/ $lob_corpus_url || exit 1; + fi + unzip $lobcorpus/0167.zip -d $lobcorpus || exit 1; + echo "$0: Done downloading and extracting LOB corpus" +fi + +if [ -d $bcorpus ]; then + echo "$0: Not downloading the Brown corpus as it is already there." +else + if [ ! -f $bcorpus/brown.txt ]; then + mkdir -p $bcorpus + echo "$0: Downloading the Brown text corpus..." + wget -P $bcorpus $brown_corpus_url || exit 1; + fi + echo "$0: Done downloading the Brown text corpus" +fi + +if [ -d $wcorpus ]; then + echo "$0: Not copying Wellington corpus as it is already there." +elif [ ! -z $wellington_dir ]; then + mkdir -p $wcorpus + cp -r $wellington_dir/. $wcorpus + + # Combine Wellington corpora and replace some of their annotations + cat data/local/wellingtoncorpus/Section{A,B,C,D,E,F,G,H,J,K,L}.txt | \ + cut -d' ' -f3- | sed "s/^[ \t]*//" > data/local/wellingtoncorpus/Wellington_annotated.txt + + cat data/local/wellingtoncorpus/Wellington_annotated.txt | local/remove_wellington_annotations.py > data/local/wellingtoncorpus/Wellington_annotation_removed.txt + + echo "$0: Done copying Wellington corpus" +else + echo "$0: Wellington Corpus not included because wellington_dir not provided" +fi + +if [ -d $aachen_splits ]; then + echo "$0: Not downloading the Aachen splits as it is already there." +else + if [ ! -f $aachen_splits/splits.zip ]; then + echo "$0: Downloading Aachen splits ..." + mkdir -p $aachen_splits + wget -P $aachen_splits/ $aachen_split_url || exit 1; + fi + unzip $aachen_splits/splits.zip -d $aachen_splits || exit 1; + echo "$0: Done downloading and extracting Aachen splits" +fi + + +mkdir -p data/{train,test,val} +file_name=largeWriterIndependentTextLineRecognitionTask + +train_old="data/local/$file_name/trainset.txt" +test_old="data/local/$file_name/testset.txt" +val1_old="data/local/$file_name/validationset1.txt" +val2_old="data/local/$file_name/validationset2.txt" + +train_new="data/local/train.uttlist" +test_new="data/local/test.uttlist" +val_new="data/local/validation.uttlist" + +cat $train_old > $train_new +cat $test_old > $test_new +cat $val1_old $val2_old > $val_new + +if $process_aachen_split; then + local/process_aachen_splits.py data/local $aachen_splits/splits data/train --dataset train || exit 1 + local/process_aachen_splits.py data/local $aachen_splits/splits data/test --dataset test || exit 1 + local/process_aachen_splits.py data/local $aachen_splits/splits data/val --dataset validation || exit 1 +else + local/process_data.py data/local data/train --dataset train || exit 1 + local/process_data.py data/local data/test --dataset test || exit 1 + local/process_data.py data/local data/val --dataset validation || exit 1 +fi + +image/fix_data_dir.sh data/train +image/fix_data_dir.sh data/test +image/fix_data_dir.sh data/val diff --git a/egs/iam/v2/local/prepare_dict.sh b/egs/iam/v2/local/prepare_dict.sh new file mode 100755 index 00000000000..714b5b51788 --- /dev/null +++ b/egs/iam/v2/local/prepare_dict.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash + +# Copyright 2017 Hossein Hadian +# 2017 Chun Chieh Chang +# 2017 Ashish Arora + +# This script prepares the dictionary. + +set -e +dir=data/local/dict +vocab_size=50000 +. ./utils/parse_options.sh + +mkdir -p $dir + +# First get the set of all letters that occur in data/train/text +cat data/train/text | \ + perl -ne '@A = split; shift @A; for(@A) {print join("\n", split(//)), "\n";}' | \ + sort -u | grep -v "|" > $dir/nonsilence_phones.txt + +# Now use the pocolm's wordlist which is the most N frequent words in +# in data/train/text and LOB+Brown corpora (dev and test excluded) with their comprising +# letters as their transcription. Only include words that use the above letters. +# (Letter # is replaced with ) + +export letters=$(cat $dir/nonsilence_phones.txt | tr -d "\n") + +head -n $vocab_size data/local/local_lm/data/word_count | awk '{print $2}' | \ + perl -e '$letters=$ENV{letters}; $letters=$letters . "|"; +while(<>){ + chop; + $w = $_; + if($w =~ m/^[$letters]+$/){ + $trans = join(" ", split(//, $w)); + $trans =~ s/#//g; + $trans =~ s/\|/SIL/g; + print "$w $trans\n"; + } +}' | sort -u > $dir/lexicon.txt + + +perl -i -pe "s/#//" $dir/nonsilence_phones.txt + +echo ' SIL' >> $dir/lexicon.txt + +echo SIL > $dir/silence_phones.txt + +echo SIL >$dir/optional_silence.txt + +echo -n "" >$dir/extra_questions.txt diff --git a/egs/iam/v2/local/process_aachen_splits.py b/egs/iam/v2/local/process_aachen_splits.py new file mode 100755 index 00000000000..cb6a6d4f0d8 --- /dev/null +++ b/egs/iam/v2/local/process_aachen_splits.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Chun Chieh Chang +# 2017 Ashish Arora + +""" This script reads the extracted IAM database files and creates + the following files (for the data subset selected via --dataset): + text, utt2spk, images.scp. + + Eg. local/process_aachen_splits.py data/local data/train data --dataset train + Eg. text file: 000_a01-000u-00 A MOVE to stop Mr. Gaitskell from + utt2spk file: 000_a01-000u-00 000 + images.scp file: 000_a01-000u-00 data/local/lines/a01/a01-000u/a01-000u-00.png +""" + +import argparse +import os +import sys +import xml.dom.minidom as minidom + +parser = argparse.ArgumentParser(description="""Creates text, utt2spk + and images.scp files.""") +parser.add_argument('database_path', type=str, + help='Path to the downloaded (and extracted) IAM data') +parser.add_argument('split_path', type=str, + help='location of the train/test/val set') +parser.add_argument('out_dir', type=str, + help='location to write output files.') +parser.add_argument('--dataset', type=str, default='train', + choices=['train', 'test','validation'], + help='Subset of data to process.') +args = parser.parse_args() + +text_file = os.path.join(args.out_dir + '/', 'text') +text_fh = open(text_file, 'w') + +utt2spk_file = os.path.join(args.out_dir + '/', 'utt2spk') +utt2spk_fh = open(utt2spk_file, 'w') + +image_file = os.path.join(args.out_dir + '/', 'images.scp') +image_fh = open(image_file, 'w') + +dataset_path = os.path.join(args.split_path, + args.dataset + '.uttlist') + +text_file_path = os.path.join(args.database_path, + 'ascii','lines.txt') +text_dict = {} +def process_text_file_for_word_model(): + with open (text_file_path, 'rt') as in_file: + for line in in_file: + if line[0]=='#': + continue + line = line.strip() + utt_id = line.split(' ')[0] + text_vect = line.split(' ')[8:] + text = "".join(text_vect) + text = text.replace("|", " ") + text_dict[utt_id] = text + + +### main ### + +print("Processing '{}' data...".format(args.dataset)) +process_text_file_for_word_model() + +with open(dataset_path) as f: + for line in f: + line = line.strip() + line_vect = line.split('-') + xml_file = line_vect[0] + '-' + line_vect[1] + xml_path = os.path.join(args.database_path, 'xml', xml_file + '.xml') + doc = minidom.parse(xml_path) + form_elements = doc.getElementsByTagName('form')[0] + writer_id = form_elements.getAttribute('writer-id') + outerfolder = form_elements.getAttribute('id')[0:3] + innerfolder = form_elements.getAttribute('id') + lines_path = os.path.join(args.database_path, 'lines', + outerfolder, innerfolder) + for file in os.listdir(lines_path): + if file.endswith(".png"): + image_file_path = os.path.join(lines_path, file) + base_name = os.path.splitext(os.path.basename(image_file_path))[0] + text = text_dict[base_name] + utt_id = writer_id + '_' + base_name + text_fh.write(utt_id + ' ' + text + '\n') + utt2spk_fh.write(utt_id + ' ' + writer_id + '\n') + image_fh.write(utt_id + ' ' + image_file_path + '\n') diff --git a/egs/iam/v2/local/process_data.py b/egs/iam/v2/local/process_data.py new file mode 100755 index 00000000000..2adae7bf7be --- /dev/null +++ b/egs/iam/v2/local/process_data.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Chun Chieh Chang +# 2017 Ashish Arora + +""" This script reads the extracted IAM database files and creates + the following files (for the data subset selected via --dataset): + text, utt2spk, images.scp. + + Eg. local/process_data.py data/local data/train data --dataset train + Eg. text file: 000_a01-000u-00 A MOVE to stop Mr. Gaitskell from + utt2spk file: 000_a01-000u-00 000 + images.scp file: 000_a01-000u-00 data/local/lines/a01/a01-000u/a01-000u-00.png +""" + +import argparse +import os +import sys +import xml.dom.minidom as minidom + +parser = argparse.ArgumentParser(description="""Creates text, utt2spk + and images.scp files.""") +parser.add_argument('database_path', type=str, + help='Path to the downloaded (and extracted) IAM data') +parser.add_argument('out_dir', type=str, + help='Where to write output files.') +parser.add_argument('--dataset', type=str, default='train', + choices=['train', 'test','validation'], + help='Subset of data to process.') +args = parser.parse_args() + +text_file = os.path.join(args.out_dir + '/', 'text') +text_fh = open(text_file, 'w') + +utt2spk_file = os.path.join(args.out_dir + '/', 'utt2spk') +utt2spk_fh = open(utt2spk_file, 'w') + +image_file = os.path.join(args.out_dir + '/', 'images.scp') +image_fh = open(image_file, 'w') + +dataset_path = os.path.join(args.database_path, + args.dataset + '.uttlist') + +text_file_path = os.path.join(args.database_path, + 'ascii','lines.txt') +text_dict = {} +def process_text_file_for_word_model(): + with open (text_file_path, 'rt') as in_file: + for line in in_file: + if line[0]=='#': + continue + line = line.strip() + utt_id = line.split(' ')[0] + text_vect = line.split(' ')[8:] + text = "".join(text_vect) + text = text.replace("|", " ") + text_dict[utt_id] = text + +print("Processing '{}' data...".format(args.dataset)) +process_text_file_for_word_model() + +with open(dataset_path) as f: + for line in f: + line = line.strip() + line_vect = line.split('-') + xml_file = line_vect[0] + '-' + line_vect[1] + xml_path = os.path.join(args.database_path, 'xml', xml_file + '.xml') + img_num = line[-3:] + doc = minidom.parse(xml_path) + form_elements = doc.getElementsByTagName('form')[0] + writer_id = form_elements.getAttribute('writer-id') + outerfolder = form_elements.getAttribute('id')[0:3] + innerfolder = form_elements.getAttribute('id') + lines_path = os.path.join(args.database_path, 'lines', + outerfolder, innerfolder, innerfolder) + image_file_path = lines_path + img_num + '.png' + text = text_dict[line] + utt_id = writer_id + '_' + line + text_fh.write(utt_id + ' ' + text + '\n') + utt2spk_fh.write(utt_id + ' ' + writer_id + '\n') + image_fh.write(utt_id + ' ' + image_file_path + '\n') diff --git a/egs/iam/v2/local/remove_test_utterances_from_lob.py b/egs/iam/v2/local/remove_test_utterances_from_lob.py new file mode 100755 index 00000000000..5e5dac52818 --- /dev/null +++ b/egs/iam/v2/local/remove_test_utterances_from_lob.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# Copyright 2018 Ashish Arora + +import argparse +import os +import numpy as np +import sys +import re + +parser = argparse.ArgumentParser(description="""Removes dev/test set lines + from the LOB corpus. Reads the + corpus from stdin, and writes it to stdout.""") +parser.add_argument('dev_text', type=str, + help='dev transcription location.') +parser.add_argument('test_text', type=str, + help='test transcription location.') +args = parser.parse_args() + +def remove_punctuations(transcript): + char_list = [] + for char in transcript: + if char.isdigit() or char == '+' or char == '~' or char == '?': + continue + if char == '#' or char == '=' or char == '-' or char == '!': + continue + if char == ',' or char == '.' or char == ')' or char == '\'': + continue + if char == '(' or char == ':' or char == ';' or char == '"': + continue + if char == '*': + continue + char_list.append(char) + return char_list + + +def remove_special_words(words): + word_list = [] + for word in words: + if word == '' or word == '#': + continue + word_list.append(word) + return word_list + + +# process and add dev/eval transcript in a list +# remove special words, punctuations, spaces between words +# lowercase the characters +def read_utterances(text_file_path): + with open(text_file_path, 'rt') as in_file: + for line in in_file: + words = line.strip().split() + words_wo_sw = remove_special_words(words) + transcript = ''.join(words_wo_sw[1:]) + transcript = transcript.lower() + trans_wo_punct = remove_punctuations(transcript) + transcript = ''.join(trans_wo_punct) + utterance_dict[words_wo_sw[0]] = transcript + + +### main ### + +# read utterances and add it to utterance_dict +utterance_dict = dict() +read_utterances(args.dev_text) +read_utterances(args.test_text) + +# read corpus and add it to below lists +corpus_text_lowercase_wo_sc = list() +corpus_text_wo_sc = list() +original_corpus_text = list() +for line in sys.stdin: + original_corpus_text.append(line) + words = line.strip().split() + words_wo_sw = remove_special_words(words) + + transcript = ''.join(words_wo_sw) + transcript = transcript.lower() + trans_wo_punct = remove_punctuations(transcript) + transcript = ''.join(trans_wo_punct) + corpus_text_lowercase_wo_sc.append(transcript) + + transcript = ''.join(words_wo_sw) + trans_wo_punct = remove_punctuations(transcript) + transcript = ''.join(trans_wo_punct) + corpus_text_wo_sc.append(transcript) + +# find majority of utterances below +# for utterances which were not found +# add them to remaining_utterances +row_to_keep = [True for i in range(len(original_corpus_text))] +remaining_utterances = dict() +for line_id, line_to_find in utterance_dict.items(): + found_line = False + # avoiding very small utterance, it causes removing + # complete lob text + if len(line_to_find) < 10: + remaining_utterances[line_id] = line_to_find + else: + for i in range(1, (len(corpus_text_lowercase_wo_sc) - 2)): + # Combine 3 consecutive lines of the corpus into a single line + prev_words = corpus_text_lowercase_wo_sc[i - 1].strip() + curr_words = corpus_text_lowercase_wo_sc[i].strip() + next_words = corpus_text_lowercase_wo_sc[i + 1].strip() + new_line = prev_words + curr_words + next_words + transcript = ''.join(new_line) + if line_to_find in transcript: + found_line = True + row_to_keep[i-1] = False + row_to_keep[i] = False + row_to_keep[i+1] = False + if not found_line: + remaining_utterances[line_id] = line_to_find + +# removing long utterances not found above +row_to_keep[87530] = False; row_to_keep[87531] = False; row_to_keep[87532] = False; +row_to_keep[31724] = False; row_to_keep[31725] = False; row_to_keep[31726] = False; +row_to_keep[16704] = False; row_to_keep[16705] = False; row_to_keep[16706] = False; +row_to_keep[94181] = False; row_to_keep[94182] = False; row_to_keep[94183] = False; +row_to_keep[20171] = False; row_to_keep[20172] = False; row_to_keep[20173] = False; +row_to_keep[16734] = False; row_to_keep[16733] = False; row_to_keep[16732] = False; +row_to_keep[20576] = False; row_to_keep[20577] = False; row_to_keep[20578] = False; +row_to_keep[31715] = False; row_to_keep[31716] = False; row_to_keep[31717] = False; +row_to_keep[31808] = False; row_to_keep[31809] = False; row_to_keep[31810] = False; +row_to_keep[31822] = False; row_to_keep[31823] = False; row_to_keep[31824] = False; +row_to_keep[88791] = False; row_to_keep[88792] = False; row_to_keep[88793] = False; +row_to_keep[31745] = False; row_to_keep[31746] = False; row_to_keep[31825] = False; +row_to_keep[94256] = False; row_to_keep[94257] = False; row_to_keep[88794] = False; +row_to_keep[88665] = False; row_to_keep[17093] = False; row_to_keep[17094] = False; +row_to_keep[20586] = False; row_to_keep[87228] = False; row_to_keep[87229] = False; +row_to_keep[16744] = False; row_to_keep[87905] = False; row_to_keep[87906] = False; +row_to_keep[16669] = False; row_to_keep[16670] = False; row_to_keep[16719] = False; +row_to_keep[87515] = False; row_to_keep[20090] = False; row_to_keep[31748] = False; +for i in range(len(original_corpus_text)): + transcript = original_corpus_text[i].strip() + if row_to_keep[i]: + print(transcript) + +print('Sentences not removed from LOB: {}'.format(remaining_utterances), file=sys.stderr) +print('Total test+dev sentences: {}'.format(len(utterance_dict)), file=sys.stderr) +print('Number of sentences not removed from LOB: {}'. format(len(remaining_utterances)), file=sys.stderr) +print('LOB lines: Before: {} After: {}'.format(len(original_corpus_text), + row_to_keep.count(True)), file=sys.stderr) diff --git a/egs/iam/v2/local/remove_wellington_annotations.py b/egs/iam/v2/local/remove_wellington_annotations.py new file mode 100755 index 00000000000..260a3542985 --- /dev/null +++ b/egs/iam/v2/local/remove_wellington_annotations.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright 2018 Chun-Chieh Chang + +import sys +import io +import re +from collections import OrderedDict + +sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding="utf8"); +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf8"); + +prev2_line = " "; +prev_line = " "; +for line in sys.stdin: + line = line.strip() + pattern = re.compile("\\*\\*\\[.*?\\*\\*\\]|\\*[0-9]|\\\\[0-9]{0,2}|\\*\\*?[\|,\?,\#,\=,\;,\:,\<,\>]|\||\^") + line_fixed = pattern.sub("", line) + dict=OrderedDict([("*+$","$"), ("*+","£"), ("*-","-"), ("*/","*"), ("*{","{"), ("*}","}"), + ("**\"","\""), ("*\"","\""), ("**'","'"), ("*'","'"), ("*@","°")]) + pattern = re.compile("|".join(re.escape(key) for key in dict.keys())); + line_fixed = pattern.sub(lambda x: dict[x.group()], line_fixed) + + line_fixed = prev2_line + "\n" + prev_line + "\n" + line_fixed + + pattern = re.compile("\{[0-9]{0,2}(.*?)\}", re.DOTALL) + line_fixed = pattern.sub(lambda x: x.group(1), line_fixed) + + output, prev2_line, prev_line = line_fixed.split("\n") + + sys.stdout.write(output + "\n") +sys.stdout.write(prev2_line + "\n") +sys.stdout.write(prev_line + "\n") diff --git a/egs/iam/v2/local/score.sh b/egs/iam/v2/local/score.sh new file mode 100755 index 00000000000..1d84815fc69 --- /dev/null +++ b/egs/iam/v2/local/score.sh @@ -0,0 +1,6 @@ + +#!/bin/bash + + +steps/scoring/score_kaldi_wer.sh "$@" +steps/scoring/score_kaldi_cer.sh --stage 2 "$@" diff --git a/egs/iam/v2/local/train_lm.sh b/egs/iam/v2/local/train_lm.sh new file mode 100755 index 00000000000..cc0119eb748 --- /dev/null +++ b/egs/iam/v2/local/train_lm.sh @@ -0,0 +1,156 @@ +#!/bin/bash + +# Copyright 2016 Vincent Nguyen +# 2016 Johns Hopkins University (author: Daniel Povey) +# 2017 Ashish Arora +# 2017 Hossein Hadian +# Apache 2.0 +# +# This script trains an LM on the LOB+Brown text data and IAM training transcriptions. +# It is based on the example scripts distributed with PocoLM + +# It will check if pocolm is installed and if not will proceed with installation + +set -e +stage=0 +vocab_size=50000 + +echo "$0 $@" # Print the command line for logging +. ./utils/parse_options.sh || exit 1; + +dir=data/local/local_lm +lm_dir=${dir}/data + + +mkdir -p $dir +. ./path.sh || exit 1; # for KALDI_ROOT +export PATH=$KALDI_ROOT/tools/pocolm/scripts:$PATH +( # First make sure the pocolm toolkit is installed. + cd $KALDI_ROOT/tools || exit 1; + if [ -d pocolm ]; then + echo Not installing the pocolm toolkit since it is already there. + else + echo "$0: Please install the PocoLM toolkit with: " + echo " cd ../../../tools; extras/install_pocolm.sh; cd -" + exit 1; + fi +) || exit 1; + +bypass_metaparam_optim_opt= +# If you want to bypass the metaparameter optimization steps with specific metaparameters +# un-comment the following line, and change the numbers to some appropriate values. +# You can find the values from output log of train_lm.py. +# These example numbers of metaparameters is for 4-gram model (with min-counts) +# running with train_lm.py. +# The dev perplexity should be close to the non-bypassed model. +#bypass_metaparam_optim_opt= +# Note: to use these example parameters, you may need to remove the .done files +# to make sure the make_lm_dir.py be called and tain only 3-gram model +#for order in 3; do +#rm -f ${lm_dir}/${num_word}_${order}.pocolm/.done + +if [ $stage -le 0 ]; then + mkdir -p ${dir}/data + mkdir -p ${dir}/data/text + + echo "$0: Getting the Data sources" + + rm ${dir}/data/text/* 2>/dev/null || true + + # Using LOB and brown corpus. + if [ ! -f data/local/lob-train-only.txt ]; then + cat data/local/lobcorpus/0167/download/LOB_COCOA/lob.txt | \ + local/remove_test_utterances_from_lob.py data/test/text.old data/val/text.old \ + > data/local/lob-train-only.txt + fi + cat data/local/lob-train-only.txt | \ + utils/lang/bpe/prepend_words.py | utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt \ + | sed 's/@@//g' > ${dir}/data/text/lob.txt + cat data/local/browncorpus/brown.txt | \ + utils/lang/bpe/prepend_words.py | utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt \ + | sed 's/@@//g' > ${dir}/brown.txt + tail -n +5000 ${dir}/brown.txt > ${dir}/data/text/brown.txt + if [ -d "data/local/wellingtoncorpus" ]; then + cat data/local/wellingtoncorpus/Wellington_annotation_removed.txt | \ + utils/lang/bpe/prepend_words.py | utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt \ + | sed 's/@@//g' > ${dir}/data/text/wellington.txt + fi + + # use the validation data as the dev set. + # Note: the name 'dev' is treated specially by pocolm, it automatically + # becomes the dev set. + head -5000 ${dir}/brown.txt > ${dir}/data/text/dev.txt + + # use the training data as an additional data source. + # we can later fold the dev data into this. + cat data/train/text | cut -d " " -f 2- > ${dir}/data/text/iam.txt + + # for reporting perplexities, we'll use the "real" dev set. + # (the validation data is used as ${dir}/data/text/dev.txt to work + # out interpolation weights.) + # note, we can't put it in ${dir}/data/text/, because then pocolm would use + # it as one of the data sources. + cut -d " " -f 2- < data/test/text > ${dir}/data/real_dev_set.txt + + # get the wordlist from IAM text + if [ -d "data/local/wellingtoncorpus" ]; then + cat ${dir}/data/text/{iam,lob,brown,wellington}.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + else + echo "$0: Wellington Corpus not found. Proceeding without using that corpus." + cat ${dir}/data/text/{iam,lob,brown}.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + fi + head -n $vocab_size ${dir}/data/word_count | awk '{print $2}' > ${dir}/data/wordlist +fi + +order=6 + +if [ $stage -le 1 ]; then + # decide on the vocabulary. + # Note: you'd use --wordlist if you had a previously determined word-list + # that you wanted to use. + # Note: if you have more than one order, use a certain amount of words as the + # vocab and want to restrict max memory for 'sort', + echo "$0: training the unpruned LM" + min_counts='brown=2 lob=2 iam=1' + wordlist=${dir}/data/wordlist + + lm_name="`basename ${wordlist}`_${order}" + if [ -n "${min_counts}" ]; then + lm_name+="_`echo ${min_counts} | tr -s "[:blank:]" "_" | tr "=" "-"`" + fi + unpruned_lm_dir=${lm_dir}/${lm_name}.pocolm + + train_lm.py --wordlist=${wordlist} --num-splits=10 --warm-start-ratio=20 \ + --limit-unk-history=true \ + ${bypass_metaparam_optim_opt} \ + ${dir}/data/text ${order} ${lm_dir}/work ${unpruned_lm_dir} + + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${unpruned_lm_dir} | gzip -c > ${dir}/data/arpa/${order}gram_unpruned.arpa.gz + + get_data_prob.py ${dir}/data/real_dev_set.txt ${unpruned_lm_dir} 2>&1 | grep -F '[perplexity' +fi + +if [ $stage -le 2 ]; then + echo "$0: pruning the LM (to larger size)" + # Using 1 million n-grams for a big LM for rescoring purposes. + size=1000000 + prune_lm_dir.py --target-num-ngrams=$size --initial-threshold=0.02 ${unpruned_lm_dir} ${dir}/data/lm_${order}_prune_big + + get_data_prob.py ${dir}/data/real_dev_set.txt ${dir}/data/lm_${order}_prune_big 2>&1 | grep -F '[perplexity' + + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${dir}/data/lm_${order}_prune_big | gzip -c > ${dir}/data/arpa/${order}gram_big.arpa.gz +fi + +if [ $stage -le 3 ]; then + echo "$0: pruning the LM (to smaller size)" + # Using 500,000 n-grams for a smaller LM for graph building. Prune from the + # bigger-pruned LM, it'll be faster. + size=500000 + prune_lm_dir.py --target-num-ngrams=$size ${dir}/data/lm_${order}_prune_big ${dir}/data/lm_${order}_prune_small + + get_data_prob.py ${dir}/data/real_dev_set.txt ${dir}/data/lm_${order}_prune_small 2>&1 | grep -F '[perplexity' + + format_arpa_lm.py ${dir}/data/lm_${order}_prune_small | gzip -c > ${dir}/data/arpa/${order}gram_small.arpa.gz +fi diff --git a/egs/iam/v2/local/wer_output_filter b/egs/iam/v2/local/wer_output_filter new file mode 100755 index 00000000000..24691a160a9 --- /dev/null +++ b/egs/iam/v2/local/wer_output_filter @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# Copyright 2017 Hossein Hadian + +# This is a filter used in scoring. It separates all +# punctuations from words. For e.g. this sentence: + +# "They have come!" he said reverently, gripping his +# hands. "Isn't it a glorious thing! Long awaited." + +# is converted to this: + +# " They have come ! " he said reverently , gripping his +# hands . " Isn ' t it a glorious thing ! Long awaited . " + +# Sample BPE-based output: +# |He |ro se |from |his |b re ak f as t - s ch oo l |b en ch + +import sys +import re + +punctuations = "!(),.?;:'-\"" +escaped_punctuations = re.escape(punctuations) + +for line in sys.stdin: + words = line.strip().split() + uttid = words[0] + transcript = ''.join(words[1:]) + transcript = transcript.replace('|', ' ') + split_transcript = " ".join(re.split("([{}])".format(escaped_punctuations), + transcript)).strip() + print("{} {}".format(uttid, split_transcript)) diff --git a/egs/iam/v2/path.sh b/egs/iam/v2/path.sh new file mode 100755 index 00000000000..7e458144624 --- /dev/null +++ b/egs/iam/v2/path.sh @@ -0,0 +1,9 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh + +export LD_LIBRARY_PATH=$KALDI_ROOT/tools/openfst/lib:$LD_LIBRARY_PATH +export LD_LIBRARY_PATH=/home/dpovey/libs:$LD_LIBRARY_PATH +export LC_ALL=C diff --git a/egs/iam/v2/run_end2end.sh b/egs/iam/v2/run_end2end.sh new file mode 100755 index 00000000000..c515c85fc72 --- /dev/null +++ b/egs/iam/v2/run_end2end.sh @@ -0,0 +1,146 @@ +#!/bin/bash +# Copyright 2017 Hossein Hadian + +set -e +stage=0 +nj=20 +username= +password= +process_aachen_split=false +overwrite=false +# iam_database points to the database path on the JHU grid. If you have not +# already downloaded the database you can set it to a local directory +# like "data/download" and follow the instructions +# in "local/prepare_data.sh" to download the database: +iam_database=/export/corpora5/handwriting_ocr/IAM +# wellington_database points to the database path on the JHU grid. The Wellington +# corpus contains two directories WWC and WSC (Wellington Written and Spoken Corpus). +# This corpus is of written NZ English that can be purchased here: +# "https://www.victoria.ac.nz/lals/resources/corpora-default" +wellington_database=/export/corpora5/Wellington/WWC/ + +. ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. + ## This relates to the queue. +. ./path.sh +. ./utils/parse_options.sh # e.g. this parses the above options + # if supplied. + + +./local/check_tools.sh + +if [ $stage -le 0 ]; then + + if [ -f data/train/text ] && ! $overwrite; then + echo "$0: Not processing, probably script have run from wrong stage" + echo "Exiting with status 1 to avoid data corruption" + exit 1; + fi + + echo "$0: Preparing data..." + local/prepare_data.sh --download-dir "$iam_database" \ + --wellington-dir "$wellington_database" \ + --username "$username" --password "$password" \ + --process_aachen_split $process_aachen_split +fi + +mkdir -p data/{train,test}/data +if [ $stage -le 1 ]; then + echo "$(date) stage 1: getting allowed image widths for e2e training..." + image/get_image2num_frames.py --feat-dim 40 data/train # This will be needed for the next command + # The next command creates a "allowed_lengths.txt" file in data/train + # which will be used by local/make_features.py to enforce the images to + # have allowed lengths. The allowed lengths will be spaced by 10% difference in length. + image/get_allowed_lengths.py --frame-subsampling-factor 4 10 data/train + echo "$(date) Extracting features, creating feats.scp file" + local/extract_features.sh --nj $nj --cmd "$cmd" --feat-dim 40 data/train + steps/compute_cmvn_stats.sh data/train || exit 1; + for set in val test; do + local/extract_features.sh --nj $nj --cmd "$cmd" --augment true \ + --feat-dim 40 data/${set} + steps/compute_cmvn_stats.sh data/${set} || exit 1; + done + utils/fix_data_dir.sh data/train +fi + +if [ $stage -le 2 ]; then + for set in train; do + echo "$(date) stage 2: Performing augmentation, it will double training data" + local/augment_data.sh --nj $nj --cmd "$cmd" --feat-dim 40 data/${set} data/${set}_aug data + steps/compute_cmvn_stats.sh data/${set}_aug || exit 1; + done +fi + +if [ $stage -le 3 ]; then + echo "$0: Preparing BPE..." + # getting non-silence phones. + cut -d' ' -f2- data/train/text | \ +python3 <( +cat << "END" +import os, sys, io; +infile = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8'); +output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8'); +phone_dict = dict(); +for line in infile: + line_vect = line.strip().split(); + for word in line_vect: + for phone in word: + phone_dict[phone] = phone; +for phone in phone_dict.keys(): + output.write(phone+ '\n'); +END + ) > data/local/phones.txt + + cut -d' ' -f2- data/train/text > data/local/train_data.txt + cat data/local/phones.txt data/local/train_data.txt | \ + utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/learn_bpe.py -s 700 > data/local/bpe.txt + for set in test train val train_aug; do + cut -d' ' -f1 data/$set/text > data/$set/ids + cut -d' ' -f2- data/$set/text | \ + utils/lang/bpe/prepend_words.py | utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt \ + | sed 's/@@//g' > data/$set/bpe_text + mv data/$set/text data/$set/text.old + paste -d' ' data/$set/ids data/$set/bpe_text > data/$set/text + done +fi + +if [ $stage -le 4 ]; then + echo "$0: Estimating a language model for decoding..." + local/train_lm.sh +fi + +if [ $stage -le 5 ]; then + echo "$0: Preparing dictionary and lang..." + local/prepare_dict.sh + # This recipe uses byte-pair encoding, the silences are part of the words' pronunciations. + # So we set --sil-prob to 0.0 + utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 8 --sil-prob 0.0 --position-dependent-phones false \ + data/local/dict "" data/lang/temp data/lang + silphonelist=`cat data/lang/phones/silence.csl` + nonsilphonelist=`cat data/lang/phones/nonsilence.csl` + local/gen_topo.py 8 4 4 $nonsilphonelist $silphonelist data/lang/phones.txt >data/lang/topo + utils/lang/bpe/add_final_optional_silence.sh --final-sil-prob 0.5 data/lang + + utils/format_lm.sh data/lang data/local/local_lm/data/arpa/6gram_big.arpa.gz \ + data/local/dict/lexicon.txt data/lang + utils/build_const_arpa_lm.sh data/local/local_lm/data/arpa/6gram_unpruned.arpa.gz \ + data/lang data/lang_rescore_6g +fi + +if [ $stage -le 6 ]; then + echo "$0: Calling the flat-start chain recipe..." + local/chain/run_e2e_cnn.sh --train_set train_aug +fi + +if [ $stage -le 7 ]; then + echo "$0: Aligning the training data using the e2e chain model..." + steps/nnet3/align.sh --nj 50 --cmd "$cmd" \ + --use-gpu false \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0 --acoustic-scale=1.0' \ + data/train_aug data/lang exp/chain/e2e_cnn_1b exp/chain/e2e_ali_train +fi + +if [ $stage -le 8 ]; then + echo "$0: Building a tree and training a regular chain model using the e2e alignments..." + local/chain/run_cnn_e2eali.sh --train_set train_aug +fi diff --git a/egs/iam/v2/steps b/egs/iam/v2/steps new file mode 120000 index 00000000000..6e99bf5b5ad --- /dev/null +++ b/egs/iam/v2/steps @@ -0,0 +1 @@ +../../wsj/s5/steps \ No newline at end of file diff --git a/egs/iam/v2/utils b/egs/iam/v2/utils new file mode 120000 index 00000000000..b240885218f --- /dev/null +++ b/egs/iam/v2/utils @@ -0,0 +1 @@ +../../wsj/s5/utils \ No newline at end of file diff --git a/egs/iban/s5/local/chain/tuning/run_tdnn_1a.sh b/egs/iban/s5/local/chain/tuning/run_tdnn_1a.sh index d320f49d3aa..10650a18269 100755 --- a/egs/iban/s5/local/chain/tuning/run_tdnn_1a.sh +++ b/egs/iban/s5/local/chain/tuning/run_tdnn_1a.sh @@ -136,7 +136,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.08 dropout-per-dim-continuous=true" output_opts="l2-regularize=0.02 bottleneck-dim=256" diff --git a/egs/iban/s5/local/chain/tuning/run_tdnn_1b.sh b/egs/iban/s5/local/chain/tuning/run_tdnn_1b.sh index 56f5255288c..db62e6f8a55 100755 --- a/egs/iban/s5/local/chain/tuning/run_tdnn_1b.sh +++ b/egs/iban/s5/local/chain/tuning/run_tdnn_1b.sh @@ -136,7 +136,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.08 dropout-per-dim=true dropout-per-dim-continuous=true" linear_opts="orthonormal-constraint=-1.0" output_opts="l2-regularize=0.04" diff --git a/egs/iban/s5/run.sh b/egs/iban/s5/run.sh index 991d32505bf..278a8177c0e 100755 --- a/egs/iban/s5/run.sh +++ b/egs/iban/s5/run.sh @@ -68,7 +68,7 @@ if [ $stage -le 4 ]; then echo "Starting triphone training." steps/align_si.sh --nj $nj --cmd "$train_cmd" \ data/train data/lang exp/mono exp/mono_ali - steps/train_deltas.sh --boost-silence 1.25 --cmd "$train_cmd" \ + steps/train_deltas.sh --boost-silence 1.25 --cmd "$train_cmd" \ 3200 30000 data/train data/lang exp/mono_ali exp/tri1 echo "Triphone training done." @@ -78,7 +78,7 @@ if [ $stage -le 4 ]; then steps/decode.sh --nj $dev_nj --cmd "$decode_cmd" \ exp/tri1/graph data/dev exp/tri1/decode_dev - steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ data/lang_test/ data/lang_big/ data/dev \ exp/tri1/decode_dev exp/tri1/decode_dev.rescored echo "Triphone decoding done." @@ -89,7 +89,7 @@ if [ $stage -le 5 ]; then ## Triphones + delta delta # Training echo "Starting (larger) triphone training." - steps/align_si.sh --nj $nj --cmd "$train_cmd" --use-graphs true \ + steps/align_si.sh --nj $nj --cmd "$train_cmd" --use-graphs true \ data/train data/lang exp/tri1 exp/tri1_ali steps/train_deltas.sh --cmd "$train_cmd" \ 4200 40000 data/train data/lang exp/tri1_ali exp/tri2a @@ -97,11 +97,11 @@ if [ $stage -le 5 ]; then ( echo "Decoding the dev set using triphone(large) models." - utils/mkgraph.sh data/lang_test exp/tri2a exp/tri2a/graph + utils/mkgraph.sh data/lang_test exp/tri2a exp/tri2a/graph steps/decode.sh --nj $dev_nj --cmd "$decode_cmd" \ - exp/tri2a/graph data/dev exp/tri2a/decode_dev + exp/tri2a/graph data/dev exp/tri2a/decode_dev - steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ data/lang_test/ data/lang_big/ data/dev \ exp/tri2a/decode_dev exp/tri2a/decode_dev.rescored echo "Triphone(large) decoding done." @@ -112,21 +112,21 @@ if [ $stage -le 6 ]; then ### Triphone + LDA and MLLT # Training echo "Starting LDA+MLLT training." - steps/align_si.sh --nj $nj --cmd "$train_cmd" \ + steps/align_si.sh --nj $nj --cmd "$train_cmd" \ data/train data/lang exp/tri2a exp/tri2a_ali - steps/train_lda_mllt.sh --cmd "$train_cmd" \ + steps/train_lda_mllt.sh --cmd "$train_cmd" \ --splice-opts "--left-context=3 --right-context=3" \ - 4200 40000 data/train data/lang exp/tri2a_ali exp/tri2b + 4200 40000 data/train data/lang exp/tri2a_ali exp/tri2b echo "LDA+MLLT training done." ( echo "Decoding the dev set using LDA+MLLT models." - utils/mkgraph.sh data/lang_test exp/tri2b exp/tri2b/graph - steps/decode.sh --nj $dev_nj --cmd "$decode_cmd" \ - exp/tri2b/graph data/dev exp/tri2b/decode_dev + utils/mkgraph.sh data/lang_test exp/tri2b exp/tri2b/graph + steps/decode.sh --nj $dev_nj --cmd "$decode_cmd" \ + exp/tri2b/graph data/dev exp/tri2b/decode_dev - steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ data/lang_test/ data/lang_big/ data/dev \ exp/tri2b/decode_dev exp/tri2b/decode_dev.rescored echo "LDA+MLLT decoding done." @@ -138,7 +138,7 @@ if [ $stage -le 7 ]; then ### Triphone + LDA and MLLT + SAT and FMLLR # Training echo "Starting SAT+FMLLR training." - steps/align_si.sh --nj $nj --cmd "$train_cmd" \ + steps/align_si.sh --nj $nj --cmd "$train_cmd" \ --use-graphs true data/train data/lang exp/tri2b exp/tri2b_ali steps/train_sat.sh --cmd "$train_cmd" 4200 40000 \ data/train data/lang exp/tri2b_ali exp/tri3b @@ -150,7 +150,7 @@ if [ $stage -le 7 ]; then steps/decode_fmllr.sh --nj $dev_nj --cmd "$decode_cmd" \ exp/tri3b/graph data/dev exp/tri3b/decode_dev - steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ data/lang_test/ data/lang_big/ data/dev \ exp/tri3b/decode_dev exp/tri3b/decode_dev.rescored echo "SAT+FMLLR decoding done." @@ -163,10 +163,10 @@ if [ $stage -le 8 ]; then steps/align_fmllr.sh --nj $nj --cmd "$train_cmd" \ data/train data/lang exp/tri3b exp/tri3b_ali - steps/train_ubm.sh --cmd "$train_cmd" \ + steps/train_ubm.sh --cmd "$train_cmd" \ 600 data/train data/lang exp/tri3b_ali exp/ubm5b2 - steps/train_sgmm2.sh --cmd "$train_cmd" \ + steps/train_sgmm2.sh --cmd "$train_cmd" \ 5200 12000 data/train data/lang exp/tri3b_ali exp/ubm5b2/final.ubm exp/sgmm2_5b2 echo "SGMM training done." @@ -180,7 +180,7 @@ if [ $stage -le 8 ]; then --transform-dir exp/tri3b/decode_dev \ exp/sgmm2_5b2/graph data/dev exp/sgmm2_5b2/decode_dev - steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ data/lang_test/ data/lang_big/ data/dev \ exp/sgmm2_5b2/decode_dev exp/sgmm2_5b2/decode_dev.rescored diff --git a/egs/ifnenit/v1/README.txt b/egs/ifnenit/README.txt similarity index 100% rename from egs/ifnenit/v1/README.txt rename to egs/ifnenit/README.txt diff --git a/egs/ifnenit/v1/local/chain/run_cnn_1a.sh b/egs/ifnenit/v1/local/chain/run_cnn_1a.sh index b0e147d157b..b0ecd547741 100755 --- a/egs/ifnenit/v1/local/chain/run_cnn_1a.sh +++ b/egs/ifnenit/v1/local/chain/run_cnn_1a.sh @@ -123,7 +123,7 @@ if [ $stage -le 4 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) common1="required-time-offsets=0 height-offsets=-2,-1,0,1,2 num-filters-out=36" common2="required-time-offsets=0 height-offsets=-2,-1,0,1,2 num-filters-out=70" mkdir -p $dir/configs diff --git a/egs/ifnenit/v1/local/chain/run_cnn_chainali_1a.sh b/egs/ifnenit/v1/local/chain/run_cnn_chainali_1a.sh index b1f33b41a0c..7f3132d657e 100755 --- a/egs/ifnenit/v1/local/chain/run_cnn_chainali_1a.sh +++ b/egs/ifnenit/v1/local/chain/run_cnn_chainali_1a.sh @@ -128,7 +128,7 @@ if [ $stage -le 4 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) common1="required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" common2="required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" common3="required-time-offsets= height-offsets=-1,0,1 num-filters-out=70" diff --git a/egs/ifnenit/v1/local/make_features.py b/egs/ifnenit/v1/local/make_features.py index 3a485e32eb1..87afa37c00a 100755 --- a/egs/ifnenit/v1/local/make_features.py +++ b/egs/ifnenit/v1/local/make_features.py @@ -10,7 +10,7 @@ eg. local/make_features.py data/train --feat-dim 40 """ - +from __future__ import division import argparse import os @@ -24,8 +24,8 @@ signal(SIGPIPE,SIG_DFL) parser = argparse.ArgumentParser(description="""Generates and saves the feature vectors""") -parser.add_argument('dir', type=str, help='directory of images.scp and is also output directory') -parser.add_argument('--out-ark', type=str, default='-', help='where to write the output feature file') +parser.add_argument('dir', help='directory of images.scp and is also output directory') +parser.add_argument('--out-ark', default='-', help='where to write the output feature file') parser.add_argument('--feat-dim', type=int, default=40, help='size to scale the height of all images') parser.add_argument('--padding', type=int, default=5, help='size to scale the height of all images') args = parser.parse_args() @@ -42,7 +42,7 @@ def write_kaldi_matrix(file_handle, matrix, key): if num_cols != len(matrix[row_index]): raise Exception("All the rows of a matrix are expected to " "have the same length") - file_handle.write(" ".join(map(lambda x: str(x), matrix[row_index]))) + file_handle.write(" ".join([str(x) for x in matrix[row_index]])) if row_index != num_rows - 1: file_handle.write("\n") file_handle.write(" ]\n") @@ -51,7 +51,7 @@ def get_scaled_image(im): scale_size = args.feat_dim sx = im.shape[1] sy = im.shape[0] - scale = (1.0 * scale_size) / sy + scale = (1.0 * scale_size)/ sy nx = int(scale_size) ny = int(scale * sx) im = misc.imresize(im, (nx, ny)) diff --git a/egs/librispeech/s5/local/chain/run_cnn_tdnn.sh b/egs/librispeech/s5/local/chain/run_cnn_tdnn.sh new file mode 100755 index 00000000000..cd8f38d8309 --- /dev/null +++ b/egs/librispeech/s5/local/chain/run_cnn_tdnn.sh @@ -0,0 +1 @@ +tuning/run_cnn_tdnn_1a.sh diff --git a/egs/librispeech/s5/local/chain/run_tdnn_lstm.sh b/egs/librispeech/s5/local/chain/run_tdnn_lstm.sh new file mode 120000 index 00000000000..a4fa11e0908 --- /dev/null +++ b/egs/librispeech/s5/local/chain/run_tdnn_lstm.sh @@ -0,0 +1 @@ +tuning/run_tdnn_lstm_1b.sh \ No newline at end of file diff --git a/egs/librispeech/s5/local/chain/tuning/run_cnn_tdnn_1a.sh b/egs/librispeech/s5/local/chain/tuning/run_cnn_tdnn_1a.sh new file mode 100755 index 00000000000..8ebca6fd650 --- /dev/null +++ b/egs/librispeech/s5/local/chain/tuning/run_cnn_tdnn_1a.sh @@ -0,0 +1,278 @@ +#!/bin/bash + +# This is based on tdnn_1d_sp, but adding cnn as the front-end. +# The cnn-tdnn-f (tdnn_cnn_1a_sp) outperforms the tdnn-f (tdnn_1d_sp). + +# bash local/chain/compare_wer.sh exp/chain_cleaned/tdnn_1d_sp exp/chain_cleaned/tdnn_cnn_1a_sp/ +# System tdnn_1d_sp tdnn_cnn_1a_sp +# WER on dev(fglarge) 3.29 3.34 +# WER on dev(tglarge) 3.44 3.39 +# WER on dev(tgmed) 4.22 4.29 +# WER on dev(tgsmall) 4.72 4.77 +# WER on dev_other(fglarge) 8.71 8.62 +# WER on dev_other(tglarge) 9.05 9.00 +# WER on dev_other(tgmed) 11.09 10.93 +# WER on dev_other(tgsmall) 12.13 12.02 +# WER on test(fglarge) 3.80 3.69 +# WER on test(tglarge) 3.89 3.80 +# WER on test(tgmed) 4.72 4.64 +# WER on test(tgsmall) 5.19 5.16 +# WER on test_other(fglarge) 8.76 8.71 +# WER on test_other(tglarge) 9.19 9.11 +# WER on test_other(tgmed) 11.22 11.00 +# WER on test_other(tgsmall) 12.24 12.16 +# Final train prob -0.0378 -0.0420 +# Final valid prob -0.0374 -0.0400 +# Final train prob (xent) -0.6099 -0.6881 +# Final valid prob (xent) -0.6353 -0.7180 +# Num-parameters 22623456 18100736 + + +set -e + +# configs for 'chain' +stage=0 +decode_nj=50 +train_set=train_960_cleaned +gmm=tri6b_cleaned +nnet3_affix=_cleaned + +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +affix=cnn_1a +tree_affix= +train_stage=-10 +get_egs_stage=-10 +decode_iter= + +# TDNN options +frames_per_eg=150,110,100 +remove_egs=true +common_egs_dir= +xent_regularize=0.1 +dropout_schedule='0,0@0.20,0.5@0.50,0' + +test_online_decoding=true # if true, it will run the last decoding stage. + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # MFCC to filterbank + idct-layer name=idct input=input dim=40 cepstral-lifter=22 affine-transform-file=$dir/configs/idct.mat + + linear-component name=ivector-linear $ivector_affine_opts dim=200 input=ReplaceIndex(ivector, t, 0) + batchnorm-component name=ivector-batchnorm target-rms=0.025 + batchnorm-component name=idct-batchnorm input=idct + + combine-feature-maps-layer name=combine_inputs input=Append(idct-batchnorm, ivector-batchnorm) num-filters1=1 num-filters2=5 height=40 + conv-relu-batchnorm-layer name=cnn1 $cnn_opts height-in=40 height-out=40 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=64 + conv-relu-batchnorm-layer name=cnn2 $cnn_opts height-in=40 height-out=40 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=64 + conv-relu-batchnorm-layer name=cnn3 $cnn_opts height-in=40 height-out=20 height-subsample-out=2 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=128 + conv-relu-batchnorm-layer name=cnn4 $cnn_opts height-in=20 height-out=20 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=128 + conv-relu-batchnorm-layer name=cnn5 $cnn_opts height-in=20 height-out=10 height-subsample-out=2 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=256 + conv-relu-batchnorm-layer name=cnn6 $cnn_opts height-in=10 height-out=10 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=256 + + # the first TDNN-F layer has no bypass + tdnnf-layer name=tdnnf7 $tdnnf_first_opts dim=1536 bottleneck-dim=256 time-stride=0 + tdnnf-layer name=tdnnf8 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf9 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf10 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf11 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf12 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf13 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf14 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf15 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf16 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf17 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf18 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + linear-component name=prefinal-l dim=256 $linear_opts + + prefinal-layer name=prefinal-chain input=prefinal-l $prefinal_opts big-dim=1536 small-dim=256 + output-layer name=output include-log-softmax=false dim=$num_targets $output_opts + + prefinal-layer name=prefinal-xent input=prefinal-l $prefinal_opts big-dim=1536 small-dim=256 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor $output_opts +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 15 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b{09,10,11,12}/$USER/kaldi-data/egs/swbd-$(date +'%m_%d_%H_%M')/s5c/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage $train_stage \ + --use-gpu "wait" \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir $train_ivector_dir \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.0 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--frames-overlap-per-eg 0 --constrained false" \ + --egs.chunk-width $frames_per_eg \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --trainer.num-chunk-per-minibatch 64 \ + --trainer.frames-per-iter 2500000 \ + --trainer.num-epochs 4 \ + --trainer.optimization.num-jobs-initial 3 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.00015 \ + --trainer.optimization.final-effective-lrate 0.000015 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs $remove_egs \ + --feat-dir $train_data_dir \ + --tree-dir $tree_dir \ + --lat-dir $lat_dir \ + --dir $dir || exit 1; + +fi + +graph_dir=$dir/graph_tgsmall +if [ $stage -le 16 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 --remove-oov data/lang_test_tgsmall $dir $graph_dir + # remove from the graph, and convert back to const-FST. + fstrmsymbols --apply-to-output=true --remove-arcs=true "echo 3|" $graph_dir/HCLG.fst - | \ + fstconvert --fst_type=const > $graph_dir/temp.fst + mv $graph_dir/temp.fst $graph_dir/HCLG.fst +fi + +iter_opts= +if [ ! -z $decode_iter ]; then + iter_opts=" --iter $decode_iter " +fi +if [ $stage -le 17 ]; then + rm $dir/.error 2>/dev/null || true + for decode_set in test_clean test_other dev_clean dev_other; do + ( + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $decode_nj --cmd "$decode_cmd" $iter_opts \ + --online-ivector-dir exp/nnet3${nnet3_affix}/ivectors_${decode_set}_hires \ + $graph_dir data/${decode_set}_hires $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_tgsmall || exit 1 + steps/lmrescore.sh --cmd "$decode_cmd" --self-loop-scale 1.0 data/lang_test_{tgsmall,tgmed} \ + data/${decode_set}_hires $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_{tgsmall,tgmed} || exit 1 + steps/lmrescore_const_arpa.sh \ + --cmd "$decode_cmd" data/lang_test_{tgsmall,tglarge} \ + data/${decode_set}_hires $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_{tgsmall,tglarge} || exit 1 + steps/lmrescore_const_arpa.sh \ + --cmd "$decode_cmd" data/lang_test_{tgsmall,fglarge} \ + data/${decode_set}_hires $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_{tgsmall,fglarge} || exit 1 + ) || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in decoding" + exit 1 + fi +fi + +if $test_online_decoding && [ $stage -le 18 ]; then + # note: if the features change (e.g. you add pitch features), you will have to + # change the options of the following command line. + steps/online/nnet3/prepare_online_decoding.sh \ + --mfcc-config conf/mfcc_hires.conf \ + $lang exp/nnet3${nnet3_affix}/extractor $dir ${dir}_online + + rm $dir/.error 2>/dev/null || true + for data in test_clean test_other dev_clean dev_other; do + ( + nspk=$(wc -l $dir/configs/network.xconfig diff --git a/egs/librispeech/s5/local/chain/tuning/run_tdnn_1c.sh b/egs/librispeech/s5/local/chain/tuning/run_tdnn_1c.sh index 29ebe62ddde..3970fa8c4d9 100755 --- a/egs/librispeech/s5/local/chain/tuning/run_tdnn_1c.sh +++ b/egs/librispeech/s5/local/chain/tuning/run_tdnn_1c.sh @@ -112,7 +112,7 @@ if [ $stage -le 14 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.002" linear_opts="orthonormal-constraint=1.0" output_opts="l2-regularize=0.0005 bottleneck-dim=256" diff --git a/egs/librispeech/s5/local/chain/tuning/run_tdnn_1d.sh b/egs/librispeech/s5/local/chain/tuning/run_tdnn_1d.sh index 81b621ef86f..5c488362e59 100755 --- a/egs/librispeech/s5/local/chain/tuning/run_tdnn_1d.sh +++ b/egs/librispeech/s5/local/chain/tuning/run_tdnn_1d.sh @@ -207,7 +207,7 @@ if [ $stage -le 14 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) affine_opts="l2-regularize=0.008 dropout-proportion=0.0 dropout-per-dim=true dropout-per-dim-continuous=true" tdnnf_opts="l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.75" linear_opts="l2-regularize=0.008 orthonormal-constraint=-1.0" diff --git a/egs/librispeech/s5/local/chain/tuning/run_tdnn_lstm_1a.sh b/egs/librispeech/s5/local/chain/tuning/run_tdnn_lstm_1a.sh new file mode 100755 index 00000000000..4277f769119 --- /dev/null +++ b/egs/librispeech/s5/local/chain/tuning/run_tdnn_lstm_1a.sh @@ -0,0 +1,222 @@ +#!/bin/bash +# this is the tdnn-lstmp based on the run_tdnn_lstm_1n.sh under Switchboard. + +# training acoustic model and decoding: +# local/chain/tuning/run_tdnn_lstm_1a.sh +# System tdnn_lstm1a_sp +# WER on dev(fglarge) 3.44 +# WER on dev(tglarge) 3.55 +# WER on dev_other(fglarge) 8.63 +# WER on dev_other(tglarge) 9.09 +# WER on test(fglarge) 3.78 +# WER on test(tglarge) 3.94 +# WER on test_other(fglarge) 8.83 +# WER on test_other(tglarge) 9.09 +# Final train prob -0.0452 +# Final valid prob -0.0477 +# Final train prob (xent) -0.7874 +# Final valid prob (xent) -0.8150 +# Num-parameters 27790288 +# exp/chain_cleaned/tdnn_lstm1a_sp/: num-iters=1303 nj=3..16 num-params=27.8M dim=40+100->6056 combine=-0.041->-0.040 (over 9) xent:train/valid[867,1302,final]=(-1.15,-0.782,-0.787/-1.18,-0.810,-0.815) logprob:train/valid[867,1302,final]=(-0.063,-0.047,-0.045/-0.062,-0.049,-0.048) + +set -e + +# configs for 'chain' +stage=12 +train_stage=-10 +get_egs_stage=-10 +speed_perturb=true +affix=1a +decode_iter= +decode_nj=50 + +# LSTM training options +frames_per_chunk=140,100,160 +frames_per_chunk_primary=$(echo $frames_per_chunk | cut -d, -f1) +chunk_left_context=40 +chunk_right_context=0 +xent_regularize=0.025 +self_repair_scale=0.00001 +label_delay=5 +# decode options +extra_left_context=50 +extra_right_context=0 +dropout_schedule='0,0@0.20,0.3@0.50,0' + +remove_egs=false +common_egs_dir= +nnet3_affix=_cleaned +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 $opts dim=1280 + linear-component name=tdnn2l dim=256 $linear_opts input=Append(-1,0) + relu-batchnorm-layer name=tdnn2 $opts input=Append(0,1) dim=1280 + linear-component name=tdnn3l dim=256 $linear_opts + relu-batchnorm-layer name=tdnn3 $opts dim=1280 + linear-component name=tdnn4l dim=256 $linear_opts input=Append(-1,0) + relu-batchnorm-layer name=tdnn4 $opts input=Append(0,1) dim=1280 + linear-component name=tdnn5l dim=256 $linear_opts + relu-batchnorm-layer name=tdnn5 $opts dim=1280 input=Append(tdnn5l, tdnn3l) + linear-component name=tdnn6l dim=256 $linear_opts input=Append(-3,0) + relu-batchnorm-layer name=tdnn6 $opts input=Append(0,3) dim=1280 + linear-component name=lstm1l dim=256 $linear_opts input=Append(-3,0) + fast-lstmp-layer name=lstm1 cell-dim=1024 recurrent-projection-dim=256 non-recurrent-projection-dim=128 delay=-3 dropout-proportion=0.0 $lstm_opts + relu-batchnorm-layer name=tdnn7 $opts input=Append(0,3,tdnn6l,tdnn4l,tdnn2l) dim=1280 + linear-component name=tdnn8l dim=256 $linear_opts input=Append(-3,0) + relu-batchnorm-layer name=tdnn8 $opts input=Append(0,3) dim=1280 + linear-component name=lstm2l dim=256 $linear_opts input=Append(-3,0) + fast-lstmp-layer name=lstm2 cell-dim=1280 recurrent-projection-dim=256 non-recurrent-projection-dim=128 delay=-3 dropout-proportion=0.0 $lstm_opts + relu-batchnorm-layer name=tdnn9 $opts input=Append(0,3,tdnn8l,tdnn6l,tdnn4l) dim=1280 + linear-component name=tdnn10l dim=256 $linear_opts input=Append(-3,0) + relu-batchnorm-layer name=tdnn10 $opts input=Append(0,3) dim=1280 + linear-component name=lstm3l dim=256 $linear_opts input=Append(-3,0) + fast-lstmp-layer name=lstm3 cell-dim=1280 recurrent-projection-dim=256 non-recurrent-projection-dim=128 delay=-3 dropout-proportion=0.0 $lstm_opts + + output-layer name=output input=lstm3 include-log-softmax=false $output_opts + + output-layer name=output-xent input=lstm3 learning-rate-factor=$learning_rate_factor $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 13 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/c0{1,2,5,7}/$USER/kaldi-data/egs/swbd-$(date +'%m_%d_%H_%M')/s5c/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir $train_ivector_dir \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.0 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.num-chunk-per-minibatch 64,32 \ + --trainer.frames-per-iter 1500000 \ + --trainer.max-param-change 2.0 \ + --trainer.num-epochs 6 \ + --trainer.optimization.num-jobs-initial 3 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.optimization.momentum 0.0 \ + --trainer.deriv-truncate-margin 8 \ + --egs.stage $get_egs_stage \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width $frames_per_chunk \ + --egs.chunk-left-context $chunk_left_context \ + --egs.chunk-right-context $chunk_right_context \ + --egs.chunk-left-context-initial 0 \ + --egs.chunk-right-context-final 0 \ + --egs.dir "$common_egs_dir" \ + --cleanup.remove-egs $remove_egs \ + --feat-dir $train_data_dir \ + --tree-dir $tree_dir \ + --lat-dir $lat_dir \ + --dir $dir || exit 1; +fi + + +graph_dir=$dir/graph_tgsmall +if [ $stage -le 14 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 --remove-oov data/lang_test_tgsmall $dir $graph_dir + # remove from the graph, and convert back to const-FST. + fstrmsymbols --apply-to-output=true --remove-arcs=true "echo 3|" $graph_dir/HCLG.fst - | \ + fstconvert --fst_type=const > $graph_dir/temp.fst + mv $graph_dir/temp.fst $graph_dir/HCLG.fst +fi + + +iter_opts= +if [ ! -z $decode_iter ]; then + iter_opts=" --iter $decode_iter " +fi +if [ $stage -le 15 ]; then + rm $dir/.error 2>/dev/null || true + for decode_set in test_clean test_other dev_clean dev_other; do + ( + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $decode_nj --cmd "$decode_cmd" $iter_opts \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --extra-left-context-initial 0 \ + --extra-right-context-final 0 \ + --frames-per-chunk "$frames_per_chunk_primary" \ + --online-ivector-dir exp/nnet3${nnet3_affix}/ivectors_${decode_set}_hires \ + $graph_dir data/${decode_set}_hires $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_tgsmall || exit 1 + steps/lmrescore.sh --cmd "$decode_cmd" --self-loop-scale 1.0 data/lang_test_{tgsmall,tgmed} \ + data/${decode_set}_hires $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_{tgsmall,tgmed} || exit 1 + steps/lmrescore_const_arpa.sh \ + --cmd "$decode_cmd" data/lang_test_{tgsmall,tglarge} \ + data/${decode_set}_hires $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_{tgsmall,tglarge} || exit 1 + steps/lmrescore_const_arpa.sh \ + --cmd "$decode_cmd" data/lang_test_{tgsmall,fglarge} \ + data/${decode_set}_hires $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_{tgsmall,fglarge} || exit 1 + ) || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in decoding" + exit 1 + fi +fi diff --git a/egs/librispeech/s5/local/chain/tuning/run_tdnn_lstm_1b.sh b/egs/librispeech/s5/local/chain/tuning/run_tdnn_lstm_1b.sh new file mode 100755 index 00000000000..383cc533270 --- /dev/null +++ b/egs/librispeech/s5/local/chain/tuning/run_tdnn_lstm_1b.sh @@ -0,0 +1,257 @@ +#!/bin/bash +# this is the tdnn-lstmp based on the run_tdnn_lstm_1a.sh under Librispeech but with larger model size. + +# training acoustic model and decoding: +# local/chain/tuning/run_tdnn_lstm_1b.sh +# local/chain/compare_wer.sh exp/chain_cleaned/tdnn_lstm1a_sp exp/chain_cleaned/tdnn_lstm1b_sp +# System tdnn_lstm1a_sp tdnn_lstm1b_sp +# WER on dev(fglarge) 3.44 3.36 +# WER on dev(tglarge) 3.55 3.48 +# WER on dev(tgmed) 4.41 4.26 +# WER on dev(tgsmall) 4.82 4.71 +# WER on dev_other(fglarge) 8.63 8.43 +# WER on dev_other(tglarge) 9.09 8.94 +# WER on dev_other(tgmed) 10.99 10.65 +# WER on dev_other(tgsmall) 11.95 11.51 +# WER on test(fglarge) 3.78 3.83 +# WER on test(tglarge) 3.94 3.93 +# WER on test(tgmed) 4.68 4.72 +# WER on test(tgsmall) 5.11 5.10 +# WER on test_other(fglarge) 8.83 8.69 +# WER on test_other(tglarge) 9.09 9.10 +# WER on test_other(tgmed) 11.05 10.86 +# WER on test_other(tgsmall) 12.18 11.83 +# Final train prob -0.0452 -0.0417 +# Final valid prob -0.0477 -0.0459 +# Final train prob (xent) -0.7874 -0.7488 +# Final valid prob (xent) -0.8150 -0.7757 +# Num-parameters 27790288 45245520 + +# rnn-lm rescoring: +# local/rnnlm/tuning/run_tdnn_lstm_1a.sh --ac-model-dir exp/chain_cleaned/tdnn_lstm1b_sp/ +# System tdnn_lstm1b_sp +# WER on dev(fglarge_nbe_rnnlm) 2.73 +# WER on dev(fglarge_lat_rnnlm) 2.83 +# WER on dev(fglarge) 3.36 +# WER on dev(tglarge) 3.48 +# WER on dev_other(fglarge_nbe_rnnlm) 7.20 +# WER on dev_other(fglarge_lat_rnnlm) 7.23 +# WER on dev_other(fglarge) 8.43 +# WER on dev_other(tglarge) 8.94 +# WER on test(fglarge_nbe_rnnlm) 3.10 +# WER on test(fglarge_lat_rnnlm) 3.22 +# WER on test(fglarge) 3.83 +# WER on test(tglarge) 3.93 +# WER on test_other(fglarge_nbe_rnnlm) 7.54 +# WER on test_other(fglarge_lat_rnnlm) 7.65 +# WER on test_other(fglarge) 8.69 +# WER on test_other(tglarge) 9.10 +# Final train prob -0.0417 +# Final valid prob -0.0459 +# Final train prob (xent) -0.7488 +# Final valid prob (xent) -0.7757 +# Num-parameters 45245520 + + + +set -e + +# configs for 'chain' +stage=12 +train_stage=-10 +get_egs_stage=-10 +speed_perturb=true +affix=1b +decode_iter= +decode_nj=50 + +# LSTM training options +frames_per_chunk=140,100,160 +frames_per_chunk_primary=$(echo $frames_per_chunk | cut -d, -f1) +chunk_left_context=40 +chunk_right_context=0 +xent_regularize=0.025 +self_repair_scale=0.00001 +label_delay=5 +# decode options +extra_left_context=50 +extra_right_context=0 +dropout_schedule='0,0@0.20,0.3@0.50,0' + +remove_egs=false +common_egs_dir= +nnet3_affix=_cleaned +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 $opts dim=1280 + linear-component name=tdnn2l dim=320 $linear_opts input=Append(-1,0) + relu-batchnorm-layer name=tdnn2 $opts input=Append(0,1) dim=1280 + linear-component name=tdnn3l dim=320 $linear_opts + relu-batchnorm-layer name=tdnn3 $opts dim=1280 + linear-component name=tdnn4l dim=320 $linear_opts input=Append(-1,0) + relu-batchnorm-layer name=tdnn4 $opts input=Append(0,1) dim=1280 + linear-component name=tdnn5l dim=320 $linear_opts + relu-batchnorm-layer name=tdnn5 $opts dim=1280 input=Append(tdnn5l, tdnn3l) + linear-component name=tdnn6l dim=320 $linear_opts input=Append(-3,0) + relu-batchnorm-layer name=tdnn6 $opts input=Append(0,3) dim=1280 + linear-component name=lstm1l dim=320 $linear_opts input=Append(-3,0) + fast-lstmp-layer name=lstm1 cell-dim=1536 recurrent-projection-dim=384 non-recurrent-projection-dim=384 delay=-3 dropout-proportion=0.0 $lstm_opts + relu-batchnorm-layer name=tdnn7 $opts input=Append(0,3,tdnn6l,tdnn4l,tdnn2l) dim=1280 + linear-component name=tdnn8l dim=320 $linear_opts input=Append(-3,0) + relu-batchnorm-layer name=tdnn8 $opts input=Append(0,3) dim=1280 + linear-component name=lstm2l dim=320 $linear_opts input=Append(-3,0) + fast-lstmp-layer name=lstm2 cell-dim=1536 recurrent-projection-dim=384 non-recurrent-projection-dim=384 delay=-3 dropout-proportion=0.0 $lstm_opts + relu-batchnorm-layer name=tdnn9 $opts input=Append(0,3,tdnn8l,tdnn6l,tdnn4l) dim=1280 + linear-component name=tdnn10l dim=320 $linear_opts input=Append(-3,0) + relu-batchnorm-layer name=tdnn10 $opts input=Append(0,3) dim=1280 + linear-component name=lstm3l dim=320 $linear_opts input=Append(-3,0) + fast-lstmp-layer name=lstm3 cell-dim=1536 recurrent-projection-dim=384 non-recurrent-projection-dim=384: delay=-3 dropout-proportion=0.0 $lstm_opts + + output-layer name=output input=lstm3 include-log-softmax=false $output_opts + + output-layer name=output-xent input=lstm3 learning-rate-factor=$learning_rate_factor $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 13 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/c0{1,2,5,7}/$USER/kaldi-data/egs/swbd-$(date +'%m_%d_%H_%M')/s5c/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir $train_ivector_dir \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.0 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.num-chunk-per-minibatch 64,32 \ + --trainer.frames-per-iter 1500000 \ + --trainer.max-param-change 2.0 \ + --trainer.num-epochs 6 \ + --trainer.optimization.num-jobs-initial 3 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.optimization.momentum 0.0 \ + --trainer.deriv-truncate-margin 8 \ + --egs.stage $get_egs_stage \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width $frames_per_chunk \ + --egs.chunk-left-context $chunk_left_context \ + --egs.chunk-right-context $chunk_right_context \ + --egs.chunk-left-context-initial 0 \ + --egs.chunk-right-context-final 0 \ + --egs.dir "$common_egs_dir" \ + --cleanup.remove-egs $remove_egs \ + --feat-dir $train_data_dir \ + --tree-dir $tree_dir \ + --lat-dir $lat_dir \ + --dir $dir || exit 1; +fi + + +graph_dir=$dir/graph_tgsmall +if [ $stage -le 14 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 --remove-oov data/lang_test_tgsmall $dir $graph_dir + # remove from the graph, and convert back to const-FST. + fstrmsymbols --apply-to-output=true --remove-arcs=true "echo 3|" $graph_dir/HCLG.fst - | \ + fstconvert --fst_type=const > $graph_dir/temp.fst + mv $graph_dir/temp.fst $graph_dir/HCLG.fst +fi + + +iter_opts= +if [ ! -z $decode_iter ]; then + iter_opts=" --iter $decode_iter " +fi +if [ $stage -le 15 ]; then + rm $dir/.error 2>/dev/null || true + for decode_set in test_clean test_other dev_clean dev_other; do + ( + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $decode_nj --cmd "$decode_cmd" $iter_opts \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --extra-left-context-initial 0 \ + --extra-right-context-final 0 \ + --frames-per-chunk "$frames_per_chunk_primary" \ + --online-ivector-dir exp/nnet3${nnet3_affix}/ivectors_${decode_set}_hires \ + $graph_dir data/${decode_set}_hires $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_tgsmall || exit 1 + steps/lmrescore.sh --cmd "$decode_cmd" --self-loop-scale 1.0 data/lang_test_{tgsmall,tgmed} \ + data/${decode_set}_hires $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_{tgsmall,tgmed} || exit 1 + steps/lmrescore_const_arpa.sh \ + --cmd "$decode_cmd" data/lang_test_{tgsmall,tglarge} \ + data/${decode_set}_hires $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_{tgsmall,tglarge} || exit 1 + steps/lmrescore_const_arpa.sh \ + --cmd "$decode_cmd" data/lang_test_{tgsmall,fglarge} \ + data/${decode_set}_hires $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_{tgsmall,fglarge} || exit 1 + ) || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in decoding" + exit 1 + fi +fi diff --git a/egs/librispeech/s5/local/download_and_untar.sh b/egs/librispeech/s5/local/download_and_untar.sh index d01e681fed7..1bb6d909edc 100755 --- a/egs/librispeech/s5/local/download_and_untar.sh +++ b/egs/librispeech/s5/local/download_and_untar.sh @@ -67,7 +67,9 @@ if [ -f $data/$part.tar.gz ]; then fi fi -if [ ! -f $data/$part.tar.gz ]; then +pushd $data + +if [ ! -f $part.tar.gz ]; then if ! which wget >/dev/null; then echo "$0: wget is not installed." exit 1; @@ -75,20 +77,19 @@ if [ ! -f $data/$part.tar.gz ]; then full_url=$url/$part.tar.gz echo "$0: downloading data from $full_url. This may take some time, please be patient." - cd $data if ! wget --no-check-certificate $full_url; then echo "$0: error executing wget $full_url" exit 1; fi fi -cd $data - if ! tar -xvzf $part.tar.gz; then echo "$0: error un-tarring archive $data/$part.tar.gz" exit 1; fi +popd >&/dev/null + touch $data/LibriSpeech/$part/.complete echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz" diff --git a/egs/librispeech/s5/local/lm/python/text_post_process.py b/egs/librispeech/s5/local/lm/python/text_post_process.py index 4ffbbe04b1f..344c1b291bd 100755 --- a/egs/librispeech/s5/local/lm/python/text_post_process.py +++ b/egs/librispeech/s5/local/lm/python/text_post_process.py @@ -21,10 +21,10 @@ def parse_args(): parser.add_argument('--abort-long-sent', type=bool, default=False, help='If True and a sentence longer than "max-sent-len" detected' +\ 'exit with error code 1. If False, just split the long sentences.') - parser.add_argument('--sent-end-marker', type=str, default="DOTDOTDOT") - parser.add_argument("in_text", type=str, help="Input text") - parser.add_argument("out_text", type=str, help="Output text") - parser.add_argument("sent_bounds", type=str, + parser.add_argument('--sent-end-marker', default="DOTDOTDOT") + parser.add_argument("in_text", help="Input text") + parser.add_argument("out_text", help="Output text") + parser.add_argument("sent_bounds", help="A file that will contain a comma separated list of numbers, s.t. if" + "i is in this list, then there is a sententence break after token i") return parser.parse_args() @@ -66,7 +66,7 @@ def parse_args(): n_tokens += 1 start_scan = 4 current_line.append('SUN') - for i in xrange(start_scan, len(opl_tokens)): + for i in range(start_scan, len(opl_tokens)): m = re.match("^[A-Z]+\'?[A-Z\']*$", opl_tokens[i]) if m is not None: n_tokens += 1 diff --git a/egs/librispeech/s5/local/lm/python/text_pre_process.py b/egs/librispeech/s5/local/lm/python/text_pre_process.py index 6228079b3a3..b75d0711d13 100755 --- a/egs/librispeech/s5/local/lm/python/text_pre_process.py +++ b/egs/librispeech/s5/local/lm/python/text_pre_process.py @@ -20,13 +20,13 @@ def parse_args(): parser = argparse.ArgumentParser(description="Pre-process a book's text") - parser.add_argument("--in-encoding", type=str, default="utf-8", + parser.add_argument("--in-encoding", default="utf-8", help="Encoding to use when reading the input text") - parser.add_argument("--out-encoding", type=str, default="ascii", + parser.add_argument("--out-encoding", default="ascii", help="Encoding to use when writing the output text") - parser.add_argument('--sent-end-marker', type=str, default="DOTDOTDOT") - parser.add_argument("in_text", type=str, help="Input text") - parser.add_argument("out_text", type=str, help="Output text") + parser.add_argument('--sent-end-marker', default="DOTDOTDOT") + parser.add_argument("in_text", help="Input text") + parser.add_argument("out_text", help="Output text") return parser.parse_args() # http://rosettacode.org/wiki/Roman_numerals/Decode#Python diff --git a/egs/librispeech/s5/local/rnnlm/tuning/run_tdnn_lstm_1a.sh b/egs/librispeech/s5/local/rnnlm/tuning/run_tdnn_lstm_1a.sh new file mode 100755 index 00000000000..137a972f3d9 --- /dev/null +++ b/egs/librispeech/s5/local/rnnlm/tuning/run_tdnn_lstm_1a.sh @@ -0,0 +1,166 @@ +#!/bin/bash + +# Copyright 2012 Johns Hopkins University (author: Daniel Povey) +# 2018 Ke Li + +# This script trains LMs on the librispeech-lm-norm.txt.gz. + +# rnnlm/train_rnnlm.sh: best iteration (out of 143) was 142, linking it to final iteration. +# rnnlm/train_rnnlm.sh: train/dev perplexity was 109.2 / 110.7. +# Train objf: -5.74 -5.54 -5.44 -5.37 -5.32 -5.28 -5.25 -5.23 -5.20 -5.18 -5.15 -5.14 -5.12 -5.10 -5.09 -5.08 -5.07 -5.05 -5.04 -5.04 -5.03 -5.02 -5.01 -5.00 -4.99 -4.99 -4.98 -4.97 -4.96 -4.96 -4.95 -4.95 -4.94 -4.93 -4.93 -4.92 -4.92 -4.92 -4.91 -4.90 -4.90 -4.89 -4.89 -4.89 -4.88 -4.88 -4.87 -4.87 -4.87 -4.86 -4.86 -4.86 -4.85 -4.85 -4.84 -4.84 -4.84 -4.84 -4.84 -4.83 -4.83 -4.83 -4.82 -4.82 -4.82 -4.82 -4.81 -4.81 -4.81 -4.81 -4.80 -4.80 -4.80 -4.79 -4.79 -4.79 -4.79 -4.78 -4.79 -4.78 -4.78 -4.78 -4.78 -4.77 -4.77 -4.77 -4.77 -4.77 -4.76 -4.76 -4.76 -4.76 -4.76 -4.75 -4.75 -4.75 -4.75 -4.75 -4.74 -4.74 -4.74 -4.74 -4.74 -4.74 -4.73 -4.74 -4.74 -4.73 -4.73 -4.73 -4.73 -4.73 -4.72 -4.73 -4.73 -4.73 -4.72 -4.72 -4.72 -4.72 -4.72 -4.72 -4.72 -4.72 -4.71 -4.71 -4.71 -4.71 -4.71 -4.70 -4.70 -4.70 -4.70 -4.70 -4.69 -4.69 -4.69 -4.69 -4.69 -4.69 -4.68 -4.68 +# Dev objf: -5.99 -5.65 -5.53 -5.44 -5.38 -5.34 -5.30 -5.27 -5.22 -5.20 -5.18 -5.16 -5.14 -5.12 -5.11 -5.10 -5.09 -5.08 -5.07 -5.05 -5.04 -5.04 -5.03 -5.01 -5.00 -4.99 -4.99 -4.98 -4.97 -4.97 0.00 -4.96 -4.95 -4.95 -4.94 -4.93 -4.93 -4.92 -4.92 -4.91 -4.91 -4.90 -4.90 -4.89 -4.89 -4.89 -4.88 -4.88 -4.88 -4.87 -4.87 -4.87 -4.86 -4.86 -4.85 -4.85 -4.87 -4.84 -4.84 -4.84 -4.83 -4.91 -4.83 -4.83 -4.83 -4.82 -4.82 -4.82 -4.82 -4.81 -4.81 -4.81 -4.80 -4.80 -4.80 -4.80 -4.80 -4.79 -4.79 -4.79 -4.79 -4.79 -4.79 -4.78 -4.78 -4.79 -4.78 -4.77 -4.77 -4.77 -4.77 -4.77 -4.77 -4.77 -4.76 -4.76 -4.76 -4.76 -4.76 -4.75 -4.75 -4.75 -4.75 -4.75 -4.75 -4.75 -4.75 -4.75 -4.75 -4.75 -4.75 -4.74 -4.74 -4.74 -4.74 -4.74 -4.74 -4.74 -4.73 -4.74 -4.73 -4.73 -4.73 -4.73 -4.73 -4.73 -4.72 -4.72 -4.72 -4.72 -4.72 -4.72 -4.72 -4.72 -4.71 -4.71 -4.71 -4.71 -4.71 -4.71 -4.71 -4.71 + +# WER summary on dev and test sets +# System tdnn_1d_sp +lattice_rescore +nbest_rescore +# WER on dev(fglarge) 3.34 2.71 2.62 +# WER on dev(tglarge) 3.44 2.75 2.66 +# WER on dev_other(fglarge) 8.70 7.37 7.55 +# WER on dev_other(tglarge) 9.25 7.56 7.73 +# WER on test(fglarge) 3.77 3.12 3.06 +# WER on test(tglarge) 3.85 3.18 3.11 +# WER on test_other(fglarge) 8.91 7.63 7.68 +# WER on test_other(tglarge) 9.31 7.83 7.95 + +# command to get the WERs above: +# tdnn_1d_sp +# for test in dev_clean test_clean dev_other test_other; do for lm in fglarge tglarge; do grep WER exp/chain_cleaned/tdnn_1d_sp/decode_${test}_${lm}/wer* | best_wer.sh; done; done +# tdnn_1d_sp with lattice rescoring +# for test in dev_clean test_clean dev_other test_other; do for lm in fglarge tglarge; do grep WER exp/chain_cleaned/tdnn_1d_sp/decode_${test}_${lm}_rnnlm_1a_rescore/wer* | best_wer.sh; done; done +# tdnn_1d_sp with nbest rescoring +# for test in dev_clean test_clean dev_other test_other; do for lm in fglarge tglarge; do grep WER exp/chain_cleaned/tdnn_1d_sp/decode_${test}_${lm}_rnnlm_1a_nbest_rescore/wer* | best_wer.sh; done; done + +# Begin configuration section. + +dir=exp/rnnlm_lstm_1a +embedding_dim=1024 +lstm_rpd=256 +lstm_nrpd=256 +stage=-10 +train_stage=-10 +epochs=4 + +# variables for lattice rescoring +run_lat_rescore=true +run_nbest_rescore=true +run_backward_rnnlm=false +ac_model_dir=exp/chain_cleaned/tdnn_1d_sp +decode_dir_suffix=rnnlm_1a +ngram_order=4 # approximate the lattice-rescoring by limiting the max-ngram-order + # if it's set, it merges histories in the lattice if they share + # the same ngram history and this prevents the lattice from + # exploding exponentially +pruned_rescore=true + +. ./cmd.sh +. ./utils/parse_options.sh + +text=data/local/lm/librispeech-lm-norm.txt.gz +lexicon=data/lang_nosp/words.txt +text_dir=data/rnnlm/text +mkdir -p $dir/config +set -e + +for f in $lexicon; do + [ ! -f $f ] && \ + echo "$0: expected file $f to exist; search for run.sh in run.sh" && exit 1 +done + +if [ $stage -le 0 ]; then + mkdir -p $text_dir + if [ ! -f $text ]; then + wget http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz -P data/local/lm + fi + echo -n >$text_dir/dev.txt + # hold out one in every 2000 lines as dev data. + gunzip -c $text | cut -d ' ' -f2- | awk -v text_dir=$text_dir '{if(NR%2000 == 0) { print >text_dir"/dev.txt"; } else {print;}}' >$text_dir/librispeech.txt +fi + +if [ $stage -le 1 ]; then + cp $lexicon $dir/config/ + n=`cat $dir/config/words.txt | wc -l` + echo " $n" >> $dir/config/words.txt + + # words that are not present in words.txt but are in the training or dev data, will be + # mapped to during training. + echo "" >$dir/config/oov.txt + + cat > $dir/config/data_weights.txt <$dir/config/unigram_probs.txt + + # choose features + rnnlm/choose_features.py --unigram-probs=$dir/config/unigram_probs.txt \ + --top-word-features=5000 \ + --use-constant-feature=true \ + --special-words=',,,,' \ + $dir/config/words.txt > $dir/config/features.txt + + cat >$dir/config/xconfig <$lat_dir/splice_opts - fi if [ $stage -le 3 ]; then @@ -133,7 +129,7 @@ if [ $stage -le 4 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) common1="required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" common2="required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" common3="required-time-offsets= height-offsets=-1,0,1 num-filters-out=70" @@ -185,7 +181,7 @@ if [ $stage -le 5 ]; then --chain.leaky-hmm-coefficient=0.1 \ --chain.l2-regularize=0.00005 \ --chain.apply-deriv-weights=false \ - --chain.lm-opts="--num-extra-lm-states=500" \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=1000" \ --chain.frame-subsampling-factor=$frame_subsampling_factor \ --chain.alignment-subsampling-factor=1 \ --chain.left-tolerance 3 \ @@ -193,7 +189,7 @@ if [ $stage -le 5 ]; then --trainer.srand=$srand \ --trainer.max-param-change=2.0 \ --trainer.num-epochs=4 \ - --trainer.frames-per-iter=1000000 \ + --trainer.frames-per-iter=2000000 \ --trainer.optimization.num-jobs-initial=3 \ --trainer.optimization.num-jobs-final=16 \ --trainer.optimization.initial-effective-lrate=0.001 \ @@ -201,11 +197,8 @@ if [ $stage -le 5 ]; then --trainer.optimization.shrink-value=1.0 \ --trainer.num-chunk-per-minibatch=64,32 \ --trainer.optimization.momentum=0.0 \ + --trainer.add-option="--optimization.memory-compression-level=2" \ --egs.chunk-width=$chunk_width \ - --egs.chunk-left-context=$chunk_left_context \ - --egs.chunk-right-context=$chunk_right_context \ - --egs.chunk-left-context-initial=0 \ - --egs.chunk-right-context-final=0 \ --egs.dir="$common_egs_dir" \ --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ --cleanup.remove-egs=$remove_egs \ @@ -226,18 +219,20 @@ if [ $stage -le 6 ]; then # as long as phones.txt was compatible. utils/mkgraph.sh \ - --self-loop-scale 1.0 data/$lang_test \ + --self-loop-scale 1.0 $lang_decode \ $dir $dir/graph || exit 1; fi if [ $stage -le 7 ]; then frames_per_chunk=$(echo $chunk_width | cut -d, -f1) steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ - --extra-left-context $chunk_left_context \ - --extra-right-context $chunk_right_context \ - --extra-left-context-initial 0 \ - --extra-right-context-final 0 \ --frames-per-chunk $frames_per_chunk \ --nj $nj --cmd "$cmd" \ $dir/graph data/test $dir/decode_test || exit 1; + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/test $dir/decode_test{,_rescored} || exit 1 fi + +echo "Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/madcat_ar/v1/local/chain/tuning/run_e2e_cnn_1a.sh b/egs/madcat_ar/v1/local/chain/tuning/run_e2e_cnn_1a.sh new file mode 100755 index 00000000000..3caf8ae4494 --- /dev/null +++ b/egs/madcat_ar/v1/local/chain/tuning/run_e2e_cnn_1a.sh @@ -0,0 +1,133 @@ +#!/bin/bash +# Copyright 2017 Hossein Hadian + +# This script does end2end chain training (i.e. from scratch) +# ./local/chain/compare_wer.sh exp/chain/e2e_cnn_1a/ +# System e2e_cnn_1a e2e_cnn_1a (with extra corpus text) +# WER 9.47 5.73 +# WER (rescored) 8.05 5.67 +# CER 2.45 1.45 +# CER (rescored) 2.10 1.42 +# Final train prob -0.0934 -0.0934 +# Final valid prob -0.0746 -0.0746 +# Final train prob (xent) +# Final valid prob (xent) +# Parameters 2.94M 2.94M + +# steps/info/chain_dir_info.pl exp/chain/e2e_cnn_1a/ +# exp/chain/e2e_cnn_1a/: num-iters=98 nj=6..16 num-params=2.9M dim=40->330 combine=-0.071->-0.070 (over 5) logprob:train/valid[64,97,final]=(-0.089,-0.084,-0.093/-0.075,-0.073,-0.075) +set -e + +# configs for 'chain' +stage=0 +train_stage=-10 +get_egs_stage=-10 +affix=1a + +# training options +tdnn_dim=450 +minibatch_size=150=128,64/300=128,64/600=64,32/1200=32,16 +common_egs_dir= +cmvn_opts="--norm-means=false --norm-vars=false" +train_set=train +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 1 ]; then + steps/nnet3/chain/e2e/prepare_e2e.sh --nj 30 --cmd "$cmd" \ + --shared-phones true \ + --type mono \ + data/$train_set $lang $treedir + $cmd $treedir/log/make_phone_lm.log \ + cat data/$train_set/text \| \ + steps/nnet3/chain/e2e/text_to_phones.py data/lang \| \ + utils/sym2int.pl -f 2- data/lang/phones.txt \| \ + chain-est-phone-lm --num-extra-lm-states=500 \ + ark:- $treedir/phone_lm.fst +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $treedir/tree | grep num-pdfs | awk '{print $2}') + common1="required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="required-time-offsets= height-offsets=-1,0,1 num-filters-out=70" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=40 name=input + conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-4,0,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs +fi + +if [ $stage -le 3 ]; then + steps/nnet3/chain/e2e/train_e2e.py --stage $train_stage \ + --cmd "$cmd" \ + --feat.cmvn-opts "$cmvn_opts" \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--num_egs_diagnostic 100 --num_utts_subset 400" \ + --chain.frame-subsampling-factor 4 \ + --chain.alignment-subsampling-factor 4 \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=1000" \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter 2000000 \ + --trainer.num-epochs 2 \ + --trainer.optimization.momentum 0 \ + --trainer.optimization.num-jobs-initial 6 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.optimization.shrink-value 1.0 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs true \ + --feat-dir data/${train_set} \ + --tree-dir $treedir \ + --dir $dir || exit 1; +fi diff --git a/egs/madcat_ar/v1/local/create_line_image_from_page_image.py b/egs/madcat_ar/v1/local/create_line_image_from_page_image.py index ba35f8b9ace..650a0704d80 100755 --- a/egs/madcat_ar/v1/local/create_line_image_from_page_image.py +++ b/egs/madcat_ar/v1/local/create_line_image_from_page_image.py @@ -13,6 +13,7 @@ be vertically or horizontally aligned). Hence to extract line image from line bounding box, page image is rotated and line image is cropped and saved. """ +from __future__ import division import sys import argparse @@ -21,22 +22,10 @@ import numpy as np from math import atan2, cos, sin, pi, degrees, sqrt from collections import namedtuple - +import random from scipy.spatial import ConvexHull from PIL import Image from scipy.misc import toimage -import logging - -sys.path.insert(0, 'steps') -logger = logging.getLogger('libs') -logger.setLevel(logging.INFO) -handler = logging.StreamHandler() -handler.setLevel(logging.INFO) -formatter = logging.Formatter("%(asctime)s [%(pathname)s:%(lineno)s - " - "%(funcName)s - %(levelname)s ] %(message)s") -handler.setFormatter(formatter) -logger.addHandler(handler) - parser = argparse.ArgumentParser(description="Creates line images from page image", epilog="E.g. " + sys.argv[0] + " data/LDC2012T15" " data/LDC2013T09 data/LDC2013T15 data/madcat.train.raw.lineid " @@ -60,6 +49,12 @@ help='Path to the downloaded (and extracted) writing conditions file 3') parser.add_argument('--padding', type=int, default=400, help='padding across horizontal/verticle direction') +parser.add_argument('--pixel-scaling', type=int, default=30, + help='padding across horizontal/verticle direction') +parser.add_argument("--subset", type=lambda x: (str(x).lower()=='true'), default=False, + help="only processes subset of data based on writing condition") +parser.add_argument("--augment", type=lambda x: (str(x).lower()=='true'), default=False, + help="performs image augmentation") args = parser.parse_args() """ @@ -93,8 +88,8 @@ def unit_vector(pt0, pt1): (float, float): unit vector """ dis_0_to_1 = sqrt((pt0[0] - pt1[0])**2 + (pt0[1] - pt1[1])**2) - return (pt1[0] - pt0[0]) / dis_0_to_1, \ - (pt1[1] - pt0[1]) / dis_0_to_1 + return (pt1[0] - pt0[0])/ dis_0_to_1, \ + (pt1[1] - pt0[1])/ dis_0_to_1 def orthogonal_vector(vector): @@ -136,7 +131,7 @@ def bounding_area(index, hull): return {'area': len_p * len_o, 'length_parallel': len_p, 'length_orthogonal': len_o, - 'rectangle_center': (min_p + len_p / 2, min_o + len_o / 2), + 'rectangle_center': (min_p + float(len_p)/ 2, min_o + float(len_o)/ 2), 'unit_vector': unit_vector_p, } @@ -149,7 +144,7 @@ def to_xy_coordinates(unit_vector_angle, point): ------ (float, float): converted x,y coordinate of the unit vector. """ - angle_orthogonal = unit_vector_angle + pi / 2 + angle_orthogonal = unit_vector_angle + pi/ 2 return point[0] * cos(unit_vector_angle) + point[1] * cos(angle_orthogonal), \ point[0] * sin(unit_vector_angle) + point[1] * sin(angle_orthogonal) @@ -194,65 +189,6 @@ def rectangle_corners(rectangle): return rotate_points(rectangle['rectangle_center'], rectangle['unit_vector_angle'], corner_points) -def get_orientation(origin, p1, p2): - """ - Given origin and two points, return the orientation of the Point p1 with - regards to Point p2 using origin. - Returns - ------- - integer: Negative if p1 is clockwise of p2. - """ - difference = ( - ((p2[0] - origin[0]) * (p1[1] - origin[1])) - - ((p1[0] - origin[0]) * (p2[1] - origin[1])) - ) - return difference - - -def compute_hull(points): - """ - Given input list of points, return a list of points that - made up the convex hull. - Returns - ------- - [(float, float)]: convexhull points - """ - hull_points = [] - start = points[0] - min_x = start[0] - for p in points[1:]: - if p[0] < min_x: - min_x = p[0] - start = p - - point = start - hull_points.append(start) - - far_point = None - while far_point is not start: - p1 = None - for p in points: - if p is point: - continue - else: - p1 = p - break - - far_point = p1 - - for p2 in points: - if p2 is point or p2 is p1: - continue - else: - direction = get_orientation(point, far_point, p2) - if direction > 0: - far_point = p2 - - hull_points.append(far_point) - point = far_point - return hull_points - - def minimum_bounding_box(points): """ Given a list of 2D points, it returns the minimum area rectangle bounding all the points in the point cloud. @@ -272,7 +208,6 @@ def minimum_bounding_box(points): hull_ordered = [points[index] for index in ConvexHull(points).vertices] hull_ordered.append(hull_ordered[0]) - #hull_ordered = compute_hull(points) hull_ordered = tuple(hull_ordered) min_rectangle = bounding_area(0, hull_ordered) @@ -301,8 +236,8 @@ def get_center(im): ------- (int, int): center of the image """ - center_x = im.size[0] / 2 - center_y = im.size[1] / 2 + center_x = float(im.size[0])/ 2 + center_y = float(im.size[1])/ 2 return int(center_x), int(center_y) @@ -314,9 +249,9 @@ def get_horizontal_angle(unit_vector_angle): (float): updated angle of the unit vector to be in radians. It is only in first or fourth quadrant. """ - if unit_vector_angle > pi / 2 and unit_vector_angle <= pi: + if unit_vector_angle > pi/ 2 and unit_vector_angle <= pi: unit_vector_angle = unit_vector_angle - pi - elif unit_vector_angle > -pi and unit_vector_angle < -pi / 2: + elif unit_vector_angle > -pi and unit_vector_angle < -pi/ 2: unit_vector_angle = unit_vector_angle + pi return unit_vector_angle @@ -400,6 +335,36 @@ def update_minimum_bounding_box_input(bounding_box_input): return updated_minimum_bounding_box_input +def dilate_polygon(points, amount_increase): + """ Increases size of polygon given as a list of tuples. + Assumes points in polygon are given in CCW + """ + expanded_points = [] + for index, point in enumerate(points): + prev_point = points[(index - 1) % len(points)] + next_point = points[(index + 1) % len(points)] + prev_edge = np.subtract(point, prev_point) + next_edge = np.subtract(next_point, point) + + prev_normal = ((1 * prev_edge[1]), (-1 * prev_edge[0])) + prev_normal = np.divide(prev_normal, np.linalg.norm(prev_normal)) + next_normal = ((1 * next_edge[1]), (-1 * next_edge[0])) + next_normal = np.divide(next_normal, np.linalg.norm(next_normal)) + + bisect = np.add(prev_normal, next_normal) + bisect = np.divide(bisect, np.linalg.norm(bisect)) + + cos_theta = np.dot(next_normal, bisect) + hyp = float(amount_increase)/ cos_theta + + new_point = np.around(point + hyp * bisect) + new_point = new_point.astype(int) + new_point = new_point.tolist() + new_point = tuple(new_point) + expanded_points.append(new_point) + return expanded_points + + def set_line_image_data(image, line_id, image_file_name, image_fh): """ Given an image, saves a flipped line image. Line image file name is formed by appending the line id at the end page image name. @@ -438,50 +403,83 @@ def get_line_images_from_page_image(image_file_name, madcat_file_path, image_fh) word_coordinate = (int(word_node.getAttribute('x')), int(word_node.getAttribute('y'))) minimum_bounding_box_input.append(word_coordinate) updated_mbb_input = update_minimum_bounding_box_input(minimum_bounding_box_input) - bounding_box = minimum_bounding_box(updated_mbb_input) - - p1, p2, p3, p4 = bounding_box.corner_points - x1, y1 = p1 - x2, y2 = p2 - x3, y3 = p3 - x4, y4 = p4 - min_x = int(min(x1, x2, x3, x4)) - min_y = int(min(y1, y2, y3, y4)) - max_x = int(max(x1, x2, x3, x4)) - max_y = int(max(y1, y2, y3, y4)) - box = (min_x, min_y, max_x, max_y) - region_initial = im.crop(box) - rot_points = [] - p1_new = (x1 - min_x, y1 - min_y) - p2_new = (x2 - min_x, y2 - min_y) - p3_new = (x3 - min_x, y3 - min_y) - p4_new = (x4 - min_x, y4 - min_y) - rot_points.append(p1_new) - rot_points.append(p2_new) - rot_points.append(p3_new) - rot_points.append(p4_new) - - cropped_bounding_box = bounding_box_tuple(bounding_box.area, - bounding_box.length_parallel, - bounding_box.length_orthogonal, - bounding_box.length_orthogonal, - bounding_box.unit_vector, - bounding_box.unit_vector_angle, - set(rot_points) - ) - - rotation_angle_in_rad = get_smaller_angle(cropped_bounding_box) - img2 = region_initial.rotate(degrees(rotation_angle_in_rad), resample = Image.BICUBIC) - x_dash_1, y_dash_1, x_dash_2, y_dash_2, x_dash_3, y_dash_3, x_dash_4, y_dash_4 = rotated_points( + points_ordered = [updated_mbb_input[index] for index in ConvexHull(updated_mbb_input).vertices] + if args.augment: + for i in range(0, 3): + additional_pixel = random.randint(1, args.pixel_scaling) + mar = dilate_polygon(points_ordered, (i-1)*args.pixel_scaling + additional_pixel + 1) + bounding_box = minimum_bounding_box(mar) + (x1, y1), (x2, y2), (x3, y3), (x4, y4) = bounding_box.corner_points + min_x, min_y = int(min(x1, x2, x3, x4)), int(min(y1, y2, y3, y4)) + max_x, max_y = int(max(x1, x2, x3, x4)), int(max(y1, y2, y3, y4)) + box = (min_x, min_y, max_x, max_y) + region_initial = im.crop(box) + rot_points = [] + p1, p2 = (x1 - min_x, y1 - min_y), (x2 - min_x, y2 - min_y) + p3, p4 = (x3 - min_x, y3 - min_y), (x4 - min_x, y4 - min_y) + rot_points.append(p1) + rot_points.append(p2) + rot_points.append(p3) + rot_points.append(p4) + + cropped_bounding_box = bounding_box_tuple(bounding_box.area, + bounding_box.length_parallel, + bounding_box.length_orthogonal, + bounding_box.length_orthogonal, + bounding_box.unit_vector, + bounding_box.unit_vector_angle, + set(rot_points) + ) + + rotation_angle_in_rad = get_smaller_angle(cropped_bounding_box) + img2 = region_initial.rotate(degrees(rotation_angle_in_rad), resample = Image.BICUBIC) + x_dash_1, y_dash_1, x_dash_2, y_dash_2, x_dash_3, y_dash_3, x_dash_4, y_dash_4 = rotated_points( + cropped_bounding_box, get_center(region_initial)) + + min_x = int(min(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) + min_y = int(min(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) + max_x = int(max(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) + max_y = int(max(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) + box = (min_x, min_y, max_x, max_y) + region_final = img2.crop(box) + line_id = id + '_scale' + str(i) + set_line_image_data(region_final, line_id, image_file_name, image_fh) + else: + bounding_box = minimum_bounding_box(points_ordered) + (x1, y1), (x2, y2), (x3, y3), (x4, y4) = bounding_box.corner_points + min_x, min_y = int(min(x1, x2, x3, x4)), int(min(y1, y2, y3, y4)) + max_x, max_y = int(max(x1, x2, x3, x4)), int(max(y1, y2, y3, y4)) + box = (min_x, min_y, max_x, max_y) + region_initial = im.crop(box) + rot_points = [] + p1, p2 = (x1 - min_x, y1 - min_y), (x2 - min_x, y2 - min_y) + p3, p4 = (x3 - min_x, y3 - min_y), (x4 - min_x, y4 - min_y) + rot_points.append(p1) + rot_points.append(p2) + rot_points.append(p3) + rot_points.append(p4) + + cropped_bounding_box = bounding_box_tuple(bounding_box.area, + bounding_box.length_parallel, + bounding_box.length_orthogonal, + bounding_box.length_orthogonal, + bounding_box.unit_vector, + bounding_box.unit_vector_angle, + set(rot_points) + ) + + rotation_angle_in_rad = get_smaller_angle(cropped_bounding_box) + img2 = region_initial.rotate(degrees(rotation_angle_in_rad), resample = Image.BICUBIC) + x_dash_1, y_dash_1, x_dash_2, y_dash_2, x_dash_3, y_dash_3, x_dash_4, y_dash_4 = rotated_points( cropped_bounding_box, get_center(region_initial)) - min_x = int(min(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) - min_y = int(min(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) - max_x = int(max(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) - max_y = int(max(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) - box = (min_x, min_y, max_x, max_y) - region_final = img2.crop(box) - set_line_image_data(region_final, id, image_file_name, image_fh) + min_x = int(min(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) + min_y = int(min(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) + max_x = int(max(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) + max_y = int(max(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) + box = (min_x, min_y, max_x, max_y) + region_final = img2.crop(box) + set_line_image_data(region_final, id, image_file_name, image_fh) def check_file_location(base_name, wc_dict1, wc_dict2, wc_dict3): @@ -535,16 +533,16 @@ def check_writing_condition(wc_dict, base_name): Returns (bool): True if writing condition matches. """ - return True - writing_condition = wc_dict[base_name].strip() - if writing_condition != 'IUC': - return False - - return True - + if args.subset: + writing_condition = wc_dict[base_name].strip() + if writing_condition != 'IUC': + return False + else: + return True + else: + return True ### main ### - def main(): wc_dict1 = parse_writing_conditions(args.writing_condition1) @@ -564,8 +562,7 @@ def main(): madcat_file_path, image_file_path, wc_dict = check_file_location(base_name, wc_dict1, wc_dict2, wc_dict3) if wc_dict is None or not check_writing_condition(wc_dict, base_name): continue - if madcat_file_path is not None: - get_line_images_from_page_image(image_file_path, madcat_file_path, image_fh) + get_line_images_from_page_image(image_file_path, madcat_file_path, image_fh) if __name__ == '__main__': diff --git a/egs/madcat_ar/v1/local/download_data.sh b/egs/madcat_ar/v1/local/download_data.sh deleted file mode 100755 index 7061be49c2a..00000000000 --- a/egs/madcat_ar/v1/local/download_data.sh +++ /dev/null @@ -1,40 +0,0 @@ -#!/bin/bash - -# Copyright 2018 Ashish Arora -# Apache 2.0 - -# This script downloads data splits for MADCAT Arabic dataset. -# It also check if madcat arabic data is present or not. - -download_dir1=/export/corpora/LDC/LDC2012T15/data -download_dir2=/export/corpora/LDC/LDC2013T09/data -download_dir3=/export/corpora/LDC/LDC2013T15/data -train_split_url=http://www.openslr.org/resources/48/madcat.train.raw.lineid -test_split_url=http://www.openslr.org/resources/48/madcat.test.raw.lineid -dev_split_url=http://www.openslr.org/resources/48/madcat.dev.raw.lineid -data_splits=data/download/data_splits - -. ./cmd.sh -. ./path.sh -. ./utils/parse_options.sh || exit 1; - -if [ -d $data_splits ]; then - echo "$0: Not downloading the data splits as it is already there." -else - if [ ! -f $data_splits/madcat.train.raw.lineid ]; then - mkdir -p $data_splits - echo "$0: Downloading the data splits..." - wget -P $data_splits $train_split_url || exit 1; - wget -P $data_splits $test_split_url || exit 1; - wget -P $data_splits $dev_split_url || exit 1; - fi - echo "$0: Done downloading the data splits" -fi - -if [ -d $download_dir1 ]; then - echo "$0: madcat arabic data directory is present." -else - if [ ! -f $download_dir1/madcat/*.madcat.xml ]; then - echo "$0: please download madcat data..." - fi -fi diff --git a/egs/madcat_ar/v1/local/extract_features.sh b/egs/madcat_ar/v1/local/extract_features.sh index 70c5498626c..9fe588f31b8 100755 --- a/egs/madcat_ar/v1/local/extract_features.sh +++ b/egs/madcat_ar/v1/local/extract_features.sh @@ -1,10 +1,16 @@ #!/bin/bash + # Copyright 2017 Yiwen Shao # 2018 Ashish Arora +# Apache 2.0 +# This script runs the make features script in parallel. + nj=4 cmd=run.pl feat_dim=40 +augment='no_aug' +verticle_shift=0 echo "$0 $@" . ./cmd.sh @@ -30,9 +36,10 @@ done utils/split_scp.pl $scp $split_scps || exit 1; $cmd JOB=1:$nj $logdir/extract_features.JOB.log \ - local/make_features.py $logdir/images.JOB.scp \ + image/ocr/make_features.py $logdir/images.JOB.scp \ --allowed_len_file_path $data/allowed_lengths.txt \ - --feat-dim $feat_dim \| \ + --feat-dim $feat_dim --augment_type $augment \ + --vertical-shift $verticle_shift \| \ copy-feats --compress=true --compression-method=7 \ ark:- ark,scp:$featdir/images.JOB.ark,$featdir/images.JOB.scp diff --git a/egs/madcat_ar/v1/local/extract_lines.sh b/egs/madcat_ar/v1/local/extract_lines.sh index 50129ad38c9..ab87836ae3a 100755 --- a/egs/madcat_ar/v1/local/extract_lines.sh +++ b/egs/madcat_ar/v1/local/extract_lines.sh @@ -11,6 +11,8 @@ writing_condition2=/export/corpora/LDC/LDC2013T09/docs/writing_conditions.tab writing_condition3=/export/corpora/LDC/LDC2013T15/docs/writing_conditions.tab data_split_file=data/download/data_splits/madcat.dev.raw.lineid data=data/local/dev +subset=false +augment=false echo "$0 $@" . ./cmd.sh @@ -35,7 +37,7 @@ done $cmd JOB=1:$nj $log_dir/extract_lines.JOB.log \ local/create_line_image_from_page_image.py $download_dir1 $download_dir2 $download_dir3 \ $log_dir/lines.JOB.scp $data/JOB $writing_condition1 $writing_condition2 $writing_condition3 \ - || exit 1; + --subset $subset --augment $augment || exit 1; ## concatenate the .scp files together. for n in $(seq $nj); do diff --git a/egs/madcat_ar/v1/local/prepare_data.sh b/egs/madcat_ar/v1/local/prepare_data.sh index d808d736845..1049db9826d 100755 --- a/egs/madcat_ar/v1/local/prepare_data.sh +++ b/egs/madcat_ar/v1/local/prepare_data.sh @@ -5,49 +5,65 @@ # 2017 Hossein Hadian # Apache 2.0 -# This script prepares the training and test data for MADCAT Arabic dataset -# (i.e text, images.scp, utt2spk and spk2utt). It calls process_data.py. +# This script downloads the data splits for MADCAT Arabic dataset and prepares the training +# validation, and test data (i.e text, images.scp, utt2spk and spk2utt) by calling process_data.py. +# It also uses Arabic Gigaword text corpus for language modeling. # Eg. local/prepare_data.sh -# Eg. text file: LDC0001_000404_NHR_ARB_20070113.0052_11_LDC0001_00z2 ﻮﺠﻫ ﻮﻌﻘﻟ ﻍﺍﺮﻗ ﺢﺗّﻯ ﺎﻠﻨﺧﺎﻋ +# Eg. text file: LDC0001_000399_NHR_ARB_20070113.0052_11_LDC0001_0z11 +# وهناك تداخل بين الرأسمالية الإسرائيلية # utt2spk file: LDC0001_000397_NHR_ARB_20070113.0052_11_LDC0001_00z1 LDC0001 -# images.scp file: LDC0009_000000_arb-NG-2-76513-5612324_2_LDC0009_00z0 -# data/local/lines/1/arb-NG-2-76513-5612324_2_LDC0009_00z0.tif +# images.scp file: LDC0001_000397_NHR_ARB_20070113.0052_11_LDC0001_00z1 +# data/local/train/1/NHR_ARB_20070113.0052_11_LDC0001_00z1.png -stage=0 download_dir1=/export/corpora/LDC/LDC2012T15/data download_dir2=/export/corpora/LDC/LDC2013T09/data download_dir3=/export/corpora/LDC/LDC2013T15/data -writing_condition1=/export/corpora/LDC/LDC2012T15/docs/writing_conditions.tab -writing_condition2=/export/corpora/LDC/LDC2013T09/docs/writing_conditions.tab -writing_condition3=/export/corpora/LDC/LDC2013T15/docs/writing_conditions.tab -data_splits_dir=data/download/data_splits -images_scp_dir=data/local +train_split_url=http://www.openslr.org/resources/48/madcat.train.raw.lineid +test_split_url=http://www.openslr.org/resources/48/madcat.test.raw.lineid +dev_split_url=http://www.openslr.org/resources/48/madcat.dev.raw.lineid +data_splits=data/download/data_splits +stage=0 +download_dir=data/download +gigacorpus=data/local/gigawordcorpus +gigaword_loc=/export/corpora5/LDC/LDC2011T11 +use_extra_corpus_text=true . ./cmd.sh . ./path.sh . ./utils/parse_options.sh || exit 1; -mkdir -p data/{train,test,dev} - -if [ $stage -le 1 ]; then - echo "$0: Processing dev, train and test data..." - echo "Date: $(date)." - local/process_data.py $download_dir1 $download_dir2 $download_dir3 \ - $data_splits_dir/madcat.dev.raw.lineid data/dev $images_scp_dir/dev/images.scp \ - $writing_condition1 $writing_condition2 $writing_condition3 || exit 1 - - local/process_data.py $download_dir1 $download_dir2 $download_dir3 \ - $data_splits_dir/madcat.test.raw.lineid data/test $images_scp_dir/test/images.scp \ - $writing_condition1 $writing_condition2 $writing_condition3 || exit 1 +if [ -d $data_splits ]; then + echo "$0: Not downloading the data splits as it is already there." +else + if [ ! -f $data_splits/madcat.train.raw.lineid ]; then + mkdir -p $data_splits + echo "$0: Downloading the data splits..." + wget -P $data_splits $train_split_url || exit 1; + wget -P $data_splits $test_split_url || exit 1; + wget -P $data_splits $dev_split_url || exit 1; + fi + echo "$0: Done downloading the data splits" +fi - local/process_data.py $download_dir1 $download_dir2 $download_dir3 \ - $data_splits_dir/madcat.train.raw.lineid data/train $images_scp_dir/train/images.scp \ - $writing_condition1 $writing_condition2 $writing_condition3 || exit 1 +if [ -d $download_dir1 ]; then + echo "$0: madcat arabic data directory is present." +else + if [ ! -f $download_dir1/madcat/*.madcat.xml ]; then + echo "$0: please download madcat data..." + fi +fi - for dataset in dev test train; do - echo "$0: Fixing data directory for dataset: $dataset" - echo "Date: $(date)." - image/fix_data_dir.sh data/$dataset +mkdir -p $download_dir data/local +if $use_extra_corpus_text; then + mkdir -p $gigacorpus + cp -r $gigaword_loc/. $gigacorpus + for newswire in aaw_arb afp_arb ahr_arb asb_arb hyt_arb nhr_arb qds_arb umh_arb xin_arb; do + for file in $gigacorpus/arb_gw_5/data/$newswire/*.gz; do + gzip -d $file + done + for file in $gigacorpus/arb_gw_5/data/$newswire/*; do + sed -e '/^<[^>]*>$/d; s/``/"/g; s/\x27\x27/"/g' $file >> $gigacorpus/arb_gw_5/data/${newswire}_combined.txt + done done fi diff --git a/egs/madcat_ar/v1/local/prepend_words.py b/egs/madcat_ar/v1/local/prepend_words.py deleted file mode 100755 index d53eb8974bf..00000000000 --- a/egs/madcat_ar/v1/local/prepend_words.py +++ /dev/null @@ -1,13 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# This script, prepend '|' to every words in the transcript to mark -# the beginning of the words for finding the initial-space of every word -# after decoding. - -import sys, io - -infile = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') -output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') -for line in infile: - output.write(' '.join(["|" + word for word in line.split()]) + '\n') diff --git a/egs/madcat_ar/v1/local/process_data.py b/egs/madcat_ar/v1/local/process_data.py index b57500cf2fa..a39bcfa87d3 100755 --- a/egs/madcat_ar/v1/local/process_data.py +++ b/egs/madcat_ar/v1/local/process_data.py @@ -24,24 +24,28 @@ " data/LDC2013T09 data/LDC2013T15 data/madcat.train.raw.lineid " " data/train data/local/lines ", formatter_class=argparse.ArgumentDefaultsHelpFormatter) -parser.add_argument('database_path1', type=str, +parser.add_argument('database_path1', help='Path to the downloaded (and extracted) madcat data') -parser.add_argument('database_path2', type=str, +parser.add_argument('database_path2', help='Path to the downloaded (and extracted) madcat data') -parser.add_argument('database_path3', type=str, +parser.add_argument('database_path3', help='Path to the downloaded (and extracted) madcat data') -parser.add_argument('data_splits', type=str, +parser.add_argument('data_splits', help='Path to file that contains the train/test/dev split information') -parser.add_argument('out_dir', type=str, +parser.add_argument('out_dir', help='directory location to write output files.') -parser.add_argument('images_scp_path', type=str, +parser.add_argument('images_scp_path', help='Path of input images.scp file(maps line image and location)') -parser.add_argument('writing_condition1', type=str, +parser.add_argument('writing_condition1', help='Path to the downloaded (and extracted) writing conditions file 1') -parser.add_argument('writing_condition2', type=str, +parser.add_argument('writing_condition2', help='Path to the downloaded (and extracted) writing conditions file 2') -parser.add_argument('writing_condition3', type=str, +parser.add_argument('writing_condition3', help='Path to the downloaded (and extracted) writing conditions file 3') +parser.add_argument("--augment", type=lambda x: (str(x).lower()=='true'), default=False, + help="performs image augmentation") +parser.add_argument("--subset", type=lambda x: (str(x).lower()=='true'), default=False, + help="only processes subset of data based on writing condition") args = parser.parse_args() @@ -97,50 +101,42 @@ def check_writing_condition(wc_dict): Returns: (bool): True if writing condition matches. """ - return True - writing_condition = wc_dict[base_name].strip() - if writing_condition != 'IUC': - return False + if args.subset: + writing_condition = wc_dict[base_name].strip() + if writing_condition != 'IUC': + return False + else: + return True + else: + return True - return True - -def get_word_line_mapping(madcat_file_path): +def read_text(madcat_file_path): """ Maps every word in the page image to a corresponding line. Args: - madcat_file_path (string): complete path and name of the madcat xml file + madcat_file_path (string): complete path and name of the madcat xml file corresponding to the page image. Returns: + dict: Mapping every word in the page image to a corresponding line. """ + + word_line_dict = dict() doc = minidom.parse(madcat_file_path) zone = doc.getElementsByTagName('zone') for node in zone: line_id = node.getAttribute('id') - line_word_dict[line_id] = list() word_image = node.getElementsByTagName('token-image') for tnode in word_image: word_id = tnode.getAttribute('id') - line_word_dict[line_id].append(word_id) word_line_dict[word_id] = line_id - -def read_text(madcat_file_path): - """ Maps every word in the page image to a corresponding line. - Args: - madcat_file_path (string): complete path and name of the madcat xml file - corresponding to the page image. - Returns: - dict: Mapping every word in the page image to a corresponding line. - """ text_line_word_dict = dict() - doc = minidom.parse(madcat_file_path) segment = doc.getElementsByTagName('segment') for node in segment: token = node.getElementsByTagName('token') for tnode in token: ref_word_id = tnode.getAttribute('ref_id') word = tnode.getElementsByTagName('source')[0].firstChild.nodeValue - word = unicodedata.normalize('NFKC',word) ref_line_id = word_line_dict[ref_word_id] if ref_line_id not in text_line_word_dict: text_line_word_dict[ref_line_id] = list() @@ -160,7 +156,6 @@ def get_line_image_location(): ### main ### - print("Processing '{}' data...".format(args.out_dir)) text_file = os.path.join(args.out_dir, 'text') @@ -188,23 +183,34 @@ def get_line_image_location(): madcat_xml_path, image_file_path, wc_dict = check_file_location() if wc_dict is None or not check_writing_condition(wc_dict): continue - if madcat_xml_path is not None: - madcat_doc = minidom.parse(madcat_xml_path) - writer = madcat_doc.getElementsByTagName('writer') - writer_id = writer[0].getAttribute('id') - line_word_dict = dict() - word_line_dict = dict() - get_word_line_mapping(madcat_xml_path) - text_line_word_dict = read_text(madcat_xml_path) - base_name = os.path.basename(image_file_path) - base_name, b = base_name.split('.tif') - for lineID in sorted(text_line_word_dict): - updated_base_name = base_name + '_' + str(lineID).zfill(4) +'.png' + madcat_doc = minidom.parse(madcat_xml_path) + writer = madcat_doc.getElementsByTagName('writer') + writer_id = writer[0].getAttribute('id') + text_line_word_dict = read_text(madcat_xml_path) + base_name = os.path.basename(image_file_path).split('.tif')[0] + for line_id in sorted(text_line_word_dict): + if args.augment: + key = (line_id + '.')[:-1] + for i in range(0, 3): + location_id = "_{}_scale{}".format(line_id, i) + line_image_file_name = base_name + location_id + '.png' + location = image_loc_dict[line_image_file_name] + image_file_path = os.path.join(location, line_image_file_name) + line = text_line_word_dict[key] + text = ' '.join(line) + base_line_image_file_name = line_image_file_name.split('.png')[0] + utt_id = "{}_{}_{}".format(writer_id, str(image_num).zfill(6), base_line_image_file_name) + text_fh.write(utt_id + ' ' + text + '\n') + utt2spk_fh.write(utt_id + ' ' + writer_id + '\n') + image_fh.write(utt_id + ' ' + image_file_path + '\n') + image_num += 1 + else: + updated_base_name = "{}_{}.png".format(base_name, str(line_id).zfill(4)) location = image_loc_dict[updated_base_name] image_file_path = os.path.join(location, updated_base_name) - line = text_line_word_dict[lineID] + line = text_line_word_dict[line_id] text = ' '.join(line) - utt_id = writer_id + '_' + str(image_num).zfill(6) + '_' + base_name + '_' + str(lineID).zfill(4) + utt_id = "{}_{}_{}_{}".format(writer_id, str(image_num).zfill(6), base_name, str(line_id).zfill(4)) text_fh.write(utt_id + ' ' + text + '\n') utt2spk_fh.write(utt_id + ' ' + writer_id + '\n') image_fh.write(utt_id + ' ' + image_file_path + '\n') diff --git a/egs/madcat_ar/v1/local/score.sh b/egs/madcat_ar/v1/local/score.sh index 2c11aba3e13..31564d25326 100755 --- a/egs/madcat_ar/v1/local/score.sh +++ b/egs/madcat_ar/v1/local/score.sh @@ -1,5 +1,5 @@ #!/bin/bash -steps/scoring/score_kaldi_wer.sh --word_ins_penalty 0.0,0.5,1.0,1.5,2.0,2.5,3.0,3.5,4.0,4.5,5.0,5.5,6.0,6.5,7.0 "$@" -steps/scoring/score_kaldi_cer.sh --stage 2 --word_ins_penalty 0.0,0.5,1.0,1.5,2.0,2.5,3.0,3.5,4.0,4.5,5.0,5.5,6.0,6.5,7.0 "$@" +steps/scoring/score_kaldi_wer.sh "$@" +steps/scoring/score_kaldi_cer.sh --stage 2 "$@" diff --git a/egs/madcat_ar/v1/local/tl/augment_data.sh b/egs/madcat_ar/v1/local/tl/augment_data.sh new file mode 100755 index 00000000000..cc44aa58a62 --- /dev/null +++ b/egs/madcat_ar/v1/local/tl/augment_data.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# Copyright 2018 Hossein Hadian +# 2018 Ashish Arora + +# Apache 2.0 +# This script performs data augmentation. + +nj=4 +cmd=run.pl +feat_dim=40 +verticle_shift=0 +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +srcdir=$1 +outdir=$2 +datadir=$3 +aug_set=aug1 +mkdir -p $datadir/augmentations +echo "copying $srcdir to $datadir/augmentations/$aug_set, allowed length, creating feats.scp" + +for set in $aug_set; do + image/copy_data_dir.sh --spk-prefix $set- --utt-prefix $set- \ + $srcdir $datadir/augmentations/$set + cat $srcdir/allowed_lengths.txt > $datadir/augmentations/$set/allowed_lengths.txt + local/extract_features.sh --nj $nj --cmd "$cmd" --feat-dim $feat_dim \ + --vertical-shift $verticle_shift \ + --augment 'random_shift' $datadir/augmentations/$set +done + +echo " combine original data and data from different augmentations" +utils/combine_data.sh --extra-files images.scp $outdir $srcdir $datadir/augmentations/$aug_set +cat $srcdir/allowed_lengths.txt > $outdir/allowed_lengths.txt diff --git a/egs/madcat_ar/v1/local/tl/chain/run_cnn_e2eali.sh b/egs/madcat_ar/v1/local/tl/chain/run_cnn_e2eali.sh new file mode 100755 index 00000000000..ccbb7119674 --- /dev/null +++ b/egs/madcat_ar/v1/local/tl/chain/run_cnn_e2eali.sh @@ -0,0 +1,229 @@ +#!/bin/bash + +# ./local/chain/compare_wer.sh exp/chain/cnn_e2eali_1a/ +# System cnn_e2eali_1a +# WER 16.78 +# CER 5.22 +# Final train prob -0.1189 +# Final valid prob -0.1319 +# Final train prob (xent) -0.6395 +# Final valid prob (xent) -0.6732 +# Parameters 3.73M + +# steps/info/chain_dir_info.pl exp/chain/cnn_e2eali_1a/ +# exp/chain/cnn_e2eali_1a/: num-iters=24 nj=3..15 num-params=3.7M dim=56->392 combine=-0.125->-0.125 (over 1) xent:train/valid[15,23,final]=(-0.850,-1.24,-0.640/-0.901,-1.31,-0.673) logprob:train/valid[15,23,final]=(-0.149,-0.209,-0.119/-0.166,-0.229,-0.132) +set -e -o pipefail + +stage=0 + +nj=30 +train_set=train +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. +affix=_1a #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +common_egs_dir= +reporting_email= + +# chain options +train_stage=-10 +xent_regularize=0.1 +# training chunk-options +chunk_width=340,300,200,100 +num_leaves=500 +tdnn_dim=450 +srand=0 +remove_egs=true +lang_decode=data/lang +# End configuration section. +echo "$0 $@" # Print the command line for logging + + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 2 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/nnet3/align_lats.sh --nj $nj --cmd "$cmd" \ + --acoustic-scale 1.0 \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0' \ + ${train_data_dir} data/lang $e2echain_model_dir $lat_dir + echo "" >$lat_dir/splice_opts + +fi + +if [ $stage -le 3 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor 4 \ + --alignment-subsampling-factor 1 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$cmd" $num_leaves ${train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 4 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + cnn_opts="l2-regularize=0.075" + tdnn_opts="l2-regularize=0.075" + output_opts="l2-regularize=0.1" + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=70" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=56 name=input + conv-relu-batchnorm-layer name=cnn1 height-in=56 height-out=56 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=56 height-out=28 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=28 height-out=28 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=28 height-out=28 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=28 height-out=14 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=14 height-out=14 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=14 height-out=14 time-offsets=-4,0,4 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' mod?els... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 5 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/iam-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=1000" \ + --chain.frame-subsampling-factor=4 \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=2 \ + --trainer.frames-per-iter=2000000 \ + --trainer.optimization.num-jobs-initial=3 \ + --trainer.optimization.num-jobs-final=16 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=64,32 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 6 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $lang_decode \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 7 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/test $dir/decode_test || exit 1; +fi + +echo "Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/madcat_ar/v1/local/chain/run_flatstart_cnn1a.sh b/egs/madcat_ar/v1/local/tl/chain/run_e2e_cnn.sh similarity index 78% rename from egs/madcat_ar/v1/local/chain/run_flatstart_cnn1a.sh rename to egs/madcat_ar/v1/local/tl/chain/run_e2e_cnn.sh index 2c85e982ce6..3fca8cf5fdc 100755 --- a/egs/madcat_ar/v1/local/chain/run_flatstart_cnn1a.sh +++ b/egs/madcat_ar/v1/local/tl/chain/run_e2e_cnn.sh @@ -3,40 +3,37 @@ # This script does end2end chain training (i.e. from scratch) -# local/chain/compare_wer.sh exp/chain/e2e_cnn_1a +# ./local/chain/compare_wer.sh exp/chain/e2e_cnn_1a/ # System e2e_cnn_1a -# WER 10.71 -# CER 2.85 -# Final train prob -0.0859 -# Final valid prob -0.1266 +# WER 19.30 +# CER 5.72 +# Final train prob -0.0734 +# Final valid prob -0.0607 # Final train prob (xent) # Final valid prob (xent) -# Parameters 2.94M +# Parameters 3.30M # steps/info/chain_dir_info.pl exp/chain/e2e_cnn_1a/ -# exp/chain/e2e_cnn_1a/: num-iters=195 nj=6..16 num-params=2.9M dim=40->324 combine=-0.065->-0.064 (over 5) logprob:train/valid[129,194,final]=(-0.078,-0.077,-0.086/-0.129,-0.126,-0.127) +# exp/chain/e2e_cnn_1a/: num-iters=24 nj=3..15 num-params=3.3M dim=56->292 combine=-0.060->-0.060 (over 1) logprob:train/valid[15,23,final]=(-0.122,-0.143,-0.073/-0.105,-0.132,-0.061) set -e + # configs for 'chain' stage=0 -nj=70 +nj=30 train_stage=-10 get_egs_stage=-10 affix=1a # training options tdnn_dim=450 -num_epochs=2 -num_jobs_initial=6 -num_jobs_final=16 -minibatch_size=150=128,64/300=128,64/600=64,32/1200=32,16 +minibatch_size=150=64,32/300=32,16/600=16,8/1200=8,4 common_egs_dir= -l2_regularize=0.00005 frames_per_iter=1000000 -cmvn_opts="--norm-means=true --norm-vars=true" +cmvn_opts="--norm-means=false --norm-vars=false" train_set=train -lang_test=lang_test +lang_decode=data/lang # End configuration section. echo "$0 $@" # Print the command line for logging @@ -89,16 +86,17 @@ if [ $stage -le 2 ]; then common1="required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" common2="required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" common3="required-time-offsets= height-offsets=-1,0,1 num-filters-out=70" + mkdir -p $dir/configs cat < $dir/configs/network.xconfig - input dim=40 name=input - conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 - conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 - conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 - conv-relu-batchnorm-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 - conv-relu-batchnorm-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 - conv-relu-batchnorm-layer name=cnn6 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 - conv-relu-batchnorm-layer name=cnn7 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + input dim=56 name=input + conv-relu-batchnorm-layer name=cnn1 height-in=56 height-out=56 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=56 height-out=28 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=28 height-out=28 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=28 height-out=28 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=28 height-out=14 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=14 height-out=14 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=14 height-out=14 time-offsets=-4,0,4 $common3 relu-batchnorm-layer name=tdnn1 input=Append(-4,0,4) dim=$tdnn_dim relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim @@ -118,20 +116,21 @@ if [ $stage -le 3 ]; then --cmd "$cmd" \ --feat.cmvn-opts "$cmvn_opts" \ --chain.leaky-hmm-coefficient 0.1 \ - --chain.l2-regularize $l2_regularize \ + --chain.l2-regularize 0.00005 \ --chain.apply-deriv-weights false \ --egs.dir "$common_egs_dir" \ --egs.stage $get_egs_stage \ --egs.opts "--num_egs_diagnostic 100 --num_utts_subset 400" \ --chain.frame-subsampling-factor 4 \ --chain.alignment-subsampling-factor 4 \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=1000" \ --trainer.add-option="--optimization.memory-compression-level=2" \ --trainer.num-chunk-per-minibatch $minibatch_size \ - --trainer.frames-per-iter $frames_per_iter \ - --trainer.num-epochs $num_epochs \ + --trainer.frames-per-iter 2000000 \ + --trainer.num-epochs 2 \ --trainer.optimization.momentum 0 \ - --trainer.optimization.num-jobs-initial $num_jobs_initial \ - --trainer.optimization.num-jobs-final $num_jobs_final \ + --trainer.optimization.num-jobs-initial 3 \ + --trainer.optimization.num-jobs-final 16 \ --trainer.optimization.initial-effective-lrate 0.001 \ --trainer.optimization.final-effective-lrate 0.0001 \ --trainer.optimization.shrink-value 1.0 \ @@ -151,7 +150,7 @@ if [ $stage -le 4 ]; then # as long as phones.txt was compatible. utils/mkgraph.sh \ - --self-loop-scale 1.0 data/$lang_test \ + --self-loop-scale 1.0 $lang_decode \ $dir $dir/graph || exit 1; fi diff --git a/egs/madcat_ar/v1/local/tl/process_waldo_data.py b/egs/madcat_ar/v1/local/tl/process_waldo_data.py new file mode 100755 index 00000000000..0d278e64122 --- /dev/null +++ b/egs/madcat_ar/v1/local/tl/process_waldo_data.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 + +""" This script reads image and transcription mapping and creates the following files :text, utt2spk, images.scp. + Eg. local/process_waldo_data.py lines/hyp_line_image_transcription_mapping_kaldi.txt data/test + Eg. text file: LDC0001_000404_NHR_ARB_20070113.0052_11_LDC0001_00z2 ﻮﺠﻫ ﻮﻌﻘﻟ ﻍﺍﺮﻗ ﺢﺗّﻯ ﺎﻠﻨﺧﺎﻋ + utt2spk file: LDC0001_000397_NHR_ARB_20070113.0052_11_LDC0001_00z1 LDC0001 + images.scp file: LDC0009_000000_arb-NG-2-76513-5612324_2_LDC0009_00z0 + data/local/lines/1/arb-NG-2-76513-5612324_2_LDC0009_00z0.tif +""" + +import argparse +import os +import sys + +parser = argparse.ArgumentParser(description="Creates text, utt2spk and images.scp files", + epilog="E.g. " + sys.argv[0] + " data/train data/local/lines ", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('image_transcription_file', type=str, + help='Path to the file containing line image path and transcription information') +parser.add_argument('out_dir', type=str, + help='directory location to write output files.') +args = parser.parse_args() + + +def read_image_text(image_text_path): + """ Given the file path containing, mapping information of line image + and transcription, it returns a dict. The dict contains this mapping + info. It can be accessed via line_id and will provide transcription. + Returns: + -------- + dict: line_id and transcription mapping + """ + image_transcription_dict = dict() + with open(image_text_path, encoding='utf-8') as f: + for line in f: + line_vect = line.strip().split(' ') + image_path = line_vect[0] + line_id = os.path.basename(image_path).split('.png')[0] + transcription = line_vect[1:] + joined_transcription = list() + for word in transcription: + joined_transcription.append(word) + joined_transcription = " ".join(joined_transcription) + image_transcription_dict[line_id] = joined_transcription + return image_transcription_dict + + +### main ### +print("Processing '{}' data...".format(args.out_dir)) +text_file = os.path.join(args.out_dir, 'text') +text_fh = open(text_file, 'w', encoding='utf-8') +utt2spk_file = os.path.join(args.out_dir, 'utt2spk') +utt2spk_fh = open(utt2spk_file, 'w', encoding='utf-8') +image_file = os.path.join(args.out_dir, 'images.scp') +image_fh = open(image_file, 'w', encoding='utf-8') + +image_transcription_dict = read_image_text(args.image_transcription_file) +for line_id in sorted(image_transcription_dict.keys()): + writer_id = line_id.strip().split('_')[-3] + updated_line_id = line_id + '.png' + image_file_path = os.path.join('lines', updated_line_id) + text = image_transcription_dict[line_id] + utt_id = line_id + text_fh.write(utt_id + ' ' + text + '\n') + utt2spk_fh.write(utt_id + ' ' + writer_id + '\n') + image_fh.write(utt_id + ' ' + image_file_path + '\n') + diff --git a/egs/madcat_ar/v1/local/tl/run_text_localization.sh b/egs/madcat_ar/v1/local/tl/run_text_localization.sh new file mode 100755 index 00000000000..8d12f7d802f --- /dev/null +++ b/egs/madcat_ar/v1/local/tl/run_text_localization.sh @@ -0,0 +1,143 @@ +#!/bin/bash +# Copyright 2017 Hossein Hadian +# 2018 Ashish Arora + +# This script performs full page text recognition on automatically extracted line images +# from madcat arabic data. It is created as a separate scrip, because it performs +# data augmentation, uses smaller language model and calls process_waldo_data for +# test images (automatically extracted line images). Data augmentation increases image +# height hence requires different DNN arachitecture and different chain scripts. + +set -e +stage=0 +nj=70 +# download_dir{1,2,3} points to the database path on the JHU grid. If you have not +# already downloaded the database you can set it to a local directory +# This corpus can be purchased here: +# https://catalog.ldc.upenn.edu/{LDC2012T15,LDC2013T09/,LDC2013T15/} +download_dir1=/export/corpora/LDC/LDC2012T15/data +download_dir2=/export/corpora/LDC/LDC2013T09/data +download_dir3=/export/corpora/LDC/LDC2013T15/data +writing_condition1=/export/corpora/LDC/LDC2012T15/docs/writing_conditions.tab +writing_condition2=/export/corpora/LDC/LDC2013T09/docs/writing_conditions.tab +writing_condition3=/export/corpora/LDC/LDC2013T15/docs/writing_conditions.tab +data_splits_dir=data/download/data_splits +images_scp_dir=data/local +overwrite=false +subset=true +augment=true +verticle_shift=16 +. ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. + ## This relates to the queue. +. ./path.sh +. ./utils/parse_options.sh # e.g. this parses the above options + # if supplied. +./local/check_tools.sh + +mkdir -p data/{train,test,dev}/data +mkdir -p data/local/{train,test,dev} +if [ $stage -le 0 ]; then + + if [ -f data/train/text ] && ! $overwrite; then + echo "$0: Not processing, probably script have run from wrong stage" + echo "Exiting with status 1 to avoid data corruption" + exit 1; + fi + echo "$0: Downloading data splits...$(date)" + local/download_data.sh --data_splits $data_splits_dir --download_dir1 $download_dir1 \ + --download_dir2 $download_dir2 --download_dir3 $download_dir3 + + for set in train dev; do + data_split_file=$data_splits_dir/madcat.$set.raw.lineid + local/extract_lines.sh --nj $nj --cmd $cmd --data_split_file $data_split_file \ + --download_dir1 $download_dir1 --download_dir2 $download_dir2 \ + --download_dir3 $download_dir3 --writing_condition1 $writing_condition1 \ + --writing_condition2 $writing_condition2 --writing_condition3 $writing_condition3 \ + --data data/local/$set --subset $subset --augment $augment || exit 1 + done + + echo "$0: Preparing data..." + for set in dev train; do + local/process_data.py $download_dir1 $download_dir2 $download_dir3 \ + $data_splits_dir/madcat.$set.raw.lineid data/$set $images_scp_dir/$set/images.scp \ + $writing_condition1 $writing_condition2 $writing_condition3 --augment $augment --subset $subset + image/fix_data_dir.sh data/${set} + done + + local/tl/process_waldo_data.py lines/hyp_line_image_transcription_mapping_kaldi.txt data/test + utils/utt2spk_to_spk2utt.pl data/test/utt2spk > data/test/spk2utt +fi + +if [ $stage -le 1 ]; then + echo "$0: Obtaining image groups. calling get_image2num_frames $(date)." + image/get_image2num_frames.py data/train + image/get_allowed_lengths.py --frame-subsampling-factor 4 10 data/train + for set in dev train test; do + echo "$0: Extracting features and calling compute_cmvn_stats for dataset: $set. $(date)" + local/extract_features.sh --nj $nj --cmd $cmd --feat-dim 40 \ + --verticle_shift $verticle_shift data/$set + steps/compute_cmvn_stats.sh data/$set || exit 1; + done + echo "$0: Fixing data directory for train dataset $(date)." + image/fix_data_dir.sh data/train +fi + +if [ $stage -le 2 ]; then + for set in train; do + echo "$(date) stage 2: Performing augmentation, it will double training data" + local/tl/augment_data.sh --nj $nj --cmd "$cmd" --feat-dim 40 \ + --verticle_shift $verticle_shift data/${set} data/${set}_aug data + steps/compute_cmvn_stats.sh data/${set}_aug || exit 1; + done +fi + +if [ $stage -le 3 ]; then + echo "$0: Preparing BPE..." + cut -d' ' -f2- data/train/text | utils/lang/bpe/reverse.py | \ + utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/learn_bpe.py -s 700 > data/local/bpe.txt + + for set in test train dev train_aug; do + cut -d' ' -f1 data/$set/text > data/$set/ids + cut -d' ' -f2- data/$set/text | utils/lang/bpe/reverse.py | \ + utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt \ + | sed 's/@@//g' > data/$set/bpe_text + + mv data/$set/text data/$set/text.old + paste -d' ' data/$set/ids data/$set/bpe_text > data/$set/text + rm -f data/$set/bpe_text data/$set/ids + done + + echo "$0:Preparing dictionary and lang..." + local/prepare_dict.sh + utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 8 --sil-prob 0.0 --position-dependent-phones false \ + data/local/dict "" data/lang/temp data/lang + utils/lang/bpe/add_final_optional_silence.sh --final-sil-prob 0.5 data/lang +fi + +if [ $stage -le 4 ]; then + echo "$0: Estimating a language model for decoding..." + local/tl/train_lm.sh --order 3 + utils/format_lm.sh data/lang data/local/local_lm/data/arpa/3gram_unpruned.arpa.gz \ + data/local/dict/lexicon.txt data/lang +fi + +nj=30 +if [ $stage -le 5 ]; then + echo "$0: Calling the flat-start chain recipe... $(date)." + local/tl/chain/run_e2e_cnn.sh --nj $nj --train_set train_aug +fi + +if [ $stage -le 6 ]; then + echo "$0: Aligning the training data using the e2e chain model...$(date)." + steps/nnet3/align.sh --nj $nj --cmd "$cmd" \ + --use-gpu false \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0 --acoustic-scale=1.0' \ + data/train_aug data/lang exp/chain/e2e_cnn_1a exp/chain/e2e_ali_train +fi + +if [ $stage -le 7 ]; then + echo "$0: Building a tree and training a regular chain model using the e2e alignments...$(date)" + local/tl/chain/run_cnn_e2eali.sh --nj $nj --train_set train_aug +fi diff --git a/egs/madcat_ar/v1/local/tl/train_lm.sh b/egs/madcat_ar/v1/local/tl/train_lm.sh new file mode 100755 index 00000000000..524bb2e9f40 --- /dev/null +++ b/egs/madcat_ar/v1/local/tl/train_lm.sh @@ -0,0 +1,102 @@ +#!/bin/bash + +# Copyright 2016 Vincent Nguyen +# 2016 Johns Hopkins University (author: Daniel Povey) +# 2017 Ashish Arora +# 2017 Hossein Hadian +# Apache 2.0 +# +# This script trains a LM on the training transcriptions. +# It is based on the example scripts distributed with PocoLM + +# It will check if pocolm is installed and if not will proceed with installation + +set -e +stage=0 +dir=data/local/local_lm +order=3 +echo "$0 $@" # Print the command line for logging +. ./utils/parse_options.sh || exit 1; + +lm_dir=${dir}/data + + +mkdir -p $dir +. ./path.sh || exit 1; # for KALDI_ROOT +export PATH=$KALDI_ROOT/tools/pocolm/scripts:$PATH +( # First make sure the pocolm toolkit is installed. + cd $KALDI_ROOT/tools || exit 1; + if [ -d pocolm ]; then + echo Not installing the pocolm toolkit since it is already there. + else + echo "$0: Please install the PocoLM toolkit with: " + echo " cd ../../../tools; extras/install_pocolm.sh; cd -" + exit 1; + fi +) || exit 1; + +bypass_metaparam_optim_opt= +# If you want to bypass the metaparameter optimization steps with specific metaparameters +# un-comment the following line, and change the numbers to some appropriate values. +# You can find the values from output log of train_lm.py. +# These example numbers of metaparameters is for 4-gram model (with min-counts) +# running with train_lm.py. +# The dev perplexity should be close to the non-bypassed model. +# Note: to use these example parameters, you may need to remove the .done files +# to make sure the make_lm_dir.py be called and tain only 3-gram model +#for order in 3; do +#rm -f ${lm_dir}/${num_word}_${order}.pocolm/.done +if [ $stage -le 0 ]; then + mkdir -p ${dir}/data + mkdir -p ${dir}/data/text + + echo "$0: Getting the Data sources" + + rm ${dir}/data/text/* 2>/dev/null || true + + # use the validation data as the dev set. + # Note: the name 'dev' is treated specially by pocolm, it automatically + # becomes the dev set. + + cat data/dev/text | cut -d " " -f 2- > ${dir}/data/text/dev.txt + + # use the training data as an additional data source. + # we can later fold the dev data into this. + cat data/train/text | cut -d " " -f 2- > ${dir}/data/text/train.txt + + # for reporting perplexities, we'll use the "real" dev set. + # (the validation data is used as ${dir}/data/text/dev.txt to work + # out interpolation weights.) + # note, we can't put it in ${dir}/data/text/, because then pocolm would use + # it as one of the data sources. + cut -d " " -f 2- < data/test/text > ${dir}/data/real_dev_set.txt + + # get the wordlist from MADCAT text + cat ${dir}/data/text/train.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + cat ${dir}/data/word_count | awk '{print $2}' > ${dir}/data/wordlist +fi + +if [ $stage -le 1 ]; then + # decide on the vocabulary. + # Note: you'd use --wordlist if you had a previously determined word-list + # that you wanted to use. + # Note: if you have more than one order, use a certain amount of words as the + # vocab and want to restrict max memory for 'sort', + echo "$0: training the unpruned LM" + min_counts='train=1' + wordlist=${dir}/data/wordlist + + lm_name="`basename ${wordlist}`_${order}" + if [ -n "${min_counts}" ]; then + lm_name+="_`echo ${min_counts} | tr -s "[:blank:]" "_" | tr "=" "-"`" + fi + unpruned_lm_dir=${lm_dir}/${lm_name}.pocolm + train_lm.py --wordlist=${wordlist} --num-splits=20 --warm-start-ratio=20 \ + --limit-unk-history=true \ + ${bypass_metaparam_optim_opt} \ + ${dir}/data/text ${order} ${lm_dir}/work ${unpruned_lm_dir} + + get_data_prob.py ${dir}/data/real_dev_set.txt ${unpruned_lm_dir} 2>&1 | grep -F '[perplexity' + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${unpruned_lm_dir} | gzip -c > ${dir}/data/arpa/${order}gram_unpruned.arpa.gz +fi diff --git a/egs/madcat_ar/v1/local/train_lm.sh b/egs/madcat_ar/v1/local/train_lm.sh index 3b8a382cb00..903b288a834 100755 --- a/egs/madcat_ar/v1/local/train_lm.sh +++ b/egs/madcat_ar/v1/local/train_lm.sh @@ -6,20 +6,19 @@ # 2017 Hossein Hadian # Apache 2.0 # -# This script trains a LM on the MADCAT training transcriptions. +# This script trains a LM on the training transcriptions. # It is based on the example scripts distributed with PocoLM # It will check if pocolm is installed and if not will proceed with installation set -e stage=0 - +dir=data/local/local_lm +order=6 echo "$0 $@" # Print the command line for logging . ./utils/parse_options.sh || exit 1; -dir=data/local/local_lm lm_dir=${dir}/data -segments=data/train/segmented_words mkdir -p $dir @@ -43,12 +42,10 @@ bypass_metaparam_optim_opt= # These example numbers of metaparameters is for 4-gram model (with min-counts) # running with train_lm.py. # The dev perplexity should be close to the non-bypassed model. -#bypass_metaparam_optim_opt= # Note: to use these example parameters, you may need to remove the .done files # to make sure the make_lm_dir.py be called and tain only 3-gram model #for order in 3; do #rm -f ${lm_dir}/${num_word}_${order}.pocolm/.done - if [ $stage -le 0 ]; then mkdir -p ${dir}/data mkdir -p ${dir}/data/text @@ -65,7 +62,13 @@ if [ $stage -le 0 ]; then # use the training data as an additional data source. # we can later fold the dev data into this. - cat data/train/text | cut -d " " -f 2- > ${dir}/data/text/madcat.txt + cat data/train/text | cut -d " " -f 2- > ${dir}/data/text/train.txt + + if [ -d "data/local/gigawordcorpus/arb_gw_5/data" ]; then + cat data/local/gigawordcorpus/arb_gw_5/data/nhr_arb_combined.txt | \ + utils/lang/bpe/prepend_words.py | utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt \ + | sed 's/@@//g' > ${dir}/data/text/corpus_text.txt + fi # for reporting perplexities, we'll use the "real" dev set. # (the validation data is used as ${dir}/data/text/dev.txt to work @@ -75,12 +78,10 @@ if [ $stage -le 0 ]; then cut -d " " -f 2- < data/test/text > ${dir}/data/real_dev_set.txt # get the wordlist from MADCAT text - cat ${dir}/data/text/madcat.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + cat ${dir}/data/text/{train,corpus_text}.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count cat ${dir}/data/word_count | awk '{print $2}' > ${dir}/data/wordlist fi -order=3 - if [ $stage -le 1 ]; then # decide on the vocabulary. # Note: you'd use --wordlist if you had a previously determined word-list @@ -88,7 +89,7 @@ if [ $stage -le 1 ]; then # Note: if you have more than one order, use a certain amount of words as the # vocab and want to restrict max memory for 'sort', echo "$0: training the unpruned LM" - min_counts='train=2 madcat=1' + min_counts='corpus_text=2 train=1' wordlist=${dir}/data/wordlist lm_name="`basename ${wordlist}`_${order}" @@ -96,13 +97,34 @@ if [ $stage -le 1 ]; then lm_name+="_`echo ${min_counts} | tr -s "[:blank:]" "_" | tr "=" "-"`" fi unpruned_lm_dir=${lm_dir}/${lm_name}.pocolm - train_lm.py --wordlist=${wordlist} --num-splits=5 --warm-start-ratio=1 \ + train_lm.py --wordlist=${wordlist} --num-splits=20 --warm-start-ratio=20 \ --limit-unk-history=true \ ${bypass_metaparam_optim_opt} \ ${dir}/data/text ${order} ${lm_dir}/work ${unpruned_lm_dir} get_data_prob.py ${dir}/data/real_dev_set.txt ${unpruned_lm_dir} 2>&1 | grep -F '[perplexity' - mkdir -p ${dir}/data/arpa format_arpa_lm.py ${unpruned_lm_dir} | gzip -c > ${dir}/data/arpa/${order}gram_unpruned.arpa.gz fi + +if [ $stage -le 2 ]; then + echo "$0: pruning the LM (to larger size)" + # Using 20 million n-grams for a big LM for rescoring purposes. + size=20000000 + prune_lm_dir.py --target-num-ngrams=$size --initial-threshold=0.02 ${unpruned_lm_dir} ${dir}/data/lm_${order}_prune_big + + get_data_prob.py ${dir}/data/real_dev_set.txt ${dir}/data/lm_${order}_prune_big 2>&1 | grep -F '[perplexity' + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${dir}/data/lm_${order}_prune_big | gzip -c > ${dir}/data/arpa/${order}gram_big.arpa.gz +fi + +if [ $stage -le 3 ]; then + echo "$0: pruning the LM (to smaller size)" + # Using 10 million n-grams for a smaller LM for graph building. Prune from the + # bigger-pruned LM, it'll be faster. + size=10000000 + prune_lm_dir.py --target-num-ngrams=$size ${dir}/data/lm_${order}_prune_big ${dir}/data/lm_${order}_prune_small + + get_data_prob.py ${dir}/data/real_dev_set.txt ${dir}/data/lm_${order}_prune_small 2>&1 | grep -F '[perplexity' + format_arpa_lm.py ${dir}/data/lm_${order}_prune_small | gzip -c > ${dir}/data/arpa/${order}gram_small.arpa.gz +fi diff --git a/egs/madcat_ar/v1/local/wer_output_filter b/egs/madcat_ar/v1/local/wer_output_filter index c0f03e7178a..d6d46f3f565 100755 --- a/egs/madcat_ar/v1/local/wer_output_filter +++ b/egs/madcat_ar/v1/local/wer_output_filter @@ -2,6 +2,9 @@ # Copyright 2012-2014 Johns Hopkins University (Author: Yenda Trmal) # Apache 2.0 +# This script converts a BPE-encoded text to normal text and performs normalization. +# It is used in scoring. + use utf8; use open qw(:encoding(utf8)); diff --git a/egs/madcat_ar/v1/run.sh b/egs/madcat_ar/v1/run.sh index 14c8bf7a6ce..01bfdbed543 100755 --- a/egs/madcat_ar/v1/run.sh +++ b/egs/madcat_ar/v1/run.sh @@ -11,9 +11,7 @@ decode_gmm=false # download_dir{1,2,3} points to the database path on the JHU grid. If you have not # already downloaded the database you can set it to a local directory # This corpus can be purchased here: -# https://catalog.ldc.upenn.edu/LDC2012T15, -# https://catalog.ldc.upenn.edu/LDC2013T09/, -# https://catalog.ldc.upenn.edu/LDC2013T15/. +# https://catalog.ldc.upenn.edu/{LDC2012T15,LDC2013T09/,LDC2013T15/} download_dir1=/export/corpora/LDC/LDC2012T15/data download_dir2=/export/corpora/LDC/LDC2013T09/data download_dir3=/export/corpora/LDC/LDC2013T15/data @@ -21,47 +19,50 @@ writing_condition1=/export/corpora/LDC/LDC2012T15/docs/writing_conditions.tab writing_condition2=/export/corpora/LDC/LDC2013T09/docs/writing_conditions.tab writing_condition3=/export/corpora/LDC/LDC2013T15/docs/writing_conditions.tab data_splits_dir=data/download/data_splits - +images_scp_dir=data/local +overwrite=false +subset=false +augment=false +use_extra_corpus_text=true . ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. ## This relates to the queue. . ./path.sh . ./utils/parse_options.sh # e.g. this parses the above options # if supplied. - ./local/check_tools.sh - mkdir -p data/{train,test,dev}/data mkdir -p data/local/{train,test,dev} if [ $stage -le 0 ]; then - echo "$0: Downloading data splits..." - echo "Date: $(date)." - local/download_data.sh --data_splits $data_splits_dir --download_dir1 $download_dir1 \ - --download_dir2 $download_dir2 --download_dir3 $download_dir3 -fi - -if [ $stage -le 1 ]; then - for dataset in test train dev; do - data_split_file=$data_splits_dir/madcat.$dataset.raw.lineid + if [ -f data/train/text ] && ! $overwrite; then + echo "$0: Not processing, probably script have run from wrong stage" + echo "Exiting with status 1 to avoid data corruption" + exit 1; + fi + local/prepare_data.sh --data_splits $data_splits_dir --download_dir1 $download_dir1 \ + --download_dir2 $download_dir2 --download_dir3 $download_dir3 \ + --use_extra_corpus_text $use_extra_corpus_text + + for set in test train dev; do + data_split_file=$data_splits_dir/madcat.$set.raw.lineid local/extract_lines.sh --nj $nj --cmd $cmd --data_split_file $data_split_file \ --download_dir1 $download_dir1 --download_dir2 $download_dir2 \ --download_dir3 $download_dir3 --writing_condition1 $writing_condition1 \ --writing_condition2 $writing_condition2 --writing_condition3 $writing_condition3 \ - --data data/local/$dataset + --data data/local/$set --subset $subset --augment $augment || exit 1 done -fi -if [ $stage -le 2 ]; then - echo "$0: Preparing data..." - local/prepare_data.sh --download_dir1 $download_dir1 --download_dir2 $download_dir2 \ - --download_dir3 $download_dir3 --images_scp_dir data/local \ - --data_splits_dir $data_splits_dir --writing_condition1 $writing_condition1 \ - --writing_condition2 $writing_condition2 --writing_condition3 $writing_condition3 + echo "$0: Processing data..." + for set in dev train test; do + local/process_data.py $download_dir1 $download_dir2 $download_dir3 \ + $data_splits_dir/madcat.$set.raw.lineid data/$set $images_scp_dir/$set/images.scp \ + $writing_condition1 $writing_condition2 $writing_condition3 --augment $augment --subset $subset + image/fix_data_dir.sh data/${set} + done fi -mkdir -p data/{train,test,dev}/data -if [ $stage -le 3 ]; then +if [ $stage -le 1 ]; then for dataset in test train; do local/extract_features.sh --nj $nj --cmd $cmd --feat-dim 40 data/$dataset steps/compute_cmvn_stats.sh data/$dataset || exit 1; @@ -69,33 +70,53 @@ if [ $stage -le 3 ]; then utils/fix_data_dir.sh data/train fi -if [ $stage -le 4 ]; then - echo "$0: Preparing dictionary and lang..." +if [ $stage -le 2 ]; then + echo "$0: Preparing BPE..." + cut -d' ' -f2- data/train/text | utils/lang/bpe/reverse.py | \ + utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/learn_bpe.py -s 700 > data/local/bpe.txt + + for set in test train dev; do + cut -d' ' -f1 data/$set/text > data/$set/ids + cut -d' ' -f2- data/$set/text | utils/lang/bpe/reverse.py | \ + utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt \ + | sed 's/@@//g' > data/$set/bpe_text + + mv data/$set/text data/$set/text.old + paste -d' ' data/$set/ids data/$set/bpe_text > data/$set/text + rm -f data/$set/bpe_text data/$set/ids + done + + echo "$0:Preparing dictionary and lang..." local/prepare_dict.sh - utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 8 --sil-prob 0.95 \ - data/local/dict "" data/lang/temp data/lang + utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 8 --sil-prob 0.0 --position-dependent-phones false \ + data/local/dict "" data/lang/temp data/lang + utils/lang/bpe/add_final_optional_silence.sh --final-sil-prob 0.5 data/lang fi -if [ $stage -le 5 ]; then +if [ $stage -le 3 ]; then echo "$0: Estimating a language model for decoding..." local/train_lm.sh - utils/format_lm.sh data/lang data/local/local_lm/data/arpa/3gram_unpruned.arpa.gz \ - data/local/dict/lexicon.txt data/lang_test + utils/format_lm.sh data/lang data/local/local_lm/data/arpa/6gram_small.arpa.gz \ + data/local/dict/lexicon.txt data/lang + utils/build_const_arpa_lm.sh data/local/local_lm/data/arpa/6gram_unpruned.arpa.gz \ + data/lang data/lang_rescore_6g fi -if [ $stage -le 6 ]; then +if [ $stage -le 4 ]; then steps/train_mono.sh --nj $nj --cmd $cmd --totgauss 10000 data/train \ data/lang exp/mono fi -if [ $stage -le 7 ] && $decode_gmm; then - utils/mkgraph.sh --mono data/lang_test exp/mono exp/mono/graph +if [ $stage -le 5 ] && $decode_gmm; then + utils/mkgraph.sh --mono data/lang exp/mono exp/mono/graph steps/decode.sh --nj $nj --cmd $cmd exp/mono/graph data/test \ exp/mono/decode_test fi -if [ $stage -le 8 ]; then +if [ $stage -le 6 ]; then steps/align_si.sh --nj $nj --cmd $cmd data/train data/lang \ exp/mono exp/mono_ali @@ -103,14 +124,14 @@ if [ $stage -le 8 ]; then exp/mono_ali exp/tri fi -if [ $stage -le 9 ] && $decode_gmm; then - utils/mkgraph.sh data/lang_test exp/tri exp/tri/graph +if [ $stage -le 7 ] && $decode_gmm; then + utils/mkgraph.sh data/lang exp/tri exp/tri/graph steps/decode.sh --nj $nj --cmd $cmd exp/tri/graph data/test \ exp/tri/decode_test fi -if [ $stage -le 10 ]; then +if [ $stage -le 8 ]; then steps/align_si.sh --nj $nj --cmd $cmd data/train data/lang \ exp/tri exp/tri_ali @@ -119,22 +140,22 @@ if [ $stage -le 10 ]; then data/train data/lang exp/tri_ali exp/tri3 fi -if [ $stage -le 11 ] && $decode_gmm; then - utils/mkgraph.sh data/lang_test exp/tri3 exp/tri3/graph +if [ $stage -le 9 ] && $decode_gmm; then + utils/mkgraph.sh data/lang exp/tri3 exp/tri3/graph steps/decode.sh --nj $nj --cmd $cmd exp/tri3/graph \ data/test exp/tri3/decode_test fi -if [ $stage -le 12 ]; then +if [ $stage -le 10 ]; then steps/align_fmllr.sh --nj $nj --cmd $cmd --use-graphs true \ data/train data/lang exp/tri3 exp/tri3_ali fi -if [ $stage -le 13 ]; then - local/chain/run_cnn_1a.sh +if [ $stage -le 11 ]; then + local/chain/run_cnn.sh fi -if [ $stage -le 14 ]; then - local/chain/run_cnn_chainali_1a.sh --stage 2 +if [ $stage -le 12 ]; then + local/chain/run_cnn_chainali.sh --stage 2 fi diff --git a/egs/madcat_ar/v1/run_end2end.sh b/egs/madcat_ar/v1/run_end2end.sh index 5d27476d3e1..62f4eeb7c71 100755 --- a/egs/madcat_ar/v1/run_end2end.sh +++ b/egs/madcat_ar/v1/run_end2end.sh @@ -7,9 +7,7 @@ nj=70 # download_dir{1,2,3} points to the database path on the JHU grid. If you have not # already downloaded the database you can set it to a local directory # This corpus can be purchased here: -# https://catalog.ldc.upenn.edu/LDC2012T15, -# https://catalog.ldc.upenn.edu/LDC2013T09/, -# https://catalog.ldc.upenn.edu/LDC2013T15/. +# https://catalog.ldc.upenn.edu/{LDC2012T15,LDC2013T09/,LDC2013T15/} download_dir1=/export/corpora/LDC/LDC2012T15/data download_dir2=/export/corpora/LDC/LDC2013T09/data download_dir3=/export/corpora/LDC/LDC2013T15/data @@ -17,7 +15,11 @@ writing_condition1=/export/corpora/LDC/LDC2012T15/docs/writing_conditions.tab writing_condition2=/export/corpora/LDC/LDC2013T09/docs/writing_conditions.tab writing_condition3=/export/corpora/LDC/LDC2013T15/docs/writing_conditions.tab data_splits_dir=data/download/data_splits - +images_scp_dir=data/local +overwrite=false +subset=false +augment=false +use_extra_corpus_text=true . ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. ## This relates to the queue. . ./path.sh @@ -27,102 +29,105 @@ data_splits_dir=data/download/data_splits mkdir -p data/{train,test,dev}/data mkdir -p data/local/{train,test,dev} - if [ $stage -le 0 ]; then - echo "$0: Downloading data splits..." - echo "Date: $(date)." - local/download_data.sh --data_splits $data_splits_dir --download_dir1 $download_dir1 \ - --download_dir2 $download_dir2 --download_dir3 $download_dir3 -fi -if [ $stage -le 1 ]; then - for dataset in test train dev; do - data_split_file=$data_splits_dir/madcat.$dataset.raw.lineid + if [ -f data/train/text ] && ! $overwrite; then + echo "$0: Not processing, probably script have run from wrong stage" + echo "Exiting with status 1 to avoid data corruption" + exit 1; + fi + + echo "$0: preparing data...$(date)" + local/prepare_data.sh --data_splits $data_splits_dir --download_dir1 $download_dir1 \ + --download_dir2 $download_dir2 --download_dir3 $download_dir3 \ + --use_extra_corpus_text $use_extra_corpus_text + + for set in test train dev; do + data_split_file=$data_splits_dir/madcat.$set.raw.lineid local/extract_lines.sh --nj $nj --cmd $cmd --data_split_file $data_split_file \ --download_dir1 $download_dir1 --download_dir2 $download_dir2 \ --download_dir3 $download_dir3 --writing_condition1 $writing_condition1 \ --writing_condition2 $writing_condition2 --writing_condition3 $writing_condition3 \ - --data data/local/$dataset + --data data/local/$set --subset $subset --augment $augment || exit 1 + done + + echo "$0: Processing data..." + for set in dev train test; do + local/process_data.py $download_dir1 $download_dir2 $download_dir3 \ + $data_splits_dir/madcat.$set.raw.lineid data/$set $images_scp_dir/$set/images.scp \ + $writing_condition1 $writing_condition2 $writing_condition3 --augment $augment --subset $subset + image/fix_data_dir.sh data/${set} done -fi -if [ $stage -le 2 ]; then - echo "$0: Preparing data..." - local/prepare_data.sh --download_dir1 $download_dir1 --download_dir2 $download_dir2 \ - --download_dir3 $download_dir3 --images_scp_dir data/local \ - --data_splits_dir $data_splits_dir --writing_condition1 $writing_condition1 \ - --writing_condition2 $writing_condition2 --writing_condition3 $writing_condition3 fi -if [ $stage -le 3 ]; then - echo "$0: Obtaining image groups. calling get_image2num_frames" - echo "Date: $(date)." - image/get_image2num_frames.py data/train # This will be needed for the next command - # The next command creates a "allowed_lengths.txt" file in data/train - # which will be used by local/make_features.py to enforce the images to - # have allowed lengths. The allowed lengths will be spaced by 10% difference in length. - echo "$0: Obtaining image groups. calling get_allowed_lengths" - echo "Date: $(date)." +if [ $stage -le 1 ]; then + echo "$0: Obtaining image groups. calling get_image2num_frames $(date)." + image/get_image2num_frames.py data/train image/get_allowed_lengths.py --frame-subsampling-factor 4 10 data/train -fi -if [ $stage -le 4 ]; then - for dataset in test train; do - echo "$0: Extracting features and calling compute_cmvn_stats for dataset: $dataset. " - echo "Date: $(date)." - local/extract_features.sh --nj $nj --cmd $cmd --feat-dim 40 data/$dataset - steps/compute_cmvn_stats.sh data/$dataset || exit 1; + for set in test dev train; do + echo "$0: Extracting features and calling compute_cmvn_stats for dataset: $set. $(date)" + local/extract_features.sh --nj $nj --cmd $cmd --feat-dim 40 data/$set + steps/compute_cmvn_stats.sh data/$set || exit 1; done - echo "$0: Fixing data directory for train dataset" - echo "Date: $(date)." + echo "$0: Fixing data directory for train dataset $(date)." utils/fix_data_dir.sh data/train fi -if [ $stage -le 5 ]; then - echo "$0: Preparing dictionary and lang..." - cut -d' ' -f2- data/train/text | local/reverse.py | \ - local/prepend_words.py | \ - utils/lang/bpe/learn_bpe.py -s 700 > data/train/bpe.out +if [ $stage -le 2 ]; then + echo "$0: Preparing BPE..." + cut -d' ' -f2- data/train/text | utils/lang/bpe/reverse.py | \ + utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/learn_bpe.py -s 700 > data/local/bpe.txt + for set in test train dev; do cut -d' ' -f1 data/$set/text > data/$set/ids - cut -d' ' -f2- data/$set/text | local/reverse.py | \ - local/prepend_words.py | utils/lang/bpe/apply_bpe.py -c data/train/bpe.out \ + cut -d' ' -f2- data/$set/text | utils/lang/bpe/reverse.py | \ + utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt \ | sed 's/@@//g' > data/$set/bpe_text + mv data/$set/text data/$set/text.old paste -d' ' data/$set/ids data/$set/bpe_text > data/$set/text + rm -f data/$set/bpe_text data/$set/ids done + + echo "$0:Preparing dictionary and lang..." local/prepare_dict.sh - # This recipe uses byte-pair encoding, the silences are part of the words' pronunciations. - # So we set --sil-prob to 0.0 utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 8 --sil-prob 0.0 --position-dependent-phones false \ data/local/dict "" data/lang/temp data/lang utils/lang/bpe/add_final_optional_silence.sh --final-sil-prob 0.5 data/lang fi -if [ $stage -le 6 ]; then +if [ $stage -le 3 ]; then + echo "$0: Calling the flat-start chain recipe... $(date)." + local/chain/run_e2e_cnn.sh +fi + +lang_decode=data/lang +lang_rescore=data/lang_rescore_6g +decode_e2e=true +if [ $stage -le 4 ]; then echo "$0: Estimating a language model for decoding..." local/train_lm.sh - utils/format_lm.sh data/lang data/local/local_lm/data/arpa/3gram_unpruned.arpa.gz \ - data/local/dict/lexicon.txt data/lang_test + utils/format_lm.sh data/lang data/local/local_lm/data/arpa/6gram_big.arpa.gz \ + data/local/dict/lexicon.txt $lang_decode + utils/build_const_arpa_lm.sh data/local/local_lm/data/arpa/6gram_unpruned.arpa.gz \ + data/lang $lang_rescore fi -if [ $stage -le 7 ]; then - echo "$0: Calling the flat-start chain recipe..." - echo "Date: $(date)." - local/chain/run_flatstart_cnn1a.sh --nj $nj -fi +if [ $stage -le 5 ] && $decode_e2e; then + echo "$0: $(date) stage 5: decoding end2end setup..." + utils/mkgraph.sh --self-loop-scale 1.0 $lang_decode \ + exp/chain/e2e_cnn_1a/ exp/chain/e2e_cnn_1a/graph || exit 1; -if [ $stage -le 8 ]; then - echo "$0: Aligning the training data using the e2e chain model..." - echo "Date: $(date)." - steps/nnet3/align.sh --nj $nj --cmd "$cmd" \ - --use-gpu false \ - --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0 --acoustic-scale=1.0' \ - data/train data/lang exp/chain/e2e_cnn_1a exp/chain/e2e_ali_train -fi + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 --nj $nj --cmd "$cmd" \ + exp/chain/e2e_cnn_1a/graph data/test exp/chain/e2e_cnn_1a/decode_test || exit 1; + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/test exp/chain/e2e_cnn_1a/decode_test{,_rescored} || exit 1 -if [ $stage -le 9 ]; then - echo "$0: Building a tree and training a regular chain model using the e2e alignments..." - echo "Date: $(date)." - local/chain/run_cnn_e2eali_1b.sh --nj $nj + echo "$0: Done. Date: $(date). Results:" + local/chain/compare_wer.sh exp/chain/e2e_cnn_1a/ fi diff --git a/egs/madcat_zh/README.txt b/egs/madcat_zh/README.txt new file mode 100644 index 00000000000..4ea8df8bb3c --- /dev/null +++ b/egs/madcat_zh/README.txt @@ -0,0 +1,5 @@ +This directory contains example scripts for handwriting recognition on +the MADCAT Chinese HWR dataset (LDC2014T13). +This dataset consists of handwritten Chinese documents, scanned +at high resolution and annotated for each line and token. +More info: https://catalog.ldc.upenn.edu/LDC2014T13 diff --git a/egs/madcat_zh/v1/cmd.sh b/egs/madcat_zh/v1/cmd.sh new file mode 100644 index 00000000000..3c8eb9f93a5 --- /dev/null +++ b/egs/madcat_zh/v1/cmd.sh @@ -0,0 +1,13 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +export cmd="queue.pl" diff --git a/egs/madcat_zh/v1/image b/egs/madcat_zh/v1/image new file mode 120000 index 00000000000..1668ee99922 --- /dev/null +++ b/egs/madcat_zh/v1/image @@ -0,0 +1 @@ +../../cifar/v1/image/ \ No newline at end of file diff --git a/egs/madcat_zh/v1/local/chain/compare_wer.sh b/egs/madcat_zh/v1/local/chain/compare_wer.sh new file mode 100755 index 00000000000..4eb665fc702 --- /dev/null +++ b/egs/madcat_zh/v1/local/chain/compare_wer.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +# this script is used for comparing decoding results between systems. +# e.g. local/chain/compare_wer.sh exp/chain/cnn{1a,1b} + +# Copyright 2017 Chun Chieh Chang +# 2017 Ashish Arora + +if [ $# == 0 ]; then + echo "Usage: $0: [ ... ]" + echo "e.g.: $0 exp/chain/cnn{1a,1b}" + exit 1 +fi + +echo "# $0 $*" +used_epochs=false + +echo -n "# System " +for x in $*; do printf "% 10s" " $(basename $x)"; done +echo + +echo -n "# WER " +for x in $*; do + wer=$(cat $x/decode_test/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +if $used_epochs; then + exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. +fi + +echo -n "# Final train prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final train prob (xent) " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob (xent) " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo diff --git a/egs/madcat_zh/v1/local/chain/run_cnn.sh b/egs/madcat_zh/v1/local/chain/run_cnn.sh new file mode 120000 index 00000000000..df6f0a468c1 --- /dev/null +++ b/egs/madcat_zh/v1/local/chain/run_cnn.sh @@ -0,0 +1 @@ +tuning/run_cnn_1a.sh \ No newline at end of file diff --git a/egs/madcat_zh/v1/local/chain/run_cnn_chainali.sh b/egs/madcat_zh/v1/local/chain/run_cnn_chainali.sh new file mode 120000 index 00000000000..86568421fe1 --- /dev/null +++ b/egs/madcat_zh/v1/local/chain/run_cnn_chainali.sh @@ -0,0 +1 @@ +tuning/run_cnn_chainali_1b.sh \ No newline at end of file diff --git a/egs/madcat_zh/v1/local/chain/run_e2e_cnn.sh b/egs/madcat_zh/v1/local/chain/run_e2e_cnn.sh new file mode 120000 index 00000000000..d26ba0182ce --- /dev/null +++ b/egs/madcat_zh/v1/local/chain/run_e2e_cnn.sh @@ -0,0 +1 @@ +tuning/run_e2e_cnn_1a.sh \ No newline at end of file diff --git a/egs/madcat_zh/v1/local/chain/tuning/run_cnn_1a.sh b/egs/madcat_zh/v1/local/chain/tuning/run_cnn_1a.sh new file mode 100755 index 00000000000..164d62a7ad9 --- /dev/null +++ b/egs/madcat_zh/v1/local/chain/tuning/run_cnn_1a.sh @@ -0,0 +1,223 @@ +#!/bin/bash + +# Copyright 2017 Hossein Hadian +# 2017 Chun Chieh Chang +# 2017 Ashish Arora + +# steps/info/chain_dir_info.pl exp/chain/cnn_1a/ +# exp/chain/cnn_1a/: num-iters=21 nj=2..4 num-params=4.4M dim=40->364 combine=-0.021->-0.015 xent:train/valid[13,20,final]=(-1.05,-0.701,-0.591/-1.30,-1.08,-1.00) logprob:train/valid[13,20,final]=(-0.061,-0.034,-0.030/-0.107,-0.101,-0.098) + +# local/chain/compare_wer.sh exp/chain/cnn_1a/ exp/chain/cnn_chainali_1b/ exp/chain/e2e_cnn_1a/ +# System cnn_1a cnn_chainali_1b e2e_cnn_1a +# WER 13.51 6.76 10.55 +# Final train prob -0.0291 -0.0138 -0.0702 +# Final valid prob -0.0712 -0.0171 -0.0578 +# Final train prob (xent) -0.3847 -0.4169 +# Final valid prob (xent) -0.4962 -0.5040 + +set -e -o pipefail + +stage=0 + +nj=50 +train_set=train +gmm=tri3 # this is the source gmm-dir that we'll use for alignments; it + # should have alignments for the specified training data. +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. +affix=_1a #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +ali=tri3_ali +common_egs_dir= +reporting_email= + +# chain options +train_stage=-10 +xent_regularize=0.1 +chunk_width=340,300,200,100 +num_leaves=500 +tdnn_dim=450 +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 2 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj $nj --cmd "$cmd" ${train_data_dir} \ + data/lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 3 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor 4 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$cmd" $num_leaves ${train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 4 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + common1="height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="height-offsets=-2,-1,0,1,2 num-filters-out=70" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=60 name=input + + conv-relu-batchnorm-layer name=cnn1 height-in=60 height-out=60 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=60 height-out=30 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=30 height-out=30 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=30 height-out=15 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + relu-batchnorm-layer name=tdnn1 input=Append(-4,-2,0,2,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn4 input=Append(-4,0,4) dim=$tdnn_dim + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' mod?els... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn4 dim=$tdnn_dim target-rms=0.5 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 5 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/iam-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=1000" \ + --chain.frame-subsampling-factor=4 \ + --chain.alignment-subsampling-factor=4 \ + --trainer.srand=0 \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=2 \ + --trainer.frames-per-iter=2000000 \ + --trainer.optimization.num-jobs-initial=3 \ + --trainer.optimization.num-jobs-final=12 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=64,32 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0" \ + --cleanup.remove-egs=false \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 6 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/lang_test \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 7 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/test $dir/decode_test || exit 1; +fi + +echo "$0: Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir + diff --git a/egs/iam/v1/local/chain/run_cnn_chainali_1a.sh b/egs/madcat_zh/v1/local/chain/tuning/run_cnn_chainali_1a.sh similarity index 87% rename from egs/iam/v1/local/chain/run_cnn_chainali_1a.sh rename to egs/madcat_zh/v1/local/chain/tuning/run_cnn_chainali_1a.sh index ee3a1a3d92c..be51bdcc3d1 100755 --- a/egs/iam/v1/local/chain/run_cnn_chainali_1a.sh +++ b/egs/madcat_zh/v1/local/chain/tuning/run_cnn_chainali_1a.sh @@ -2,10 +2,16 @@ # chainali_1a is as 1a except it uses chain alignments (using 1a system) instead of gmm alignments +# ./local/chain/compare_wer.sh exp/chain/cnn_chainali_1a/ exp/chain/cnn_1a/ + +# steps/info/chain_dir_info.pl exp/chain/cnn_chainali_1a/ +# exp/chain/cnn_chainali_1a/: num-iters=21 nj=2..4 num-params=4.4M dim=40->364 combine=-0.002->0.000 xent:train/valid[13,20,final]=(-0.929,-0.711,-0.645/-1.16,-1.04,-0.992) logprob:train/valid[13,20,final]=(-0.029,-0.016,-0.013/-0.051,-0.047,-0.045) + +# cat exp/chain/cnn_chainali_1a/decode_test/scoring_kaldi/best_* + set -e -o pipefail stage=0 - nj=30 train_set=train gmm=tri3 # this is the source gmm-dir that we'll use for alignments; it @@ -13,35 +19,25 @@ gmm=tri3 # this is the source gmm-dir that we'll use for alignments; it nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. affix=_1a #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. ali=tri3_ali -chain_model_dir=exp/chain${nnet3_affix}/cnn_1a +chain_model_dir=exp/chain${nnet3_affix}/cnn${affix} common_egs_dir= reporting_email= # chain options train_stage=-10 xent_regularize=0.1 -frame_subsampling_factor=4 -alignment_subsampling_factor=1 # training chunk-options chunk_width=340,300,200,100 num_leaves=500 # we don't need extra left/right context for TDNN systems. -chunk_left_context=0 -chunk_right_context=0 tdnn_dim=450 -# training options -srand=0 -remove_egs=false -lang_test=lang_test # End configuration section. echo "$0 $@" # Print the command line for logging - . ./cmd.sh . ./path.sh . ./utils/parse_options.sh - if ! cuda-compiled; then cat < $dir/configs/network.xconfig input dim=40 name=input - conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 @@ -175,27 +169,23 @@ if [ $stage -le 5 ]; then --chain.l2-regularize=0.00005 \ --chain.apply-deriv-weights=false \ --chain.lm-opts="--num-extra-lm-states=500" \ - --chain.frame-subsampling-factor=$frame_subsampling_factor \ - --chain.alignment-subsampling-factor=$alignment_subsampling_factor \ - --trainer.srand=$srand \ + --chain.frame-subsampling-factor=4 \ + --chain.alignment-subsampling-factor=1 \ + --trainer.srand=0 \ --trainer.max-param-change=2.0 \ - --trainer.num-epochs=4 \ - --trainer.frames-per-iter=1000000 \ - --trainer.optimization.num-jobs-initial=2 \ - --trainer.optimization.num-jobs-final=4 \ + --trainer.num-epochs=2 \ + --trainer.frames-per-iter=2000000 \ + --trainer.optimization.num-jobs-initial=8 \ + --trainer.optimization.num-jobs-final=16 \ --trainer.optimization.initial-effective-lrate=0.001 \ --trainer.optimization.final-effective-lrate=0.0001 \ --trainer.optimization.shrink-value=1.0 \ --trainer.num-chunk-per-minibatch=64,32 \ --trainer.optimization.momentum=0.0 \ --egs.chunk-width=$chunk_width \ - --egs.chunk-left-context=$chunk_left_context \ - --egs.chunk-right-context=$chunk_right_context \ - --egs.chunk-left-context-initial=0 \ - --egs.chunk-right-context-final=0 \ --egs.dir="$common_egs_dir" \ --egs.opts="--frames-overlap-per-eg 0" \ - --cleanup.remove-egs=$remove_egs \ + --cleanup.remove-egs=false \ --use-gpu=true \ --reporting.email="$reporting_email" \ --feat-dir=$train_data_dir \ @@ -211,19 +201,14 @@ if [ $stage -le 6 ]; then # topology file from the model). So you could give it a different # lang directory, one that contained a wordlist and LM of your choice, # as long as phones.txt was compatible. - utils/mkgraph.sh \ - --self-loop-scale 1.0 data/$lang_test \ + --self-loop-scale 1.0 data/lang_test \ $dir $dir/graph || exit 1; fi if [ $stage -le 7 ]; then frames_per_chunk=$(echo $chunk_width | cut -d, -f1) steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ - --extra-left-context $chunk_left_context \ - --extra-right-context $chunk_right_context \ - --extra-left-context-initial 0 \ - --extra-right-context-final 0 \ --frames-per-chunk $frames_per_chunk \ --nj $nj --cmd "$cmd" \ $dir/graph data/test $dir/decode_test || exit 1; diff --git a/egs/madcat_zh/v1/local/chain/tuning/run_cnn_chainali_1b.sh b/egs/madcat_zh/v1/local/chain/tuning/run_cnn_chainali_1b.sh new file mode 100755 index 00000000000..aa61620a92f --- /dev/null +++ b/egs/madcat_zh/v1/local/chain/tuning/run_cnn_chainali_1b.sh @@ -0,0 +1,226 @@ +#!/bin/bash + +# chainali_1b is as chainali_1a except it has 3 more cnn layers and 1 less tdnn layer. +# ./local/chain/compare_wer.sh exp/chain/cnn_chainali_1a/ exp/chain/cnn_chainali_1b/ + +# steps/info/chain_dir_info.pl exp/chain/chainali_cnn_1b/ +# exp/chain/chainali_cnn_1b/: num-iters=21 nj=2..4 num-params=4.0M dim=40->364 combine=-0.009->-0.005 xent:train/valid[13,20,final]=(-1.47,-0.728,-0.623/-1.69,-1.02,-0.940) logprob:train/valid[13,20,final]=(-0.068,-0.030,-0.011/-0.086,-0.056,-0.038) + +# local/chain/compare_wer.sh exp/chain/cnn_1a/ exp/chain/cnn_chainali_1b/ exp/chain/e2e_cnn_1a/ +# System cnn_1a cnn_chainali_1b e2e_cnn_1a +# WER 13.51 6.76 10.55 +# Final train prob -0.0291 -0.0138 -0.0702 +# Final valid prob -0.0712 -0.0171 -0.0578 +# Final train prob (xent) -0.3847 -0.4169 +# Final valid prob (xent) -0.4962 -0.5040 + +set -e -o pipefail + +stage=0 +nj=30 +train_set=train +gmm=tri3 # this is the source gmm-dir that we'll use for alignments; it + # should have alignments for the specified training data. +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. +affix=_1b #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +ali=tri3_ali +chain_model_dir=exp/chain${nnet3_affix}/cnn${affix} +common_egs_dir= +reporting_email= +# chain options +train_stage=-10 +xent_regularize=0.1 +# training chunk-options +chunk_width=340,300,200,100 +num_leaves=500 +tdnn_dim=450 +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 2 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/nnet3/align_lats.sh --nj $nj --cmd "$cmd" \ + --acoustic-scale 1.0 \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0' \ + ${train_data_dir} data/lang $chain_model_dir $lat_dir + cp $gmm_lat_dir/splice_opts $lat_dir/splice_opts +fi + +if [ $stage -le 3 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor 4 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$cmd" $num_leaves ${train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 4 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + common1="required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="required-time-offsets= height-offsets=-1,0,1 num-filters-out=70" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=60 name=input + + conv-relu-batchnorm-layer name=cnn1 height-in=60 height-out=60 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=60 height-out=30 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=30 height-out=30 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=30 height-out=30 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=30 height-out=15 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=15 height-out=15 time-offsets=-1,0,1 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=15 height-out=15 time-offsets=-1,0,1 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-4,-2,0,2,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' mod?els... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 5 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/iam-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=1000" \ + --chain.frame-subsampling-factor=4 \ + --chain.alignment-subsampling-factor=1 \ + --trainer.srand=0 \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=2 \ + --trainer.frames-per-iter=2000000 \ + --trainer.optimization.num-jobs-initial=6 \ + --trainer.optimization.num-jobs-final=16 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=64,32 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0" \ + --cleanup.remove-egs=false \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 6 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/lang_test \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 7 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/test $dir/decode_test || exit 1; +fi + +echo "$0: Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/madcat_zh/v1/local/chain/tuning/run_e2e_cnn_1a.sh b/egs/madcat_zh/v1/local/chain/tuning/run_e2e_cnn_1a.sh new file mode 100755 index 00000000000..ffc9a4c8a14 --- /dev/null +++ b/egs/madcat_zh/v1/local/chain/tuning/run_e2e_cnn_1a.sh @@ -0,0 +1,130 @@ +#!/bin/bash +# Copyright 2017 Hossein Hadian + +# local/chain/compare_wer.sh exp/chain/e2e_cnn_1a +# System e2e_cnn_1a +# WER 10.41 +# Final train prob -0.0536 +# Final valid prob -0.0489 +# Final train prob (xent) +# Final valid prob (xent) + +# steps/info/chain_dir_info.pl exp/chain/e2e_cnn_1a/ +# exp/chain/e2e_cnn_1a/: num-iters=63 nj=6..12 num-params=6.1M dim=80->5760 combine=-0.048->-0.048 (over 5) logprob:train/valid[41,62,final]=(-0.062,-0.065,-0.054/-0.058,-0.062,-0.049) + +set -e +# configs for 'chain' +stage=0 +train_stage=-10 +get_egs_stage=-10 +affix=1a + +# training options +tdnn_dim=450 +minibatch_size=150=48,24/300=24,12/600=12,6/1200=4,4 +common_egs_dir= +train_set=train + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 1 ]; then + steps/nnet3/chain/e2e/prepare_e2e.sh --nj 70 --cmd "$cmd" \ + --shared-phones true \ + --type mono \ + data/$train_set $lang $treedir + $cmd $treedir/log/make_phone_lm.log \ + cat data/$train_set/text \| \ + steps/nnet3/chain/e2e/text_to_phones.py data/lang \| \ + utils/sym2int.pl -f 2- data/lang/phones.txt \| \ + chain-est-phone-lm --num-extra-lm-states=500 \ + ark:- $treedir/phone_lm.fst +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $treedir/tree | grep num-pdfs | awk '{print $2}') + common1="required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="required-time-offsets= height-offsets=-1,0,1 num-filters-out=70" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=80 name=input + conv-relu-batchnorm-layer name=cnn1 height-in=80 height-out=80 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=80 height-out=40 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=40 height-out=40 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=40 height-out=40 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=40 height-out=20 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=20 height-out=20 time-offsets=-1,0,1 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=20 height-out=20 time-offsets=-1,0,1 $common3 + conv-relu-batchnorm-layer name=cnn8 height-in=20 height-out=10 time-offsets=-1,0,1 $common3 height-subsample-out=2 + relu-batchnorm-layer name=tdnn1 input=Append(-4,-2,0,2,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs +fi + +if [ $stage -le 3 ]; then + # no need to store the egs in a shared storage because we always + # remove them. Anyway, it takes only 5 minutes to generate them. + + steps/nnet3/chain/e2e/train_e2e.py --stage $train_stage \ + --cmd "$cmd" \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--num_egs_diagnostic 100 --num_utts_subset 400" \ + --chain.frame-subsampling-factor 4 \ + --chain.alignment-subsampling-factor 4 \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter 2000000 \ + --trainer.num-epochs 2 \ + --trainer.optimization.momentum 0 \ + --trainer.optimization.num-jobs-initial 6 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.optimization.shrink-value 1.0 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs true \ + --feat-dir data/${train_set} \ + --tree-dir $treedir \ + --dir $dir || exit 1; +fi diff --git a/egs/madcat_zh/v1/local/check_tools.sh b/egs/madcat_zh/v1/local/check_tools.sh new file mode 100755 index 00000000000..00de9778808 --- /dev/null +++ b/egs/madcat_zh/v1/local/check_tools.sh @@ -0,0 +1,49 @@ +#!/bin/bash -u + +# Copyright 2015 (c) Johns Hopkins University (Jan Trmal ) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +[ -f ./path.sh ] && . ./path.sh +set +e + +command -v python3 >&/dev/null \ + || { echo >&2 "python3 not found on PATH. You will have to install Python3, preferably >= 3.6"; exit 1; } + +python3 -c "import numpy" +if [ $? -ne 0 ] ; then + echo >&2 "This recipe needs numpy installed." + exit 1 +fi + +python3 -c "import scipy" +if [ $? -ne 0 ] ; then + echo >&2 "This recipe needs scipy installed." + exit 1 +fi + +python3 -c "from scipy.spatial import ConvexHull" +if [ $? -ne 0 ] ; then + echo >&2 "This recipe needs scipy installed." + exit 1 +fi + +python3 -c "import scipy.misc; scipy.misc.__dict__['imread'];" +if [ $? -ne 0 ] ; then + echo >&2 "This recipe needs scipy-image, scikit-image and Pillow installed." + exit 1 +fi + + +exit 0 diff --git a/egs/madcat_zh/v1/local/create_line_image_from_page_image.py b/egs/madcat_zh/v1/local/create_line_image_from_page_image.py new file mode 100755 index 00000000000..22af571fc04 --- /dev/null +++ b/egs/madcat_zh/v1/local/create_line_image_from_page_image.py @@ -0,0 +1,536 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora +# Apache 2.0 +# minimum bounding box part in this script is originally from +#https://github.com/BebeSparkelSparkel/MinimumBoundingBox +#https://startupnextdoor.com/computing-convex-hull-in-python/ + +""" This module will be used for extracting line images from page image. + Given the word segmentation (bounding box around a word) for every word, it will + extract line segmentation. To extract line segmentation, it will take word bounding + boxes of a line as input, will create a minimum area bounding box that will contain + all corner points of word bounding boxes. The obtained bounding box (will not necessarily + be vertically or horizontally aligned). Hence to extract line image from line bounding box, + page image is rotated and line image is cropped and saved. +""" + +import sys +import argparse +import os +import xml.dom.minidom as minidom +import numpy as np +from math import atan2, cos, sin, pi, degrees, sqrt +from collections import namedtuple + +from scipy.spatial import ConvexHull +from PIL import Image +from scipy.misc import toimage + +parser = argparse.ArgumentParser(description="Creates line images from page image", + epilog="E.g. " + sys.argv[0] + " data/LDC2012T15" + " data/madcat.train.raw.lineid " + " data/local/lines ", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('database_path1', type=str, + help='Path to the downloaded madcat data directory 1') +parser.add_argument('data_splits', type=str, + help='Path to file that contains the train/test/dev split information') +parser.add_argument('out_dir', type=str, + help='directory location to write output files') +parser.add_argument('--padding', type=int, default=400, + help='padding across horizontal/verticle direction') +args = parser.parse_args() + +""" +bounding_box is a named tuple which contains: + + area (float): area of the rectangle + length_parallel (float): length of the side that is parallel to unit_vector + length_orthogonal (float): length of the side that is orthogonal to unit_vector + rectangle_center(int, int): coordinates of the rectangle center + (use rectangle_corners to get the corner points of the rectangle) + unit_vector (float, float): direction of the length_parallel side. + (it's orthogonal vector can be found with the orthogonal_vector function + unit_vector_angle (float): angle of the unit vector to be in radians. + corner_points [(float, float)]: set that contains the corners of the rectangle +""" + +bounding_box_tuple = namedtuple('bounding_box_tuple', 'area ' + 'length_parallel ' + 'length_orthogonal ' + 'rectangle_center ' + 'unit_vector ' + 'unit_vector_angle ' + 'corner_points' + ) + +def unit_vector(pt0, pt1): + """ Returns an unit vector that points in the direction of pt0 to pt1. + Args: + pt0 (float, float): Point 0. Eg. (1.0, 2.0). + pt1 (float, float): Point 1. Eg. (3.0, 8.0). + + Returns: + (float, float): unit vector that points in the direction of pt0 to pt1. + Eg. 0.31622776601683794, 0.9486832980505138 + """ + dis_0_to_1 = sqrt((pt0[0] - pt1[0])**2 + (pt0[1] - pt1[1])**2) + return (pt1[0] - pt0[0])/ dis_0_to_1, \ + (pt1[1] - pt0[1])/ dis_0_to_1 + + +def orthogonal_vector(vector): + """ From vector returns a orthogonal/perpendicular vector of equal length. + Args: + vector (float, float): A vector. Eg. (0.31622776601683794, 0.9486832980505138). + + Returns: + (float, float): A vector that points in the direction orthogonal to vector. + Eg. - 0.9486832980505138,0.31622776601683794 + """ + return -1 * vector[1], vector[0] + + +def bounding_area(index, hull): + """ Returns a named tuple that mainly contains area of the box that bounds + the hull. This bounding box orintation is same as the orientation of the + lines formed by the point hull[index] and hull[index+1]. + Args: + index (int): Eg. 1. + hull [(float, float)]: list or tuple of point cloud + Eg. ((1.0, -1.0), (2.0, -3.0), (3.0, 4.0), (5.0, 6.0)). + + Returns: a named tuple that contains: + area: area of the rectangle + length_parallel: length of the side that is parallel to unit_vector + length_orthogonal: length of the side that is orthogonal to unit_vector + rectangle_center: coordinates of the rectangle center + (use rectangle_corners to get the corner points of the rectangle) + unit_vector: direction of the length_parallel side. + (it's orthogonal vector can be found with the orthogonal_vector function + """ + unit_vector_p = unit_vector(hull[index], hull[index+1]) + unit_vector_o = orthogonal_vector(unit_vector_p) + + dis_p = tuple(np.dot(unit_vector_p, pt) for pt in hull) + dis_o = tuple(np.dot(unit_vector_o, pt) for pt in hull) + + min_p = min(dis_p) + min_o = min(dis_o) + len_p = max(dis_p) - min_p + len_o = max(dis_o) - min_o + + return {'area': len_p * len_o, + 'length_parallel': len_p, + 'length_orthogonal': len_o, + 'rectangle_center': (min_p + float(len_p)/ 2, min_o + float(len_o)/ 2), + 'unit_vector': unit_vector_p, + } + + +def to_xy_coordinates(unit_vector_angle, point): + """ Returns converted unit vector coordinates in x, y coordinates. + Args: + unit_vector_angle (float): angle of unit vector to be in radians. + Eg. 0.1543 . + point (float, float): Point from origin. Eg. (1.0, 2.0). + + Returns: + (float, float): converted x,y coordinate of the unit vector. + Eg. 0.680742447866183, 2.1299271629971663 + """ + angle_orthogonal = unit_vector_angle + pi/ 2 + return point[0] * cos(unit_vector_angle) + point[1] * cos(angle_orthogonal), \ + point[0] * sin(unit_vector_angle) + point[1] * sin(angle_orthogonal) + + +def rotate_points(center_of_rotation, angle, points): + """ Rotates a point cloud around the center_of_rotation point by angle + Args: + center_of_rotation (float, float): angle of unit vector to be in radians. + Eg. (1.56, -23.4). + angle (float): angle of rotation to be in radians. Eg. 0.1543 . + points [(float, float)]: Points to be a list or tuple of points. Points to be rotated. + Eg. ((1.56, -23.4), (1.56, -23.4)) + + Returns: + [(float, float)]: Rotated points around center of rotation by angle + Eg. ((1.16, -12.4), (2.34, -34.4)) + """ + + rot_points = [] + ang = [] + for pt in points: + diff = tuple([pt[d] - center_of_rotation[d] for d in range(2)]) + diff_angle = atan2(diff[1], diff[0]) + angle + ang.append(diff_angle) + diff_length = sqrt(sum([d**2 for d in diff])) + rot_points.append((center_of_rotation[0] + diff_length * cos(diff_angle), + center_of_rotation[1] + diff_length * sin(diff_angle))) + + return rot_points + + +def rectangle_corners(rectangle): + """ Given rectangle center and its inclination. It returns the corner + locations of the rectangle. + Args: + rectangle (bounding_box): the output of minimum bounding box rectangle + + Returns: + [(float, float)]: 4 corner points of rectangle. + Eg. ((1.0, -1.0), (2.0, -3.0), (3.0, 4.0), (5.0, 6.0)) + """ + corner_points = [] + for i1 in (.5, -.5): + for i2 in (i1, -1 * i1): + corner_points.append((rectangle['rectangle_center'][0] + i1 * rectangle['length_parallel'], + rectangle['rectangle_center'][1] + i2 * rectangle['length_orthogonal'])) + + return rotate_points(rectangle['rectangle_center'], rectangle['unit_vector_angle'], corner_points) + + +# use this function to find the listed properties of the minimum bounding box of a point cloud +def minimum_bounding_box(points): + """ Given a point cloud, it returns the minimum area rectangle bounding all + the points in the point cloud. + Args: + points [(float, float)]: points to be a list or tuple of 2D points + needs to be more than 2 points + + Returns: returns a namedtuple that contains: + area: area of the rectangle + length_parallel: length of the side that is parallel to unit_vector + length_orthogonal: length of the side that is orthogonal to unit_vector + rectangle_center: coordinates of the rectangle center + (use rectangle_corners to get the corner points of the rectangle) + unit_vector: direction of the length_parallel side. RADIANS + (it's orthogonal vector can be found with the orthogonal_vector function + unit_vector_angle: angle of the unit vector + corner_points: set that contains the corners of the rectangle + """ + + if len(points) <= 2: raise ValueError('More than two points required.') + + hull_ordered = [points[index] for index in ConvexHull(points).vertices] + hull_ordered.append(hull_ordered[0]) + hull_ordered = tuple(hull_ordered) + + min_rectangle = bounding_area(0, hull_ordered) + for i in range(1, len(hull_ordered)-1): + rectangle = bounding_area(i, hull_ordered) + if rectangle['area'] < min_rectangle['area']: + min_rectangle = rectangle + + min_rectangle['unit_vector_angle'] = atan2(min_rectangle['unit_vector'][1], min_rectangle['unit_vector'][0]) + min_rectangle['rectangle_center'] = to_xy_coordinates(min_rectangle['unit_vector_angle'], min_rectangle['rectangle_center']) + + return bounding_box_tuple( + area = min_rectangle['area'], + length_parallel = min_rectangle['length_parallel'], + length_orthogonal = min_rectangle['length_orthogonal'], + rectangle_center = min_rectangle['rectangle_center'], + unit_vector = min_rectangle['unit_vector'], + unit_vector_angle = min_rectangle['unit_vector_angle'], + corner_points = set(rectangle_corners(min_rectangle)) + ) + + +def get_center(im): + """ Returns the center pixel location of an image + Args: + im: image + + Returns: + (int, int): center of the image + Eg. 2550, 3300 + """ + center_x = float(im.size[0])/ 2 + center_y = float(im.size[1])/ 2 + return int(center_x), int(center_y) + + +def get_horizontal_angle(unit_vector_angle): + """ Returns angle of the unit vector in first or fourth quadrant. + Args: + angle (float): angle of the unit vector to be in radians. Eg. 0.01543. + + Returns: + (float): updated angle of the unit vector to be in radians. + It is only in first or fourth quadrant. + Eg. 0.01543. + """ + + if unit_vector_angle > pi/ 2 and unit_vector_angle <= pi: + unit_vector_angle = unit_vector_angle - pi + elif unit_vector_angle > -pi and unit_vector_angle < -pi/ 2: + unit_vector_angle = unit_vector_angle + pi + + return unit_vector_angle + + +def get_smaller_angle(bounding_box): + """ Returns smallest absolute angle of a rectangle. + Args: + rectangle (bounding_box): bounding box rectangle + + Returns: + (float): smallest angle of the rectangle to be in radians. + Eg. 0.01543. + """ + + unit_vector = bounding_box.unit_vector + unit_vector_angle = bounding_box.unit_vector_angle + ortho_vector = orthogonal_vector(unit_vector) + ortho_vector_angle = atan2(ortho_vector[1], ortho_vector[0]) + + unit_vector_angle_updated = get_horizontal_angle(unit_vector_angle) + ortho_vector_angle_updated = get_horizontal_angle(ortho_vector_angle) + + if abs(unit_vector_angle_updated) < abs(ortho_vector_angle_updated): + return unit_vector_angle_updated + else: + return ortho_vector_angle_updated + + +def rotated_points(bounding_box, center): + """ Rotates the corners of a bounding box rectangle around the center by smallest angle + of the rectangle. It first finds the smallest angle of the rectangle + then rotates it around the given center point. + Args: + rectangle (bounding_box): bounding box rectangle + center (int, int): center point around which the corners of rectangle are rotated. + Eg. (2550, 3300). + + Returns: 4 corner points of rectangle. + Eg. ((1.0, -1.0), (2.0, -3.0), (3.0, 4.0), (5.0, 6.0)) + """ + + p1, p2, p3, p4 = bounding_box.corner_points + x1, y1 = p1 + x2, y2 = p2 + x3, y3 = p3 + x4, y4 = p4 + center_x, center_y = center + rotation_angle_in_rad = -get_smaller_angle(bounding_box) + x_dash_1 = (x1 - center_x) * cos(rotation_angle_in_rad) - (y1 - center_y) * sin(rotation_angle_in_rad) + center_x + x_dash_2 = (x2 - center_x) * cos(rotation_angle_in_rad) - (y2 - center_y) * sin(rotation_angle_in_rad) + center_x + x_dash_3 = (x3 - center_x) * cos(rotation_angle_in_rad) - (y3 - center_y) * sin(rotation_angle_in_rad) + center_x + x_dash_4 = (x4 - center_x) * cos(rotation_angle_in_rad) - (y4 - center_y) * sin(rotation_angle_in_rad) + center_x + + y_dash_1 = (y1 - center_y) * cos(rotation_angle_in_rad) + (x1 - center_x) * sin(rotation_angle_in_rad) + center_y + y_dash_2 = (y2 - center_y) * cos(rotation_angle_in_rad) + (x2 - center_x) * sin(rotation_angle_in_rad) + center_y + y_dash_3 = (y3 - center_y) * cos(rotation_angle_in_rad) + (x3 - center_x) * sin(rotation_angle_in_rad) + center_y + y_dash_4 = (y4 - center_y) * cos(rotation_angle_in_rad) + (x4 - center_x) * sin(rotation_angle_in_rad) + center_y + return x_dash_1, y_dash_1, x_dash_2, y_dash_2, x_dash_3, y_dash_3, x_dash_4, y_dash_4 + + +def pad_image(image): + """ Pads the image around the border. It help in getting + bounding boxes that are slightly outside the page boundary. + Args: + image: page image. + + Returns: + image: page image + """ + + padded_image = Image.new('RGB', (image.size[0] + padding, image.size[1] + padding), "white") + padded_image.paste(im=image, box=(offset, offset)) + return padded_image + + +def update_minimum_bounding_box_input(bounding_box_input): + """ Updates the word bounding box corner points. + Args: + points [(float, float)]: points, a list or tuple of 2D coordinates. + ideally should be more than 2 points + Returns: + points [(float, float)]: points, a list or tuple of 2D coordinates + """ + + updated_minimum_bounding_box_input = [] + for point in bounding_box_input: + x, y = point + new_x = x + offset + new_y = y + offset + word_coordinate = (new_x, new_y) + updated_minimum_bounding_box_input.append(word_coordinate) + + return updated_minimum_bounding_box_input + + +def set_line_image_data(image, line_id, image_file_name): + """ Flips a given line image and saves it. Line image file name + is formed by appending the line id at the end page image name. + Args: + image: line image, non flipped + line_id (string): id of the line image. + image_file_name(string): name of the page image. + + Returns: + """ + + base_name = os.path.splitext(os.path.basename(image_file_name))[0] + image_file_name_wo_tif, b = image_file_name.split('.tif') + line_id = '_' + line_id.zfill(4) + line_image_file_name = base_name + line_id + '.png' + image_path = os.path.join(output_directory, line_image_file_name) + imgray = toimage(image.convert('L')) + imgray.save(image_path) + image_fh.write(image_path + '\n') + +def get_line_images_from_page_image(image_file_name, madcat_file_path): + """ Extracts the line image from page image. + Args: + image_file_name (string): complete path and name of the page image. + madcat_file_path (string): complete path and name of the madcat xml file + corresponding to the page image. + + Returns: + """ + im_wo_pad = Image.open(image_file_name) + im = pad_image(im_wo_pad) + doc = minidom.parse(madcat_file_path) + zone = doc.getElementsByTagName('zone') + for node in zone: + id = node.getAttribute('id') + token_image = node.getElementsByTagName('token-image') + minimum_bounding_box_input = [] + for token_node in token_image: + word_point = token_node.getElementsByTagName('point') + for word_node in word_point: + word_coordinate = (int(word_node.getAttribute('x')), int(word_node.getAttribute('y'))) + minimum_bounding_box_input.append(word_coordinate) + updated_mbb_input = update_minimum_bounding_box_input(minimum_bounding_box_input) + bounding_box = minimum_bounding_box(updated_mbb_input) + + p1, p2, p3, p4 = bounding_box.corner_points + x1, y1 = p1 + x2, y2 = p2 + x3, y3 = p3 + x4, y4 = p4 + min_x = int(min(x1, x2, x3, x4)) + min_y = int(min(y1, y2, y3, y4)) + max_x = int(max(x1, x2, x3, x4)) + max_y = int(max(y1, y2, y3, y4)) + box = (min_x, min_y, max_x, max_y) + region_initial = im.crop(box) + rot_points = [] + p1_new = (x1 - min_x, y1 - min_y) + p2_new = (x2 - min_x, y2 - min_y) + p3_new = (x3 - min_x, y3 - min_y) + p4_new = (x4 - min_x, y4 - min_y) + rot_points.append(p1_new) + rot_points.append(p2_new) + rot_points.append(p3_new) + rot_points.append(p4_new) + + cropped_bounding_box = bounding_box_tuple(bounding_box.area, + bounding_box.length_parallel, + bounding_box.length_orthogonal, + bounding_box.length_orthogonal, + bounding_box.unit_vector, + bounding_box.unit_vector_angle, + set(rot_points) + ) + + rotation_angle_in_rad = get_smaller_angle(cropped_bounding_box) + img2 = region_initial.rotate(degrees(rotation_angle_in_rad), resample=Image.BICUBIC) + x_dash_1, y_dash_1, x_dash_2, y_dash_2, x_dash_3, y_dash_3, x_dash_4, y_dash_4 = rotated_points( + cropped_bounding_box, get_center(region_initial)) + + + min_x = int(min(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) + min_y = int(min(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) + max_x = int(max(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) + max_y = int(max(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) + box = (min_x, min_y, max_x, max_y) + region_final = img2.crop(box) + set_line_image_data(region_final, id, image_file_name) + + +def check_file_location(): + """ Returns the complete path of the page image and corresponding + xml file. + Args: + + Returns: + image_file_name (string): complete path and name of the page image. + madcat_file_path (string): complete path and name of the madcat xml file + corresponding to the page image. + """ + + madcat_file_path1 = os.path.join(data_path1, 'madcat', base_name + '.madcat.xml') + + image_file_path1 = os.path.join(data_path1, 'images', base_name + '.tif') + + if os.path.exists(madcat_file_path1): + return madcat_file_path1, image_file_path1, wc_dict1 + + print("ERROR: path does not exist") + return None, None, None + +def parse_writing_conditions(writing_conditions): + """ Returns a dictionary which have writing condition of each page image. + Args: + writing_conditions(string): complete path of writing condition file. + + Returns: + (dict): dictionary with key as page image name and value as writing condition. + """ + + with open(writing_conditions) as f: + file_writing_cond = dict() + for line in f: + line_list = line.strip().split("\t") + file_writing_cond[line_list[0]] = line_list[3] + return file_writing_cond + +def check_writing_condition(wc_dict): + """ Checks if a given page image is writing in a given writing condition. + It is used to create subset of dataset based on writing condition. + Args: + wc_dict (dict): dictionary with key as page image name and value as writing condition. + + Returns: + (bool): True if writing condition matches. + """ + + return True + writing_condition = wc_dict[base_name].strip() + if writing_condition != 'IUC': + return False + + return True + + +### main ### + +data_path1 = os.path.join(args.database_path1, 'data') + +splits_handle = open(args.data_splits, 'r') +splits_data = splits_handle.read().strip().split('\n') + +padding = int(args.padding) +offset = int(padding // 2) + +output_directory = args.out_dir +image_file = os.path.join(output_directory, 'images.scp') +image_fh = open(image_file, 'w', encoding='utf-8') + +writing_conditions1 = os.path.join(args.database_path1, 'docs', 'writing_conditions.tab') + +wc_dict1 = parse_writing_conditions(writing_conditions1) + +prev_base_name = '' +for line in splits_data: + base_name = os.path.splitext(os.path.splitext(line.split(' ')[0])[0])[0] + if prev_base_name != base_name: + prev_base_name = base_name + madcat_file_path, image_file_path, wc_dict = check_file_location() + if wc_dict == None or not check_writing_condition(wc_dict): + continue + if madcat_file_path != None: + get_line_images_from_page_image(image_file_path, madcat_file_path) diff --git a/egs/madcat_zh/v1/local/extract_features.sh b/egs/madcat_zh/v1/local/extract_features.sh new file mode 100755 index 00000000000..9fe588f31b8 --- /dev/null +++ b/egs/madcat_zh/v1/local/extract_features.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +# Copyright 2017 Yiwen Shao +# 2018 Ashish Arora + +# Apache 2.0 +# This script runs the make features script in parallel. + +nj=4 +cmd=run.pl +feat_dim=40 +augment='no_aug' +verticle_shift=0 +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +data=$1 +featdir=$data/data +scp=$data/images.scp +logdir=$data/log + +mkdir -p $logdir +mkdir -p $featdir + +# make $featdir an absolute pathname +featdir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $featdir ${PWD}` + +for n in $(seq $nj); do + split_scps="$split_scps $logdir/images.$n.scp" +done + +# split images.scp +utils/split_scp.pl $scp $split_scps || exit 1; + +$cmd JOB=1:$nj $logdir/extract_features.JOB.log \ + image/ocr/make_features.py $logdir/images.JOB.scp \ + --allowed_len_file_path $data/allowed_lengths.txt \ + --feat-dim $feat_dim --augment_type $augment \ + --vertical-shift $verticle_shift \| \ + copy-feats --compress=true --compression-method=7 \ + ark:- ark,scp:$featdir/images.JOB.ark,$featdir/images.JOB.scp + +## aggregates the output scp's to get feats.scp +for n in $(seq $nj); do + cat $featdir/images.$n.scp || exit 1; +done > $data/feats.scp || exit 1 diff --git a/egs/madcat_zh/v1/local/extract_lines.sh b/egs/madcat_zh/v1/local/extract_lines.sh new file mode 100755 index 00000000000..ed752e97e13 --- /dev/null +++ b/egs/madcat_zh/v1/local/extract_lines.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# Copyright 2018 Ashish Arora + +nj=4 +cmd=run.pl +download_dir=/export/corpora/LDC/LDC2014T13 +dataset_file=data/download/datasplits/madcat.dev.raw.lineid +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +data=$1 +log_dir=$data/log +mkdir -p $log_dir +mkdir -p $data + +for n in $(seq $nj); do + split_scps="$split_scps $log_dir/lines.$n.scp" +done + +utils/split_scp.pl $dataset_file $split_scps || exit 1; + +for n in $(seq $nj); do + mkdir -p $data/$n +done + +$cmd JOB=1:$nj $log_dir/extract_lines.JOB.log \ + local/create_line_image_from_page_image.py $download_dir $log_dir/lines.JOB.scp $data/JOB \ + || exit 1; + +## concatenate the .scp files together. +for n in $(seq $nj); do + cat $data/$n/images.scp || exit 1; +done > $data/images.scp || exit 1 diff --git a/egs/madcat_zh/v1/local/prepare_data.sh b/egs/madcat_zh/v1/local/prepare_data.sh new file mode 100755 index 00000000000..ba35b90b173 --- /dev/null +++ b/egs/madcat_zh/v1/local/prepare_data.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +# Copyright 2017 Chun Chieh Chang +# 2017 Ashish Arora +# 2017 Hossein Hadian +# Apache 2.0 + +# This script downloads the Madcat Chinese handwriting database and prepares the training +# and test data (i.e text, images.scp, utt2spk and spk2utt) by calling process_data.py. +# It also downloads the LOB and Brown text corpora. It downloads the database files +# only if they do not already exist in download directory. + +# Eg. local/prepare_data.sh +# Eg. text file: 000_a01-000u-00 A MOVE to stop Mr. Gaitskell from +# utt2spk file: 000_a01-000u-00 000 +# images.scp file: 000_a01-000u-00 data/local/lines/a01/a01-000u/a01-000u-00.png +# spk2utt file: 000 000_a01-000u-00 000_a01-000u-01 000_a01-000u-02 000_a01-000u-03 + +download_dir1=/export/corpora/LDC/LDC2014T13/data +train_split_url=http://www.openslr.org/resources/50/madcat.train.raw.lineid +test_split_url=http://www.openslr.org/resources/50/madcat.test.raw.lineid +dev_split_url=http://www.openslr.org/resources/50/madcat.dev.raw.lineid +data_split_dir=data/download/datasplits + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +if [ -d $data_split_dir ]; then + echo "$0: Not downloading the data splits as it is already there." +else + if [ ! -f $data_split_dir/madcat.train.raw.lineid ]; then + mkdir -p $data_split_dir + echo "$0: Downloading the data splits..." + wget -P $data_split_dir $train_split_url || exit 1; + wget -P $data_split_dir $test_split_url || exit 1; + wget -P $data_split_dir $dev_split_url || exit 1; + fi + echo "$0: Done downloading the data splits" +fi + +if [ -d $download_dir1 ]; then + echo "$0: madcat chinese data directory is present." +else + if [ ! -f $download_dir1/madcat/*.madcat.xml ]; then + echo "$0: please download madcat data..." + fi +fi diff --git a/egs/madcat_zh/v1/local/prepare_dict.sh b/egs/madcat_zh/v1/local/prepare_dict.sh new file mode 100755 index 00000000000..f9cd8387fad --- /dev/null +++ b/egs/madcat_zh/v1/local/prepare_dict.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash + +# Copyright 2017 Hossein Hadian +# 2017 Chun Chieh Chang +# 2017 Ashish Arora + +# This script prepares the dictionary. + +set -e +dir=data/local/dict +mkdir -p $dir + +#local/prepare_lexicon.py data/train $dir +cat data/train/text | cut -d' ' -f2- | tr ' ' '\n' | sort -u | sed '/^$/d' | \ + python3 -c \ + 'import sys, io; \ + sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding="utf8"); \ + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf8"); \ + [sys.stdout.write(line.strip() + " " + " ".join(list(line.strip())) + "\n") for line in sys.stdin];' > $dir/lexicon.txt + +cut -d' ' -f2- $dir/lexicon.txt | tr ' ' '\n' | sort -u >$dir/nonsilence_phones.txt || exit 1; + +echo ' SIL' >> $dir/lexicon.txt + +echo SIL > $dir/silence_phones.txt + +echo SIL >$dir/optional_silence.txt + +echo -n "" >$dir/extra_questions.txt diff --git a/egs/madcat_zh/v1/local/process_data.py b/egs/madcat_zh/v1/local/process_data.py new file mode 100755 index 00000000000..994a4486420 --- /dev/null +++ b/egs/madcat_zh/v1/local/process_data.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora + +""" This script reads the extracted IAM database files and creates + the following files (for the data subset selected via --dataset): + text, utt2spk, images.scp. + + Eg. local/process_data.py data/local data/train data --dataset train + Eg. text file: 000_a01-000u-00 A MOVE to stop Mr. Gaitskell from + utt2spk file: 000_a01-000u-00 000 + images.scp file: 000_a01-000u-00 data/local/lines/a01/a01-000u/a01-000u-00.png +""" + +import argparse +import os +import sys +import xml.dom.minidom as minidom +import unicodedata + +parser = argparse.ArgumentParser(description="Creates text, utt2spk and images.scp files", + epilog="E.g. " + sys.argv[0] + " data/LDC2012T15" + " data/LDC2013T09 data/LDC2013T15 data/madcat.train.raw.lineid " + " data/train data/local/lines ", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('database_path1', + help='Path to the downloaded (and extracted) madcat data') +parser.add_argument('data_splits', + help='Path to file that contains the train/test/dev split information') +parser.add_argument('out_dir', + help='directory location to write output files.') +args = parser.parse_args() + + +def check_file_location(): + """ Returns the complete path of the page image and corresponding + xml file. + Args: + + Returns: + image_file_name (string): complete path and name of the page image. + madcat_file_path (string): complete path and name of the madcat xml file + corresponding to the page image. + """ + + madcat_file_path1 = os.path.join(args.database_path1, 'data', 'madcat', base_name + '.madcat.xml') + + image_file_path1 = os.path.join(args.database_path1, 'data', 'images', base_name + '.tif') + + if os.path.exists(madcat_file_path1): + return madcat_file_path1, image_file_path1, wc_dict1 + + return None, None, None + + +def parse_writing_conditions(writing_conditions): + """ Returns a dictionary which have writing condition of each page image. + Args: + writing_conditions(string): complete path of writing condition file. + + Returns: + (dict): dictionary with key as page image name and value as writing condition. + """ + + with open(writing_conditions) as f: + file_writing_cond = dict() + for line in f: + line_list = line.strip().split("\t") + file_writing_cond[line_list[0]] = line_list[3] + return file_writing_cond + + +def check_writing_condition(wc_dict): + """ Checks if a given page image is writing in a given writing condition. + It is used to create subset of dataset based on writing condition. + Args: + wc_dict (dict): dictionary with key as page image name and value as writing condition. + + Returns: + (bool): True if writing condition matches. + """ + + return True + writing_condition = wc_dict[base_name].strip() + if writing_condition != 'IUC': + return False + + return True + + +def get_word_line_mapping(madcat_file_path): + """ Maps every word in the page image to a corresponding line. + Args: + madcat_file_path (string): complete path and name of the madcat xml file + corresponding to the page image. + + Returns: + """ + + doc = minidom.parse(madcat_file_path) + zone = doc.getElementsByTagName('zone') + for node in zone: + line_id = node.getAttribute('id') + line_word_dict[line_id] = list() + word_image = node.getElementsByTagName('token-image') + for tnode in word_image: + word_id = tnode.getAttribute('id') + line_word_dict[line_id].append(word_id) + word_line_dict[word_id] = line_id + + +def read_text(madcat_file_path): + """ Maps every word in the page image to a corresponding line. + Args: + madcat_file_path (string): complete path and name of the madcat xml file + corresponding to the page image. + + Returns: + dict: Mapping every word in the page image to a corresponding line. + """ + + text_line_word_dict = dict() + doc = minidom.parse(madcat_file_path) + segment = doc.getElementsByTagName('segment') + for node in segment: + token = node.getElementsByTagName('token') + for tnode in token: + segment_id = tnode.getAttribute('id') + ref_word_id = tnode.getAttribute('ref_id') + word = tnode.getElementsByTagName('source')[0].firstChild.nodeValue + word = unicodedata.normalize('NFKC',word) + ref_line_id = word_line_dict[ref_word_id] + if ref_line_id not in text_line_word_dict: + text_line_word_dict[ref_line_id] = list() + text_line_word_dict[ref_line_id].append(word) + return text_line_word_dict + + +def get_line_image_location(): + image_loc_dict = dict() # Stores image base name and location + image_loc_vect = input_image_fh.read().strip().split("\n") + for line in image_loc_vect: + base_name = os.path.basename(line) + location_vect = line.split('/') + location = "/".join(location_vect[:-1]) + image_loc_dict[base_name]=location + return image_loc_dict + + +text_file = os.path.join(args.out_dir, 'text') +text_fh = open(text_file, 'w', encoding='utf-8') +utt2spk_file = os.path.join(args.out_dir, 'utt2spk') +utt2spk_fh = open(utt2spk_file, 'w', encoding='utf-8') +image_file = os.path.join(args.out_dir, 'images.scp') +image_fh = open(image_file, 'w', encoding='utf-8') + +data_path1 = os.path.join(args.database_path1, 'data') + +input_image_file = os.path.join(args.out_dir, 'lines', 'images.scp') +input_image_fh = open(input_image_file, 'r', encoding='utf-8') + +writing_conditions1 = os.path.join(args.database_path1, 'docs', 'writing_conditions.tab') + +wc_dict1 = parse_writing_conditions(writing_conditions1) +image_loc_dict = get_line_image_location() + +image_num = 0 +with open(args.data_splits) as f: + prev_base_name = '' + for line in f: + base_name = os.path.splitext(os.path.splitext(line.split(' ')[0])[0])[0] + if prev_base_name != base_name: + prev_base_name = base_name + madcat_xml_path, image_file_path, wc_dict = check_file_location() + if wc_dict is None or not check_writing_condition(wc_dict): + continue + if madcat_xml_path is not None: + madcat_doc = minidom.parse(madcat_xml_path) + writer = madcat_doc.getElementsByTagName('writer') + writer_id = writer[0].getAttribute('id') + line_word_dict = dict() + word_line_dict = dict() + get_word_line_mapping(madcat_xml_path) + text_line_word_dict = read_text(madcat_xml_path) + base_name = os.path.basename(image_file_path) + base_name, b = base_name.split('.tif') + for lineID in sorted(text_line_word_dict): + updated_base_name = "{}_{}.png".format(base_name, str(lineID).zfill(4)) + location = image_loc_dict[updated_base_name] + image_file_path = os.path.join(location, updated_base_name) + line = text_line_word_dict[lineID] + text = ' '.join(''.join(line)) + utt_id = "{}_{}_{}_{}".format(writer_id, str(image_num).zfill(6), base_name, str(lineID).zfill(4)) + text_fh.write(utt_id + ' ' + text + '\n') + utt2spk_fh.write(utt_id + ' ' + writer_id + '\n') + image_fh.write(utt_id + ' ' + image_file_path + '\n') + image_num = image_num + 1 diff --git a/egs/madcat_zh/v1/local/process_segments.py b/egs/madcat_zh/v1/local/process_segments.py new file mode 100755 index 00000000000..3d09c0df190 --- /dev/null +++ b/egs/madcat_zh/v1/local/process_segments.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Chun Chieh Chang + +""" This script reads the provided word segmentations of chinese + and ensures that all of them are normalized to the same + unicode form. +""" + +import argparse +import os +import unicodedata + +parser = argparse.ArgumentParser(description="""Takes in word segmentations and normalizes character form.""") +parser.add_argument('segmentation_path', type=str, + help='Path to chinese word segmentation') +parser.add_argument('out_dir', type=str, + help='Where to write output file') +args = parser.parse_args() + +segment_file = os.path.join(args.out_dir, 'segmented_words') +segment_fh = open(segment_file, 'w', encoding='utf-8') + +with open(args.segmentation_path, encoding='utf-8') as f: + for line in f: + line_normalize = unicodedata.normalize('NFKC', line) + segment_fh.write(line_normalize + '\n') diff --git a/egs/madcat_zh/v1/local/score.sh b/egs/madcat_zh/v1/local/score.sh new file mode 100755 index 00000000000..31564d25326 --- /dev/null +++ b/egs/madcat_zh/v1/local/score.sh @@ -0,0 +1,5 @@ +#!/bin/bash + + +steps/scoring/score_kaldi_wer.sh "$@" +steps/scoring/score_kaldi_cer.sh --stage 2 "$@" diff --git a/egs/madcat_zh/v1/local/train_lm.sh b/egs/madcat_zh/v1/local/train_lm.sh new file mode 100755 index 00000000000..a8e2dc71f28 --- /dev/null +++ b/egs/madcat_zh/v1/local/train_lm.sh @@ -0,0 +1,108 @@ +#!/bin/bash + +# Copyright 2016 Vincent Nguyen +# 2016 Johns Hopkins University (author: Daniel Povey) +# 2017 Ashish Arora +# 2017 Hossein Hadian +# Apache 2.0 +# +# This script trains an LM on the LOB+Brown text data and IAM training transcriptions. +# It is based on the example scripts distributed with PocoLM + +# It will check if pocolm is installed and if not will proceed with installation + +set -e +stage=0 + +echo "$0 $@" # Print the command line for logging +. ./utils/parse_options.sh || exit 1; + +dir=data/local/local_lm +lm_dir=${dir}/data + + +mkdir -p $dir +. ./path.sh || exit 1; # for KALDI_ROOT +export PATH=$KALDI_ROOT/tools/pocolm/scripts:$PATH +( # First make sure the pocolm toolkit is installed. + cd $KALDI_ROOT/tools || exit 1; + if [ -d pocolm ]; then + echo Not installing the pocolm toolkit since it is already there. + else + echo "$0: Please install the PocoLM toolkit with: " + echo " cd ../../../tools; extras/install_pocolm.sh; cd -" + exit 1; + fi +) || exit 1; + +bypass_metaparam_optim_opt= +# If you want to bypass the metaparameter optimization steps with specific metaparameters +# un-comment the following line, and change the numbers to some appropriate values. +# You can find the values from output log of train_lm.py. +# These example numbers of metaparameters is for 4-gram model (with min-counts) +# running with train_lm.py. +# The dev perplexity should be close to the non-bypassed model. +#bypass_metaparam_optim_opt= +# Note: to use these example parameters, you may need to remove the .done files +# to make sure the make_lm_dir.py be called and tain only 3-gram model +#for order in 3; do +#rm -f ${lm_dir}/${num_word}_${order}.pocolm/.done + +if [ $stage -le 0 ]; then + mkdir -p ${dir}/data + mkdir -p ${dir}/data/text + + echo "$0: Getting the Data sources" + + rm ${dir}/data/text/* 2>/dev/null || true + + # use the validation data as the dev set. + # Note: the name 'dev' is treated specially by pocolm, it automatically + # becomes the dev set. + + cat data/dev/text | cut -d " " -f 2- > ${dir}/data/text/dev.txt + + # use the training data as an additional data source. + # we can later fold the dev data into this. + cat data/train/text | cut -d " " -f 2- > ${dir}/data/text/madcat.txt + + # for reporting perplexities, we'll use the "real" dev set. + # (the validation data is used as ${dir}/data/text/dev.txt to work + # out interpolation weights.) + # note, we can't put it in ${dir}/data/text/, because then pocolm would use + # it as one of the data sources. + cut -d " " -f 2- < data/test/text > ${dir}/data/real_dev_set.txt + + # get the wordlist from IAM text + cat ${dir}/data/text/madcat.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + cat ${dir}/data/word_count | awk '{print $2}' > ${dir}/data/wordlist +fi + +order=3 + +if [ $stage -le 1 ]; then + # decide on the vocabulary. + # Note: you'd use --wordlist if you had a previously determined word-list + # that you wanted to use. + # Note: if you have more than one order, use a certain amount of words as the + # vocab and want to restrict max memory for 'sort', + echo "$0: training the unpruned LM" + min_counts='train=2 madcat=1' + wordlist=${dir}/data/wordlist + + lm_name="`basename ${wordlist}`_${order}" + if [ -n "${min_counts}" ]; then + lm_name+="_`echo ${min_counts} | tr -s "[:blank:]" "_" | tr "=" "-"`" + fi + unpruned_lm_dir=${lm_dir}/${lm_name}.pocolm + train_lm.py --wordlist=${wordlist} --num-splits=5 --warm-start-ratio=1 \ + --limit-unk-history=true \ + ${bypass_metaparam_optim_opt} \ + ${dir}/data/text ${order} ${lm_dir}/work ${unpruned_lm_dir} + + get_data_prob.py ${dir}/data/real_dev_set.txt ${unpruned_lm_dir} 2>&1 | grep -F '[perplexity' + #log-prob: -5.05603614242 [perplexity = 156.967086371] over 19477.0 words + + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${unpruned_lm_dir} | gzip -c > ${dir}/data/arpa/${order}gram_unpruned.arpa.gz +fi diff --git a/egs/madcat_zh/v1/local/wer_output_filter b/egs/madcat_zh/v1/local/wer_output_filter new file mode 100755 index 00000000000..5d5290ad8c3 --- /dev/null +++ b/egs/madcat_zh/v1/local/wer_output_filter @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# Copyright 2017 Hossein Hadian + +# This is a filter used in scoring. It separates all +# punctuations from words. For e.g. this sentence: + +# "They have come!" he said reverently, gripping his +# hands. "Isn't it a glorious thing! Long awaited." + +# is converted to this: + +# " They have come ! " he said reverently , gripping his +# hands . " Isn ' t it a glorious thing ! Long awaited . " + +import sys +import io +import re +from collections import OrderedDict + +sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding="utf8"); +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf8"); + +re_dict = OrderedDict([("“","\""), ("”","\"")]) +pattern = re.compile("|".join(re.escape(key) for key in re_dict.keys())) + +for line in sys.stdin: + words = line.strip().split() + uttid = words[0] + transcript = ' '.join(words[1:]) + transcript_fixed = pattern.sub(lambda x: re_dict[x.group()], transcript) + sys.stdout.write(uttid + " " + transcript_fixed + "\n") diff --git a/egs/madcat_zh/v1/path.sh b/egs/madcat_zh/v1/path.sh new file mode 100755 index 00000000000..2d17b17a84a --- /dev/null +++ b/egs/madcat_zh/v1/path.sh @@ -0,0 +1,6 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C diff --git a/egs/madcat_zh/v1/run.sh b/egs/madcat_zh/v1/run.sh new file mode 100755 index 00000000000..b3ef370c830 --- /dev/null +++ b/egs/madcat_zh/v1/run.sh @@ -0,0 +1,159 @@ +#!/bin/bash + +# Copyright 2017 Chun Chieh Chang +# 2017 Ashish Arora +# 2017 Hossein Hadian + +set -e +stage=0 +nj=50 +decode_gmm=true +# madcat_database points to the database path on the JHU grid. If you have not +# already downloaded the database you can set it to a local directory +# like "data/download" and follow the instructions +# in "local/download_data.sh" to download the database: +# data_split_dir is an unofficial datasplit that is used. +# The datasplits can be found on http://www.openslr.org/51/ +madcat_database=/export/corpora/LDC/LDC2014T13 +data_split_dir=data/download/datasplits +overwrite=false +corpus_dir=/export/corpora5/handwriting_ocr/corpus_data/zh/ + +. ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. + ## This relates to the queue. +. ./path.sh +. ./utils/parse_options.sh # e.g. this parses the above options + # if supplied. +./local/check_tools.sh + +# Start from stage=-1 for using extra corpus text +if [ $stage -le -1 ]; then + echo "$(date): getting corpus text for language modelling..." + mkdir -p data/local/text/cleaned + cat $corpus_dir/* > data/local/text/zh.txt + head -20000 data/local/text/zh.txt > data/local/text/cleaned/val.txt + tail -n +20000 data/local/text/zh.txt > data/local/text/cleaned/corpus.txt +fi + +mkdir -p data/{train,test,dev}/lines +if [ $stage -le 0 ]; then + + if [ -f data/train/text ] && ! $overwrite; then + echo "$0: Not processing, probably script have run from wrong stage" + echo "Exiting with status 1 to avoid data corruption" + exit 1; + fi + + echo "$0: Preparing data..." + local/prepare_data.sh --download-dir1 $madcat_database/data --data-split-dir $data_split_dir + + for dataset in train test dev; do + local/extract_lines.sh --nj $nj --cmd $cmd \ + --download-dir $madcat_database + --dataset-file $data_split_dir/madcat.${dataset}.raw.lineid \ + data/${dataset}/lines + done + + echo "$0: Processing data..." + for set in dev train test; do + local/process_data.py $madcat_database $data_split_dir/madcat.$set.raw.lineid data/$set + image/fix_data_dir.sh data/$set + done +fi + +mkdir -p data/{train,test,dev}/data +if [ $stage -le 1 ]; then + for dataset in train test dev; do + local/extract_features.sh --nj $nj --cmd $cmd --feat-dim 60 data/$dataset + steps/compute_cmvn_stats.sh data/$dataset + done +fi + +if [ $stage -le 2 ]; then +echo "$0: Preparing dictionary and lang..." + local/prepare_dict.sh + utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 16 --sil-prob 0.95 \ + --position-dependent-phones false \ + data/local/dict "" data/lang/temp data/lang +fi + +if [ $stage -le 3 ]; then + echo "$0: Estimating a language model for decoding..." + local/train_lm.sh + utils/format_lm.sh data/lang data/local/local_lm/data/arpa/3gram_unpruned.arpa.gz \ + data/local/dict/lexicon.txt data/lang_test +fi + +if [ $stage -le 4 ]; then + steps/train_mono.sh --nj $nj --cmd $cmd --totgauss 10000 data/train \ + data/lang exp/mono +fi + +if [ $stage -le 5 ] && $decode_gmm; then + utils/mkgraph.sh --mono data/lang_test exp/mono exp/mono/graph + + steps/decode.sh --nj $nj --cmd $cmd exp/mono/graph data/test \ + exp/mono/decode_test +fi + +if [ $stage -le 6 ]; then + steps/align_si.sh --nj $nj --cmd $cmd data/train data/lang \ + exp/mono exp/mono_ali + + steps/train_deltas.sh --cmd $cmd --context-opts "--context-width=2 --central-position=1" \ + 50000 20000 data/train data/lang \ + exp/mono_ali exp/tri +fi + +if [ $stage -le 7 ] && $decode_gmm; then + utils/mkgraph.sh data/lang_test exp/tri exp/tri/graph + + steps/decode.sh --nj $nj --cmd $cmd exp/tri/graph data/test \ + exp/tri/decode_test +fi + +if [ $stage -le 8 ]; then + steps/align_si.sh --nj $nj --cmd $cmd data/train data/lang \ + exp/tri exp/tri_ali + + steps/train_lda_mllt.sh --cmd $cmd \ + --splice-opts "--left-context=3 --right-context=3" \ + --context-opts "--context-width=2 --central-position=1" 50000 20000 \ + data/train data/lang exp/tri_ali exp/tri2 +fi + +if [ $stage -le 9 ] && $decode_gmm; then + utils/mkgraph.sh data/lang_test exp/tri2 exp/tri2/graph + + steps/decode.sh --nj $nj --cmd $cmd exp/tri2/graph \ + data/test exp/tri2/decode_test +fi + +if [ $stage -le 10 ]; then + steps/align_fmllr.sh --nj $nj --cmd $cmd --use-graphs true \ + data/train data/lang exp/tri2 exp/tri2_ali + + steps/train_sat.sh --cmd $cmd --context-opts "--context-width=2 --central-position=1" \ + 50000 20000 data/train data/lang \ + exp/tri2_ali exp/tri3 +fi + +if [ $stage -le 11 ] && $decode_gmm; then + utils/mkgraph.sh data/lang_test exp/tri3 exp/tri3/graph + + steps/decode_fmllr.sh --nj $nj --cmd $cmd exp/tri3/graph \ + data/test exp/tri3/decode_test +fi + +if [ $stage -le 12 ]; then + steps/align_fmllr.sh --nj $nj --cmd $cmd --use-graphs true \ + data/train data/lang exp/tri3 exp/tri3_ali +fi + +if [ $stage -le 13 ]; then + local/chain/run_cnn_1a.sh +fi + +if [ $stage -le 14 ]; then + local/chain/run_cnn_chainali_1b.sh --chain-model-dir exp/chain/cnn_1a --stage 2 +fi diff --git a/egs/madcat_zh/v1/run_end2end.sh b/egs/madcat_zh/v1/run_end2end.sh new file mode 100755 index 00000000000..7e0fc1e25d1 --- /dev/null +++ b/egs/madcat_zh/v1/run_end2end.sh @@ -0,0 +1,107 @@ +#!/bin/bash +# Copyright 2017 Hossein Hadian + +set -e +stage=0 +nj=50 +username= +password= +# iam_database points to the database path on the JHU grid. If you have not +# already downloaded the database you can set it to a local directory +# like "data/download" and follow the instructions +# in "local/prepare_data.sh" to download the database: +madcat_database=/export/corpora/LDC/LDC2014T13 +data_split_dir=data/download/datasplits +overwrite=false +corpus_dir=/export/corpora5/handwriting_ocr/corpus_data/zh/ + +. ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. + ## This relates to the queue. +. ./path.sh +. ./utils/parse_options.sh # e.g. this parses the above options + # if supplied. +./local/check_tools.sh + + +# Start from stage=-1 for using extra corpus text +if [ $stage -le -1 ]; then + echo "$(date): getting corpus text for language modelling..." + mkdir -p data/local/text/cleaned + cat $corpus_dir/* > data/local/text/zh.txt + head -20000 data/local/text/zh.txt > data/local/text/cleaned/val.txt + tail -n +20000 data/local/text/zh.txt > data/local/text/cleaned/corpus.txt +fi + +if [ $stage -le 0 ]; then + if [ -f data/train/text ] && ! $overwrite; then + echo "$0: Not processing, probably script have run from wrong stage" + echo "Exiting with status 1 to avoid data corruption" + exit 1; + fi + + echo "$0: Preparing data..." + local/prepare_data.sh --download-dir1 $madcat_database/data --data-split-dir $data_split_dir + + for dataset in train test dev; do + local/extract_lines.sh --nj $nj --cmd $cmd \ + --download-dir $madcat_database \ + --dataset-file $data_split_dir/madcat.${dataset}.raw.lineid \ + data/${dataset}/lines + done + + echo "$0: Processing data..." + for set in dev train test; do + local/process_data.py $madcat_database $data_split_dir/madcat.$set.raw.lineid data/$set + image/fix_data_dir.sh data/$set + done + +fi + +mkdir -p data/{train,test}/data +if [ $stage -le 1 ]; then + image/get_image2num_frames.py --feat-dim 80 data/train # This will be needed for the next command + # The next command creates a "allowed_lengths.txt" file in data/train + # which will be used by local/make_features.py to enforce the images to + # have allowed lengths. The allowed lengths will be spaced by 10% difference in length. + image/get_allowed_lengths.py --frame-subsampling-factor 4 10 data/train + echo "$0: Preparing the test and train feature files..." + for dataset in train test; do + local/extract_features.sh --nj $nj --cmd $cmd --feat-dim 80 data/$dataset + steps/compute_cmvn_stats.sh data/$dataset + done + utils/fix_data_dir.sh data/train +fi + +if [ $stage -le 2 ]; then + echo "$0: Preparing dictionary and lang..." + local/prepare_dict.sh + utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 16 --sil-prob 0.95 \ + --position-dependent-phones false \ + data/local/dict "" data/lang/temp data/lang +fi + +if [ $stage -le 3 ]; then + echo "$0: calling the flat-start chain recipe..." + local/chain/run_e2e_cnn.sh +fi + +lang_decode=data/lang_test +decode_e2e=true +if [ $stage -le 4 ]; then + echo "$0: Estimating a language model for decoding..." + local/train_lm.sh + utils/format_lm.sh data/lang data/local/local_lm/data/arpa/3gram_unpruned.arpa.gz \ + data/local/dict/lexicon.txt $lang_decode +fi + +if [ $stage -le 5 ] && $decode_e2e; then + echo "$0: $(date) stage 5: decoding end2end setup..." + utils/mkgraph.sh --self-loop-scale 1.0 $lang_decode \ + exp/chain/e2e_cnn_1a/ exp/chain/e2e_cnn_1a/graph || exit 1; + + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 --nj $nj --cmd "$cmd" \ + exp/chain/e2e_cnn_1a/graph data/test exp/chain/e2e_cnn_1a/decode_test || exit 1; + + echo "$0: Done. Date: $(date). Results:" + local/chain/compare_wer.sh exp/chain/e2e_cnn_1a/ +fi diff --git a/egs/madcat_zh/v1/steps b/egs/madcat_zh/v1/steps new file mode 120000 index 00000000000..1b186770dd1 --- /dev/null +++ b/egs/madcat_zh/v1/steps @@ -0,0 +1 @@ +../../wsj/s5/steps/ \ No newline at end of file diff --git a/egs/madcat_zh/v1/utils b/egs/madcat_zh/v1/utils new file mode 120000 index 00000000000..a3279dc8679 --- /dev/null +++ b/egs/madcat_zh/v1/utils @@ -0,0 +1 @@ +../../wsj/s5/utils/ \ No newline at end of file diff --git a/egs/mini_librispeech/s5/local/chain/run_cnn_tdnn.sh b/egs/mini_librispeech/s5/local/chain/run_cnn_tdnn.sh new file mode 120000 index 00000000000..ab83f3c43e8 --- /dev/null +++ b/egs/mini_librispeech/s5/local/chain/run_cnn_tdnn.sh @@ -0,0 +1 @@ +tuning/run_cnn_tdnn_1a.sh \ No newline at end of file diff --git a/egs/mini_librispeech/s5/local/chain/tuning/run_cnn_tdnn_1a.sh b/egs/mini_librispeech/s5/local/chain/tuning/run_cnn_tdnn_1a.sh new file mode 100755 index 00000000000..c8f2503b578 --- /dev/null +++ b/egs/mini_librispeech/s5/local/chain/tuning/run_cnn_tdnn_1a.sh @@ -0,0 +1,307 @@ +#!/bin/bash + +# run_cnn_tdnn_1a.sh is modified from run_tdnn_1h.sh, but adding CNN layers +# near the beginning. + +# local/chain/compare_wer.sh --online exp/chain/tdnn1h_sp exp/chain/cnn_tdnn1a_sp +# System tdnn1h_sp cnn_tdnn1a_sp +#WER dev_clean_2 (tgsmall) 12.09 11.15 +# [online:] 12.11 11.17 +#WER dev_clean_2 (tglarge) 8.59 7.79 +# [online:] 8.76 7.80 +# Final train prob -0.0493 -0.0467 +# Final valid prob -0.0805 -0.0789 +# Final train prob (xent) -1.1730 -1.0767 +# Final valid prob (xent) -1.3872 -1.3070 +# Num-params 5207856 4492816 + +# Set -e here so that we catch if any executable fails immediately +set -euo pipefail + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +decode_nj=10 +train_set=train_clean_5 +test_sets=dev_clean_2 +gmm=tri3b +nnet3_affix= + +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +affix=1a # affix for the TDNN directory name +tree_affix= +train_stage=-10 +get_egs_stage=-10 +decode_iter= + +# training options +# training chunk-options +chunk_width=140,100,160 +dropout_schedule='0,0@0.20,0.3@0.50,0' +common_egs_dir= +xent_regularize=0.1 + +# training options +srand=0 +remove_egs=true +reporting_email= + +#decode options +test_online_decoding=true # if true, it will run the last decoding stage. + + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 11 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj 75 --cmd "$train_cmd" ${lores_train_data_dir} \ + data/lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 12 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 3500 ${lores_train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 13 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + + cnn_opts="l2-regularize=0.03" + ivector_affine_opts="l2-regularize=0.03" + tdnn_opts="l2-regularize=0.03 dropout-proportion=0.0 dropout-per-dim-continuous=true" + tdnnf_first_opts="l2-regularize=0.03 dropout-proportion=0.0 bypass-scale=0.0" + tdnnf_opts="l2-regularize=0.03 dropout-proportion=0.0 bypass-scale=0.66" + linear_opts="l2-regularize=0.03 orthonormal-constraint=-1.0" + prefinal_opts="l2-regularize=0.03" + output_opts="l2-regularize=0.015" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # this takes the MFCCs and generates filterbank coefficients. The MFCCs + # are more compressible so we prefer to dump the MFCCs to disk rather + # than filterbanks. + idct-layer name=idct input=input dim=40 cepstral-lifter=22 affine-transform-file=$dir/configs/idct.mat + + linear-component name=ivector-linear $ivector_affine_opts dim=200 input=ReplaceIndex(ivector, t, 0) + batchnorm-component name=ivector-batchnorm target-rms=0.025 + + batchnorm-component name=idct-batchnorm input=idct + combine-feature-maps-layer name=combine_inputs input=Append(idct-batchnorm, ivector-batchnorm) num-filters1=1 num-filters2=5 height=40 + + conv-relu-batchnorm-layer name=cnn1 $cnn_opts height-in=40 height-out=40 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=48 learning-rate-factor=0.333 max-change=0.25 + conv-relu-batchnorm-layer name=cnn2 $cnn_opts height-in=40 height-out=40 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=48 + conv-relu-batchnorm-layer name=cnn3 $cnn_opts height-in=40 height-out=20 height-subsample-out=2 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=64 + conv-relu-batchnorm-layer name=cnn4 $cnn_opts height-in=20 height-out=20 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=64 + conv-relu-batchnorm-layer name=cnn5 $cnn_opts height-in=20 height-out=10 height-subsample-out=2 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=64 + conv-relu-batchnorm-layer name=cnn6 $cnn_opts height-in=10 height-out=5 height-subsample-out=2 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=128 + + # the first TDNN-F layer has no bypass (since dims don't match), and a larger bottleneck so the + # information bottleneck doesn't become a problem. (we use time-stride=0 so no splicing, to + # limit the num-parameters). + tdnnf-layer name=tdnnf7 $tdnnf_first_opts dim=768 bottleneck-dim=192 time-stride=0 + tdnnf-layer name=tdnnf8 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + tdnnf-layer name=tdnnf9 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + tdnnf-layer name=tdnnf10 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + tdnnf-layer name=tdnnf11 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + tdnnf-layer name=tdnnf12 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + tdnnf-layer name=tdnnf13 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + tdnnf-layer name=tdnnf14 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + tdnnf-layer name=tdnnf15 $tdnnf_opts dim=768 bottleneck-dim=96 time-stride=3 + linear-component name=prefinal-l dim=192 $linear_opts + + ## adding the layers for chain branch + prefinal-layer name=prefinal-chain input=prefinal-l $prefinal_opts small-dim=192 big-dim=768 + output-layer name=output include-log-softmax=false dim=$num_targets $output_opts + + # adding the layers for xent branch + prefinal-layer name=prefinal-xent input=prefinal-l $prefinal_opts small-dim=192 big-dim=768 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 14 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/fs0{1,2}/$USER/kaldi-data/egs/mini_librispeech-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$decode_cmd" \ + --feat.online-ivector-dir=$train_ivector_dir \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.0 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=20 \ + --trainer.frames-per-iter=3000000 \ + --trainer.optimization.num-jobs-initial=2 \ + --trainer.optimization.num-jobs-final=5 \ + --trainer.optimization.initial-effective-lrate=0.002 \ + --trainer.optimization.final-effective-lrate=0.0002 \ + --trainer.num-chunk-per-minibatch=128,64 \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 15 ]; then + # Note: it's not important to give mkgraph.sh the lang directory with the + # matched topology (since it gets the topology file from the model). + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/lang_test_tgsmall \ + $tree_dir $tree_dir/graph_tgsmall || exit 1; +fi + +if [ $stage -le 16 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + rm $dir/.error 2>/dev/null || true + + for data in $test_sets; do + ( + nspk=$(wc -l /dev/null || true + + for data in $test_sets; do + ( + nspk=$(wc -l $dir/configs/network.xconfig diff --git a/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1b.sh b/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1b.sh index 110b7b87415..3d0c2d63902 100755 --- a/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1b.sh +++ b/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1b.sh @@ -154,7 +154,7 @@ if [ $stage -le 13 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1c.sh b/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1c.sh index fe6f1b50f9e..081af8fe2f8 100755 --- a/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1c.sh +++ b/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1c.sh @@ -150,7 +150,7 @@ if [ $stage -le 13 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1d.sh b/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1d.sh index 225b36f909c..04df38d4da3 100755 --- a/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1d.sh +++ b/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1d.sh @@ -150,7 +150,7 @@ if [ $stage -le 13 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1e.sh b/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1e.sh index 565387003ff..cdf9bb584f4 100755 --- a/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1e.sh +++ b/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1e.sh @@ -148,7 +148,7 @@ if [ $stage -le 13 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.05" output_opts="l2-regularize=0.01" diff --git a/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1f.sh b/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1f.sh index 9cc6d93022a..d1385ff2be5 100755 --- a/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1f.sh +++ b/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1f.sh @@ -156,7 +156,7 @@ if [ $stage -le 13 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.05" output_opts="l2-regularize=0.02 bottleneck-dim=192" diff --git a/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1g.sh b/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1g.sh index e234b847aa7..ad51780e191 100755 --- a/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1g.sh +++ b/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1g.sh @@ -155,7 +155,7 @@ if [ $stage -le 13 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.05 dropout-per-dim-continuous=true" output_opts="l2-regularize=0.02 bottleneck-dim=192" diff --git a/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1g20.sh b/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1g20.sh index 18540806028..dbfe5c5a07a 100755 --- a/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1g20.sh +++ b/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1g20.sh @@ -168,7 +168,7 @@ if [ $stage -le 13 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.05 dropout-per-dim-continuous=true" output_opts="l2-regularize=0.02 bottleneck-dim=192" diff --git a/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1h.sh b/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1h.sh index 776247f5ea3..cc4123e2755 100755 --- a/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1h.sh +++ b/egs/mini_librispeech/s5/local/chain/tuning/run_tdnn_1h.sh @@ -151,7 +151,7 @@ if [ $stage -le 13 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) tdnn_opts="l2-regularize=0.03 dropout-proportion=0.0 dropout-per-dim-continuous=true" tdnnf_opts="l2-regularize=0.03 dropout-proportion=0.0 bypass-scale=0.66" diff --git a/egs/mini_librispeech/s5/local/download_and_untar.sh b/egs/mini_librispeech/s5/local/download_and_untar.sh index 93c18f263d2..5a27219f676 100755 --- a/egs/mini_librispeech/s5/local/download_and_untar.sh +++ b/egs/mini_librispeech/s5/local/download_and_untar.sh @@ -28,9 +28,11 @@ if [ ! -d "$data" ]; then exit 1; fi +data=$(readlink -f $data) + part_ok=false list="dev-clean-2 train-clean-5" -for x in $list; do +for x in $list; do if [ "$part" == $x ]; then part_ok=true; fi done if ! $part_ok; then @@ -49,7 +51,8 @@ if [ -f $data/LibriSpeech/$part/.complete ]; then fi -sizes="126046265 332747356" +#sizes="126046265 332747356" +sizes="126046265 332954390" if [ -f $data/$part.tar.gz ]; then size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}') @@ -77,6 +80,7 @@ if [ ! -f $data/$part.tar.gz ]; then echo "$0: error executing wget $full_url" exit 1; fi + cd - fi cd $data diff --git a/egs/mini_librispeech/s5/local/download_lm.sh b/egs/mini_librispeech/s5/local/download_lm.sh index 185d4811768..b37ae599118 100755 --- a/egs/mini_librispeech/s5/local/download_lm.sh +++ b/egs/mini_librispeech/s5/local/download_lm.sh @@ -4,14 +4,15 @@ # 2017 Daniel Povey # Apache 2.0 -if [ $# -ne "2" ]; then - echo "Usage: $0 " - echo "e.g.: $0 http://www.openslr.org/resources/11 data/local/lm" +if [ $# -ne "3" ]; then + echo "Usage: $0 " data/local/lang_tmp_nosp data/lang_nosp + +if [ $stage -le 0 ]; then + cp -r data/local/dict_nosp data/local/dict_nosp_basevocab + echo "#nonterm:unk" > data/local/dict_nosp_basevocab/nonterminals.txt + + utils/prepare_lang.sh data/local/dict_nosp_basevocab \ + "" data/local/lang_tmp_nosp $lang_base +fi + +if [ $stage -le 1 ]; then + # note: does appear in that arpa file, with a reasonable probability + # (0.0)... presumably because the vocab that the arpa file was built with was + # not vast, so there were plenty of OOVs. It would be possible to adjust its + # probability with adjust_unk_arpa.pl, but for now we just leave it as-is. + # The appears quite a few times in the ARPA. In the language model we + # replaced it with #nonterm:unk, which will later expand to our custom graph + # of new words. + + # We don't want the #nonterm:unk on the output side of G.fst, or it would + # appear in the decoded output, so we remove it using the 'fstrmsymbols' command. + + nonterm_unk=$(grep '#nonterm:unk' $lang_base/words.txt | awk '{print $2}') + + gunzip -c data/local/lm/lm_tgsmall.arpa.gz | \ + sed 's//#nonterm:unk/g' | \ + arpa2fst --disambig-symbol=#0 \ + --read-symbol-table=$lang_base/words.txt - | \ + fstrmsymbols --remove-from-output=true "echo $nonterm_unk|" - $lang_base/G.fst +fi + + +if [ $stage -le 2 ]; then + # make the top-level part of the graph. + utils/mkgraph.sh --self-loop-scale 1.0 $lang_base $tree_dir $tree_dir/extvocab_nosp_top +fi + +if [ $stage -le 3 ] && $run_g2p; then + # you may have to do some stuff manually to install sequitur, to get this to work. + dict=data/local/dict_nosp_basevocab + steps/dict/train_g2p.sh --silence-phones $dict/silence_phones.txt $dict/lexicon.txt $tree_dir/extvocab_nosp_g2p +fi + + +if [ $stage -le 4 ]; then + # Create data/local/dict_nosp_newvocab as a dict-dir containing just the + # newly created vocabulary entries (but the same phone list as our old setup, not + # that it matters) + + mkdir -p $tree_dir/extvocab_nosp_lexicon + + # First find a list of words in the test set that are out of vocabulary. + # Of course this is totally cheating. + awk -v w=data/lang/words.txt 'BEGIN{while(getline $tree_dir/extvocab_nosp_lexicon/words + echo "$0: generating g2p entries for $(wc -l <$tree_dir/extvocab_nosp_lexicon/words) words" + + if $run_g2p; then + steps/dict/apply_g2p.sh $tree_dir/extvocab_nosp_lexicon/words $tree_dir/extvocab_nosp_g2p $tree_dir/extvocab_nosp_lexicon + else + cat <$tree_dir/extvocab_nosp_lexicon/lexicon.lex +HARDWIGG 0.962436 HH AA1 R D W IH1 G +SUDVESTR 0.162048 S AH1 D V EY1 S T R +SUDVESTR 0.133349 S AH1 D V EH1 S T R +SUDVESTR 0.114376 S AH1 D V EH1 S T ER0 +VINOS 0.558345 V IY1 N OW0 Z +VINOS 0.068883 V AY1 N OW0 Z +VINOS 0.068431 V IY1 N OW0 S +DOMA 0.645714 D OW1 M AH0 +DOMA 0.118255 D UW1 M AH0 +DOMA 0.080682 D OW0 M AH0 +GWYNPLAINE'S 0.983053 G W IH1 N P L EY1 N Z +SHIMERDA 0.610922 SH IH0 M EH1 R D AH0 +SHIMERDA 0.175678 SH IY0 M EH1 R D AH0 +SHIMERDA 0.069785 SH AY1 M ER1 D AH0 +MYRDALS 0.479183 M IH1 R D AH0 L Z +MYRDALS 0.135225 M ER1 D AH0 L Z +MYRDALS 0.115478 M IH1 R D L Z +HEUCHERA 0.650042 HH OY1 K IH1 R AH0 +HEUCHERA 0.119363 HH OY1 K EH1 R AH0 +HEUCHERA 0.077907 HH OY1 K ER0 AH0 +IMPARA 0.906222 IH0 M P AA1 R AH0 +VERLOC'S 0.564847 V ER0 L AA1 K S +VERLOC'S 0.173540 V ER1 L AH0 K S +VERLOC'S 0.050543 V ER1 L AA1 K S +UNTRUSSING 0.998019 AH0 N T R AH1 S IH0 NG +DARFHULVA 0.317057 D AA2 F UH1 L V AH0 +DARFHULVA 0.262882 D AA2 F HH UH1 L V AH0 +DARFHULVA 0.064055 D AA2 F HH UW1 L V AH0 +FINNACTA 0.594586 F IH1 N AH0 K T AH0 +FINNACTA 0.232454 F IH1 N AE1 K T AH0 +FINNACTA 0.044733 F IH1 N IH0 K T AH0 +YOKUL 0.845279 Y OW1 K AH0 L +YOKUL 0.051082 Y OW2 K AH0 L +YOKUL 0.029435 Y OW0 K AH0 L +CONGAL 0.504228 K AA1 NG G AH0 L +CONGAL 0.151648 K AA2 NG G AH0 L +CONGAL 0.137837 K AH0 N JH AH0 L +DELECTASTI 0.632180 D IH0 L EH0 K T EY1 S T IY0 +DELECTASTI 0.203808 D IH0 L EH1 K T EY1 S T IY0 +DELECTASTI 0.066722 D IH0 L EH0 K T AE1 S T IY0 +YUNDT 0.975077 Y AH1 N T +QUINCI 0.426115 K W IH1 N S IY0 +QUINCI 0.369324 K W IH1 N CH IY0 +QUINCI 0.064507 K W IY0 N CH IY0 +BIRDIKINS 0.856979 B ER1 D IH0 K AH0 N Z +BIRDIKINS 0.045315 B ER1 D AH0 K AH0 N Z +SNEFFELS 0.928413 S N EH1 F AH0 L Z +FJORDUNGR 0.130629 F Y AO1 R D UW0 NG G R +FJORDUNGR 0.125082 F Y AO1 R D AH0 NG G R +FJORDUNGR 0.111035 F Y AO1 R D UH1 NG R +YULKA 0.540253 Y UW1 L K AH0 +YULKA 0.295588 Y AH1 L K AH0 +YULKA 0.076631 Y UH1 L K AH0 +LACQUEY'S 0.987908 L AE1 K IY0 Z +OSSIPON'S 0.651400 AA1 S AH0 P AA2 N Z +OSSIPON'S 0.118444 AA1 S AH0 P AA0 N Z +OSSIPON'S 0.106377 AA1 S AH0 P AH0 N Z +SAKNUSSEMM 0.060270 S AE1 K N AH1 S EH1 M +SAKNUSSEMM 0.044992 S AE1 K N AH0 S EH1 M +SAKNUSSEMM 0.044084 S AA0 K N AH1 S EH1 M +CONGAL'S 0.618287 K AA1 NG G AH0 L Z +CONGAL'S 0.185952 K AA2 NG G AH0 L Z +CONGAL'S 0.115143 K AH0 N G AH0 L Z +TARRINZEAU 0.159153 T AA1 R IY0 N Z OW1 +TARRINZEAU 0.136536 T AA1 R AH0 N Z OW1 +TARRINZEAU 0.100924 T EH1 R IY0 N Z OW1 +SHIMERDAS 0.230819 SH IH0 M EH1 R D AH0 Z +SHIMERDAS 0.216235 SH IH0 M EH1 R D AH0 S +SHIMERDAS 0.073311 SH AY1 M ER1 D AH0 Z +RUGGEDO'S 0.821285 R UW0 JH EY1 D OW0 Z +RUGGEDO'S 0.166825 R AH1 G AH0 D OW0 Z +CORNCAKES 0.934118 K AO1 R N K EY2 K S +VENDHYA 0.616662 V EH0 N D Y AH0 +VENDHYA 0.178349 V EH1 N D Y AH0 +VENDHYA 0.160768 V AA1 N D Y AH0 +GINGLE 0.919815 G IH1 NG G AH0 L +STUPIRTI 0.422653 S T UW0 P IH1 R T IY0 +STUPIRTI 0.126925 S T UW1 P IH0 R T IY0 +STUPIRTI 0.078422 S T UW1 P AH0 R T IY0 +HERBIVORE 0.950887 HH ER1 B IH0 V AO2 R +BRION'S 0.838326 B R AY1 AH0 N Z +BRION'S 0.140310 B R IY0 AH0 N Z +DELAUNAY'S 0.993259 D EH1 L AO0 N EY0 Z +KHOSALA 0.920908 K OW0 S AA1 L AH0 +BRANDD 0.827461 B R AE1 N D +BRANDD 0.085646 B R AE2 N D +GARDAR 0.598675 G AA0 R D AA1 R +GARDAR 0.289831 G AA1 R D AA2 R +GARDAR 0.057983 G AA0 R D AA2 R +MACKLEWAIN 0.570209 M AE1 K AH0 L W EY0 N +MACKLEWAIN 0.101477 M AH0 K AH0 L W EY0 N +MACKLEWAIN 0.067905 M AE1 K AH0 L W EY2 N +LIBANO 0.993297 L IY0 B AA1 N OW0 +MOLING 0.782578 M OW1 L IH0 NG +MOLING 0.059362 M OW2 L IH0 NG +MOLING 0.056217 M AA1 L IH0 NG +BENNYDECK'S 0.583859 B EH1 N IY0 D EH0 K S +BENNYDECK'S 0.276699 B EH1 N IH0 D EH0 K S +BENNYDECK'S 0.028343 B EH1 N IH0 D IH0 K S +MACKLEWAIN'S 0.615766 M AE1 K AH0 L W EY0 N Z +MACKLEWAIN'S 0.109585 M AH0 K AH0 L W EY0 N Z +MACKLEWAIN'S 0.039423 M AE1 K AH0 L W AH0 N Z +PRESTY 0.616071 P R EH1 S T IY0 +PRESTY 0.288701 P R AH0 S T IY0 +BREADHOUSE 0.995874 B R EH1 D HH AW2 S +BUZZER'S 0.992495 B AH1 Z ER0 Z +BHUNDA 0.502439 B UW1 N D AH0 +BHUNDA 0.267733 B AH0 N D AH0 +BHUNDA 0.193772 B UH1 N D AH0 +PINKIES 0.998440 P IH1 NG K IY0 Z +TROKE 0.723320 T R OW1 K +TROKE 0.269707 T R OW2 K +OSSIPON 0.728486 AA1 S AH0 P AA2 N +OSSIPON 0.098752 AA1 S AH0 P AH0 N +OSSIPON 0.033957 AA1 S AH0 P AO0 N +RIVERLIKE 0.991731 R IH1 V ER0 L AY2 K +NICLESS 0.478183 N IH1 K L AH0 S +NICLESS 0.159889 N IH0 K L AH0 S +NICLESS 0.120611 N IH1 K L IH0 S +TRAMPE 0.959184 T R AE1 M P +VERLOC 0.610461 V ER0 L AA1 K +VERLOC 0.128479 V ER1 L AH0 K +VERLOC 0.073687 V ER1 L AA0 K +GANNY 0.991703 G AE1 N IY0 +AMBROSCH 0.302906 AE0 M B R OW1 SH +AMBROSCH 0.201163 AE0 M B R AO1 SH +AMBROSCH 0.109274 AE1 M B R AO1 SH +FIBI 0.619154 F IH1 B IY0 +FIBI 0.163168 F IY1 B IY0 +FIBI 0.083443 F AY1 B IY0 +IROLG 0.823123 IH0 R OW1 L G +IROLG 0.053196 IH0 R OW1 L JH +IROLG 0.021038 IH0 R OW1 L JH IY1 +BALVASTRO 0.251546 B AA0 L V AA1 S T R OW0 +BALVASTRO 0.213351 B AE0 L V AE1 S T R OW0 +BALVASTRO 0.133005 B AA0 L V AE1 S T R OW0 +BOOLOOROO 0.676757 B UW1 L UW1 R UW0 +BOOLOOROO 0.173653 B UW1 L UH2 R UW0 +BOOLOOROO 0.086501 B UW1 L UH0 R UW0 +EOF + fi + + # extend_lang.sh needs it to have basename 'lexiconp.txt'. + mv $tree_dir/extvocab_nosp_lexicon/lexicon.lex $tree_dir/extvocab_nosp_lexicon/lexiconp.txt + + [ -f data/lang_nosp_extvocab/G.fst ] && rm data/lang_nosp_extvocab/G.fst + utils/lang/extend_lang.sh data/lang_nosp_basevocab $tree_dir/extvocab_nosp_lexicon/lexiconp.txt data/lang_nosp_extvocab +fi + +if [ $stage -le 5 ]; then + # make the G.fst for the extra words. Just assign equal probabilities to all of + # them. The words will all transition from state 1 to 2. + cat < $lang_ext/G.txt +0 1 #nonterm_begin +2 3 #nonterm_end +3 +EOF + lexicon=$tree_dir/extvocab_nosp_lexicon/lexiconp.txt + num_words=$(wc -l <$lexicon) + cost=$(perl -e "print log($num_words)"); + awk -v cost=$cost '{print 1, 2, $1, $1, cost}' <$lexicon >>$lang_ext/G.txt + fstcompile --isymbols=$lang_ext/words.txt --osymbols=$lang_ext/words.txt <$lang_ext/G.txt | \ + fstarcsort --sort_type=ilabel >$lang_ext/G.fst +fi + +if [ $stage -le 6 ]; then + # make the part of the graph that will be included. + # Refer to the 'compile-graph' commands in ./simple_demo.sh for how you'd do + # this in code. + utils/mkgraph.sh --self-loop-scale 1.0 $lang_ext $tree_dir $tree_dir/extvocab_nosp_part +fi + +if [ $stage -le 7 ]; then + offset=$(grep nonterm_bos $lang_ext/phones.txt | awk '{print $2}') + nonterm_unk=$(grep nonterm:unk $lang_ext/phones.txt | awk '{print $2}') + + mkdir -p $tree_dir/extvocab_nosp_combined + [ -d $tree_dir/extvocab_nosp_combined/phones ] && rm -r $tree_dir/extvocab_nosp_combined/phones + # the decoding script expects words.txt and phones/, copy them from the extvocab_part + # graph directory where they will have suitable values. + cp -r $tree_dir/extvocab_nosp_part/{words.txt,phones.txt,phones/} $tree_dir/extvocab_nosp_combined + + # the following, due to --write-as-grammar=false, compiles it into an FST + # which can be decoded by our normal decoder. + make-grammar-fst --write-as-grammar=false --nonterm-phones-offset=$offset $tree_dir/extvocab_nosp_top/HCLG.fst \ + $nonterm_unk $tree_dir/extvocab_nosp_part/HCLG.fst $tree_dir/extvocab_nosp_combined/HCLG.fst + + # the following compiles it and writes as GrammarFst. The size is 176M, vs. 182M for HCLG.fst. + # In other examples, of course the difference might be more. + + make-grammar-fst --write-as-grammar=true --nonterm-phones-offset=$offset $tree_dir/extvocab_nosp_top/HCLG.fst \ + $nonterm_unk $tree_dir/extvocab_nosp_part/HCLG.fst $tree_dir/extvocab_nosp_combined/HCLG.gra +fi + + +if [ $stage -le 8 ]; then + # OK, now we actually decode the test data. For reference, the command which was used to + # decode the test data in the current (at the time of writing) chain TDNN system + # local/chain/run_tdnn.sh (as figured out by running it from that stage), was: + # steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 --frames-per-chunk 140 --nj 38 \ + # --cmd "queue.pl --mem 4G --num-threads 4" --online-ivector-dir exp/nnet3/ivectors_dev_clean_2_hires \ + # exp/chain/tree_sp/graph_tgsmall data/dev_clean_2_hires exp/chain/tdnn1h_sp/decode_tgsmall_dev_clean_2 + + # We just replace the graph with the one in $treedir/extvocab_nosp_combined. + + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 --frames-per-chunk 140 --nj 38 \ + --cmd "queue.pl --mem 4G --num-threads 4" --online-ivector-dir exp/nnet3/ivectors_dev_clean_2_hires \ + exp/chain/tree_sp/extvocab_nosp_combined data/dev_clean_2_hires exp/chain/tdnn1h_sp/decode_tgsmall_dev_clean_2_ev_nosp_comb + + + +# grep WER exp/chain/tdnn1h_sp/decode_tgsmall_dev_clean_2_ev_nosp_comb/wer_* | utils/best_wer.sh +#%WER 11.79 [ 2375 / 20138, 195 ins, 343 del, 1837 sub ] exp/chain/tdnn1h_sp/decode_tgsmall_dev_clean_2_ev_nosp_comb/wer_12_0.0# s5: grep WER exp/chain/tdnn1h_sp/decode_tgsmall_dev_clean_2_ev_nosp_comb/wer_* | utils/best_wer.sh + + #.. versus the baseline below note, the baseline is not 100% comparable as it used the + # silence probabilities, which the grammar-decoding does not (yet) support... + # s5: grep WER exp/chain/tdnn1h_sp/decode_tgsmall_dev_clean_2/wer_* | utils/best_wer.sh + # %WER 12.01 [ 2418 / 20138, 244 ins, 307 del, 1867 sub ] exp/chain/tdnn1h_sp/decode_tgsmall_dev_clean_2/wer_13_0.0 +fi + +if [ $stage -le 9 ]; then + steps/nnet3/decode_grammar.sh --acwt 1.0 --post-decode-acwt 10.0 --frames-per-chunk 140 --nj 38 \ + --cmd "queue.pl --mem 4G --num-threads 4" --online-ivector-dir exp/nnet3/ivectors_dev_clean_2_hires \ + exp/chain/tree_sp/extvocab_nosp_combined data/dev_clean_2_hires exp/chain/tdnn1h_sp/decode_tgsmall_dev_clean_2_ev_nosp_comb_gra + + # The WER when decoding with the grammar FST directly is exactly the same: + # s5: grep WER exp/chain/tdnn1h_sp/decode_tgsmall_dev_clean_2_ev_nosp_comb_gra/wer_* | utils/best_wer.sh + # %WER 11.79 [ 2375 / 20138, 195 ins, 343 del, 1837 sub ] exp/chain/tdnn1h_sp/decode_tgsmall_dev_clean_2_ev_nosp_comb_gra/wer_12_0.0 +fi diff --git a/egs/mini_librispeech/s5/local/grammar/extend_vocab_demo_silprobs.sh b/egs/mini_librispeech/s5/local/grammar/extend_vocab_demo_silprobs.sh new file mode 100755 index 00000000000..28c58dfa453 --- /dev/null +++ b/egs/mini_librispeech/s5/local/grammar/extend_vocab_demo_silprobs.sh @@ -0,0 +1,326 @@ +#!/usr/bin/env bash + +# This script demonstrates how to use the grammar-decoding framework to build +# graphs made out of more than one part. (This version uses word-specific +# silence probabilities). It demonstrates using `fstequivalent` +# that the graph constructed this way is equivalent to what you would create if +# you had the LM all as a single piece. This uses the command line tools to +# expand to a regular FST (--write-as-grammar=false) In practice you might not +# want do to that, since the result might be large, and since writing the entire +# thing might take too much time. The code itself allows you to construct these +# GrammarFst objects in lightweight way and decode using them. + +# Unfortunately the filenames here are not very well through through. I hope to +# rework this when I have time. + +stage=0 +run_g2p=false # set this to true to run the g2p stuff, it's slow so + # by default we fake it by providing what it previously output +set -e + +. ./path.sh +. utils/parse_options.sh + + +tree_dir=exp/chain/tree_sp +lang_base=data/lang_basevocab +lang_ext=data/lang_extvocab + +# For the purposes of this script we just need a biphone tree and associated +# transition-model for testing, because we're testing it at the graph level, +# i.e. testing equivalence of compiled HCLG graphs; there is no decoding +# involved here. + +# We're doing this with the "no-silprobs" dictionary dir for now, as we +# need to write some scripts to support silprobs with this. + +# For reference, here is how we could create the 'lang' dir for the +# baseline. +#utils/prepare_lang.sh data/local/dict \ +# "" data/local/lang_tmp data/lang + +if [ $stage -le 0 ]; then + cp -r data/local/dict data/local/dict_basevocab + echo "#nonterm:unk" > data/local/dict_basevocab/nonterminals.txt + + utils/prepare_lang.sh data/local/dict_basevocab \ + "" data/local/lang_tmp $lang_base +fi + +if [ $stage -le 1 ]; then + # note: does appear in that arpa file, with a reasonable probability + # (0.0)... presumably because the vocab that the arpa file was built with was + # not vast, so there were plenty of OOVs. It would be possible to adjust its + # probability with adjust_unk_arpa.pl, but for now we just leave it as-is. + # The appears quite a few times in the ARPA. In the language model we + # replaced it with #nonterm:unk, which will later expand to our custom graph + # of new words. + + # We don't want the #nonterm:unk on the output side of G.fst, or it would + # appear in the decoded output, so we remove it using the 'fstrmsymbols' command. + + nonterm_unk=$(grep '#nonterm:unk' $lang_base/words.txt | awk '{print $2}') + + gunzip -c data/local/lm/lm_tgsmall.arpa.gz | \ + sed 's//#nonterm:unk/g' | \ + arpa2fst --disambig-symbol=#0 \ + --read-symbol-table=$lang_base/words.txt - | \ + fstrmsymbols --remove-from-output=true "echo $nonterm_unk|" - $lang_base/G.fst +fi + + +if [ $stage -le 2 ]; then + # make the top-level part of the graph. + utils/mkgraph.sh --self-loop-scale 1.0 $lang_base $tree_dir $tree_dir/extvocab_top +fi + +if [ $stage -le 3 ] && $run_g2p; then + # you may have to do some stuff manually to install sequitur, to get this to work. + dict=data/local/dict_basevocab + steps/dict/train_g2p.sh --silence-phones $dict/silence_phones.txt $dict/lexicon.txt $tree_dir/extvocab_g2p +fi + + +if [ $stage -le 4 ]; then + # Create data/local/dict_newvocab as a dict-dir containing just the + # newly created vocabulary entries (but the same phone list as our old setup, not + # that it matters) + + mkdir -p $tree_dir/extvocab_lexicon + + # First find a list of words in the test set that are out of vocabulary. + # Of course this is totally cheating. + awk -v w=data/lang/words.txt 'BEGIN{while(getline $tree_dir/extvocab_lexicon/words + echo "$0: generating g2p entries for $(wc -l <$tree_dir/extvocab_lexicon/words) words" + + if $run_g2p; then + steps/dict/apply_g2p.sh $tree_dir/extvocab_lexicon/words $tree_dir/extvocab_g2p $tree_dir/extvocab_lexicon + else + cat <$tree_dir/extvocab_lexicon//lexicon.lex +HARDWIGG 0.962436 HH AA1 R D W IH1 G +SUDVESTR 0.162048 S AH1 D V EY1 S T R +SUDVESTR 0.133349 S AH1 D V EH1 S T R +SUDVESTR 0.114376 S AH1 D V EH1 S T ER0 +VINOS 0.558345 V IY1 N OW0 Z +VINOS 0.068883 V AY1 N OW0 Z +VINOS 0.068431 V IY1 N OW0 S +DOMA 0.645714 D OW1 M AH0 +DOMA 0.118255 D UW1 M AH0 +DOMA 0.080682 D OW0 M AH0 +GWYNPLAINE'S 0.983053 G W IH1 N P L EY1 N Z +SHIMERDA 0.610922 SH IH0 M EH1 R D AH0 +SHIMERDA 0.175678 SH IY0 M EH1 R D AH0 +SHIMERDA 0.069785 SH AY1 M ER1 D AH0 +MYRDALS 0.479183 M IH1 R D AH0 L Z +MYRDALS 0.135225 M ER1 D AH0 L Z +MYRDALS 0.115478 M IH1 R D L Z +HEUCHERA 0.650042 HH OY1 K IH1 R AH0 +HEUCHERA 0.119363 HH OY1 K EH1 R AH0 +HEUCHERA 0.077907 HH OY1 K ER0 AH0 +IMPARA 0.906222 IH0 M P AA1 R AH0 +VERLOC'S 0.564847 V ER0 L AA1 K S +VERLOC'S 0.173540 V ER1 L AH0 K S +VERLOC'S 0.050543 V ER1 L AA1 K S +UNTRUSSING 0.998019 AH0 N T R AH1 S IH0 NG +DARFHULVA 0.317057 D AA2 F UH1 L V AH0 +DARFHULVA 0.262882 D AA2 F HH UH1 L V AH0 +DARFHULVA 0.064055 D AA2 F HH UW1 L V AH0 +FINNACTA 0.594586 F IH1 N AH0 K T AH0 +FINNACTA 0.232454 F IH1 N AE1 K T AH0 +FINNACTA 0.044733 F IH1 N IH0 K T AH0 +YOKUL 0.845279 Y OW1 K AH0 L +YOKUL 0.051082 Y OW2 K AH0 L +YOKUL 0.029435 Y OW0 K AH0 L +CONGAL 0.504228 K AA1 NG G AH0 L +CONGAL 0.151648 K AA2 NG G AH0 L +CONGAL 0.137837 K AH0 N JH AH0 L +DELECTASTI 0.632180 D IH0 L EH0 K T EY1 S T IY0 +DELECTASTI 0.203808 D IH0 L EH1 K T EY1 S T IY0 +DELECTASTI 0.066722 D IH0 L EH0 K T AE1 S T IY0 +YUNDT 0.975077 Y AH1 N T +QUINCI 0.426115 K W IH1 N S IY0 +QUINCI 0.369324 K W IH1 N CH IY0 +QUINCI 0.064507 K W IY0 N CH IY0 +BIRDIKINS 0.856979 B ER1 D IH0 K AH0 N Z +BIRDIKINS 0.045315 B ER1 D AH0 K AH0 N Z +SNEFFELS 0.928413 S N EH1 F AH0 L Z +FJORDUNGR 0.130629 F Y AO1 R D UW0 NG G R +FJORDUNGR 0.125082 F Y AO1 R D AH0 NG G R +FJORDUNGR 0.111035 F Y AO1 R D UH1 NG R +YULKA 0.540253 Y UW1 L K AH0 +YULKA 0.295588 Y AH1 L K AH0 +YULKA 0.076631 Y UH1 L K AH0 +LACQUEY'S 0.987908 L AE1 K IY0 Z +OSSIPON'S 0.651400 AA1 S AH0 P AA2 N Z +OSSIPON'S 0.118444 AA1 S AH0 P AA0 N Z +OSSIPON'S 0.106377 AA1 S AH0 P AH0 N Z +SAKNUSSEMM 0.060270 S AE1 K N AH1 S EH1 M +SAKNUSSEMM 0.044992 S AE1 K N AH0 S EH1 M +SAKNUSSEMM 0.044084 S AA0 K N AH1 S EH1 M +CONGAL'S 0.618287 K AA1 NG G AH0 L Z +CONGAL'S 0.185952 K AA2 NG G AH0 L Z +CONGAL'S 0.115143 K AH0 N G AH0 L Z +TARRINZEAU 0.159153 T AA1 R IY0 N Z OW1 +TARRINZEAU 0.136536 T AA1 R AH0 N Z OW1 +TARRINZEAU 0.100924 T EH1 R IY0 N Z OW1 +SHIMERDAS 0.230819 SH IH0 M EH1 R D AH0 Z +SHIMERDAS 0.216235 SH IH0 M EH1 R D AH0 S +SHIMERDAS 0.073311 SH AY1 M ER1 D AH0 Z +RUGGEDO'S 0.821285 R UW0 JH EY1 D OW0 Z +RUGGEDO'S 0.166825 R AH1 G AH0 D OW0 Z +CORNCAKES 0.934118 K AO1 R N K EY2 K S +VENDHYA 0.616662 V EH0 N D Y AH0 +VENDHYA 0.178349 V EH1 N D Y AH0 +VENDHYA 0.160768 V AA1 N D Y AH0 +GINGLE 0.919815 G IH1 NG G AH0 L +STUPIRTI 0.422653 S T UW0 P IH1 R T IY0 +STUPIRTI 0.126925 S T UW1 P IH0 R T IY0 +STUPIRTI 0.078422 S T UW1 P AH0 R T IY0 +HERBIVORE 0.950887 HH ER1 B IH0 V AO2 R +BRION'S 0.838326 B R AY1 AH0 N Z +BRION'S 0.140310 B R IY0 AH0 N Z +DELAUNAY'S 0.993259 D EH1 L AO0 N EY0 Z +KHOSALA 0.920908 K OW0 S AA1 L AH0 +BRANDD 0.827461 B R AE1 N D +BRANDD 0.085646 B R AE2 N D +GARDAR 0.598675 G AA0 R D AA1 R +GARDAR 0.289831 G AA1 R D AA2 R +GARDAR 0.057983 G AA0 R D AA2 R +MACKLEWAIN 0.570209 M AE1 K AH0 L W EY0 N +MACKLEWAIN 0.101477 M AH0 K AH0 L W EY0 N +MACKLEWAIN 0.067905 M AE1 K AH0 L W EY2 N +LIBANO 0.993297 L IY0 B AA1 N OW0 +MOLING 0.782578 M OW1 L IH0 NG +MOLING 0.059362 M OW2 L IH0 NG +MOLING 0.056217 M AA1 L IH0 NG +BENNYDECK'S 0.583859 B EH1 N IY0 D EH0 K S +BENNYDECK'S 0.276699 B EH1 N IH0 D EH0 K S +BENNYDECK'S 0.028343 B EH1 N IH0 D IH0 K S +MACKLEWAIN'S 0.615766 M AE1 K AH0 L W EY0 N Z +MACKLEWAIN'S 0.109585 M AH0 K AH0 L W EY0 N Z +MACKLEWAIN'S 0.039423 M AE1 K AH0 L W AH0 N Z +PRESTY 0.616071 P R EH1 S T IY0 +PRESTY 0.288701 P R AH0 S T IY0 +BREADHOUSE 0.995874 B R EH1 D HH AW2 S +BUZZER'S 0.992495 B AH1 Z ER0 Z +BHUNDA 0.502439 B UW1 N D AH0 +BHUNDA 0.267733 B AH0 N D AH0 +BHUNDA 0.193772 B UH1 N D AH0 +PINKIES 0.998440 P IH1 NG K IY0 Z +TROKE 0.723320 T R OW1 K +TROKE 0.269707 T R OW2 K +OSSIPON 0.728486 AA1 S AH0 P AA2 N +OSSIPON 0.098752 AA1 S AH0 P AH0 N +OSSIPON 0.033957 AA1 S AH0 P AO0 N +RIVERLIKE 0.991731 R IH1 V ER0 L AY2 K +NICLESS 0.478183 N IH1 K L AH0 S +NICLESS 0.159889 N IH0 K L AH0 S +NICLESS 0.120611 N IH1 K L IH0 S +TRAMPE 0.959184 T R AE1 M P +VERLOC 0.610461 V ER0 L AA1 K +VERLOC 0.128479 V ER1 L AH0 K +VERLOC 0.073687 V ER1 L AA0 K +GANNY 0.991703 G AE1 N IY0 +AMBROSCH 0.302906 AE0 M B R OW1 SH +AMBROSCH 0.201163 AE0 M B R AO1 SH +AMBROSCH 0.109274 AE1 M B R AO1 SH +FIBI 0.619154 F IH1 B IY0 +FIBI 0.163168 F IY1 B IY0 +FIBI 0.083443 F AY1 B IY0 +IROLG 0.823123 IH0 R OW1 L G +IROLG 0.053196 IH0 R OW1 L JH +IROLG 0.021038 IH0 R OW1 L JH IY1 +BALVASTRO 0.251546 B AA0 L V AA1 S T R OW0 +BALVASTRO 0.213351 B AE0 L V AE1 S T R OW0 +BALVASTRO 0.133005 B AA0 L V AE1 S T R OW0 +BOOLOOROO 0.676757 B UW1 L UW1 R UW0 +BOOLOOROO 0.173653 B UW1 L UH2 R UW0 +BOOLOOROO 0.086501 B UW1 L UH0 R UW0 +EOF + fi + + # extend_lang.sh needs it to have basename 'lexiconp.txt'. + mv $tree_dir/extvocab_lexicon/lexicon.lex $tree_dir/extvocab_lexicon/lexiconp.txt + + [ -f data/lang_extvocab/G.fst ] && rm data/lang_extvocab/G.fst + utils/lang/extend_lang.sh data/lang_basevocab $tree_dir/extvocab_lexicon/lexiconp.txt data/lang_extvocab +fi + +if [ $stage -le 5 ]; then + # make the G.fst for the extra words. Just assign equal probabilities to all of + # them. The words will all transition from state 1 to 2. + cat < $lang_ext/G.txt +0 1 #nonterm_begin +2 3 #nonterm_end +3 +EOF + lexicon=$tree_dir/extvocab_lexicon/lexiconp.txt + num_words=$(wc -l <$lexicon) + cost=$(perl -e "print log($num_words)"); + awk -v cost=$cost '{print 1, 2, $1, $1, cost}' <$lexicon >>$lang_ext/G.txt + fstcompile --isymbols=$lang_ext/words.txt --osymbols=$lang_ext/words.txt <$lang_ext/G.txt | \ + fstarcsort --sort_type=ilabel >$lang_ext/G.fst +fi + +if [ $stage -le 6 ]; then + # make the part of the graph that will be included. + # Refer to the 'compile-graph' commands in ./simple_demo.sh for how you'd do + # this in code. + utils/mkgraph.sh --self-loop-scale 1.0 $lang_ext $tree_dir $tree_dir/extvocab_part +fi + +if [ $stage -le 7 ]; then + offset=$(grep nonterm_bos $lang_ext/phones.txt | awk '{print $2}') + nonterm_unk=$(grep nonterm:unk $lang_ext/phones.txt | awk '{print $2}') + + mkdir -p $tree_dir/extvocab_combined + [ -d $tree_dir/extvocab_combined/phones ] && rm -r $tree_dir/extvocab_combined/phones + # the decoding script expects words.txt and phones/, copy them from the extvocab_part + # graph directory where they will have suitable values. + cp -r $tree_dir/extvocab_part/{words.txt,phones.txt,phones/} $tree_dir/extvocab_combined + + # the following, due to --write-as-grammar=false, compiles it into an FST + # which can be decoded by our normal decoder. + make-grammar-fst --write-as-grammar=false --nonterm-phones-offset=$offset $tree_dir/extvocab_top/HCLG.fst \ + $nonterm_unk $tree_dir/extvocab_part/HCLG.fst $tree_dir/extvocab_combined/HCLG.fst + + # the following compiles it and writes as GrammarFst. The size is 176M, vs. 182M for HCLG.fst. + # In other examples, of course the difference might be more. + + make-grammar-fst --write-as-grammar=true --nonterm-phones-offset=$offset $tree_dir/extvocab_top/HCLG.fst \ + $nonterm_unk $tree_dir/extvocab_part/HCLG.fst $tree_dir/extvocab_combined/HCLG.gra +fi + + +if [ $stage -le 8 ]; then + # OK, now we actually decode the test data. For reference, the command which was used to + # decode the test data in the current (at the time of writing) chain TDNN system + # local/chain/run_tdnn.sh (as figured out by running it from that stage), was: + # steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 --frames-per-chunk 140 --nj 38 \ + # --cmd "queue.pl --mem 4G --num-threads 4" --online-ivector-dir exp/nnet3/ivectors_dev_clean_2_hires \ + # exp/chain/tree_sp/graph_tgsmall data/dev_clean_2_hires exp/chain/tdnn1h_sp/decode_tgsmall_dev_clean_2 + + # We just replace the graph with the one in $treedir/extvocab_combined. + + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 --frames-per-chunk 140 --nj 38 \ + --cmd "queue.pl --mem 4G --num-threads 4" --online-ivector-dir exp/nnet3/ivectors_dev_clean_2_hires \ + exp/chain/tree_sp/extvocab_combined data/dev_clean_2_hires exp/chain/tdnn1h_sp/decode_tgsmall_dev_clean_2_ev_comb + + # s5: grep WER exp/chain/tdnn1h_sp/decode_tgsmall_dev_clean_2_ev_comb/wer_* | utils/best_wer.sh + # %WER 11.42 [ 2300 / 20138, 227 ins, 275 del, 1798 sub ] exp/chain/tdnn1h_sp/decode_tgsmall_dev_clean_2_ev_comb/wer_12_0.0 + + #.. versus the baseline below: + # s5: grep WER exp/chain/tdnn1h_sp/decode_tgsmall_dev_clean_2/wer_* | utils/best_wer.sh + # %WER 12.01 [ 2418 / 20138, 244 ins, 307 del, 1867 sub ] exp/chain/tdnn1h_sp/decode_tgsmall_dev_clean_2/wer_13_0.0 +fi + +if [ $stage -le 9 ]; then + steps/nnet3/decode_grammar.sh --acwt 1.0 --post-decode-acwt 10.0 --frames-per-chunk 140 --nj 38 \ + --cmd "queue.pl --mem 4G --num-threads 4" --online-ivector-dir exp/nnet3/ivectors_dev_clean_2_hires \ + exp/chain/tree_sp/extvocab_combined data/dev_clean_2_hires exp/chain/tdnn1h_sp/decode_tgsmall_dev_clean_2_ev_comb_gra + + # WER with grammar decoding is exactly the same as decoding from the converted FST. + # grep WER exp/chain/tdnn1h_sp/decode_tgsmall_dev_clean_2_ev_comb_gra/wer_* | utils/best_wer.sh + # %WER 11.42 [ 2300 / 20138, 227 ins, 275 del, 1798 sub ] exp/chain/tdnn1h_sp/decode_tgsmall_dev_clean_2_ev_comb_gra/wer_12_0.0 +fi diff --git a/egs/mini_librispeech/s5/local/grammar/simple_demo.sh b/egs/mini_librispeech/s5/local/grammar/simple_demo.sh new file mode 100755 index 00000000000..a4edeb8091c --- /dev/null +++ b/egs/mini_librispeech/s5/local/grammar/simple_demo.sh @@ -0,0 +1,177 @@ +#!/usr/bin/env bash + +# This script demonstrates how to use the grammar-decoding framework to build +# graphs made out of more than one part. It demonstrates using `fstequivalent` +# that the graph constructed this way is equivalent to what you would create if +# you had the LM all as a single piece. This uses the command line tools to +# expand to a regular FST (--write-as-grammar=false) In practice you might not +# want do to that, since the result might be large, and since writing the entire +# thing might take too much time. The code itself allows you to construct these +# GrammarFst objects in lightweight way and decode using them. + +stage=0 +set -e +. ./path.sh +. utils/parse_options.sh + + +tree_dir=exp/chain/tree_sp + +# For the purposes of this script we just need a biphone tree and associated +# transition-model for testing, because we're testing it at the graph level, +# i.e. testing equivalence of compiled HCLG graphs; there is no decoding +# involved here. + +# We're doing this with the "no-silprobs" dictionary dir for now, as we +# need to write some scripts to support silprobs with this. + +# For reference, the original command we +#utils/prepare_lang.sh data/local/dict_nosp \ +# "" data/local/lang_tmp_nosp data/lang_nosp + +if [ $stage -le 0 ]; then + [ -d data/local/dict_nosp_grammar1 ] && rm -r data/local/dict_nosp_grammar1 + cp -r data/local/dict_nosp data/local/dict_nosp_grammar1 + echo "#nonterm:contact_list" > data/local/dict_nosp_grammar1/nonterminals.txt + + [ -f data/lang_nosp_grammar1/G.fst ] && rm data/lang_nosp_grammar1/G.fst + utils/prepare_lang.sh data/local/dict_nosp_grammar1 \ + "" data/local/lang_tmp_nosp data/lang_nosp_grammar1 +fi + + + +if [ $stage -le 1 ]; then + # Most contents of these directories will be the same, only G.fst differs, but + # it's our practice to make these things as directories combining G.fst with + # everything else. + rm -r data/lang_nosp_grammar2{a,b} 2>/dev/null || true + cp -r data/lang_nosp_grammar1 data/lang_nosp_grammar2a + cp -r data/lang_nosp_grammar1 data/lang_nosp_grammar2b +fi + +if [ $stage -le 2 ]; then + # Create a simple G.fst in data/lang_nosp_grammar1, which won't + # actually use any grammar stuff, it will be a baseline to test against. + + lang=data/lang_nosp_grammar1 + cat < $lang/G.fst +0 1 GROUP GROUP +0 1 4.0 +1 2 ONE ONE 0.69314718055994 +1 2 TWO TWO 0.69314718055994 +1 2 5.0 +2 3 ASSIST ASSIST 0.69314718055994 +2 0.69314718055994 +3 +EOF + utils/mkgraph.sh --self-loop-scale 1.0 $lang $tree_dir $tree_dir/grammar1 + + # test that the binary 'compile-graph' does the same thing as mkgraph.sh. + compile-graph --read-disambig-syms=$lang/phones/disambig.int $tree_dir/tree $tree_dir/1.mdl $lang/L_disambig.fst $lang/G.fst $tree_dir/grammar1/HCLG2.fst + + if ! fstequivalent --delta=0.01 --random=true --npath=100 $tree_dir/grammar1/HCLG{,2}.fst; then + echo "$0: two methods of producing graph in $tree_dir/grammar1 were different." + exit 1 + fi +fi + + +if [ $stage -le 3 ]; then + # create the top-level graph in data/lang_nosp_grammar2a + + # you can of course choose to put what symbols you want on the output side, as + # long as they are defined in words.txt. #nonterm:contact_list, #nonterm_begin + # and #nonterm_end would be defined in this example. This might be useful in + # situations where you want to keep track of the structure of calling + # nonterminals. + lang=data/lang_nosp_grammar2a + cat <$lang/G.fst +0 1 GROUP GROUP +0 1 4.0 +1 2 #nonterm:contact_list +2 3 ASSIST ASSIST 0.69314718055994 +2 0.69314718055994 +3 +EOF + utils/mkgraph.sh --self-loop-scale 1.0 $lang $tree_dir $tree_dir/grammar2a + + # test that the binary 'compile-graph' does the same thing as mkgraph.sh. + offset=$(grep nonterm_bos $lang/phones.txt | awk '{print $2}') # 364 + compile-graph --nonterm-phones-offset=$offset --read-disambig-syms=$lang/phones/disambig.int \ + $tree_dir/tree $tree_dir/1.mdl $lang/L_disambig.fst $lang/G.fst $tree_dir/grammar2a/HCLG2.fst + + if ! fstequivalent --delta=0.01 --random=true --npath=100 $tree_dir/grammar2a/HCLG{,2}.fst; then + echo "$0: two methods of producing graph in $tree_dir/grammar2a were different." + exit 1 + fi +fi + +if [ $stage -le 4 ]; then + # Create the graph for the nonterminal in data/lang_nosp_grammar2b + # Again, we don't choose to put these symbols on the output side, but it would + # be possible to do so. + lang=data/lang_nosp_grammar2b + cat < $lang/G.fst +0 1 #nonterm_begin +1 2 ONE ONE 0.69314718055994 +1 2 TWO TWO 0.69314718055994 +1 2 5.0 +2 3 #nonterm_end +3 +EOF + utils/mkgraph.sh --self-loop-scale 1.0 $lang $tree_dir $tree_dir/grammar2b + + + # test that the binary 'compile-graph' does the same thing as mkgraph.sh. + offset=$(grep nonterm_bos $lang/phones.txt | awk '{print $2}') # 364 + compile-graph --nonterm-phones-offset=$offset --read-disambig-syms=$lang/phones/disambig.int \ + $tree_dir/tree $tree_dir/1.mdl $lang/L_disambig.fst $lang/G.fst $tree_dir/grammar2b/HCLG2.fst + + if ! fstequivalent --delta=0.01 --random=true --npath=100 $tree_dir/grammar2b/HCLG{,2}.fst; then + echo "$0: two methods of producing graph in $tree_dir/grammar2b were different." + exit 1 + fi +fi + +if [ $stage -le 5 ]; then + # combine the top-level graph and the sub-graph together using the command + # line tools. (In practice you might want to do this from appliation code). + + lang=data/lang_nosp_grammar2a + offset=$(grep nonterm_bos $lang/phones.txt | awk '{print $2}') # 364 + clist=$(grep nonterm:contact_list $lang/phones.txt | awk '{print $2}') # 368 + + # the graph in $tree_dir/grammar2/HCLG.fst will be a normal FST (ConstFst) + # that was expanded from the grammar. (we use --write-as-grammar=false to + # make it expand it). This is to test equivalence to the one in + # $tree_dir/grammar1/ + mkdir -p $tree_dir/grammar2 + make-grammar-fst --write-as-grammar=false --nonterm-phones-offset=$offset $tree_dir/grammar2a/HCLG.fst \ + $clist $tree_dir/grammar2b/HCLG.fst $tree_dir/grammar2/HCLG.fst +fi + +if [ $stage -le 6 ]; then + # Test equivalence using a random path.. can be useful for debugging if + # fstequivalent fails. + echo "$0: will print costs with the two FSTs, for one random path." + fstrandgen $tree_dir/grammar1/HCLG.fst > path.fst + for x in 1 2; do + fstproject --project_output=false path.fst | fstcompose - $tree_dir/grammar${x}/HCLG.fst | fstcompose - <(fstproject --project_output=true path.fst) > composed.fst + start_state=$(fstprint composed.fst | head -n 1 | awk '{print $1}') + fstshortestdistance --reverse=true composed.fst | awk -v s=$start_state '{if($1 == s) { print $2; }}' + done + +fi + +if [ $stage -le 7 ]; then + echo "$0: will test equivalece using fstequivalent" + if fstequivalent --delta=0.01 --random=true --npath=100 $tree_dir/grammar1/HCLG.fst $tree_dir/grammar2/HCLG.fst; then + echo "$0: success: the two were equivalent" + else + echo "$0: failure: the two were inequivalent" + fi +fi diff --git a/egs/mini_librispeech/s5/local/grammar/simple_demo_silprobs.sh b/egs/mini_librispeech/s5/local/grammar/simple_demo_silprobs.sh new file mode 100755 index 00000000000..414227f2ad6 --- /dev/null +++ b/egs/mini_librispeech/s5/local/grammar/simple_demo_silprobs.sh @@ -0,0 +1,175 @@ +#!/usr/bin/env bash + +# simple_demo_silprobs.sh is a version of simple_demo.sh that uses a lexicon +# with word-specific silence probabilities. + +# These scripts demonstrate how to use the grammar-decoding framework to build +# graphs made out of more than one part. It demonstrates using `fstequivalent` +# that the graph constructed this way is equivalent to what you would create if +# you had the LM all as a single piece. This uses the command line tools to +# expand to a regular FST (--write-as-grammar=false) In practice you might not +# want do to that, since the result might be large, and since writing the entire +# thing might take too much time. The code itself allows you to construct these +# GrammarFst objects in lightweight way and decode using them. + +stage=0 +set -e +. ./path.sh +. utils/parse_options.sh + + +tree_dir=exp/chain/tree_sp + +# For the purposes of this script we just need a biphone tree and associated +# transition-model for testing, because we're testing it at the graph level, +# i.e. testing equivalence of compiled HCLG graphs; there is no decoding +# involved here. + + +# For reference, the original command we +#utils/prepare_lang.sh data/local/dict \ +# "" data/local/lang_tmp data/lang + +if [ $stage -le 0 ]; then + [ -d data/local/dict_grammar1 ] && rm -r data/local/dict_grammar1 + cp -r data/local/dict data/local/dict_grammar1 + echo "#nonterm:contact_list" > data/local/dict_grammar1/nonterminals.txt + + utils/prepare_lang.sh data/local/dict_grammar1 \ + "" data/local/lang_tmp data/lang_grammar1 +fi + + + +if [ $stage -le 1 ]; then + # Most contents of these directories will be the same, only G.fst differs, but + # it's our practice to make these things as directories combining G.fst with + # everything else. + rm -r data/lang_grammar2{a,b} 2>/dev/null || true + cp -r data/lang_grammar1 data/lang_grammar2a + cp -r data/lang_grammar1 data/lang_grammar2b +fi + +if [ $stage -le 2 ]; then + # Create a simple G.fst in data/lang_grammar1, which won't + # actually use any grammar stuff, it will be a baseline to test against. + + lang=data/lang_grammar1 + cat < $lang/G.fst +0 1 GROUP GROUP +1 2 ONE ONE 0.69314718055994 +1 2 TWO TWO 0.69314718055994 +1 2 5.0 +2 3 ASSIST ASSIST 0.69314718055994 +2 0.69314718055994 +3 +EOF + utils/mkgraph.sh --self-loop-scale 1.0 $lang $tree_dir $tree_dir/grammar1 + + # test that the binary 'compile-graph' does the same thing as mkgraph.sh. + compile-graph --read-disambig-syms=$lang/phones/disambig.int $tree_dir/tree $tree_dir/1.mdl $lang/L_disambig.fst $lang/G.fst $tree_dir/grammar1/HCLG2.fst + + if ! fstequivalent --delta=0.01 --random=true --npath=100 $tree_dir/grammar1/HCLG{,2}.fst; then + echo "$0: two methods of producing graph in $tree_dir/grammar1 were different." + exit 1 + fi +fi + + +if [ $stage -le 3 ]; then + # create the top-level graph in data/lang_grammar2a + + # you can of course choose to put what symbols you want on the output side, as + # long as they are defined in words.txt. #nonterm:contact_list, #nonterm_begin + # and #nonterm_end would be defined in this example. This might be useful in + # situations where you want to keep track of the structure of calling + # nonterminals. + lang=data/lang_grammar2a + cat < $lang/G.fst +0 1 GROUP GROUP +1 2 #nonterm:contact_list +2 3 ASSIST ASSIST 0.69314718055994 +2 0.69314718055994 +3 +EOF + utils/mkgraph.sh --self-loop-scale 1.0 $lang $tree_dir $tree_dir/grammar2a + + # test that the binary 'compile-graph' does the same thing as mkgraph.sh. + offset=$(grep nonterm_bos $lang/phones.txt | awk '{print $2}') # 364 + compile-graph --nonterm-phones-offset=$offset --read-disambig-syms=$lang/phones/disambig.int \ + $tree_dir/tree $tree_dir/1.mdl $lang/L_disambig.fst $lang/G.fst $tree_dir/grammar2a/HCLG2.fst + + if ! fstequivalent --delta=0.01 --random=true --npath=100 $tree_dir/grammar2a/HCLG{,2}.fst; then + echo "$0: two methods of producing graph in $tree_dir/grammar2a were different." + exit 1 + fi +fi + +if [ $stage -le 4 ]; then + # Create the graph for the nonterminal in data/lang_grammar2b + # Again, we don't choose to put these symbols on the output side, but it would + # be possible to do so. + lang=data/lang_grammar2b + cat < $lang/G.fst +0 1 #nonterm_begin +1 2 ONE ONE 0.69314718055994 +1 2 TWO TWO 0.69314718055994 +1 2 5.0 +2 3 #nonterm_end +3 +EOF + utils/mkgraph.sh --self-loop-scale 1.0 $lang $tree_dir $tree_dir/grammar2b + + + # test that the binary 'compile-graph' does the same thing as mkgraph.sh. + offset=$(grep nonterm_bos $lang/phones.txt | awk '{print $2}') # 364 + compile-graph --nonterm-phones-offset=$offset --read-disambig-syms=$lang/phones/disambig.int \ + $tree_dir/tree $tree_dir/1.mdl $lang/L_disambig.fst $lang/G.fst $tree_dir/grammar2b/HCLG2.fst + + if ! fstequivalent --delta=0.01 --random=true --npath=100 $tree_dir/grammar2b/HCLG{,2}.fst; then + echo "$0: two methods of producing graph in $tree_dir/grammar2b were different." + exit 1 + fi +fi + +if [ $stage -le 5 ]; then + # combine the top-level graph and the sub-graph together using the command + # line tools. (In practice you might want to do this from appliation code). + + lang=data/lang_grammar2a + offset=$(grep nonterm_bos $lang/phones.txt | awk '{print $2}') # 364 + clist=$(grep nonterm:contact_list $lang/phones.txt | awk '{print $2}') # 368 + + # the graph in $tree_dir/grammar2/HCLG.fst will be a normal FST (ConstFst) + # that was expanded from the grammar. (we use --write-as-grammar=false to + # make it expand it). This is to test equivalence to the one in + # $tree_dir/grammar1/ + mkdir -p $tree_dir/grammar2 + make-grammar-fst --write-as-grammar=false --nonterm-phones-offset=$offset $tree_dir/grammar2a/HCLG.fst \ + $clist $tree_dir/grammar2b/HCLG.fst $tree_dir/grammar2/HCLG.fst +fi + +if [ $stage -le 6 ]; then + # Test equivalence using a random path.. can be useful for debugging if + # fstequivalent fails. + echo "$0: will print costs with the two FSTs, for one random path." + fstrandgen $tree_dir/grammar1/HCLG.fst > path.fst + for x in 1 2; do + fstproject --project_output=false path.fst | fstcompose - $tree_dir/grammar${x}/HCLG.fst | fstcompose - <(fstproject --project_output=true path.fst) > composed.fst + start_state=$(fstprint composed.fst | head -n 1 | awk '{print $1}') + fstshortestdistance --reverse=true composed.fst | awk -v s=$start_state '{if($1 == s) { print $2; }}' + done + +fi + +if [ $stage -le 7 ]; then + echo "$0: will test equivalece using fstequivalent" + if fstequivalent --delta=0.01 --random=true --npath=100 $tree_dir/grammar1/HCLG.fst $tree_dir/grammar2/HCLG.fst; then + echo "$0: success: the two were equivalent" + else + echo "$0: failure: the two were inequivalent" + fi +fi diff --git a/egs/mini_librispeech/s5/local/kws/compile_keywords.sh b/egs/mini_librispeech/s5/local/kws/compile_keywords.sh new file mode 100755 index 00000000000..9f88b9665ff --- /dev/null +++ b/egs/mini_librispeech/s5/local/kws/compile_keywords.sh @@ -0,0 +1,59 @@ +#!/bin/bash +# Copyright (c) 2015-2018, Johns Hopkins University (Yenda Trmal ) +# License: Apache 2.0 + +# Begin configuration section. +silence_word= +filter='OOV=0' +# End configuration section +echo $0 "$@" +. ./utils/parse_options.sh || exit 1; + +set -e -o pipefail +set -o nounset # Treat unset variables as an error + + +data=$1 +lang=$2 +workdir=$3 + +mkdir -p $workdir +if [ -f $data/categories ] ; then + cat $data/categories | \ + local/search/filter_by_category.pl $data/categories "$filter" > $workdir/categories + + if [ ! -s $workdir/categories ]; then + echo "$0: WARNING: $workdir/categories is zero-size. That means no keyword" + echo "$0: WARNING: was found that fits the filter \"$filter\". That might be expected." + touch $workdir/keywords.int + touch $workdir/keywords.fsts + exit 0 + fi + grep -w -F -f <(awk '{print $1}' $workdir/categories) \ + $data/keywords.int > $workdir/keywords.int +else + cp $data/keywords.int $workdir/keywords.int +fi + + + +if [ -s $workdir/keywords.int ]; then + if [ -z $silence_word ]; then + transcripts-to-fsts ark:$workdir/keywords.int \ + ark,scp,t:$workdir/keywords.fsts,- | sort -o $workdir/keywords.scp + else + silence_int=`grep -w $silence_word $lang/words.txt | awk '{print $2}'` + [ -z $silence_int ] && \ + echo "$0: Error: could not find integer representation of silence word $silence_word" && exit 1; + transcripts-to-fsts ark:$data/keywords.int ark,t:- | \ + awk -v 'OFS=\t' -v silint=$silence_int '{ + if (NF == 4 && $1 != 0) { print $1, $1, silint, silint; } print; + }' | fstcopy ark:- ark,scp,t:$workdir/keywords.fsts,- | \ + sort -o $workdir/keywords.scp + fi +else + echo "$0: WARNING: $workdir/keywords.int is zero-size. That means no keyword" + echo "$0: WARNING: was found in the dictionary. That might be expected -- or not." + touch $workdir/keywords.fsts +fi + diff --git a/egs/mini_librispeech/s5/local/kws/create_categories.pl b/egs/mini_librispeech/s5/local/kws/create_categories.pl new file mode 100755 index 00000000000..4a9e3314c41 --- /dev/null +++ b/egs/mini_librispeech/s5/local/kws/create_categories.pl @@ -0,0 +1,112 @@ +#!/usr/bin/env perl +#=============================================================================== +# Copyright 2015-2018 (Author: Yenda Trmal ) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== +my $Usage = < + e.g.: $0 keywords.txt + or $0 --results results + +Allowed options: + --results : instead of keyword specification format, keyword search + results format is assumed. + +NOTE: + If you need both information, you can call the script twice (with different + parameters) and call local/search/normalize_categories.pl to merge (and normalize) + these two tables together. +EOU + +use strict; +use warnings; +use utf8; +use POSIX; +use Data::Dumper; +use Getopt::Long; +use open qw(:std :utf8); + +binmode STDIN, ":utf8"; +binmode STDOUT, ":utf8"; +binmode STDERR, ":utf8"; + +my $result_format; +GetOptions("results", \$result_format) or do { + print STDERR "Cannot parse the command-line parameters.\n"; + print STDERR "$Usage\n"; + die "Cannot continue\n" +}; + +if ( @ARGV > 1 ) { + print STDERR "Incorrect number of command-line parameters\n"; + print STDERR "$Usage\n"; + die "Cannot continue\n" +} + +sub QuantizeCount { + my $count = shift @_; + + if ($count <= 0) { + return "0"; + } elsif ($count == 1) { + return "000-001"; + } elsif ($count <= 5) { + return "002-005"; + } elsif ($count <=10) { + return "006-010"; + } elsif ($count <=20) { + return "011-020"; + } elsif ($count <=100) { + return "021-100"; + } else { + return "101-inf"; + } +} + +if (not $result_format ) { + my $kwlist_name=$ARGV[0]; + while (my $line = <>) { + chomp $line; + my ($kwid, $text) = split " ", $line, 2; + + my @words = split " ", $text; + printf "$kwid NGramOrder=%03d\n", scalar @words; + printf "$kwid Characters=%03d\n", length(join("", @words)); + print "$kwid $kwid\n"; + } +} else { + my $prev_kwid = ""; + my $count = 0; + + while (my $line = <>) { + chomp $line; + my @entries = split " ", $line; + next unless @entries; + + if ($prev_kwid ne $entries[0]) { + if ($prev_kwid) { + print "$prev_kwid ResCount=$count\n"; + print "$prev_kwid ResCountQuant=" . QuantizeCount($count) . "\n"; + } + $count = 0; + $prev_kwid = $entries[0]; + } + $count += 1; + } +} + + diff --git a/egs/mini_librispeech/s5/local/kws/create_hitlist.sh b/egs/mini_librispeech/s5/local/kws/create_hitlist.sh new file mode 100755 index 00000000000..be06a3b9312 --- /dev/null +++ b/egs/mini_librispeech/s5/local/kws/create_hitlist.sh @@ -0,0 +1,72 @@ +#!/bin/bash +# Copyright 2012-2018 Johns Hopkins University (Author: Guoguo Chen, Yenda Trmal) +# Apache 2.0. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +cmd=run.pl +scale_opts="--transition-scale=1.0 --acoustic-scale=0.1 --self-loop-scale=0.1" +beam=10 +retry_beam=40 +boost_silence=1.0 + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# != 5 ]; then + echo "This script takes an ali directory and creates the corresponding RTTM file" + echo "" + echo "Usage: create_hitlist.sh " + echo " e.g.: create_hitlist.sh data/heldout data/lang data/local/lang_tmp exp/heldout_ali data/heldout/kws" + echo "main options (for others, see top of script file)" + echo " --cmd (utils/run.pl|utils/queue.pl ) " + + exit 1; +fi + +set -e +set -o pipefail +set -u + +data=$1 +lang=$2 +lang_tmp=$3 +dir=$4 +kws=$5 + +oov=`cat $lang/oov.txt` +mkdir -p $dir/log + +echo "$0: writing alignments." +wbegin=`grep "#1" $lang/phones.txt | head -1 | awk '{print $2}'` +wend=`grep "#2" $lang/phones.txt | head -1 | awk '{print $2}'` + +if [ ! -f $lang/L_align.fst ]; then + echo "$0: generating $lang/L_align.fst" + local/kws/make_L_align.sh $lang_tmp $lang $lang 2>&1 | tee $dir/log/L_align.log +fi + +$cmd $dir/log/ali_to_hitlist.log \ + set -e -o pipefail\; \ + ali-to-phones $dir/final.mdl "ark:gunzip -c $dir/ali.*.gz|" ark,t:- \| \ + phones-to-prons $lang/L_align.fst $wbegin $wend ark:- "ark,s:utils/sym2int.pl -f 2- --map-oov '$oov' $lang/words.txt <$data/text|" ark,t:- \| \ + prons-to-wordali ark:- "ark:ali-to-phones --write-lengths=true $dir/final.mdl 'ark:gunzip -c $dir/ali.*.gz|' ark,t:- |" ark,t:- \| \ + local/kws/generate_hitlist.pl $kws/keywords.int \|\ + utils/sym2int.pl -f 2 $kws/utt.map \> $kws/hitlist + +echo "$0: done generating hitlist" + + +exit 0; diff --git a/egs/mini_librispeech/s5/local/kws/example/keywords.txt b/egs/mini_librispeech/s5/local/kws/example/keywords.txt new file mode 100644 index 00000000000..118de904297 --- /dev/null +++ b/egs/mini_librispeech/s5/local/kws/example/keywords.txt @@ -0,0 +1,7 @@ +KWS_001 GOOD MORNING +KWS_002 SCOTLAND YARD +KWS_003 CLERGYMAN +KWS_004 UNKNOWN +KWS_005 WHITE FLAG +KWS_006 DON'T CRY + diff --git a/egs/mini_librispeech/s5/local/kws/filter_kws_results.pl b/egs/mini_librispeech/s5/local/kws/filter_kws_results.pl new file mode 100755 index 00000000000..37549249bdc --- /dev/null +++ b/egs/mini_librispeech/s5/local/kws/filter_kws_results.pl @@ -0,0 +1,189 @@ +#!/usr/bin/env perl +#=============================================================================== +# Copyright 2015-2018 (Author: Yenda Trmal ) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +my $Usage = < > output + e.g.: gunzip -c exp/tri5/kws/result.*.gz | $0 > exp/tri5/kws/results + +Allowed options: + --nbest : how many best results (for each KWID) should be printed + (int, default -1, i.e. no limit) + --duptime : duplicates detection, tolerance (in frames) for being + the same hits (int, default = 50) + --likes + --probs + +CAVEATS: + The script tries to be memory-effective. The impact of this is that we + assume the results are sorted by KWID (i.e. all entries with the same KWID + are in a continuous block). The user is responsible for sorting it. +EOU + +use strict; +use warnings; +use utf8; +use POSIX; +use Data::Dumper; +use Getopt::Long; + +# if parameter nbest > 0, then filters the result list so that there is no +# more than nbest hits in the output for each of the KWID +# + +my $nbest = -1; +my $duptime = 50; +my $likes = 0; + +#print STDERR join(" ", $0, @ARGV) . "\n"; +GetOptions ("nbest=f" => \$nbest, + "likes" => \$likes, + "probs" => sub{ $likes = 0}, + "duptime=i" => \$duptime) || do { + print STDERR "Cannot parse the command-line parameters.\n"; + print STDERR "$Usage\n"; + die "Cannot continue\n" +}; + +if (@ARGV != 0) { + print STDERR "Incorrect number of command-line parameters\n"; + print STDERR "$Usage\n"; + die "Cannot continue\n" +} + +# Function for sorting +sub KwslistOutputSort { + if ($a->[0] ne $b->[0]) { + if ($a->[0] =~ m/[0-9]+$/ && $b->[0] =~ m/[0-9]+$/) { + ($a->[0] =~ /([0-9]*)$/)[0] <=> ($b->[0] =~ /([0-9]*)$/)[0] + } else { + $a->[0] cmp $b->[0]; + } + } elsif ($a->[5] ne $b->[5]) { + $b->[5] <=> $a->[5]; + } else { + $a->[1] cmp $b->[1]; + } +} + +sub KwslistDupSort { + my ($a, $b, $duptime) = @_; + if ($a->[1] ne $b->[1]) { + #file + $a->[1] cmp $b->[1]; + } elsif (abs($a->[2]-$b->[2]) >= $duptime){ + #start + $a->[2] <=> $b->[2]; + } elsif ($a->[4] ne $b->[4]) { + #score + $b->[4] <=> $a->[4]; + } else { + #end time + $b->[3] <=> $a->[3]; + } +} + +my @RESULTS; +my %SEEN_KWS; +my $kw = ""; + +while ( my $line = ) { + chomp $line; + my @F = split " ", $line; + @F == 5 || die "$0: Bad number of columns in raw results \"$line\"\n"; + + $F[4] = -$F[4] if $likes; + + if ($F[0] eq $kw) { + push @RESULTS, \@F; + } elsif ($kw eq "" ) { + @RESULTS = (); + push @RESULTS, \@F; + $kw = $F[0]; + } else { + + my @results; + my @tmp = sort { KwslistDupSort($a, $b, $duptime) } @RESULTS; + + @results = (); + if (@tmp >= 1) {push(@results, $tmp[0])}; + for (my $i = 1; $i < scalar(@tmp); $i ++) { + my $prev = $results[-1]; + my $curr = $tmp[$i]; + if ((abs($prev->[2]-$curr->[2]) < $duptime ) && + ($prev->[1] eq $curr->[1])) { + next; + } else { + push(@results, $curr); + } + } + + # this is probably needed only when nbest > 0 + @results = sort { ($b->[4] + 0.0) <=> ($a->[4] + 0.0) } @results; + + my $len; + if( $nbest > 0) { + $len = scalar @results < $nbest ? scalar @results : $nbest; + } else { + $len = scalar @results; + } + for (my $i=0; $i < $len; $i++) { + $results[$i]->[4] = -$results[$i]->[4] if $likes; + print join(" ", @{$results[$i]}) . "\n"; + } + + @RESULTS = (); + push @RESULTS, \@F; + $kw = $F[0]; + } +} +do { + my @results; + my @tmp = sort { KwslistDupSort($a, $b, $duptime) } @RESULTS; + + @results = (); + if (@tmp >= 1) {push(@results, $tmp[0])}; + for (my $i = 1; $i < scalar(@tmp); $i ++) { + my $prev = $results[-1]; + my $curr = $tmp[$i]; + if ((abs($prev->[2]-$curr->[2]) < $duptime ) && + ($prev->[1] eq $curr->[1])) { + next; + } else { + push(@results, $curr); + } + } + + # this is probably needed only when nbest > 0 + @results = sort { ($b->[4] + 0.0) <=> ($a->[4] + 0.0) } @results; + + my $len; + if( $nbest > 0) { + $len = scalar @results < $nbest ? scalar @results : $nbest; + } else { + $len = scalar @results; + } + for (my $i=0; $i < $len; $i++) { + $results[$i]->[4] = -$results[$i]->[4] if $likes; + print join(" ", @{$results[$i]}) . "\n"; + } +} + + diff --git a/egs/mini_librispeech/s5/local/kws/generate_hitlist.pl b/egs/mini_librispeech/s5/local/kws/generate_hitlist.pl new file mode 100755 index 00000000000..41df32626a6 --- /dev/null +++ b/egs/mini_librispeech/s5/local/kws/generate_hitlist.pl @@ -0,0 +1,117 @@ +#!/usr/bin/env perl +#=============================================================================== +# Copyright 2018 (Author: Yenda Trmal ) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +# this will generate the hitlist (list of all hits) using the word-level +# alignments +# Format of the file +# utt-id word-1 duration-1 ; word-2 duration-2 ; .... +# it is exactly the same format that you can get from ali-to-phones with +# parameter --write-lengths (see the script create_hitlist.sh for complete +# example) + +# The script is not very optimized -- the finding of the hits in the utterance +# is done by concatenating the word_ids sequence using '|' and then by searching +# for a substring processed the same way. After that, we workout the word-level +# indices of the individual hits (remember, there may be more hits per utterance) +# Probably still faster than rolling our own searching algorithm due to the fact +# that it goes directly to (optimized) perl's runtime function + +use strict; +use warnings; +use utf8; + +if ((scalar @ARGV > 2) || (scalar @ARGV < 1)) { + print STDERR "Usage: $0 []\n"; + print STDERR "E.g.\n"; + print STDERR " $0 data/train_clean_5/kws/keywords.int < exp/tri3b_ali_train_clean_5/align.txt\n"; + die "Incorrect number of arguments." +} + +my $keyword_file = shift @ARGV; +open(my $keywords, "<$keyword_file") or + die "Cannot open $keyword_file for reading"; + +my @KW; +while (<$keywords>) { + chomp; + next unless $_; + my @F = split; + my $kwid = shift @F; + push @KW, [$kwid, \@F]; +} + +while (<>) { + chomp; + next unless $_; + + my @F = split(" ", $_, 2); + my $utt_id = shift @F; + @F = split(/ ; /, $F[0]); + + my $frames_prev = 0; + my @UTT; + foreach my $entry (@F) { + (my $word, my $frames) = split(" ", $entry, 2); + if ($word ne 0) { + my $frames_start = $frames_prev; + my $frames_end = $frames_start + $frames; + $frames_prev = $frames_end; + push @UTT, [$word + 0, $frames_start, $frames_end]; + } else { + $frames_prev += $frames; + } + } + + my $utt_string = '|' . join('|', map { $_->[0] } @UTT) . '|'; + my %utt_indices; + my $counter = 0; + my $idx = 0; + #mapping between the position in the utt_string and the position of + #the word in the original utterance + while () { + $idx = index($utt_string, '|', $idx); + last if $idx == -1; + $utt_indices{$idx} = $counter; + $idx += 1; + $counter +=1 + } + + + foreach my $kw (@KW) { + my $kw_string = "|" . join('|', @{$kw->[1]}) . '|'; + my $kwlen = scalar @{$kw->[1]}; + + my $idx = 0; + my @all_idx; + while () { + $idx = index($utt_string, $kw_string, $idx); + last if $idx == -1; + push @all_idx, $idx; + $idx += 1; + } + + foreach my $hit (@all_idx) { + my $start_idx = $utt_indices{$hit}; + my $end_idx = $start_idx + $kwlen - 1; + my $start = $UTT[$start_idx]->[1]; + my $end = $UTT[$end_idx]->[2]; + + print "$kw->[0] $utt_id $start $end 0\n"; + } + } +} diff --git a/egs/mini_librispeech/s5/local/kws/keywords_to_indices.pl b/egs/mini_librispeech/s5/local/kws/keywords_to_indices.pl new file mode 100755 index 00000000000..7eb721cf1c3 --- /dev/null +++ b/egs/mini_librispeech/s5/local/kws/keywords_to_indices.pl @@ -0,0 +1,123 @@ +#!/usr/bin/env perl +# Copyright 2012-2018 Johns Hopkins University (Author: Yenda Trmal) +# Apache 2.0. + +use Data::Dumper; +$Data::Dumper::Indent = 1; + +binmode STDOUT, ":utf8"; +binmode STDIN, ":utf8"; + +sub permute { + + my $last = pop @_; + + unless(@_) { + return map([$_], @$last); + } + + return map { + my $left = $_; + map([@$left, $_], @$last) + } + permute(@_); +} + +$oov_count=0; + +$ignore_oov = 0; +$ignore_first_field = 0; +for($x = 0; $x < 2; $x++) { + if ($ARGV[0] eq "--map-oov") { + shift @ARGV; $map_oov = shift @ARGV; + } + if ($ARGV[0] eq "-f") { + shift @ARGV; + $field_spec = shift @ARGV; + if ($field_spec =~ m/^\d+$/) { + $field_begin = $field_spec - 1; $field_end = $field_spec - 1; + } + if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10) + if ($1 ne "") { + $field_begin = $1 - 1; # Change to zero-based indexing. + } + if ($2 ne "") { + $field_end = $2 - 1; # Change to zero-based indexing. + } + } + if (!defined $field_begin && !defined $field_end) { + die "Bad argument to -f option: $field_spec"; + } + } +} + +$symtab = shift @ARGV; +if (!defined $symtab) { + print STDERR "Usage: sym2int.pl [options] symtab [input transcriptions] > output transcriptions\n" . + "options: [--map-oov ] [-f ]\n" . + "note: can look like 4-5, or 4-, or 5-, or 1.\n"; +} +open(F, "<:encoding(UTF-8)", $symtab) || die "Error opening symbol table file $symtab"; +while() { + @A = split(" ", $_); + @A == 2 || die "bad line in symbol table file: $_"; + + if ( not defined( $sym2int{$A[0]} ) ) { + $sym2int{$A[0]} = []; + } + push @{ $sym2int{$A[0]} }, $A[1] + 0; +} +#print Dumper(\%sym2int); + +if (defined $map_oov && $map_oov !~ m/^\d+$/) { # not numeric-> look it up + if (!defined $sym2int{$map_oov}) { die "OOV symbol $map_oov not defined."; } + $map_oov = $sym2int{$map_oov}; +} + +$lines=0; +while (<>) { + @A = split(" ", $_); + @B = (); + $lines = $lines + 1; + $undefined_words = 0; + for ($n = 1; $n < @A; $n++) { + $a = $A[$n]; + $i = $sym2int{$a}; + if (!defined ($i)) { + if (defined $map_oov) { + if ($num_warning++ < $max_warning) { + print STDERR "sym2int.pl: replacing $a with $map_oov\n"; + if ($num_warning == $max_warning) { + print STDERR "sym2int.pl: not warning for OOVs any more times\n"; + } + } + $i = [ $map_oov ]; + } else { + $pos = $n+1; + die "sym2int.pl: undefined symbol $a (in position $pos)\n"; + } + $undefined_words = $undefined_words + 1; + } + $a = $i; + push @B, $a; + } + #if ( defined $sym2int{$A[$n]} ) { + # push @B, $sym2int{$A[$n]}; + #} else { + # push @B, [0]; + #} + if ($undefined_words > 0) { + $oov_count = $oov_count + 1; + } + @C = permute @B; + #print Dumper(\@B); + #print Dumper(\@C); + foreach $phrase ( @C ) { + print "$A[0] "; + print join(" ", @{$phrase}); + print "\n"; + } +} + +print STDERR "Found $oov_count phrases containing (at least one) OOV...\n"; + diff --git a/egs/mini_librispeech/s5/local/kws/make_L_align.sh b/egs/mini_librispeech/s5/local/kws/make_L_align.sh new file mode 100755 index 00000000000..72a1e9e3f4c --- /dev/null +++ b/egs/mini_librispeech/s5/local/kws/make_L_align.sh @@ -0,0 +1,59 @@ +#!/bin/bash +# Copyright 2013-2018 Johns Hopkins University (authors: Guoguo Chen, Yenda Trmal) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +set -o pipefail +set -e +set -x + +if [ $# -ne 3 ]; then + echo "This is a simple script that will generate the L_align.fst" + echo "The FST L_align.fst is used for getting the force-aligned " + echo "utterances" + echo "The script automaticky recognizes the probabilistic lexicon" + echo "is used and will use the correct file" + echo "" + echo "usage: local/L_align.sh " + echo "e.g.: local/L_align.sh data/local/lang data/lang data/lang" + exit 1; +fi + +tmpdir=$1 +dir=$2 +outdir=$3 + +silphone=`cat $dir/phones/optional_silence.txt` || exit 1; + +# Create lexicon with alignment info +if [ -f $tmpdir/lexicon.txt ] ; then + cat $tmpdir/lexicon.txt | \ + awk '{printf("%s #1 ", $1); for (n=2; n <= NF; n++) { printf("%s ", $n); } print "#2"; }' | \ + utils/make_lexicon_fst.pl - 0.5 $silphone | \ + fstcompile --isymbols=$dir/phones.txt --osymbols=$dir/words.txt \ + --keep_isymbols=false --keep_osymbols=false | \ + fstarcsort --sort_type=olabel > $outdir/L_align.fst +elif [ -f $tmpdir/lexiconp.txt ] ; then + cat $tmpdir/lexiconp.txt | \ + awk '{printf("%s #1 ", $1); for (n=3; n <= NF; n++) { printf("%s ", $n); } print "#2"; }' | \ + utils/make_lexicon_fst.pl - 0.5 $silphone | \ + fstcompile --isymbols=$dir/phones.txt --osymbols=$dir/words.txt \ + --keep_isymbols=false --keep_osymbols=false | \ + fstarcsort --sort_type=olabel > $outdir/L_align.fst +else + echo >&2 "Neither $tmpdir/lexicon.txt nor $tmpdir/lexiconp.txt does not exist" + exit 1 +fi +exit 0; diff --git a/egs/mini_librispeech/s5/local/kws/normalize_results_kst.pl b/egs/mini_librispeech/s5/local/kws/normalize_results_kst.pl new file mode 100755 index 00000000000..5e8e6419959 --- /dev/null +++ b/egs/mini_librispeech/s5/local/kws/normalize_results_kst.pl @@ -0,0 +1,203 @@ +#!/usr/bin/env perl +#=============================================================================== +# Copyright 2015-2018 (Author: Yenda Trmal ) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== +my $Usage = < results.normalized + +Allowed options: + --probs : the input is probabilities instead of neg-loglikelihoods + + --duration|--trials : size of the searched collectiona in seconds (float) + --beta : the FA vs MISS rate (float, default 999.9) + --ntrue-scale : scales for scaling the expected count of true hits (float, default 1.0) + --thr|--threshold : the decision threshold (float, default 0.5) +EOU + +use strict; +use warnings; +use utf8; +use POSIX; +use Data::Dumper; +use Getopt::Long; + +my $ntrue_scale = 1.0; +my $global_thr = 0.5; +my $beta = 999.9; +my $duration = 35785.578; +my $ntrue_table_filename; +my $probs=0; +my $bsum_filename; + +GetOptions("duration|trials=f" => \$duration, + "ntrue-scale=f" => \$ntrue_scale, + "beta=f" => \$beta, + "probs" => \$probs, + "thr|threshold=f" => \$global_thr, + "ntrue-table=s" => \$ntrue_table_filename, + "bsum-table=s" => \$bsum_filename) or do + { + print STDERR "$0: Cannot parse the command-line parameters.\n"; + print STDERR "$Usage\n"; + die "$0: Cannot continue\n" +}; + +if (@ARGV != 0) { + print STDERR "$0: Incorrect number of command-line parameters\n"; + print STDERR "$Usage\n"; + die "$0: Cannot continue\n" +} + +sub ComputeKST { + my @instances = @{shift @_}; + my $ntrue_scale = shift @_; + my %ntrue_table = %{shift @_}; + + + my $ntrue = 0; + foreach my $elem(@instances) { + $ntrue += $elem->[4]; + } + #$ntrue = $ntrue / @instances; + if (defined ($ntrue_table{$instances[0]->[0]})) { + #print STDERR "For KW " . $instances[0]->[0] . " using the value " . $ntrue_table{$instances[0]->[0]} . "\n"; + $ntrue = $ntrue * $ntrue_table{$instances[0]->[0]}; + } else { + #print STDERR "Using the default vsalue $ntrue_scale\n"; + $ntrue = $ntrue * $ntrue_scale; + } + + my $thr = $beta * $ntrue / ( $duration + $ntrue * ($beta - 1)); + return $thr; +} + +sub ComputeKSTWithExpected { + my @instances = @{shift @_}; + my %expected_table = %{shift @_}; + my $ntrue_scale = shift @_; + my %ntrue_table = %{shift @_}; + + + my $ntrue = $expected_table{$instances[0]->[0]}; + #$ntrue = $ntrue / @instances; + if (defined ($ntrue_table{$instances[0]->[0]})) { + #print STDERR "For KW " . $instances[0]->[0] . " using the value " . $ntrue_table{$instances[0]->[0]} . "\n"; + $ntrue = $ntrue * $ntrue_table{$instances[0]->[0]}; + } else { + #print STDERR "Using the default vsalue $ntrue_scale\n"; + $ntrue = $ntrue * $ntrue_scale; + } + + my $thr = $beta * $ntrue / ( $duration + $ntrue * ($beta - 1)); + return $thr; +} +sub NormalizeScores { + my @instances = @{shift @_}; + my $thr = shift @_; + my $global_thr = shift @_; + + + if ($thr == 0) { + $thr = 0.001; + } + my $q = log($global_thr)/log($thr); + + foreach my $elem(@instances) { + $elem->[4] = pow($elem->[4], $q); + } +} + +sub WriteResults { + my @instances = @{shift @_}; + + foreach my $elem(@instances) { + print join(" ", @{$elem}) . "\n"; + die "$0: " . join(" ", @{$elem}) . "\n" if $elem->[-1] > 1.0; + } + +} + +my $KWID; +my @putative_hits; +my %NTRUE_TABLE = (); + +my %BSUM=(); +if (defined $bsum_filename) { + open(BSUMF, $bsum_filename) or die "$0: Cannot open $bsum_filename"; + while (my $line = ) { + chomp $line; + next unless (($line =~ m/^\s*KW/) || ($line =~ m/^Keyword\s*KW/)); + $line =~ s/^Keyword//g; + $line =~ s/^\s+|\s+$//g; + my @entries = split /\s*\|\s*/, $line; + $BSUM{$entries[0]} = $entries[12]; + } + close(BSUMF); +} + +if ( defined $ntrue_table_filename) { + open (F, $ntrue_table_filename) or die "$0: Cannot open the Ntrue-table file\n"; + while (my $line = ) { + my @entries=split(" ", $line); + + die "$0: The Ntrue-table does not have expected format\n" if @entries != 2; + $NTRUE_TABLE{$entries[0]} = $entries[1] + 0.0; + } + close (F); +} + +while (my $line = ) { + chomp $line; + (my $kwid, my $file, my $start, my $end, my $score) = split " ", $line; + + if ($KWID && ($kwid ne $KWID)) { + + my $thr = ComputeKST(\@putative_hits, $ntrue_scale, \%NTRUE_TABLE ); + if ((defined $BSUM{$KWID}) && (scalar @putative_hits > 100)) { + print STDERR "$0: $KWID $thr $BSUM{$KWID} " . log($thr)/log($global_thr) . "\n"; + my $old_thr = $thr; + $thr = pow($BSUM{$KWID}, log($thr)/log($global_thr)); + } + if ($thr < 0.9999 ) { + NormalizeScores(\@putative_hits, $thr, $global_thr); + WriteResults(\@putative_hits); + } + + $KWID = $kwid; + @putative_hits = (); + } elsif ( not $KWID ) { + $KWID = $kwid; + } + + unless ($probs) { + $score = exp(-$score); + } + push @putative_hits, [$kwid, $file, $start, $end, $score]; +} + +if ($KWID) { + my $thr = ComputeKST(\@putative_hits, $ntrue_scale, \%NTRUE_TABLE ); + if ((defined $BSUM{$KWID}) && (scalar @putative_hits > 100)) { + $thr = pow($BSUM{$KWID}, log($thr)/log($global_thr)); + } + if ($thr < 0.9999 ) { + NormalizeScores(\@putative_hits, $thr, $global_thr); + WriteResults(\@putative_hits); + } +} + diff --git a/egs/mini_librispeech/s5/local/kws/run_kws.sh b/egs/mini_librispeech/s5/local/kws/run_kws.sh new file mode 100755 index 00000000000..8e7b56f0082 --- /dev/null +++ b/egs/mini_librispeech/s5/local/kws/run_kws.sh @@ -0,0 +1,108 @@ +#!/bin/bash +# Copyright (c) 2018, Johns Hopkins University (Yenda Trmal ) +# License: Apache 2.0 + +# Begin configuration section. +flen=0.01 +stage=0 +cmd=run.pl +data=data/dev_clean_2 +lang=data/lang +keywords=local/kws/example/keywords.txt +output=data/dev_clean_2/kws/ +# End configuration section + +. ./utils/parse_options.sh +. ./path.sh + +set -e -o pipefail +set -o nounset # Treat unset variables as an error + +mkdir -p $output +if [ $stage -le 1 ] ; then + ## generate the auxiliary data files + ## utt.map + ## wav.map + ## trials + ## frame_length + ## keywords.int + + ## For simplicity, we do not generate the following files + ## categories + + ## We will generate the following files later + ## hitlist + ## keywords.fsts + + [ ! -f $data/utt2dur ] && + utils/data/get_utt2dur.sh $data + + duration=$(cat $data/utt2dur | awk '{sum += $2} END{print sum}' ) + + echo $duration > $output/trials + echo $flen > $output/frame_length + + echo "Number of trials: $(cat $output/trials)" + echo "Frame lengths: $(cat $output/frame_length)" + + echo "Generating map files" + cat $data/utt2dur | awk 'BEGIN{i=1}; {print $1, i; i+=1;}' > $output/utt.map + cat $data/wav.scp | awk 'BEGIN{i=1}; {print $1, i; i+=1;}' > $output/wav.map + + cp $lang/words.txt $output/words.txt + cp $keywords $output/keywords.txt + cat $output/keywords.txt | \ + local/kws/keywords_to_indices.pl --map-oov 0 $output/words.txt | \ + sort -u > $output/keywords.int +fi + +if [ $stage -le 2 ] ; then + ## this step generates the file hitlist + + ## in many cases, when the reference hits are given, the followin two steps \ + ## are not needed + ## we create the alignments of the data directory + ## this is only so that we can obtain the hitlist + steps/align_fmllr.sh --nj 5 --cmd "$cmd" \ + $data data/lang exp/tri3b exp/tri3b_ali_$(basename $data) + + local/kws/create_hitlist.sh $data $lang data/local/lang_tmp \ + exp/tri3b_ali_$(basename $data) $output +fi + +if [ $stage -le 3 ] ; then + ## this steps generates the file keywords.fsts + + ## compile the keywords (it's done via tmp work dirs, so that + ## you can use the keywords filtering and then just run fsts-union + local/kws/compile_keywords.sh $output $lang $output/tmp.2 + cp $output/tmp.2/keywords.fsts $output/keywords.fsts + # for example + # fsts-union scp:<(sort data/$dir/kwset_${set}/tmp*/keywords.scp) \ + # ark,t:"|gzip -c >data/$dir/kwset_${set}/keywords.fsts.gz" + ## +fi + +system=exp/chain/tdnn1h_sp_online/decode_tglarge_dev_clean_2/ +if [ $stage -le 4 ]; then + ## this is not exactly necessary for a single system and single keyword set + ## but if you have multiple keyword sets, then it avoids having to recompute + ## the indices unnecesarily every time (see --indices-dir and --skip-indexing + ## parameters to the search script bellow). + for lmwt in `seq 8 14` ; do + steps/make_index.sh --cmd "$cmd" --lmwt $lmwt --acwt 1.0 \ + --frame-subsampling-factor 3\ + $output $lang $system $system/kws_indices_$lmwt + done +fi + +if [ $stage -le 5 ]; then + ## find the hits, normalize and score + local/kws/search.sh --cmd "$cmd" --min-lmwt 8 --max-lmwt 14 \ + --indices-dir $system/kws_indices --skip-indexing true\ + $lang $data $system +fi + +echo "Done" + + diff --git a/egs/mini_librispeech/s5/local/kws/score.sh b/egs/mini_librispeech/s5/local/kws/score.sh new file mode 100755 index 00000000000..b056e150e83 --- /dev/null +++ b/egs/mini_librispeech/s5/local/kws/score.sh @@ -0,0 +1,147 @@ +#!/bin/bash + +# Copyright 2012-2018 Johns Hopkins University (Author: Guoguo Chen, Yenda Trmal) +# Apache 2.0. + +# Begin configuration section. +# case_insensitive=true +extraid= +min_lmwt=8 +max_lmwt=12 +cmd=run.pl +stage=0 +ntrue_from= +# End configuration section. + +help_message="$0: score the kwslist using the F4DE scorer from NIST + Example: + $0 [additional-parameters] + where the most important additional parameters can be: + --extraid #for using, when a non-default kws tasks are setup + (using the kws_setup.sh --extraid) for a kaldi-single data-dir" + +echo $0 $@ +[ -f ./path.sh ] && . ./path.sh; # source the path. +. parse_options.sh || exit 1; + + +if [ $# -ne 3 ]; then + printf "FATAL: incorrect number of variables given to the script\n\n" + printf "$help_message\n" + exit 1; +fi + +set -e -o pipefail + +langdir=$1 +if [ -z $extraid ] ; then + kwsdatadir=$2/kws +else + kwsdatadir=$2/kwset_${extraid} +fi +kwsoutputdir="$3" + +trials=$(cat $kwsdatadir/trials) +mkdir -p $kwsoutputdir/log/ + +if [ $stage -le 0 ] ; then + if [ -z "$ntrue_from" ]; then + for LMWT in $(seq $min_lmwt $max_lmwt) ; do + mkdir -p ${kwsoutputdir}_$LMWT/details/ + mkdir -p ${kwsoutputdir}_$LMWT/scoring/ + + # as we need to sweep through different ntrue-scales we will + # we will do it in one parallel command -- it will be more effective + # than sweeping in a loop and for all lmwts in parallel (as usuallyu + # there will be just a couple of different lmwts, but the ntrue-scale + # has a larger dynamic range + $cmd NTRUE=1:21 $kwsoutputdir/log/score.${LMWT}.NTRUE.log \ + ntrue=\$\(perl -e 'print 1+(NTRUE-1)/5.0' \) '&&' \ + cat ${kwsoutputdir}_$LMWT/results \|\ + local/kws/normalize_results_kst.pl --trials $trials --ntrue-scale \$ntrue \|\ + local/kws/filter_kws_results.pl --probs --nbest 200 \|\ + compute-atwv $trials ark,t:$kwsdatadir/hitlist ark:- \ + \> ${kwsoutputdir}_$LMWT/scoring/score.NTRUE.txt + + ntrue=$(grep ATWV ${kwsoutputdir}_$LMWT/scoring/score.*.txt | \ + sort -k2,2nr -t '=' | head -n 1 | \ + sed 's/.*score\.\([0-9][0-9]*\)\.txt.*/\1/g') + #The calculation of ntrue must be the same as in the command above + echo "$ntrue" > ${kwsoutputdir}_$LMWT/details/ntrue_raw + ntrue=$(perl -e "print 1+($ntrue-1)/5.0") + echo "$ntrue" > ${kwsoutputdir}_$LMWT/details/ntrue + done + else + for LMWT in $(seq $min_lmwt $max_lmwt) ; do + mkdir -p ${kwsoutputdir}_$LMWT/details/ + mkdir -p ${kwsoutputdir}_$LMWT/scoring/ + + cp ${ntrue_from}_${LMWT}/details/ntrue ${kwsoutputdir}_${LMWT}/details/ntrue + [ -f ${ntrue_from}_${LMWT}/details/ntrue_raw ] && \ + cp ${ntrue_from}_${LMWT}/details/ntrue_raw ${kwsoutputdir}_${LMWT}/details/ntrue_raw + echo "$ntrue_from" > ${kwsoutputdir}_${LMWT}/details/ntrue_from + done + fi +fi + +if [ $stage -le 1 ] ; then + $cmd LMWT=$min_lmwt:$max_lmwt $kwsoutputdir/log/normalize.LMWT.log \ + cat ${kwsoutputdir}_LMWT/results \|\ + local/kws/normalize_results_kst.pl --trials $trials --ntrue-scale \$\(cat ${kwsoutputdir}_LMWT/details/ntrue\)\ + \> ${kwsoutputdir}_LMWT/details/results + + $cmd LMWT=$min_lmwt:$max_lmwt $kwsoutputdir/log/score.final.LMWT.log \ + cat ${kwsoutputdir}_LMWT/details/results \|\ + compute-atwv $trials ark,t:$kwsdatadir/hitlist ark:- \ + ${kwsoutputdir}_LMWT/details/alignment.csv \> ${kwsoutputdir}_LMWT/details/score.txt '&&' \ + cp ${kwsoutputdir}_LMWT/details/score.txt ${kwsoutputdir}_LMWT/score.txt + + if [ -f $kwsdatadir/categories ]; then + $cmd LMWT=$min_lmwt:$max_lmwt $kwsoutputdir/log/per-category-stats.LMWT.log \ + cat ${kwsoutputdir}_LMWT/details/alignment.csv \|\ + perl local/search/per_category_stats.pl --sweep-step 0.005 $trials \ + $kwsdatadir/categories \> ${kwsoutputdir}_LMWT/details/per-category-score.txt + else + echo "$0: Categories file not found, not generating per-category scores" + fi +fi + +if [ $stage -le 2 ]; then +if [ -f $kwsdatadir/f4de_attribs ] ; then + language="" + flen=0.01 + kwlist_name="" + . $kwsdatadir/f4de_attribs #override the previous variables + + ecf=$kwsdatadir/ecf.xml + rttm=$kwsdatadir/rttm + kwlist=$kwsdatadir/kwlist.xml + + $cmd LMWT=$min_lmwt:$max_lmwt $kwsoutputdir/log/f4de_prepare.LMWT.log \ + mkdir -p ${kwsoutputdir}_LMWT/f4de/ '&&' cat $kwlist \| \ + local/search/annotate_kwlist.pl $kwsdatadir/categories \> ${kwsoutputdir}_LMWT/f4de/kwlist.xml + + $cmd LMWT=$min_lmwt:$max_lmwt $kwsoutputdir/log/f4de_write_kwslist.LMWT.log \ + cat ${kwsoutputdir}_LMWT/details/results \| \ + utils/int2sym.pl -f 2 $kwsdatadir/utt.map \| \ + local/search/utt_to_files.pl --flen $flen $kwsdatadir/../segments \|\ + local/search/write_kwslist.pl --flen $flen --language $language \ + --kwlist-id $kwlist_name \> ${kwsoutputdir}_LMWT/f4de/kwslist.xml + + $cmd LMWT=$min_lmwt:$max_lmwt $kwsoutputdir/log/f4de_score.LMWT.log \ + KWSEval -e $ecf -r $rttm -t ${kwsoutputdir}_LMWT/f4de/kwlist.xml -a \ + --zGlobalMeasures Optimum --zGlobalMeasures Supremum \ + -O -B -q 'Characters:regex=.*' -q 'NGramOrder:regex=.*' \ + -O -B -q 'OOV:regex=.*' -q 'BaseOOV:regex=.*' \ + -s ${kwsoutputdir}_LMWT/f4de/kwslist.xml -c -o -b -d -f ${kwsoutputdir}_LMWT/f4de/ + + $cmd LMWT=$min_lmwt:$max_lmwt $kwsoutputdir/log/f4de_report.LMWT.log \ + local/kws_oracle_threshold.pl --duration $trials \ + ${kwsoutputdir}_LMWT/f4de/alignment.csv \> ${kwsoutputdir}_LMWT/f4de/metrics.txt +fi +fi + +echo "$0: Done" +exit 0; + + diff --git a/egs/mini_librispeech/s5/local/kws/search.sh b/egs/mini_librispeech/s5/local/kws/search.sh new file mode 100755 index 00000000000..1c69b0da556 --- /dev/null +++ b/egs/mini_librispeech/s5/local/kws/search.sh @@ -0,0 +1,208 @@ +#!/bin/bash +# Copyright 2012-2018 Johns Hopkins University (Author: Guoguo Chen, Yenda Trmal) +# License: Apache 2.0 + + +help_message="$(basename $0): do keyword indexing and search. data-dir is assumed to have + kws/ subdirectory that specifies the terms to search for. Output is in + decode-dir/kws/ + Usage: + $(basename $0) " + +# Begin configuration section. +min_lmwt=8 +max_lmwt=12 +cmd=run.pl +model= +skip_scoring=false +skip_optimization=false # true can speed it up if #keywords is small. +max_states=350000 +indices_dir= +kwsout_dir= +stage=0 +word_ins_penalty=0 +extraid= +silence_word= # specify this if you did to in kws_setup.sh, it's more accurate. +strict=false +duptime=0.6 +ntrue_scale=1.0 +frame_subsampling_factor=1 +nbest=-1 +max_silence_frames=50 +skip_indexing=false +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +[ -f ./path.sh ] && . ./path.sh; # source the path. +. parse_options.sh || exit 1; + +set -u +set -e +set -o pipefail + + +if [[ "$#" -ne "3" ]] ; then + echo -e "$0: FATAL: wrong number of script parameters!\n\n" + printf "$help_message\n\n" + exit 1; +fi + +silence_opt= + +langdir=$1 +datadir=$2 +decodedir=$3 + +if [ -z $extraid ] ; then + kwsdatadir=$datadir/kws +else + kwsdatadir=$datadir/kwset_${extraid} +fi + +if [ -z $extraid ] ; then + kwsoutdir=$decodedir/kws +else + kwsoutdir=$decodedir/kwset_${extraid} +fi + + +if [ -z $indices_dir ]; then + indices_dir=$kwsoutdir +fi + +if [ ! -z "$model" ]; then + model_flags="--model $model" +else + model_flags= +fi + +mkdir -p $kwsoutdir +for d in "$datadir" "$kwsdatadir" "$langdir" "$decodedir"; do + if [ ! -d "$d" ]; then + echo "$0: FATAL: expected directory $d to exist" + exit 1; + fi +done + +echo "$0: Searching: $kwsdatadir" +duration=$(cat $kwsdatadir/trials) +echo "$0: Duration: $duration" + + +frame_subsampling_factor=1 +if [ -f $decodedir/../frame_subsampling_factor ] ; then + frame_subsampling_factor=$(cat $decodedir/../frame_subsampling_factor) + echo "$0: Frame subsampling factor autodetected: $frame_subsampling_factor" +elif [ -f $decodedir/../../frame_subsampling_factor ] ; then + frame_subsampling_factor=$(cat $decodedir/../../frame_subsampling_factor) + echo "$0: Frame subsampling factor autodetected: $frame_subsampling_factor" +fi + +if [ $stage -le 0 ] ; then + if [ ! -f $indices_dir/.done.index ] && ! $skip_indexing ; then + [ ! -d $indices_dir ] && mkdir $indices_dir + for lmwt in $(seq $min_lmwt $max_lmwt) ; do + indices=${indices_dir}_$lmwt + mkdir -p $indices + + acwt=$(perl -e "print 1.0/$lmwt") + [ ! -z $silence_word ] && silence_opt="--silence-word $silence_word" + steps/make_index.sh $silence_opt --cmd "$cmd" --acwt $acwt $model_flags\ + --skip-optimization $skip_optimization --max-states $max_states \ + --word-ins-penalty $word_ins_penalty --max-silence-frames $max_silence_frames\ + --frame-subsampling-factor ${frame_subsampling_factor} \ + $kwsdatadir $langdir $decodedir $indices || exit 1 + done + touch $indices_dir/.done.index + else + echo "$0: Assuming indexing has been aready done. If you really need to re-run " + echo "$0: the indexing again, delete the file $indices_dir/.done.index" + fi +fi + +keywords=$kwsdatadir/keywords.fsts +if [ -f $keywords ] ; then + echo "$0: Using ${keywords} for search" + keywords="ark:$keywords" +elif [ -f ${keywords}.gz ] ; then + echo "$0: Using ${keywords}.gz for search" + keywords="ark:gunzip -c ${keywords}.gz |" +else + echo "$0: The keyword file ${keywords}[.gz] does not exist" +fi + + +if [ $stage -le 1 ]; then + for lmwt in $(seq $min_lmwt $max_lmwt) ; do + kwsoutput=${kwsoutdir}_$lmwt + indices=${indices_dir}_$lmwt + nj=$(cat $indices/num_jobs) + + + for f in $indices/index.1.gz ; do + [ ! -f $f ] && echo "$0: no such file $f" && exit 1; + done + + mkdir -p $kwsoutput/log + $cmd JOB=1:$nj $kwsoutput/log/search.JOB.log \ + set -e -o pipefail '&&' \ + kws-search --strict=$strict --negative-tolerance=-1 \ + --frame-subsampling-factor=${frame_subsampling_factor} \ + "ark:gzip -cdf $indices/index.JOB.gz|" "$keywords" \ + "ark,t:| sort -u | gzip -c > $kwsoutput/result.JOB.gz" \ + "ark,t:| sort -u | gzip -c > $kwsoutput/stats.JOB.gz" || exit 1; + done +fi + +if [ $stage -le 2 ]; then + for lmwt in $(seq $min_lmwt $max_lmwt) ; do + kwsoutput=${kwsoutdir}_$lmwt + indices=${indices_dir}_$lmwt + nj=$(cat $indices/num_jobs) + + # This is a memory-efficient way how to do the filtration + # we do this in this way because the result.* files can be fairly big + # and we do not want to run into troubles with memory + files="" + for job in $(seq 1 $nj); do + if [ -f $kwsoutput/result.${job}.gz ] ; then + files="$files <(gunzip -c $kwsoutput/result.${job}.gz)" + elif [ -f $kwsoutput/result.${job} ] ; then + files="$files $kwsoutput/result.${job}" + else + echo >&2 "The file $kwsoutput/result.${job}[.gz] does not exist" + exit 1 + fi + done + # we have to call it using eval as we need the bash to interpret + # the (possible) command substitution in case of gz files + # bash -c would probably work as well, but would spawn another + # shell instance + eval "sort -m -u $files" |\ + local/kws/filter_kws_results.pl --likes --nbest $nbest > $kwsoutput/results || exit 1 + done +fi + +if [ -z $extraid ] ; then + extraid_flags= +else + extraid_flags=" --extraid ""$extraid"" " +fi + +if [ $stage -le 4 ]; then + if $skip_scoring ; then + echo "$0: Not scoring, because --skip-scoring true was issued" + elif [ ! -x local/kws/score.sh ] ; then + echo "$0: Not scoring, because the file local/kws_score.sh is not present" + else + echo "$0: Scoring KWS results" + local/kws/score.sh --cmd "$cmd" \ + --min-lmwt $min_lmwt --max-lmwt $max_lmwt $extraid_flags \ + $langdir $datadir ${kwsoutdir} || exit 1; + fi +fi + +echo "$0: Done" +exit 0 + diff --git a/egs/mini_librispeech/s5/local/nnet3/run_ivector_common.sh b/egs/mini_librispeech/s5/local/nnet3/run_ivector_common.sh index 2663fb12ee5..fd06d81e88d 100755 --- a/egs/mini_librispeech/s5/local/nnet3/run_ivector_common.sh +++ b/egs/mini_librispeech/s5/local/nnet3/run_ivector_common.sh @@ -52,7 +52,7 @@ if [ $stage -le 3 ]; then echo "$0: creating high-resolution MFCC features" mfccdir=data/${train_set}_sp_hires/data if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then - utils/create_split_dir.pl /export/b1{5,6,8,9}/$USER/kaldi-data/mfcc/mini_librispeech-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage + utils/create_split_dir.pl /export/fs0{1,2}/$USER/kaldi-data/mfcc/mini_librispeech-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage fi for datadir in ${train_set}_sp ${test_sets}; do @@ -122,7 +122,7 @@ if [ $stage -le 6 ]; then ivectordir=exp/nnet3${nnet3_affix}/ivectors_${train_set}_sp_hires if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $ivectordir/storage ]; then - utils/create_split_dir.pl /export/b0{5,6,7,8}/$USER/kaldi-data/ivectors/mini_librispeech-$(date +'%m_%d_%H_%M')/s5/$ivectordir/storage $ivectordir/storage + utils/create_split_dir.pl /export/fs0{1,2}/$USER/kaldi-data/ivectors/mini_librispeech-$(date +'%m_%d_%H_%M')/s5/$ivectordir/storage $ivectordir/storage fi diff --git a/egs/mini_librispeech/s5/local/nnet3/tuning/run_tdnn_lstm_1a.sh b/egs/mini_librispeech/s5/local/nnet3/tuning/run_tdnn_lstm_1a.sh index de858973c98..c2f90df4b5c 100755 --- a/egs/mini_librispeech/s5/local/nnet3/tuning/run_tdnn_lstm_1a.sh +++ b/egs/mini_librispeech/s5/local/nnet3/tuning/run_tdnn_lstm_1a.sh @@ -99,7 +99,7 @@ if [ $stage -le 10 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $ali_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/mini_librispeech/s5/local/nnet3/tuning/run_tdnn_lstm_1b.sh b/egs/mini_librispeech/s5/local/nnet3/tuning/run_tdnn_lstm_1b.sh index ba4ecc268df..2b3c2844972 100755 --- a/egs/mini_librispeech/s5/local/nnet3/tuning/run_tdnn_lstm_1b.sh +++ b/egs/mini_librispeech/s5/local/nnet3/tuning/run_tdnn_lstm_1b.sh @@ -102,7 +102,7 @@ if [ $stage -le 10 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $ali_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20 delay=-3 dropout-proportion=0.0" mkdir -p $dir/configs diff --git a/egs/mini_librispeech/s5/local/nnet3/tuning/run_tdnn_lstm_1c.sh b/egs/mini_librispeech/s5/local/nnet3/tuning/run_tdnn_lstm_1c.sh index 74df56b0537..5118cb0f8bd 100755 --- a/egs/mini_librispeech/s5/local/nnet3/tuning/run_tdnn_lstm_1c.sh +++ b/egs/mini_librispeech/s5/local/nnet3/tuning/run_tdnn_lstm_1c.sh @@ -100,7 +100,7 @@ if [ $stage -le 10 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $ali_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) tdnn_opts="l2-regularize=0.05" lstm_opts="l2-regularize=0.01 decay-time=20 delay=-3 dropout-proportion=0.0" output_opts="l2-regularize=0.01" diff --git a/egs/mini_librispeech/s5/path.sh b/egs/mini_librispeech/s5/path.sh index 705600ad47a..34244b27f2e 100644 --- a/egs/mini_librispeech/s5/path.sh +++ b/egs/mini_librispeech/s5/path.sh @@ -1,4 +1,5 @@ export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH [ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 . $KALDI_ROOT/tools/config/common_path.sh diff --git a/egs/mini_librispeech/s5/run.sh b/egs/mini_librispeech/s5/run.sh index 30b0e8bda7c..68905ed68d1 100755 --- a/egs/mini_librispeech/s5/run.sh +++ b/egs/mini_librispeech/s5/run.sh @@ -1,7 +1,7 @@ #!/bin/bash # Change this location to somewhere where you want to put the data. -data=/export/a05/dgalvez/ +data=./corpus/ data_url=www.openslr.org/resources/31 lm_url=www.openslr.org/resources/11 @@ -21,7 +21,7 @@ for part in dev-clean-2 train-clean-5; do done if [ $stage -le 0 ]; then - local/download_lm.sh $lm_url data/local/lm + local/download_lm.sh $lm_url $data data/local/lm fi if [ $stage -le 1 ]; then @@ -199,5 +199,7 @@ if [ $stage -le 9 ]; then local/chain/run_tdnn.sh --stage 0 fi +# local/grammar/simple_demo.sh + # Don't finish until all background decoding jobs are finished. wait diff --git a/egs/multi_en/s5/local/chain/run_tdnn_lstm.sh b/egs/multi_en/s5/local/chain/run_tdnn_lstm.sh new file mode 120000 index 00000000000..8e647598556 --- /dev/null +++ b/egs/multi_en/s5/local/chain/run_tdnn_lstm.sh @@ -0,0 +1 @@ +tuning/run_tdnn_lstm_1a.sh \ No newline at end of file diff --git a/egs/multi_en/s5/local/chain/run_tdnn_opgru.sh b/egs/multi_en/s5/local/chain/run_tdnn_opgru.sh index aedd4c8b4ac..20d4c87b289 120000 --- a/egs/multi_en/s5/local/chain/run_tdnn_opgru.sh +++ b/egs/multi_en/s5/local/chain/run_tdnn_opgru.sh @@ -1 +1 @@ -tuning/run_tdnn_opgru_1a.sh \ No newline at end of file +tuning/run_tdnn_opgru_1b.sh \ No newline at end of file diff --git a/egs/multi_en/s5/local/chain/tuning/run_tdnn_5b.sh b/egs/multi_en/s5/local/chain/tuning/run_tdnn_5b.sh index 9f8c49387b1..96f5fdac8f3 100755 --- a/egs/multi_en/s5/local/chain/tuning/run_tdnn_5b.sh +++ b/egs/multi_en/s5/local/chain/tuning/run_tdnn_5b.sh @@ -132,7 +132,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.0015 dropout-proportion=0.0 dropout-per-dim=true dropout-per-dim-continuous=true" linear_opts="l2-regularize=0.0015 orthonormal-constraint=-1.0" output_opts="l2-regularize=0.001" diff --git a/egs/multi_en/s5/local/chain/tuning/run_tdnn_lstm_1a.sh b/egs/multi_en/s5/local/chain/tuning/run_tdnn_lstm_1a.sh new file mode 100755 index 00000000000..62266334962 --- /dev/null +++ b/egs/multi_en/s5/local/chain/tuning/run_tdnn_lstm_1a.sh @@ -0,0 +1,319 @@ +#!/bin/bash +# Copyright 2017 University of Chinese Academy of Sciences (UCAS) Gaofeng Cheng +# 2018 Xiaohui Zhang +# 2018 Vimal Manohar +# Apache 2.0 + +# This recipe is similar with tdnn_lstm_1b recipefrom fisher_swbd/s5, and is currently +# the best performing multi-en recipe. + +# System tdnn_opgru_1b_sp tdnn_lstm_1a_sp +# WER on eval2000(tg) 11.4 11.4 +# WER on eval2000(fg) 11.2 11.2 +# WER on rt03(tg) 11.1 10.7 +# WER on rt03(fg) 10.8 10.5 +# Final train prob -0.091 -0.095 +# Final valid prob -0.091 -0.089 +# Final train prob (xent) -0.990 -0.970 +# Final valid prob (xent) -0.091 -0.9638 +# Num-parameters 34976320 39704128 + +# ./steps/info/chain_dir_info.pl exp/multi_a/chain/tdnn_lstm_1a_sp +# exp/multi_a/chain/tdnn_lstm_1a_sp: num-iters=2096 nj=3..16 num-params=39.7M dim=40+100->6176 combine=-0.088->-0.087 (over 3) +# xent:train/valid[1395,2095,final]=(-1.38,-0.960,-0.970/-1.39,-0.964,-0.964) +# logprob:train/valid[1395,2095,final]=(-0.117,-0.091,-0.095/-0.109,-0.087,-0.089) + +# online results +# Eval2000 +# %WER 14.2 | 2628 21594 | 87.8 8.6 3.5 2.1 14.2 49.1 | exp/multi_a/chain/tdnn_lstm_1a_sp_online/decode_eval2000/score_8_0.0/eval2000_hires.ctm.callhm.filt.sys +# %WER 11.4 | 4459 42989 | 90.3 7.0 2.7 1.7 11.4 46.1 | exp/multi_a/chain/tdnn_lstm_1a_sp_online/decode_eval2000/score_8_0.0/eval2000_hires.ctm.filt.sys +# %WER 8.4 | 1831 21395 | 92.8 5.3 2.0 1.2 8.4 41.2 | exp/multi_a/chain/tdnn_lstm_1a_sp_online/decode_eval2000/score_9_0.0/eval2000_hires.ctm.swbd.filt.sys +# %WER 14.0 | 2628 21594 | 88.0 8.5 3.4 2.1 14.0 48.6 | exp/multi_a/chain/tdnn_lstm_1a_sp_online/decode_eval2000_fg/score_8_0.0/eval2000_hires.ctm.callhm.filt.sys +# %WER 11.2 | 4459 42989 | 90.5 6.9 2.6 1.7 11.2 45.4 | exp/multi_a/chain/tdnn_lstm_1a_sp_online/decode_eval2000_fg/score_8_0.0/eval2000_hires.ctm.filt.sys +# %WER 8.1 | 1831 21395 | 93.1 5.1 1.8 1.2 8.1 40.6 | exp/multi_a/chain/tdnn_lstm_1a_sp_online/decode_eval2000_fg/score_9_0.0/eval2000_hires.ctm.swbd.filt.sys + +# RT03 +# %WER 8.7 | 3970 36721 | 92.2 5.3 2.5 1.0 8.7 37.3 | exp/multi_a/chain/tdnn_lstm_1a_sp_online/decode_rt03/score_7_0.0/rt03_hires.ctm.fsh.filt.sys +# %WER 10.8 | 8420 76157 | 90.4 6.5 3.2 1.2 10.8 40.1 | exp/multi_a/chain/tdnn_lstm_1a_sp_online/decode_rt03/score_8_0.0/rt03_hires.ctm.filt.sys +# %WER 12.7 | 4450 39436 | 88.7 7.7 3.6 1.4 12.7 42.5 | exp/multi_a/chain/tdnn_lstm_1a_sp_online/decode_rt03/score_8_0.0/rt03_hires.ctm.swbd.filt.sys +# %WER 8.5 | 3970 36721 | 92.4 5.1 2.5 0.9 8.5 37.2 | exp/multi_a/chain/tdnn_lstm_1a_sp_online/decode_rt03_fg/score_7_1.0/rt03_hires.ctm.fsh.filt.sys +# %WER 10.5 | 8420 76157 | 90.6 6.3 3.1 1.2 10.5 40.1 | exp/multi_a/chain/tdnn_lstm_1a_sp_online/decode_rt03_fg/score_8_0.0/rt03_hires.ctm.filt.sys +# %WER 12.4 | 4450 39436 | 88.9 7.2 3.9 1.3 12.4 42.7 | exp/multi_a/chain/tdnn_lstm_1a_sp_online/decode_rt03_fg/score_9_0.0/rt03_hires.ctm.swbd.filt.sys + +set -e + +# configs for 'chain' +stage=-10 +train_stage=-10 +get_egs_stage=-10 +speed_perturb=true +multi=multi_a +gmm=tri5a +decode_iter= +decode_dir_affix= +decode_nj=50 + +# training options +frames_per_chunk=140,100,160 +frames_per_chunk_primary=$(echo $frames_per_chunk | cut -d, -f1) +chunk_left_context=40 +chunk_right_context=0 +xent_regularize=0.025 +self_repair_scale=0.00001 +label_delay=5 +# decode options +extra_left_context=50 +extra_right_context=0 +dropout_schedule='0,0@0.20,0.3@0.50,0' +num_epochs=4 + +remove_egs=false +common_egs_dir= + +test_online_decoding=true # if true, it will run the last decoding stage. + +nnet3_affix= +tdnn_affix=_1a + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 11 ]; then + # Build a tree using our new topology. + + if [ -f $treedir/final.mdl ]; then + echo "$treedir exists. Remove it or skip this stage." + exit 1 + fi + + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 7000 data/$train_set $lang $lats_dir $treedir +fi + +if [ $stage -le 12 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + lstm_opts="dropout-proportion=0.0 decay-time=40" + + relu_dim=1024 + cell_dim=1024 + projection_dim=256 + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-2,-1,0,1,2,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=$relu_dim + relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1) dim=$relu_dim + relu-batchnorm-layer name=tdnn3 input=Append(-1,0,1) dim=$relu_dim + + # check steps/libs/nnet3/xconfig/lstm.py for the other options and defaults + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-3 $lstm_opts + relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=$relu_dim + relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=$relu_dim + fast-lstmp-layer name=lstm2 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-3 $lstm_opts + relu-batchnorm-layer name=tdnn6 input=Append(-3,0,3) dim=$relu_dim + relu-batchnorm-layer name=tdnn7 input=Append(-3,0,3) dim=$relu_dim + fast-lstmp-layer name=lstm3 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-3 $lstm_opts + + ## adding the layers for chain branch + output-layer name=output input=lstm3 output-delay=$label_delay include-log-softmax=false dim=$num_targets max-change=1.5 + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + output-layer name=output-xent input=lstm3 output-delay=$label_delay dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 + +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 13 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/multi-en-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir exp/$multi/nnet3${nnet3_affix}/ivectors_${train_set} \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.num-chunk-per-minibatch 64,32 \ + --trainer.frames-per-iter 1500000 \ + --trainer.max-param-change 2.0 \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.shrink-value 0.99 \ + --trainer.optimization.num-jobs-initial 3 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.dropout-schedule=$dropout_schedule \ + --trainer.optimization.momentum 0.0 \ + --trainer.deriv-truncate-margin 8 \ + --egs.stage $get_egs_stage \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width $frames_per_chunk \ + --egs.chunk-left-context $chunk_left_context \ + --egs.chunk-right-context $chunk_right_context \ + --egs.chunk-left-context-initial 0 \ + --egs.chunk-right-context-final 0 \ + --egs.dir "$common_egs_dir" \ + --cleanup.remove-egs $remove_egs \ + --feat-dir data/${train_set}_hires \ + --tree-dir $treedir \ + --lat-dir $lats_dir \ + --dir $dir || exit 1; +fi + +lang_suffix=${lang_dir##*lang} + +if [ $stage -le 14 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 $lang_dir \ + $dir $dir/graph${lang_suffix} +fi + +graph_dir=$dir/graph${lang_suffix} +if [ $stage -le 15 ]; then + iter_opts= + [ -z $extra_left_context ] && extra_left_context=$chunk_left_context; + [ -z $extra_right_context ] && extra_right_context=$chunk_right_context; + if [ ! -z $decode_iter ]; then + iter_opts=" --iter $decode_iter " + fi + for decode_set in eval2000 rt03; do + ( + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj 50 --cmd "$decode_cmd" $iter_opts \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --extra-left-context-initial 0 \ + --extra-right-context-final 0 \ + --frames-per-chunk "$frames_per_chunk_primary" \ + --online-ivector-dir exp/$multi/nnet3${nnet3_affix}/ivectors_${decode_set} \ + $graph_dir data/${decode_set}_hires \ + $dir/decode${lang_suffix}_${decode_set}${decode_dir_affix:+_$decode_dir_affix}${decode_iter:+_iter$decode_iter} + + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + $lang_dir $rescore_lang_dir data/${decode_set}_hires \ + $dir/decode${lang_suffix}_${decode_set}${decode_dir_affix:+_$decode_dir_affix}{,_fg}${decode_iter:+_iter$decode_iter} || exit 1; + ) & + done +fi +wait; + +if $test_online_decoding && [ $stage -le 16 ]; then + # note: if the features change (e.g. you add pitch features), you will have to + # change the options of the following command line. + steps/online/nnet3/prepare_online_decoding.sh \ + --mfcc-config conf/mfcc_hires.conf \ + $lang exp/nnet3/extractor $dir ${dir}_online + + rm $dir/.error 2>/dev/null || true + for decode_set in train_dev eval2000; do + ( + # note: we just give it "$decode_set" as it only uses the wav.scp, the + # feature type does not matter. + + steps/online/nnet3/decode.sh --nj $decode_nj --cmd "$decode_cmd" $iter_opts \ + --acwt 1.0 --post-decode-acwt 10.0 \ + $graph_dir data/${decode_set}_hires \ + ${dir}_online/decode_${decode_set}${decode_iter:+_$decode_iter}_sw1_tg || exit 1; + if $has_fisher; then + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + data/lang_sw1_{tg,fsh_fg} data/${decode_set}_hires \ + ${dir}_online/decode_${decode_set}${decode_iter:+_$decode_iter}_sw1_{tg,fsh_fg} || exit 1; + fi + ) || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in online decoding" + exit 1 + fi +fi + +exit 0; diff --git a/egs/multi_en/s5/local/chain/tuning/run_tdnn_opgru_1a.sh b/egs/multi_en/s5/local/chain/tuning/run_tdnn_opgru_1a.sh index 98e7c2ed6c1..79cd3eb3014 100755 --- a/egs/multi_en/s5/local/chain/tuning/run_tdnn_opgru_1a.sh +++ b/egs/multi_en/s5/local/chain/tuning/run_tdnn_opgru_1a.sh @@ -150,7 +150,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) gru_opts="dropout-per-frame=true dropout-proportion=0.0 " mkdir -p $dir/configs diff --git a/egs/multi_en/s5/local/chain/tuning/run_tdnn_opgru_1b.sh b/egs/multi_en/s5/local/chain/tuning/run_tdnn_opgru_1b.sh new file mode 100755 index 00000000000..a7170af9431 --- /dev/null +++ b/egs/multi_en/s5/local/chain/tuning/run_tdnn_opgru_1b.sh @@ -0,0 +1,316 @@ +#!/bin/bash +# Copyright 2018 Xiaohui Zhang +# 2017 University of Chinese Academy of Sciences (UCAS) Gaofeng Cheng +# Apache 2.0 + +# This is similar with tdnn_opgru_1a but with correct num_leaves (7k rather than 11k), +# aligments from lattices when building the tree, and better l2-regularization as opgru-1a +# from fisher-swbd. + +# ./local/chain/compare_wer_general.sh tdnn_opgru_1a_sp tdnn_opgru_1b_sp +# System tdnn_opgru_1a_sp tdnn_opgru_1b_sp +# WER on eval2000(tg) 11.6 11.4 +# WER on eval2000(fg) 11.5 11.2 +# WER on rt03(tg) 11.5 11.1 +# WER on rt03(fg) 11.2 10.8 +# Final train prob -0.088 -0.091 +# Final valid prob -0.088 -0.091 +# Final train prob (xent) -1.048 -0.990 +# Final valid prob (xent) -1.0253 -0.091 +# Num-parameters 37364848 34976320 + + +# ./steps/info/chain_dir_info.pl exp/${multi}/chain/tdnn_opgru_1b_sp +# exp/${multi}/chain/tdnn_opgru_1b_sp: num-iters=2621 nj=3..16 num-params=35.0M dim=40+100->6176 combine=-0.098->-0.096 (over 4) +# xent:train/valid[1744,2620,final]=(-1.49,-0.991,-0.990/-1.51,-1.01,-1.01) +# logprob:train/valid[1744,2620,final]=(-0.118,-0.091,-0.091/-0.117,-0.093,-0.091) + +# online results +# Eval2000 +# %WER 14.3 | 2628 21594 | 87.8 8.9 3.3 2.1 14.3 49.8 | exp/${multi}/chain/tdnn_opgru_1b_sp_online/decode_eval2000_fsh_sw1_tg/score_7_0.0/eval2000_hires.ctm.callhm.filt.sys +# %WER 11.4 | 4459 42989 | 90.2 7.2 2.7 1.6 11.4 46.3 | exp/${multi}/chain/tdnn_opgru_1b_sp_online/decode_eval2000_fsh_sw1_tg/score_8_0.5/eval2000_hires.ctm.filt.sys +# %WER 8.4 | 1831 21395 | 92.7 5.3 2.0 1.1 8.4 41.8 | exp/${multi}/chain/tdnn_opgru_1b_sp_online/decode_eval2000_fsh_sw1_tg/score_10_0.0/eval2000_hires.ctm.swbd.filt.sys +# %WER 14.2 | 2628 21594 | 88.0 8.8 3.3 2.2 14.2 49.4 | exp/${multi}/chain/tdnn_opgru_1b_sp_online/decode_eval2000_fsh_sw1_fg/score_7_0.0/eval2000_hires.ctm.callhm.filt.sys +# %WER 11.4 | 4459 42989 | 90.3 7.1 2.6 1.7 11.4 45.9 | exp/${multi}/chain/tdnn_opgru_1b_sp_online/decode_eval2000_fsh_sw1_fg/score_8_0.0/eval2000_hires.ctm.filt.sys +# %WER 8.2 | 1831 21395 | 92.9 5.1 2.0 1.1 8.2 41.3 | exp/${multi}/chain/tdnn_opgru_1b_sp_online/decode_eval2000_fsh_sw1_fg/score_11_0.0/eval2000_hires.ctm.swbd.filt.sys + +# RT03 +# %WER 9.0 | 3970 36721 | 92.0 5.5 2.4 1.1 9.0 37.9 | exp/${multi}/chain/tdnn_opgru_1b_sp_online/decode_rt03_fsh_sw1_tg/score_7_0.0/rt03_hires.ctm.fsh.filt.sys +# %WER 11.2 | 8420 76157 | 90.0 6.8 3.2 1.2 11.2 41.1 | exp/${multi}/chain/tdnn_opgru_1b_sp_online/decode_rt03_fsh_sw1_tg/score_8_0.0/rt03_hires.ctm.filt.sys +# %WER 13.2 | 4450 39436 | 88.1 7.5 4.4 1.3 13.2 44.1 | exp/${multi}/chain/tdnn_opgru_1b_sp_online/decode_rt03_fsh_sw1_tg/score_10_0.0/rt03_hires.ctm.swbd.filt.sys +# %WER 8.7 | 3970 36721 | 92.3 5.1 2.6 1.0 8.7 37.8 | exp/${multi}/chain/tdnn_opgru_1b_sp_online/decode_rt03_fsh_sw1_fg/score_8_0.0/rt03_hires.ctm.fsh.filt.sys +# %WER 10.9 | 8420 76157 | 90.3 6.5 3.1 1.2 10.9 40.6 | exp/${multi}/chain/tdnn_opgru_1b_sp_online/decode_rt03_fsh_sw1_fg/score_8_0.0/rt03_hires.ctm.filt.sys +# %WER 12.9 | 4450 39436 | 88.5 7.9 3.6 1.4 12.9 43.1 | exp/${multi}/chain/tdnn_opgru_1b_sp_online/decode_rt03_fsh_sw1_fg/score_8_0.0/rt03_hires.ctm.swbd.filt.sys + +set -e + +# configs for 'chain' +stage=-1 +train_stage=-10 +get_egs_stage=-10 +speed_perturb=true +multi=multi_a +gmm=tri5a +dir=exp/${multi}/chain/tdnn_opgru_1b # Note: _sp will get added to this if $speed_perturb == true. +decode_iter= +decode_dir_affix= +rescore=true # whether to rescore lattices +dropout_schedule='0,0@0.20,0.2@0.50,0' + +# training options +num_epochs=4 +chunk_width=150 +chunk_left_context=40 +chunk_right_context=0 +xent_regularize=0.025 +self_repair_scale=0.00001 +label_delay=5 +# decode options +extra_left_context=50 +extra_right_context=0 +frames_per_chunk= + +remove_egs=false +common_egs_dir= + +affix= +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 11 ]; then + # Build a tree using our new topology. + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 7000 data/$train_set $lang exp/${multi}/${gmm}_lats_nodup$suffix $treedir +fi + +if [ $stage -le 12 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + gru_opts="dropout-per-frame=true dropout-proportion=0.0 " + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-2,-1,0,1,2, ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=1024 + relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1) dim=1024 + relu-batchnorm-layer name=tdnn3 input=Append(-1,0,1) dim=1024 + + # check steps/libs/nnet3/xconfig/lstm.py for the other options and defaults + norm-opgru-layer name=opgru1 cell-dim=1024 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-3 $gru_opts + relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=1024 + relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=1024 + norm-opgru-layer name=opgru2 cell-dim=1024 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-3 $gru_opts + relu-batchnorm-layer name=tdnn6 input=Append(-3,0,3) dim=1024 + relu-batchnorm-layer name=tdnn7 input=Append(-3,0,3) dim=1024 + norm-opgru-layer name=opgru3 cell-dim=1024 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-3 $gru_opts + + ## adding the layers for chain branch + output-layer name=output input=opgru3 output-delay=$label_delay include-log-softmax=false dim=$num_targets max-change=1.5 + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + output-layer name=output-xent input=opgru3 output-delay=$label_delay dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 + +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 13 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,7,9,8}/$USER/kaldi-data/egs/multi-en-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir exp/${multi}/nnet3/ivectors_${train_set} \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.num-chunk-per-minibatch 64 \ + --trainer.frames-per-iter 1200000 \ + --trainer.max-param-change 2.0 \ + --trainer.num-epochs $num_epocs \ + --trainer.optimization.shrink-value 0.99 \ + --trainer.optimization.num-jobs-initial 3 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.optimization.momentum 0.0 \ + --trainer.deriv-truncate-margin 8 \ + --egs.stage $get_egs_stage \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width $chunk_width \ + --egs.chunk-left-context $chunk_left_context \ + --egs.chunk-right-context $chunk_right_context \ + --egs.chunk-left-context-initial 0 \ + --egs.chunk-right-context-final 0 \ + --egs.dir "$common_egs_dir" \ + --cleanup.remove-egs $remove_egs \ + --feat-dir data/${train_set}_hires \ + --tree-dir $treedir \ + --lat-dir exp/${multi}/tri5a_lats_nodup$suffix \ + --dir $dir || exit 1; +fi + +if [ $stage -le 14 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 data/lang_${multi}_${gmm}_fsh_sw1_tg $dir $dir/graph_fsh_sw1_tg +fi + +decode_suff=fsh_sw1_tg +graph_dir=$dir/graph_fsh_sw1_tg +if [ $stage -le 15 ]; then + rm $dir/.error 2>/dev/null || true + [ -z $extra_left_context ] && extra_left_context=$chunk_left_context; + [ -z $extra_right_context ] && extra_right_context=$chunk_right_context; + [ -z $frames_per_chunk ] && frames_per_chunk=$chunk_width; + if [ ! -z $decode_iter ]; then + iter_opts=" --iter $decode_iter " + fi + if $rescore && [ ! -f data/lang_${multi}_${gmm}_fsh_sw1_fg/G.carpa ]; then + LM_fg=data/local/lm/4gram-mincount/lm_unpruned.gz + utils/build_const_arpa_lm.sh $LM_fg data/lang_${multi}_${gmm}_fsh_sw1_tg data/lang_${multi}_${gmm}_fsh_sw1_fg + fi + for decode_set in rt03 eval2000; do + ( + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj 50 --cmd "$decode_cmd" $iter_opts \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --extra-left-context-initial 0 \ + --extra-right-context-final 0 \ + --frames-per-chunk "$frames_per_chunk" \ + --online-ivector-dir exp/${multi}/nnet3/ivectors_${decode_set} \ + $graph_dir data/${decode_set}_hires \ + $dir/decode_${decode_set}${decode_dir_affix:+_$decode_dir_affix}_${decode_suff} || exit 1; + if $rescore; then + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + data/lang_${multi}_${gmm}_fsh_sw1_{tg,fg} data/${decode_set}_hires \ + $dir/decode_${decode_set}${decode_dir_affix:+_$decode_dir_affix}_fsh_sw1_{tg,fg} || exit 1; + fi + ) || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in decoding" + exit 1 + fi +fi + +test_online_decoding=true +lang=data/lang_${multi}_${gmm}_fsh_sw1_tg +if $test_online_decoding && [ $stage -le 16 ]; then + # note: if the features change (e.g. you add pitch features), you will have to + # change the options of the following command line. + steps/online/nnet3/prepare_online_decoding.sh \ + --mfcc-config conf/mfcc_hires.conf \ + $lang exp/${multi}/nnet3/extractor $dir ${dir}_online + + rm $dir/.error 2>/dev/null || true + for decode_set in rt03 eval2000; do + ( + # note: we just give it "$decode_set" as it only uses the wav.scp, the + # feature type does not matter. + + steps/online/nnet3/decode.sh --nj 50 --cmd "$decode_cmd" $iter_opts \ + --acwt 1.0 --post-decode-acwt 10.0 \ + $graph_dir data/${decode_set}_hires \ + ${dir}_online/decode_${decode_set}${decode_iter:+_$decode_iter}_${decode_suff} || exit 1; + if $rescore; then + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + data/lang_${multi}_${gmm}_fsh_sw1_{tg,fg} data/${decode_set}_hires \ + ${dir}_online/decode_${decode_set}${decode_dir_affix:+_$decode_dir_affix}_fsh_sw1_{tg,fg} || exit 1; + fi + ) || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in online decoding" + exit 1 + fi +fi + +exit 0; diff --git a/egs/multi_en/s5/local/format_acronyms_ctm_eval2000.py b/egs/multi_en/s5/local/format_acronyms_ctm_eval2000.py index 3c447c5976a..75cc4458d85 100755 --- a/egs/multi_en/s5/local/format_acronyms_ctm_eval2000.py +++ b/egs/multi_en/s5/local/format_acronyms_ctm_eval2000.py @@ -10,6 +10,7 @@ # en_4156 B 414.58 0.16 l # en_4156 B 414.74 0.17 a +from __future__ import division import argparse,re __author__ = 'Minhua Wu' @@ -27,7 +28,7 @@ if items[4].find(".") != -1: letters = items[4].split("._") acronym_period = round(float(items[3]), 2) - letter_slot = round(acronym_period / len(letters), 2) + letter_slot = round(acronym_period/len(letters), 2) time_start = round(float(items[2]), 2) for l in letters[:-1]: time = " %.2f %.2f " % (time_start, letter_slot) diff --git a/egs/multi_en/s5/local/format_acronyms_ctm_rt03.py b/egs/multi_en/s5/local/format_acronyms_ctm_rt03.py index 59814beb4ea..8438bbdaf81 100755 --- a/egs/multi_en/s5/local/format_acronyms_ctm_rt03.py +++ b/egs/multi_en/s5/local/format_acronyms_ctm_rt03.py @@ -10,6 +10,7 @@ # en_4156 B 414.58 0.16 l # en_4156 B 414.74 0.17 a +from __future__ import division import argparse,re __author__ = 'Minhua Wu' @@ -27,7 +28,7 @@ if items[4].find(".") != -1: letters = items[4].split("._") acronym_period = round(float(items[3]), 2) - letter_slot = round(acronym_period / len(letters), 2) + letter_slot = round(acronym_period/len(letters), 2) time_start = round(float(items[2]), 2) for l in letters[:-1]: time = " %.2f %.2f " % (time_start, letter_slot) diff --git a/egs/multi_en/s5/local/g2p/apply_g2p.sh b/egs/multi_en/s5/local/g2p/apply_g2p.sh deleted file mode 100755 index 8484155800d..00000000000 --- a/egs/multi_en/s5/local/g2p/apply_g2p.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/bin/bash - -# Copyright 2016 Allen Guo -# 2017 Xiaohui Zhang -# Apache License 2.0 - -# This script applies a trained Phonetisarus G2P model to -# synthesize pronunciations for missing words (i.e., words in -# transcripts but not the lexicon), and output the expanded lexicon. - -var_counts=1 - -. ./path.sh || exit 1 -. parse_options.sh || exit 1; - -if [ $# -ne "4" ]; then - echo "Usage: $0 " - exit 1 -fi - -model=$1 -workdir=$2 -lexicon=$3 -outlexicon=$4 - -mkdir -p $workdir - -# awk command from http://stackoverflow.com/questions/2626274/print-all-but-the-first-three-columns -echo 'Gathering missing words...' -cat data/*/train/text | \ - local/count_oovs.pl $lexicon | \ - awk '{if (NF > 3 ) {for(i=4; i $workdir/missing.txt -cat $workdir/missing.txt | \ - grep "^[a-z]*$" > $workdir/missing_onlywords.txt - -echo 'Synthesizing pronunciations for missing words...' -phonetisaurus-apply --nbest $var_counts --model $model --thresh 5 --accumulate --word_list $workdir/missing_onlywords.txt > $workdir/missing_g2p_${var_counts}.txt - -echo "Adding new pronunciations to $lexicon" -cat "$lexicon" $workdir/missing_g2p_${var_counts}.txt | sort | uniq > $outlexicon diff --git a/egs/multi_en/s5/local/g2p/train_g2p.sh b/egs/multi_en/s5/local/g2p/train_g2p.sh deleted file mode 100755 index 43e75f6608d..00000000000 --- a/egs/multi_en/s5/local/g2p/train_g2p.sh +++ /dev/null @@ -1,67 +0,0 @@ -#!/bin/bash - -# Copyright 2017 Intellisist, Inc. (Author: Navneeth K) -# 2017 Xiaohui Zhang -# Apache License 2.0 - -# This script trains a g2p model using Phonetisaurus and SRILM. - -stage=0 -silence_phones= - -echo "$0 $@" # Print the command line for logging - -[ -f ./path.sh ] && . ./path.sh; # source the path. -. utils/parse_options.sh || exit 1; - - -if [ $# -ne 2 ]; then - echo "Usage: $0 " - exit 1; -fi - -lexicondir=$1 -outdir=$2 - -[ ! -f $lexicondir/lexicon.txt ] && echo "Cannot find $lexicondir/lexicon.txt" && exit - -isuconv=`which uconv` -if [ -z $isuconv ]; then - echo "uconv was not found. You must install the icu4c package." - exit 1; -fi - -mkdir -p $outdir - - -# For input lexicon, remove pronunciations containing non-utf-8-encodable characters, -# and optionally remove words that are mapped to a single silence phone from the lexicon. -if [ $stage -le 0 ]; then - lexicon=$lexicondir/lexicon.txt - if [ ! -z "$silence_phones" ]; then - awk 'NR==FNR{a[$1] = 1; next} {s=$2;for(i=3;i<=NF;i++) s=s" "$i; if(!(s in a)) print $1" "s}' \ - $silence_phones $lexicon | \ - awk '{printf("%s\t",$1); for (i=2;i 0'> $outdir/lexicon_tab_separated.txt - else - awk '{printf("%s\t",$1); for (i=2;i 0'> $outdir/lexicon_tab_separated.txt - fi -fi - -if [ $stage -le 1 ]; then - # Align lexicon stage. Lexicon is assumed to have first column tab separated - phonetisaurus-align --input=$outdir/lexicon_tab_separated.txt --ofile=${outdir}/aligned_lexicon.corpus || exit 1; -fi - -if [ $stage -le 2 ]; then - # Convert aligned lexicon to arpa using srilm. - ngram-count -order 7 -kn-modify-counts-at-end -gt1min 0 -gt2min 0 \ - -gt3min 0 -gt4min 0 -gt5min 0 -gt6min 0 -gt7min 0 -ukndiscount \ - -text ${outdir}/aligned_lexicon.corpus -lm ${outdir}/aligned_lexicon.arpa -fi - -if [ $stage -le 3 ]; then - # Convert the arpa file to FST. - phonetisaurus-arpa2wfst --lm=${outdir}/aligned_lexicon.arpa --ofile=${outdir}/model.fst -fi diff --git a/egs/multi_en/s5/local/normalize_transcript.py b/egs/multi_en/s5/local/normalize_transcript.py index 4572f4d658d..c640723a885 100755 --- a/egs/multi_en/s5/local/normalize_transcript.py +++ b/egs/multi_en/s5/local/normalize_transcript.py @@ -7,6 +7,7 @@ # This script normalizes the given "text" (transcript) file. The normalized result # is printed to STDOUT. This normalization should be applied to all corpora. +from __future__ import print_function import re import sys @@ -26,7 +27,7 @@ def normalize(utt): def main(): if len(sys.argv) != 2: - print 'Usage: local/normalize_transcript.py [text_file]' + print('Usage: local/normalize_transcript.py [text_file]') sys.exit(1) with open(sys.argv[1], 'r') as f: for line in f.readlines(): diff --git a/egs/multi_en/s5/local/tedlium_join_suffix.py b/egs/multi_en/s5/local/tedlium_join_suffix.py index c85e8f364f6..47db4ce0b05 100755 --- a/egs/multi_en/s5/local/tedlium_join_suffix.py +++ b/egs/multi_en/s5/local/tedlium_join_suffix.py @@ -12,6 +12,7 @@ # Apache 2.0 +from __future__ import print_function import sys from codecs import open diff --git a/egs/multi_en/s5/run.sh b/egs/multi_en/s5/run.sh index 3a1262101aa..034ffeb4e66 100755 --- a/egs/multi_en/s5/run.sh +++ b/egs/multi_en/s5/run.sh @@ -58,8 +58,8 @@ if [ $stage -le 1 ]; then # We prepare the basic dictionary in data/local/dict_combined. local/prepare_dict.sh $swbd $tedlium2 ( - local/g2p/train_g2p.sh --stage 0 --silence-phones \ - "data/local/dict_combined/silence_phones.txt" data/local/dict_combined exp/g2p || touch exp/g2p/.error + steps/dict/train_g2p_phonetisaurus.sh --stage 0 --silence-phones \ + "data/local/dict_combined/silence_phones.txt" data/local/dict_combined/lexicon.txt exp/g2p || touch exp/g2p/.error ) & fi @@ -114,8 +114,28 @@ if [ $stage -le 4 ]; then mkdir -p $dict_dir rm $dict_dir/lexiconp.txt 2>/dev/null || true cp data/local/dict_combined/{extra_questions,nonsilence_phones,silence_phones,optional_silence}.txt $dict_dir - local/g2p/apply_g2p.sh --var-counts 1 exp/g2p/model.fst data/local/g2p_phonetisarus \ - data/local/dict_combined/lexicon.txt $dict_dir/lexicon.txt || exit 1; + + echo 'Gathering missing words...' + + lexicon=data/local/dict_combined/lexicon.txt + g2p_tmp_dir=data/local/g2p_phonetisarus + mkdir -p $g2p_tmp_dir + + # awk command from http://stackoverflow.com/questions/2626274/print-all-but-the-first-three-columns + cat data/*/train/text | \ + local/count_oovs.pl $lexicon | \ + awk '{if (NF > 3 ) {for(i=4; i $g2p_tmp_dir/missing.txt + cat $g2p_tmp_dir/missing.txt | \ + grep "^[a-z]*$" > $g2p_tmp_dir/missing_onlywords.txt + + steps/dict/apply_g2p_phonetisaurus.sh --nbest 1 $g2p_tmp_dir/missing_onlywords.txt exp/g2p exp/g2p/oov_lex || exit 1; + cp exp/g2p/oov_lex/lexicon.lex $g2p_tmp_dir/missing_lexicon.txt + + extended_lexicon=$dict_dir/lexicon.txt + echo "Adding new pronunciations to get extended lexicon $extended_lexicon" + cat <(cut -f 1,3 $g2p_tmp_dir/missing_lexicon.txt) $lexicon | sort | uniq > $extended_lexicon fi # We'll do multiple iterations of pron/sil-prob estimation. So the structure of diff --git a/egs/reverb/s5/README.txt b/egs/reverb/s5/README.txt index 1daa214edb6..0ac97059952 100644 --- a/egs/reverb/s5/README.txt +++ b/egs/reverb/s5/README.txt @@ -1,130 +1,36 @@ -Improved multi condition training baseline for REVERB challenge based on Kaldi -============================================================================== +Improved baseline for REVERB challenge +====================================== -updated -Wed Apr 29 19:10:33 EDT 2015 Shinji Watanabe - -updated -Wed Apr 9 12:14:02 CEST 2014 Felix Weninger - -original: -Wed Nov 6 14:47:59 EST 2013 Felix Weninger +This is an improvement over "Improved multi condition training baseline" from Felix Weninger & Shinji Watanabe Key specs: -- MFCC-LDA-STC front-end -- Boosted MMI trained GMM-HMM -- Utterance-based adaptation using basis fMLLR -- Tri-gram LM minimum Bayes risk decoding - -WER [%] -@ Language model weight = 15 -Avg(SimData_(far|near)) = 11.73 -Avg(RealData) = 30.44 -@ Language model weight = 16 (optimal) -Avg(SimData_(far|near)) = 11.72 -Avg(RealData) = 30.28 - -See RESULTS in more detail - -Kaldi SVN rev. 5035, 4/26/15 -tested on Ubuntu 13.04 +- Nara-WPE and BeamformIt front-end enhancement +- TDNN acoustic model +RESULT: +For experiment results, please see RESULTS for more detail REFERENCE: ++++++++ If you find this software useful for your own research, please cite the -following paper: +following papers: Felix Weninger, Shinji Watanabe, Jonathan Le Roux, John R. Hershey, Yuuki Tachioka, Jürgen Geiger, Björn Schuller, Gerhard Rigoll: "The MERL/MELCO/TUM system for the REVERB Challenge using Deep Recurrent Neural Network Feature Enhancement", Proc. REVERB Workshop, IEEE, Florence, Italy, May 2014. +Lukas Drude, Jahn Heymann, Christoph Boeddeker, and Reinhold Haeb-Umbach: +"NARA-WPE: A Python package for weighted prediction error dereverberation in +Numpy and Tensorflow for online and offline processing." In Speech Communication; +13th ITG-Symposium, pp. 1-5. VDE, 2018. INSTRUCTIONS: +++++++++++++ - -1) Set the path names in corpus.sh.default, - and copy this file to "corpus.sh" - ------ -2) [optional:] If you have speech enhancement (processed waveforms), then - -3a) Change directories and data preparation steps - For example, you could have something like - - local/REVERB_wsjcam0_data_prep.sh /path/to/processed/REVERB_WSJCAM0_dt REVERB_dt_derev dt - - The first argument is supposed to point to a folder that has the same - structure as the REVERB corpus. - -3b) run the multi-condition training steps in run.sh with the processed - training set, e.g., REVERB_tr_cut_derev, if you want to investigate - recognizer re-training - - - Any system that has _mc in its name uses multi-condition training - - You probably want to change the system names if you are using enhanced - data for training (e.g. tri2b_mc -> tri2b_mc_derev) - -3c) Add your re-trained recognizer to the list of recognizers that are - discriminatively re-trained - -3d) Modify the decoding steps in run.sh so that they use enhanced data and add - your re-trained recognizer(s) to the list ------ - -4) Execute the training and recognition steps by +1) Execute the training and recognition steps by ./run.sh Depending on your system specs (# of CPUs, RAM) you might want (or have) to - change the number of parallel jobs -- this is controlled by the nj_train, - nj_bg, and nj_tg variables (# of jobs for training, for bi-gram and tri-gram - decoding). - - If you also want to have the re-implementation of the HTK baseline in Kaldi - (tri2a and tri2a_mc systems), set the do_tri2a variable to true in run.sh. - -5) Execute - - ./local/get_results.sh - - to display the results corresponding to Table 1 in - the following paper, - - Felix Weninger, Shinji Watanabe, Jonathan Le Roux, John R. Hershey, Yuuki - Tachioka, Jürgen Geiger, Björn Schuller, Gerhard Rigoll: "The MERL/MELCO/TUM - system for the REVERB Challenge using Deep Recurrent Neural Network Feature - Enhancement", to appear in Proc. REVERB Workshop, IEEE, Florence, Italy, 2014. - - NOTE: It is very common to have slightly different results (up to +/- 1% - absolute WER per REVERB task file) on different machines. The reason for - this is not fully known. - - NOTE 2: By default, only the LDA-STC systems are trained - set do_tri2a in - run.sh to true to also train the Delta+Delta-Delta systems (cf. above). - ------ -6) You can get more recognition results (for other combinations of front-ends, - adaptation, language model, etc.), by - - $> local/summarize_results.pl [options] [ [ >>>>>> 77343718c6dc1936d7374b4948be4706d6f9ee2a diff --git a/egs/reverb/s5/conf/decode_dnn.config b/egs/reverb/s5/conf/decode_dnn.config deleted file mode 100644 index bfaae86702e..00000000000 --- a/egs/reverb/s5/conf/decode_dnn.config +++ /dev/null @@ -1,2 +0,0 @@ -beam=18.0 # beam for decoding. Was 13.0 in the scripts. -latbeam=10.0 # this has most effect on size of the lattices. diff --git a/egs/reverb/s5/conf/fbank.conf b/egs/reverb/s5/conf/fbank.conf deleted file mode 100644 index c4b73674cab..00000000000 --- a/egs/reverb/s5/conf/fbank.conf +++ /dev/null @@ -1,2 +0,0 @@ -# No non-default options for now. - diff --git a/egs/reverb/s5/conf/mfcc_hires.conf b/egs/reverb/s5/conf/mfcc_hires.conf new file mode 100644 index 00000000000..fd64b62eb16 --- /dev/null +++ b/egs/reverb/s5/conf/mfcc_hires.conf @@ -0,0 +1,10 @@ +# config for high-resolution MFCC features, intended for neural network training. +# Note: we keep all cepstra, so it has the same info as filterbank features, +# but MFCC is more easily compressible (because less correlated) which is why +# we prefer this method. +--use-energy=false # use average of log energy, not energy. +--sample-frequency=16000 +--num-mel-bins=40 +--num-ceps=40 +--low-freq=40 +--high-freq=-400 diff --git a/egs/reverb/s5/conf/online_cmvn.conf b/egs/reverb/s5/conf/online_cmvn.conf new file mode 100644 index 00000000000..7748a4a4dd3 --- /dev/null +++ b/egs/reverb/s5/conf/online_cmvn.conf @@ -0,0 +1 @@ +# configuration file for apply-cmvn-online, used in the script ../local/run_online_decoding.sh diff --git a/egs/reverb/s5/conf/reverb_beamformit.cfg b/egs/reverb/s5/conf/reverb_beamformit.cfg new file mode 100755 index 00000000000..70fdd858651 --- /dev/null +++ b/egs/reverb/s5/conf/reverb_beamformit.cfg @@ -0,0 +1,50 @@ +#BeamformIt sample configuration file for AMI data (http://groups.inf.ed.ac.uk/ami/download/) + +# scrolling size to compute the delays +scroll_size = 250 + +# cross correlation computation window size +window_size = 500 + +#amount of maximum points for the xcorrelation taken into account +nbest_amount = 4 + +#flag wether to apply an automatic noise thresholding +do_noise_threshold = 1 + +#Percentage of frames with lower xcorr taken as noisy +noise_percent = 10 + +######## acoustic modelling parameters + +#transition probabilities weight for multichannel decoding +trans_weight_multi = 25 +trans_weight_nbest = 25 + +### + +#flag wether to print the feaures after setting them, or not +print_features = 1 + +#flag wether to use the bad frames in the sum process +do_avoid_bad_frames = 1 + +#flag to use the best channel (SNR) as a reference +#defined from command line +do_compute_reference = 1 + +#flag wether to use a uem file or not(process all the file) +do_use_uem_file = 0 + +#flag wether to use an adaptative weights scheme or fixed weights +do_adapt_weights = 1 + +#flag wether to output the sph files or just run the system to create the auxiliary files +do_write_sph_files = 1 + +####directories where to store/retrieve info#### +#channels_file = ./cfg-files/channels + +#show needs to be passed as argument normally, here a default one is given just in case +#show_id = Ttmp + diff --git a/egs/reverb/s5/local/Generate_mcTrainData_cut.m b/egs/reverb/s5/local/Generate_mcTrainData_cut.m index cc01ff89b7d..831ff6a5226 100755 --- a/egs/reverb/s5/local/Generate_mcTrainData_cut.m +++ b/egs/reverb/s5/local/Generate_mcTrainData_cut.m @@ -1,13 +1,13 @@ function Generate_mcTrainData_cut(WSJ_dir_name, save_dir) % % Input variables: -% WSJ_dir_name: string name of user's clean wsjcam0 corpus directory -% (*Directory structure for wsjcam0 corpushas to be kept as it is after obtaining it from LDC. +% WSJ_dir_name: string name of WAV file directory converted from original wsjcam0 SPHERE files +% (*Directory structure for wsjcam0 corpus to be kept as it is after obtaining it from LDC. % Otherwise this script does not work.) % % This function generates multi-condition traiing data % based on the following items: -% 1. wsjcam0 corpus (distributed from the LDC) +% 1. wsjcam0 corpus (WAV files) % 2. room impulse responses (ones under ./RIR/) % 3. noise (ones under ./NOISE/). % Generated data has the same directory structure as original wsjcam0 corpus. @@ -26,8 +26,6 @@ function Generate_mcTrainData_cut(WSJ_dir_name, save_dir) display(['Name of directory for original wsjcam0: ',WSJ_dir_name]) display(['Name of directory to save generated multi-condition training data: ',save_dir]) -unix(['chmod u+x sphere_to_wave.csh']); -unix(['chmod u+x bin/*']); % Parameters related to acoustic conditions SNRdB=20; @@ -89,7 +87,6 @@ function Generate_mcTrainData_cut(WSJ_dir_name, save_dir) save_dir_tr=[save_dir,'/data/mc_train/']; end mkdir([save_dir_tr]); -%mkdir([save_dir,'/taskfiles/']) mic_idx=['A';'B';'C';'D';'E';'F';'G';'H']; prev_fname='dummy'; @@ -114,13 +111,12 @@ function Generate_mcTrainData_cut(WSJ_dir_name, save_dir) end prev_fname=fname(1:idx1(end)); - % load (sphere format) speech signal - x=read_sphere([WSJ_dir_name,'/data/', fname]); - x=x/(2^15); % conversion from short-int to float + % load speech signal + x=audioread([WSJ_dir_name, '/data/', fname, '.wav'])'; % load RIR and noise for "THIS" utterance - eval(['RIR=wavread(RIR_sim',num2str(rcount),');']); - eval(['NOISE=wavread([noise_sim',num2str(ceil(rcount/4)),',''_',num2str(ncount),'.wav'']);']); + eval(['RIR=audioread(RIR_sim',num2str(rcount),');']); + eval(['NOISE=audioread([noise_sim',num2str(ceil(rcount/4)),',''_',num2str(ncount),'.wav'']);']); % Generate 8ch noisy reverberant data y=gen_obs(x,RIR,NOISE,SNRdB); @@ -138,8 +134,9 @@ function Generate_mcTrainData_cut(WSJ_dir_name, save_dir) y=y/4; % common normalization to all the data to prevent clipping % denominator was decided experimentally - for ch=1:8 - eval(['wavwrite(y(:,',num2str(ch),'),16000,''',save_dir_tr fname,'_ch',num2str(ch),'.wav'');']); + for ch=1:8 + outfilename = [save_dir_tr, fname, '_ch', num2str(ch), '.wav']; + eval(['audiowrite(outfilename, y(:,',num2str(ch),'), 16000);']); end display(['sentence ',num2str(fcount),' (out of 7861) finished! (Multi-condition training data)']) diff --git a/egs/reverb/s5/local/REVERB_create_mcdata.sh b/egs/reverb/s5/local/REVERB_create_mcdata.sh deleted file mode 100755 index 4cc776aa159..00000000000 --- a/egs/reverb/s5/local/REVERB_create_mcdata.sh +++ /dev/null @@ -1,74 +0,0 @@ -#!/bin/bash - -# Copyright 2013 MERL (author: Shinji Watanabe) -# Contains some code by Microsoft Corporation, Johns Hopkins University (author: Daniel Povey) - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -# MERCHANTABLITY OR NON-INFRINGEMENT. -# See the Apache 2 License for the specific language governing permissions and -# limitations under the License. - -if [ $# -ne 2 ]; then - printf "\nUSAGE: %s \n\n" `basename $0` - echo "e.g.,:" - echo " `basename $0` /archive/speech-db/processed/public/REVERB/wsjcam0 data_mc_tr" - exit 1; -fi - -wsjcam0_dir=$1 -reverb_tr_dir=$2 - -dir=`pwd`/data/local/reverb_tools -mkdir -p $dir $reverb_tr_dir -lmdir=`pwd`/data/local/nist_lm - -# Download tools -URL1="http://reverb2014.dereverberation.com/tools/reverb_tools_for_Generate_mcTrainData.tgz" -URL2="http://reverb2014.dereverberation.com/tools/REVERB_TOOLS_FOR_ASR_ver2.0.tgz" -for f in $URL1 $URL2; do - x=`basename $f` - if [ ! -e $dir/$x ]; then - wget $f -O $dir/$x || exit 1; - tar zxvf $dir/$x -C $dir || exit 1; - fi -done -URL3="http://reverb2014.dereverberation.com/tools/taskFiles_et.tgz" -x=`basename $URL3` -if [ ! -e $dir/$x ]; then - wget $URL3 -O $dir/$x || exit 1; - tar zxvf $dir/$x -C $dir || exit 1; - cp -fr $dir/`basename $x .tgz`/* $dir/ReleasePackage/reverb_tools_for_asr_ver2.0/taskFiles/ -fi - -# Download and install nist tools -pushd $dir/ReleasePackage/reverb_tools_for_asr_ver2.0 -perl -ape "s|^main$|targetSPHEREDir\=tools/SPHERE\ninstall_nist|;" installTools > installnist -chmod u+x installnist -./installnist -popd - -# Make mcTrainData -cp local/Generate_mcTrainData_cut.m $dir/reverb_tools_for_Generate_mcTrainData/ -pushd $dir/reverb_tools_for_Generate_mcTrainData/ -# copied nist tools required for the following matlab command -cp $dir/ReleasePackage/reverb_tools_for_asr_ver2.0/tools/SPHERE/nist/bin/{h_strip,w_decode} ./bin/ - -tmpdir=`mktemp -d tempXXXXX ` -tmpmfile=$tmpdir/run_mat.m -cat < $tmpmfile -addpath(genpath('.')) -Generate_mcTrainData_cut('$wsjcam0_dir', '$reverb_tr_dir'); -EOF -cat $tmpmfile | matlab -nodisplay -rm -rf $tmpdir -popd - -echo "Successfully generated multi-condition training data and stored it in $reverb_tr_dir." && exit 0; diff --git a/egs/reverb/s5/local/REVERB_mcwsjav_data_prep.sh b/egs/reverb/s5/local/REVERB_mcwsjav_data_prep.sh deleted file mode 100755 index a4599f97702..00000000000 --- a/egs/reverb/s5/local/REVERB_mcwsjav_data_prep.sh +++ /dev/null @@ -1,165 +0,0 @@ -#!/bin/bash - -# Copyright 2013 MERL (author: Felix Weninger) -# Contains some code by Microsoft Corporation, Johns Hopkins University (author: Daniel Povey) - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -# MERCHANTABLITY OR NON-INFRINGEMENT. -# See the Apache 2 License for the specific language governing permissions and -# limitations under the License. - - -# for REVERB challenge: - -dir=`pwd`/data/local/data -lmdir=`pwd`/data/local/nist_lm -mkdir -p $dir $lmdir -local=`pwd`/local -utils=`pwd`/utils -root=`pwd` - -. ./path.sh # Needed for KALDI_ROOT -export PATH=$PATH:$KALDI_ROOT/tools/irstlm/bin -sph2pipe=$KALDI_ROOT/tools/sph2pipe_v2.5/sph2pipe -if [ ! -x $sph2pipe ]; then - echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; - exit 1; -fi - -cd $dir - -MIC=primary - -# input corpus (original or processed, tr or dt, etc.) -RWSJ=$1 -if [ ! -d "$RWSJ" ]; then - echo Could not find directory $RWSJ! Check pathnames in corpus.sh! - exit 1 -fi - -mcwsjav_mlf=$RWSJ/mlf/WSJ.mlf -if [ ! -z "$4" ]; then - mcwsjav_mlf=$4 -fi - -# the name of the dataset to be created -dataset=REVERB_Real_dt - -# the WSJCAM0 set that the set is based on (tr, dt, ...) -# this will be used to find the correct transcriptions etc. -dt_or_x=dt - -if [ ! -z "$2" ]; then - dataset=$2 -fi -# dt or et -if [ ! -z "$3" ]; then - dt_or_x=$3 -fi - -# unfortunately, we need a pointer to HTK baseline -# since the corpus does NOT contain the data set descriptions -# for the REVERB Challenge - -taskFileDir=$dir/../reverb_tools/ReleasePackage/reverb_tools_for_asr_ver2.0/taskFiles/1ch -#taskFiles=`ls $taskFileDir/*Data_dt_for_*` -taskFiles=`ls $taskFileDir/RealData_${dt_or_x}_for_1ch_{far,near}*` - -dir2=$dir/$dataset -mkdir -p $dir2 - -for taskFile in $taskFiles; do - -set=`basename $taskFile` - - -echo $mcwsjav_mlf - -# MLF transcription correction -# taken from HTK baseline script -sed -e ' -# dos to unix line feed conversion -s/\x0D$//' \ --e " - s/\x60//g # remove unicode character grave accent. - " \ --e " - # fix the single quote for the word yield - # and the quoted ROOTS - # e.g. yield' --> yield - # reason: YIELD' is not in dict, while YIELD is - s/YIELD'/YIELD/g - s/'ROOTS'/ROOTS/g - s/'WHERE/WHERE/g - s/PEOPLE'/PEOPLE/g - s/SIT'/SIT/g - s/'DOMINEE/DOMINEE/g - s/CHURCH'/CHURCH/g" \ --e ' - # fix the single missing double full stop issue at the end of an utterance - # e.g. I. C. N should be I. C. N. - # reason: N is not in dict, while N. is - /^[A-Z]$/ { - # append a line - N - # search for single dot on the second line - /\n\./ { - # found it - now replace the - s/\([A-Z]\)\n\./\1\.\n\./ - } - }' \ -$mcwsjav_mlf |\ -perl $local/mlf2text.pl > $dir2/$set.txt1 - -#exit - -#taskFile=$taskFileDir/$set -# contains pointer to wav files with relative path --> add absolute path -echo taskFile = $taskFile -awk '{print "'$RWSJ'"$1}' < $taskFile > $dir2/${set}.flist || exit 1; - -# this is like flist2scp.pl but it can take wav file list as input -(perl -e 'while(<>){ - m:^\S+/[\w\-]*_(T\w{6,7})\.wav$: || die "Bad line $_"; - $id = lc $1; - print "$id $_"; -}' < $dir2/$set.flist || exit 1) | sort > $dir2/${set}_wav.scp - - -# Make the utt2spk and spk2utt files. -cat $dir2/${set}_wav.scp | awk '{print $1, $1}' > $dir2/$set.utt2spk || exit 1; -cat $dir2/$set.utt2spk | $utils/utt2spk_to_spk2utt.pl > $dir2/$set.spk2utt || exit 1; - -awk '{print $1}' < $dir2/$set.utt2spk |\ -$local/find_transcripts_txt.pl $dir2/$set.txt1 | sort | uniq > $dir2/$set.txt -#rm $dir2/$set.txt1 - -# Create directory structure required by decoding scripts - -cd $root -mkdir -p data/$dataset/$set -cp $dir2/${set}_wav.scp data/$dataset/$set/wav.scp || exit 1; -cp $dir2/$set.txt data/$dataset/$set/text || exit 1; -cp $dir2/$set.spk2utt data/$dataset/$set/spk2utt || exit 1; -cp $dir2/$set.utt2spk data/$dataset/$set/utt2spk || exit 1; - -echo "Data preparation for $set succeeded" -#echo "Put files into $dir2/$set.*" - - -mfccdir=mfcc/$dataset -#for x in test_eval92_clean test_eval92_5k_clean dev_dt_05_clean dev_dt_20_clean train_si84_clean; do -#for x in si_tr; do -steps/make_mfcc.sh --cmd "$train_cmd" --nj 10 \ - data/$dataset/$set exp/make_mfcc/$dataset/$set $mfccdir || exit 1; -steps/compute_cmvn_stats.sh data/$dataset/$set exp/make_mfcc/$dataset/$set $mfccdir || exit 1; - -done diff --git a/egs/reverb/s5/local/REVERB_wsjcam0_data_prep.sh b/egs/reverb/s5/local/REVERB_wsjcam0_data_prep.sh deleted file mode 100755 index 6ab2f2f4b73..00000000000 --- a/egs/reverb/s5/local/REVERB_wsjcam0_data_prep.sh +++ /dev/null @@ -1,117 +0,0 @@ -#!/bin/bash - -# Copyright 2013 MERL (author: Felix Weninger) -# Contains some code by Microsoft Corporation, Johns Hopkins University (author: Daniel Povey) - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -# MERCHANTABLITY OR NON-INFRINGEMENT. -# See the Apache 2 License for the specific language governing permissions and -# limitations under the License. - -dir=$PWD/data/local/data -lmdir=$PWD/data/local/nist_lm -mkdir -p $dir $lmdir -local=$PWD/local -utils=$PWD/utils -root=$PWD - -. ./path.sh # Needed for KALDI_ROOT -export PATH=$PATH:$KALDI_ROOT/tools/irstlm/bin -sph2pipe=$KALDI_ROOT/tools/sph2pipe_v2.5/sph2pipe -if [ ! -x $sph2pipe ]; then - echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; - exit 1; -fi - -RWSJ=$1 # input corpus (original or processed, tr or dt, etc.) -dataset=REVERB_dt # the name of the dataset to be created -if [ ! -z "$2" ]; then - dataset=$2 -fi -dt_or_x=dt # the WSJCAM0 set that the set is based on (tr, dt, ...) -# this will be used to find the correct transcriptions etc. -if [ ! -z "$3" ]; then - dt_or_x=$3 -fi - -if [ ! -d "$RWSJ" ]; then - echo Could not find directory $RWSJ! Check pathnames in corpus.sh! - exit 1 -fi - -cd $dir -MIC=primary - -# unfortunately, we need a pointer to HTK baseline -# since the corpus does NOT contain the data set descriptions -# for the REVERB Challenge -taskFileDir=$dir/../reverb_tools/ReleasePackage/reverb_tools_for_asr_ver2.0/taskFiles/1ch -#taskFiles=`ls $taskFileDir/*Data_dt_for_*` -nch=1 -if [ "$dt_or_x" = "tr" ]; then - taskFiles=`ls $taskFileDir/SimData_tr_for_${nch}ch*` || exit 1 -else - taskFiles=`ls $taskFileDir/SimData_${dt_or_x}_for_${nch}ch_{far,near}*` || exit 1 -fi -for taskFile in $taskFiles; do - -set=`basename $taskFile` - -#taskFile=$taskFileDir/$set -dir2=$dir/$dataset -mkdir -p $dir2 -# contains pointer to wav files with relative path --> add absolute path -echo taskFile = $taskFile -awk '{print "'$RWSJ/data'"$1}' < $taskFile > $dir2/${set}.flist || exit 1; - -# this is like flist2scp.pl but it can take wav file list as input -perl -e 'while(<>){ - m:^\S+/(\w{8})\w*\.wav$: || die "Bad line $_"; - $id = lc $1; - print "$id $_"; -}' < $dir2/$set.flist | sort > $dir2/${set}_wav.scp || exit 1; - -# find transcriptions of given utterances in si_dt.dot -# create a trans1 file for each set, convert to txt (kaldi "MLF") -dot=$dir/si_${dt_or_x}.dot -perl -e 'while (<>) { chomp; if (m/\/(\w{8})[^\/]+$/) { print $1, "\n"; } }' $taskFile |\ -perl $local/find_transcripts_singledot.pl $dot \ -> $dir2/$set.trans1 || exit 1; - -noiseword=""; -cat $dir2/$set.trans1 | $local/normalize_transcript.pl $noiseword | sort | uniq > $dir2/$set.txt || exit 1; -#exit - - -# Make the utt2spk and spk2utt files. -cat $dir2/${set}_wav.scp | awk '{print $1, $1}' > $dir2/$set.utt2spk || exit 1; -cat $dir2/$set.utt2spk | $utils/utt2spk_to_spk2utt.pl > $dir2/$set.spk2utt || exit 1; - -# Create directory structure required by decoding scripts -cd $root -mkdir -p data/$dataset/$set -cp $dir2/${set}_wav.scp data/$dataset/$set/wav.scp || exit 1; -cp $dir2/$set.txt data/$dataset/$set/text || exit 1; -cp $dir2/$set.spk2utt data/$dataset/$set/spk2utt || exit 1; -cp $dir2/$set.utt2spk data/$dataset/$set/utt2spk || exit 1; - -echo "Data preparation for $set succeeded" -#echo "Put files into $dir2/$set.*" - - -mfccdir=mfcc/$dataset -#for x in test_eval92_clean test_eval92_5k_clean dev_dt_05_clean dev_dt_20_clean train_si84_clean; do -#for x in si_tr; do -steps/make_mfcc.sh --cmd "$train_cmd" --nj 10 \ - data/$dataset/$set exp/make_mfcc/$dataset/$set $mfccdir || exit 1; -steps/compute_cmvn_stats.sh data/$dataset/$set exp/make_mfcc/$dataset/$set $mfccdir || exit 1; - -done diff --git a/egs/reverb/s5/local/calc_wer.sh b/egs/reverb/s5/local/calc_wer.sh deleted file mode 100755 index c4b5eeb87f3..00000000000 --- a/egs/reverb/s5/local/calc_wer.sh +++ /dev/null @@ -1,55 +0,0 @@ -#!/bin/bash - -# Copyright 2016 MERL (author: Shinji Watanabe) - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -# MERCHANTABLITY OR NON-INFRINGEMENT. -# See the Apache 2 License for the specific language governing permissions and -# limitations under the License. - -. ./cmd.sh -. ./path.sh - -lmw=15 -am="tri2a" -lm="bg_5k" -decode="" - -. utils/parse_options.sh - -if [ ! -z $decode ]; then - decode="_$decode" -fi - -dir="exp/$am/decode${decode}_${lm}_REVERB_" -echo "####################" -echo "${dir}*dt*" -for a in `echo ${dir}*dt* | tr " " "\n" | grep -v "A\.si"`; do - echo $a | awk -F '_' '{for(i=NF-6;i [ ... ]" + echo "e.g.: $0 exp/chain/tdnn_{b,c}_sp" + echo "or (with epoch numbers for discriminative training):" + echo "$0 exp/chain/tdnn_b_sp_disc:{1,2,3}" + exit 1 +fi + +echo "# $0 $*" + +include_looped=false +if [ "$1" == "--looped" ]; then + include_looped=true + shift +fi +include_online=false +if [ "$1" == "--online" ]; then + include_online=true + shift +fi + + +used_epochs=false + +# this function set_names is used to separate the epoch-related parts of the name +# [for discriminative training] and the regular parts of the name. +# If called with a colon-free directory name, like: +# set_names exp/chain/tdnn_lstm1e_sp_bi_smbr +# it will set dir=exp/chain/tdnn_lstm1e_sp_bi_smbr and epoch_infix="" +# If called with something like: +# set_names exp/chain/tdnn_d_sp_smbr:3 +# it will set dir=exp/chain/tdnn_d_sp_smbr and epoch_infix="_epoch3" + + +set_names() { + if [ $# != 1 ]; then + echo "compare_wer_general.sh: internal error" + exit 1 # exit the program + fi + dirname=$(echo $1 | cut -d: -f1) + epoch=$(echo $1 | cut -s -d: -f2) + if [ -z $epoch ]; then + epoch_infix="" + else + used_epochs=true + epoch_infix=_epoch${epoch} + fi +} + + + +echo -n "# System " +for x in $*; do printf "% 10s" " $(basename $x)"; done +echo + +strings=( + "#WER dev_clean_2 (tgsmall) " + "#WER dev_clean_2 (tglarge) ") + +for n in 0 1; do + echo -n "${strings[$n]}" + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + decode_names=(tgsmall_dev_clean_2 tglarge_dev_clean_2) + + wer=$(cat $dirname/decode_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + if $include_looped; then + echo -n "# [looped:] " + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + wer=$(cat $dirname/decode_looped_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + fi + if $include_online; then + echo -n "# [online:] " + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + wer=$(cat ${dirname}_online/decode_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + fi +done + + +if $used_epochs; then + exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. +fi + + +echo -n "# Final train prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final train prob (xent)" +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob (xent)" +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo diff --git a/egs/reverb/s5/local/chain/run_tdnn.sh b/egs/reverb/s5/local/chain/run_tdnn.sh new file mode 120000 index 00000000000..34499362831 --- /dev/null +++ b/egs/reverb/s5/local/chain/run_tdnn.sh @@ -0,0 +1 @@ +tuning/run_tdnn_1a.sh \ No newline at end of file diff --git a/egs/reverb/s5/local/chain/run_tdnn_lstm.sh b/egs/reverb/s5/local/chain/run_tdnn_lstm.sh new file mode 120000 index 00000000000..8e647598556 --- /dev/null +++ b/egs/reverb/s5/local/chain/run_tdnn_lstm.sh @@ -0,0 +1 @@ +tuning/run_tdnn_lstm_1a.sh \ No newline at end of file diff --git a/egs/reverb/s5/local/chain/tuning/run_tdnn_1a.sh b/egs/reverb/s5/local/chain/tuning/run_tdnn_1a.sh new file mode 100755 index 00000000000..c8b4997161e --- /dev/null +++ b/egs/reverb/s5/local/chain/tuning/run_tdnn_1a.sh @@ -0,0 +1,281 @@ +#!/bin/bash + +# Set -e here so that we catch if any executable fails immediately +set -euo pipefail + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +nj=96 +train_set=tr_simu_8ch +test_sets="dt_real_1ch dt_simu_1ch et_real_1ch et_simu_1ch" +gmm=tri3 +nnet3_affix=_tr_simu_8ch +lm_suffix= + +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +affix=1a # affix for the TDNN directory name +tree_affix= +train_stage=-10 +get_egs_stage=-10 +decode_iter= + +# training options +# training chunk-options +chunk_width=140,100,160 +# we don't need extra left/right context for TDNN systems. +chunk_left_context=0 +chunk_right_context=0 +common_egs_dir= +xent_regularize=0.1 + +# training options +srand=0 +remove_egs=true +reporting_email= + +#decode options +test_online_decoding=false # if true, it will run the last decoding stage. + + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 11 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj ${nj} --cmd "$train_cmd" ${lores_train_data_dir} \ + data/lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 12 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 3500 ${lores_train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 13 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + opts="l2-regularize=0.05" + output_opts="l2-regularize=0.01 bottleneck-dim=320" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-2,-1,0,1,2,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 $opts dim=512 + relu-batchnorm-layer name=tdnn2 $opts dim=512 input=Append(-1,0,1) + relu-batchnorm-layer name=tdnn3 $opts dim=512 + relu-batchnorm-layer name=tdnn4 $opts dim=512 input=Append(-1,0,1) + relu-batchnorm-layer name=tdnn5 $opts dim=512 + relu-batchnorm-layer name=tdnn6 $opts dim=512 input=Append(-3,0,3) + relu-batchnorm-layer name=tdnn7 $opts dim=512 input=Append(-3,0,3) + relu-batchnorm-layer name=tdnn8 $opts dim=512 input=Append(-6,-3,0) + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain $opts dim=512 target-rms=0.5 + output-layer name=output include-log-softmax=false $output_opts dim=$num_targets max-change=1.5 + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn8 $opts dim=512 target-rms=0.5 + output-layer name=output-xent $output_opts dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 14 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/chime5-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$decode_cmd" \ + --feat.online-ivector-dir=$train_ivector_dir \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=10 \ + --trainer.frames-per-iter=3000000 \ + --trainer.optimization.num-jobs-initial=2 \ + --trainer.optimization.num-jobs-final=4 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=256,128,64 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.chunk-left-context=$chunk_left_context \ + --egs.chunk-right-context=$chunk_right_context \ + --egs.chunk-left-context-initial=0 \ + --egs.chunk-right-context-final=0 \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 15 ]; then + # Note: it's not important to give mkgraph.sh the lang directory with the + # matched topology (since it gets the topology file from the model). + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/lang${lm_suffix}/ \ + $tree_dir $tree_dir/graph${lm_suffix} || exit 1; +fi + +if [ $stage -le 16 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + rm $dir/.error 2>/dev/null || true + + for data in $test_sets; do + ( + steps/nnet3/decode.sh \ + --acwt 1.0 --post-decode-acwt 10.0 \ + --extra-left-context $chunk_left_context \ + --extra-right-context $chunk_right_context \ + --extra-left-context-initial 0 \ + --extra-right-context-final 0 \ + --frames-per-chunk $frames_per_chunk \ + --nj 8 --cmd "$decode_cmd" --num-threads 4 \ + --online-ivector-dir exp/nnet3${nnet3_affix}/ivectors_${data}_hires \ + $tree_dir/graph${lm_suffix} data/${data}_hires ${dir}/decode${lm_suffix}_${data} || exit 1 + ) || touch $dir/.error & + done + wait + [ -f $dir/.error ] && echo "$0: there was a problem while decoding" && exit 1 +fi + +# Not testing the 'looped' decoding separately, because for +# TDNN systems it would give exactly the same results as the +# normal decoding. + +if $test_online_decoding && [ $stage -le 17 ]; then + # note: if the features change (e.g. you add pitch features), you will have to + # change the options of the following command line. + steps/online/nnet3/prepare_online_decoding.sh \ + --mfcc-config conf/mfcc_hires.conf \ + $lang exp/nnet3${nnet3_affix}/extractor ${dir} ${dir}_online + + rm $dir/.error 2>/dev/null || true + + for data in $test_sets; do + ( + nspk=$(wc -l $lang/topo + fi +fi + +if [ $stage -le 11 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj ${nj} --cmd "$train_cmd" ${lores_train_data_dir} \ + data/lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 12 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 3500 ${lores_train_data_dir} \ + $lang $ali_dir $tree_dir +fi + +if [ $stage -le 13 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + + lstm_opts="decay-time=40" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-2,-1,0,1,2,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=$hidden_dim + relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1) dim=$hidden_dim + relu-batchnorm-layer name=tdnn3 input=Append(-1,0,1) dim=$hidden_dim + + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-3 dropout-proportion=0.0 $lstm_opts + relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=$hidden_dim + relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=$hidden_dim + fast-lstmp-layer name=lstm2 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-3 dropout-proportion=0.0 $lstm_opts + relu-batchnorm-layer name=tdnn6 input=Append(-3,0,3) dim=$hidden_dim + relu-batchnorm-layer name=tdnn7 input=Append(-3,0,3) dim=$hidden_dim + fast-lstmp-layer name=lstm3 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-3 dropout-proportion=0.0 $lstm_opts + relu-batchnorm-layer name=tdnn8 input=Append(-3,0,3) dim=$hidden_dim + relu-batchnorm-layer name=tdnn9 input=Append(-3,0,3) dim=$hidden_dim + fast-lstmp-layer name=lstm4 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-3 dropout-proportion=0.0 $lstm_opts + + ## adding the layers for chain branch + output-layer name=output input=lstm4 output-delay=$label_delay include-log-softmax=false dim=$num_targets max-change=1.5 + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + output-layer name=output-xent input=lstm4 output-delay=$label_delay dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 + +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 14 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/chime5-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + mkdir -p $dir/egs + touch $dir/egs/.nodelete # keep egs around when that run dies. + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$train_cmd --mem 4G" \ + --feat.online-ivector-dir=$train_ivector_dir \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.num-chunk-per-minibatch 64,32 \ + --trainer.frames-per-iter 1500000 \ + --trainer.max-param-change 2.0 \ + --trainer.num-epochs $num_epochs \ + --trainer.srand=$srand \ + --trainer.optimization.shrink-value 0.99 \ + --trainer.optimization.num-jobs-initial=3 \ + --trainer.optimization.num-jobs-final=16 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.momentum=0.0 \ + --trainer.deriv-truncate-margin 8 \ + --egs.stage $get_egs_stage \ + --egs.opts="--frames-overlap-per-eg 0" \ + --egs.chunk-width=$chunk_width \ + --egs.chunk-left-context=$chunk_left_context \ + --egs.chunk-right-context=$chunk_right_context \ + --egs.chunk-left-context-initial=0 \ + --egs.chunk-right-context-final=0 \ + --egs.dir="$common_egs_dir" \ + --cleanup.remove-egs=$remove_egs \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 15 ]; then + # Note: it's not important to give mkgraph.sh the lang directory with the + # matched topology (since it gets the topology file from the model). + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/lang${lm_suffix}/ \ + $tree_dir $tree_dir/graph${lm_suffix} || exit 1; +fi + +if [ $stage -le 16 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + rm $dir/.error 2>/dev/null || true + + for data in $test_sets; do + ( + steps/nnet3/decode.sh \ + --acwt 1.0 --post-decode-acwt 10.0 \ + --extra-left-context $chunk_left_context \ + --extra-right-context $chunk_right_context \ + --extra-left-context-initial 0 \ + --extra-right-context-final 0 \ + --frames-per-chunk $frames_per_chunk \ + --nj 8 --cmd "$decode_cmd" --num-threads 4 \ + --online-ivector-dir exp/nnet3${nnet3_affix}/ivectors_${data}_hires \ + $tree_dir/graph${lm_suffix} data/${data}_hires ${dir}/decode${lm_suffix}_${data} || exit 1 + ) || touch $dir/.error & + done + wait + [ -f $dir/.error ] && echo "$0: there was a problem while decoding" && exit 1 +fi + +# Not testing the 'looped' decoding separately, because for +# TDNN systems it would give exactly the same results as the +# normal decoding. + +if $test_online_decoding && [ $stage -le 17 ]; then + # note: if the features change (e.g. you add pitch features), you will have to + # change the options of the following command line. + steps/online/nnet3/prepare_online_decoding.sh \ + --mfcc-config conf/mfcc_hires.conf \ + $lang exp/nnet3${nnet3_affix}/extractor ${dir} ${dir}_online + + rm $dir/.error 2>/dev/null || true + + for data in $test_sets; do + ( + nspk=$(wc -l " + echo "options" + echo " --cmd # Command to run in parallel with" + echo " --nch # nch of WPE to use for computing SE scores" + echo " --enable_pesq # Boolean flag to enable PESQ" + exit 1; +fi + +reverb_data=$1 +enhancement_directory=$2 +pesqdir=$3 +enhancement_directory_sim=$enhancement_directory/WPE/${nch}ch/REVERB_WSJCAM0_dt/data/ +enhancement_directory_real=$enhancement_directory/WPE/${nch}ch/MC_WSJ_AV_Dev/ +expdir=${PWD}/exp/compute_se_${nch}ch +if $enable_pesq; then + compute_pesq=1 +else + compute_pesq=0 +fi + +pushd local/REVERB_scores_source/REVERB-SPEENHA.Release04Oct/evaltools +$cmd $expdir/compute_se_real.log matlab -nodisplay -nosplash -r "addpath('SRMRToolbox'); score_RealData('$reverb_data','$enhancement_directory_real');exit" +$cmd $expdir/compute_se_sim.log matlab -nodisplay -nosplash -r "addpath('SRMRToolbox'); score_SimData('$reverb_data','$enhancement_directory_sim','$pesqdir',$compute_pesq);exit" +popd +rm -rf $expdir/scores +mv local/REVERB_scores_source/REVERB-SPEENHA.Release04Oct/scores $expdir/ diff --git a/egs/reverb/s5/local/download_se_eval_tool.sh b/egs/reverb/s5/local/download_se_eval_tool.sh new file mode 100755 index 00000000000..0d7bb8305ea --- /dev/null +++ b/egs/reverb/s5/local/download_se_eval_tool.sh @@ -0,0 +1,33 @@ +#!/bin/bash +# Copyright 2018 Johns Hopkins University (Author: Aswin Shanmugam Subramanian) +# This script downloads the official REVERB challenge SE scripts and SRMR toolbox +# This script also downloads and compiles PESQ +# please make sure that you or your institution have the license to report PESQ +# Apache 2.0 + +wget 'https://www.itu.int/rec/dologin_pub.asp?lang=e&id=T-REC-P.862-200102-I!!SOFT-ZST-E&type=items' -O PESQ.zip +unzip PESQ.zip -d local/PESQ_sources +rm PESQ.zip +cd local/PESQ_sources/P862/Software/source +gcc *.c -lm -o PESQ +cd ../../../../../ +mv local/PESQ_sources/P862/Software/source/PESQ local/ + +wget 'https://reverb2014.dereverberation.com/tools/REVERB-SPEENHA.Release04Oct.zip' -O REVERB_scores.zip +unzip REVERB_scores.zip -d local/REVERB_scores_source +rm REVERB_scores.zip + +pushd local/REVERB_scores_source/REVERB-SPEENHA.Release04Oct/evaltools +perl -i -pe 's/wavread/audioread/g' prog/score_sim.m +git clone https://github.com/MuSAELab/SRMRToolbox.git +perl -i -pe 's/wavread/audioread/g' SRMRToolbox/libs/preprocess.m +perl -i -pe 's/SRMR_main/SRMR/g' prog/score_real.m +perl -i -pe 's/SRMR_main/SRMR/g' prog/score_sim.m +perl -i -pe 's/\+wb //g' prog/calcpesq.m +perl -i -pe 's/pesq_/_pesq_/g' prog/calcpesq.m +perl -n -i -e 'print unless /remove target file name/' prog/calcpesq.m +patch score_RealData.m -i ../../../score_RealData.patch -o score_RealData_new.m +mv score_RealData_new.m score_RealData.m +patch score_SimData.m -i ../../../score_SimData.patch -o score_SimData_new.m +mv score_SimData_new.m score_SimData.m +popd diff --git a/egs/reverb/s5/local/generate_data.sh b/egs/reverb/s5/local/generate_data.sh new file mode 100755 index 00000000000..3228f0e1b3c --- /dev/null +++ b/egs/reverb/s5/local/generate_data.sh @@ -0,0 +1,84 @@ +#!/bin/bash +# +# Copyright 2018 Johns Hopkins University (Author: Shinji Watanabe) +# Apache 2.0 +# This script is adapted from data preprations scripts in the Kaldi reverb recipe +# https://github.com/kaldi-asr/kaldi/tree/master/egs/reverb/s5/local + +# Begin configuration section. +wavdir=${PWD}/wav +# End configuration section + +. ./utils/parse_options.sh # accept options.. you can run this run.sh with the + +. ./path.sh + +echo >&2 "$0" "$@" +if [ $# -ne 1 ] ; then + echo >&2 "$0" "$@" + echo >&2 "$0: Error: wrong number of arguments" + echo -e >&2 "Usage:\n $0 [opts] " + echo -e >&2 "eg:\n $0 /export/corpora3/LDC/LDC95S24/wsjcam0" + exit 1 +fi + +set -e -o pipefail + +wsjcam0=$1 +mkdir -p ${wavdir} + +# tool directory +dir=${PWD}/data/local/reverb_tools +mkdir -p ${dir} + +# Download tools +URL1="http://reverb2014.dereverberation.com/tools/reverb_tools_for_Generate_mcTrainData.tgz" +URL2="http://reverb2014.dereverberation.com/tools/REVERB_TOOLS_FOR_ASR_ver2.0.tgz" +for f in $URL1 $URL2; do + x=`basename $f` + if [ ! -e $dir/$x ]; then + wget $f -O $dir/$x || exit 1; + tar zxvf $dir/$x -C $dir || exit 1; + fi +done +URL3="http://reverb2014.dereverberation.com/tools/taskFiles_et.tgz" +x=`basename $URL3` +if [ ! -e $dir/$x ]; then + wget $URL3 -O $dir/$x || exit 1; + tar zxvf $dir/$x -C $dir || exit 1; + cp -fr $dir/`basename $x .tgz`/* $dir/ReleasePackage/reverb_tools_for_asr_ver2.0/taskFiles/ +fi + +# generate WAV files for matlab +echo "generating WAV files" +sph2pipe=$KALDI_ROOT/tools/sph2pipe_v2.5/sph2pipe +if [ ! -x $sph2pipe ]; then + echo "Could not find (or execute) the sph2pipe program at ${sph2pipe}"; + exit 1; +fi +for sph in `cat ${dir}/reverb_tools_for_Generate_mcTrainData/etc/audio_si_tr.lst`; do + d=`dirname ${wavdir}/WSJCAM0/data/${sph}` + if [ ! -d "${d}" ]; then + mkdir -p ${d} + fi + ${sph2pipe} -f wav ${wsjcam0}/data/${sph}.wv1 > ${wavdir}/WSJCAM0/data/${sph}.wav +done +nwav=`find ${wavdir}/WSJCAM0/data/primary_microphone/si_tr | grep .wav | wc -l` +echo "generated ${nwav} WAV files (it must be 7861)" +[ "$nwav" -eq 7861 ] || echo "Warning: expected 7861 WAV files, got $nwav" + +# generalte training data +reverb_tr_dir=${wavdir}/REVERB_WSJCAM0_tr +cp local/Generate_mcTrainData_cut.m $dir/reverb_tools_for_Generate_mcTrainData/ +pushd $dir/reverb_tools_for_Generate_mcTrainData/ +tmpdir=`mktemp -d tempXXXXX ` +tmpmfile=$tmpdir/run_mat.m +cat < $tmpmfile +addpath(genpath('.')) +Generate_mcTrainData_cut('$wavdir/WSJCAM0', '$reverb_tr_dir'); +EOF +cat $tmpmfile | matlab -nodisplay +rm -rf $tmpdir +popd + +echo "Successfully generated multi-condition training data and stored it in $reverb_tr_dir." && exit 0; diff --git a/egs/reverb/s5/local/get_results.sh b/egs/reverb/s5/local/get_results.sh index 7c74736e5d1..8867961dcdd 100755 --- a/egs/reverb/s5/local/get_results.sh +++ b/egs/reverb/s5/local/get_results.sh @@ -1,18 +1,86 @@ #!/bin/bash -# Reproduce selected results in Table 1 from Weninger et al. (2014) # "Our baselines" - -# LDA-STC fMLLR MCT DT LM MBR -# No No No No BG No -local/calc_wer.sh -# No No Yes No BG No -local/calc_wer.sh --am tri2a_mc -# No Yes Yes No BG No -local/calc_wer.sh --am tri2a_mc --decode basis_fmllr -# Yes Yes Yes No TG No -local/calc_wer.sh --am tri2b_mc --lm tg_5k --decode basis_fmllr -# Yes Yes Yes Yes TG No -local/calc_wer.sh --am tri2b_mc_mmi_b0.1 --lm tg_5k --decode basis_fmllr -# Yes Yes Yes Yes TG Yes -local/calc_wer.sh --am tri2b_mc_mmi_b0.1 --lm tg_5k --decode mbr_basis_fmllr +echo "########################################" +echo "GMM RESULTs:" +echo "exp/tri3/decode_dt_real_1ch" +cat exp/tri3/decode_dt_real_1ch/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_dt_simu_1ch" +cat exp/tri3/decode_dt_simu_1ch/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_et_real_1ch" +cat exp/tri3/decode_et_real_1ch/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_et_simu_1ch" +cat exp/tri3/decode_et_simu_1ch/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_dt_real_1ch_wpe" +cat exp/tri3/decode_dt_real_1ch_wpe/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_dt_simu_1ch_wpe" +cat exp/tri3/decode_dt_simu_1ch_wpe/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_et_real_1ch_wpe" +cat exp/tri3/decode_et_real_1ch_wpe/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_et_simu_1ch_wpe" +cat exp/tri3/decode_et_simu_1ch_wpe/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_dt_real_2ch_wpe" +cat exp/tri3/decode_dt_real_2ch_wpe/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_dt_simu_2ch_wpe" +cat exp/tri3/decode_dt_simu_2ch_wpe/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_et_real_2ch_wpe" +cat exp/tri3/decode_et_real_2ch_wpe/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_et_simu_2ch_wpe" +cat exp/tri3/decode_et_simu_2ch_wpe/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_dt_real_8ch_wpe" +cat exp/tri3/decode_dt_real_8ch_wpe/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_dt_simu_8ch_wpe" +cat exp/tri3/decode_dt_simu_8ch_wpe/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_et_real_8ch_wpe" +cat exp/tri3/decode_et_real_8ch_wpe/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_et_simu_8ch_wpe" +cat exp/tri3/decode_et_simu_8ch_wpe/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_dt_real_2ch_beamformit" +cat exp/tri3/decode_dt_real_2ch_beamformit/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_dt_simu_2ch_beamformit" +cat exp/tri3/decode_dt_simu_2ch_beamformit/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_et_real_2ch_beamformit" +cat exp/tri3/decode_et_real_2ch_beamformit/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_et_simu_2ch_beamformit" +cat exp/tri3/decode_et_simu_2ch_beamformit/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_dt_real_8ch_beamformit" +cat exp/tri3/decode_dt_real_8ch_beamformit/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_dt_simu_8ch_beamformit" +cat exp/tri3/decode_dt_simu_8ch_beamformit/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_et_real_8ch_beamformit" +cat exp/tri3/decode_et_real_8ch_beamformit/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_dt_cln" +cat exp/tri3/decode_dt_cln/scoring_kaldi/best_wer* +echo "" +echo "exp/tri3/decode_et_cln" +cat exp/tri3/decode_et_cln/scoring_kaldi/best_wer* +echo "########################################" +echo "TDNN RESULTs:" +echo "exp/chain_tr_simu_8ch/tdnn1a_sp/decode_test_tg_5k_dt*" +cat exp/chain_tr_simu_8ch/tdnn1a_sp/decode_test_tg_5k_dt*/scoring_kaldi/best_wer_* +echo "" +echo "exp/chain_tr_simu_8ch/tdnn1a_sp/decode_test_tg_5k_et*" +cat exp/chain_tr_simu_8ch/tdnn1a_sp/decode_test_tg_5k_et*/scoring_kaldi/best_wer_* diff --git a/egs/reverb/s5/local/nnet3/compare_wer.sh b/egs/reverb/s5/local/nnet3/compare_wer.sh new file mode 100755 index 00000000000..095e85cc338 --- /dev/null +++ b/egs/reverb/s5/local/nnet3/compare_wer.sh @@ -0,0 +1,132 @@ +#!/bin/bash + +# this script is used for comparing decoding results between systems. +# e.g. local/chain/compare_wer.sh exp/chain/tdnn_{c,d}_sp +# For use with discriminatively trained systems you specify the epochs after a colon: +# for instance, +# local/chain/compare_wer.sh exp/chain/tdnn_c_sp exp/chain/tdnn_c_sp_smbr:{1,2,3} + + +if [ $# == 0 ]; then + echo "Usage: $0: [--looped] [--online] [ ... ]" + echo "e.g.: $0 exp/chain/tdnn_{b,c}_sp" + echo "or (with epoch numbers for discriminative training):" + echo "$0 exp/chain/tdnn_b_sp_disc:{1,2,3}" + exit 1 +fi + +echo "# $0 $*" + +include_looped=false +if [ "$1" == "--looped" ]; then + include_looped=true + shift +fi +include_online=false +if [ "$1" == "--online" ]; then + include_online=true + shift +fi + + +used_epochs=false + +# this function set_names is used to separate the epoch-related parts of the name +# [for discriminative training] and the regular parts of the name. +# If called with a colon-free directory name, like: +# set_names exp/chain/tdnn_lstm1e_sp_bi_smbr +# it will set dir=exp/chain/tdnn_lstm1e_sp_bi_smbr and epoch_infix="" +# If called with something like: +# set_names exp/chain/tdnn_d_sp_smbr:3 +# it will set dir=exp/chain/tdnn_d_sp_smbr and epoch_infix="_epoch3" + + +set_names() { + if [ $# != 1 ]; then + echo "compare_wer_general.sh: internal error" + exit 1 # exit the program + fi + dirname=$(echo $1 | cut -d: -f1) + epoch=$(echo $1 | cut -s -d: -f2) + if [ -z $epoch ]; then + epoch_infix="" + else + used_epochs=true + epoch_infix=_epoch${epoch} + fi +} + + + +echo -n "# System " +for x in $*; do printf "% 10s" " $(basename $x)"; done +echo + +strings=( + "#WER dev_clean_2 (tgsmall) " + "#WER dev_clean_2 (tglarge) ") + +for n in 0 1; do + echo -n "${strings[$n]}" + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + decode_names=(tgsmall_dev_clean_2 tglarge_dev_clean_2) + + wer=$(cat $dirname/decode_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + if $include_looped; then + echo -n "# [looped:] " + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + wer=$(cat $dirname/decode_looped_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + fi + if $include_online; then + echo -n "# [online:] " + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + wer=$(cat ${dirname}_online/decode_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + fi +done + + +if $used_epochs; then + exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. +fi + +echo -n "# Final train prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.{final,combined}.log 2>/dev/null | grep log-like | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.{final,combined}.log 2>/dev/null | grep log-like | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final train acc " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.{final,combined}.log 2>/dev/null | grep accuracy | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid acc " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.{final,combined}.log 2>/dev/null | grep accuracy | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo diff --git a/egs/reverb/s5/local/nnet3/run_ivector_common.sh b/egs/reverb/s5/local/nnet3/run_ivector_common.sh new file mode 100755 index 00000000000..3af3ad77565 --- /dev/null +++ b/egs/reverb/s5/local/nnet3/run_ivector_common.sh @@ -0,0 +1,149 @@ +#!/bin/bash + +set -euo pipefail + +# This script is called from local/nnet3/run_tdnn.sh and +# local/chain/run_tdnn.sh (and may eventually be called by more +# scripts). It contains the common feature preparation and +# iVector-related parts of the script. See those scripts for examples +# of usage. + +stage=0 +train_set=train_worn_u100k +test_sets="dev_worn dev_beamformit_ref" +gmm=tri3 +nj=96 + +nnet3_affix=_train_worn_u100k + +. ./cmd.sh +. ./path.sh +. utils/parse_options.sh + +gmm_dir=exp/${gmm} +ali_dir=exp/${gmm}_ali_${train_set}_sp + +for f in data/${train_set}/feats.scp ${gmm_dir}/final.mdl; do + if [ ! -f $f ]; then + echo "$0: expected file $f to exist" + exit 1 + fi +done + +if [ $stage -le 1 ]; then + # Although the nnet will be trained by high resolution data, we still have to + # perturb the normal data to get the alignment _sp stands for speed-perturbed + echo "$0: preparing directory for low-resolution speed-perturbed data (for alignment)" + utils/data/perturb_data_dir_speed_3way.sh data/${train_set} data/${train_set}_sp + echo "$0: making MFCC features for low-resolution speed-perturbed data" + steps/make_mfcc.sh --cmd "$train_cmd" --nj 20 data/${train_set}_sp || exit 1; + steps/compute_cmvn_stats.sh data/${train_set}_sp || exit 1; + utils/fix_data_dir.sh data/${train_set}_sp +fi + +if [ $stage -le 2 ]; then + echo "$0: aligning with the perturbed low-resolution data" + steps/align_fmllr.sh --nj ${nj} --cmd "$train_cmd" \ + data/${train_set}_sp data/lang $gmm_dir $ali_dir || exit 1 +fi + +if [ $stage -le 3 ]; then + # Create high-resolution MFCC features (with 40 cepstra instead of 13). + # this shows how you can split across multiple file-systems. + echo "$0: creating high-resolution MFCC features" + mfccdir=data/${train_set}_sp_hires/data + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl /export/b1{4,5,6,8}/$USER/kaldi-data/mfcc/reverb-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage + fi + + for datadir in ${train_set}_sp ${test_sets}; do + utils/copy_data_dir.sh data/$datadir data/${datadir}_hires + done + + # do volume-perturbation on the training data prior to extracting hires + # features; this helps make trained nnets more invariant to test data volume. + utils/data/perturb_data_dir_volume.sh data/${train_set}_sp_hires || exit 1; + + for datadir in ${train_set}_sp ${test_sets}; do + steps/make_mfcc.sh --nj 20 --mfcc-config conf/mfcc_hires.conf \ + --cmd "$train_cmd" data/${datadir}_hires || exit 1; + steps/compute_cmvn_stats.sh data/${datadir}_hires || exit 1; + utils/fix_data_dir.sh data/${datadir}_hires || exit 1; + done +fi + +if [ $stage -le 4 ]; then + echo "$0: computing a subset of data to train the diagonal UBM." + # We'll use about a quarter of the data. + mkdir -p exp/nnet3${nnet3_affix}/diag_ubm + temp_data_root=exp/nnet3${nnet3_affix}/diag_ubm + + num_utts_total=$(wc -l &2 "$0" "$@" +if [ $# -ne 1 ] ; then + echo >&2 "$0" "$@" + echo >&2 "$0: Error: wrong number of arguments" + echo -e >&2 "Usage:\n $0 [opts] " + echo -e >&2 "eg:\n $0 /export/corpora5/REVERB_2014/REVERB" + exit 1 +fi + +set -e -o pipefail + +reverb=$1 + +# working directory +dir=${PWD}/data/local/data +mkdir -p ${dir} + +for task in dt et; do + if [ ${task} == 'dt' ]; then + mlf=${reverb}/MC_WSJ_AV_Dev/mlf/WSJ.mlf + elif [ ${task} == 'et' ]; then + mlf=${reverb}/MC_WSJ_AV_Eval/mlf/WSJ.mlf + fi + # MLF transcription correction + # taken from HTK baseline script + sed -e ' +# dos to unix line feed conversion +s/\x0D$//' \ + -e " + s/\x60//g # remove unicode character grave accent. + " \ + -e " + # fix the single quote for the word yield + # and the quoted ROOTS + # e.g. yield' --> yield + # reason: YIELD' is not in dict, while YIELD is + s/YIELD'/YIELD/g + s/'ROOTS'/ROOTS/g + s/'WHERE/WHERE/g + s/PEOPLE'/PEOPLE/g + s/SIT'/SIT/g + s/'DOMINEE/DOMINEE/g + s/CHURCH'/CHURCH/g" \ + -e ' + # fix the single missing double full stop issue at the end of an utterance + # e.g. I. C. N should be I. C. N. + # reason: N is not in dict, while N. is + /^[A-Z]$/ { + # append a line + N + # search for single dot on the second line + /\n\./ { + # found it - now replace the + s/\([A-Z]\)\n\./\1\.\n\./ + } + }' \ + $mlf |\ + perl local/mlf2text.pl > ${dir}/${task}.txt +done + + +noiseword=""; +for nch in 1 2 8; do + taskdir=data/local/reverb_tools/ReleasePackage/reverb_tools_for_asr_ver2.0/taskFiles/${nch}ch + # make a wav list + for task in dt et; do + if [ ${task} == 'dt' ]; then + audiodir=${reverb}/MC_WSJ_AV_Dev + audiodir_wpe=${wavdir}/WPE/${nch}ch/MC_WSJ_AV_Dev + elif [ ${task} == 'et' ]; then + audiodir=${reverb}/MC_WSJ_AV_Eval + audiodir_wpe=${wavdir}/WPE/${nch}ch/MC_WSJ_AV_Eval + fi + for x in `ls ${taskdir} | grep RealData | grep _${task}_`; do + perl -se 'while(<>){m:^\S+/[\w\-]*_(T\w{6,7})\.wav$: || die "Bad line $_"; $id = lc $1; print "$id $dir$_";}' -- -dir=${audiodir} ${taskdir}/$x |\ + sed -e "s/^\(...\)/\1_${x}_\1/" + done > ${dir}/${task}_real_${nch}ch_wav.scp + for x in `ls ${taskdir} | grep RealData | grep _${task}_`; do + perl -se 'while(<>){m:^\S+/[\w\-]*_(T\w{6,7})\.wav$: || die "Bad line $_"; $id = lc $1; print "$id $dir$_";}' -- -dir=${audiodir_wpe} ${taskdir}/$x |\ + sed -e "s/^\(...\)/\1_${x}_\1/" + done > ${dir}/${task}_real_${nch}ch_wpe_wav.scp + done + # make a transcript + for task in dt et; do + for x in `ls ${taskdir} | grep RealData | grep _${task}_`; do + perl -se 'while(<>){m:^\S+/[\w\-]*_(T\w{6,7})\.wav$: || die "Bad line $_"; $id = lc $1; print "$id\n";}' ${taskdir}/$x |\ + perl local/find_transcripts_txt.pl ${dir}/${task}.txt |\ + sed -e "s/^\(...\)/\1_${x}_\1/" + done > ${dir}/${task}_real_${nch}ch.trans1 || exit 1; + cat ${dir}/${task}_real_${nch}ch.trans1 | local/normalize_transcript.pl ${noiseword} > ${dir}/${task}_real_${nch}ch.txt || exit 1; + done + + # Make the utt2spk and spk2utt files. + for task in dt et; do + cat ${dir}/${task}_real_${nch}ch_wav.scp | awk '{print $1}' | awk -F '_' '{print $0 " " $1}' > ${dir}/${task}_real_${nch}ch.utt2spk || exit 1; + cat ${dir}/${task}_real_${nch}ch.utt2spk | ./utils/utt2spk_to_spk2utt.pl > ${dir}/${task}_real_${nch}ch.spk2utt || exit 1; + done +done + +# finally copy the above files to the data directory +for nch in 1 2 8; do + for task in dt et; do + datadir=data/${task}_real_${nch}ch + mkdir -p ${datadir} + sort ${dir}/${task}_real_${nch}ch_wav.scp > ${datadir}/wav.scp + sort ${dir}/${task}_real_${nch}ch.txt > ${datadir}/text + sort ${dir}/${task}_real_${nch}ch.utt2spk > ${datadir}/utt2spk + sort ${dir}/${task}_real_${nch}ch.spk2utt > ${datadir}/spk2utt + ./utils/fix_data_dir.sh ${datadir} + if [ ${nch} != 1 ]; then + datadir=data/${task}_real_${nch}ch_beamformit + mkdir -p ${datadir} + sort ${dir}/${task}_real_1ch_wpe_wav.scp | sed -e "s/-[1-8]_/-bf${nch}_/" | sed -e "s/WPE\/1ch/WPE\/${nch}ch/" > ${datadir}/wav.scp + sort ${dir}/${task}_real_1ch.txt > ${datadir}/text + sort ${dir}/${task}_real_1ch.utt2spk > ${datadir}/utt2spk + sort ${dir}/${task}_real_1ch.spk2utt > ${datadir}/spk2utt + ./utils/fix_data_dir.sh ${datadir} + fi + datadir=data/${task}_real_${nch}ch_wpe + mkdir -p ${datadir} + sort ${dir}/${task}_real_1ch_wpe_wav.scp | sed -e "s/WPE\/1ch/WPE\/${nch}ch/" > ${datadir}/wav.scp + sort ${dir}/${task}_real_1ch.txt > ${datadir}/text + sort ${dir}/${task}_real_1ch.utt2spk > ${datadir}/utt2spk + sort ${dir}/${task}_real_1ch.spk2utt > ${datadir}/spk2utt + ./utils/fix_data_dir.sh ${datadir} + done +done diff --git a/egs/reverb/s5/local/prepare_simu_data.sh b/egs/reverb/s5/local/prepare_simu_data.sh new file mode 100755 index 00000000000..8757021ddd7 --- /dev/null +++ b/egs/reverb/s5/local/prepare_simu_data.sh @@ -0,0 +1,150 @@ +#!/bin/bash +# +# Copyright 2018 Johns Hopkins University (Author: Shinji Watanabe) +# Copyright 2018 Johns Hopkins University (Author: Aswin Shanmugam Subramanian) +# Apache 2.0 +# This script is adapted from data preparation scripts in the Kaldi reverb recipe +# https://github.com/kaldi-asr/kaldi/tree/master/egs/reverb/s5/local + +# Begin configuration section. +wavdir=${PWD}/wav +# End configuration section +. ./utils/parse_options.sh # accept options.. you can run this run.sh with the + +. ./path.sh + +echo >&2 "$0" "$@" +if [ $# -ne 2 ] ; then + echo >&2 "$0" "$@" + echo >&2 "$0: Error: wrong number of arguments" + echo -e >&2 "Usage:\n $0 [opts] " + echo -e >&2 "eg:\n $0 /export/corpora5/REVERB_2014/REVERB /export/corpora3/LDC/LDC95S24/wsjcam0" + exit 1 +fi + +set -e -o pipefail + +reverb=$1 +wsjcam0=$2 + +# tool directory +tooldir=${PWD}/data/local/reverb_tools + +# working directory +dir=${PWD}/data/local/data +mkdir -p ${dir} + +# make a one dot file for train, dev, and eval data +# the directory structure of WSJCAM0 is not consistent and we need such process for each task +cp ${wsjcam0}/data/primary_microphone/etc/si_tr.dot ${dir}/tr.dot +cat ${wsjcam0}/data/primary_microphone/etc/si_dt*.dot | sort > ${dir}/dt.dot +cat ${wsjcam0}/data/*/si_et*/*/*.dot | sort > ${dir}/et.dot + +noiseword=""; +for nch in 1 2 8; do + taskdir=data/local/reverb_tools/ReleasePackage/reverb_tools_for_asr_ver2.0/taskFiles/${nch}ch + # make a wav list + task=tr + for x in `ls ${taskdir} | grep SimData | grep _${task}_`; do + perl -se 'while (<>) { chomp; if (m/\/(\w{8})[^\/]+$/) { print $1, " ", $dir, $_, "\n"; } }' -- -dir=${wavdir}/REVERB_WSJCAM0_${task}/data ${taskdir}/$x |\ + sed -e "s/^\(...\)/\1_${x}_\1/" + done > ${dir}/${task}_simu_${nch}ch_wav.scp + for task in dt et; do + for x in `ls ${taskdir} | grep SimData | grep _${task}_ | grep -e far -e near`; do + perl -se 'while (<>) { chomp; if (m/\/(\w{8})[^\/]+$/) { print $1, " ", $dir, $_, "\n"; } }' -- -dir=${reverb}/REVERB_WSJCAM0_${task}/data ${taskdir}/$x |\ + sed -e "s/^\(...\)/\1_${x}_\1/" + done > ${dir}/${task}_simu_${nch}ch_wav.scp + if [ ${nch} == 1 ]; then + for x in `ls ${taskdir} | grep SimData | grep _${task}_ | grep -e cln`; do + perl -se 'while (<>) { chomp; if (m/\/(\w{8})[^\/]+$/) { print $1, " ", $dir, $_, "\n"; } }' -- -dir=${reverb}/REVERB_WSJCAM0_${task}/data ${taskdir}/$x |\ + sed -e "s/^\(...\)/\1_${x}_\1/" + done > ${dir}/${task}_cln_wav.scp + fi + done + + task=tr + for x in `ls ${taskdir} | grep SimData | grep _${task}_`; do + perl -se 'while (<>) { chomp; if (m/\/(\w{8})[^\/]+$/) { print $1, " ", $dir, $_, "\n"; } }' -- -dir=${wavdir}/WPE/${nch}ch/REVERB_WSJCAM0_${task}/data ${taskdir}/$x |\ + sed -e "s/^\(...\)/\1_${x}_\1/" + done > ${dir}/${task}_simu_${nch}ch_wpe_wav.scp + for task in dt et; do + for x in `ls ${taskdir} | grep SimData | grep _${task}_ | grep -e far -e near`; do + perl -se 'while (<>) { chomp; if (m/\/(\w{8})[^\/]+$/) { print $1, " ", $dir, $_, "\n"; } }' -- -dir=${wavdir}/WPE/${nch}ch/REVERB_WSJCAM0_${task}/data ${taskdir}/$x |\ + sed -e "s/^\(...\)/\1_${x}_\1/" + done > ${dir}/${task}_simu_${nch}ch_wpe_wav.scp + done + + # make a transcript + task=tr + for x in `ls ${taskdir} | grep SimData | grep _${task}_`; do + perl -e 'while (<>) { chomp; if (m/\/(\w{8})[^\/]+$/) { print $1, "\n"; } }' ${taskdir}/$x |\ + perl local/find_transcripts_singledot.pl ${dir}/${task}.dot |\ + sed -e "s/^\(...\)/\1_${x}_\1/" + done > ${dir}/${task}_simu_${nch}ch.trans1 || exit 1; + cat ${dir}/${task}_simu_${nch}ch.trans1 | local/normalize_transcript.pl ${noiseword} > ${dir}/${task}_simu_${nch}ch.txt || exit 1; + for task in dt et; do + for x in `ls ${taskdir} | grep SimData | grep _${task}_ | grep -e far -e near`; do + perl -e 'while (<>) { chomp; if (m/\/(\w{8})[^\/]+$/) { print $1, "\n"; } }' ${taskdir}/$x |\ + perl local/find_transcripts_singledot.pl ${dir}/${task}.dot |\ + sed -e "s/^\(...\)/\1_${x}_\1/" + done > ${dir}/${task}_simu_${nch}ch.trans1 || exit 1; + cat ${dir}/${task}_simu_${nch}ch.trans1 | local/normalize_transcript.pl ${noiseword} > ${dir}/${task}_simu_${nch}ch.txt || exit 1; + if [ ${nch} == 1 ]; then + for x in `ls ${taskdir} | grep SimData | grep _${task}_ | grep -e cln`; do + perl -e 'while (<>) { chomp; if (m/\/(\w{8})[^\/]+$/) { print $1, "\n"; } }' ${taskdir}/$x |\ + perl local/find_transcripts_singledot.pl ${dir}/${task}.dot |\ + sed -e "s/^\(...\)/\1_${x}_\1/" + done > ${dir}/${task}_cln.trans1 || exit 1; + cat ${dir}/${task}_cln.trans1 | local/normalize_transcript.pl ${noiseword} > ${dir}/${task}_cln.txt || exit 1; + fi + done + + # Make the utt2spk and spk2utt files. + for task in tr dt et; do + cat ${dir}/${task}_simu_${nch}ch_wav.scp | awk '{print $1}' | awk -F '_' '{print $0 " " $1}' > ${dir}/${task}_simu_${nch}ch.utt2spk || exit 1; + cat ${dir}/${task}_simu_${nch}ch.utt2spk | ./utils/utt2spk_to_spk2utt.pl > ${dir}/${task}_simu_${nch}ch.spk2utt || exit 1; + done + for task in dt et; do + cat ${dir}/${task}_cln_wav.scp | awk '{print $1}' | awk -F '_' '{print $0 " " $1}' > ${dir}/${task}_cln.utt2spk || exit 1; + cat ${dir}/${task}_cln.utt2spk | ./utils/utt2spk_to_spk2utt.pl > ${dir}/${task}_cln.spk2utt || exit 1; + done +done + +# finally copy the above files to the data directory +for nch in 1 2 8; do + for task in tr dt et; do + datadir=data/${task}_simu_${nch}ch + mkdir -p ${datadir} + sort ${dir}/${task}_simu_${nch}ch_wav.scp > ${datadir}/wav.scp + sort ${dir}/${task}_simu_${nch}ch.txt > ${datadir}/text + sort ${dir}/${task}_simu_${nch}ch.utt2spk > ${datadir}/utt2spk + sort ${dir}/${task}_simu_${nch}ch.spk2utt > ${datadir}/spk2utt + ./utils/fix_data_dir.sh ${datadir} + if [ ${task} != 'tr' ]; then + datadir=data/${task}_simu_${nch}ch_wpe + mkdir -p ${datadir} + sort ${dir}/${task}_simu_1ch_wpe_wav.scp | sed -e "s/WPE\/1ch/WPE\/${nch}ch/" > ${datadir}/wav.scp + sort ${dir}/${task}_simu_1ch.txt > ${datadir}/text + sort ${dir}/${task}_simu_1ch.utt2spk > ${datadir}/utt2spk + sort ${dir}/${task}_simu_1ch.spk2utt > ${datadir}/spk2utt + ./utils/fix_data_dir.sh ${datadir} + if [ ${nch} != 1 ]; then + datadir=data/${task}_simu_${nch}ch_beamformit + mkdir -p ${datadir} + sort ${dir}/${task}_simu_1ch_wpe_wav.scp | sed -e "s/ch1/bf${nch}/" | sed -e "s/WPE\/1ch/WPE\/${nch}ch/" > ${datadir}/wav.scp + sort ${dir}/${task}_simu_1ch.txt > ${datadir}/text + sort ${dir}/${task}_simu_1ch.utt2spk > ${datadir}/utt2spk + sort ${dir}/${task}_simu_1ch.spk2utt > ${datadir}/spk2utt + ./utils/fix_data_dir.sh ${datadir} + else + datadir=data/${task}_cln + mkdir -p ${datadir} + sort ${dir}/${task}_cln_wav.scp > ${datadir}/wav.scp + sort ${dir}/${task}_cln.txt > ${datadir}/text + sort ${dir}/${task}_cln.utt2spk > ${datadir}/utt2spk + sort ${dir}/${task}_cln.spk2utt > ${datadir}/spk2utt + ./utils/fix_data_dir.sh ${datadir} + fi + fi + done +done diff --git a/egs/reverb/s5/local/run_beamform.sh b/egs/reverb/s5/local/run_beamform.sh new file mode 100755 index 00000000000..1c8aade7287 --- /dev/null +++ b/egs/reverb/s5/local/run_beamform.sh @@ -0,0 +1,142 @@ +#!/bin/bash + +# Copyright 2015, Mitsubishi Electric Research Laboratories, MERL (Author: Shinji Watanabe) +# Copyright 2018, Johns Hopkins University (Author: Aswin Shanmugam Subramanian) + +. ./cmd.sh +. ./path.sh + +# Config: +nj=50 +cmd=run.pl + +. utils/parse_options.sh || exit 1; + +if [ $# != 1 ]; then + echo "Wrong #arguments ($#, expected 1)" + echo "Usage: local/run_beamform.sh [options] " + echo "main options (for others, see top of script file)" + echo " --nj # number of parallel jobs" + echo " --cmd # Command to run in parallel with" + exit 1; +fi + +odir=$1 +dir=${PWD}/data/local/data + +if [ -z $BEAMFORMIT ] ; then + export BEAMFORMIT=$KALDI_ROOT/tools/extras/BeamformIt +fi +export PATH=${PATH}:$BEAMFORMIT +! hash BeamformIt && echo "Missing BeamformIt, run 'cd ../../../tools/; extras/install_beamformit.sh;'" && exit 1 + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +for task in dt et; do + for nch in 2 8; do + wdir=exp/beamform_real_${task}_${nch}ch + mkdir -p $wdir/log + arrays=$wdir/channels + output_wavfiles=$wdir/wavfiles.list + if [ ${nch} == 2 ]; then + allwavs=`cat ${dir}/${task}_real_${nch}ch_wpe_wav.scp | cut -d " " -f2` + allwavs_beamformit=`cat data/${task}_real_${nch}ch_beamformit/wav.scp | cut -d " " -f2` + echo $allwavs | tr ' ' '\n' | rev | sort | rev | awk 'NR%2==1' > $wdir/channels.1st + echo $allwavs | tr ' ' '\n' | rev | sort | rev | awk 'NR%2==0' > $wdir/channels.2nd + echo $allwavs_beamformit | tr ' ' '\n' | rev | sort | rev | awk -F 'WPE/' '{print $2}' | awk -F '.wav' '{print $1}' > $output_wavfiles + paste -d" " $output_wavfiles $wdir/channels.1st $wdir/channels.2nd > $arrays + elif [ ${nch} == 8 ]; then + allwavs=`cat ${dir}/${task}_real_${nch}ch_wpe_wav.scp | cut -d " " -f2` + allwavs_beamformit=`cat data/${task}_real_${nch}ch_beamformit/wav.scp | cut -d " " -f2` + echo $allwavs | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==1' > $wdir/channels.1st + echo $allwavs | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==2' > $wdir/channels.2nd + echo $allwavs | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==3' > $wdir/channels.3rd + echo $allwavs | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==4' > $wdir/channels.4th + echo $allwavs | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==5' > $wdir/channels.5th + echo $allwavs | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==6' > $wdir/channels.6th + echo $allwavs | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==7' > $wdir/channels.7th + echo $allwavs | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==0' > $wdir/channels.8th + echo $allwavs_beamformit | tr ' ' '\n' | rev | sort | rev | awk -F 'WPE/' '{print $2}' | awk -F '.wav' '{print $1}' > $output_wavfiles + paste -d" " $output_wavfiles $wdir/channels.1st $wdir/channels.2nd $wdir/channels.3rd $wdir/channels.4th $wdir/channels.5th $wdir/channels.6th $wdir/channels.7th $wdir/channels.8th > $arrays + fi + # split the list for parallel processing + split_wavfiles="" + for n in `seq $nj`; do + split_wavfiles="$split_wavfiles $output_wavfiles.$n" + done + utils/split_scp.pl $output_wavfiles $split_wavfiles || exit 1; + + echo -e "Beamforming - $task - real - $nch ch\n" + # making a shell script for each job + for n in `seq $nj`; do + cat <<-EOF > $wdir/log/beamform.$n.sh + while read line; do + $BEAMFORMIT/BeamformIt -s \$line -c $arrays \ + --config_file `pwd`/conf/reverb_beamformit.cfg \ + --result_dir $odir + done < $output_wavfiles.$n + EOF + done + + chmod a+x $wdir/log/beamform.*.sh + $cmd JOB=1:$nj $wdir/log/beamform.JOB.log \ + $wdir/log/beamform.JOB.sh + done +done + +for task in dt et; do + for nch in 2 8; do + wdir=exp/beamform_simu_${task}_${nch}ch + mkdir -p $wdir/log + arrays=$wdir/channels + output_wavfiles=$wdir/wavfiles.list + if [ ${nch} == 2 ]; then + allwavs=`cat ${dir}/${task}_simu_${nch}ch_wpe_wav.scp | grep "ch[1-2].wav" | cut -d " " -f2` + allwavs_beamformit=`cat data/${task}_simu_${nch}ch_beamformit/wav.scp | grep "bf2.wav" | cut -d " " -f2` + echo $allwavs | tr ' ' '\n' | grep 'ch1' | sort > $wdir/channels.1st + echo $allwavs | tr ' ' '\n' | grep 'ch2' | sort > $wdir/channels.2nd + echo $allwavs_beamformit | tr ' ' '\n' | awk -F 'WPE/' '{print $2}' | sort | awk -F '.wav' '{print $1}' > $output_wavfiles + paste -d" " $output_wavfiles $wdir/channels.1st $wdir/channels.2nd > $arrays + elif [ ${nch} == 8 ]; then + allwavs=`cat ${dir}/${task}_simu_${nch}ch_wpe_wav.scp | grep "ch[1-8].wav" | cut -d " " -f2` + allwavs_beamformit=`cat data/${task}_simu_${nch}ch_beamformit/wav.scp | grep "bf8.wav" | cut -d " " -f2` + echo $allwavs | tr ' ' '\n' | grep 'ch1' | sort > $wdir/channels.1st + echo $allwavs | tr ' ' '\n' | grep 'ch2' | sort > $wdir/channels.2nd + echo $allwavs | tr ' ' '\n' | grep 'ch3' | sort > $wdir/channels.3rd + echo $allwavs | tr ' ' '\n' | grep 'ch4' | sort > $wdir/channels.4th + echo $allwavs | tr ' ' '\n' | grep 'ch5' | sort > $wdir/channels.5th + echo $allwavs | tr ' ' '\n' | grep 'ch6' | sort > $wdir/channels.6th + echo $allwavs | tr ' ' '\n' | grep 'ch7' | sort > $wdir/channels.7th + echo $allwavs | tr ' ' '\n' | grep 'ch8' | sort > $wdir/channels.8th + echo $allwavs_beamformit | tr ' ' '\n' | awk -F 'WPE/' '{print $2}' | sort | awk -F '.wav' '{print $1}' > $output_wavfiles + paste -d" " $output_wavfiles $wdir/channels.1st $wdir/channels.2nd $wdir/channels.3rd $wdir/channels.4th $wdir/channels.5th $wdir/channels.6th $wdir/channels.7th $wdir/channels.8th > $arrays + fi + # split the list for parallel processing + split_wavfiles="" + for n in `seq $nj`; do + split_wavfiles="$split_wavfiles $output_wavfiles.$n" + done + utils/split_scp.pl $output_wavfiles $split_wavfiles || exit 1; + + echo -e "Beamforming - $task - simu - $nch ch\n" + # making a shell script for each job + for n in `seq $nj`; do + cat <<-EOF > $wdir/log/beamform.$n.sh + while read line; do + $BEAMFORMIT/BeamformIt -s \$line -c $arrays \ + --config_file `pwd`/conf/reverb_beamformit.cfg \ + --result_dir $odir + done < $output_wavfiles.$n + EOF + done + + chmod a+x $wdir/log/beamform.*.sh + $cmd JOB=1:$nj $wdir/log/beamform.JOB.log \ + $wdir/log/beamform.JOB.sh + done +done +echo "`basename $0` Done." diff --git a/egs/reverb/s5/local/run_wpe.py b/egs/reverb/s5/local/run_wpe.py new file mode 100644 index 00000000000..cc9cd41927a --- /dev/null +++ b/egs/reverb/s5/local/run_wpe.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python +# Copyright 2018 Johns Hopkins University (Author: Aswin Shanmugam Subramanian) +# Apache 2.0 +# Works with both python2 and python3 + +import numpy as np +import soundfile as sf +import time +import os, errno +from tqdm import tqdm +import argparse + +from nara_wpe.wpe import wpe +from nara_wpe.utils import stft, istft +from nara_wpe import project_root + +parser = argparse.ArgumentParser() +parser.add_argument('--files', '-f', nargs='+') +args = parser.parse_args() + +input_files = args.files[:len(args.files)//2] +output_files = args.files[len(args.files)//2:] +out_dir = os.path.dirname(output_files[0]) +try: + os.makedirs(out_dir) +except OSError as e: + if e.errno != errno.EEXIST: + raise + +stft_options = dict( + size=512, + shift=128, + window_length=None, + fading=True, + pad=True, + symmetric_window=False +) + +sampling_rate = 16000 +delay = 3 +iterations = 5 +taps = 10 + +signal_list = [ + sf.read(f)[0] + for f in input_files +] +y = np.stack(signal_list, axis=0) +Y = stft(y, **stft_options).transpose(2, 0, 1) +Z = wpe(Y, iterations=iterations, statistics_mode='full').transpose(1, 2, 0) +z = istft(Z, size=stft_options['size'], shift=stft_options['shift']) + +for d in range(len(signal_list)): + sf.write(output_files[d], z[d,:], sampling_rate) diff --git a/egs/reverb/s5/local/run_wpe.sh b/egs/reverb/s5/local/run_wpe.sh new file mode 100755 index 00000000000..d1ea56c6c55 --- /dev/null +++ b/egs/reverb/s5/local/run_wpe.sh @@ -0,0 +1,172 @@ +#!/bin/bash +# Copyright 2018 Johns Hopkins University (Author: Aswin Shanmugam Subramanian) +# Apache 2.0 + +. ./cmd.sh +. ./path.sh + +# Config: +nj=50 +cmd=run.pl + +. utils/parse_options.sh || exit 1; + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +miniconda_dir=$HOME/miniconda3/ +if [ ! -d $miniconda_dir ]; then + echo "$miniconda_dir does not exist. Please run '../../../tools/extras/install_miniconda.sh' and '../../../tools/extras/install_wpe.sh';" +fi + +# check if WPE is installed +result=`$HOME/miniconda3/bin/python -c "\ +try: + import nara_wpe + print('1') +except ImportError: + print('0')"` + +if [ "$result" == "1" ]; then + echo "WPE is installed" +else + echo "WPE is not installed. Please run ../../../tools/extras/install_wpe.sh" +fi + +dir=${PWD}/data/local/data + +for task in dt et; do + for nch in 1 2 8; do + wdir=exp/wpe_real_${task}_${nch}ch + mkdir -p $wdir/log + arrays=$wdir/channels + output_wavfiles=$wdir/wavfiles.list + if [ ${nch} == 1 ]; then + allwavs=`cat ${dir}/${task}_real_1ch_wav.scp | cut -d " " -f2` + allwavs_output=`cat ${dir}/${task}_real_1ch_wpe_wav.scp | cut -d " " -f2` + echo $allwavs | tr ' ' '\n' > $wdir/channels_input + echo $allwavs_output | tr ' ' '\n' > $wdir/channels_output + paste -d" " $wdir/channels_input $wdir/channels_output > $arrays + elif [ ${nch} == 2 ]; then + allwavs=`cat ${dir}/${task}_real_2ch_wav.scp | cut -d " " -f2` + allwavs_output=`cat ${dir}/${task}_real_2ch_wpe_wav.scp | cut -d " " -f2` + echo $allwavs | tr ' ' '\n' | rev | sort | rev | awk 'NR%2==1' > $wdir/channels.1st + echo $allwavs | tr ' ' '\n' | rev | sort | rev | awk 'NR%2==0' > $wdir/channels.2nd + echo $allwavs_output | tr ' ' '\n' | rev | sort | rev | awk 'NR%2==1' > $wdir/channels_output.1st + echo $allwavs_output | tr ' ' '\n' | rev | sort | rev | awk 'NR%2==0' > $wdir/channels_output.2nd + paste -d" " $wdir/channels.1st $wdir/channels.2nd $wdir/channels_output.1st $wdir/channels_output.2nd > $arrays + elif [ ${nch} == 8 ]; then + allwavs=`cat ${dir}/${task}_real_8ch_wav.scp | cut -d " " -f2` + allwavs_output=`cat ${dir}/${task}_real_8ch_wpe_wav.scp | cut -d " " -f2` + echo $allwavs | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==1' > $wdir/channels.1st + echo $allwavs | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==2' > $wdir/channels.2nd + echo $allwavs | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==3' > $wdir/channels.3rd + echo $allwavs | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==4' > $wdir/channels.4th + echo $allwavs | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==5' > $wdir/channels.5th + echo $allwavs | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==6' > $wdir/channels.6th + echo $allwavs | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==7' > $wdir/channels.7th + echo $allwavs | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==0' > $wdir/channels.8th + echo $allwavs_output | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==1' > $wdir/channels_output.1st + echo $allwavs_output | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==2' > $wdir/channels_output.2nd + echo $allwavs_output | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==3' > $wdir/channels_output.3rd + echo $allwavs_output | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==4' > $wdir/channels_output.4th + echo $allwavs_output | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==5' > $wdir/channels_output.5th + echo $allwavs_output | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==6' > $wdir/channels_output.6th + echo $allwavs_output | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==7' > $wdir/channels_output.7th + echo $allwavs_output | tr ' ' '\n' | rev | sort | rev | awk 'NR%8==0' > $wdir/channels_output.8th + paste -d" " $wdir/channels.1st $wdir/channels.2nd $wdir/channels.3rd $wdir/channels.4th $wdir/channels.5th $wdir/channels.6th $wdir/channels.7th $wdir/channels.8th $wdir/channels_output.1st $wdir/channels_output.2nd $wdir/channels_output.3rd $wdir/channels_output.4th $wdir/channels_output.5th $wdir/channels_output.6th $wdir/channels_output.7th $wdir/channels_output.8th > $arrays + fi + + # split the list for parallel processing + split_wavfiles="" + for n in `seq $nj`; do + split_wavfiles="$split_wavfiles $output_wavfiles.$n" + done + utils/split_scp.pl $arrays $split_wavfiles || exit 1; + + echo -e "Dereverberation - $task - real - $nch ch\n" + # making a shell script for each job + for n in `seq $nj`; do + cat <<-EOF > $wdir/log/wpe.$n.sh + while read line; do + $HOME/miniconda3/bin/python local/run_wpe.py \ + --file \$line + done < $output_wavfiles.$n + EOF + done + + chmod a+x $wdir/log/wpe.*.sh + $cmd JOB=1:$nj $wdir/log/wpe.JOB.log \ + $wdir/log/wpe.JOB.sh + done +done + +for task in dt et; do + for nch in 1 2 8; do + wdir=exp/wpe_simu_${task}_${nch}ch + mkdir -p $wdir/log + arrays=$wdir/channels + output_wavfiles=$wdir/wavfiles.list + if [ ${nch} == 1 ]; then + allwavs=`cat ${dir}/${task}_simu_1ch_wav.scp | cut -d " " -f2` + allwavs_output=`cat ${dir}/${task}_simu_1ch_wpe_wav.scp | cut -d " " -f2` + echo $allwavs | tr ' ' '\n' > $wdir/channels_input + echo $allwavs_output | tr ' ' '\n' > $wdir/channels_output + paste -d" " $wdir/channels_input $wdir/channels_output > $arrays + elif [ ${nch} == 2 ]; then + allwavs=`cat ${dir}/${task}_simu_2ch_wav.scp | cut -d " " -f2` + allwavs_output=`cat ${dir}/${task}_simu_2ch_wpe_wav.scp | cut -d " " -f2` + echo $allwavs | tr ' ' '\n' | grep 'ch1' | sort > $wdir/channels.1st + echo $allwavs | tr ' ' '\n' | grep 'ch2' | sort > $wdir/channels.2nd + echo $allwavs_output | tr ' ' '\n' | grep 'ch1' | sort > $wdir/channels_output.1st + echo $allwavs_output | tr ' ' '\n' | grep 'ch2' | sort > $wdir/channels_output.2nd + paste -d" " $wdir/channels.1st $wdir/channels.2nd $wdir/channels_output.1st $wdir/channels_output.2nd > $arrays + elif [ ${nch} == 8 ]; then + allwavs=`cat ${dir}/${task}_simu_8ch_wav.scp | cut -d " " -f2` + allwavs_output=`cat ${dir}/${task}_simu_8ch_wpe_wav.scp | cut -d " " -f2` + echo $allwavs | tr ' ' '\n' | grep 'ch1' | sort > $wdir/channels.1st + echo $allwavs | tr ' ' '\n' | grep 'ch2' | sort > $wdir/channels.2nd + echo $allwavs | tr ' ' '\n' | grep 'ch3' | sort > $wdir/channels.3rd + echo $allwavs | tr ' ' '\n' | grep 'ch4' | sort > $wdir/channels.4th + echo $allwavs | tr ' ' '\n' | grep 'ch5' | sort > $wdir/channels.5th + echo $allwavs | tr ' ' '\n' | grep 'ch6' | sort > $wdir/channels.6th + echo $allwavs | tr ' ' '\n' | grep 'ch7' | sort > $wdir/channels.7th + echo $allwavs | tr ' ' '\n' | grep 'ch8' | sort > $wdir/channels.8th + echo $allwavs_output | tr ' ' '\n' | grep 'ch1' | sort > $wdir/channels_output.1st + echo $allwavs_output | tr ' ' '\n' | grep 'ch2' | sort > $wdir/channels_output.2nd + echo $allwavs_output | tr ' ' '\n' | grep 'ch3' | sort > $wdir/channels_output.3rd + echo $allwavs_output | tr ' ' '\n' | grep 'ch4' | sort > $wdir/channels_output.4th + echo $allwavs_output | tr ' ' '\n' | grep 'ch5' | sort > $wdir/channels_output.5th + echo $allwavs_output | tr ' ' '\n' | grep 'ch6' | sort > $wdir/channels_output.6th + echo $allwavs_output | tr ' ' '\n' | grep 'ch7' | sort > $wdir/channels_output.7th + echo $allwavs_output | tr ' ' '\n' | grep 'ch8' | sort > $wdir/channels_output.8th + paste -d" " $wdir/channels.1st $wdir/channels.2nd $wdir/channels.3rd $wdir/channels.4th $wdir/channels.5th $wdir/channels.6th $wdir/channels.7th $wdir/channels.8th $wdir/channels_output.1st $wdir/channels_output.2nd $wdir/channels_output.3rd $wdir/channels_output.4th $wdir/channels_output.5th $wdir/channels_output.6th $wdir/channels_output.7th $wdir/channels_output.8th > $arrays + fi + + # split the list for parallel processing + split_wavfiles="" + for n in `seq $nj`; do + split_wavfiles="$split_wavfiles $output_wavfiles.$n" + done + utils/split_scp.pl $arrays $split_wavfiles || exit 1; + + echo -e "Dereverberation - $task - simu - $nch ch\n" + # making a shell script for each job + for n in `seq $nj`; do + cat <<-EOF > $wdir/log/wpe.$n.sh + while read line; do + $HOME/miniconda3/bin/python local/run_wpe.py \ + --file \$line + done < $output_wavfiles.$n + EOF + done + + chmod a+x $wdir/log/wpe.*.sh + $cmd JOB=1:$nj $wdir/log/wpe.JOB.log \ + $wdir/log/wpe.JOB.sh + done +done +echo "`basename $0` Done." diff --git a/egs/reverb/s5/local/score.sh b/egs/reverb/s5/local/score.sh index abd8149a672..66bc976333f 100755 --- a/egs/reverb/s5/local/score.sh +++ b/egs/reverb/s5/local/score.sh @@ -1,23 +1,29 @@ #!/bin/bash -# Copyright 2012 Johns Hopkins University (Author: Daniel Povey) +# Copyright 2012-2014 Johns Hopkins University (Author: Daniel Povey, Yenda Trmal) # Apache 2.0 +# See the script steps/scoring/score_kaldi_cer.sh in case you need to evalutate CER + [ -f ./path.sh ] && . ./path.sh # begin configuration section. cmd=run.pl stage=0 -decode_mbr=true -word_ins_penalty=0.0 +decode_mbr=false +stats=true +beam=6 +word_ins_penalty=0.0,0.5,1.0 min_lmwt=7 max_lmwt=17 +iter=final #end configuration section. +echo "$0 $@" # Print the command line for logging [ -f ./path.sh ] && . ./path.sh . parse_options.sh || exit 1; if [ $# -ne 3 ]; then - echo "Usage: local/score.sh [--cmd (run.pl|queue.pl...)] " + echo "Usage: $0 [--cmd (run.pl|queue.pl...)] " echo " Options:" echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." echo " --stage (0|1|2) # start scoring script from part-way through." @@ -37,21 +43,122 @@ for f in $symtab $dir/lat.1.gz $data/text; do [ ! -f $f ] && echo "score.sh: no such file $f" && exit 1; done -mkdir -p $dir/scoring/log -cat $data/text | sed 's:::g' | sed 's:::g' > $dir/scoring/test_filt.txt +ref_filtering_cmd="cat" +[ -x local/wer_output_filter ] && ref_filtering_cmd="local/wer_output_filter" +[ -x local/wer_ref_filter ] && ref_filtering_cmd="local/wer_ref_filter" +hyp_filtering_cmd="cat" +[ -x local/wer_output_filter ] && hyp_filtering_cmd="local/wer_output_filter" +[ -x local/wer_hyp_filter ] && hyp_filtering_cmd="local/wer_hyp_filter" + + +if $decode_mbr ; then + echo "$0: scoring with MBR, word insertion penalty=$word_ins_penalty" +else + echo "$0: scoring with word insertion penalty=$word_ins_penalty" +fi + + +mkdir -p $dir/scoring_kaldi +if echo $data | grep -q "real"; then + tasks="\ + near_room1 far_room1" +elif echo $data | grep -q "cln"; then + tasks="\ + cln_room1 cln_room2 cln_room3" +else + tasks="\ + near_room1 far_room1 \ + near_room2 far_room2 \ + near_room3 far_room3" +fi +for task in ${tasks}; do + grep $task $data/text | $ref_filtering_cmd > $dir/scoring_kaldi/test_filt_${task}.txt || exit 1; +done + +if [ $stage -le 0 ]; then + + for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do + mkdir -p $dir/scoring_kaldi/penalty_$wip/log + + if $decode_mbr ; then + $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring_kaldi/penalty_$wip/log/best_path.LMWT.log \ + acwt=\`perl -e \"print 1.0/LMWT\"\`\; \ + lattice-scale --inv-acoustic-scale=LMWT "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ + lattice-add-penalty --word-ins-penalty=$wip ark:- ark:- \| \ + lattice-prune --beam=$beam ark:- ark:- \| \ + lattice-mbr-decode --word-symbol-table=$symtab \ + ark:- ark,t:- \| \ + utils/int2sym.pl -f 2- $symtab \| \ + $hyp_filtering_cmd '>' $dir/scoring_kaldi/penalty_$wip/LMWT.txt || exit 1; -$cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/best_path.LMWT.log \ - lattice-scale --inv-acoustic-scale=LMWT "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ - lattice-add-penalty --word-ins-penalty=$word_ins_penalty ark:- ark:- \| \ - lattice-best-path --word-symbol-table=$symtab \ - ark:- ark,t:$dir/scoring/LMWT.tra || exit 1; + else + $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring_kaldi/penalty_$wip/log/best_path.LMWT.log \ + lattice-scale --inv-acoustic-scale=LMWT "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ + lattice-add-penalty --word-ins-penalty=$wip ark:- ark:- \| \ + lattice-best-path --word-symbol-table=$symtab ark:- ark,t:- \| \ + utils/int2sym.pl -f 2- $symtab \| \ + $hyp_filtering_cmd '>' $dir/scoring_kaldi/penalty_$wip/LMWT.txt || exit 1; + fi + for task in ${tasks}; do + $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring_kaldi/penalty_$wip/log/score.LMWT.log \ + grep $task $dir/scoring_kaldi/penalty_$wip/LMWT.txt \| \ + compute-wer --text --mode=present \ + ark:$dir/scoring_kaldi/test_filt_${task}.txt ark,p:- ">&" $dir/wer_LMWT_${wip}_${task} || exit 1; + done + done +fi + + + +if [ $stage -le 1 ]; then + for task in ${tasks}; do + for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do + for lmwt in $(seq $min_lmwt $max_lmwt); do + # adding /dev/null to the command list below forces grep to output the filename + grep WER $dir/wer_${lmwt}_${wip}_${task} /dev/null + done + done | utils/best_wer.sh >& $dir/scoring_kaldi/best_wer_${task} || exit 1 + + best_wer_file=$(awk '{print $NF}' $dir/scoring_kaldi/best_wer_${task}) + best_wip=$(echo $best_wer_file | awk -F_ '{N=NF-2; print $N}') + best_lmwt=$(echo $best_wer_file | awk -F_ '{N=NF-3; print $N}') + + if [ -z "$best_lmwt" ]; then + echo "$0: we could not get the details of the best WER from the file $dir/wer_*. Probably something went wrong." + exit 1; + fi + if $stats; then + mkdir -p $dir/scoring_kaldi/wer_details + echo $best_lmwt > $dir/scoring_kaldi/wer_details/lmwt # record best language model weight + echo $best_wip > $dir/scoring_kaldi/wer_details/wip # record best word insertion penalty + + $cmd $dir/scoring_kaldi/log/stats1.log \ + cat $dir/scoring_kaldi/penalty_$best_wip/$best_lmwt.txt \| \ + align-text --special-symbol="'***'" ark:$dir/scoring_kaldi/test_filt_${task}.txt ark:- ark,t:- \| \ + utils/scoring/wer_per_utt_details.pl --special-symbol "'***'" \| tee $dir/scoring_kaldi/wer_details/per_utt \|\ + utils/scoring/wer_per_spk_details.pl $data/utt2spk \> $dir/scoring_kaldi/wer_details/per_spk || exit 1; + + $cmd $dir/scoring_kaldi/log/stats2.log \ + cat $dir/scoring_kaldi/wer_details/per_utt \| \ + utils/scoring/wer_ops_details.pl --special-symbol "'***'" \| \ + sort -b -i -k 1,1 -k 4,4rn -k 2,2 -k 3,3 \> $dir/scoring_kaldi/wer_details/ops || exit 1; + + $cmd $dir/scoring_kaldi/log/wer_bootci.log \ + compute-wer-bootci --mode=present \ + ark:$dir/scoring_kaldi/test_filt_${task}.txt ark:$dir/scoring_kaldi/penalty_$best_wip/$best_lmwt.txt \ + '>' $dir/scoring_kaldi/wer_details/wer_bootci || exit 1; + + fi + done +fi -# Note: the double level of quoting for the sed command -$cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/score.LMWT.log \ - cat $dir/scoring/LMWT.tra \| \ - utils/int2sym.pl -f 2- $symtab \| sed 's:\::g' \| \ - compute-wer --text --mode=present \ - ark:$dir/scoring/test_filt.txt ark,p:- ">&" $dir/wer_LMWT || exit 1; +# If we got here, the scoring was successful. +# As a small aid to prevent confusion, we remove all wer_{?,??} files; +# these originate from the previous version of the scoring files +# i keep both statement here because it could lead to confusion about +# the capabilities of the script (we don't do cer in the script) +rm $dir/wer_{?,??} 2>/dev/null +rm $dir/cer_{?,??} 2>/dev/null exit 0; diff --git a/egs/reverb/s5/local/score_RealData.patch b/egs/reverb/s5/local/score_RealData.patch new file mode 100644 index 00000000000..cafa521d483 --- /dev/null +++ b/egs/reverb/s5/local/score_RealData.patch @@ -0,0 +1,14 @@ +11c11 +< clear all; +--- +> function score_RealData(download_from_ldc,senhroot) +26c26,27 +< srmrdir = 'SRMRtoolbox-ReverbChallenge'; +--- +> srmrdir = 'SRMRToolbox'; +> addpath(genpath('SRMRToolbox/libs')); +32d32 +< senhroot = '../output/RealData'; +129a130,131 +> +> end diff --git a/egs/reverb/s5/local/score_SimData.patch b/egs/reverb/s5/local/score_SimData.patch new file mode 100644 index 00000000000..4fb0d9f48ac --- /dev/null +++ b/egs/reverb/s5/local/score_SimData.patch @@ -0,0 +1,23 @@ +11c11 +< clear all; +--- +> function score_SimData(download_from_ldc,senhroot,pesqdir,compute_pesq) +26,27c26,27 +< srmrdir = 'SRMRtoolbox-ReverbChallenge'; +< % pesqdir = '/directory/where/pesq/executable/is/stored'; +--- +> srmrdir = 'SRMRToolbox'; +> addpath(genpath('SRMRToolbox/libs')); +36d35 +< senhroot = '../output/SimData'; +39c38 +< if exist('pesqdir', 'var') +--- +> if exist('pesqdir', 'var') && compute_pesq~=0 +471c470,472 +< fclose(fid); +\ No newline at end of file +--- +> fclose(fid); +> +> end diff --git a/egs/reverb/s5/local/score_mbr.sh b/egs/reverb/s5/local/score_mbr.sh deleted file mode 120000 index 2573fadf042..00000000000 --- a/egs/reverb/s5/local/score_mbr.sh +++ /dev/null @@ -1 +0,0 @@ -../../../wsj/s5/local/score_mbr.sh \ No newline at end of file diff --git a/egs/reverb/s5/path.sh b/egs/reverb/s5/path.sh index 1a6fb5f891b..f46c5d8cb72 100644 --- a/egs/reverb/s5/path.sh +++ b/egs/reverb/s5/path.sh @@ -1,4 +1,6 @@ export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +[ -f $KALDI_ROOT/tools/extras/env.sh ] && . $KALDI_ROOT/tools/extras/env.sh export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH [ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 . $KALDI_ROOT/tools/config/common_path.sh diff --git a/egs/reverb/s5/run.sh b/egs/reverb/s5/run.sh index cb0b00c19b6..999ec98e637 100755 --- a/egs/reverb/s5/run.sh +++ b/egs/reverb/s5/run.sh @@ -1,6 +1,8 @@ #!/bin/bash # Copyright 2013-2014 MERL (author: Felix Weninger and Shinji Watanabe) +# Johns Hopkins University (author: Szu-Jui Chen) +# Johns Hopkins University (author: Aswin Shanmugam Subramanian) # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,7 +35,13 @@ fi . ./cmd.sh . ./path.sh -stage=1 +stage=0 +nch_se=8 +# flag for turing on computation of dereverberation measures +compute_se=true +# please make sure that you or your institution have the license to report PESQ before turning on the below flag +enable_pesq=false + . utils/parse_options.sh # Set bash to 'debug' mode, it prints the commands (option '-x') and exits on : # -e 'error', -u 'undefined variable', -o pipefail 'error in pipeline', @@ -41,297 +49,141 @@ set -euxo pipefail # please make sure to set the paths of the REVERB and WSJ0 data if [[ $(hostname -f) == *.clsp.jhu.edu ]] ; then - REVERB_home=/export/corpora5/REVERB_2014/REVERB + reverb=/export/corpora5/REVERB_2014/REVERB export wsjcam0=/export/corpora3/LDC/LDC95S24/wsjcam0 # set LDC WSJ0 directory to obtain LMs # REVERB data directory only provides bi-gram (bcb05cnp), but this recipe also uses 3-gram (tcb05cnp.z) export wsj0=/export/corpora5/LDC/LDC93S6A/11-13.1 #LDC93S6A or LDC93S6B # It is assumed that there will be a 'wsj0' subdirectory # within the top-level corpus directory -elif [[ $(hostname -f) == *.merl.com ]] ; then - REVERB_home=/db/laputa1/data/original/public/REVERB - export wsjcam0=$REVERB_home/wsjcam0 - # set LDC WSJ0 directory to obtain LMs - # REVERB data directory only provides bi-gram (bcb05cnp), but this recipe also uses 3-gram (tcb05cnp.z) - export wsj0=/db/laputa1/data/original/public/WSJ0/11-13.1 #LDC93S6A or LDC93S6B - # It is assumed that there will be a 'wsj0' subdirectory - # within the top-level corpus directory else echo "Set the data directory locations." && exit 1; fi -export reverb_dt=$REVERB_home/REVERB_WSJCAM0_dt -export reverb_et=$REVERB_home/REVERB_WSJCAM0_et -export reverb_real_dt=$REVERB_home/MC_WSJ_AV_Dev -export reverb_real_et=$REVERB_home/MC_WSJ_AV_Eval - -# set the directory of the multi-condition training data to be generated -reverb_tr=`pwd`/data_tr_cut/REVERB_WSJCAM0_tr_cut -# LDA context size (left/right) (4 is default) -context_size=4 +#training set and test set +train_set=tr_simu_8ch +test_sets="dt_real_8ch_beamformit dt_simu_8ch_beamformit et_real_8ch_beamformit et_simu_8ch_beamformit dt_real_1ch_wpe dt_simu_1ch_wpe et_real_1ch_wpe et_simu_1ch_wpe dt_cln et_cln" # The language models with which to decode (tg_5k or bg_5k) lm="tg_5k" # number of jobs for feature extraction and model training -nj_train=30 - +nj=92 # number of jobs for decoding -nj_decode=8 - -# set to true if you want the tri2a systems (re-implementation of the HTK baselines) -do_tri2a=true +decode_nj=10 -if [ $stage -le 1 ]; then - # Generate multi-condition training data - # Note that utterance lengths match the original set. - # This enables using clean alignments in multi-condition training (stereo training) - local/REVERB_create_mcdata.sh $wsjcam0 $reverb_tr +wavdir=${PWD}/wav +pesqdir=${PWD}/local +if [ ${stage} -le 1 ]; then + # data preparation + echo "stage 0: Data preparation" + local/generate_data.sh --wavdir ${wavdir} ${wsjcam0} + local/prepare_simu_data.sh --wavdir ${wavdir} ${reverb} ${wsjcam0} + local/prepare_real_data.sh --wavdir ${wavdir} ${reverb} fi if [ $stage -le 2 ]; then + local/run_wpe.sh --cmd "$train_cmd" + local/run_beamform.sh --cmd "$train_cmd" ${wavdir}/WPE/ +fi + +# Compute dereverberation scores +if [ $stage -le 3 ] && $compute_se; then + if [ ! -d local/REVERB_scores_source ] || [ ! -d local/REVERB_scores_source/REVERB-SPEENHA.Release04Oct/evaltools/SRMRToolbox ] || [ ! -f local/PESQ ]; then + # download and install speech enhancement evaluation tools + local/download_se_eval_tool.sh + fi + local/compute_se_scores.sh --nch $nch_se --enable_pesq $enable_pesq $reverb $wavdir $pesqdir + cat exp/compute_se_${nch_se}ch/scores/score_SimData + cat exp/compute_se_${nch_se}ch/scores/score_RealData +fi + +if [ $stage -le 4 ]; then # Prepare wsjcam0 clean data and wsj0 language model. local/wsjcam0_data_prep.sh $wsjcam0 $wsj0 - + # Prepare merged BEEP/CMU dictionary. local/wsj_prepare_beep_dict.sh # Prepare wordlists, etc. - utils/prepare_lang.sh data/local/dict "" data/local/lang_tmp data/lang + utils/prepare_lang.sh data/local/dict "" data/local/lang_tmp data/lang # Prepare directory structure for clean data. Apply some language model fixes. local/wsjcam0_format_data.sh +fi - # Now it's getting more interesting. - # Prepare the multi-condition training data and the REVERB dt set. - # This also extracts MFCC features (!!!) - # This creates the data sets called REVERB_tr_cut and REVERB_dt. - # If you have processed waveforms, this is a good starting point to integrate them. - # For example, you could have something like - # local/REVERB_wsjcam0_data_prep.sh /path/to/processed/REVERB_WSJCAM0_dt processed_REVERB_dt dt - # The first argument is supposed to point to a folder that has the same structure - # as the REVERB corpus. - local/REVERB_wsjcam0_data_prep.sh $reverb_tr REVERB_tr_cut tr - local/REVERB_wsjcam0_data_prep.sh $reverb_dt REVERB_dt dt - local/REVERB_wsjcam0_data_prep.sh $reverb_et REVERB_et et - - # Prepare the REVERB "real" dt set from MCWSJAV corpus. - # This corpus is *never* used for training. - # This creates the data set called REVERB_Real_dt and its subfolders - local/REVERB_mcwsjav_data_prep.sh $reverb_real_dt REVERB_Real_dt dt - # The MLF file exists only once in the corpus, namely in the real_dt directory - # so we pass it as 4th argument - local/REVERB_mcwsjav_data_prep.sh $reverb_real_et REVERB_Real_et et $reverb_real_dt/mlf/WSJ.mlf +if [ $stage -le 5 ]; then + for dset in ${train_set} ${test_sets}; do + utils/copy_data_dir.sh data/${dset} data/${dset}_nosplit + utils/data/modify_speaker_info.sh --seconds-per-spk-max 180 data/${dset}_nosplit data/${dset} + done fi -if [ $stage -le 3 ]; then - # Extract MFCC features for clean sets. - # For the non-clean data sets, this is outsourced to the data preparation scripts. +if [ $stage -le 6 ]; then + # Extract MFCC features for train and test sets. mfccdir=mfcc - ### for x in si_tr si_dt; do it seems that the number of transcriptions of si_dt is not correct. - for x in si_tr; do - steps/make_mfcc.sh --cmd "$train_cmd" --nj $nj_train \ + for x in ${train_set} ${test_sets}; do + steps/make_mfcc.sh --cmd "$train_cmd" --nj 30 \ data/$x exp/make_mfcc/$x $mfccdir steps/compute_cmvn_stats.sh data/$x exp/make_mfcc/$x $mfccdir done fi -if [ $stage -le 4 ]; then - # Train monophone model on clean data (si_tr). - echo "### TRAINING mono0a ###" - steps/train_mono.sh --boost-silence 1.25 --nj $nj_train --cmd "$train_cmd" \ - data/si_tr data/lang exp/mono0a - - # Align monophones with clean data. - echo "### ALIGNING mono0a_ali ###" - steps/align_si.sh --boost-silence 1.25 --nj $nj_train --cmd "$train_cmd" \ - data/si_tr data/lang exp/mono0a exp/mono0a_ali - - # Create first triphone recognizer. - echo "### TRAINING tri1 ###" - steps/train_deltas.sh --boost-silence 1.25 --cmd "$train_cmd" \ - 2000 10000 data/si_tr data/lang exp/mono0a_ali exp/tri1 - - echo "### ALIGNING tri1_ali ###" - # Re-align triphones. - steps/align_si.sh --nj $nj_train --cmd "$train_cmd" \ - data/si_tr data/lang exp/tri1 exp/tri1_ali +if [ $stage -le 7 ]; then + # Starting basic training on MFCC features + steps/train_mono.sh --nj $nj --cmd "$train_cmd" \ + data/${train_set} data/lang exp/mono fi -# The following code trains and evaluates a delta feature recognizer, which is similar to the HTK -# baseline (but using per-utterance basis fMLLR instead of batch MLLR). This is for reference only. -if $do_tri2a; then -if [ $stage -le 5 ]; then - # Train tri2a, which is deltas + delta-deltas, on clean data. - steps/train_deltas.sh --cmd "$train_cmd" \ - 2500 15000 data/si_tr data/lang exp/tri1_ali exp/tri2a - - # Re-align triphones using clean data. This gives a smallish performance gain. - steps/align_si.sh --nj $nj_train --cmd "$train_cmd" \ - data/si_tr data/lang exp/tri2a exp/tri2a_ali +if [ $stage -le 8 ]; then + steps/align_si.sh --nj $nj --cmd "$train_cmd" \ + data/${train_set} data/lang exp/mono exp/mono_ali - # Train a multi-condition triphone recognizer. - # This uses alignments on *clean* data, which is allowed for REVERB. - # However, we have to use the "cut" version so that the length of the - # waveforms match. - # It is actually asserted by the Challenge that clean and multi-condition waves are aligned. steps/train_deltas.sh --cmd "$train_cmd" \ - 2500 15000 data/REVERB_tr_cut/SimData_tr_for_1ch_A data/lang exp/tri2a_ali exp/tri2a_mc - - # Prepare clean and mc tri2a models for decoding. - utils/mkgraph.sh data/lang_test_bg_5k exp/tri2a exp/tri2a/graph_bg_5k & - utils/mkgraph.sh data/lang_test_bg_5k exp/tri2a_mc exp/tri2a_mc/graph_bg_5k & - wait + 2500 30000 data/${train_set} data/lang exp/mono_ali exp/tri1 fi -if [ $stage -le 6 ]; then - # decode REVERB dt using tri2a, clean - for dataset in data/REVERB_*{dt,et}/*; do - steps/decode.sh --nj $nj_decode --cmd "$decode_cmd" \ - exp/tri2a/graph_bg_5k $dataset exp/tri2a/decode_bg_5k_`echo $dataset | awk -F '/' '{print $2 "_" $3}'` & - done - - # decode REVERB dt using tri2a, mc - for dataset in data/REVERB_*{dt,et}/*; do - steps/decode.sh --nj $nj_decode --cmd "$decode_cmd" \ - exp/tri2a_mc/graph_bg_5k $dataset exp/tri2a_mc/decode_bg_5k_`echo $dataset | awk -F '/' '{print $2 "_" $3}'` & - done - - # basis fMLLR for tri2a_mc system - # This computes a transform for every training utterance and computes a basis from that. - steps/get_fmllr_basis.sh --cmd "$train_cmd" --per-utt true data/REVERB_tr_cut/SimData_tr_for_1ch_A data/lang exp/tri2a_mc - - # Recognition using fMLLR adaptation (per-utterance processing). - for dataset in data/REVERB_*{dt,et}/*; do - steps/decode_basis_fmllr.sh --nj $nj_decode --cmd "$decode_cmd" \ - exp/tri2a_mc/graph_bg_5k $dataset exp/tri2a_mc/decode_basis_fmllr_bg_5k_`echo $dataset | awk -F '/' '{print $2 "_" $3}'` & - done - wait -fi -fi +if [ $stage -le 9 ]; then + steps/align_si.sh --nj $nj --cmd "$train_cmd" \ + data/${train_set} data/lang exp/tri1 exp/tri1_ali -if [ $stage -le 7 ]; then - # Train tri2b recognizer, which uses LDA-MLLT, using the default parameters from the WSJ recipe. - echo "### TRAINING tri2b ###" steps/train_lda_mllt.sh --cmd "$train_cmd" \ - --splice-opts "--left-context=$context_size --right-context=$context_size" \ - 2500 15000 data/si_tr data/lang exp/tri1_ali exp/tri2b - - # tri2b (LDA-MLLT system) with multi-condition training, using default parameters. - echo "### TRAINING tri2b_mc ###" - steps/train_lda_mllt.sh --cmd "$train_cmd"\ - --splice-opts "--left-context=$context_size --right-context=$context_size" \ - 2500 15000 data/REVERB_tr_cut/SimData_tr_for_1ch_A data/lang exp/tri1_ali exp/tri2b_mc + 4000 50000 data/${train_set} data/lang exp/tri1_ali exp/tri2 fi -# Prepare tri2b* systems for decoding. -if [ $stage -le 8 ]; then - echo "### MAKING GRAPH {tri2b,tri2b_mc}/graph_$lm ###" - for recog in tri2b tri2b_mc; do - utils/mkgraph.sh data/lang_test_$lm exp/$recog exp/$recog/graph_$lm & +if [ $stage -le 10 ]; then + utils/mkgraph.sh data/lang_test_$lm exp/tri2 exp/tri2/graph + for dset in ${test_sets}; do + steps/decode.sh --nj $decode_nj --cmd "$decode_cmd" --num-threads 4 \ + exp/tri2/graph data/${dset} exp/tri2/decode_${dset} & done wait fi -# discriminative training on top of multi-condition systems -# one could also add tri2b here to have a DT clean recognizer for reference -if [ $stage -le 9 ]; then - base_recog=tri2b_mc - bmmi_recog=${base_recog}_mmi_b0.1 - echo "### DT $base_recog --> $bmmi_recog ###" - - # get alignments from base recognizer - steps/align_si.sh --nj $nj_train --cmd "$train_cmd" \ - --use-graphs true data/REVERB_tr_cut/SimData_tr_for_1ch_A data/lang exp/$base_recog exp/${base_recog}_ali - - # get lattices from base recognizer - denlats_dir=${base_recog}_denlats - subsplit=`echo $nj_train \* 2 | bc` - # DT with multi-condition data ... - steps/make_denlats.sh --sub-split $subsplit --nj $nj_train --cmd "$decode_cmd" \ - data/REVERB_tr_cut/SimData_tr_for_1ch_A data/lang exp/$base_recog exp/$denlats_dir +if [ $stage -le 11 ]; then + steps/align_si.sh --nj $nj --cmd "$train_cmd" \ + data/${train_set} data/lang exp/tri2 exp/tri2_ali - # boosted MMI training - steps/train_mmi.sh --boost 0.1 --cmd "$train_cmd" \ - data/REVERB_tr_cut/SimData_tr_for_1ch_A \ - data/lang \ - exp/${base_recog}_ali \ - exp/$denlats_dir \ - exp/$bmmi_recog - cp exp/$base_recog/ali.* exp/$bmmi_recog + steps/train_sat.sh --cmd "$train_cmd" \ + 5000 100000 data/${train_set} data/lang exp/tri2_ali exp/tri3 fi -# decoding using various recognizers -if [ $stage -le 10 ]; then - # put tri2b last since it takes longest due to the large mismatch. - for recog in tri2b_mc tri2b_mc_mmi_b0.1 tri2b; do - # The graph from the ML directory is used in recipe - recog2=`echo $recog | sed s/_mmi.*//` - graph=exp/$recog2/graph_$lm - - echo "### DECODING with $recog, noadapt, $lm ###" - for dataset in data/REVERB_*{dt,et}/*; do - decode_suff=${lm}_`echo $dataset | awk -F '/' '{print $2 "_" $3}'` - steps/decode.sh --nj $nj_decode --cmd "$decode_cmd" \ - $graph $dataset \ - exp/$recog/decode_$decode_suff & - done - wait - - echo " ## MBR RESCORING with $recog, noadapt ##" - for dataset in data/REVERB_*{dt,et}/*; do - decode_suff=${lm}_`echo $dataset | awk -F '/' '{print $2 "_" $3}'` - mkdir -p exp/$recog/decode_mbr_$decode_suff - cp exp/$recog/decode_$decode_suff/lat.*.gz exp/$recog/decode_mbr_$decode_suff - local/score_mbr.sh --cmd "$decode_cmd" \ - $dataset data/lang_test_$lm/ exp/$recog/decode_mbr_$decode_suff & - done - wait - - done # loop recog +if [ $stage -le 12 ]; then + utils/mkgraph.sh data/lang_test_$lm exp/tri3 exp/tri3/graph + for dset in ${test_sets}; do + steps/decode_fmllr.sh --nj $decode_nj --cmd "$decode_cmd" --num-threads 4 \ + exp/tri3/graph data/${dset} exp/tri3/decode_${dset} & + done + wait fi -# decoding using various recognizers with adaptation -if [ $stage -le 11 ]; then - # put tri2b last since it takes longest due to the large mismatch. - for recog in tri2b_mc tri2b_mc_mmi_b0.1 tri2b; do - # The graph from the ML directory is used in recipe - recog2=`echo $recog | sed s/_mmi.*//` - graph=exp/$recog2/graph_$lm - - # set the adaptation data - if [[ "$recog" =~ _mc ]]; then - tr_dataset=REVERB_tr_cut/SimData_tr_for_1ch_A - else - tr_dataset=si_tr - fi - - echo "### DECODING with $recog, basis_fmllr, $lm ###" - steps/get_fmllr_basis.sh --cmd "$train_cmd" --per-utt true data/$tr_dataset data/lang exp/$recog - for dataset in data/REVERB_*{dt,et}/*; do - ( - decode_suff=${lm}_`echo $dataset | awk -F '/' '{print $2 "_" $3}'` - steps/decode_basis_fmllr.sh --nj $nj_decode --cmd "$decode_cmd" \ - $graph $dataset \ - exp/$recog/decode_basis_fmllr_$decode_suff - ) & - done - wait - - echo " ## MBR RESCORING with $recog, basis_fmllr ##" - for dataset in data/REVERB_*{dt,et}/*; do - decode_suff=${lm}_`echo $dataset | awk -F '/' '{print $2 "_" $3}'` - mkdir -p exp/$recog/decode_mbr_basis_fmllr_$decode_suff - cp exp/$recog/decode_basis_fmllr_$decode_suff/lat.*.gz exp/$recog/decode_mbr_basis_fmllr_$decode_suff - local/score_mbr.sh --cmd "$decode_cmd" \ - $dataset data/lang_test_$lm/ exp/$recog/decode_mbr_basis_fmllr_$decode_suff & - done - wait - - done # loop recog +if [ $stage -le 13 ]; then + # chain TDNN + local/chain/run_tdnn.sh --nj ${nj} --train-set ${train_set} --test-sets "$test_sets" --gmm tri3 --nnet3-affix _${train_set} \ + --lm-suffix _test_$lm fi -# get all WERs with lmw=15 -if [ $stage -le 12 ]; then +# get all WERs. +if [ $stage -le 14 ]; then local/get_results.sh fi diff --git a/egs/rimes/README.txt b/egs/rimes/README.txt new file mode 100644 index 00000000000..d201c5fec4e --- /dev/null +++ b/egs/rimes/README.txt @@ -0,0 +1,13 @@ +Rimes is a French handwriting recognition database created by A2iA. +The database was created by asking individuals to write letters on a given scenario like +a change of personal information, payment difficulty, damage declaration. The +dataset has been used in several international research including ICFHR 2008, +ICDAR-2009, ICDAR-2011 competitions for isolated word level and +line level recognition tasks. + +It contains 11333 training lines and 788 test lines. It does not include +a validation split but in a recent publication a 10% sampling of the total +training lines for validation purposes were performed +(http://www.jpuigcerver.net/pubs/jpuigcerver_icdar2017.pdf). +We have used a similar train, test and validation split. +More info: http://www.a2ialab.com/doku.php?id=rimes_database:start diff --git a/egs/rimes/v1/cmd.sh b/egs/rimes/v1/cmd.sh new file mode 100755 index 00000000000..6080a8bab68 --- /dev/null +++ b/egs/rimes/v1/cmd.sh @@ -0,0 +1,13 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +export cmd="retry.pl queue.pl" diff --git a/egs/rimes/v1/image b/egs/rimes/v1/image new file mode 120000 index 00000000000..1668ee99922 --- /dev/null +++ b/egs/rimes/v1/image @@ -0,0 +1 @@ +../../cifar/v1/image/ \ No newline at end of file diff --git a/egs/rimes/v1/local/chain/compare_wer.sh b/egs/rimes/v1/local/chain/compare_wer.sh new file mode 100755 index 00000000000..4a2cc29481c --- /dev/null +++ b/egs/rimes/v1/local/chain/compare_wer.sh @@ -0,0 +1,88 @@ +#!/bin/bash + +# this script is used for comparing decoding results between systems. +# e.g. local/chain/compare_wer.sh exp/chain/cnn{1a,1b} + +# Copyright 2017 Chun Chieh Chang +# 2017 Ashish Arora + +if [ $# == 0 ]; then + echo "Usage: $0: [ ... ]" + echo "e.g.: $0 exp/chain/cnn{1a,1b}" + exit 1 +fi +. ./path.sh + +echo "# $0 $*" +used_epochs=false + +echo -n "# System " +for x in $*; do printf "% 10s" " $(basename $x)"; done +echo + +echo -n "# WER " +for x in $*; do + wer=$(cat $x/decode_test/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# CER " +for x in $*; do + cer=$(cat $x/decode_test/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + +echo -n "# WER val " +for x in $*; do + wer=$(cat $x/decode_val/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# CER val " +for x in $*; do + cer=$(cat $x/decode_val/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + +if $used_epochs; then + exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. +fi + +echo -n "# Final train prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final train prob (xent) " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob (xent) " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Parameters " +for x in $*; do + params=$(nnet3-info $x/final.mdl 2>/dev/null | grep num-parameters | cut -d' ' -f2 | awk '{printf "%0.2fM\n",$1/1000000}') + printf "% 10s" $params +done +echo diff --git a/egs/rimes/v1/local/chain/run_cnn_e2eali.sh b/egs/rimes/v1/local/chain/run_cnn_e2eali.sh new file mode 120000 index 00000000000..e2545b0186e --- /dev/null +++ b/egs/rimes/v1/local/chain/run_cnn_e2eali.sh @@ -0,0 +1 @@ +tuning/run_cnn_e2eali_1a.sh \ No newline at end of file diff --git a/egs/rimes/v1/local/chain/run_e2e_cnn.sh b/egs/rimes/v1/local/chain/run_e2e_cnn.sh new file mode 120000 index 00000000000..d26ba0182ce --- /dev/null +++ b/egs/rimes/v1/local/chain/run_e2e_cnn.sh @@ -0,0 +1 @@ +tuning/run_e2e_cnn_1a.sh \ No newline at end of file diff --git a/egs/rimes/v1/local/chain/tuning/run_cnn_e2eali_1a.sh b/egs/rimes/v1/local/chain/tuning/run_cnn_e2eali_1a.sh new file mode 100755 index 00000000000..33eb9dcb98c --- /dev/null +++ b/egs/rimes/v1/local/chain/tuning/run_cnn_e2eali_1a.sh @@ -0,0 +1,257 @@ +#!/bin/bash + +# e2eali_1a is a 6 cnn layer 3 tdnn layer model with dropout, l2-regularization, batch-normalization + +# local/chain/compare_wer.sh exp/chain/cnn_e2eali_1a +# System cnn_e2eali_1a +# WER 7.75 +# CER 2.68 +# Final train prob -0.0779 +# Final valid prob -0.0860 +# Final train prob (xent) -0.7744 +# Final valid prob (xent) -0.8111 +# Parameters 4.96M + +# steps/info/chain_dir_info.pl exp/chain/cnn_e2eali_1a +# exp/chain/cnn_e2eali_1a: num-iters=36 nj=3..8 num-params=5.0M dim=40->944 combine=-0.076->-0.076 (over 1) xent:train/valid[23,35,final]=(-1.48,-0.871,-0.774/-1.46,-0.888,-0.811) logprob:train/valid[23,35,final]=(-0.208,-0.102,-0.078/-0.189,-0.104,-0.086) + +# line level scoring result +# WER 7.75 [ 437 / 5639, 62 ins, 55 del, 320 sub ] exp/chain/cnn_e2eali_1d/decode_test/wer_7_1.0 +# paragraph scoring result +# WER 6.69 [ 377 / 5639, 44 ins, 37 del, 296 sub ] exp/chain/cnn_e2eali_1a/decode_test/para/wer_7_1.0 + +set -e -o pipefail + +stage=0 + +nj=50 +train_set=train +decode_val=true +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. +affix=_1a #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +e2echain_model_dir=exp/chain/e2e_cnn_1a +tree_affix=_1a +bnf_chain_model_dir=exp/chain/e2e_cnn_1a +bnf_layer_name=tdnn6.affine +common_egs_dir= +reporting_email= + +# chain options +train_stage=-10 +xent_regularize=0.1 +# training chunk-options +chunk_width=340,300,200,100 +num_leaves=1000 +# we don't need extra left/right context for TDNN systems. +tdnn_dim=550 +# training options +srand=0 +remove_egs=true +lang_decode=data/lang +if $decode_val; then maybe_val=val; else maybe_val= ; fi +dropout_schedule='0,0@0.20,0.2@0.50,0' +# End configuration section. +echo "$0 $@" # Print the command line for logging + + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 2 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/nnet3/align_lats.sh --nj $nj --cmd "$cmd" \ + --acoustic-scale 1.0 --generate-ali-from-lats true \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0' \ + ${train_data_dir} data/lang $e2echain_model_dir $lat_dir + echo "" >$lat_dir/splice_opts +fi + +bnf_data_dir=$bnf_chain_model_dir/$(basename $train_data_dir) +if [ $stage -le 3 ]; then + if [ -f $bnf_data_dir/feats.scp ]; then + echo "$0: $bnf_data_dir/feats.scp exists. Refusing to dump features!" + exit 1 + fi + + steps/nnet3/make_bottleneck_features.sh --cmd "$cmd" --use-gpu true \ + --compress false --nj $nj \ + $bnf_layer_name ${train_data_dir} ${bnf_data_dir} $bnf_chain_model_dir || exit 1 +fi + +if [ $stage -le 4 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor 4 \ + --alignment-subsampling-factor 1 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$cmd" $num_leaves ${bnf_data_dir} \ + $lang $lat_dir $tree_dir +fi + + +if [ $stage -le 5 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + cnn_opts="l2-regularize=0.03 dropout-proportion=0.0" + tdnn_opts="l2-regularize=0.03" + output_opts="l2-regularize=0.04" + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=70" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=40 name=input + conv-relu-batchnorm-dropout-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-dropout-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-dropout-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-dropout-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-dropout-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common3 height-subsample-out=2 + conv-relu-batchnorm-dropout-layer name=cnn6 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + relu-batchnorm-dropout-layer name=tdnn1 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + relu-batchnorm-dropout-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + relu-batchnorm-dropout-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 6 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/iam-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.apply-deriv-weights=true \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=1000" \ + --chain.frame-subsampling-factor=4 \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=10 \ + --trainer.frames-per-iter=2000000 \ + --trainer.optimization.num-jobs-initial=3 \ + --trainer.optimization.num-jobs-final=8 \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=32,16 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 7 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $lang_decode \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 8 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + for decode_set in test $maybe_val; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/$decode_set $dir/decode_$decode_set || exit 1; + done +fi + +echo "Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/rimes/v1/local/chain/tuning/run_e2e_cnn_1a.sh b/egs/rimes/v1/local/chain/tuning/run_e2e_cnn_1a.sh new file mode 100755 index 00000000000..9d28a41316d --- /dev/null +++ b/egs/rimes/v1/local/chain/tuning/run_e2e_cnn_1a.sh @@ -0,0 +1,156 @@ +#!/bin/bash +# Copyright 2017 Hossein Hadian + +# This script does end2end chain training (i.e. from scratch) +# local/chain/compare_wer.sh exp/chain/e2e_cnn_1a +# System e2e_cnn_1d +# WER 10.07 +# CER 3.95 +# Final train prob 0.0369 +# Final valid prob -0.0129 +# Final train prob (xent) +# Final valid prob (xent) +# Parameters 12.73M + +# steps/info/chain_dir_info.pl exp/chain/e2e_cnn_1a +# exp/chain/e2e_cnn_1a: num-iters=20 nj=2..4 num-params=12.7M dim=40->19404 combine=0.079->0.079 (over 3) logprob:train/valid[12,19,final]=(0.017,0.034,0.037/-0.024,-0.013,-0.013) + +set -e + +# configs for 'chain' +stage=0 +train_stage=-10 +get_egs_stage=-10 +affix=1a +nj=50 + +# training options +tdnn_dim=450 +minibatch_size=150=100,64/300=50,32/600=25,16/1200=16,8 +common_egs_dir= +train_set=train +decode_val=true +lang_decode=data/lang +if $decode_val; then maybe_val=val; else maybe_val= ; fi +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 1 ]; then + steps/nnet3/chain/e2e/prepare_e2e.sh --nj 30 --cmd "$cmd" \ + --shared-phones true \ + --type biphone \ + data/$train_set $lang $treedir + $cmd $treedir/log/make_phone_lm.log \ + cat data/$train_set/text \| \ + steps/nnet3/chain/e2e/text_to_phones.py data/lang \| \ + utils/sym2int.pl -f 2- data/lang/phones.txt \| \ + chain-est-phone-lm --num-extra-lm-states=500 \ + ark:- $treedir/phone_lm.fst +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $treedir/tree | grep num-pdfs | awk '{print $2}') + common1="height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="height-offsets=-2,-1,0,1,2 num-filters-out=70" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=40 name=input + conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + relu-batchnorm-layer name=tdnn1 input=Append(-4,-2,0,2,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn4 input=Append(-4,0,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn6 input=Append(-4,0,4) dim=200 + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $output_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs +fi + +if [ $stage -le 3 ]; then + # no need to store the egs in a shared storage because we always + # remove them. Anyway, it takes only 5 minutes to generate them. + steps/nnet3/chain/e2e/train_e2e.py --stage $train_stage \ + --cmd "$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--num_egs_diagnostic 100 --num_utts_subset 400" \ + --chain.frame-subsampling-factor 4 \ + --chain.alignment-subsampling-factor 4 \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter 2000000 \ + --trainer.num-epochs 3 \ + --trainer.optimization.momentum 0 \ + --trainer.optimization.num-jobs-initial 2 \ + --trainer.optimization.num-jobs-final 4 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.optimization.shrink-value 1.0 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs true \ + --feat-dir data/${train_set} \ + --tree-dir $treedir \ + --dir $dir || exit 1; +fi + +if [ $stage -le 4 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + utils/mkgraph.sh \ + --self-loop-scale 1.0 $lang_decode \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 5 ]; then + for decode_set in test $maybe_val; do + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/$decode_set $dir/decode_$decode_set || exit 1; + done +fi + +echo "Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/rimes/v1/local/combine_line_txt_to_paragraph.py b/egs/rimes/v1/local/combine_line_txt_to_paragraph.py new file mode 100755 index 00000000000..5a794506b47 --- /dev/null +++ b/egs/rimes/v1/local/combine_line_txt_to_paragraph.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +""" This script creates paragraph level text file. It reads + the line level text file and combines them to get + paragraph level file. + Eg. local/combine_line_txt_to_paragraph.py + Eg. Input: writer000000_eval2011-0_000001 Comme indiqué dans + writer000000_eval2011-0_000002 habitation n° DVT 36 + writer000000_eval2011-0_000003 de mon domicile + Output: writer000000_eval2011-0 Comme indiqué dans habitation n° DVT 36 de mon domicile +""" + +import argparse +import os +import io +import sys +### main ### +infile = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') +output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') + +paragraph_txt_dict = dict() +for line in infile: + line_vect = line.strip().split(' ') + line_id = int(line_vect[0].split('_')[-1]) + paragraph_id = line_vect[0].split('-')[-1] + paragraph_id = int(paragraph_id.split('_')[0]) + line_text = " ".join(line_vect[1:]) + if paragraph_id not in paragraph_txt_dict.keys(): + paragraph_txt_dict[paragraph_id] = dict() + paragraph_txt_dict[paragraph_id][line_id] = line_text + + +para_txt_dict = dict() +for para_id in sorted(paragraph_txt_dict.keys()): + para_txt = "" + for line_id in sorted(paragraph_txt_dict[para_id]): + text = paragraph_txt_dict[para_id][line_id] + para_txt = para_txt + " " + text + para_txt_dict[para_id] = para_txt + utt_id = 'writer' + str(para_id).zfill(6) + '_' + 'eval2011-' + str(para_id) + output.write(utt_id + ' ' + para_txt + '\n') diff --git a/egs/rimes/v1/local/extract_features.sh b/egs/rimes/v1/local/extract_features.sh new file mode 100755 index 00000000000..ec3bc8a268c --- /dev/null +++ b/egs/rimes/v1/local/extract_features.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# Copyright 2017 Yiwen Shao +# 2018 Ashish Arora + +# Apache 2.0 +# This script runs the make features script in parallel. + +nj=4 +cmd=run.pl +feat_dim=40 +augment_type=no_aug +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +data=$1 +featdir=$data/data +scp=$data/images.scp +logdir=$data/log + +mkdir -p $logdir +mkdir -p $featdir + +# make $featdir an absolute pathname +featdir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $featdir ${PWD}` + +for n in $(seq $nj); do + split_scps="$split_scps $logdir/images.$n.scp" +done + +# split images.scp +utils/split_scp.pl $scp $split_scps || exit 1; + +$cmd JOB=1:$nj $logdir/extract_features.JOB.log \ + image/ocr/make_features.py $logdir/images.JOB.scp \ + --allowed_len_file_path $data/allowed_lengths.txt \| \ + copy-feats --compress=true --compression-method=7 \ + ark:- ark,scp:$featdir/images.JOB.ark,$featdir/images.JOB.scp + +## aggregates the output scp's to get feats.scp +for n in $(seq $nj); do + cat $featdir/images.$n.scp || exit 1; +done > $data/feats.scp || exit 1 diff --git a/egs/rimes/v1/local/prepare_data.sh b/egs/rimes/v1/local/prepare_data.sh new file mode 100755 index 00000000000..502718e7777 --- /dev/null +++ b/egs/rimes/v1/local/prepare_data.sh @@ -0,0 +1,67 @@ +#!/bin/bash + +# This script creates traing and validations splits, downloads text corpus for language modeling, +# prepares the training, validation and test data for rimes dataset +# (i.e text, images.scp, utt2spk and spk2utt). It calls process_data.py. + +# Eg. local/prepare_data.sh +# Eg. text file: writer000150_train2011-150_000001 J'ai perdu mon emploi depuis 3 mois et je me +# utt2spk file: writer000150_train2011-150_000001 writer000150 +# images.scp file: writer000150_train2011-150_000001 data/local/rimes_data/line_image/train/train2011-150_000001.png + +stage=0 +download_dir=data/local/rimes_data +data_dir=data/local/rimes_data +page_image=$data_dir/page_image +xml=$data_dir/xml +train_img_url="http://www.a2ialab.com/lib/exe/fetch.php?media=rimes_database:data:icdar2011:line:training_2011.tar"; +train_xml_url="http://www.a2ialab.com/lib/exe/fetch.php?media=rimes_database:data:icdar2011:line:training_2011.xml"; +test_xml_url="http://www.a2ialab.com/lib/exe/fetch.php?media=rimes_database:data:icdar2011:line:eval_2011_annotated.xml"; +test_img_url="http://www.a2ialab.com/lib/exe/fetch.php?media=rimes_database:data:icdar2011:line:eval_2011.tar"; +text_url="http://opus.nlpl.eu/download.php?f=OfisPublik.tar.gz" +use_extra_corpus_text=true +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +mkdir -p data/{train,test,val} + +if [ -d $page_image ]; then + echo "$0: Not downloading data as it is already there." +else + mkdir -p $data_dir/{page_image,xml,line_image}/{train_total,test,val,train} + tar -xf $download_dir/training_2011.tar -C $page_image/train_total || exit 1; + tar -xf $download_dir/eval_2011.tar -C $page_image/test || exit 1; + cp -r $download_dir/training_2011.xml $xml/train_total/rimes_2011.xml + cp -r $download_dir/eval_2011_annotated.xml $xml/test/rimes_2011.xml + echo "$0: Done downloading and extracting data" + + #First 150 training page images are used for validation + cat $xml/train_total/rimes_2011.xml | head -n451 > $xml/val/rimes_2011.xml + cat $xml/train_total/rimes_2011.xml | tail -1 >> $xml/val/rimes_2011.xml + cp -r $page_image/train_total/* $page_image/train + + #Remaining training page images are used for training + cat $xml/train_total/rimes_2011.xml | head -1 > $xml/train/rimes_2011.xml + cat $xml/train_total/rimes_2011.xml | tail -n+452 >> $xml/train/rimes_2011.xml + cp -r $page_image/train_total/* $page_image/val +fi + +if $use_extra_corpus_text; then + # using freely available french text corpus for language modeling + mkdir -p data/local/text_data + wget -P data/local/text_data $text_url || exit 1; + tar -xf data/local/text_data/download.php?f=OfisPublik.tar.gz -C data/local/text_data || exit 1; + zcat data/local/text_data/OfisPublik/raw/fr/*.gz > data/local/text_data/fr_text +fi + +if [ $stage -le 0 ]; then + echo "$0: Processing train, val and test data... $(date)." + local/process_data.py $data_dir train --augment true || exit 1 + local/process_data.py $data_dir val || exit 1 + local/process_data.py $data_dir test || exit 1 + for dataset in test train val; do + echo "$0: Fixing data directory for dataset: $dataset $(date)." + image/fix_data_dir.sh data/$dataset + done +fi diff --git a/egs/rimes/v1/local/prepare_dict.sh b/egs/rimes/v1/local/prepare_dict.sh new file mode 100755 index 00000000000..d8093658c30 --- /dev/null +++ b/egs/rimes/v1/local/prepare_dict.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +# Copyright 2017 Hossein Hadian +# 2017 Chun Chieh Chang +# 2017 Ashish Arora + +# This script prepares the dictionary. + +set -e +dir=data/local/dict +mkdir -p $dir + +local/prepare_lexicon.py $dir + +cut -d' ' -f2- $dir/lexicon.txt | sed 's/SIL//g' | tr ' ' '\n' | sort -u | sed '/^$/d' >$dir/nonsilence_phones.txt || exit 1; + +echo ' SIL' >> $dir/lexicon.txt + +echo SIL > $dir/silence_phones.txt + +echo SIL >$dir/optional_silence.txt + +echo -n "" >$dir/extra_questions.txt diff --git a/egs/rimes/v1/local/prepare_lexicon.py b/egs/rimes/v1/local/prepare_lexicon.py new file mode 100755 index 00000000000..5a6ac5b6dbf --- /dev/null +++ b/egs/rimes/v1/local/prepare_lexicon.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora + +import argparse +import os + +parser = argparse.ArgumentParser(description="""Creates the list of characters and words in lexicon""") +parser.add_argument('dir', type=str, help='output path') +args = parser.parse_args() + +### main ### +lex = {} +text_path = os.path.join('data', 'train', 'text') +text_fh = open(text_path, 'r', encoding='utf-8') + +with open(text_path, 'r', encoding='utf-8') as f: + for line in f: + line_vect = line.strip().split(' ') + for i in range(1, len(line_vect)): + characters = list(line_vect[i]) + # Put SIL instead of "|". Because every "|" in the beginning of the words is for initial-space of that word + characters = " ".join(['SIL' if char == '|' else char for char in characters]) + lex[line_vect[i]] = characters + if line_vect[i] == '#': + lex[line_vect[i]] = "" + +with open(os.path.join(args.dir, 'lexicon.txt'), 'w', encoding='utf-8') as fp: + for key in sorted(lex): + fp.write(key + " " + lex[key] + "\n") diff --git a/egs/rimes/v1/local/process_data.py b/egs/rimes/v1/local/process_data.py new file mode 100755 index 00000000000..b87d9fbc5e2 --- /dev/null +++ b/egs/rimes/v1/local/process_data.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 + +""" This script reads xml file and creates the following files :text, utt2spk, images.scp. + It also creates line images from page image and stores it into + data/local/rimes_data/train/lines. + Eg. local/process_data.py data/local/rimes_data/train train + Eg. text file: writer000000_train2011-0_000001 Je vous adresse ce courrier afin + utt2spk file: writer000000_train2011-0_000001 writer000000 + images.scp file: writer000000_train2011-0_000001 \ + data/local/rimes_data/train/lines/train2011-0_000001.png +""" + +import argparse +import xml.dom.minidom as minidom +from PIL import Image +import os +import random +parser = argparse.ArgumentParser(description="""Creates line images from page image.""") +parser.add_argument('database_path', type=str, + help='Path to the downloaded (and extracted) mdacat data') +parser.add_argument('dataset', type=str, + help='Subset of data to process.') +parser.add_argument("--augment", type=lambda x: (str(x).lower()=='true'), default=False, + help="performs image augmentation") +parser.add_argument('--pixel-scaling', type=int, default=20, + help='padding across horizontal/verticle direction') +args = parser.parse_args() + +def expand_aabb(left, right, top, bottom, delta_pixel): + """ Increases size of axis aligned bounding box (aabb). + """ + left = left - delta_pixel + right = right + delta_pixel + top = top - delta_pixel + bottom = bottom + delta_pixel + return left, right, top, bottom + +def get_line_images_from_page_image(file_name, left, right, top, bottom, line_id): + """ Given a page image, extracts the line images from it. + Input + ----- + file_name (string): name of the page image. + left, right, top, bottom (int): coordinates corresponding to the line image. + line_id (int): line number on the page image. + """ + page_image_path = os.path.join(page_image_folder, file_name) + im = Image.open(page_image_path) + box = (left, top, right, bottom) + region = im.crop(box) + base_name = os.path.splitext(os.path.basename(file_name))[0] + line_image_file_name = base_name + '_' + str(line_id).zfill(6) + '.png' + imgray = region.convert('L') + line_image_path = os.path.join(args.database_path, 'line_image', args.dataset, line_image_file_name) + imgray.save(line_image_path) + return base_name, line_image_path + +def write_kaldi_process_data_files(base_name, line_id, text): + """creates files requires for dictionary and feats.scp. + Input + ----- + image_path (string): name of the page image. + line_id (str): line number on the page image. + text: transcription of the line image. + base_name (string): + """ + writer_id = str(base_name.split('-')[1]) + writer_id = str(writer_id).zfill(6) + writer_id = 'writer' + writer_id + utt_id = writer_id + '_' + base_name + '_' + str(line_id).zfill(6) + line_image_file_name = base_name + '_' + str(line_id).zfill(6) + '.png' + image_path = os.path.join(args.database_path, 'line_image', args.dataset, line_image_file_name) + text_fh.write(utt_id + ' ' + text + '\n') + utt2spk_fh.write(utt_id + ' ' + writer_id + '\n') + image_fh.write(utt_id + ' ' + image_path + '\n') + +### main ### +text_file = os.path.join('data', args.dataset, 'text') +text_fh = open(text_file, 'w', encoding='utf-8') +utt2spk_file = os.path.join('data', args.dataset, 'utt2spk') +utt2spk_fh = open(utt2spk_file, 'w', encoding='utf-8') +image_file = os.path.join('data', args.dataset, 'images.scp') +image_fh = open(image_file, 'w', encoding='utf-8') + +xml_path = os.path.join(args.database_path, 'xml', args.dataset) + '/rimes_2011.xml' +page_image_folder = os.path.join(args.database_path, 'page_image', args.dataset) +doc = minidom.parse(xml_path) +single_page = doc.getElementsByTagName('SinglePage') +for page in single_page: + file_name = page.getAttribute('FileName') + line = page.getElementsByTagName('Line') + id = 0 + for node in line: + id += 1 + bottom = int(node.getAttribute('Bottom')) + left = int(node.getAttribute('Left')) + right = int(node.getAttribute('Right')) + top = int(node.getAttribute('Top')) + text = node.getAttribute('Value') + text_vect = text.split() # this is to avoid non-utf-8 spaces + text = " ".join(text_vect) + if args.augment: + base_name, image_path = get_line_images_from_page_image(file_name, left, right, top, bottom, str(id)) + write_kaldi_process_data_files(base_name, str(id), text) + additional_pixel = random.randint(1, args.pixel_scaling) + left, right, top, bottom = expand_aabb(left, right, top, bottom, args.pixel_scaling + additional_pixel + 1) + line_id = str(id) + '_scale' + str(2) + base_name, image_path = get_line_images_from_page_image(file_name, left, right, top, bottom, line_id) + write_kaldi_process_data_files(base_name, line_id, text) + else: + base_name, image_path = get_line_images_from_page_image(file_name, left, right, top, bottom, str(id)) + write_kaldi_process_data_files(base_name, str(id), text) diff --git a/egs/rimes/v1/local/score.sh b/egs/rimes/v1/local/score.sh new file mode 100755 index 00000000000..0cfbda9b556 --- /dev/null +++ b/egs/rimes/v1/local/score.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +set -e +cmd=run.pl +stage=0 +decode_mbr=false +stats=true +beam=6 +word_ins_penalty=0.0,0.5,1.0 +min_lmwt=7 +max_lmwt=17 +iter=final + +echo "$0 $@" # Print the command line for logging +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +decode_dir=$3 +steps/scoring/score_kaldi_wer.sh --word_ins_penalty $word_ins_penalty \ + --min_lmwt $min_lmwt --max_lmwt $max_lmwt "$@" + +steps/scoring/score_kaldi_cer.sh --word_ins_penalty $word_ins_penalty \ + --min_lmwt $min_lmwt --max_lmwt $max_lmwt --stage 2 "$@" + +local/score_paragraph.sh --word_ins_penalty $word_ins_penalty \ + --min_lmwt $min_lmwt --max_lmwt $max_lmwt $decode_dir diff --git a/egs/rimes/v1/local/score_paragraph.sh b/egs/rimes/v1/local/score_paragraph.sh new file mode 100755 index 00000000000..c6ef4da1d5b --- /dev/null +++ b/egs/rimes/v1/local/score_paragraph.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +min_lmwt=7 +max_lmwt=17 +word_ins_penalty=0.0,0.5,1.0 + +set -e +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +decode_dir=$1 +test_para=$decode_dir/scoring_kaldi/test_filt_para.txt + +cat $decode_dir/scoring_kaldi/test_filt.txt | \ + local/combine_line_txt_to_paragraph.py > $test_para + +for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do + for LMWT in $(seq $min_lmwt $max_lmwt); do + mkdir -p $decode_dir/para/penalty_$wip + cat $decode_dir/scoring_kaldi/penalty_$wip/$LMWT.txt | \ + local/combine_line_txt_to_paragraph.py > $decode_dir/para/penalty_$wip/$LMWT.txt + done +done + +for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do + for LMWT in $(seq $min_lmwt $max_lmwt); do + compute-wer --text --mode=present \ + ark:$test_para ark:$decode_dir/para/penalty_$wip/$LMWT.txt &> $decode_dir/para/wer_${LMWT}_${wip} || exit 1; + done +done + +for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do + for lmwt in $(seq $min_lmwt $max_lmwt); do + # adding /dev/null to the command list below forces grep to output the filename + grep WER $decode_dir/para/wer_${lmwt}_${wip} /dev/null + done +done | utils/best_wer.sh >& $decode_dir/para/best_wer || exit 1 diff --git a/egs/rimes/v1/local/train_lm.sh b/egs/rimes/v1/local/train_lm.sh new file mode 100755 index 00000000000..51927b7a97e --- /dev/null +++ b/egs/rimes/v1/local/train_lm.sh @@ -0,0 +1,105 @@ +#!/bin/bash + +# Copyright 2016 Vincent Nguyen +# 2016 Johns Hopkins University (author: Daniel Povey) +# 2017 Ashish Arora +# 2017 Hossein Hadian +# Apache 2.0 +# +# This script trains a LM on the training transcriptions. +# It is based on the example scripts distributed with PocoLM + +# It will check if pocolm is installed and if not will proceed with installation + +set -e +stage=0 +dir=data/local/local_lm +order=6 +echo "$0 $@" # Print the command line for logging +. ./utils/parse_options.sh || exit 1; + +lm_dir=${dir}/data + + +mkdir -p $dir +. ./path.sh || exit 1; # for KALDI_ROOT +export PATH=$KALDI_ROOT/tools/pocolm/scripts:$PATH +( # First make sure the pocolm toolkit is installed. + cd $KALDI_ROOT/tools || exit 1; + if [ -d pocolm ]; then + echo Not installing the pocolm toolkit since it is already there. + else + echo "$0: Please install the PocoLM toolkit with: " + echo " cd ../../../tools; extras/install_pocolm.sh; cd -" + exit 1; + fi +) || exit 1; + +bypass_metaparam_optim_opt= +# If you want to bypass the metaparameter optimization steps with specific metaparameters +# un-comment the following line, and change the numbers to some appropriate values. +# You can find the values from output log of train_lm.py. +# These example numbers of metaparameters is for 4-gram model (with min-counts) +# running with train_lm.py. +# The dev perplexity should be close to the non-bypassed model. +# Note: to use these example parameters, you may need to remove the .done files +# to make sure the make_lm_dir.py be called and tain only 3-gram model +#for order in 3; do +#rm -f ${lm_dir}/${num_word}_${order}.pocolm/.done +if [ $stage -le 0 ]; then + mkdir -p ${dir}/data + mkdir -p ${dir}/data/text + + echo "$0: Getting the Data sources" + + rm ${dir}/data/text/* 2>/dev/null || true + + # use the validation data as the dev set. + # Note: the name 'dev' is treated specially by pocolm, it automatically + # becomes the dev set. + head -2000 data/train/text | cut -d " " -f 2- > ${dir}/data/text/dev.txt + + # use the training data as an additional data source. + # we can later fold the dev data into this. + tail -n +2000 data/train/text | cut -d " " -f 2- > ${dir}/data/text/train.txt + + if [ -d "data/local/text_data" ]; then + cat data/local/text_data/fr_text | \ + utils/lang/bpe/prepend_words.py | utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt \ + | sed 's/@@//g' > ${dir}/data/text/corpus_text.txt + fi + + # for reporting perplexities, we'll use the "real" dev set. + # (the validation data is used as ${dir}/data/text/dev.txt to work + # out interpolation weights.) + # note, we can't put it in ${dir}/data/text/, because then pocolm would use + # it as one of the data sources. + cut -d " " -f 2- < data/test/text > ${dir}/data/real_dev_set.txt + cat ${dir}/data/text/{train,corpus_text}.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + cat ${dir}/data/word_count | awk '{print $2}' > ${dir}/data/wordlist +fi + +if [ $stage -le 1 ]; then + # decide on the vocabulary. + # Note: you'd use --wordlist if you had a previously determined word-list + # that you wanted to use. + # Note: if you have more than one order, use a certain amount of words as the + # vocab and want to restrict max memory for 'sort', + echo "$0: training the unpruned LM" + min_counts='corpus_text=2 train=1' + wordlist=${dir}/data/wordlist + + lm_name="`basename ${wordlist}`_${order}" + if [ -n "${min_counts}" ]; then + lm_name+="_`echo ${min_counts} | tr -s "[:blank:]" "_" | tr "=" "-"`" + fi + unpruned_lm_dir=${lm_dir}/${lm_name}.pocolm + train_lm.py --wordlist=${wordlist} --num-splits=20 --warm-start-ratio=20 \ + --limit-unk-history=true \ + ${bypass_metaparam_optim_opt} \ + ${dir}/data/text ${order} ${lm_dir}/work ${unpruned_lm_dir} + + get_data_prob.py ${dir}/data/real_dev_set.txt ${unpruned_lm_dir} 2>&1 | grep -F '[perplexity' + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${unpruned_lm_dir} | gzip -c > ${dir}/data/arpa/${order}gram_unpruned.arpa.gz +fi diff --git a/egs/rimes/v1/local/wer_output_filter b/egs/rimes/v1/local/wer_output_filter new file mode 100755 index 00000000000..d9cf1f4072e --- /dev/null +++ b/egs/rimes/v1/local/wer_output_filter @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Hossein Hadian + +# Apache 2.0 +# This script converts a BPE-encoded text to normal text. It is used in scoring + +import sys, io +import string +infile = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') +output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') + +for line in infile: + words = line.strip().split() + uttid = words[0] + transcript = ''.join(words[1:]) + transcript = transcript.replace('|', ' ') + output.write(uttid + ' ' + transcript + '\n') diff --git a/egs/rimes/v1/path.sh b/egs/rimes/v1/path.sh new file mode 100755 index 00000000000..c7ebe7f2abf --- /dev/null +++ b/egs/rimes/v1/path.sh @@ -0,0 +1,7 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LD_LIBRARY_PATH=$KALDI_ROOT/tools/openfst/lib:$LD_LIBRARY_PATH +export LC_ALL=C diff --git a/egs/rimes/v1/run_end2end.sh b/egs/rimes/v1/run_end2end.sh new file mode 100755 index 00000000000..d3e3da2be13 --- /dev/null +++ b/egs/rimes/v1/run_end2end.sh @@ -0,0 +1,113 @@ +#!/bin/bash + +# Copyright 2018 Hossein Hadian +# Ashish Arora +# Jonathan Chang +# Apache 2.0 + +set -e +stage=0 +nj=50 +overwrite=false +rimes_database=/export/corpora5/handwriting_ocr/RIMES +train_set=train +use_extra_corpus_text=true +. ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. + ## This relates to the queue. +. ./path.sh +. ./utils/parse_options.sh # e.g. this parses the above options + # if supplied. + +if [ $stage -le 0 ]; then + if [ -f data/train/text ] && ! $overwrite; then + echo "$0: Not processing, probably script have run from wrong stage" + echo "Exiting with status 1 to avoid data corruption" + exit 1; + fi + + echo "$0: Preparing data..." + local/prepare_data.sh --download-dir "$rimes_database" \ + --use_extra_corpus_text $use_extra_corpus_text + +fi + +mkdir -p data/{train,test,val}/data +if [ $stage -le 1 ]; then + echo "$(date) stage 1: getting allowed image widths for e2e training..." + image/get_image2num_frames.py --feat-dim 40 data/train + image/get_allowed_lengths.py --frame-subsampling-factor 4 10 data/train + echo "$(date) Extracting features, creating feats.scp file" + for set in train test val; do + local/extract_features.sh --nj $nj --cmd "$cmd" data/${set} + steps/compute_cmvn_stats.sh data/${set} || exit 1; + done + utils/fix_data_dir.sh data/train +fi + +if [ $stage -le 3 ]; then + echo "$0: Preparing BPE..." + # getting non-silence phones. + cut -d' ' -f2- data/train/text | \ +python3 <( +cat << "END" +import os, sys, io; +infile = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8'); +output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8'); +phone_dict = dict(); +for line in infile: + line_vect = line.strip().split(); + for word in line_vect: + for phone in word: + phone_dict[phone] = phone; +for phone in phone_dict.keys(): + output.write(phone+ '\n'); +END + ) > data/local/phones.txt + + cut -d' ' -f2- data/train/text > data/local/train_data.txt + cat data/local/phones.txt data/local/train_data.txt | \ + utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/learn_bpe.py -s 700 > data/local/bpe.txt + + for set in test train val; do + cut -d' ' -f1 data/$set/text > data/$set/ids + cut -d' ' -f2- data/$set/text | \ + utils/lang/bpe/prepend_words.py | utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt \ + | sed 's/@@//g' > data/$set/bpe_text + mv data/$set/text data/$set/text.old + paste -d' ' data/$set/ids data/$set/bpe_text > data/$set/text + rm -f data/$set/bpe_text data/$set/ids + done +fi + +if [ $stage -le 4 ]; then + echo "$0: Preparing dictionary and lang..." + local/prepare_dict.sh + utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 8 --sil-prob 0.0 --position-dependent-phones false \ + data/local/dict "" data/lang/temp data/lang + utils/lang/bpe/add_final_optional_silence.sh --final-sil-prob 0.5 data/lang +fi + +if [ $stage -le 5 ]; then + echo "$0: Estimating a language model for decoding..." + local/train_lm.sh + utils/format_lm.sh data/lang data/local/local_lm/data/arpa/6gram_unpruned.arpa.gz \ + data/local/dict/lexicon.txt data/lang +fi + +if [ $stage -le 6 ]; then + echo "$0: Calling the flat-start chain recipe..." + local/chain/run_e2e_cnn.sh --train_set $train_set +fi + +if [ $stage -le 7 ]; then + echo "$0: Aligning the training data using the e2e chain model..." + steps/nnet3/align.sh --nj 50 --cmd "$cmd" \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0 --acoustic-scale=1.0' \ + data/$train_set data/lang exp/chain/e2e_cnn_1a exp/chain/e2e_ali_train +fi + +if [ $stage -le 8 ]; then + echo "$0: Building a tree and training a regular chain model using the e2e alignments..." + local/chain/run_cnn_e2eali.sh --train_set $train_set +fi diff --git a/egs/rimes/v1/steps b/egs/rimes/v1/steps new file mode 120000 index 00000000000..1b186770dd1 --- /dev/null +++ b/egs/rimes/v1/steps @@ -0,0 +1 @@ +../../wsj/s5/steps/ \ No newline at end of file diff --git a/egs/rimes/v1/utils b/egs/rimes/v1/utils new file mode 120000 index 00000000000..a3279dc8679 --- /dev/null +++ b/egs/rimes/v1/utils @@ -0,0 +1 @@ +../../wsj/s5/utils/ \ No newline at end of file diff --git a/egs/rm/README.txt b/egs/rm/README.txt index ed588e481c6..4fa3d7c87e8 100644 --- a/egs/rm/README.txt +++ b/egs/rm/README.txt @@ -9,7 +9,7 @@ About the Resource Management corpus: Each subdirectory of this directory contains the scripts for a sequence of experiments. -s5 is the currently recommmended setup. +s5 is the currently recommended setup. s5: This is the "new-new-style" recipe. It is now finished. All further work will be on top of this style of recipe. Note: diff --git a/egs/rm/s5/local/chain/tuning/run_tdnn_wsj_rm_1a.sh b/egs/rm/s5/local/chain/tuning/run_tdnn_wsj_rm_1a.sh index 6b6c08e779a..2fd2556c19b 100755 --- a/egs/rm/s5/local/chain/tuning/run_tdnn_wsj_rm_1a.sh +++ b/egs/rm/s5/local/chain/tuning/run_tdnn_wsj_rm_1a.sh @@ -130,7 +130,7 @@ if [ $stage -le 7 ]; then echo " generating new layers, that are specific to rm. These layers "; echo " are added to the transferred part of the wsj network."; num_targets=$(tree-info --print-args=false $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/rm/s5/local/run_raw_fmllr.sh b/egs/rm/s5/local/run_raw_fmllr.sh index 20b475fa32e..e02002aa1d0 100755 --- a/egs/rm/s5/local/run_raw_fmllr.sh +++ b/egs/rm/s5/local/run_raw_fmllr.sh @@ -25,8 +25,8 @@ steps/decode_raw_fmllr.sh --use-normal-fmllr true --config conf/decode.config -- steps/align_raw_fmllr.sh --nj 8 --cmd "$train_cmd" data/train data/lang exp/tri3c exp/tri3c_ali - - + + if [ ! -f exp/ubm4c/final.mdl ]; then steps/train_ubm.sh --silence-weight 0.5 --cmd "$train_cmd" 400 data/train data/lang exp/tri3c_ali exp/ubm4c || exit 1; fi @@ -43,7 +43,7 @@ steps/decode_sgmm2.sh --config conf/decode.config --nj 20 --cmd "$decode_cmd" \ steps/decode_sgmm2.sh --use-fmllr true --config conf/decode.config --nj 20 --cmd "$decode_cmd" \ --transform-dir exp/tri3c/decode exp/sgmm2_4c/graph data/test exp/sgmm2_4c/decode_fmllr || exit 1; - + exit 0; @@ -61,7 +61,7 @@ exit 0; # awk -v scale=30.0 '{printf("%s [ ", $1); for (n=3;n exp/sgmm2_4c_x30/decode_ug/vecs.1 # ) # exit 0; -# ## +# ## # steps/decode_sgmm2.sh --config conf/decode.config --nj 20 --cmd "$decode_cmd" \ # exp/sgmm2_4c.no_transform/graph data/test exp/sgmm2_4c.no_transform/decode || exit 1; diff --git a/egs/rm/s5/local/run_sgmm2.sh b/egs/rm/s5/local/run_sgmm2.sh index 5fa683ff0c2..95a40141892 100755 --- a/egs/rm/s5/local/run_sgmm2.sh +++ b/egs/rm/s5/local/run_sgmm2.sh @@ -9,7 +9,7 @@ if [ ! -f exp/ubm4a/final.ubm ] || [ ! data/train/feats.scp -nt exp/ubm4a/final.ubm ]; then steps/train_ubm.sh --silence-weight 0.5 --cmd "$train_cmd" 400 data/train data/lang exp/tri3b_ali exp/ubm4a || exit 1; fi - + steps/train_sgmm2.sh --cmd "$train_cmd" 5000 7000 data/train data/lang exp/tri3b_ali exp/ubm4a/final.ubm exp/sgmm2_4a || exit 1; utils/mkgraph.sh data/lang exp/sgmm2_4a exp/sgmm2_4a/graph || exit 1; @@ -26,22 +26,22 @@ steps/decode_sgmm2.sh --use-fmllr true --config conf/decode.config --nj 20 --cmd steps/make_denlats_sgmm2.sh --nj 8 --sub-split 20 --cmd "$decode_cmd" --transform-dir exp/tri3b \ data/train data/lang exp/sgmm2_4a_ali exp/sgmm2_4a_denlats steps/train_mmi_sgmm2.sh --cmd "$decode_cmd" --transform-dir exp/tri3b --boost 0.2 \ - data/train data/lang exp/sgmm2_4a_ali exp/sgmm2_4a_denlats exp/sgmm2_4a_mmi_b0.2 + data/train data/lang exp/sgmm2_4a_ali exp/sgmm2_4a_denlats exp/sgmm2_4a_mmi_b0.2 for iter in 1 2 3 4; do steps/decode_sgmm2_rescore.sh --cmd "$decode_cmd" --iter $iter \ --transform-dir exp/tri3b/decode data/lang data/test exp/sgmm2_4a/decode exp/sgmm2_4a_mmi_b0.2/decode_it$iter & - done + done ( steps/train_mmi_sgmm2.sh --cmd "$decode_cmd" --transform-dir exp/tri3b --boost 0.2 --drop-frames true \ - data/train data/lang exp/sgmm2_4a_ali exp/sgmm2_4a_denlats exp/sgmm2_4a_mmi_b0.2_x + data/train data/lang exp/sgmm2_4a_ali exp/sgmm2_4a_denlats exp/sgmm2_4a_mmi_b0.2_x for iter in 1 2 3 4; do steps/decode_sgmm2_rescore.sh --cmd "$decode_cmd" --iter $iter \ --transform-dir exp/tri3b/decode data/lang data/test exp/sgmm2_4a/decode exp/sgmm2_4a_mmi_b0.2_x/decode_it$iter & - done + done ) -wait +wait steps/decode_combine.sh data/test data/lang exp/tri1/decode exp/tri2a/decode exp/combine_1_2a/decode || exit 1; steps/decode_combine.sh data/test data/lang exp/sgmm2_4a/decode exp/tri3b_mmi/decode exp/combine_sgmm2_4a_3b/decode || exit 1; # combining the sgmm run and the best MMI+fMMI run. diff --git a/egs/rm/s5/local/run_sgmm2x.sh b/egs/rm/s5/local/run_sgmm2x.sh index deea4feb13f..00730697693 100755 --- a/egs/rm/s5/local/run_sgmm2x.sh +++ b/egs/rm/s5/local/run_sgmm2x.sh @@ -26,14 +26,14 @@ steps/decode_sgmm2.sh --use-fmllr true --config conf/decode.config --nj 20 --cmd steps/make_denlats_sgmm2.sh --nj 8 --sub-split 20 --cmd "$decode_cmd" --transform-dir exp/tri3b \ data/train data/lang exp/sgmm2x_4a_ali exp/sgmm2x_4a_denlats steps/train_mmi_sgmm2.sh --cmd "$decode_cmd" --transform-dir exp/tri3b --boost 0.2 \ - data/train data/lang exp/sgmm2x_4a_ali exp/sgmm2x_4a_denlats exp/sgmm2x_4a_mmi_b0.2 + data/train data/lang exp/sgmm2x_4a_ali exp/sgmm2x_4a_denlats exp/sgmm2x_4a_mmi_b0.2 for iter in 1 2 3 4; do steps/decode_sgmm2_rescore.sh --cmd "$decode_cmd" --iter $iter \ --transform-dir exp/tri3b/decode data/lang data/test exp/sgmm2x_4a/decode exp/sgmm2x_4a_mmi_b0.2/decode_it$iter & - done + done -wait +wait steps/decode_combine.sh data/test data/lang exp/tri1/decode exp/tri2a/decode exp/combine_1_2a/decode || exit 1; steps/decode_combine.sh data/test data/lang exp/sgmm2x_4a/decode exp/tri3b_mmi/decode exp/combine_sgmm2x_4a_3b/decode || exit 1; # combining the sgmm run and the best MMI+fMMI run. diff --git a/egs/rm/s5/local/run_sgmm_multiling.sh b/egs/rm/s5/local/run_sgmm_multiling.sh index 2b2af7f5ca6..42369cd2937 100755 --- a/egs/rm/s5/local/run_sgmm_multiling.sh +++ b/egs/rm/s5/local/run_sgmm_multiling.sh @@ -45,7 +45,7 @@ utils/convert_models.sh exp/tri2b data_ml/lang_rm exp_ml/tri2b_rm data_ml/lang_r utils/convert_models.sh ../../wsj/exp/tri4b data_ml/lang_wsj exp_ml/tri4b_wsj data_ml/lang -# Re-do the alignment of the RM tri2b setup with the converted models +# Re-do the alignment of the RM tri2b setup with the converted models # (this avoids the hassle of converting the alignment.) steps/align_si.sh --nj 8 --cmd "$train_cmd" data_ml/train_rm data_ml/lang exp_ml/tri2b_rm \ exp_ml/tri2b_rm_ali || exit 1; @@ -66,7 +66,7 @@ steps/train_sat.sh 1800 9000 data_ml/train_rm data_ml/lang exp_ml/tri2b_rm_ali e # "merge-tree" program will need, for each tree, a record of which sets of # phones it was supposed to handle, since this is not recorded in the tree # itself-- we can get this from the transition models which do record this. -# probably the "merge-tree" program will have usage: +# probably the "merge-tree" program will have usage: # merge-tree ... # where the phone-set-n's will probably be filenames that contain lists of # the phones. diff --git a/egs/rm/s5/local/run_vtln2.sh b/egs/rm/s5/local/run_vtln2.sh index 6437032ca61..b87030d2e3d 100755 --- a/egs/rm/s5/local/run_vtln2.sh +++ b/egs/rm/s5/local/run_vtln2.sh @@ -59,4 +59,4 @@ steps/compute_cmvn_stats.sh data/test_vtln exp/make_mfcc/test_vtln $featdir # %WER 3.13 [ 392 / 12533, 59 ins, 64 del, 269 sub ] exp/tri3b/decode.si/wer_3 # %WER 10.36 [ 1298 / 12533, 147 ins, 192 del, 959 sub ] exp/tri3b/decode_ug/wer_12 # %WER 13.48 [ 1689 / 12533, 159 ins, 277 del, 1253 sub ] exp/tri3b/decode_ug.si/wer_13 -# a04:s5: \ No newline at end of file +# a04:s5: diff --git a/egs/sitw/v1/local/make_musan.py b/egs/sitw/v1/local/make_musan.py index 74c434990fb..833da0619c9 100755 --- a/egs/sitw/v1/local/make_musan.py +++ b/egs/sitw/v1/local/make_musan.py @@ -47,9 +47,9 @@ def prepare_music(root_dir, use_vocals): utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n" num_good_files += 1 else: - print("Missing file", utt) + print("Missing file {}".format(utt)) num_bad_files += 1 - print("In music directory, processed", num_good_files, "files;", num_bad_files, "had missing wav data") + print("In music directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files)) return utt2spk_str, utt2wav_str def prepare_speech(root_dir): @@ -73,9 +73,9 @@ def prepare_speech(root_dir): utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n" num_good_files += 1 else: - print("Missing file", utt) + print("Missing file {}".format(utt)) num_bad_files += 1 - print("In speech directory, processed", num_good_files, "files;", num_bad_files, "had missing wav data") + print("In speech directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files)) return utt2spk_str, utt2wav_str def prepare_noise(root_dir): @@ -99,9 +99,9 @@ def prepare_noise(root_dir): utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n" num_good_files += 1 else: - print("Missing file", utt) + print("Missing file {}".format(utt)) num_bad_files += 1 - print("In noise directory, processed", num_good_files, "files;", num_bad_files, "had missing wav data") + print("In noise directory, processed {} files: {} had missing wav data".format(num_good_files, num_bad_files)) return utt2spk_str, utt2wav_str def main(): diff --git a/egs/sitw/v1/local/make_voxceleb1.pl b/egs/sitw/v1/local/make_voxceleb1.pl index e56483563b8..279f8d6cbfe 100755 --- a/egs/sitw/v1/local/make_voxceleb1.pl +++ b/egs/sitw/v1/local/make_voxceleb1.pl @@ -26,6 +26,10 @@ system("wget -O $out_dir/voxceleb1_sitw_overlap.txt http://www.openslr.org/resources/49/voxceleb1_sitw_overlap.txt"); } +if (! -e "$data_base/vox1_meta.csv") { + system("wget -O $data_base/vox1_meta.csv http://www.openslr.org/resources/49/vox1_meta.csv"); +} + # sitw_overlap contains the list of speakers that also exist in our evaluation set, SITW. my %sitw_overlap = (); open(OVERLAP, "<", "$out_dir/voxceleb1_sitw_overlap.txt") or die "Could not open the overlap file $out_dir/voxceleb1_sitw_overlap.txt"; @@ -34,6 +38,20 @@ my $spkr_id = $_; $sitw_overlap{$spkr_id} = (); } +close(OVERLAP) or die; + +open(META_IN, "<", "$data_base/vox1_meta.csv") or die "Could not open the meta data file $data_base/vox1_meta.csv"; + +# Also add the banned speakers to sitw_overlap using their ID format in the +# newest version of VoxCeleb. +while () { + chomp; + my ($vox_id, $spkr_id, $gender, $nation, $set) = split; + if (exists($sitw_overlap{$spkr_id})) { + $sitw_overlap{$vox_id} = (); + } +} +close(META_IN) or die; opendir my $dh, "$data_base/voxceleb1_wav" or die "Cannot open directory: $!"; my @spkr_dirs = grep {-d "$data_base/voxceleb1_wav/$_" && ! /^\.{1,2}$/} readdir($dh); diff --git a/egs/sitw/v1/run.sh b/egs/sitw/v1/run.sh index 68d08dfc170..e016f8a4752 100755 --- a/egs/sitw/v1/run.sh +++ b/egs/sitw/v1/run.sh @@ -122,7 +122,7 @@ if [ $stage -le 4 ]; then # Make a reverberated version of the VoxCeleb2 list. Note that we don't add any # additive noise here. - python steps/data/reverberate_data_dir.py \ + steps/data/reverberate_data_dir.py \ "${rvb_opts[@]}" \ --speech-rvb-probability 1 \ --pointsource-noise-addition-probability 0 \ @@ -147,11 +147,11 @@ if [ $stage -le 4 ]; then done # Augment with musan_noise - python steps/data/augment_data_dir.py --utt-suffix "noise" --fg-interval 1 --fg-snrs "15:10:5:0" --fg-noise-dir "data/musan_noise" data/train_100k data/train_100k_noise + steps/data/augment_data_dir.py --utt-suffix "noise" --fg-interval 1 --fg-snrs "15:10:5:0" --fg-noise-dir "data/musan_noise" data/train_100k data/train_100k_noise # Augment with musan_music - python steps/data/augment_data_dir.py --utt-suffix "music" --bg-snrs "15:10:8:5" --num-bg-noises "1" --bg-noise-dir "data/musan_music" data/train_100k data/train_100k_music + steps/data/augment_data_dir.py --utt-suffix "music" --bg-snrs "15:10:8:5" --num-bg-noises "1" --bg-noise-dir "data/musan_music" data/train_100k data/train_100k_music # Augment with musan_speech - python steps/data/augment_data_dir.py --utt-suffix "babble" --bg-snrs "20:17:15:13" --num-bg-noises "3:4:5:6:7" --bg-noise-dir "data/musan_speech" data/train_100k data/train_100k_babble + steps/data/augment_data_dir.py --utt-suffix "babble" --bg-snrs "20:17:15:13" --num-bg-noises "3:4:5:6:7" --bg-noise-dir "data/musan_speech" data/train_100k data/train_100k_babble # Combine reverb, noise, music, and babble into one directory. utils/combine_data.sh data/train_aug data/train_100k_reverb data/train_100k_noise data/train_100k_music data/train_100k_babble diff --git a/egs/sitw/v2/run.sh b/egs/sitw/v2/run.sh index 499d436366a..8aeecc18b3f 100755 --- a/egs/sitw/v2/run.sh +++ b/egs/sitw/v2/run.sh @@ -88,7 +88,7 @@ if [ $stage -le 2 ]; then # Make a reverberated version of the VoxCeleb2 list. Note that we don't add any # additive noise here. - python steps/data/reverberate_data_dir.py \ + steps/data/reverberate_data_dir.py \ "${rvb_opts[@]}" \ --speech-rvb-probability 1 \ --pointsource-noise-addition-probability 0 \ @@ -113,11 +113,11 @@ if [ $stage -le 2 ]; then done # Augment with musan_noise - python steps/data/augment_data_dir.py --utt-suffix "noise" --fg-interval 1 --fg-snrs "15:10:5:0" --fg-noise-dir "data/musan_noise" data/train data/train_noise + steps/data/augment_data_dir.py --utt-suffix "noise" --fg-interval 1 --fg-snrs "15:10:5:0" --fg-noise-dir "data/musan_noise" data/train data/train_noise # Augment with musan_music - python steps/data/augment_data_dir.py --utt-suffix "music" --bg-snrs "15:10:8:5" --num-bg-noises "1" --bg-noise-dir "data/musan_music" data/train data/train_music + steps/data/augment_data_dir.py --utt-suffix "music" --bg-snrs "15:10:8:5" --num-bg-noises "1" --bg-noise-dir "data/musan_music" data/train data/train_music # Augment with musan_speech - python steps/data/augment_data_dir.py --utt-suffix "babble" --bg-snrs "20:17:15:13" --num-bg-noises "3:4:5:6:7" --bg-noise-dir "data/musan_speech" data/train data/train_babble + steps/data/augment_data_dir.py --utt-suffix "babble" --bg-snrs "20:17:15:13" --num-bg-noises "3:4:5:6:7" --bg-noise-dir "data/musan_speech" data/train data/train_babble # Combine reverb, noise, music, and babble into one directory. utils/combine_data.sh data/train_aug data/train_reverb data/train_noise data/train_music data/train_babble diff --git a/egs/sprakbanken/s5/local/chain/tuning/run_lstm_1a.sh b/egs/sprakbanken/s5/local/chain/tuning/run_lstm_1a.sh index ec6b8941955..47557f93696 100755 --- a/egs/sprakbanken/s5/local/chain/tuning/run_lstm_1a.sh +++ b/egs/sprakbanken/s5/local/chain/tuning/run_lstm_1a.sh @@ -152,7 +152,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/sprakbanken/s5/local/chain/tuning/run_lstm_1b.sh b/egs/sprakbanken/s5/local/chain/tuning/run_lstm_1b.sh index 53aa92710e8..7afa1b7f902 100755 --- a/egs/sprakbanken/s5/local/chain/tuning/run_lstm_1b.sh +++ b/egs/sprakbanken/s5/local/chain/tuning/run_lstm_1b.sh @@ -153,7 +153,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/sprakbanken/s5/local/chain/tuning/run_lstm_1c.sh b/egs/sprakbanken/s5/local/chain/tuning/run_lstm_1c.sh index 83c2f3607f0..e69e499e152 100755 --- a/egs/sprakbanken/s5/local/chain/tuning/run_lstm_1c.sh +++ b/egs/sprakbanken/s5/local/chain/tuning/run_lstm_1c.sh @@ -151,7 +151,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/sprakbanken/s5/local/chain/tuning/run_lstm_1d.sh b/egs/sprakbanken/s5/local/chain/tuning/run_lstm_1d.sh index 2665ea91ff8..86e0352828c 100755 --- a/egs/sprakbanken/s5/local/chain/tuning/run_lstm_1d.sh +++ b/egs/sprakbanken/s5/local/chain/tuning/run_lstm_1d.sh @@ -164,7 +164,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/sprakbanken/s5/local/chain/tuning/run_lstm_1e.sh b/egs/sprakbanken/s5/local/chain/tuning/run_lstm_1e.sh index 80f67d34ba9..313f899a471 100755 --- a/egs/sprakbanken/s5/local/chain/tuning/run_lstm_1e.sh +++ b/egs/sprakbanken/s5/local/chain/tuning/run_lstm_1e.sh @@ -152,7 +152,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/sprakbanken/s5/local/chain/tuning/run_tdnn_1b.sh b/egs/sprakbanken/s5/local/chain/tuning/run_tdnn_1b.sh index e242660a10e..600f27ddf86 100755 --- a/egs/sprakbanken/s5/local/chain/tuning/run_tdnn_1b.sh +++ b/egs/sprakbanken/s5/local/chain/tuning/run_tdnn_1b.sh @@ -135,7 +135,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/sprakbanken/s5/local/chain/tuning/run_tdnn_lstm_1a.sh b/egs/sprakbanken/s5/local/chain/tuning/run_tdnn_lstm_1a.sh index 86dc4b75a24..cedc448464a 100755 --- a/egs/sprakbanken/s5/local/chain/tuning/run_tdnn_lstm_1a.sh +++ b/egs/sprakbanken/s5/local/chain/tuning/run_tdnn_lstm_1a.sh @@ -145,7 +145,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/sprakbanken/s5/local/norm_dk/write_punct.sh b/egs/sprakbanken/s5/local/norm_dk/write_punct.sh index 57726bd44cb..3b8decaf376 100755 --- a/egs/sprakbanken/s5/local/norm_dk/write_punct.sh +++ b/egs/sprakbanken/s5/local/norm_dk/write_punct.sh @@ -22,4 +22,4 @@ perl -pe 's/([\n ])\;([ \n])/\1SEMIKOLON\2/g' | \ perl -pe 's/([\n ])_NL_([ \n])/\1NY LINJE\2/g' | \ perl -pe 's/([\n ])_NS_([ \n])/\1NYT AFSNIT\2/g' | \ -tr -s ' ' \ No newline at end of file +tr -s ' ' diff --git a/egs/sprakbanken/s5/local/normalize_transcript.py b/egs/sprakbanken/s5/local/normalize_transcript.py index 2374418bee7..21d70864f04 100755 --- a/egs/sprakbanken/s5/local/normalize_transcript.py +++ b/egs/sprakbanken/s5/local/normalize_transcript.py @@ -17,8 +17,8 @@ "\t": " " } -from_chars = ''.join(normdict.keys()) -to_chars = ''.join(normdict.values()) +from_chars = ''.join(list(normdict.keys())) +to_chars = ''.join(list(normdict.values())) #t_table = maketrans(from_chars, to_chars) diff --git a/egs/sprakbanken/s5/local/sprak2kaldi.py b/egs/sprakbanken/s5/local/sprak2kaldi.py index f3abf1d9a38..5fa4baa1fa2 100755 --- a/egs/sprakbanken/s5/local/sprak2kaldi.py +++ b/egs/sprakbanken/s5/local/sprak2kaldi.py @@ -16,6 +16,7 @@ # limitations under the License. ''' +from __future__ import print_function import sys @@ -59,8 +60,8 @@ def create_parallel_file_list(session, sndlist, txtlist): if len(os.listdir(session.sessiondir)) != 0: # Check if there are files in the directory global n n += 1 - session.sessiondir = session.sessiondir + "_" + str(n) - session.speaker_id = session.speaker_id + "_" + str(n) + session.sessiondir = "{}_{}".format(session.sessiondir, n) + session.speaker_id = "{}_{}".format(session.speaker_id, n) os.mkdir(session.sessiondir) shadow = True else: diff --git a/egs/sprakbanken/s5/local/sprak2parallel.py b/egs/sprakbanken/s5/local/sprak2parallel.py index b5fe56fd60f..3dc82e30ac2 100755 --- a/egs/sprakbanken/s5/local/sprak2parallel.py +++ b/egs/sprakbanken/s5/local/sprak2parallel.py @@ -76,8 +76,8 @@ def make_speech_corpus(top, dest, srcfolder): session.sessiondir = os.path.join(dest, session.filestem) +"."+ session.speaker_id if os.path.exists(session.sessiondir): n += 1 - session.sessiondir = session.sessiondir+ "_" +str(n) - session.speaker_id+ "_" +str(n) + session.sessiondir = "{}_{}".format(session.sessiondir, n) + session.speaker_id = "{}_{}".format(session.speaker_id, n) os.mkdir(session.sessiondir) create_parallel_files(session) diff --git a/egs/sprakbanken/s5/local/sprakparser.py b/egs/sprakbanken/s5/local/sprakparser.py index 7bdf6ac94e3..1221cf0b023 100755 --- a/egs/sprakbanken/s5/local/sprakparser.py +++ b/egs/sprakbanken/s5/local/sprakparser.py @@ -22,11 +22,12 @@ ''' +from __future__ import print_function import codecs import os -class Session: +class Session(object): delimit = ">-<" @@ -151,7 +152,7 @@ def set_channel_vars(self, handle): pass def create_filename(self, uid, file_ending): - return self.filestem+ "." +self.speaker_id+ "." +str(uid)+ "." +file_ending + return "{}.{}.{}.{}".format(self.filestem, self.speaker_id, uid, file_ending) def wavpath(self, topfolder): prefix, suffix = topfolder.rsplit('/data/', 1) diff --git a/egs/sprakbanken/s5/local/writenumbers.py b/egs/sprakbanken/s5/local/writenumbers.py index df3235243d4..c419b3c7550 100755 --- a/egs/sprakbanken/s5/local/writenumbers.py +++ b/egs/sprakbanken/s5/local/writenumbers.py @@ -22,6 +22,7 @@ Changed to write output to file to prevent problems with shell ascii codec. ''' +from __future__ import print_function import sys import os @@ -215,7 +216,7 @@ def rmPvAnnotation(string): def normNumber(line, table): tokens = line.split() - keys = table.keys() + keys = list(table.keys()) for num, tok in enumerate(tokens): newtoks = splitNumeric(tok) if newtoks != False: diff --git a/egs/sprakbanken_swe/s5/local/normalize_transcript.py b/egs/sprakbanken_swe/s5/local/normalize_transcript.py index 90e45744e2a..150a9563aba 100755 --- a/egs/sprakbanken_swe/s5/local/normalize_transcript.py +++ b/egs/sprakbanken_swe/s5/local/normalize_transcript.py @@ -18,8 +18,8 @@ } #removes all the above signs -from_chars = ''.join(normdict.keys()) -to_chars = ''.join(normdict.values()) +from_chars = ''.join(list(normdict.keys())) +to_chars = ''.join(list(normdict.values())) t_table = str.maketrans(normdict) diff --git a/egs/sprakbanken_swe/s5/local/sprak2kaldi.py b/egs/sprakbanken_swe/s5/local/sprak2kaldi.py index cc67344c36e..8f723762e50 100755 --- a/egs/sprakbanken_swe/s5/local/sprak2kaldi.py +++ b/egs/sprakbanken_swe/s5/local/sprak2kaldi.py @@ -16,6 +16,7 @@ # limitations under the License. ''' +from __future__ import print_function import sys @@ -59,8 +60,8 @@ def create_parallel_file_list(session, sndlist, txtlist): if len(os.listdir(session.sessiondir)) != 0: # Check if there are files in the directory global n n += 1 - session.sessiondir = session.sessiondir + "_" + str(n) - session.speaker_id = session.speaker_id + "_" + str(n) + session.sessiondir = "{}_{}".format(session.sessiondir, n) + session.speaker_id = "{}_{}".format(session.speaker_id, n) os.mkdir(session.sessiondir) shadow = True else: diff --git a/egs/sprakbanken_swe/s5/local/sprakparser.py b/egs/sprakbanken_swe/s5/local/sprakparser.py index 4775328b56b..0951f7f39e7 100755 --- a/egs/sprakbanken_swe/s5/local/sprakparser.py +++ b/egs/sprakbanken_swe/s5/local/sprakparser.py @@ -26,7 +26,7 @@ import codecs import os -class Session: +class Session(object): delimit = ">-<" @@ -151,7 +151,7 @@ def set_channel_vars(self, handle): pass def create_filename(self, uid, file_ending): - return self.filestem+ "." +self.speaker_id+ "." +str(uid)+ "." +file_ending + return "{}.{}.{}.{}".format(self.filestem, self.speaker_id, uid, file_ending) def wavpath(self, topfolder): prefix, suffix = topfolder.rsplit('/data/', 1) diff --git a/egs/sre08/v1/local/score_sre08.sh b/egs/sre08/v1/local/score_sre08.sh index 92831502f45..c1584946735 100755 --- a/egs/sre08/v1/local/score_sre08.sh +++ b/egs/sre08/v1/local/score_sre08.sh @@ -35,11 +35,11 @@ tot_eer=0.0 printf '% 12s' 'EER:' for condition in $(seq 8); do eer=$(awk '{print $3}' $scores | paste - $trials | awk -v c=$condition '{n=4+c; if ($n == "Y") print $1, $4}' | compute-eer - 2>/dev/null) - tot_eer=$(echo "$tot_eer+$eer" | bc) + tot_eer=$(perl -e "print ($tot_eer+$eer);") eers[$condition]=$eer done -eers[0]=$(echo "$tot_eer/8" | bc -l) +eers[0]=$(perl -e "print ($tot_eer/8.0);") for i in $(seq 0 8); do printf '% 7.2f' ${eers[$i]} diff --git a/egs/sre08/v1/sid/compute_vad_decision.sh b/egs/sre08/v1/sid/compute_vad_decision.sh deleted file mode 100755 index 7099d063c7f..00000000000 --- a/egs/sre08/v1/sid/compute_vad_decision.sh +++ /dev/null @@ -1,72 +0,0 @@ -#!/bin/bash - -# Copyright 2013 Daniel Povey -# Apache 2.0 -# To be run from .. (one directory up from here) -# see ../run.sh for example - -# Compute energy based VAD output -# We do this in just one job; it's fast. -# - -nj=2 -cmd=run.pl -vad_config=conf/vad.conf - -echo "$0 $@" # Print the command line for logging - -if [ -f path.sh ]; then . ./path.sh; fi -. parse_options.sh || exit 1; - -if [ $# != 3 ]; then - echo "Usage: $0 [options] "; - echo "e.g.: $0 data/train exp/make_vad mfcc" - echo " Options:" - echo " --vad-config # config passed to compute-vad-energy" - echo " --nj # number of parallel jobs" - echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." - exit 1; -fi - -data=$1 -logdir=$2 -vaddir=$3 - -# make $vaddir an absolute pathname. -vaddir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $vaddir ${PWD}` - -# use "name" as part of name of the archive. -name=`basename $data` - -mkdir -p $vaddir || exit 1; -mkdir -p $logdir || exit 1; - - -for f in $data/feats.scp "$vad_config"; do - if [ ! -f $f ]; then - echo "compute_vad_decision.sh: no such file $f" - exit 1; - fi -done - -utils/split_data.sh $data $nj || exit 1; -sdata=$data/split$nj; - -$cmd JOB=1:$nj $logdir/vad_${name}.JOB.log \ - compute-vad --config=$vad_config scp:$sdata/JOB/feats.scp ark,scp:$vaddir/vad_${name}.JOB.ark,$vaddir/vad_${name}.JOB.scp \ - || exit 1; - -for ((n=1; n<=nj; n++)); do - cat $vaddir/vad_${name}.$n.scp || exit 1; -done > $data/vad.scp - -nc=`cat $data/vad.scp | wc -l` -nu=`cat $data/feats.scp | wc -l` -if [ $nc -ne $nu ]; then - echo "**Warning it seems not all of the speakers got VAD output ($nc != $nu);" - echo "**validate_data_dir.sh will fail; you might want to use fix_data_dir.sh" - [ $nc -eq 0 ] && exit 1; -fi - - -echo "Created VAD output for $name" diff --git a/egs/sre08/v1/sid/compute_vad_decision.sh b/egs/sre08/v1/sid/compute_vad_decision.sh new file mode 120000 index 00000000000..174321b847e --- /dev/null +++ b/egs/sre08/v1/sid/compute_vad_decision.sh @@ -0,0 +1 @@ +../steps/compute_vad_decision.sh \ No newline at end of file diff --git a/egs/sre08/v1/sid/nnet3/xvector/allocate_egs.py b/egs/sre08/v1/sid/nnet3/xvector/allocate_egs.py index 72a4572d9a0..e1a4fc534e0 100755 --- a/egs/sre08/v1/sid/nnet3/xvector/allocate_egs.py +++ b/egs/sre08/v1/sid/nnet3/xvector/allocate_egs.py @@ -65,6 +65,7 @@ # We're using python 3.x style print but want it to work in python 2.x. from __future__ import print_function +from __future__ import division import re, os, argparse, sys, math, warnings, random def get_args(): @@ -196,7 +197,7 @@ def deterministic_chunk_length(archive_id, num_archives, min_frames_per_chunk, m elif num_archives == 1: return int(max_frames_per_chunk); else: - return int(math.pow(float(max_frames_per_chunk) / + return int(math.pow(float(max_frames_per_chunk)/ min_frames_per_chunk, float(archive_id) / (num_archives-1)) * min_frames_per_chunk + 0.5) @@ -247,7 +248,7 @@ def main(): length = deterministic_chunk_length(archive_index, args.num_archives, args.min_frames_per_chunk, args.max_frames_per_chunk); print("{0} {1}".format(archive_index + 1, length), file=info_f) archive_chunk_lengths.append(length) - this_num_egs = int((args.frames_per_iter / length) + 1) + this_num_egs = int(float(args.frames_per_iter) / length + 1) this_egs = [ ] # A 2-tuple of the form (utt-id, start-frame) spkrs = args.num_repeats * list(spk2utt.keys()) random.shuffle(spkrs) diff --git a/egs/sre10/v1/local/prepare_for_eer.py b/egs/sre10/v1/local/prepare_for_eer.py index 59d2985e7c2..bb4e666f0ab 100755 --- a/egs/sre10/v1/local/prepare_for_eer.py +++ b/egs/sre10/v1/local/prepare_for_eer.py @@ -1,3 +1,4 @@ +from __future__ import print_function # Copyright 2015 David Snyder # Apache 2.0. # @@ -12,4 +13,4 @@ spkrutt2target[spkr+utt]=target for line in scores: spkr, utt, score = line.strip().split() - print score, spkrutt2target[spkr+utt] + print("{} {}".format(score, spkrutt2target[spkr+utt])) diff --git a/egs/sre16/v1/local/make_musan.py b/egs/sre16/v1/local/make_musan.py index b3f6652ba40..7735bd28818 100755 --- a/egs/sre16/v1/local/make_musan.py +++ b/egs/sre16/v1/local/make_musan.py @@ -43,9 +43,9 @@ def prepare_music(root_dir, use_vocals): utt2wav_str = utt2wav_str + utt + " sox -t wav " + utt2wav[utt] + " -r 8k -t wav - |\n" num_good_files += 1 else: - print("Missing file", utt) + print("Missing file {}".format(utt)) num_bad_files += 1 - print("In music directory, processed", num_good_files, "files;", num_bad_files, "had missing wav data") + print("In music directory, processed {} files; {} had missing wav data".format(num_good_files, num_bad_files)) return utt2spk_str, utt2wav_str def prepare_speech(root_dir): @@ -69,9 +69,9 @@ def prepare_speech(root_dir): utt2wav_str = utt2wav_str + utt + " sox -t wav " + utt2wav[utt] + " -r 8k -t wav - |\n" num_good_files += 1 else: - print("Missing file", utt) + print("Missing file {}".format(utt)) num_bad_files += 1 - print("In speech directory, processed", num_good_files, "files;", num_bad_files, "had missing wav data") + print("In speech directory, processed {} files; {} had missing wav data".format(num_good_files, num_bad_files)) return utt2spk_str, utt2wav_str def prepare_noise(root_dir): @@ -95,9 +95,9 @@ def prepare_noise(root_dir): utt2wav_str = utt2wav_str + utt + " sox -t wav " + utt2wav[utt] + " -r 8k -t wav - |\n" num_good_files += 1 else: - print("Missing file", utt) + print("Missing file {}".format(utt)) num_bad_files += 1 - print("In noise directory, processed", num_good_files, "files;", num_bad_files, "had missing wav data") + print("In noise directory, processed {} files; {} had missing wav data".format(num_good_files, num_bad_files)) return utt2spk_str, utt2wav_str def main(): diff --git a/egs/sre16/v1/run.sh b/egs/sre16/v1/run.sh index 52ee86ec5b2..28481e27c3a 100755 --- a/egs/sre16/v1/run.sh +++ b/egs/sre16/v1/run.sh @@ -130,7 +130,7 @@ if [ $stage -le 4 ]; then # Make a reverberated version of the SRE list. Note that we don't add any # additive noise here. - python steps/data/reverberate_data_dir.py \ + steps/data/reverberate_data_dir.py \ "${rvb_opts[@]}" \ --speech-rvb-probability 1 \ --pointsource-noise-addition-probability 0 \ @@ -155,11 +155,11 @@ if [ $stage -le 4 ]; then done # Augment with musan_noise - python steps/data/augment_data_dir.py --utt-suffix "noise" --fg-interval 1 --fg-snrs "15:10:5:0" --fg-noise-dir "data/musan_noise" data/sre data/sre_noise + steps/data/augment_data_dir.py --utt-suffix "noise" --fg-interval 1 --fg-snrs "15:10:5:0" --fg-noise-dir "data/musan_noise" data/sre data/sre_noise # Augment with musan_music - python steps/data/augment_data_dir.py --utt-suffix "music" --bg-snrs "15:10:8:5" --num-bg-noises "1" --bg-noise-dir "data/musan_music" data/sre data/sre_music + steps/data/augment_data_dir.py --utt-suffix "music" --bg-snrs "15:10:8:5" --num-bg-noises "1" --bg-noise-dir "data/musan_music" data/sre data/sre_music # Augment with musan_speech - python steps/data/augment_data_dir.py --utt-suffix "babble" --bg-snrs "20:17:15:13" --num-bg-noises "3:4:5:6:7" --bg-noise-dir "data/musan_speech" data/sre data/sre_babble + steps/data/augment_data_dir.py --utt-suffix "babble" --bg-snrs "20:17:15:13" --num-bg-noises "3:4:5:6:7" --bg-noise-dir "data/musan_speech" data/sre data/sre_babble # Combine reverb, noise, music, and babble into one directory. utils/combine_data.sh data/sre_aug data/sre_reverb data/sre_noise data/sre_music data/sre_babble diff --git a/egs/sre16/v2/run.sh b/egs/sre16/v2/run.sh index 0bc06431138..b2072dfd69d 100755 --- a/egs/sre16/v2/run.sh +++ b/egs/sre16/v2/run.sh @@ -82,7 +82,7 @@ if [ $stage -le 0 ]; then fi if [ $stage -le 1 ]; then - # Make filterbanks and compute the energy-based VAD for each dataset + # Make MFCCs and compute the energy-based VAD for each dataset if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then utils/create_split_dir.pl \ /export/b{14,15,16,17}/$USER/kaldi-data/egs/sre16/v2/xvector-$(date +'%m_%d_%H_%M')/mfccs/storage $mfccdir/storage @@ -120,7 +120,7 @@ if [ $stage -le 2 ]; then # Make a reverberated version of the SWBD+SRE list. Note that we don't add any # additive noise here. - python steps/data/reverberate_data_dir.py \ + steps/data/reverberate_data_dir.py \ "${rvb_opts[@]}" \ --speech-rvb-probability 1 \ --pointsource-noise-addition-probability 0 \ @@ -145,11 +145,11 @@ if [ $stage -le 2 ]; then done # Augment with musan_noise - python steps/data/augment_data_dir.py --utt-suffix "noise" --fg-interval 1 --fg-snrs "15:10:5:0" --fg-noise-dir "data/musan_noise" data/swbd_sre data/swbd_sre_noise + steps/data/augment_data_dir.py --utt-suffix "noise" --fg-interval 1 --fg-snrs "15:10:5:0" --fg-noise-dir "data/musan_noise" data/swbd_sre data/swbd_sre_noise # Augment with musan_music - python steps/data/augment_data_dir.py --utt-suffix "music" --bg-snrs "15:10:8:5" --num-bg-noises "1" --bg-noise-dir "data/musan_music" data/swbd_sre data/swbd_sre_music + steps/data/augment_data_dir.py --utt-suffix "music" --bg-snrs "15:10:8:5" --num-bg-noises "1" --bg-noise-dir "data/musan_music" data/swbd_sre data/swbd_sre_music # Augment with musan_speech - python steps/data/augment_data_dir.py --utt-suffix "babble" --bg-snrs "20:17:15:13" --num-bg-noises "3:4:5:6:7" --bg-noise-dir "data/musan_speech" data/swbd_sre data/swbd_sre_babble + steps/data/augment_data_dir.py --utt-suffix "babble" --bg-snrs "20:17:15:13" --num-bg-noises "3:4:5:6:7" --bg-noise-dir "data/musan_speech" data/swbd_sre data/swbd_sre_babble # Combine reverb, noise, music, and babble into one directory. utils/combine_data.sh data/swbd_sre_aug data/swbd_sre_reverb data/swbd_sre_noise data/swbd_sre_music data/swbd_sre_babble @@ -159,7 +159,7 @@ if [ $stage -le 2 ]; then utils/subset_data_dir.sh data/swbd_sre_aug 128000 data/swbd_sre_aug_128k utils/fix_data_dir.sh data/swbd_sre_aug_128k - # Make filterbanks for the augmented data. Note that we do not compute a new + # Make MFCCs for the augmented data. Note that we do not compute a new # vad.scp file here. Instead, we use the vad.scp from the clean version of # the list. steps/make_mfcc.sh --mfcc-config conf/mfcc.conf --nj 40 --cmd "$train_cmd" \ diff --git a/egs/svhn/v1/local/process_data.py b/egs/svhn/v1/local/process_data.py index f6ea85118f9..2a5bfc9a0d6 100755 --- a/egs/svhn/v1/local/process_data.py +++ b/egs/svhn/v1/local/process_data.py @@ -6,6 +6,7 @@ """ This script prepares the training and test data for SVHN. """ +from __future__ import division import argparse import os @@ -16,11 +17,11 @@ parser = argparse.ArgumentParser(description="""Converts train/test data of SVHN (Street View House Numbers) dataset to Kaldi feature format""") -parser.add_argument('matlab_file', type=str, +parser.add_argument('matlab_file', help='path to SVHN matlab data file (cropped version)') -parser.add_argument('dir', type=str, +parser.add_argument('dir', help='output dir') -parser.add_argument('--out-ark', type=str, +parser.add_argument('--out-ark', default='-', help='where to write output feature data') args = parser.parse_args() @@ -48,7 +49,7 @@ def write_kaldi_matrix(file_handle, matrix, key): if num_cols != len(matrix[row_index]): raise Exception("All the rows of a matrix are expected to " "have the same length") - file_handle.write(" ".join(map(lambda x: str(x), matrix[row_index]))) + file_handle.write(" ".join([str(x) for x in matrix[row_index]])) if row_index != num_rows - 1: file_handle.write("\n") file_handle.write(" ]\n") @@ -80,7 +81,7 @@ def zeropad(x, length): lbl = labels[i, 0] if lbl == 10: lbl = 0 - labels_fh.write(key + ' ' + str(lbl) + '\n') + labels_fh.write("{} {}\n".format(key, lbl)) img = data[i] write_kaldi_matrix(out_fh, img, key) img_id += 1 diff --git a/egs/swahili/s5/run.sh b/egs/swahili/s5/run.sh index 018ac84ea84..3da3a30cc3b 100755 --- a/egs/swahili/s5/run.sh +++ b/egs/swahili/s5/run.sh @@ -117,7 +117,7 @@ echo -e "fMMI+MMI training done.\n" ### Triphone + LDA and MLLT + SGMM ## SGMM # Training -steps/train_ubm.sh --cmd "$train_cmd" 500 data/train data/lang exp/system1/tri3b_ali exp/system1/ubm5b2 || exit 1; +steps/train_ubm.sh --cmd "$train_cmd" 500 data/train data/lang exp/system1/tri3b_ali exp/system1/ubm5b2 || exit 1; steps/train_sgmm2.sh --cmd "$train_cmd" 5000 12000 data/train data/lang exp/system1/tri3b_ali exp/system1/ubm5b2/final.ubm exp/system1/sgmm2_5b2 || exit 1; # Graph compilation utils/mkgraph.sh data/lang exp/system1/sgmm2_5b2 exp/system1/sgmm2_5b2/graph @@ -146,7 +146,7 @@ for iter in 1 2 3 4; do done ## MBR -rm -r exp/system1/sgmm2_5b2_mmi_b0.1/decode_test_it3.mbr 2>/dev/null +rm -r exp/system1/sgmm2_5b2_mmi_b0.1/decode_test_it3.mbr 2>/dev/null cp -r exp/system1/sgmm2_5b2_mmi_b0.1/decode_test_it3{,.mbr} local/score_mbr.sh data/test data/lang exp/system1/sgmm2_5b2_mmi_b0.1/decode_test_it3.mbr diff --git a/egs/swbd/s5/local/run_sgmm2.sh b/egs/swbd/s5/local/run_sgmm2.sh index 1884e327db0..194dfa05e61 100755 --- a/egs/swbd/s5/local/run_sgmm2.sh +++ b/egs/swbd/s5/local/run_sgmm2.sh @@ -9,7 +9,7 @@ if [ ! -f exp/ubm5a/final.ubm ]; then steps/train_ubm.sh --cmd "$train_cmd" 700 data/train_100k_nodup data/lang \ exp/tri4a_ali_100k_nodup exp/ubm5a || exit 1; -fi +fi steps/train_sgmm2.sh --cmd "$train_cmd" \ 9000 30000 data/train_100k_nodup data/lang exp/tri4a_ali_100k_nodup \ @@ -65,4 +65,4 @@ done done wait ) - + diff --git a/egs/swbd/s5/run.sh b/egs/swbd/s5/run.sh index 66aa1d99866..79fe7703314 100755 --- a/egs/swbd/s5/run.sh +++ b/egs/swbd/s5/run.sh @@ -14,7 +14,7 @@ exit 1; #local/swbd_p1_data_prep.sh /mnt/matylda2/data/SWITCHBOARD_1R2 local/swbd_p1_data_prep.sh /data/corpora0/LDC97S62/ -#local/swbd_p1_data_prep.sh /export/corpora3/LDC/LDC97S62 +#local/swbd_p1_data_prep.sh /export/corpora3/LDC/LDC97S62 local/swbd_p1_prepare_dict.sh @@ -33,13 +33,13 @@ local/eval2000_data_prep.sh /data/corpora0/LDC2002S09/hub5e_00 /data/corpora0/L . ./cmd.sh # mfccdir should be some place with a largish disk where you -# want to store MFCC features. +# want to store MFCC features. mfccdir=`pwd`/mfcc steps/make_mfcc.sh --nj 20 --cmd "$train_cmd" data/train exp/make_mfcc/train $mfccdir || exit 1; -# Don't do "|| exit 1" because actually some speakers don't have data, +# Don't do "|| exit 1" because actually some speakers don't have data, # we'll get rid of them later. Ignore this error. -steps/compute_cmvn_stats.sh data/train exp/make_mfcc/train $mfccdir +steps/compute_cmvn_stats.sh data/train exp/make_mfcc/train $mfccdir # after this, the next command will remove the small number of utterances # that couldn't be extracted for some reason (e.g. too short; no such file). @@ -77,22 +77,22 @@ utils/data/remove_dup_utts.sh 300 data/train_nodev data/train_nodup utils/subset_data_dir.sh --first data/train_nodev 100000 data/train_100k utils/data/remove_dup_utts.sh 200 data/train_100k data/train_100k_nodup -# The next commands are not necessary for the scripts to run, but increase -# efficiency of data access by putting the mfcc's of the subset +# The next commands are not necessary for the scripts to run, but increase +# efficiency of data access by putting the mfcc's of the subset # in a contiguous place in a file. ( . ./path.sh; # make sure mfccdir is defined as above.. - cp data/train_10k_nodup/feats.scp{,.bak} + cp data/train_10k_nodup/feats.scp{,.bak} copy-feats scp:data/train_10k_nodup/feats.scp ark,scp:$mfccdir/kaldi_swbd_10k_nodup.ark,$mfccdir/kaldi_swbd_10k_nodup.scp \ && cp $mfccdir/kaldi_swbd_10k_nodup.scp data/train_10k_nodup/feats.scp ) ( . ./path.sh; # make sure mfccdir is defined as above.. - cp data/train_30k_nodup/feats.scp{,.bak} + cp data/train_30k_nodup/feats.scp{,.bak} copy-feats scp:data/train_30k_nodup/feats.scp ark,scp:$mfccdir/kaldi_swbd_30k_nodup.ark,$mfccdir/kaldi_swbd_30k_nodup.scp \ && cp $mfccdir/kaldi_swbd_30k_nodup.scp data/train_30k_nodup/feats.scp ) - + steps/train_mono.sh --nj 10 --cmd "$train_cmd" \ data/train_10k_nodup data/lang exp/mono0a || exit 1; @@ -102,7 +102,7 @@ steps/align_si.sh --nj 30 --cmd "$train_cmd" \ steps/train_deltas.sh --cmd "$train_cmd" \ 2500 20000 data/train_30k_nodup data/lang exp/mono0a_ali exp/tri1 || exit 1; - + utils/mkgraph.sh data/lang_test exp/tri1 exp/tri1/graph steps/decode.sh --nj 30 --cmd "$decode_cmd" --config conf/decode.config \ diff --git a/egs/swbd/s5b/local/run_sgmm2.sh b/egs/swbd/s5b/local/run_sgmm2.sh index eda22786d82..0cddc13bbd4 100755 --- a/egs/swbd/s5b/local/run_sgmm2.sh +++ b/egs/swbd/s5b/local/run_sgmm2.sh @@ -10,7 +10,7 @@ set -e if [ ! -f exp/ubm5b/final.ubm ]; then steps/train_ubm.sh --cmd "$train_cmd" 1400 data/train_nodup data/lang \ exp/tri4b_ali_nodup exp/ubm5b || exit 1; -fi +fi steps/train_sgmm2.sh --cmd "$train_cmd" \ 18000 60000 data/train_nodup data/lang exp/tri4b_ali_nodup \ diff --git a/egs/swbd/s5b/run.sh b/egs/swbd/s5b/run.sh index a1424e1fa34..ba447e6f972 100755 --- a/egs/swbd/s5b/run.sh +++ b/egs/swbd/s5b/run.sh @@ -15,7 +15,7 @@ exit 1; . ./path.sh set -e # exit on error # mfccdir should be some place with a largish disk where you -# want to store MFCC features. +# want to store MFCC features. mfccdir=mfcc if [ -z $IRSTLM ] ; then @@ -36,7 +36,7 @@ fi # which specifies the directory to Switchboard documentations. Specifically, if # this argument is given, the script will look for the conv.tab file and correct # speaker IDs to the actual speaker personal identification numbers released in -# the documentations. The documentations can be found here: +# the documentations. The documentations can be found here: # https://catalog.ldc.upenn.edu/docs/LDC97S62/ # Note: if you are using this link, make sure you rename conv_tab.csv to conv.tab # after downloading. @@ -52,7 +52,7 @@ local/swbd1_prepare_dict.sh utils/prepare_lang.sh data/local/dict "" data/local/lang data/lang # Now train the language models. We are using SRILM and interpolating with an -# LM trained on the Fisher transcripts (part 2 disk is currently missing; so +# LM trained on the Fisher transcripts (part 2 disk is currently missing; so # only part 1 transcripts ~700hr are used) # If you have the Fisher data, you can set this "fisher_dir" variable. @@ -75,10 +75,10 @@ for order in 3 4; do LM=data/local/lm/sw1.o${order}g.kn.gz utils/format_lm_sri.sh --srilm-opts "$srilm_opts" \ data/lang $LM data/local/dict/lexicon.txt data/lang_sw1_$lm_suffix - + LM=data/local/lm/sw1_fsh.o${order}g.kn.gz utils/build_const_arpa_lm.sh $LM data/lang data/lang_sw1_fsh_$lm_suffix - + # For some funny reason we are still using IRSTLM for doing LM pruning :) prune-lm --threshold=1e-7 data/local/lm/sw1_fsh.o${order}g.kn.gz /dev/stdout \ | gzip -c > data/local/lm/sw1_fsh.o${order}g.pr1-7.kn.gz || exit 1 @@ -98,9 +98,9 @@ done local/eval2000_data_prep.sh /export/corpora2/LDC/LDC2002S09/hub5e_00 /export/corpora2/LDC/LDC2002T43 steps/make_mfcc.sh --nj 50 --cmd "$train_cmd" data/train exp/make_mfcc/train $mfccdir -steps/compute_cmvn_stats.sh data/train exp/make_mfcc/train $mfccdir +steps/compute_cmvn_stats.sh data/train exp/make_mfcc/train $mfccdir -# Remove the small number of utterances that couldn't be extracted for some +# Remove the small number of utterances that couldn't be extracted for some # reason (e.g. too short; no such file). utils/fix_data_dir.sh data/train @@ -120,10 +120,10 @@ utils/subset_data_dir.sh --last data/train $n data/train_nodev # perl -ne 'split; $s+=($_[3]-$_[2]); END{$h=int($s/3600); $r=($s-$h*3600); $m=int($r/60); $r-=$m*60; printf "%.1f sec -- %d:%d:%.1f\n", $s, $h, $m, $r;}' data/local/train/segments -# Now-- there are 260k utterances (313hr 23min), and we want to start the -# monophone training on relatively short utterances (easier to align), but not +# Now-- there are 260k utterances (313hr 23min), and we want to start the +# monophone training on relatively short utterances (easier to align), but not # only the shortest ones (mostly uh-huh). So take the 100k shortest ones; -# remove most of the repeated utterances (these are the uh-huh type ones), and +# remove most of the repeated utterances (these are the uh-huh type ones), and # then take 10k random utterances from those (about 4hr 40mins) utils/subset_data_dir.sh --shortest data/train_nodev 100000 data/train_100kshort @@ -144,13 +144,13 @@ utils/data/remove_dup_utts.sh 300 data/train_nodev data/train_nodup # 286hr ## Starting basic training on MFCC features steps/train_mono.sh --nj 10 --cmd "$train_cmd" \ - data/train_10k_nodup data/lang exp/mono + data/train_10k_nodup data/lang exp/mono steps/align_si.sh --nj 30 --cmd "$train_cmd" \ - data/train_30k_nodup data/lang exp/mono exp/mono_ali + data/train_30k_nodup data/lang exp/mono exp/mono_ali steps/train_deltas.sh --cmd "$train_cmd" \ - 3200 30000 data/train_30k_nodup data/lang exp/mono_ali exp/tri1 + 3200 30000 data/train_30k_nodup data/lang exp/mono_ali exp/tri1 for lm_suffix in tg fsh_tgpr; do ( @@ -163,10 +163,10 @@ for lm_suffix in tg fsh_tgpr; do done steps/align_si.sh --nj 30 --cmd "$train_cmd" \ - data/train_30k_nodup data/lang exp/tri1 exp/tri1_ali + data/train_30k_nodup data/lang exp/tri1 exp/tri1_ali steps/train_deltas.sh --cmd "$train_cmd" \ - 3200 30000 data/train_30k_nodup data/lang exp/tri1_ali exp/tri2 + 3200 30000 data/train_30k_nodup data/lang exp/tri1_ali exp/tri2 for lm_suffix in tg fsh_tgpr; do @@ -183,14 +183,14 @@ for lm_suffix in tg fsh_tgpr; do ) & done -# From now, we start building a bigger system (on train_100k_nodup, which has +# From now, we start building a bigger system (on train_100k_nodup, which has # 110hrs of data). We start with the LDA+MLLT system steps/align_si.sh --nj 30 --cmd "$train_cmd" \ - data/train_100k_nodup data/lang exp/tri2 exp/tri2_ali_100k_nodup + data/train_100k_nodup data/lang exp/tri2 exp/tri2_ali_100k_nodup # Train tri3b, which is LDA+MLLT, on 100k_nodup data. steps/train_lda_mllt.sh --cmd "$train_cmd" \ - 5500 90000 data/train_100k_nodup data/lang exp/tri2_ali_100k_nodup exp/tri3b + 5500 90000 data/train_100k_nodup data/lang exp/tri2_ali_100k_nodup exp/tri3b for lm_suffix in tg fsh_tgpr; do ( @@ -204,12 +204,12 @@ done # Train tri4a, which is LDA+MLLT+SAT, on 100k_nodup data. steps/align_fmllr.sh --nj 30 --cmd "$train_cmd" \ - data/train_100k_nodup data/lang exp/tri3b exp/tri3b_ali_100k_nodup + data/train_100k_nodup data/lang exp/tri3b exp/tri3b_ali_100k_nodup steps/train_sat.sh --cmd "$train_cmd" \ 5500 90000 data/train_100k_nodup data/lang exp/tri3b_ali_100k_nodup \ - exp/tri4a + exp/tri4a for lm_suffix in tg fsh_tgpr; do ( @@ -226,11 +226,11 @@ done # both train and test data. # local/run_resegment.sh -# Now train a LDA+MLLT+SAT model on the entire training data (train_nodup; +# Now train a LDA+MLLT+SAT model on the entire training data (train_nodup; # 286 hours) # Train tri4b, which is LDA+MLLT+SAT, on train_nodup data. steps/align_fmllr.sh --nj 30 --cmd "$train_cmd" \ - data/train_nodup data/lang exp/tri3b exp/tri3b_ali_nodup + data/train_nodup data/lang exp/tri3b exp/tri3b_ali_nodup steps/train_sat.sh --cmd "$train_cmd" \ @@ -257,7 +257,7 @@ steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ data/lang_sw1_fsh_{tgpr,fg} data/eval2000 \ exp/tri4b/decode_eval2000_sw1_fsh_{tgpr,fg} || exit 1; -# MMI training starting from the LDA+MLLT+SAT systems on both the +# MMI training starting from the LDA+MLLT+SAT systems on both the # train_100k_nodup (110hr) and train_nodup (286hr) sets steps/align_fmllr.sh --nj 50 --cmd "$train_cmd" \ data/train_100k_nodup data/lang exp/tri4a exp/tri4a_ali_100k_nodup || exit 1 @@ -268,11 +268,11 @@ steps/align_fmllr.sh --nj 100 --cmd "$train_cmd" \ steps/make_denlats.sh --nj 50 --cmd "$decode_cmd" --config conf/decode.config \ --transform-dir exp/tri4a_ali_100k_nodup \ data/train_100k_nodup data/lang exp/tri4a exp/tri4a_denlats_100k_nodup \ - + steps/make_denlats.sh --nj 100 --cmd "$decode_cmd" --config conf/decode.config \ --transform-dir exp/tri4b_ali_nodup \ - data/train_nodup data/lang exp/tri4b exp/tri4b_denlats_nodup + data/train_nodup data/lang exp/tri4b exp/tri4b_denlats_nodup # 4 iterations of MMI seems to work well overall. The number of iterations is # used as an explicit argument even though train_mmi.sh will use 4 iterations by @@ -280,11 +280,11 @@ steps/make_denlats.sh --nj 100 --cmd "$decode_cmd" --config conf/decode.config \ num_mmi_iters=4 steps/train_mmi.sh --cmd "$decode_cmd" --boost 0.1 --num-iters $num_mmi_iters \ data/train_100k_nodup data/lang exp/tri4a_{ali,denlats}_100k_nodup \ - exp/tri4a_mmi_b0.1 + exp/tri4a_mmi_b0.1 steps/train_mmi.sh --cmd "$decode_cmd" --boost 0.1 --num-iters $num_mmi_iters \ data/train_nodup data/lang exp/tri4b_{ali,denlats}_nodup \ - exp/tri4b_mmi_b0.1 + exp/tri4b_mmi_b0.1 for iter in 1 2 3 4; do for lm_suffix in tg fsh_tgpr; do @@ -336,11 +336,11 @@ steps/train_diag_ubm.sh --silence-weight 0.5 --nj 100 --cmd "$train_cmd" \ steps/train_mmi_fmmi.sh --learning-rate 0.005 --boost 0.1 --cmd "$train_cmd" \ data/train_100k_nodup data/lang exp/tri4a_ali_100k_nodup exp/tri4a_dubm \ - exp/tri4a_denlats_100k_nodup exp/tri4a_fmmi_b0.1 + exp/tri4a_denlats_100k_nodup exp/tri4a_fmmi_b0.1 steps/train_mmi_fmmi.sh --learning-rate 0.005 --boost 0.1 --cmd "$train_cmd" \ data/train_nodup data/lang exp/tri4b_ali_nodup exp/tri4b_dubm \ - exp/tri4b_denlats_nodup exp/tri4b_fmmi_b0.1 + exp/tri4b_denlats_nodup exp/tri4b_fmmi_b0.1 for iter in 4 5 6 7 8; do for lm_suffix in tg fsh_tgpr; do diff --git a/egs/swbd/s5c/local/chain/multi_condition/run_tdnn_7k.sh b/egs/swbd/s5c/local/chain/multi_condition/run_tdnn_7k.sh index 6792332da56..20dcab8eb50 100755 --- a/egs/swbd/s5c/local/chain/multi_condition/run_tdnn_7k.sh +++ b/egs/swbd/s5c/local/chain/multi_condition/run_tdnn_7k.sh @@ -152,7 +152,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/swbd/s5c/local/chain/run_cnn_tdnn.sh b/egs/swbd/s5c/local/chain/run_cnn_tdnn.sh new file mode 120000 index 00000000000..ab83f3c43e8 --- /dev/null +++ b/egs/swbd/s5c/local/chain/run_cnn_tdnn.sh @@ -0,0 +1 @@ +tuning/run_cnn_tdnn_1a.sh \ No newline at end of file diff --git a/egs/swbd/s5c/local/chain/tuning/run_blstm_6j.sh b/egs/swbd/s5c/local/chain/tuning/run_blstm_6j.sh index ae7c97e7d08..acdae844b65 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_blstm_6j.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_blstm_6j.sh @@ -120,7 +120,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/swbd/s5c/local/chain/tuning/run_blstm_6k.sh b/egs/swbd/s5c/local/chain/tuning/run_blstm_6k.sh index 90d672b9ae9..bbd8cb63697 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_blstm_6k.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_blstm_6k.sh @@ -116,7 +116,7 @@ if [ $stage -le 12 ]; then num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" diff --git a/egs/swbd/s5c/local/chain/tuning/run_blstm_6l.sh b/egs/swbd/s5c/local/chain/tuning/run_blstm_6l.sh index 68daf81ab01..16f2ea211d0 100644 --- a/egs/swbd/s5c/local/chain/tuning/run_blstm_6l.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_blstm_6l.sh @@ -125,7 +125,7 @@ if [ $stage -le 12 ]; then num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20 dropout-proportion=0.0" diff --git a/egs/swbd/s5c/local/chain/tuning/run_blstm_6m.sh b/egs/swbd/s5c/local/chain/tuning/run_blstm_6m.sh index 4668aac9ebc..09f7d72434c 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_blstm_6m.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_blstm_6m.sh @@ -124,7 +124,7 @@ if [ $stage -le 12 ]; then num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" diff --git a/egs/swbd/s5c/local/chain/tuning/run_blstm_6n.sh b/egs/swbd/s5c/local/chain/tuning/run_blstm_6n.sh index 22316d56ed2..8e44d0bc114 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_blstm_6n.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_blstm_6n.sh @@ -123,7 +123,7 @@ if [ $stage -le 12 ]; then num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" diff --git a/egs/swbd/s5c/local/chain/tuning/run_blstm_6o.sh b/egs/swbd/s5c/local/chain/tuning/run_blstm_6o.sh index ad2ac4bf043..6a836e81b09 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_blstm_6o.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_blstm_6o.sh @@ -125,7 +125,7 @@ if [ $stage -le 12 ]; then num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" diff --git a/egs/swbd/s5c/local/chain/tuning/run_cnn_tdnn_1a.sh b/egs/swbd/s5c/local/chain/tuning/run_cnn_tdnn_1a.sh new file mode 100755 index 00000000000..d1a61360f85 --- /dev/null +++ b/egs/swbd/s5c/local/chain/tuning/run_cnn_tdnn_1a.sh @@ -0,0 +1,274 @@ +#!/bin/bash + +# This is based on tdnn_7q, but adding cnn as the front-end. +# The cnn-tdnn-f (cnn_tdnn_1a) outperforms the tdnn-f (tdnn_7q). + +# local/chain/compare_wer_general.sh --rt03 tdnn7q_sp cnn_tdnn1a_sp +# System tdnn7q_sp cnn_tdnn1a_sp +# WER on train_dev(tg) 12.08 11.97 +# WER on train_dev(fg) 11.15 11.12 +# WER on eval2000(tg) 14.1 13.9 +# WER on eval2000(fg) 12.8 12.5 +# WER on rt03(tg) 17.5 17.1 +# WER on rt03(fg) 15.3 14.9 +# Final train prob -0.055 -0.056 +# Final valid prob -0.072 -0.075 +# Final train prob (xent) -0.875 -0.871 +# Final valid prob (xent) -0.9064 -0.9110 +# Num-parameters 18725244 15187100 + +# steps/info/chain_dir_info.pl exp/chain/cnn_tdnn1a_sp +# exp/chain/cnn_tdnn1a_sp: num-iters=394 nj=3..16 num-params=15.2M dim=40+100->6078 combine=-0.054->-0.054 (over 7) xent:train/valid[261,393,final]=(-1.03,-0.878,-0.871/-1.06,-0.918,-0.911) logprob:train/valid[261,393,final]=(-0.076,-0.057,-0.056/-0.091,-0.076,-0.075) +set -e + +# configs for 'chain' +stage=0 +train_stage=-10 +get_egs_stage=-10 +speed_perturb=true +affix=1a +if [ -e data/rt03 ]; then maybe_rt03=rt03; else maybe_rt03= ; fi + +decode_iter= +decode_nj=50 + +# training options +frames_per_eg=150,110,100 +remove_egs=false +common_egs_dir= +xent_regularize=0.1 +dropout_schedule='0,0@0.20,0.3@0.50,0' + +test_online_decoding=false # if true, it will run the last decoding stage. + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +suffix= +$speed_perturb && suffix=_sp +dir=exp/chain/cnn_tdnn${affix}${suffix} + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 11 ]; then + # Build a tree using our new topology. This is the critically different + # step compared with other recipes. + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 7000 data/$train_set $lang $ali_dir $treedir +fi + +if [ $stage -le 12 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + + cnn_opts="l2-regularize=0.01" + ivector_affine_opts="l2-regularize=0.01" + tdnnf_first_opts="l2-regularize=0.01 dropout-proportion=0.0 bypass-scale=0.0" + tdnnf_opts="l2-regularize=0.01 dropout-proportion=0.0 bypass-scale=0.66" + linear_opts="l2-regularize=0.01 orthonormal-constraint=-1.0" + prefinal_opts="l2-regularize=0.01" + output_opts="l2-regularize=0.002" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + # this takes the MFCCs and generates filterbank coefficients. The MFCCs + # are more compressible so we prefer to dump the MFCCs to disk rather + # than filterbanks. + idct-layer name=idct input=input dim=40 cepstral-lifter=22 affine-transform-file=$dir/configs/idct.mat + linear-component name=ivector-linear $ivector_affine_opts dim=200 input=ReplaceIndex(ivector, t, 0) + batchnorm-component name=ivector-batchnorm target-rms=0.025 + batchnorm-component name=idct-batchnorm input=idct + combine-feature-maps-layer name=combine_inputs input=Append(idct-batchnorm, ivector-batchnorm) num-filters1=1 num-filters2=5 height=40 + conv-relu-batchnorm-layer name=cnn1 $cnn_opts height-in=40 height-out=40 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=64 + conv-relu-batchnorm-layer name=cnn2 $cnn_opts height-in=40 height-out=40 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=64 + conv-relu-batchnorm-layer name=cnn3 $cnn_opts height-in=40 height-out=20 height-subsample-out=2 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=128 + conv-relu-batchnorm-layer name=cnn4 $cnn_opts height-in=20 height-out=20 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=128 + conv-relu-batchnorm-layer name=cnn5 $cnn_opts height-in=20 height-out=10 height-subsample-out=2 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=256 + conv-relu-batchnorm-layer name=cnn6 $cnn_opts height-in=10 height-out=10 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=256 + # the first TDNN-F layer has no bypass (since dims don't match), and a larger bottleneck so the + # information bottleneck doesn't become a problem. (we use time-stride=0 so no splicing, to + # limit the num-parameters). + tdnnf-layer name=tdnnf7 $tdnnf_first_opts dim=1536 bottleneck-dim=256 time-stride=0 + tdnnf-layer name=tdnnf8 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf9 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf10 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf11 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf12 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf13 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf14 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf15 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + linear-component name=prefinal-l dim=256 $linear_opts + ## adding the layers for chain branch + prefinal-layer name=prefinal-chain input=prefinal-l $prefinal_opts small-dim=256 big-dim=1536 + output-layer name=output include-log-softmax=false dim=$num_targets $output_opts + # adding the layers for xent branch + prefinal-layer name=prefinal-xent input=prefinal-l $prefinal_opts small-dim=256 big-dim=1536 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 13 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/swbd-$(date +'%m_%d_%H_%M')/s5c/$dir/egs/storage $dir/egs/storage + fi + +# --cmd "queue.pl --config /home/dpovey/queue_conly.conf" \ + + + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$train_cmd" \ + --feat.online-ivector-dir exp/nnet3/ivectors_${train_set} \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.0 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--frames-overlap-per-eg 0 --constrained false" \ + --egs.chunk-width $frames_per_eg \ + --trainer.num-chunk-per-minibatch 64 \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs 6 \ + --trainer.optimization.num-jobs-initial 3 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.00025 \ + --trainer.optimization.final-effective-lrate 0.000025 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs $remove_egs \ + --feat-dir data/${train_set}_hires \ + --tree-dir $treedir \ + --lat-dir exp/tri4_lats_nodup$suffix \ + --dir $dir || exit 1; + +fi + +if [ $stage -le 14 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 data/lang_sw1_tg $dir $dir/graph_sw1_tg +fi + + +graph_dir=$dir/graph_sw1_tg +iter_opts= +if [ ! -z $decode_iter ]; then + iter_opts=" --iter $decode_iter " +fi +if [ $stage -le 15 ]; then + rm $dir/.error 2>/dev/null || true + for decode_set in train_dev eval2000 $maybe_rt03; do + ( + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $decode_nj --cmd "$decode_cmd" $iter_opts \ + --online-ivector-dir exp/nnet3/ivectors_${decode_set} \ + $graph_dir data/${decode_set}_hires \ + $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_sw1_tg || exit 1; + if $has_fisher; then + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + data/lang_sw1_{tg,fsh_fg} data/${decode_set}_hires \ + $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_sw1_{tg,fsh_fg} || exit 1; + fi + ) || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in decoding" + exit 1 + fi +fi + +if $test_online_decoding && [ $stage -le 16 ]; then + # note: if the features change (e.g. you add pitch features), you will have to + # change the options of the following command line. + steps/online/nnet3/prepare_online_decoding.sh \ + --mfcc-config conf/mfcc_hires.conf \ + $lang exp/nnet3/extractor $dir ${dir}_online + + rm $dir/.error 2>/dev/null || true + for decode_set in train_dev eval2000 $maybe_rt03; do + ( + # note: we just give it "$decode_set" as it only uses the wav.scp, the + # feature type does not matter. + + steps/online/nnet3/decode.sh --nj $decode_nj --cmd "$decode_cmd" \ + --acwt 1.0 --post-decode-acwt 10.0 \ + $graph_dir data/${decode_set}_hires \ + ${dir}_online/decode_${decode_set}${decode_iter:+_$decode_iter}_sw1_tg || exit 1; + if $has_fisher; then + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + data/lang_sw1_{tg,fsh_fg} data/${decode_set}_hires \ + ${dir}_online/decode_${decode_set}${decode_iter:+_$decode_iter}_sw1_{tg,fsh_fg} || exit 1; + fi + ) || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in decoding" + exit 1 + fi +fi + + +exit 0; diff --git a/egs/swbd/s5c/local/chain/tuning/run_lstm_6j.sh b/egs/swbd/s5c/local/chain/tuning/run_lstm_6j.sh index e432435a551..48db81f586f 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_lstm_6j.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_lstm_6j.sh @@ -119,7 +119,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/swbd/s5c/local/chain/tuning/run_lstm_6k.sh b/egs/swbd/s5c/local/chain/tuning/run_lstm_6k.sh index b9b7152dcbe..021eab09506 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_lstm_6k.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_lstm_6k.sh @@ -121,7 +121,7 @@ if [ $stage -le 12 ]; then num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" diff --git a/egs/swbd/s5c/local/chain/tuning/run_lstm_6l.sh b/egs/swbd/s5c/local/chain/tuning/run_lstm_6l.sh index 12564c4faae..f219167f9ec 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_lstm_6l.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_lstm_6l.sh @@ -131,7 +131,7 @@ if [ $stage -le 12 ]; then num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7g.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7g.sh index fa6518a9ad9..0623d26a9e4 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7g.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7g.sh @@ -117,7 +117,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7h.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7h.sh index 9dfaa1d4509..dbbe3c1e6fd 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7h.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7h.sh @@ -120,7 +120,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7i.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7i.sh index c5b5633d94c..2a8a658bf6b 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7i.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7i.sh @@ -113,7 +113,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7j.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7j.sh index 793b40f7fe3..a9eba36ddaa 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7j.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7j.sh @@ -112,7 +112,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7k.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7k.sh index bd47ed61f23..8e0b290cf87 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7k.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7k.sh @@ -114,7 +114,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7l.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7l.sh index f7681a743e1..bb9ddf209d6 100644 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7l.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7l.sh @@ -112,7 +112,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7m.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7m.sh index 03b1ee3c97f..97f92c14f1f 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7m.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7m.sh @@ -122,7 +122,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7m25l.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7m25l.sh index 0fa7353edb2..d9fe106e5d7 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7m25l.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7m25l.sh @@ -452,7 +452,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.002 dropout-proportion=0.0 dropout-per-dim=true dropout-per-dim-continuous=true" linear_opts="orthonormal-constraint=1.0" output_opts="l2-regularize=0.0005" diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7n.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7n.sh index cf4855db611..99e43443f99 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7n.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7n.sh @@ -119,7 +119,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.002" linear_opts="orthonormal-constraint=1.0" output_opts="l2-regularize=0.0005 bottleneck-dim=256" diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7o.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7o.sh index fb47b1e88ad..44ca3b3d279 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7o.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7o.sh @@ -126,7 +126,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.004 dropout-proportion=0.0 dropout-per-dim=true dropout-per-dim-continuous=true" linear_opts="orthonormal-constraint=-1.0 l2-regularize=0.004" output_opts="l2-regularize=0.002" diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7p.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7p.sh index 096ed9c54fd..d19a4ef4c0b 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7p.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7p.sh @@ -114,7 +114,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.004 dropout-proportion=0.0 dropout-per-dim=true dropout-per-dim-continuous=true" linear_opts="orthonormal-constraint=-1.0 l2-regularize=0.004" output_opts="l2-regularize=0.002" diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7q.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7q.sh index 8eab54a9dc2..cea0891d5d7 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7q.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7q.sh @@ -118,7 +118,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) affine_opts="l2-regularize=0.01 dropout-proportion=0.0 dropout-per-dim=true dropout-per-dim-continuous=true" tdnnf_opts="l2-regularize=0.01 dropout-proportion=0.0 bypass-scale=0.66" linear_opts="l2-regularize=0.01 orthonormal-constraint=-1.0" diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_attention_1a.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_attention_1a.sh index 3ce4fa68397..d4febd61e94 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_attention_1a.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_attention_1a.sh @@ -122,7 +122,7 @@ fi if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_blstm_1a.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_blstm_1a.sh index 7854bac44c5..4414147bf0e 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_blstm_1a.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_blstm_1a.sh @@ -120,7 +120,7 @@ if [ $stage -le 12 ]; then num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_blstm_1b.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_blstm_1b.sh index 3929cdc432e..cd9d4dc6f2b 100644 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_blstm_1b.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_blstm_1b.sh @@ -122,7 +122,7 @@ if [ $stage -le 12 ]; then num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20 dropout-proportion=0.0" diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_blstm_1c.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_blstm_1c.sh index 311fe15d895..18b660b4080 100644 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_blstm_1c.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_blstm_1c.sh @@ -119,7 +119,7 @@ if [ $stage -le 12 ]; then num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_blstm_1d.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_blstm_1d.sh index 4894e492542..be615e0e361 100644 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_blstm_1d.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_blstm_1d.sh @@ -112,7 +112,7 @@ if [ $stage -le 12 ]; then num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20 dropout-proportion=0.0" diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1a.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1a.sh index 89ed8ad1d72..43855e6f7ce 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1a.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1a.sh @@ -33,7 +33,6 @@ chunk_width=150 chunk_left_context=40 chunk_right_context=0 xent_regularize=0.025 -self_repair_scale=0.00001 label_delay=5 # decode options extra_left_context=50 @@ -119,7 +118,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1b.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1b.sh index f0c88368245..5c82ed0eb11 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1b.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1b.sh @@ -29,7 +29,6 @@ chunk_width=150 chunk_left_context=40 chunk_right_context=0 xent_regularize=0.025 -self_repair_scale=0.00001 label_delay=5 # decode options extra_left_context=50 @@ -115,7 +114,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1c.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1c.sh index d71301eb102..c3df0bf2b2c 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1c.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1c.sh @@ -36,7 +36,6 @@ chunk_width=150 chunk_left_context=40 chunk_right_context=0 xent_regularize=0.025 -self_repair_scale=0.00001 label_delay=5 # decode options extra_left_context=50 @@ -122,7 +121,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1d.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1d.sh index 22c7d2e582d..3d353387239 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1d.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1d.sh @@ -48,7 +48,6 @@ decode_iter=final # training options xent_regularize=0.025 -self_repair_scale=0.00001 label_delay=5 chunk_left_context=40 @@ -141,7 +140,7 @@ if [ $stage -le 12 ]; then num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1e.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1e.sh index 6987757757a..2a2d508ecdd 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1e.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1e.sh @@ -41,7 +41,6 @@ decode_nj=50 # training options xent_regularize=0.01 -self_repair_scale=0.00001 label_delay=5 chunk_left_context=40 @@ -136,7 +135,7 @@ if [ $stage -le 12 ]; then num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1f.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1f.sh index 90e179379e4..5af5463b372 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1f.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1f.sh @@ -60,7 +60,6 @@ decode_iter=final # training options xent_regularize=0.01 -self_repair_scale=0.00001 label_delay=5 chunk_left_context=40 @@ -153,7 +152,7 @@ if [ $stage -le 12 ]; then num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1g.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1g.sh index cb73f020e3e..28105a587ec 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1g.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1g.sh @@ -42,7 +42,6 @@ decode_iter=final # training options xent_regularize=0.01 -self_repair_scale=0.00001 label_delay=5 chunk_left_context=40 @@ -135,7 +134,7 @@ if [ $stage -le 12 ]; then num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=15" diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1h.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1h.sh index b12be22ce3d..d6e81f2d8eb 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1h.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1h.sh @@ -39,7 +39,6 @@ decode_iter=final # training options xent_regularize=0.01 -self_repair_scale=0.00001 label_delay=5 chunk_left_context=40 @@ -132,7 +131,7 @@ if [ $stage -le 12 ]; then num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1i.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1i.sh index 7e05834c1fb..060d98c9d05 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1i.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1i.sh @@ -60,7 +60,6 @@ decode_iter=final # training options xent_regularize=0.01 -self_repair_scale=0.00001 label_delay=5 chunk_left_context=40 @@ -153,7 +152,7 @@ if [ $stage -le 12 ]; then num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1j.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1j.sh index 6a6a4ba30e1..9bd39a262c5 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1j.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1j.sh @@ -25,7 +25,6 @@ decode_nj=50 # training options xent_regularize=0.01 -self_repair_scale=0.00001 label_delay=5 chunk_left_context=40 @@ -120,7 +119,7 @@ if [ $stage -le 12 ]; then num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1k.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1k.sh index 21cb4fa9373..ccd6138da6e 100644 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1k.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1k.sh @@ -35,7 +35,6 @@ decode_nj=50 # training options xent_regularize=0.01 -self_repair_scale=0.00001 label_delay=5 chunk_left_context=40 @@ -130,7 +129,7 @@ if [ $stage -le 12 ]; then num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=20" diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1l.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1l.sh index e88e199839c..f702033377a 100644 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1l.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1l.sh @@ -34,7 +34,6 @@ chunk_width=150 chunk_left_context=40 chunk_right_context=0 xent_regularize=0.025 -self_repair_scale=0.00001 label_delay=5 dropout_schedule='0,0@0.20,0.3@0.50,0' # decode options @@ -121,7 +120,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1m.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1m.sh index b50692616c4..b43577bd76c 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1m.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1m.sh @@ -42,7 +42,6 @@ frames_per_chunk_primary=$(echo $frames_per_chunk | cut -d, -f1) chunk_left_context=40 chunk_right_context=0 xent_regularize=0.025 -self_repair_scale=0.00001 label_delay=5 # decode options extra_left_context=50 @@ -129,7 +128,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) lstm_opts="decay-time=40" diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1n.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1n.sh index 9cb182b2915..5bb6e7da152 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1n.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1n.sh @@ -43,7 +43,6 @@ frames_per_chunk_primary=$(echo $frames_per_chunk | cut -d, -f1) chunk_left_context=40 chunk_right_context=0 xent_regularize=0.025 -self_repair_scale=0.00001 label_delay=5 # decode options extra_left_context=50 @@ -126,7 +125,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.002" linear_opts="orthonormal-constraint=1.0" diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_opgru_1a.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_opgru_1a.sh index b1426bc22b7..4db38d74508 100755 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_opgru_1a.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_opgru_1a.sh @@ -4,31 +4,36 @@ # This is based on TDNN_LSTM_1b, but using the NormOPGRU to replace the LSTMP, # and adding chunk-{left,right}-context-initial=0 +# For the details of OPGRU structure, please check the paper +# "Output-Gate Projected Gated Recurrent Unit for Speech Recognition" +# by Gaofeng Cheng et al, +# http://www.danielpovey.com/files/2018_interspeech_opgru.pdf + # Different from the vanilla OPGRU, Norm-OPGRU adds batchnorm in its output (forward direction) # and renorm in its recurrence. Experiments show that the TDNN-NormOPGRU could achieve similar # results than TDNN-LSTMP and BLSTMP in both large or small data sets (80 ~ 2300 Hrs). # ./local/chain/compare_wer_general.sh --looped tdnn_lstm_1e_sp tdnn_opgru_1a_sp # System tdnn_lstm_1e_sp tdnn_opgru_1a_sp -# WER on train_dev(tg) 12.81 12.39 -# [looped:] 12.93 12.32 -# WER on train_dev(fg) 11.92 11.39 -# [looped:] 12.07 11.35 +# WER on train_dev(tg) 12.81 12.31 +# [looped:] 12.93 12.26 +# WER on train_dev(fg) 11.92 11.60 +# [looped:] 12.07 11.65 # WER on eval2000(tg) 15.6 15.1 # [looped:] 16.0 15.1 -# WER on eval2000(fg) 14.1 13.6 +# WER on eval2000(fg) 14.1 13.5 # [looped:] 14.5 13.5 -# Final train prob -0.065 -0.066 -# Final valid prob -0.087 -0.085 -# Final train prob (xent) -0.918 -0.889 -# Final valid prob (xent) -1.0309 -0.9837 +# Final train prob -0.065 -0.068 +# Final valid prob -0.087 -0.091 +# Final train prob (xent) -0.918 -0.879 +# Final valid prob (xent) -1.0309 -0.9667 set -e # configs for 'chain' -stage=12 +stage=0 train_stage=-10 get_egs_stage=-10 speed_perturb=true @@ -129,7 +134,7 @@ if [ $stage -le 12 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) gru_opts="dropout-per-frame=true dropout-proportion=0.0" mkdir -p $dir/configs diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_opgru_1b.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_opgru_1b.sh new file mode 100755 index 00000000000..7e9dec67068 --- /dev/null +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_opgru_1b.sh @@ -0,0 +1,315 @@ +#!/bin/bash +# Apache 2.0 + +# This is based on TDNN_OPGRU_1A, but using the FastNormOPGRU to replace the NormPGRU. +# For the details of OPGRU structure, please check the paper +# "Output-Gate Projected Gated Recurrent Unit for Speech Recognition" +# by Gaofeng Cheng et al, +# http://www.danielpovey.com/files/2018_interspeech_opgru.pdf + +# Different from the vanilla OPGRU, Norm-OPGRU adds batchnorm in its output (forward direction) +# and renorm in its recurrence. Experiments show that the TDNN-NormOPGRU could achieve similar +# results than TDNN-LSTMP and BLSTMP in both large or small data sets (80 ~ 2300 Hrs). + +# ./local/chain/compare_wer_general.sh --looped tdnn_opgru_1a_sp tdnn_opgru_1b_sp +# System tdnn_opgru_1a_sp tdnn_opgru_1b_sp +# WER on train_dev(tg) 12.31 12.41 +# [looped:] 12.26 12.38 +# WER on train_dev(fg) 11.49 11.60 +# [looped:] 11.43 11.65 +# WER on eval2000(tg) 14.9 15.1 +# [looped:] 15.0 15.1 +# WER on eval2000(fg) 13.5 13.7 +# [looped:] 13.5 13.7 +# Final train prob -0.068 -0.070 +# Final valid prob -0.091 -0.092 +# Final train prob (xent) -0.879 -0.889 +# Final valid prob (xent) -0.9667 -0.9723 + + + +set -e + +# configs for 'chain' +stage=0 +train_stage=-10 +get_egs_stage=-10 +speed_perturb=true +dir=exp/chain/tdnn_opgru_1b # Note: _sp will get added to this if $speed_perturb == true. +decode_iter= +decode_dir_affix= + +# training options +leftmost_questions_truncate=-1 +chunk_width=150 +chunk_left_context=40 +chunk_right_context=0 +xent_regularize=0.025 +self_repair_scale=0.00001 +label_delay=5 +dropout_schedule='0,0@0.20,0.2@0.50,0' +# decode options +extra_left_context=50 +extra_right_context=0 +frames_per_chunk= +test_online_decoding= + +remove_egs=false +common_egs_dir= + +affix= +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 11 ]; then + # Build a tree using our new topology. + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ + --leftmost-questions-truncate $leftmost_questions_truncate \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 7000 data/$train_set $lang $ali_dir $treedir +fi + +if [ $stage -le 12 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + gru_opts="dropout-per-frame=true dropout-proportion=0.0 gru-nonlinearity-options=\"max-change=0.75\"" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-2,-1,0,1,2,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=1024 + relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1) dim=1024 + relu-batchnorm-layer name=tdnn3 input=Append(-1,0,1) dim=1024 + + # check steps/libs/nnet3/xconfig/gru.py for the other options and defaults + fast-norm-opgru-layer name=opgru1 cell-dim=1024 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-3 $gru_opts + relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=1024 + relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=1024 + fast-norm-opgru-layer name=opgru2 cell-dim=1024 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-3 $gru_opts + relu-batchnorm-layer name=tdnn6 input=Append(-3,0,3) dim=1024 + relu-batchnorm-layer name=tdnn7 input=Append(-3,0,3) dim=1024 + fast-norm-opgru-layer name=opgru3 cell-dim=1024 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-3 $gru_opts + + ## adding the layers for chain branch + output-layer name=output input=opgru3 output-delay=$label_delay include-log-softmax=false dim=$num_targets max-change=1.5 + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + output-layer name=output-xent input=opgru3 output-delay=$label_delay dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 + +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 13 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/swbd-$(date +'%m_%d_%H_%M')/s5c/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir exp/nnet3/ivectors_${train_set} \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.num-chunk-per-minibatch 64 \ + --trainer.frames-per-iter 1200000 \ + --trainer.max-param-change 2.0 \ + --trainer.num-epochs 4 \ + --trainer.optimization.shrink-value 0.99 \ + --trainer.optimization.num-jobs-initial 3 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.optimization.momentum 0.0 \ + --trainer.deriv-truncate-margin 8 \ + --egs.stage $get_egs_stage \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width $chunk_width \ + --egs.chunk-left-context $chunk_left_context \ + --egs.chunk-right-context $chunk_right_context \ + --trainer.dropout-schedule $dropout_schedule \ + --egs.chunk-left-context-initial 0 \ + --egs.chunk-right-context-final 0 \ + --egs.dir "$common_egs_dir" \ + --cleanup.remove-egs $remove_egs \ + --feat-dir data/${train_set}_hires \ + --tree-dir $treedir \ + --lat-dir exp/tri4_lats_nodup$suffix \ + --dir $dir || exit 1; +fi + +if [ $stage -le 14 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 data/lang_sw1_tg $dir $dir/graph_sw1_tg +fi + +decode_suff=sw1_tg +graph_dir=$dir/graph_sw1_tg +if [ $stage -le 15 ]; then + [ -z $extra_left_context ] && extra_left_context=$chunk_left_context; + [ -z $extra_right_context ] && extra_right_context=$chunk_right_context; + [ -z $frames_per_chunk ] && frames_per_chunk=$chunk_width; + iter_opts= + if [ ! -z $decode_iter ]; then + iter_opts=" --iter $decode_iter " + fi + for decode_set in train_dev eval2000; do + ( + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj 50 --cmd "$decode_cmd" $iter_opts \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --extra-left-context-initial 0 \ + --extra-right-context-final 0 \ + --frames-per-chunk "$frames_per_chunk" \ + --online-ivector-dir exp/nnet3/ivectors_${decode_set} \ + $graph_dir data/${decode_set}_hires \ + $dir/decode_${decode_set}${decode_dir_affix:+_$decode_dir_affix}_${decode_suff} || exit 1; + if $has_fisher; then + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + data/lang_sw1_{tg,fsh_fg} data/${decode_set}_hires \ + $dir/decode_${decode_set}${decode_dir_affix:+_$decode_dir_affix}_sw1_{tg,fsh_fg} || exit 1; + fi + ) & + done +fi + +if $test_online_decoding && [ $stage -le 16 ]; then + # note: if the features change (e.g. you add pitch features), you will have to + # change the options of the following command line. + steps/online/nnet3/prepare_online_decoding.sh \ + --mfcc-config conf/mfcc_hires.conf \ + $lang exp/nnet3/extractor $dir ${dir}_online + + rm $dir/.error 2>/dev/null || true + for decode_set in train_dev eval2000; do + ( + # note: we just give it "$decode_set" as it only uses the wav.scp, the + # feature type does not matter. + steps/online/nnet3/decode.sh --nj 50 --cmd "$decode_cmd" $iter_opts \ + --acwt 1.0 --post-decode-acwt 10.0 \ + $graph_dir data/${decode_set}_hires \ + ${dir}_online/decode_${decode_set}${decode_iter:+_$decode_iter}_sw1_tg || exit 1; + if $has_fisher; then + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + data/lang_sw1_{tg,fsh_fg} data/${decode_set}_hires \ + ${dir}_online/decode_${decode_set}${decode_iter:+_$decode_iter}_sw1_{tg,fsh_fg} || exit 1; + fi + ) || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in online decoding" + exit 1 + fi +fi + +if [ $stage -le 17 ]; then + rm $dir/.error 2>/dev/null || true + for decode_set in train_dev eval2000; do + ( + steps/nnet3/decode_looped.sh \ + --acwt 1.0 --post-decode-acwt 10.0 \ + --nj 50 --cmd "$decode_cmd" $iter_opts \ + --online-ivector-dir exp/nnet3/ivectors_${decode_set} \ + $graph_dir data/${decode_set}_hires \ + $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_sw1_tg_looped || exit 1; + if $has_fisher; then + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + data/lang_sw1_{tg,fsh_fg} data/${decode_set}_hires \ + $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_sw1_{tg,fsh_fg}_looped || exit 1; + fi + ) & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in looped decoding" + exit 1 + fi +fi + +wait; +exit 0; diff --git a/egs/swbd/s5c/local/map_acronyms_ctm.py b/egs/swbd/s5c/local/map_acronyms_ctm.py index bee488f73b0..7ae59d2a1d0 100755 --- a/egs/swbd/s5c/local/map_acronyms_ctm.py +++ b/egs/swbd/s5c/local/map_acronyms_ctm.py @@ -10,6 +10,7 @@ # en_4156 B 414.58 0.16 l # en_4156 B 414.74 0.17 a +from __future__ import division import argparse,re __author__ = 'Minhua Wu' diff --git a/egs/swbd/s5c/local/run_sgmm2.sh b/egs/swbd/s5c/local/run_sgmm2.sh index 97697e5251d..5410819dadb 100755 --- a/egs/swbd/s5c/local/run_sgmm2.sh +++ b/egs/swbd/s5c/local/run_sgmm2.sh @@ -12,7 +12,7 @@ has_fisher=$1 if [ ! -f exp/ubm5/final.ubm ]; then steps/train_ubm.sh --cmd "$train_cmd" 1400 data/train_nodup data/lang \ exp/tri4_ali_nodup exp/ubm5 || exit 1; -fi +fi # steps/train_sgmm2.sh --cmd "$train_cmd" \ steps/train_sgmm2_group.sh --cmd "$train_cmd" \ diff --git a/egs/swbd/s5c/local/score_sclite_conf.sh b/egs/swbd/s5c/local/score_sclite_conf.sh index 9a1fa5083bf..21da4520a4d 100755 --- a/egs/swbd/s5c/local/score_sclite_conf.sh +++ b/egs/swbd/s5c/local/score_sclite_conf.sh @@ -39,6 +39,12 @@ for f in $data/stm $data/glm $lang/words.txt $lang/phones/word_boundary.int \ [ ! -f $f ] && echo "$0: expecting file $f to exist" && exit 1; done +if [ -f $dir/../frame_subsampling_factor ]; then + factor=$(cat $dir/../frame_subsampling_factor) || exit 1 + frame_shift_opt="--frame-shift=0.0$factor" + echo "$0: $dir/../frame_subsampling_factor exists, using $frame_shift_opt" +fi + name=`basename $data`; # e.g. eval2000 mkdir -p $dir/scoring/log @@ -51,7 +57,7 @@ if [ $stage -le 0 ]; then ACWT=\`perl -e \"print 1.0/LMWT\;\"\` '&&' \ lattice-add-penalty --word-ins-penalty=$wip "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ lattice-align-words $lang/phones/word_boundary.int $model ark:- ark:- \| \ - lattice-to-ctm-conf --decode-mbr=$decode_mbr --acoustic-scale=\$ACWT ark:- - \| \ + lattice-to-ctm-conf $frame_shift_opt --decode-mbr=$decode_mbr --acoustic-scale=\$ACWT ark:- - \| \ utils/int2sym.pl -f 5 $lang/words.txt \| \ utils/convert_ctm.pl $data/segments $data/reco2file_and_channel \ '>' $dir/score_LMWT_${wip}/$name.ctm || exit 1; diff --git a/egs/swbd/s5c/local/swbd1_map_words.pl b/egs/swbd/s5c/local/swbd1_map_words.pl index 39f90d72816..125e4de0d61 100755 --- a/egs/swbd/s5c/local/swbd1_map_words.pl +++ b/egs/swbd/s5c/local/swbd1_map_words.pl @@ -44,7 +44,7 @@ # which is a mistake in the input. $a =~ s:^\{(.+)\}$:$1:; # e.g. {YUPPIEDOM} -> YUPPIEDOM $a =~ s:[A-Z]\[([^][])+\][A-Z]:$1-$3:i; # e.g. AMMU[N]IT- -> AMMU-IT- - $a =~ s:_\d$::; # e.g. THEM_1 -> THEM + $a =~ s:_\d::; # e.g. THEM_1 -> THEM, THEM_1's -> THEM's } $A[$n] = $a; } diff --git a/egs/tedlium/s5/local/join_suffix.py b/egs/tedlium/s5/local/join_suffix.py index 64c62964331..c36b96a07f9 100755 --- a/egs/tedlium/s5/local/join_suffix.py +++ b/egs/tedlium/s5/local/join_suffix.py @@ -5,6 +5,7 @@ # Apache 2.0 +from __future__ import print_function import sys from codecs import open diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_blstm_1a.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_blstm_1a.sh index 5e60ee1178c..2ac8c09dad1 100644 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_blstm_1a.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_blstm_1a.sh @@ -139,7 +139,7 @@ if [ $stage -le 17 ]; then lstm_opts="decay-time=20" num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_lstm_1a.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_lstm_1a.sh index ec6b8941955..47557f93696 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_lstm_1a.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_lstm_1a.sh @@ -152,7 +152,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_lstm_1b.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_lstm_1b.sh index 53aa92710e8..7afa1b7f902 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_lstm_1b.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_lstm_1b.sh @@ -153,7 +153,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_lstm_1c.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_lstm_1c.sh index 83c2f3607f0..e69e499e152 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_lstm_1c.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_lstm_1c.sh @@ -151,7 +151,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_lstm_1d.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_lstm_1d.sh index 2665ea91ff8..86e0352828c 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_lstm_1d.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_lstm_1d.sh @@ -164,7 +164,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_lstm_1e.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_lstm_1e.sh index f768c7659d7..0fdb2b3b63e 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_lstm_1e.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_lstm_1e.sh @@ -154,7 +154,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1b.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1b.sh index 3384b085114..492d3efb804 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1b.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1b.sh @@ -143,7 +143,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1c.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1c.sh index 5dd838a15e3..01768c3875f 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1c.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1c.sh @@ -160,7 +160,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1d.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1d.sh index 4f86691b752..bb5007f4c9f 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1d.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1d.sh @@ -151,7 +151,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1e.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1e.sh index e32c08562c6..1476ed1fd40 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1e.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1e.sh @@ -143,7 +143,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1f.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1f.sh index 2eab0285828..47f939fea1c 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1f.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1f.sh @@ -141,7 +141,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1g.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1g.sh index 64ce1f02fdd..f02025674e8 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1g.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1g.sh @@ -142,7 +142,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) affine_opts="l2-regularize=0.008 dropout-proportion=0.0 dropout-per-dim-continuous=true" tdnnf_opts="l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66" linear_opts="l2-regularize=0.008 orthonormal-constraint=-1.0" diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1a.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1a.sh index 8f0be130e27..b03da27e760 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1a.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1a.sh @@ -156,7 +156,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1b.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1b.sh index fef021c6482..e896a7867b3 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1b.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1b.sh @@ -169,7 +169,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1c.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1c.sh index d05ae15dfec..00f72fab796 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1c.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1c.sh @@ -160,7 +160,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1d.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1d.sh index 29d8e69b04c..80a9ed1c4d0 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1d.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1d.sh @@ -165,7 +165,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1e.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1e.sh index db3fde91656..031978f878a 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1e.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1e.sh @@ -213,7 +213,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1f.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1f.sh index f6a1d49890d..c60b8f7fefc 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1f.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1f.sh @@ -167,7 +167,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1g.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1g.sh index ff2c302fdf6..2d2048a6869 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1g.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1g.sh @@ -170,7 +170,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1h.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1h.sh index d4cb5e85657..a074e128270 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1h.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1h.sh @@ -168,7 +168,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1i.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1i.sh index 40b1bf7f54a..3bfe175806f 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1i.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1i.sh @@ -189,7 +189,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1j.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1j.sh index 838f49f977f..acbef783823 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1j.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1j.sh @@ -186,7 +186,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1k.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1k.sh index b1abfdcf525..173be863608 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1k.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1k.sh @@ -184,7 +184,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) # note: the value of the dropout-proportion is not important, as it's # controlled by the dropout schedule; what's important is that we set it. diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1l.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1l.sh index ef151d72875..94955d0472c 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1l.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1l.sh @@ -174,7 +174,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) # note: the value of the dropout-proportion is not important, as it's # controlled by the dropout schedule; what's important is that we set it. diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1m.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1m.sh index c2aac3f6e20..efd3bc98725 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1m.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1m.sh @@ -174,7 +174,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) # note: the value of the dropout-proportion is not important, as it's # controlled by the dropout schedule; what's important is that we set it. diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1n.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1n.sh index ed6cb66957d..c0559e8d389 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1n.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1n.sh @@ -185,7 +185,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) # note: the value of the dropout-proportion is not important, as it's # controlled by the dropout schedule; what's important is that we set it. diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1o.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1o.sh index 8a4b7468058..5a6dbaef8af 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1o.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1o.sh @@ -189,7 +189,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) # note: the value of the dropout-proportion is not important, as it's # controlled by the dropout schedule; what's important is that we set it. diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1r.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1r.sh index 8f80a6885ca..dd38d56759f 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1r.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1r.sh @@ -187,7 +187,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) tdnn_opts='ng-affine-options="update-period=1"' lstmp_opts='ng-affine-options="update-period=1" decay-time=20' diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1s.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1s.sh index ef1c7fc196f..1378d2d176d 100644 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1s.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1s.sh @@ -151,7 +151,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1t.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1t.sh index 19479de41aa..3c4882ec2c6 100644 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1t.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1t.sh @@ -152,7 +152,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1u.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1u.sh index 85c0e4a0661..23ea14ae151 100644 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1u.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1u.sh @@ -145,7 +145,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1v.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1v.sh index e0431a83ceb..7c44d963504 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1v.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1v.sh @@ -149,7 +149,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_attention_1a.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_attention_1a.sh index e1543c0120f..042ef346578 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_attention_1a.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_attention_1a.sh @@ -159,7 +159,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_attention_bs_1a.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_attention_bs_1a.sh index d08a7ad5e86..905e1845183 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_attention_bs_1a.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_attention_bs_1a.sh @@ -163,7 +163,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_attention_bs_1b.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_attention_bs_1b.sh index d256150484b..7bd96e7d82c 100755 --- a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_attention_bs_1b.sh +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_attention_bs_1b.sh @@ -150,7 +150,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r2/local/join_suffix.py b/egs/tedlium/s5_r2/local/join_suffix.py index 64c62964331..c36b96a07f9 100755 --- a/egs/tedlium/s5_r2/local/join_suffix.py +++ b/egs/tedlium/s5_r2/local/join_suffix.py @@ -5,6 +5,7 @@ # Apache 2.0 +from __future__ import print_function import sys from codecs import open diff --git a/egs/tedlium/s5_r2/local/rnnlm/tuning/run_lstm_tdnn.sh b/egs/tedlium/s5_r2/local/rnnlm/tuning/run_lstm_tdnn.sh index cc0410c3519..87f99f651bf 100755 --- a/egs/tedlium/s5_r2/local/rnnlm/tuning/run_lstm_tdnn.sh +++ b/egs/tedlium/s5_r2/local/rnnlm/tuning/run_lstm_tdnn.sh @@ -2,29 +2,52 @@ # Copyright 2012 Johns Hopkins University (author: Daniel Povey) Tony Robinson # 2017 Hainan Xu -# 2017 Ke Li +# 2018 Ke Li -# rnnlm/train_rnnlm.sh: best iteration (out of 10) was 8, linking it to final iteration. -# rnnlm/train_rnnlm.sh: train/dev perplexity was 78.4 / 147.8. -# Train objf: -1556.00 -5.43 -5.15 -5.00 -4.90 -4.82 -4.75 -4.69 -4.63 -4.58 -# Dev objf: -11.92 -5.70 -5.29 -5.16 -5.08 -5.04 -5.02 -5.00 -5.00 -5.00 +# rnnlm/train_rnnlm.sh: best iteration (out of 9) was 8, linking it to final iteration. +# rnnlm/train_rnnlm.sh: train/dev perplexity was 94.1 / 155.1. +# Train objf: -6.24 -5.45 -5.12 -4.95 -4.84 -4.74 -4.66 -4.59 -4.52 -4.46 +# Dev objf: -11.92 -5.80 -5.32 -5.17 -5.10 -5.07 -5.05 -5.05 -5.04 -5.06 + +# 1-pass results +# %WER 8.3 | 1155 27500 | 92.7 4.9 2.4 1.0 8.3 68.8 | -0.019 | /export/a12/ywang/kaldi/egs/tedlium/s5_r2/exp/chain_cleaned/tdnn_lstm1i_adversarial1.0_interval4_epoches7_lin_to_5_sp_bi/decode_looped_test/score_10_0.0/ctm.filt.filt.sys + +# 4-gram rescoring +# %WER 7.8 | 1155 27500 | 93.1 4.5 2.4 0.9 7.8 66.4 | -0.089 | /export/a12/ywang/kaldi/egs/tedlium/s5_r2/exp/chain_cleaned/tdnn_lstm1i_adversarial1.0_interval4_epoches7_lin_to_5_sp_bi/decode_looped_test_rescore/score_10_0.0/ctm.filt.filt.sys + +# RNNLM lattice rescoring +# %WER 7.2 | 1155 27500 | 93.6 4.0 2.3 0.8 7.2 64.3 | -0.927 | exp/decode_looped_test_rnnlm_tedlium_rescore//score_10_0.0/ctm.filt.filt.sys + +# RNNLM nbest rescoring +# %WER 7.4 | 1155 27500 | 93.4 4.3 2.3 0.9 7.4 64.8 | -0.863 | exp/decode_looped_test_rnnlm_tedlium_nbest_rescore/score_8_0.0/ctm.filt.filt.sys # Begin configuration section. +cmd=run.pl +decode_cmd=run.pl dir=exp/rnnlm_lstm_tdnn -embedding_dim=800 -lstm_rpd=200 -lstm_nrpd=200 -stage=-10 +embedding_dim=1024 +lstm_rpd=256 +lstm_nrpd=256 +stage=0 train_stage=-10 epochs=20 -. ./cmd.sh -. utils/parse_options.sh -[ -z "$cmd" ] && cmd=$train_cmd +# variables for lattice rescoring +run_lat_rescore=true +run_nbest_rescore=true +decode_dir_suffix=rnnlm_tedlium +ac_model_dir=exp/chain_cleaned/tdnn_lstm1i_adversarial1.0_interval4_epoches7_lin_to_5_sp_bi +ngram_order=4 # approximate the lattice-rescoring by limiting the max-ngram-order + # if it's set, it merges histories in the lattice if they share + # the same ngram history and this prevents the lattice from + # exploding exponentially +pruned_rescore=true +. ./cmd.sh +. ./utils/parse_options.sh -text=data/train/text wordlist=data/lang/words.txt +text=data/train/text dev_sents=10000 text_dir=data/rnnlm/text mkdir -p $dir/config @@ -37,14 +60,14 @@ done if [ $stage -le 0 ]; then mkdir -p $text_dir - cat $text | cut -d ' ' -f2- | head -n $dev_sents> $text_dir/dev.txt + cat $text | cut -d ' ' -f2- | head -n $dev_sents > $text_dir/dev.txt cat $text | cut -d ' ' -f2- | tail -n +$[$dev_sents+1] > $text_dir/ted.txt fi if [ $stage -le 1 ]; then cp $wordlist $dir/config/ - n=`cat $dir/config/words.txt | wc -l` - echo " $n" >> $dir/config/words.txt + n=`cat $dir/config/words.txt | wc -l` + echo " $n" >> $dir/config/words.txt # words that are not present in words.txt but are in the training or dev data, will be # mapped to during training. @@ -66,8 +89,9 @@ EOF --min-frequency 1.0e-03 \ --special-words=',,,' \ $dir/config/words.txt > $dir/config/features.txt +fi - cat >$dir/config/xconfig <$dir/config/xconfig < \ + training-monolingual/news.20XX.en.shuffled.sorted.tokenized + echo "Done tokenizing corpus." + cd ../../.. +fi + +if [ $stage -le 1 ]; then + mkdir -p $text_dir + cat $train_text | cut -d ' ' -f2- | head -n $dev_sents > $text_dir/dev.txt + cat $train_text | cut -d ' ' -f2- | tail -n +$[$dev_sents+1] > $text_dir/ted.txt + cp $lm1b_dir/training-monolingual/news.20XX.en.shuffled.sorted.tokenized $text_dir/lm1b.txt +fi + +if [ $stage -le 2 ]; then + cp $wordlist $dir/config/ + n=`cat $dir/config/words.txt | wc -l` + echo " $n" >> $dir/config/words.txt + + # words that are not present in words.txt but are in the training or dev data, will be + # mapped to during training. + echo "" >$dir/config/oov.txt + + cat > $dir/config/data_weights.txt <$dir/config/unigram_probs.txt + + # choose features + rnnlm/choose_features.py --unigram-probs=$dir/config/unigram_probs.txt \ + --use-constant-feature=true \ + --top-word-features=10000 \ + --min-frequency 1.0e-03 \ + --special-words=',,,' \ + $dir/config/words.txt > $dir/config/features.txt +fi + + cat >$dir/config/xconfig < $ref_vocab || exit 1; + +# Get a G2P generated lexicon for oov words (w.r.t the reference lexicon) +# in acoustic training data. +if [ $stage -le 0 ]; then + if [ -z $g2p_mdl_dir ]; then + g2p_mdl_dir=exp/g2p_phonetisaurus + steps/dict/train_g2p_phonetisaurus.sh $ref_dict/lexicon.txt $g2p_mdl_dir || exit 1; + fi + awk '{for (n=2;n<=NF;n++) vocab[$n]=1;} END{for (w in vocab) printf "%s\n",w;}' \ + $data/text | sort -u > $data/train_vocab.txt || exit 1; + awk 'NR==FNR{a[$1] = 1; next} {if(!($1 in a)) print $1}' $ref_vocab \ + $data/train_vocab.txt | sort > $data/oov_train.txt || exit 1; + steps/dict/apply_g2p_phonetisaurus.sh --nbest 5 $data/train_vocab.txt $g2p_mdl_dir \ + exp/g2p_phonetisaurus/lex_train || exit 1; +fi + +# Learn a lexicon based on the acoustic training data and the reference lexicon. +if [ $stage -le 1 ]; then + steps/dict/learn_lexicon_greedy.sh --lexiconp-g2p "exp/g2p_phonetisaurus/lex_train/lexicon.lex" \ + --alpha $alpha --beta $beta --delta $delta \ + --min-prob $min_prob --cmd "$train_cmd" \ + --variant-counts-ratio $vcr \ + --stage $lexlearn_stage --nj 60 --oov-symbol $oov_symbol --retrain-src-mdl false \ + $ref_dict $ref_vocab $data exp/tri3 data/lang data/local/dict_${affix}_nosp \ + $dir || exit 1; +fi + +# Add pronounciation probs to the learned lexicon. +if [ $stage -le 2 ]; then + utils/prepare_lang.sh --phone-symbol-table data/lang/phones.txt \ + data/local/dict_${affix}_nosp $oov_symbol data/local/lang_${affix}_nosp data/lang_${affix}_nosp || exit 1; + + steps/align_si.sh --nj $nj --cmd "$train_cmd" \ + $data data/lang_${affix}_nosp exp/tri2 exp/tri2_ali_${affix}_nosp || exit 1; + + steps/get_prons.sh --cmd "$train_cmd" data/train data/lang_${affix}_nosp exp/tri2_ali_${affix}_nosp || exit 1; + + utils/dict_dir_add_pronprobs.sh --max-normalize true \ + data/local/dict_${affix}_nosp exp/tri2_ali_${affix}_nosp/pron_counts_nowb.txt \ + exp/tri2_ali_${affix}_nosp/sil_counts_nowb.txt \ + exp/tri2_ali_${affix}_nosp/pron_bigram_counts_nowb.txt data/local/dict_${affix} || exit 1; + + utils/prepare_lang.sh --phone-symbol-table data/lang/phones.txt \ + data/local/dict_${affix} $oov_symbol data/local/lang_${affix} data/lang_${affix} || exit 1; +fi + +# Re-decode +if [ $stage -le 3 ]; then + ! cmp data/lang_nosp/words.txt data/lang_${affix}/words.txt &&\ + echo "$0: The vocab of the affix lexicon and the reference vocab may be incompatible." + cp data/lang_nosp/G.fst data/lang_${affix}/ + utils/mkgraph.sh data/lang_${affix} exp/tri3 exp/tri3/graph_${affix} || exit 1; + + for dset in dev test; do + ( steps/decode_fmllr.sh --nj $decode_nj --cmd "$decode_cmd" --num-threads 4 \ + exp/tri3/graph_${affix} data/${dset} exp/tri3/decode_${affix}_${dset} || exit 1; + ) & + done +fi + +# RESULTS: +# Baseline: +# %WER 18.7 | 507 17783 | 83.9 11.4 4.7 2.6 18.7 92.3 | -0.006 | exp/tri3/decode_dev/score_17_0.0/ctm.filt.filt.sys +# %WER 17.6 | 1155 27500 | 84.7 11.6 3.7 2.4 17.6 87.2 | 0.013 | exp/tri3/decode_test/score_15_0.0/ctm.filt.filt.sys + +# Re-decoding with the learned lexicon: +# %WER 18.5 | 507 17783 | 84.3 11.2 4.5 2.8 18.5 92.3 | -0.007 | exp/tri3/decode_learned_greedy_dev/score_16_0.0/ctm.filt.filt.sys +# %WER 17.5 | 1155 27500 | 84.9 11.5 3.6 2.4 17.5 87.5 | 0.035 | exp/tri3/decode_learned_greedy_test/score_14_0.0/ctm.filt.filt.sys + +# To see the effect to neural-net results, one should re-train NN with the learned lexicon. +# Experiments have shown that, with the new lang dir, one should just re-run NN training +# starting from the supervision generation (steps/align_fmllr_lats.sh) stage, and should +# expect improved overall WERs and word recognition performance on words whose pronunciations +# were changed. + +exit +wait diff --git a/egs/tedlium/s5_r2_wsj/local/lm/merge_word_counts.py b/egs/tedlium/s5_r2_wsj/local/lm/merge_word_counts.py index 6338cbbf875..85e15d8dc07 100755 --- a/egs/tedlium/s5_r2_wsj/local/lm/merge_word_counts.py +++ b/egs/tedlium/s5_r2_wsj/local/lm/merge_word_counts.py @@ -7,6 +7,7 @@ A min-count argument is required to only write counts that are above the specified minimum count. """ +from __future__ import print_function import sys @@ -21,7 +22,7 @@ def main(): parts = line.strip().split() words[parts[1]] = words.get(parts[1], 0) + int(parts[0]) - for word, count in words.iteritems(): + for word, count in words.items(): if count >= int(sys.argv[1]): print ("{0} {1}".format(count, word)) diff --git a/egs/tedlium/s5_r3/.gitignore b/egs/tedlium/s5_r3/.gitignore new file mode 100644 index 00000000000..65eef93d691 --- /dev/null +++ b/egs/tedlium/s5_r3/.gitignore @@ -0,0 +1 @@ +db diff --git a/egs/tedlium/s5_r3/RESULTS b/egs/tedlium/s5_r3/RESULTS new file mode 100644 index 00000000000..b2f9526a8fd --- /dev/null +++ b/egs/tedlium/s5_r3/RESULTS @@ -0,0 +1,32 @@ +# This RESULTS file was obtained by running ./run.sh and then ./result.sh + +%WER 28.32 [ 5037 / 17783, 615 ins, 1171 del, 3251 sub ] exp/tri1/decode_nosp_dev/wer_10 +%WER 26.99 [ 4799 / 17783, 603 ins, 1169 del, 3027 sub ] exp/tri1/decode_nosp_dev_rescore/wer_10 +%WER 27.76 [ 7634 / 27500, 776 ins, 1689 del, 5169 sub ] exp/tri1/decode_nosp_test/wer_11 +%WER 26.52 [ 7292 / 27500, 766 ins, 1611 del, 4915 sub ] exp/tri1/decode_nosp_test_rescore/wer_11 +%WER 23.38 [ 4158 / 17783, 603 ins, 953 del, 2602 sub ] exp/tri2/decode_dev/wer_14 +%WER 21.98 [ 3909 / 17783, 597 ins, 910 del, 2402 sub ] exp/tri2/decode_dev_rescore/wer_14 +%WER 24.12 [ 4289 / 17783, 600 ins, 1014 del, 2675 sub ] exp/tri2/decode_nosp_dev/wer_12 +%WER 22.96 [ 4083 / 17783, 631 ins, 931 del, 2521 sub ] exp/tri2/decode_nosp_dev_rescore/wer_11 +%WER 23.30 [ 6408 / 27500, 727 ins, 1375 del, 4306 sub ] exp/tri2/decode_nosp_test/wer_13 +%WER 22.10 [ 6078 / 27500, 746 ins, 1281 del, 4051 sub ] exp/tri2/decode_nosp_test_rescore/wer_12 +%WER 22.31 [ 6134 / 27500, 794 ins, 1148 del, 4192 sub ] exp/tri2/decode_test/wer_13 +%WER 21.06 [ 5791 / 27500, 737 ins, 1147 del, 3907 sub ] exp/tri2/decode_test_rescore/wer_14 +%WER 19.99 [ 3554 / 17783, 570 ins, 816 del, 2168 sub ] exp/tri3_cleaned/decode_dev/wer_16 +%WER 18.92 [ 3364 / 17783, 588 ins, 791 del, 1985 sub ] exp/tri3_cleaned/decode_dev_rescore/wer_15 +%WER 23.85 [ 4241 / 17783, 686 ins, 874 del, 2681 sub ] exp/tri3_cleaned/decode_dev.si/wer_13 +%WER 17.73 [ 4876 / 27500, 700 ins, 935 del, 3241 sub ] exp/tri3_cleaned/decode_test/wer_16 +%WER 16.72 [ 4599 / 27500, 686 ins, 906 del, 3007 sub ] exp/tri3_cleaned/decode_test_rescore/wer_16 +%WER 22.10 [ 6077 / 27500, 864 ins, 1075 del, 4138 sub ] exp/tri3_cleaned/decode_test.si/wer_13 +%WER 19.63 [ 3490 / 17783, 585 ins, 809 del, 2096 sub ] exp/tri3/decode_dev/wer_15 +%WER 18.56 [ 3300 / 17783, 558 ins, 817 del, 1925 sub ] exp/tri3/decode_dev_rescore/wer_16 +%WER 23.75 [ 4224 / 17783, 661 ins, 917 del, 2646 sub ] exp/tri3/decode_dev.si/wer_14 +%WER 17.92 [ 4928 / 27500, 730 ins, 921 del, 3277 sub ] exp/tri3/decode_test/wer_14 +%WER 16.80 [ 4621 / 27500, 650 ins, 973 del, 2998 sub ] exp/tri3/decode_test_rescore/wer_17 +%WER 22.16 [ 6095 / 27500, 849 ins, 1070 del, 4176 sub ] exp/tri3/decode_test.si/wer_13 +%WER 8.17 [ 1453 / 17783, 242 ins, 310 del, 901 sub ] exp/chain_cleaned/tdnnf_1a/decode_dev/wer_9 +%WER 7.61 [ 1354 / 17783, 236 ins, 300 del, 818 sub ] exp/chain_cleaned/tdnnf_1a/decode_dev_rescore/wer_9 +%WER 6.17 [ 1097 / 17783, 207 ins, 292 del, 598 sub ] exp/chain_cleaned/tdnnf_1a/decode_dev_rnnlm_lstm_tdnn_a_averaged/wer_10 +%WER 8.16 [ 2245 / 27500, 288 ins, 605 del, 1352 sub ] exp/chain_cleaned/tdnnf_1a/decode_test/wer_9 +%WER 7.75 [ 2131 / 27500, 264 ins, 643 del, 1224 sub ] exp/chain_cleaned/tdnnf_1a/decode_test_rescore/wer_10 +%WER 6.84 [ 1880 / 27500, 283 ins, 533 del, 1064 sub ] exp/chain_cleaned/tdnnf_1a/decode_test_rnnlm_lstm_tdnn_a_averaged/wer_8 diff --git a/egs/tedlium/s5_r3/local/chain/tuning/run_tdnn_1a.sh b/egs/tedlium/s5_r3/local/chain/tuning/run_tdnn_1a.sh index 40cdcb5b5ff..1204ff6ce4c 100755 --- a/egs/tedlium/s5_r3/local/chain/tuning/run_tdnn_1a.sh +++ b/egs/tedlium/s5_r3/local/chain/tuning/run_tdnn_1a.sh @@ -143,7 +143,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r3/local/chain/tuning/run_tdnn_1b.sh b/egs/tedlium/s5_r3/local/chain/tuning/run_tdnn_1b.sh index f8eec8c5213..f06ba3fa195 100755 --- a/egs/tedlium/s5_r3/local/chain/tuning/run_tdnn_1b.sh +++ b/egs/tedlium/s5_r3/local/chain/tuning/run_tdnn_1b.sh @@ -77,7 +77,6 @@ fi local/nnet3/run_ivector_common.sh --stage $stage \ --nj $nj \ - --min-seg-len $min_seg_len \ --train-set $train_set \ --gmm $gmm \ --num-threads-ubm $num_threads_ubm \ @@ -149,7 +148,7 @@ if [ $stage -le 17 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/tedlium/s5_r3/local/download_data.sh b/egs/tedlium/s5_r3/local/download_data.sh index 49de5b12372..c51effdd6fa 100755 --- a/egs/tedlium/s5_r3/local/download_data.sh +++ b/egs/tedlium/s5_r3/local/download_data.sh @@ -21,7 +21,9 @@ else # the following command won't re-get it if it's already there # because of the --continue switch. wget --continue http://www.openslr.org/resources/51/TEDLIUM_release-3.tgz || exit 1 - tar xf "TEDLIUM_release-3.tar.gz" + + echo "$0: extracting TEDLIUM_release-3 data" + tar xf "TEDLIUM_release-3.tgz" else echo "$0: not downloading or un-tarring TEDLIUM_release2 because it already exists." fi diff --git a/egs/tedlium/s5_r3/local/join_suffix.py b/egs/tedlium/s5_r3/local/join_suffix.py new file mode 100755 index 00000000000..c36b96a07f9 --- /dev/null +++ b/egs/tedlium/s5_r3/local/join_suffix.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# +# Copyright 2014 Nickolay V. Shmyrev +# 2016 Johns Hopkins University (author: Daniel Povey) +# Apache 2.0 + + +from __future__ import print_function +import sys +from codecs import open + +# This script joins together pairs of split-up words like "you 're" -> "you're". +# The TEDLIUM transcripts are normalized in a way that's not traditional for +# speech recognition. + +for line in sys.stdin: + items = line.split() + new_items = [] + i = 1 + while i < len(items): + if i < len(items) - 1 and items[i+1][0] == '\'': + new_items.append(items[i] + items[i+1]) + i = i + 1 + else: + new_items.append(items[i]) + i = i + 1 + print(items[0] + ' ' + ' '.join(new_items)) diff --git a/egs/tedlium/s5_r3/local/rnnlm/tuning/run_lstm_tdnn_a.sh b/egs/tedlium/s5_r3/local/rnnlm/tuning/run_lstm_tdnn_a.sh index 32252db937d..73a684b6379 100755 --- a/egs/tedlium/s5_r3/local/rnnlm/tuning/run_lstm_tdnn_a.sh +++ b/egs/tedlium/s5_r3/local/rnnlm/tuning/run_lstm_tdnn_a.sh @@ -30,7 +30,6 @@ epochs=20 [ -z "$cmd" ] && cmd=$train_cmd text_from_audio=data/train/text -text=data/LM/train.txt wordlist=data/lang_chain/words.txt dev_sents=10000 text_dir=data/rnnlm/text @@ -44,8 +43,9 @@ done if [ $stage -le 0 ]; then mkdir -p $text_dir + gunzip -c db/TEDLIUM_release-3/LM/*.en.gz | sed 's/ <\/s>//g' > $text_dir/train.txt # shuffle text from audio and lm - cat $text_from_audio | cut -d ' ' -f2- | cat $text |\ + cat $text_from_audio | cut -d ' ' -f2- | cat $text_dir/train.txt |\ shuf > data/rnnlm/full_lm_data.shuffled # create dev and train sets based on audio and LM data cat data/rnnlm/full_lm_data.shuffled | head -n $dev_sents> $text_dir/dev.txt diff --git a/egs/tedlium/s5_r3/local/ted_download_lm.sh b/egs/tedlium/s5_r3/local/ted_download_lm.sh index ad833555b5f..6118876a0ab 100755 --- a/egs/tedlium/s5_r3/local/ted_download_lm.sh +++ b/egs/tedlium/s5_r3/local/ted_download_lm.sh @@ -13,4 +13,4 @@ echo "$0: downloading Tedlium 4 gram language models (it won't re-download if it wget --continue http://kaldi-asr.org/models/5/4gram_small.arpa.gz -P data/local/local_lm/data/arpa || exit 1 wget --continue http://kaldi-asr.org/models/5/4gram_big.arpa.gz -P data/local/local_lm/data/arpa || exit 1 -exit 0 \ No newline at end of file +exit 0 diff --git a/egs/tedlium/s5_r3/local/ted_download_rnnlm.sh b/egs/tedlium/s5_r3/local/ted_download_rnnlm.sh index 431d44c6ff6..6cbcaaa85ee 100755 --- a/egs/tedlium/s5_r3/local/ted_download_rnnlm.sh +++ b/egs/tedlium/s5_r3/local/ted_download_rnnlm.sh @@ -14,7 +14,7 @@ wget --continue http://kaldi-asr.org/models/5/tedlium_rnnlm.tgz -P exp/rnnlm_lst cd exp/rnnlm_lstm_tdnn_a_averaged tar -xvzf tedlium_rnnlm.tgz || exit 1 rm tedlium_rnnlm.tgz -mkdir config +mkdir -p config cd ../.. cp data/lang/words.txt exp/rnnlm_lstm_tdnn_a_averaged/config/words.txt echo " 152217" >> exp/rnnlm_lstm_tdnn_a_averaged/config/words.txt diff --git a/egs/tedlium/s5_r3/results.sh b/egs/tedlium/s5_r3/results.sh index 98bcab94ec5..3e318cb4bc7 100755 --- a/egs/tedlium/s5_r3/results.sh +++ b/egs/tedlium/s5_r3/results.sh @@ -1,10 +1,25 @@ #!/bin/bash +# The output of this script (after successfully running ./run.sh) can be found in the RESULTS file. + filter_regexp=. [ $# -ge 1 ] && filter_regexp=$1 -for x in exp/*/decode*; do [ -d $x ] && grep WER $x/wer_* | utils/best_wer.sh; done 2>/dev/null - for x in exp/{mono,tri,sgmm,nnet,dnn,lstm,chain}*/decode*; do [ -d $x ] && grep Sum $x/score_*/*.sys | utils/best_wer.sh; done 2>/dev/null | grep $filter_regexp - for x in exp/{mono,tri,sgmm,nnet,dnn,lstm,chain}*/*/decode*; do [ -d $x ] && grep Sum $x/score_*/*.sys | utils/best_wer.sh; done 2>/dev/null | grep $filter_regexp +for x in exp/*/decode*; do + [ -d $x ] && grep WER $x/wer_* | utils/best_wer.sh; +done 2>/dev/null + +for x in exp/{mono,tri,sgmm,nnet,dnn,lstm,chain}*/decode*; do + [ -d $x ] && grep Sum $x/score_*/*.sys | utils/best_wer.sh; +done 2>/dev/null | grep $filter_regexp + +for x in exp/{mono,tri,sgmm,nnet,dnn,lstm,chain}*/*/decode*; do + [ -d $x ] && grep Sum $x/score_*/*.sys | utils/best_wer.sh; +done 2>/dev/null | grep $filter_regexp + +for x in exp/{mono,tri,sgmm,nnet,dnn,lstm,chain}*/*/decode*; do + [ -d $x ] && grep WER $x/wer_* | utils/best_wer.sh; +done 2>/dev/null | grep $filter_regexp + exit 0 diff --git a/egs/tedlium/s5_r3/run.sh b/egs/tedlium/s5_r3/run.sh index d4f3a38fd49..ecb2cdf4633 100755 --- a/egs/tedlium/s5_r3/run.sh +++ b/egs/tedlium/s5_r3/run.sh @@ -207,7 +207,7 @@ if [ $stage -le 19 ]; then for dset in dev test; do data_dir=data/${dset}_hires - decoding_dir=exp/chain_cleaned/tdnnf_1a + decoding_dir=exp/chain_cleaned/tdnnf_1a/decode_${dset} suffix=$(basename $rnnlm_dir) output_dir=${decoding_dir}_$suffix diff --git a/egs/thchs30/s5/local/dae/add-noise-mod.py b/egs/thchs30/s5/local/dae/add-noise-mod.py index 8327fc325ee..4486fd0fdc7 100755 --- a/egs/thchs30/s5/local/dae/add-noise-mod.py +++ b/egs/thchs30/s5/local/dae/add-noise-mod.py @@ -3,6 +3,7 @@ from __future__ import print_function +from __future__ import division import optparse import random import bisect @@ -26,7 +27,7 @@ def energy(mat): def mix(mat, noise, pos, scale): ret = [] l = len(noise) - for i in xrange(len(mat)): + for i in range(len(mat)): x = mat[i] d = int(x + scale * noise[pos]) #if d > 32767 or d < -32768: @@ -41,8 +42,8 @@ def mix(mat, noise, pos, scale): def dirichlet(params): samples = [random.gammavariate(x, 1) if x > 0 else 0. for x in params] - samples = [x / sum(samples) for x in samples] - for x in xrange(1, len(samples)): + samples = [(x / sum(samples)) for x in samples] + for x in range(1, len(samples)): samples[x] += samples[x - 1] return bisect.bisect_left(samples, random.random()) @@ -125,7 +126,7 @@ def main(): mat = wave_mat(wav) signal = energy(mat) logging.debug('signal energy: %f', signal) - noise = signal / (10 ** (noise_level / 10.)) + noise = signal / (10 ** (noise_level / 10)) logging.debug('noise energy: %f', noise) type = dirichlet(params) logging.debug('selected type: %d', type) diff --git a/egs/timit/s5/local/timit_data_prep.sh b/egs/timit/s5/local/timit_data_prep.sh index f8f288ffccc..be2d6725952 100755 --- a/egs/timit/s5/local/timit_data_prep.sh +++ b/egs/timit/s5/local/timit_data_prep.sh @@ -70,7 +70,7 @@ for x in train dev test; do find $*/{$train_dir,$test_dir} -not \( -iname 'SA*' \) -iname '*.WAV' \ | grep -f $tmpdir/${x}_spk > ${x}_sph.flist - sed -e 's:.*/\(.*\)/\(.*\).WAV$:\1_\2:i' ${x}_sph.flist \ + sed -e 's:.*/\(.*\)/\(.*\).\(WAV\|wav\)$:\1_\2:' ${x}_sph.flist \ > $tmpdir/${x}_sph.uttids paste $tmpdir/${x}_sph.uttids ${x}_sph.flist \ | sort -k1,1 > ${x}_sph.scp @@ -82,7 +82,7 @@ for x in train dev test; do # ID followed by the transcript. find $*/{$train_dir,$test_dir} -not \( -iname 'SA*' \) -iname '*.PHN' \ | grep -f $tmpdir/${x}_spk > $tmpdir/${x}_phn.flist - sed -e 's:.*/\(.*\)/\(.*\).PHN$:\1_\2:i' $tmpdir/${x}_phn.flist \ + sed -e 's:.*/\(.*\)/\(.*\).\(PHN\|phn\)$:\1_\2:' $tmpdir/${x}_phn.flist \ > $tmpdir/${x}_phn.uttids while read line; do [ -f $line ] || error_exit "Cannot find transcription file '$line'"; diff --git a/egs/tunisian_msa/s5/README b/egs/tunisian_msa/s5/README new file mode 100644 index 00000000000..ae2aa2bc452 --- /dev/null +++ b/egs/tunisian_msa/s5/README @@ -0,0 +1,24 @@ +A Kaldi recipe for Arabic using the Tunisian_MSA corpus. + +Extra Requirements: +This recipe uses the QCRI lexicon which uses the Buckwalter encoding. +In order to convert the Buckwalter to utf-8, the Encode::Arabic::Buckwalter perl module is required. +On ubuntu install the package: libencode-arabic-perl. +On Mac OSX use cpanm (cpanminus) to install the perl module. + +Description of the Tunisian_MSA Corpus +The Tunisian_MSA corpus was originally collected to train acoustic models for pronunciation modeling in Arabic language learning applications. +The data collection took place near Tunis the capital of the Republic of Tunisia in 2003 at the Military Academy of Fondouk Jedied . +The Tunisian_MSA corpus is divided into recited and prompted speech subcorpora. +The recited speech appears under the recordings directory and the prompted speech under the answers directory. +Each of the 118 informants contributed to both subcorpora by reciting sentences and providing answers to prompted questions. +The Tunisian_MSA corpus has 11.2 hours of speech. + +With the exception of speech from two speakers , all the corpus was used for training. + +A small corpus was collected for testing. + +A pronunciation dictionary is also available from openslrm.org. +It covers all the words uttered in the Tunisian_MSA corpus and the test corpus. +The QCRI lexicon was used as a starting point for writing this lexicon. +The phones are the same as those used in the QCRI lexicon. diff --git a/egs/tunisian_msa/s5/cmd.sh b/egs/tunisian_msa/s5/cmd.sh new file mode 100644 index 00000000000..71dd849a93b --- /dev/null +++ b/egs/tunisian_msa/s5/cmd.sh @@ -0,0 +1,15 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +export train_cmd="queue.pl --mem 2G" +export decode_cmd="queue.pl --mem 4G" +export mkgraph_cmd="queue.pl --mem 8G" diff --git a/egs/tunisian_msa/s5/conf/mfcc.conf b/egs/tunisian_msa/s5/conf/mfcc.conf new file mode 100644 index 00000000000..7361509099f --- /dev/null +++ b/egs/tunisian_msa/s5/conf/mfcc.conf @@ -0,0 +1 @@ +--use-energy=false # only non-default option. diff --git a/egs/tunisian_msa/s5/conf/mfcc_hires.conf b/egs/tunisian_msa/s5/conf/mfcc_hires.conf new file mode 100644 index 00000000000..434834a6725 --- /dev/null +++ b/egs/tunisian_msa/s5/conf/mfcc_hires.conf @@ -0,0 +1,10 @@ +# config for high-resolution MFCC features, intended for neural network training +# Note: we keep all cepstra, so it has the same info as filterbank features, +# but MFCC is more easily compressible (because less correlated) which is why +# we prefer this method. +--use-energy=false # use average of log energy, not energy. +--num-mel-bins=40 # similar to Google's setup. +--num-ceps=40 # there is no dimensionality reduction. +--low-freq=20 # low cutoff frequency for mel bins... this is high-bandwidth data, so + # there might be some information at the low end. +--high-freq=-400 # high cutoff frequently, relative to Nyquist of 8000 (=7600) diff --git a/egs/tunisian_msa/s5/conf/online_cmvn.conf b/egs/tunisian_msa/s5/conf/online_cmvn.conf new file mode 100644 index 00000000000..7748a4a4dd3 --- /dev/null +++ b/egs/tunisian_msa/s5/conf/online_cmvn.conf @@ -0,0 +1 @@ +# configuration file for apply-cmvn-online, used in the script ../local/run_online_decoding.sh diff --git a/egs/tunisian_msa/s5/local/answers_make_lists.pl b/egs/tunisian_msa/s5/local/answers_make_lists.pl new file mode 100755 index 00000000000..55ee5751d9b --- /dev/null +++ b/egs/tunisian_msa/s5/local/answers_make_lists.pl @@ -0,0 +1,77 @@ +#!/usr/bin/env perl + +# Copyright 2018 John Morgan +# Apache 2.0. + +# answers_make_lists.pl - make acoustic model training lists + +use strict; +use warnings; +use Carp; + +use File::Spec; +use File::Copy; +use File::Basename; + +my $tmpdir = 'data/local/tmp/tunis'; + +system "mkdir -p $tmpdir/answers"; + +# input wav file list +my $wav_list = "$tmpdir/answers_wav.txt"; + +# output temporary wav.scp files +my $wav_scp = "$tmpdir/answers/wav.scp"; + +# output temporary utt2spk files +my $u = "$tmpdir/answers/utt2spk"; + +# output temporary text files +my $t = "$tmpdir/answers/text"; + +# initialize hash for prompts +my %prompt = (); + +# store prompts in hash +LINEA: while ( my $line = <> ) { + chomp $line; + my ($num,$sent) = split /\t/sxm, $line, 2; + + my ($machine,$s,$mode,$language,$i) = split /\_/sxm, $num; + # the utterance name + my $utt = $machine . '_' . $s . '_' . 'a' . '_' . $i; + $prompt{$utt} = $sent; +} + +# Write wav.scp, utt2spk and text files. +open my $W, '<', $wav_list or croak "problem with $wav_list $!"; +open my $O, '+>', $wav_scp or croak "problem with $wav_scp $!"; +open my $U, '+>', $u or croak "problem with $u"; +open my $T, '+>', $t or croak "problem with $t"; + + LINE: while ( my $line = <$W> ) { + chomp $line; + next LINE if ( $line !~ /Answers/sxm ); + next LINE if ( $line =~ /Recordings/sxm ); + my ($volume,$directories,$file) = File::Spec->splitpath( $line ); + my @dirs = split /\//sxm, $directories; + my $r = basename $line, '.wav'; + my $machine = $dirs[-3]; + my $s = $dirs[-1]; + my $rid = $machine . '_' . $s . '_' . 'a' . '_' . $r; + if ( exists $prompt{$rid} ) { + print ${T} "$rid\t$prompt{$rid}\n" or croak; + } elsif ( defined $rid ) { + print STDERR "problem\t$rid" or croak; + next LINE; + } else { + croak "$line"; + } + + print ${O} "$rid sox $line -t wav - |\n" or croak; + print ${U} "$rid ${machine}_${s}_a\n" or croak; +} +close $U or croak; +close $T or croak; +close $W or croak; +close $O or croak; diff --git a/egs/tunisian_msa/s5/local/buckwalter2unicode.py b/egs/tunisian_msa/s5/local/buckwalter2unicode.py new file mode 100755 index 00000000000..f81841261ce --- /dev/null +++ b/egs/tunisian_msa/s5/local/buckwalter2unicode.py @@ -0,0 +1,454 @@ +#!/usr/bin/python + +# buckwalter2unicode.py - A script to convert transliterated Arabic +# (using the Buckwalter system) to Unicode. +# +# Version 0.2 - 15th September 2004 +# +# Andrew Roberts (andyr [at] comp (dot) leeds [dot] ac (dot) uk) +# +# Project homepage: http://www.comp.leeds.ac.uk/andyr/software/ +# +# Now, listen carefully... +# +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +# + +from __future__ import print_function +import sys, getopt, codecs, os, re + +# Declare a dictionary with Buckwalter's ASCII symbols as the keys, and +# their unicode equivalents as values. + +buck2uni = {"'": u"\u0621", # hamza-on-the-line + "|": u"\u0622", # madda + ">": u"\u0623", # hamza-on-'alif + "&": u"\u0624", # hamza-on-waaw + "<": u"\u0625", # hamza-under-'alif + "}": u"\u0626", # hamza-on-yaa' + "A": u"\u0627", # bare 'alif + "b": u"\u0628", # baa' + "p": u"\u0629", # taa' marbuuTa + "t": u"\u062A", # taa' + "v": u"\u062B", # thaa' + "j": u"\u062C", # jiim + "H": u"\u062D", # Haa' + "x": u"\u062E", # khaa' + "d": u"\u062F", # daal + "*": u"\u0630", # dhaal + "r": u"\u0631", # raa' + "z": u"\u0632", # zaay + "s": u"\u0633", # siin + "$": u"\u0634", # shiin + "S": u"\u0635", # Saad + "D": u"\u0636", # Daad + "T": u"\u0637", # Taa' + "Z": u"\u0638", # Zaa' (DHaa') + "E": u"\u0639", # cayn + "g": u"\u063A", # ghayn + "_": u"\u0640", # taTwiil + "f": u"\u0641", # faa' + "q": u"\u0642", # qaaf + "k": u"\u0643", # kaaf + "l": u"\u0644", # laam + "m": u"\u0645", # miim + "n": u"\u0646", # nuun + "h": u"\u0647", # haa' + "w": u"\u0648", # waaw + "Y": u"\u0649", # 'alif maqSuura + "y": u"\u064A", # yaa' + "F": u"\u064B", # fatHatayn + "N": u"\u064C", # Dammatayn + "K": u"\u064D", # kasratayn + "a": u"\u064E", # fatHa + "u": u"\u064F", # Damma + "i": u"\u0650", # kasra + "~": u"\u0651", # shaddah + "o": u"\u0652", # sukuun + "`": u"\u0670", # dagger 'alif + "{": u"\u0671", # waSla +} + +# For a reverse transliteration (Unicode -> Buckwalter), a dictionary +# which is the reverse of the above buck2uni is essential. + +uni2buck = {} + +# Iterate through all the items in the buck2uni dict. +for (key, value) in buck2uni.items(): + # The value from buck2uni becomes a key in uni2buck, and vice + # versa for the keys. + uni2buck[value] = key + +# Declare some global variables... + + +inFilename = "" # Name of filename containing input. +outFilename = "" # Name of filename to send the output +inEnc = "" # The text encoding of the input file +outEnc = "" # The text encoding for the output file +ignoreChars = "" # If lines begin with these symbols, ignore. +columnRange = "" # Holds columns numbers to transliterate. +delimiter = "" # Holds user-defined column delimiter. +reverse = 0 # When equal to 1, perform reverse transliteration, i.e., + # Unicode -> Buckwalter. + +# A function to print to screen the usage details of this script. + +def usage(): + print("Usage: {} -i INFILE -o OUTFILE [-g CHARS -c RANGE -d CHAR".format(sys.argv[0])) + print(" -r -e INPUT_ENCODING, -E OUTPUT ENCODING]") + print(" {} -l".format(sys.argv[0])) + print(" {} -h".format(sys.argv[0])) + print("") + print(" -i INFILE, --input=INFILE:") + print(" Path to text file to be transliterated to Unicode.") + print(" -o OUTFILE, --output=OUTFILE:") + print(" Path of file to output the newly transliterated text.") + print(" -e ENC, --input-encoding=ENC:") + print(" Specify the text encoding of the source file. Default: latin_1.") + print(" -E ENC, --output-encoding=ENC:") + print(" Specify the text encoding of the target file. Default: utf_8.") + print(" -g CHARS, --ignore-lines=CHARS:") + print(" Will not transliterate lines that start with any of the CHARS") + print(" given. E.g., -g #; will not alter lines starting with # or ;.") + print(" (May need to be -g \#\; on some platforms. See README.txt.)") + print(" -c RANGE, --columns=RANGE:") + print(" If in columns, select columns to apply transliteration. Can be") + print(" comma separated numbers, or a range. E.g., -c 1, -c 1-3, -c 1,3.") + print(" -d CHAR, --delimiter=CHAR:") + print(" Specify the delimiter that defines the column if using the -c") + print(" option above. Default is ' ' (space).") + print(" -r, --reverse:") + print(" Reverses the transliteration, i.e., Arabic to Buckwalter.") + print(" When used, it will change the default input encoding to utf_8 and") + print(" output encoding to latin_1") + print(" -l, --list-encodings:") + print(" Displays all supported file encodings.") + print(" -h, --help:") + print(" Displays this page.") + print("") + +# A function to print to screen all the available encodings supported by +# Python. + +def displayEncodings(): + print("Codec Aliases Languages") + print("ascii 646, us-ascii English") + print("cp037 IBM037, IBM039 English") + print("cp424 EBCDIC-CP-HE, IBM424 Hebrew") + print("cp437 437, IBM437 English") + print("cp500 EBCDIC-CP-BE, EBCDIC-CP-CH, IBM500 Western Europe") + print("cp737 Greek") + print("cp775 IBM775 Baltic languages") + print("cp850 850, IBM850 Western Europe") + print("cp852 852, IBM852 Central and Eastern Europe") + print("cp855 855, IBM855 Bulgarian, Byelorussian, Macedonian, Russian, Serbian") + print("cp856 Hebrew") + print("cp857 857, IBM857 Turkish") + print("cp860 860, IBM860 Portuguese") + print("cp861 861, CP-IS, IBM861 Icelandic") + print("cp862 862, IBM862 Hebrew") + print("cp863 863, IBM863 Canadian") + print("cp864 IBM864 Arabic") + print("cp865 865, IBM865 Danish, Norwegian") + print("cp869 869, CP-GR, IBM869 Greek") + print("cp874 Thai") + print("cp875 Greek") + print("cp1006 Urdu") + print("cp1026 ibm1026 Turkish") + print("cp1140 ibm1140 Western Europe") + print("cp1250 windows-1250 Central and Eastern Europe") + print("cp1251 windows-1251 Bulgarian, Byelorussian, Macedonian, Russian, Serbian") + print("cp1252 windows-1252 Western Europe") + print("cp1253 windows-1253 Greek") + print("cp1254 windows-1254 Turkish") + print("cp1255 windows-1255 Hebrew") + print("cp1256 windows-1256 Arabic") + print("cp1257 windows-1257 Baltic languages") + print("cp1258 windows-1258 Vietnamese") + print("latin_1 iso-8859-1, iso8859-1, 8859, cp819, latin, latin1, L1 West Europe") + print("iso8859_2 iso-8859-2, latin2, L2 Central and Eastern Europe") + print("iso8859_3 iso-8859-3, latin3, L3 Esperanto, Maltese") + print("iso8859_4 iso-8859-4, latin4, L4 Baltic languagues") + print("iso8859_5 iso-8859-5, cyrillic Bulgarian, Byelorussian, Macedonian, Russian, Serbian") + print("iso8859_6 iso-8859-6, arabic Arabic") + print("iso8859_7 iso-8859-7, greek, greek8 Greek") + print("iso8859_8 iso-8859-8, hebrew Hebrew") + print("iso8859_9 iso-8859-9, latin5, L5 Turkish") + print("iso8859_10 iso-8859-10, latin6, L6 Nordic languages") + print("iso8859_13 iso-8859-13 Baltic languages") + print("iso8859_14 iso-8859-14, latin8, L8 Celtic languages") + print("iso8859_15 iso-8859-15 Western Europe") + print("koi8_r Russian") + print("koi8_u Ukrainian") + print("mac_cyrillic maccyrillic Bulgarian, Byelorussian, Macedonian, Russian, Serbian") + print("mac_greek macgreek Greek") + print("mac_iceland maciceland Icelandic") + print("mac_latin2 maclatin2, maccentraleurope Central and Eastern Europe") + print("mac_roman macroman Western Europe") + print("mac_turkish macturkish Turkish") + print("utf_16 U16, utf16 all languages") + print("utf_16_be UTF-16BE all languages (BMP only)") + print("utf_16_le UTF-16LE all languages (BMP only)") + print("utf_7 U7 all languages") + print("utf_8 U8, UTF, utf8 all languages") + +def parseIgnoreString(string): + + symbols = [] + + for char in string: + symbols.append(char) + + return symbols + +# Begin parsing the command-line arguments... + +try: + (options, args) = getopt.getopt(sys.argv[1:], "i:o:e:E:g:c:d:rlh", + ["input=","output=", "input-encoding=", "output-encoding=", + "ignore-lines=", "columns=", "delimiter=" "reverse", "list-encodings", + "help"]) + +except getopt.GetoptError: + # print help information and exit: + usage() + sys.exit(1) + +# Loop over all arguments supplied by the user. +for (x, y) in options: + if x in ("-h", "--help"): + usage() + sys.exit(0) + + if x in ("-l", "--list-encodings"): + displayEncodings() + sys.exit(0) + + if x in ("-i", "--input"): inFilename = y + if x in ("-o", "--output"): outFilename = y + if x in ("-e", "--input-encoding"): inEnc= y + if x in ("-E", "--output-encoding"): outEnc= y + if x in ("-r", "--reverse"): reverse = 1 + if x in ("-g", "--ignore-lines"): ignoreChars = y + if x in ("-c", "--columns"): columnRange = y + if x in ("-d", "--delimiter"): + delimiter = y + # Tabs come in off the command line from "\\t" to "\t". However, + # that's equivalent to "\\t" from python's point of view. + # Therefore replace any inputted "tabs" with proper tabs before + # proceeding. + delimiter = delimiter.replace("\\t", "\t") + # Do some error checking + if len(delimiter) > 1: + print("Delimeter should only be a single character. Using first character" + delimiter[0], file=sys.stderr) + delimiter = delimiter[0] + + if buck2uni.get(delimiter): + print("Invalid delimiter. \"" + delimiter + "\" is part of the Buckwalter character set.", file=sys.stderr) + print("This will obviously cause much confusion as a delimiter!", file=sys.stderr) + print("Please try again. Aborting...", file=sys.stderr) + sys.exit(1) + +# If no delimiter was set then, set the default to " " (space) +if not delimiter: + delimiter = " " + +# If user didn't specify the encoding of the input file, then revert to +# defaults. The defaults can depending on the direction of +# transliteration: +# +# Buckwalter -> Unicode, default = latin1 +# Unicode -> Buckwalter, default = utf_8 + + +if not inEnc: + if reverse: + inEnc = "utf_8" + else: + inEnc = "latin_1" + +# Similarly, if user didn't specify the encoding of the output file, +# then revert to defaults. The defaults can depending on the direction +# of transliteration: +# +# Buckwalter -> Unicode, default = utf_8 +# Unicode -> Buckwalter, default # = latin_1 + +if not outEnc: + if reverse: + outEnc = "latin_1" + else: + outEnc = "utf_8" + +# Ok, let's get the files open! + +# Providing a file for output was specified... +if outFilename: + try: + # Create a file object, set it to "write" mode using the + # specified output encoding. + outFile = codecs.open(outFilename, "w", outEnc) + + except IOError as msg: + # A problem occurred when trying to open this file. Report to + # user... + print(msg) + sys.exit(1) + +# Script can not work without somewhere to store the transliteration. +# Exit. +else: + print("Must specify a file to use store the output! Aborting...") + sys.exit(1) + +# Providing a file for input was specified... +if inFilename: + try: + # Create a file object, set it to "read" mode using the + # specified input encoding. + inFile = codecs.open(inFilename, "r", inEnc) + + except IOError as msg: + # A problem occurred when trying to open this file. Report to + # user... + print(msg) + sys.exit(1) + +# This script requires a file to read from. Exit. +else: + print("Must specify a file to use as input! Aborting...") + sys.exit(1) + +def getColsFromRange(cRange): + + columns = [] + hyphenSearch = re.compile(r'-') + + rangeElements = cRange.split(",") + + for i in rangeElements: + # If it contains a hyphen (e.g., 1-3) + if hyphenSearch.search(i): + [start, end] = i.split("-") + columns = columns + list(range(int(start)-1,int(end))) + else: + columns.append(int(i)-1) + + return columns + +# This function transliterates a given string. It checks the direction +# of the transliteration and then uses the appropriate dictionary. A +# transliterated string is returned. + +def transliterate(inString, lineNumber): + out = "" + + if columnRange: + columns = getColsFromRange(columnRange) + + # Split the line on the delimiter + lineCols = inString.split(delimiter) + + # Iterate over each column. If it's one of the ones in the range + # specified, then transliterate, otherwise just output column + # unchanged. + + for i in range(len(lineCols)): + + # If first column, then don't prefix the delimiter + if i == 0: + if i in columns: + out = transliterateString(lineCols[i]) + else : + out = lineCols[i] + else : + if i in columns: + out = out + delimiter + transliterateString(lineCols[i]) + else : + out = out + delimiter + lineCols[i] + + else: + out = transliterateString(inString) + + + + return out + +def transliterateString(inString): + + out = "" + + # For normal Buckwalter -> Unicode transliteration.. + if not reverse: + + # Loop over each character in the string, inString. + for char in inString: + # Look up current char in the dictionary to get its + # respective value. If there is no match, e.g., chars like + # spaces, then just stick with the current char without any + # conversion. + out = out + buck2uni.get(char, char) + + # Same as above, just in the other direction. + else: + + for char in inString: + out = out + uni2buck.get(char, char) + + return out + +#while 1: +# line = inFile.readline().strip() +# line = line.decode(inEnc) +# if not line: +# break + + # process string +# outFile.write(transliterate(line) + os.linesep) + +# Read in the lines of the input file. +lines = inFile.readlines() + +currentLineNumber = 1 +# Loop over each line +for line in lines: + line = line.strip() + try: + # Transliterate the current line, and then write the output to + # file. + + if not ignoreChars: + outFile.write(transliterate(line, currentLineNumber) + " " + os.linesep) + else: + if line[0] in parseIgnoreString(ignoreChars): + outFile.write(line + " " + os.linesep) + else: + outFile.write(transliterate(line, currentLineNumber) + " " + os.linesep) + + currentLineNumber = currentLineNumber + 1 + + except UnicodeError as msg: + # A problem when writing occurred. Report to user... + print(msg) + sys.exit(1) + +# All done! Better close the files used before terminating... +inFile.close() +outFile.close() + +# ... and relax! :) diff --git a/egs/tunisian_msa/s5/local/chain/compare_wer.sh b/egs/tunisian_msa/s5/local/chain/compare_wer.sh new file mode 100755 index 00000000000..c6a3a91ea69 --- /dev/null +++ b/egs/tunisian_msa/s5/local/chain/compare_wer.sh @@ -0,0 +1,133 @@ +#!/bin/bash + +# this script is used for comparing decoding results between systems. +# e.g. local/chain/compare_wer.sh exp/chain/tdnn_{c,d}_sp +# For use with discriminatively trained systems you specify the epochs after a colon: +# for instance, +# local/chain/compare_wer.sh exp/chain/tdnn_c_sp exp/chain/tdnn_c_sp_smbr:{1,2,3} + + +if [ $# == 0 ]; then + echo "Usage: $0: [--looped] [--online] [ ... ]" + echo "e.g.: $0 exp/chain/tdnn_{b,c}_sp" + echo "or (with epoch numbers for discriminative training):" + echo "$0 exp/chain/tdnn_b_sp_disc:{1,2,3}" + exit 1 +fi + +echo "# $0 $*" + +include_looped=false +if [ "$1" == "--looped" ]; then + include_looped=true + shift +fi +include_online=false +if [ "$1" == "--online" ]; then + include_online=true + shift +fi + + +used_epochs=false + +# this function set_names is used to separate the epoch-related parts of the name +# [for discriminative training] and the regular parts of the name. +# If called with a colon-free directory name, like: +# set_names exp/chain/tdnn_lstm1e_sp_bi_smbr +# it will set dir=exp/chain/tdnn_lstm1e_sp_bi_smbr and epoch_infix="" +# If called with something like: +# set_names exp/chain/tdnn_d_sp_smbr:3 +# it will set dir=exp/chain/tdnn_d_sp_smbr and epoch_infix="_epoch3" + + +set_names() { + if [ $# != 1 ]; then + echo "compare_wer_general.sh: internal error" + exit 1 # exit the program + fi + dirname=$(echo $1 | cut -d: -f1) + epoch=$(echo $1 | cut -s -d: -f2) + if [ -z $epoch ]; then + epoch_infix="" + else + used_epochs=true + epoch_infix=_epoch${epoch} + fi +} + + + +echo -n "# System " +for x in $*; do printf "% 10s" " $(basename $x)"; done +echo + +test_sets=(devtest test) + +for t in ${test_sets[@]}; do + printf '# %%WER % 14s ' $t + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + wer=$(cat $dirname/decode_$t/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + if $include_looped; then + echo -n "# [looped:] " + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + wer=$(cat $dirname/decode_looped_$t/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + fi + if $include_online; then + echo -n "# [online:] " + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + wer=$(cat ${dirname}_online/decode_$t/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + fi +done + + +if $used_epochs; then + exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. +fi + + +echo -n "# Final train prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final train prob (xent)" +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob (xent)" +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Num-params " +for x in $*; do + printf "% 10s" $(grep num-parameters $x/log/progress.1.log | awk '{print $2}') +done +echo diff --git a/egs/tunisian_msa/s5/local/chain/run_tdnn.sh b/egs/tunisian_msa/s5/local/chain/run_tdnn.sh new file mode 120000 index 00000000000..34499362831 --- /dev/null +++ b/egs/tunisian_msa/s5/local/chain/run_tdnn.sh @@ -0,0 +1 @@ +tuning/run_tdnn_1a.sh \ No newline at end of file diff --git a/egs/tunisian_msa/s5/local/chain/tuning/run_tdnn_1a.sh b/egs/tunisian_msa/s5/local/chain/tuning/run_tdnn_1a.sh new file mode 100755 index 00000000000..ab68ba6fb68 --- /dev/null +++ b/egs/tunisian_msa/s5/local/chain/tuning/run_tdnn_1a.sh @@ -0,0 +1,292 @@ +#!/bin/bash + +# Uses a resnet-style factored TDNN-F model. + +# ./local/chain/compare_wer.sh exp/chain/tdnn1a_sp +# System tdnn1a_sp +# %WER devtest 39.25 +# %WER test 49.74 +# Final train prob -0.0473 +# Final valid prob -0.0538 +# Final train prob (xent) -1.0935 +# Final valid prob (xent) -1.0817 +# Num-params 3466448 + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +decode_nj=10 +train_set=train +test_sets="devtest test" +gmm=tri3b +nnet3_affix= + +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +affix=1a # affix for the TDNN directory name +tree_affix= +train_stage=-10 +get_egs_stage=-10 +decode_iter= + +num_leaves=3500 + +# training options +# training chunk-options +chunk_width=140,100,160 +# we don't need extra left/right context for TDNN systems. +dropout_schedule='0,0@0.20,0.3@0.50,0' +common_egs_dir= +xent_regularize=0.1 + +# training options +srand=0 +remove_egs=true +reporting_email= + +#decode options +test_online_decoding=true # if true, it will run the last decoding stage. + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 11 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj 20 --cmd "$train_cmd" ${lores_train_data_dir} \ + data/lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 12 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh \ + --cmd "$train_cmd" \ + --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + $num_leaves \ + ${lores_train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 13 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + affine_opts="l2-regularize=0.03 dropout-proportion=0.0 dropout-per-dim-continuous=true" + tdnnf_opts="l2-regularize=0.03 dropout-proportion=0.0 bypass-scale=0.66" + linear_opts="l2-regularize=0.03 orthonormal-constraint=-1.0" + prefinal_opts="l2-regularize=0.03" + output_opts="l2-regularize=0.015" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-dropout-layer name=tdnn1 $affine_opts dim=768 + tdnnf-layer name=tdnnf2 $tdnnf_opts dim=768 bottleneck-dim=64 time-stride=1 + tdnnf-layer name=tdnnf3 $tdnnf_opts dim=768 bottleneck-dim=64 time-stride=1 + tdnnf-layer name=tdnnf4 $tdnnf_opts dim=768 bottleneck-dim=64 time-stride=1 + tdnnf-layer name=tdnnf5 $tdnnf_opts dim=768 bottleneck-dim=64 time-stride=0 + tdnnf-layer name=tdnnf6 $tdnnf_opts dim=768 bottleneck-dim=64 time-stride=3 + tdnnf-layer name=tdnnf7 $tdnnf_opts dim=768 bottleneck-dim=64 time-stride=3 + tdnnf-layer name=tdnnf8 $tdnnf_opts dim=768 bottleneck-dim=64 time-stride=3 + tdnnf-layer name=tdnnf9 $tdnnf_opts dim=768 bottleneck-dim=64 time-stride=3 + tdnnf-layer name=tdnnf10 $tdnnf_opts dim=768 bottleneck-dim=64 time-stride=3 + tdnnf-layer name=tdnnf11 $tdnnf_opts dim=768 bottleneck-dim=64 time-stride=3 + linear-component name=prefinal-l dim=192 $linear_opts + + ## adding the layers for chain branch + prefinal-layer name=prefinal-chain input=prefinal-l $prefinal_opts big-dim=768 small-dim=192 + output-layer name=output include-log-softmax=false dim=$num_targets $output_opts + + # adding the layers for xent branch + prefinal-layer name=prefinal-xent input=prefinal-l $prefinal_opts big-dim=768 small-dim=192 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 14 ]; then + steps/nnet3/chain/train.py \ + --stage=$train_stage \ + --cmd="$decode_cmd" \ + --feat.online-ivector-dir=$train_ivector_dir \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.0 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=8 \ + --trainer.frames-per-iter=3000000 \ + --trainer.optimization.num-jobs-initial=2 \ + --trainer.optimization.num-jobs-final=5 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.num-chunk-per-minibatch=128,64 \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 15 ]; then + # Note: it's not important to give mkgraph.sh the lang directory with the + # matched topology (since it gets the topology file from the model). + utils/mkgraph.sh \ + --self-loop-scale 1.0 \ + data/lang_test \ + $tree_dir \ + $tree_dir/graph || exit 1; +fi + +if [ $stage -le 16 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + rm $dir/.error 2>/dev/null || true + + for data in $test_sets; do + ( + nspk=$(wc -l /dev/null || true + + for data in $test_sets; do + ( + nspk=$(wc -l +example: +$0 Tunisian_MSA/data/transcripts/devtest/recordings.tsv 6 tunisia +"; +} + +my ($tr,$spk,$l) = @ARGV; + +open my $I, '<', $tr or croak "problems with $tr"; + +my $tmp_dir = "data/local/tmp/$l/$spk"; + +# input wav file list +my $wav_list = "$tmp_dir/wav.txt"; +croak "$!" unless ( -f $wav_list ); +# output temporary wav.scp files +my $wav_scp = "$tmp_dir/wav.scp"; + +# output temporary utt2spk files +my $u = "$tmp_dir/utt2spk"; + +# output temporary text files +my $t = "$tmp_dir/text"; + +# initialize hash for prompts +my %p = (); + +# store prompts in hash +LINEA: while ( my $line = <$I> ) { + chomp $line; + my ($s,$sent) = split /\t/, $line, 2; + $p{$s} = $sent; +} + +open my $W, '<', $wav_list or croak "problem with $wav_list $!"; +open my $O, '+>', $wav_scp or croak "problem with $wav_scp $!"; +open my $U, '+>', $u or croak "problem with $u $!"; +open my $T, '+>', $t or croak "problem with $t $!"; + + LINE: while ( my $line = <$W> ) { + chomp $line; + next LINE if ($line =~ /answers/ ); + next LINE unless ( $line =~ /Recordings/ ); + my ($volume,$directories,$file) = File::Spec->splitpath( $line ); + my @dirs = split /\//, $directories; + my $b = basename $line, ".wav"; + my $s = $dirs[-1]; + my $rid = $s . '_' . 'recording' . '_' . $b; + my $uid = $s . '_' . 'recording'; + if ( exists $p{$b} ) { + print $T "$rid\t$p{$b}\n"; + } elsif ( defined $s ) { + warn "problem\t$s"; + next LINE; + } else { + croak "$line"; + } + + print $O "$rid sox $line -t wav - |\n"; + print $U "$rid\t$uid\n"; +} +close $T; +close $O; +close $U; +close $W; diff --git a/egs/tunisian_msa/s5/local/nnet3/run_ivector_common.sh b/egs/tunisian_msa/s5/local/nnet3/run_ivector_common.sh new file mode 100755 index 00000000000..e8ff9a150ea --- /dev/null +++ b/egs/tunisian_msa/s5/local/nnet3/run_ivector_common.sh @@ -0,0 +1,185 @@ +#!/bin/bash + +set -euo pipefail + +# This script is called from local/nnet3/run_tdnn.sh and +# local/chain/run_tdnn.sh (and may eventually be called by more +# scripts). It contains the common feature preparation and +# iVector-related parts of the script. See those scripts for examples +# of usage. + +stage=0 +train_set=train +test_sets="devtest test" +gmm=tri3b + +nnet3_affix= + +. ./cmd.sh +. ./path.sh +. utils/parse_options.sh + +gmm_dir=exp/${gmm} +ali_dir=exp/${gmm}_ali_${train_set}_sp + +for f in data/${train_set}/feats.scp ${gmm_dir}/final.mdl; do + if [ ! -f $f ]; then + echo "$0: expected file $f to exist" + exit 1 + fi +done + +if [ $stage -le 1 ]; then + # perturb data to get alignments + # nnet will be trained by high resolution data + # _sp stands for speed-perturbed + echo "$0: preparing directory for low-resolution speed-perturbed data (for alignment)" + utils/data/perturb_data_dir_speed_3way.sh \ + data/${train_set} \ + data/${train_set}_sp + echo "$0: making mfcc features for low-resolution speed-perturbed data" + steps/make_mfcc.sh \ + --cmd "$train_cmd" \ + --nj 10 \ + data/${train_set}_sp + steps/compute_cmvn_stats.sh \ + data/${train_set}_sp + utils/fix_data_dir.sh \ + data/${train_set}_sp +fi + +if [ $stage -le 2 ]; then + echo "$0: aligning with the perturbed low-resolution data" + steps/align_fmllr.sh \ + --nj 20 \ + --cmd "$train_cmd" \ + data/${train_set}_sp \ + data/lang \ + $gmm_dir \ + $ali_dir +fi + +if [ $stage -le 3 ]; then + # Create high-resolution MFCC features (with 40 cepstra instead of 13). + + echo "$0: creating high-resolution MFCC features" + mfccdir=data/${train_set}_sp_hires/data + for datadir in ${train_set}_sp ${test_sets}; do + utils/copy_data_dir.sh \ + data/$datadir \ + data/${datadir}_hires + done + + # do volume-perturbation on the training data prior to extracting hires + # features; this helps make trained nnets more invariant to test data volume. + utils/data/perturb_data_dir_volume.sh \ + data/${train_set}_sp_hires + + for datadir in ${train_set}_sp ${test_sets}; do + steps/make_mfcc.sh \ + --nj 10 \ + --mfcc-config conf/mfcc_hires.conf \ + --cmd "$train_cmd" \ + data/${datadir}_hires + steps/compute_cmvn_stats.sh \ + data/${datadir}_hires + utils/fix_data_dir.sh \ + data/${datadir}_hires + done +fi + +if [ $stage -le 4 ]; then + echo "$0: computing a subset of data to train the diagonal UBM." + # We'll use about a quarter of the data. + mkdir -p exp/nnet3${nnet3_affix}/diag_ubm + temp_data_root=exp/nnet3${nnet3_affix}/diag_ubm + + num_utts_total=$(wc -l $tmp_tunis/$s/wav.txt + + local/devtest_recordings_make_lists.pl \ + $data_dir/transcripts/devtest/recordings.tsv $s tunis + + mkdir -p data/devtest + + for x in wav.scp utt2spk text; do + cat $tmp_tunis/$s/$x | tr " " " " >> data/devtest/$x + done +done + +utils/utt2spk_to_spk2utt.pl data/devtest/utt2spk | sort > data/devtest/spk2utt + +utils/fix_data_dir.sh data/devtest + +# training data consists of 2 parts: answers and recordings (recited) +answers_transcripts=$data_dir/transcripts/train/answers.tsv +recordings_transcripts=$data_dir/transcripts/train/recordings.tsv + +# location of test data +cls_rec_tr=$libyan_src/cls/data/transcripts/recordings/cls_recordings.tsv +lfi_rec_tr=$libyan_src/lfi/data/transcripts/recordings/lfi_recordings.tsv +srj_rec_tr=$libyan_src/srj/data/transcripts/recordings/srj_recordings.tsv +mbt_rec_tr=$data_dir/transcripts/test/mbt/recordings/mbt_recordings.tsv + +# make acoustic model training lists +mkdir -p $tmp_tunis + +# get wav file names + +# for recited speech +# the data collection laptops had names like CTELLONE CTELLTWO ... +for machine in CTELLONE CTELLTWO CTELLTHREE CTELLFOUR CTELLFIVE; do + find $data_dir/speech/train/$machine -type f -name "*.wav" | grep Recordings \ + >> $tmp_tunis/recordings_wav.txt +done + +# get file names for Answers +for machine in CTELLONE CTELLTWO CTELLTHREE CTELLFOUR CTELLFIVE; do + find $data_dir/speech/train/$machine -type f \ + -name "*.wav" \ + | grep Answers >> $tmp_tunis/answers_wav.txt +done + +# make separate transcription lists for answers and recordings +export LC_ALL=en_US.UTF-8 +local/answers_make_lists.pl $answers_transcripts + +utils/fix_data_dir.sh $tmp_tunis/answers + +local/recordings_make_lists.pl $recordings_transcripts + +utils/fix_data_dir.sh $tmp_tunis/recordings + +# consolidate lists +# acoustic models will be trained on both recited and prompted speech +mkdir -p $tmp_tunis/lists + +for x in wav.scp utt2spk text; do + cat $tmp_tunis/answers/$x $tmp_tunis/recordings/$x > $tmp_tunis/lists/$x +done + +utils/fix_data_dir.sh $tmp_tunis/lists + +# get training lists +mkdir -p data/train +for x in wav.scp utt2spk text; do + sort $tmp_tunis/lists/$x | tr " " " " > data/train/$x +done + +utils/utt2spk_to_spk2utt.pl data/train/utt2spk | sort > data/train/spk2utt + +utils/fix_data_dir.sh data/train + +# process the Libyan MSA data +mkdir -p $tmp_libyan + +for s in cls lfi srj; do + mkdir -p $tmp_libyan/$s + + # get list of wav files + find $libyan_src/$s -type f \ + -name "*.wav" \ + | grep recordings > $tmp_libyan/$s/recordings_wav.txt + + echo "$0: making recordings list for $s" + local/test_recordings_make_lists.pl \ + $libyan_src/$s/data/transcripts/recordings/${s}_recordings.tsv $s libyan +done + +# process the Tunisian MSA test data + +mkdir -p $tmp_tunis/mbt + +# get list of wav files +find $data_dir/speech/test/mbt -type f \ + -name "*.wav" \ + | grep recordings > $tmp_tunis/mbt/recordings_wav.txt + +echo "$0: making recordings list for mbt" +local/test_recordings_make_lists.pl \ + $data_dir/transcripts/test/mbt/recordings/mbt_recordings.tsv mbt tunis + +mkdir -p data/test +# get the Libyan files +for s in cls lfi srj; do + for x in wav.scp utt2spk text; do + cat $tmp_libyan/$s/recordings/$x | tr " " " " >> data/test/$x + done +done + +for x in wav.scp utt2spk text; do + cat $tmp_tunis/mbt/recordings/$x | tr " " " " >> data/test/$x +done + +utils/utt2spk_to_spk2utt.pl data/test/utt2spk | sort > data/test/spk2utt + +utils/fix_data_dir.sh data/test diff --git a/egs/tunisian_msa/s5/local/prepare_dict.sh b/egs/tunisian_msa/s5/local/prepare_dict.sh new file mode 100755 index 00000000000..f7d1ac3a619 --- /dev/null +++ b/egs/tunisian_msa/s5/local/prepare_dict.sh @@ -0,0 +1,43 @@ +#!/bin/bash -u + +# Copyright 2018 John Morgan +# Apache 2.0. + +set -o errexit + +[ -f ./path.sh ] && . ./path.sh + +if [ ! -d data/local/dict ]; then + mkdir -p data/local/dict +fi + +l=$1 +export LC_ALL=C + +cut -f2- -d " " $l | tr -s '[:space:]' '[\n*]' | grep -v SPN | \ + sort -u | tail -n+2 > data/local/dict/nonsilence_phones.txt + +expand -t 1 $l | sort -u | \ + sed "1d" > data/local/dict/lexicon.txt + +echo " SPN" >> data/local/dict/lexicon.txt + +# silence phones, one per line. +{ + echo SIL; + echo SPN; +} \ + > \ + data/local/dict/silence_phones.txt + +echo SIL > data/local/dict/optional_silence.txt + +# get the phone list from the lexicon file +( + tr '\n' ' ' < data/local/dict/silence_phones.txt; + echo; + tr '\n' ' ' < data/local/dict/nonsilence_phones.txt; + echo; +) >data/local/dict/extra_questions.txt + +echo "$0: Finished dictionary preparation." diff --git a/egs/tunisian_msa/s5/local/prepare_lm.sh b/egs/tunisian_msa/s5/local/prepare_lm.sh new file mode 100755 index 00000000000..4fc50b84d11 --- /dev/null +++ b/egs/tunisian_msa/s5/local/prepare_lm.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +# Copyright 2018 John Morgan +# Apache 2.0. + +. ./cmd.sh +set -e +. ./path.sh +. $KALDI_ROOT/tools/env.sh +stage=0 +nsegs=1000000; # limit the number of training segments + +. ./utils/parse_options.sh + +if [ ! -d data/local/lm ]; then + mkdir -p data/local/lm +fi + +corpus=$1 + +if [ ! -f $corpus ]; then + echo "$0: input data $corpus not found." + exit 1 +fi + +perl -MList::Util=shuffle -e 'print shuffle();' < $corpus | \ + head -n $nsegs > data/local/lm/train.txt + +if ! command ngram-count >/dev/null; then + if uname -a | grep darwin >/dev/null; then # For MACOSX... + sdir=$KALDI_ROOT/tools/srilm/bin/macosx + elif uname -a | grep 64 >/dev/null; then # some kind of 64 bit... + sdir=$KALDI_ROOT/tools/srilm/bin/i686-m64 + else + sdir=$KALDI_ROOT/tools/srilm/bin/i686 + fi + if [ -f $sdir/ngram-count ]; then + echo Using SRILM tools from $sdir + export PATH=$PATH:$sdir + else + echo You appear to not have SRILM tools installed, either on your path, + echo or installed in $sdir. See tools/install_srilm.sh for installation + echo instructions. + exit 1 + fi +fi + + +ngram-count -order 3 -interpolate -unk -map-unk "" \ + -limit-vocab -text data/local/lm/train.txt -lm data/local/lm/trigram.arpa || exit 1; + +gzip -f data/local/lm/trigram.arpa diff --git a/egs/tunisian_msa/s5/local/qcri_buckwalter2utf8.sh b/egs/tunisian_msa/s5/local/qcri_buckwalter2utf8.sh new file mode 100755 index 00000000000..0468c04ebd8 --- /dev/null +++ b/egs/tunisian_msa/s5/local/qcri_buckwalter2utf8.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +# Copyright 2018 John Morgan +# Apache 2.0. + +# write separate files for word and pronunciation fields +cut -d " " -f 1 qcri.txt > qcri_words_buckwalter.txt +cut -d " " -f 2- qcri.txt > qcri_prons.txt + +# convert words to utf8 +local/buckwalter2unicode.py -i qcri_words_buckwalter.txt -o qcri_words_utf8.txt + +paste qcri_words_utf8.txt qcri_prons.txt + +rm qcri_words_buckwalter.txt qcri_words_utf8.txt qcri_prons.txt diff --git a/egs/tunisian_msa/s5/local/qcri_lexicon_download.sh b/egs/tunisian_msa/s5/local/qcri_lexicon_download.sh new file mode 100755 index 00000000000..29a9ca1eed6 --- /dev/null +++ b/egs/tunisian_msa/s5/local/qcri_lexicon_download.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# Copyright 2018 John Morgan +# Apache 2.0. + +# configuration variables +lex=$1 +tmpdir=data/local/tmp +# where to put the downloaded speech corpus +downloaddir=$(pwd) +# Where to put the uncompressed file +datadir=$(pwd) +# end of configuration variable settings + +# download the corpus +if [ ! -f $downloaddir/qcri.txt.bz2 ]; then + wget -O $downloaddir/qcri.txt.bz2 $lex + ( + cd $downloaddir + bzcat qcri.txt.bz2 | tail -n+4 > $datadir/qcri.txt + ) +else + echo "$0: The corpus $lex was already downloaded." +fi diff --git a/egs/tunisian_msa/s5/local/recordings_make_lists.pl b/egs/tunisian_msa/s5/local/recordings_make_lists.pl new file mode 100755 index 00000000000..41fc15e0dd3 --- /dev/null +++ b/egs/tunisian_msa/s5/local/recordings_make_lists.pl @@ -0,0 +1,72 @@ +#!/usr/bin/env perl + +# Copyright 2018 John Morgan +# Apache 2.0. + +# recordings_make_lists.pl - make acoustic model training lists + +use strict; +use warnings; +use Carp; + +use File::Spec; +use File::Copy; +use File::Basename; + +my $tmpdir = "data/local/tmp/tunis"; + +system "mkdir -p $tmpdir/recordings"; + +# input wav file list +my $w = "$tmpdir/recordings_wav.txt"; + +# output temporary wav.scp files +my $o = "$tmpdir/recordings/wav.scp"; + +# output temporary utt2spk files +my $u = "$tmpdir/recordings/utt2spk"; + +# output temporary text files +my $t = "$tmpdir/recordings/text"; + +# initialize hash for prompts +my %p = (); + +# store prompts in hash +LINEA: while ( my $line = <> ) { + chomp $line; + my ($s,$sent) = split /\t/, $line, 2; + $p{$s} = $sent; +} + +open my $W, '<', $w or croak "problem with $w $!"; +open my $O, '+>', $o or croak "problem with $o $!"; +open my $U, '+>', $u or croak "problem with $u $!"; +open my $T, '+>', $t or croak "problem with $t $!"; + + LINE: while ( my $line = <$W> ) { + chomp $line; + next LINE if ($line =~ /Answers/ ); + next LINE unless ( $line =~ /Recordings/ ); + my ($volume,$directories,$file) = File::Spec->splitpath( $line ); + my @dirs = split /\//, $directories; + my $machine = $dirs[-3]; + my $r = basename $line, ".wav"; + my $s = $dirs[-1]; + my $rid = $machine . '_' . $s . '_r_' . $r; + if ( exists $p{$r} ) { + print $T "$rid\t$p{$r}\n"; + } elsif ( defined $rid ) { + warn "problem\t$rid"; + next LINE; + } else { + croak "$line"; + } + + print $O "$rid sox $line -t wav - |\n"; + print $U "$rid\t${machine}_${s}_r\n"; +} +close $T; +close $O; +close $U; +close $W; diff --git a/egs/tunisian_msa/s5/local/score.sh b/egs/tunisian_msa/s5/local/score.sh new file mode 120000 index 00000000000..0afefc3158c --- /dev/null +++ b/egs/tunisian_msa/s5/local/score.sh @@ -0,0 +1 @@ +../steps/score_kaldi.sh \ No newline at end of file diff --git a/egs/tunisian_msa/s5/local/subs_download.sh b/egs/tunisian_msa/s5/local/subs_download.sh new file mode 100755 index 00000000000..7e46fd255aa --- /dev/null +++ b/egs/tunisian_msa/s5/local/subs_download.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +# Copyright 2018 John Morgan +# Apache 2.0. + +# Begin configuration +subs_src=$1 +tmpdir=data/local/tmp +download_dir=$(pwd) +datadir=$(pwd) +# End configuration + +# download the subs corpus +if [ ! -f $download_dir/subs.txt.gz ]; then + wget -O $download_dir/subs.txt.gz $subs_src +else + echo "$0: The corpus $subs_src was already downloaded." +fi + +if [ ! -f $datadir/subs.txt ]; then + ( + cd $datadir + zcat < ./subs.txt.gz > subs.txt + ) + else + echo "$0: subs file already extracted." +fi diff --git a/egs/tunisian_msa/s5/local/subs_prepare_data.pl b/egs/tunisian_msa/s5/local/subs_prepare_data.pl new file mode 100755 index 00000000000..e39f77a25cb --- /dev/null +++ b/egs/tunisian_msa/s5/local/subs_prepare_data.pl @@ -0,0 +1,115 @@ +#!/usr/bin/env perl + +# Copyright 2018 John Morgan +# Apache 2.0. + +# subs_prepare_data.pl - condition subs data for lm training + +use strict; +use warnings; +use Carp; + +use Encode; + +# set lower and upper bounds +my $low_bound = 8; +# only segments with at least $low_bound words will be written +my $up_bound = 16; +# only segments with fewer than $up_bound words will be written + +# input and output files +my $corp = "subs.txt"; +my $symtab = "data/lang/words.txt"; +my $conditioned = "data/local/tmp/subs/lm/ar.txt"; +my $oo = "data/local/tmp/subs/lm/oovs.txt"; +my $iv = "data/local/tmp/subs/lm/in_vocabulary.txt"; + +open my $CORP, '<', $corp or croak "problems with $corp $!"; +system "mkdir -p data/local/tmp/subs/lm"; +open my $COND, '+>:utf8', $conditioned or croak "problems with $conditioned $!"; + +if ( -s $conditioned ) { + croak "$conditioned already exists."; +} else { + LINE: while ( my $line = <$CORP> ) { + $line = decode_utf8 $line; + chomp $line; + + my @tokens = split /\s+/, $line; + + next LINE if ( ($#tokens < $low_bound) or ($#tokens > $up_bound )); + + # remove punctuation + $line =~ s/(\p{Punctuation}+|\p{Dash_Punctuation}+|\p{Close_Punctuation}+|\p{Open_Punctuation}+|\p{Initial_Punctuation}+|\p{Final_Punctuation}+|\p{Connector_Punctuation}+|\p{Other_Punctuation}+|[ ]+)/ /msxg; + #convert tabs to white space + $line =~ s/\t/ /g; + #hard to soft space + $line =~ s/ / /g; + #squeeze white space + $line =~ s/\s+/ /g; + #initial and final white space + $line =~ s/^\p{Separator}+//; + $line =~ s/\p{Separator}+$//; + #down case + $line = lc $line; + + print $COND "$line\n"; + } +}close $CORP; +close $COND; + +# find out of vocabulary words +# $symtab points to a file containing a map of symbols to integers + +# hash for word to integer map +my %sym2int = (); + +open my $F, '<', $symtab or croak "problem with $symtab $!"; + +# store words to int map in hash +while( my $line = <$F>) { + chomp $line; + my ($s,$i) = split /\s/, $line, 2; + $sym2int{$s} = $i; +} +close $F; + +open my $I, '<', $conditioned or croak "problem with $conditioned $!"; +open my $OO, '+>', $oo or croak "problems with $oo $!"; + +while ( my $line = <$I>) { + chomp $line; + my @A = split /\s/, $line; + foreach my $a (@A) { + if (!defined ($sym2int{$a})) { + print $OO "$a\n"; + } + } +} +close $OO; +close $I; + +# remove segments with OOVs + +# store OOVS in hash +my %oov = (); +open my $V, '<', $oo or croak "problems with $oo $!"; +while ( my $line = <$V> ) { + chomp $line; + $oov{$line} = 1; +} +close $V; + +open my $L, '<', $conditioned or croak "problems with $conditioned $!"; +open my $IV, '+>', $iv or croak "problems with $iv $!"; + +SEGMENT: while ( my $segment = <$L> ) { + chomp $segment; + my @words = split /\s+/, $segment; + foreach my $word ( sort @words ) { + next SEGMENT if ( $oov{$word} ); + } + print $IV "$segment\n"; +} +close $IV; +close $L; diff --git a/egs/tunisian_msa/s5/local/tamsa_download.sh b/egs/tunisian_msa/s5/local/tamsa_download.sh new file mode 100755 index 00000000000..5e4666482ab --- /dev/null +++ b/egs/tunisian_msa/s5/local/tamsa_download.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +# Copyright 2018 John Morgan +# Apache 2.0. + +speech=$1 + +# where to put the downloaded speech corpus +download_dir=$(pwd) +data_dir=$download_dir/Tunisian_MSA/data + +# download the corpus from openslr +if [ ! -f $download_dir/tamsa.tar.gz ]; then + wget -O $download_dir/tamsa.tar.gz $speech +else + echo "$0: The corpus $speech was already downloaded." +fi + +if [ ! -d $download_dir/Tunisian_MSA ]; then + ( + cd $download_dir + tar -xzf tamsa.tar.gz + ) +else + echo "$0: The corpus was already unzipped." +fi diff --git a/egs/tunisian_msa/s5/local/test_answers_make_lists.pl b/egs/tunisian_msa/s5/local/test_answers_make_lists.pl new file mode 100755 index 00000000000..aa7d0e314f3 --- /dev/null +++ b/egs/tunisian_msa/s5/local/test_answers_make_lists.pl @@ -0,0 +1,83 @@ +#!/usr/bin/env perl + +# Copyright 2018 John Zac76 +# Apache 2.0. + +# test_answers_make_lists.pl - make acoustic model training lists + +use strict; +use warnings; +use Carp; + +use File::Spec; +use File::Copy; +use File::Basename; + +BEGIN { + @ARGV == 3 or croak "USAGE $0 +example: +$0 /home/zak76/Desktop/Kaldi/kaldi-master/tunisian_msa-master/Libyan_collected_test/TEST/Libyan_MSA/adel/data/transcripts/answers/adel_answers.tsv adel libyan +"; +} + +my ($tr,$spk,$l) = @ARGV; + +open my $I, '<', $tr or croak "problems with $tr"; + +my $tmp_dir = "data/local/tmp/$l/$spk"; + +system "mkdir -p $tmp_dir/answers"; + +# input wav file list +my $w = "$tmp_dir/answers_wav.txt"; + +# output temporary wav.scp files +my $o = "$tmp_dir/answers/wav.scp"; + +# output temporary utt2spk files +my $u = "$tmp_dir/answers/utt2spk"; + +# output temporary text files +my $t = "$tmp_dir/answers/text"; + +# initialize hash for prompts +my %p = (); + +# store prompts in hash +LINEA: while ( my $line = <$I> ) { + chomp $line; + my ($s,$sent) = split /\t/, $line, 2; + $p{$s} = $sent; +} + +open my $W, '<', $w or croak "problem with $w $!"; +open my $O, '+>', $o or croak "problem with $o $!"; +open my $U, '+>', $u or croak "problem with $u $!"; +open my $T, '+>', $t or croak "problem with $t $!"; + + LINE: while ( my $line = <$W> ) { + chomp $line; + next LINE if ($line =~ /recordings/ ); + next LINE unless ( $line =~ /answers/ ); + my ($volume,$directories,$file) = File::Spec->splitpath( $line ); + my @dirs = split /\//, $directories; + my $b = basename $line, ".wav"; + my ($sk,$r) = split /\_/, $b, 2; + my $s = $dirs[-1]; + my $rid = $sk . '_' . $r; + if ( exists $p{$b} ) { + print $T "$rid\t$p{$b}\n"; + } elsif ( defined $rid ) { + warn "problem\t$rid"; + next LINE; + } else { + croak "$line"; + } + + print $O "$rid sox $line -t wav - |\n"; + print $U "$rid\t${sk}_a\n"; +} +close $T; +close $O; +close $U; +close $W; diff --git a/egs/tunisian_msa/s5/local/test_recordings_make_lists.pl b/egs/tunisian_msa/s5/local/test_recordings_make_lists.pl new file mode 100755 index 00000000000..0b1323f2738 --- /dev/null +++ b/egs/tunisian_msa/s5/local/test_recordings_make_lists.pl @@ -0,0 +1,83 @@ +#!/usr/bin/env perl + +# Copyright 2018 John Morgan +# Apache 2.0. + +# test_recordings_make_lists.pl - make acoustic model training lists + +use strict; +use warnings; +use Carp; + +use File::Spec; +use File::Copy; +use File::Basename; + +BEGIN { + @ARGV == 3 or croak "USAGE $0 +example: +$0 /mnt/disk01/Libyan_MSA/srj/data/transcripts/recordings/srj_recordings.tsv srj libyan +"; +} + +my ($tr,$spk,$l) = @ARGV; + +open my $I, '<', $tr or croak "problems with $tr"; + +my $tmp_dir = "data/local/tmp/$l/$spk"; + +system "mkdir -p $tmp_dir/recordings"; + +# input wav file list +my $w = "$tmp_dir/recordings_wav.txt"; + +# output temporary wav.scp files +my $o = "$tmp_dir/recordings/wav.scp"; + +# output temporary utt2spk files +my $u = "$tmp_dir/recordings/utt2spk"; + +# output temporary text files +my $t = "$tmp_dir/recordings/text"; + +# initialize hash for prompts +my %p = (); + +# store prompts in hash +LINEA: while ( my $line = <$I> ) { + chomp $line; + my ($s,$sent) = split /\t/, $line, 2; + $p{$s} = $sent; +} + +open my $W, '<', $w or croak "problem with $w $!"; +open my $O, '+>', $o or croak "problem with $o $!"; +open my $U, '+>', $u or croak "problem with $u $!"; +open my $T, '+>', $t or croak "problem with $t $!"; + + LINE: while ( my $line = <$W> ) { + chomp $line; + next LINE if ($line =~ /answers/ ); + next LINE unless ( $line =~ /recordings/ ); + my ($volume,$directories,$file) = File::Spec->splitpath( $line ); + my @dirs = split /\//, $directories; + my $b = basename $line, ".wav"; + my ($sk,$r) = split /\_/, $b, 2; + my $s = $dirs[-1]; + my $rid = $sk . '_' . $r; + if ( exists $p{$b} ) { + print $T "$rid\t$p{$b}\n"; + } elsif ( defined $rid ) { + warn "problem\t$rid"; + next LINE; + } else { + croak "$line"; + } + + print $O "$rid sox $line -t wav - |\n"; + print $U "$rid\t${sk}_r\n"; +} +close $T; +close $O; +close $U; +close $W; diff --git a/egs/tunisian_msa/s5/path.sh b/egs/tunisian_msa/s5/path.sh new file mode 100644 index 00000000000..705600ad47a --- /dev/null +++ b/egs/tunisian_msa/s5/path.sh @@ -0,0 +1,8 @@ +export KALDI_ROOT=`pwd`/../../.. +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C + +# For now, don't include any of the optional dependenices of the main +# librispeech recipe diff --git a/egs/tunisian_msa/s5/run.sh b/egs/tunisian_msa/s5/run.sh new file mode 100755 index 00000000000..107acdf271c --- /dev/null +++ b/egs/tunisian_msa/s5/run.sh @@ -0,0 +1,190 @@ +#!/bin/bash + +# Trains on 11 hours of speechfrom CTELL{ONE,TWO,THREE,FOUR,FIVE} +# Uses the QCRI vowelized Arabic lexicon. +# Converts the Buckwalter encoding to utf8. +. ./cmd.sh +. ./path.sh +stage=0 + +. ./utils/parse_options.sh + +set -e +set -o pipefail +set u + +# Do not change tmpdir, other scripts under local depend on it +tmpdir=data/local/tmp + +# The speech corpus is on openslr.org +speech="http://www.openslr.org/resources/46/Tunisian_MSA.tar.gz" + +# We use the QCRI lexicon. +lex="http://alt.qcri.org/resources/speech/dictionary/ar-ar_lexicon_2014-03-17.txt.bz2" + +# We train the lm on subtitles. +subs_src="http://opus.nlpl.eu/download.php?f=OpenSubtitles2018/mono/OpenSubtitles2018.ar.gz" + +if [ $stage -le 1 ]; then + # Downloads archive to this script's directory + local/tamsa_download.sh $speech + + local/qcri_lexicon_download.sh $lex + + local/subs_download.sh $subs_src +fi + +# preparation stages will store files under data/ +# Delete the entire data directory when restarting. +if [ $stage -le 2 ]; then + local/prepare_data.sh +fi + +if [ $stage -le 3 ]; then + mkdir -p $tmpdir/dict + local/qcri_buckwalter2utf8.sh > $tmpdir/dict/qcri_utf8.txt +fi + +if [ $stage -le 4 ]; then + local/prepare_dict.sh $tmpdir/dict/qcri_utf8.txt +fi + +if [ $stage -le 5 ]; then + # prepare the lang directory + utils/prepare_lang.sh data/local/dict "" data/local/lang data/lang +fi + +if [ $stage -le 6 ]; then + echo "Preparing the subs data for lm training." + local/subs_prepare_data.pl +fi + +if [ $stage -le 7 ]; then + echo "lm training." + local/prepare_lm.sh $tmpdir/subs/lm/in_vocabulary.txt +fi + +if [ $stage -le 8 ]; then + echo "Making grammar fst." + utils/format_lm.sh \ + data/lang data/local/lm/trigram.arpa.gz data/local/dict/lexicon.txt \ + data/lang_test +fi + +if [ $stage -le 9 ]; then + # extract acoustic features + for fld in devtest train test; do + steps/make_mfcc.sh data/$fld exp/make_mfcc/$fld mfcc + utils/fix_data_dir.sh data/$fld + steps/compute_cmvn_stats.sh data/$fld exp/make_mfcc mfcc + utils/fix_data_dir.sh data/$fld + done +fi + +if [ $stage -le 10 ]; then + echo "$0: monophone training" + steps/train_mono.sh data/train data/lang exp/mono +fi + +if [ $stage -le 11 ]; then + # monophone evaluation + ( + # make decoding graph for monophones + utils/mkgraph.sh data/lang_test exp/mono exp/mono/graph + + # test monophones + for x in devtest test; do + nspk=$(wc -l < data/$x/spk2utt) + steps/decode.sh --nj $nspk exp/mono/graph data/$x exp/mono/decode_${x} + done + ) & +fi + +if [ $stage -le 12 ]; then + # align with monophones + steps/align_si.sh data/train data/lang exp/mono exp/mono_ali +fi + +if [ $stage -le 13 ]; then + echo "$0: Starting triphone training in exp/tri1" + steps/train_deltas.sh \ + --boost-silence 1.25 1000 6000 data/train data/lang exp/mono_ali exp/tri1 +fi + +wait + +if [ $stage -le 14 ]; then + # test cd gmm hmm models + # make decoding graphs for tri1 + ( + utils/mkgraph.sh data/lang_test exp/tri1 exp/tri1/graph + + # decode test data with tri1 models + for x in devtest test; do + nspk=$(wc -l < data/$x/spk2utt) + steps/decode.sh --nj $nspk exp/tri1/graph data/$x exp/tri1/decode_${x} + done + ) & +fi + +if [ $stage -le 15 ]; then + # align with triphones + steps/align_si.sh data/train data/lang exp/tri1 exp/tri1_ali +fi + +if [ $stage -le 16 ]; then + echo "$0: Starting (lda_mllt) triphone training in exp/tri2b" + steps/train_lda_mllt.sh \ + --splice-opts "--left-context=3 --right-context=3" 500 5000 \ + data/train data/lang exp/tri1_ali exp/tri2b +fi + +wait + +if [ $stage -le 17 ]; then + ( + # make decoding FSTs for tri2b models + utils/mkgraph.sh data/lang_test exp/tri2b exp/tri2b/graph + + # decode test with tri2b models + for x in devtest test; do + nspk=$(wc -l < data/$x/spk2utt) + steps/decode.sh --nj $nspk exp/tri2b/graph data/$x exp/tri2b/decode_${x} + done + ) & +fi + +if [ $stage -le 18 ]; then + # align with lda and mllt adapted triphones + steps/align_si.sh \ + --use-graphs true data/train data/lang exp/tri2b exp/tri2b_ali +fi + +if [ $stage -le 19 ]; then + echo "$0: Starting (SAT) triphone training in exp/tri3b" + steps/train_sat.sh 800 8000 data/train data/lang exp/tri2b_ali exp/tri3b +fi + +if [ $stage -le 20 ]; then + ( + # make decoding graphs for SAT models + utils/mkgraph.sh data/lang_test exp/tri3b exp/tri3b/graph + + # decode test sets with tri3b models + for x in devtest test; do + nspk=$(wc -l < data/$x/spk2utt) + steps/decode_fmllr.sh --nj $nspk exp/tri3b/graph data/$x exp/tri3b/decode_${x} + done + ) & +fi + +if [ $stage -le 21 ]; then + # align with tri3b models + echo "$0: Starting exp/tri3b_ali" + steps/align_fmllr.sh data/train data/lang exp/tri3b exp/tri3b_ali +fi + +if [ $stage -le 22 ]; then + # train and test chain models + local/chain/run_tdnn.sh +fi diff --git a/egs/tunisian_msa/s5/steps b/egs/tunisian_msa/s5/steps new file mode 120000 index 00000000000..6e99bf5b5ad --- /dev/null +++ b/egs/tunisian_msa/s5/steps @@ -0,0 +1 @@ +../../wsj/s5/steps \ No newline at end of file diff --git a/egs/tunisian_msa/s5/utils b/egs/tunisian_msa/s5/utils new file mode 120000 index 00000000000..b240885218f --- /dev/null +++ b/egs/tunisian_msa/s5/utils @@ -0,0 +1 @@ +../../wsj/s5/utils \ No newline at end of file diff --git a/egs/uw3/v1/local/chain/run_cnn_1a.sh b/egs/uw3/v1/local/chain/run_cnn_1a.sh index 582bfc90105..e3548609da7 100755 --- a/egs/uw3/v1/local/chain/run_cnn_1a.sh +++ b/egs/uw3/v1/local/chain/run_cnn_1a.sh @@ -130,7 +130,7 @@ if [ $stage -le 4 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) common1="required-time-offsets=0 height-offsets=-2,-1,0,1,2 num-filters-out=12" mkdir -p $dir/configs diff --git a/egs/uw3/v1/local/make_features.py b/egs/uw3/v1/local/make_features.py index dd0a30a19d7..e0211963e39 100755 --- a/egs/uw3/v1/local/make_features.py +++ b/egs/uw3/v1/local/make_features.py @@ -24,8 +24,8 @@ parser = argparse.ArgumentParser(description="""Converts images (in 'dir'/images.scp) to features and writes them to standard output in text format.""") -parser.add_argument('dir', type=str, help='data directory (should contain images.scp)') -parser.add_argument('--out-ark', type=str, default='-', help='where to write the output feature file.') +parser.add_argument('dir', help='data directory (should contain images.scp)') +parser.add_argument('--out-ark', default='-', help='where to write the output feature file.') parser.add_argument('--feat-dim', type=int, default=40, help='size to scale the height of all images (i.e. the dimension of the resulting features)') parser.add_argument('--pad', type=bool, default=False, help='pad the left and right of the images with 10 white pixels.') @@ -43,7 +43,7 @@ def write_kaldi_matrix(file_handle, matrix, key): if num_cols != len(matrix[row_index]): raise Exception("All the rows of a matrix are expected to " "have the same length") - file_handle.write(" ".join(map(lambda x: str(x), matrix[row_index]))) + file_handle.write(" ".join([str(x) for x in matrix[row_index]])) if row_index != num_rows - 1: file_handle.write("\n") file_handle.write(" ]\n") diff --git a/egs/uw3/v1/local/process_data.py b/egs/uw3/v1/local/process_data.py index f5b37b04c2f..23b8e5402cf 100755 --- a/egs/uw3/v1/local/process_data.py +++ b/egs/uw3/v1/local/process_data.py @@ -14,8 +14,8 @@ import random parser = argparse.ArgumentParser(description="""Creates data/train and data/test.""") -parser.add_argument('database_path', type=str, help='path to downloaded (and extracted) UW3 corpus') -parser.add_argument('out_dir', type=str, default='data', +parser.add_argument('database_path', help='path to downloaded (and extracted) UW3 corpus') +parser.add_argument('out_dir', default='data', help='where to create the train and test data directories') args = parser.parse_args() @@ -52,10 +52,10 @@ # The dataset is randomly split train 95% and test 5% coin = random.randint(0, 20) if coin >= 1: - train_text_fh.write(utt_id + ' ' + text + '\n') - train_utt2spk_fh.write(utt_id + ' ' + str(page_count) + '\n') - train_image_fh.write(utt_id + ' ' + image_path + '\n') + train_text_fh.write("{} {}\n".format(utt_id, text)) + train_utt2spk_fh.write("{} {}\n".format(utt_id, page_count)) + train_image_fh.write("{} {}\n".format(utt_id, image_path)) elif coin < 1: - test_text_fh.write(utt_id + ' ' + text + '\n') - test_utt2spk_fh.write(utt_id + ' ' + str(page_count) + '\n') - test_image_fh.write(utt_id + ' ' + image_path + '\n') + test_text_fh.write("{} {}\n".format(utt_id, text)) + test_utt2spk_fh.write("{} {}\n".format(utt_id, page_count)) + train_image_fh.write("{} {}\n".format(utt_id, image_path)) diff --git a/egs/uw3/v1/local/unk_arc_post_to_transcription.py b/egs/uw3/v1/local/unk_arc_post_to_transcription.py index c86d35e4b8a..f8b69820601 100755 --- a/egs/uw3/v1/local/unk_arc_post_to_transcription.py +++ b/egs/uw3/v1/local/unk_arc_post_to_transcription.py @@ -1,86 +1,107 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 -# Copyright 2017 Ashish Arora +#Copyright 2017 Ashish Arora +""" This module will be used by scripts for open vocabulary setup. + If the hypothesis transcription contains , then it will replace the + with the word predicted by model by concatenating phones decoded + from the unk-model. It is currently supported only for triphone setup. + Args: + phones: File name of a file that contains the phones.txt, (symbol-table for phones). + phone and phoneID, Eg. a 217, phoneID of 'a' is 217. + words: File name of a file that contains the words.txt, (symbol-table for words). + word and wordID. Eg. ACCOUNTANCY 234, wordID of 'ACCOUNTANCY' is 234. + unk: ID of . Eg. 231. + one-best-arc-post: A file in arc-post format, which is a list of timing info and posterior + of arcs along the one-best path from the lattice. + E.g. 506_m01-049-00 8 12 1 7722 282 272 288 231 + [] + [ ...] + output-text: File containing hypothesis transcription with recognized by the + unk-model. + E.g. A move to stop mr. gaitskell. + + Eg. local/unk_arc_post_to_transcription.py lang/phones.txt lang/words.txt + data/lang/oov.int +""" import argparse +import os import sys - parser = argparse.ArgumentParser(description="""uses phones to convert unk to word""") -parser.add_argument('phones', type=str, help='phones and phonesID') -parser.add_argument('words', type=str, help='word and wordID') -parser.add_argument('unk', type=str, default='-', help='location of unk file') -parser.add_argument('--input-ark', type=str, default='-', help='where to read the input data') -parser.add_argument('--out-ark', type=str, default='-', help='where to write the output data') +parser.add_argument('phones', type=str, help='File name of a file that contains the' + 'symbol-table for phones. Each line must be: ') +parser.add_argument('words', type=str, help='File name of a file that contains the' + 'symbol-table for words. Each line must be: ') +parser.add_argument('unk', type=str, default='-', help='File name of a file that' + 'contains the ID of . The content must be: , e.g. 231') +parser.add_argument('--one-best-arc-post', type=str, default='-', help='A file in arc-post' + 'format, which is a list of timing info and posterior of arcs' + 'along the one-best path from the lattice') +parser.add_argument('--output-text', type=str, default='-', help='File containing' + 'hypothesis transcription with recognized by the unk-model') args = parser.parse_args() + ### main ### -phone_fh = open(args.phones, 'r') -word_fh = open(args.words, 'r') -unk_fh = open(args.unk,'r') -if args.input_ark == '-': - input_fh = sys.stdin +phone_handle = open(args.phones, 'r', encoding='latin-1') # Create file handles +word_handle = open(args.words, 'r', encoding='latin-1') +unk_handle = open(args.unk,'r', encoding='latin-1') +if args.one_best_arc_post == '-': + arc_post_handle = sys.stdin else: - input_fh = open(args.input_ark,'r') -if args.out_ark == '-': - out_fh = sys.stdout + arc_post_handle = open(args.one_best_arc_post, 'r', encoding='latin-1') +if args.output_text == '-': + output_text_handle = sys.stdout else: - out_fh = open(args.out_ark,'wb') + output_text_handle = open(args.output_text, 'w', encoding='latin-1') -phone_dict = dict()# stores phoneID and phone mapping -phone_data_vect = phone_fh.read().strip().split("\n") -for key_val in phone_data_vect: +id2phone = dict() # Stores the mapping from phone_id (int) to phone (char) +phones_data = phone_handle.read().strip().split("\n") + +for key_val in phones_data: key_val = key_val.split(" ") - phone_dict[key_val[1]] = key_val[0] + id2phone[key_val[1]] = key_val[0] + word_dict = dict() -word_data_vect = word_fh.read().strip().split("\n") +word_data_vect = word_handle.read().strip().split("\n") + for key_val in word_data_vect: key_val = key_val.split(" ") word_dict[key_val[1]] = key_val[0] -unk_val = unk_fh.read().strip().split(" ")[0] +unk_val = unk_handle.read().strip().split(" ")[0] -utt_word_dict = dict() -utt_phone_dict = dict()# stores utteranceID and phoneID -unk_word_dict = dict() -count=0 -for line in input_fh: +utt_word_dict = dict() # Dict of list, stores mapping from utteranceID(int) to words(str) +for line in arc_post_handle: line_vect = line.strip().split("\t") - if len(line_vect) < 6: - print "IndexError" - print line_vect + if len(line_vect) < 6: # Check for 1best-arc-post output + print("Error: Bad line: '{}' Expecting 6 fields. Skipping...".format(line), + file=sys.stderr) continue - uttID = line_vect[0] + utt_id = line_vect[0] word = line_vect[4] phones = line_vect[5] - if uttID in utt_word_dict.keys(): - utt_word_dict[uttID][count] = word - utt_phone_dict[uttID][count] = phones - else: - count = 0 - utt_word_dict[uttID] = dict() - utt_phone_dict[uttID] = dict() - utt_word_dict[uttID][count] = word - utt_phone_dict[uttID][count] = phones - if word == unk_val: # get character sequence for unk - phone_key_vect = phones.split(" ") - phone_val_vect = list() - for pkey in phone_key_vect: - phone_val_vect.append(phone_dict[pkey]) + if utt_id not in list(utt_word_dict.keys()): + utt_word_dict[utt_id] = list() + + if word == unk_val: # Get the 1best phone sequence given by the unk-model + phone_id_seq = phones.split(" ") + phone_seq = list() + for pkey in phone_id_seq: + phone_seq.append(id2phone[pkey]) # Convert the phone-id sequence to a phone sequence. phone_2_word = list() - for phone_val in phone_val_vect: - phone_2_word.append(phone_val.split('_')[0]) - phone_2_word = ''.join(phone_2_word) - utt_word_dict[uttID][count] = phone_2_word + for phone_val in phone_seq: + phone_2_word.append(phone_val.split('_')[0]) # Removing the world-position markers(e.g. _B) + phone_2_word = ''.join(phone_2_word) # Concatnate phone sequence + utt_word_dict[utt_id].append(phone_2_word) # Store word from unk-model else: - if word == '0': + if word == '0': # Store space/silence word_val = ' ' else: word_val = word_dict[word] - utt_word_dict[uttID][count] = word_val - count += 1 + utt_word_dict[utt_id].append(word_val) # Store word from 1best-arc-post -transcription = "" -for key in sorted(utt_word_dict.iterkeys()): - transcription = key - for index in sorted(utt_word_dict[key].iterkeys()): - value = utt_word_dict[key][index] - transcription = transcription + " " + value - out_fh.write(transcription + '\n') +transcription = "" # Output transcription +for utt_key in sorted(utt_word_dict.keys()): + transcription = utt_key + for word in utt_word_dict[utt_key]: + transcription = transcription + " " + word + output_text_handle.write(transcription + '\n') diff --git a/egs/voxceleb/v1/local/make_musan.py b/egs/voxceleb/v1/local/make_musan.py index 74c434990fb..565bfce0cc9 100755 --- a/egs/voxceleb/v1/local/make_musan.py +++ b/egs/voxceleb/v1/local/make_musan.py @@ -47,9 +47,9 @@ def prepare_music(root_dir, use_vocals): utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n" num_good_files += 1 else: - print("Missing file", utt) + print("Missing file {}".format(utt)) num_bad_files += 1 - print("In music directory, processed", num_good_files, "files;", num_bad_files, "had missing wav data") + print("In music directory, processed {} files; {} had missing wav data".format(num_good_files, num_bad_files)) return utt2spk_str, utt2wav_str def prepare_speech(root_dir): @@ -73,9 +73,9 @@ def prepare_speech(root_dir): utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n" num_good_files += 1 else: - print("Missing file", utt) + print("Missing file {}".format(utt)) num_bad_files += 1 - print("In speech directory, processed", num_good_files, "files;", num_bad_files, "had missing wav data") + print("In speech directory, processed {} files; {} had missing wav data".format(num_good_files, num_bad_files)) return utt2spk_str, utt2wav_str def prepare_noise(root_dir): @@ -99,9 +99,9 @@ def prepare_noise(root_dir): utt2wav_str = utt2wav_str + utt + " " + utt2wav[utt] + "\n" num_good_files += 1 else: - print("Missing file", utt) + print("Missing file {}".format(utt)) num_bad_files += 1 - print("In noise directory, processed", num_good_files, "files;", num_bad_files, "had missing wav data") + print("In noise directory, processed {} files; {} had missing wav data".format(num_good_files, num_bad_files)) return utt2spk_str, utt2wav_str def main(): diff --git a/egs/voxceleb/v1/local/make_voxceleb1.pl b/egs/voxceleb/v1/local/make_voxceleb1.pl index 916e11020d2..2268c20ab52 100755 --- a/egs/voxceleb/v1/local/make_voxceleb1.pl +++ b/egs/voxceleb/v1/local/make_voxceleb1.pl @@ -15,10 +15,6 @@ my $out_test_dir = "$out_dir/voxceleb1_test"; my $out_train_dir = "$out_dir/voxceleb1_train"; -if (! -e "$data_base/voxceleb1_test.txt") { - system("wget -O $data_base/voxceleb1_test.txt http://www.openslr.org/resources/49/voxceleb1_test.txt"); -} - if (system("mkdir -p $out_test_dir") != 0) { die "Error making directory $out_test_dir"; } @@ -31,20 +27,35 @@ my @spkr_dirs = grep {-d "$data_base/voxceleb1_wav/$_" && ! /^\.{1,2}$/} readdir($dh); closedir $dh; +if (! -e "$data_base/voxceleb1_test.txt") { + system("wget -O $data_base/voxceleb1_test.txt http://www.openslr.org/resources/49/voxceleb1_test.txt"); +} + +if (! -e "$data_base/vox1_meta.csv") { + system("wget -O $data_base/vox1_meta.csv http://www.openslr.org/resources/49/vox1_meta.csv"); +} + open(TRIAL_IN, "<", "$data_base/voxceleb1_test.txt") or die "Could not open the verification trials file $data_base/voxceleb1_test.txt"; +open(META_IN, "<", "$data_base/vox1_meta.csv") or die "Could not open the meta data file $data_base/vox1_meta.csv"; open(SPKR_TEST, ">", "$out_test_dir/utt2spk") or die "Could not open the output file $out_test_dir/utt2spk"; open(WAV_TEST, ">", "$out_test_dir/wav.scp") or die "Could not open the output file $out_test_dir/wav.scp"; open(SPKR_TRAIN, ">", "$out_train_dir/utt2spk") or die "Could not open the output file $out_train_dir/utt2spk"; open(WAV_TRAIN, ">", "$out_train_dir/wav.scp") or die "Could not open the output file $out_train_dir/wav.scp"; open(TRIAL_OUT, ">", "$out_test_dir/trials") or die "Could not open the output file $out_test_dir/trials"; +my %id2spkr = (); +while () { + chomp; + my ($vox_id, $spkr_id, $gender, $nation, $set) = split; + $id2spkr{$vox_id} = $spkr_id; +} + my $test_spkrs = (); while () { chomp; - my ($tar_or_none, $path1, $path2) = split; + my ($tar_or_non, $path1, $path2) = split; # Create entry for left-hand side of trial - my $wav = "$data_base/voxceleb1_wav/$path1"; my ($spkr_id, $filename) = split('/', $path1); my $rec_id = substr($filename, 0, 11); my $segment = substr($filename, 12, 7); @@ -52,7 +63,6 @@ $test_spkrs{$spkr_id} = (); # Create entry for right-hand side of trial - my $wav = "$data_base/voxceleb1_wav/$path2"; my ($spkr_id, $filename) = split('/', $path2); my $rec_id = substr($filename, 0, 11); my $segment = substr($filename, 12, 7); @@ -60,7 +70,7 @@ $test_spkrs{$spkr_id} = (); my $target = "nontarget"; - if ($tar_or_none eq "1") { + if ($tar_or_non eq "1") { $target = "target"; } print TRIAL_OUT "$utt_id1 $utt_id2 $target\n"; @@ -68,6 +78,12 @@ foreach (@spkr_dirs) { my $spkr_id = $_; + my $new_spkr_id = $spkr_id; + # If we're using a newer version of VoxCeleb1, we need to "deanonymize" + # the speaker labels. + if (exists $id2spkr{$spkr_id}) { + $new_spkr_id = $id2spkr{$spkr_id}; + } opendir my $dh, "$data_base/voxceleb1_wav/$spkr_id/" or die "Cannot open directory: $!"; my @files = map{s/\.[^.]+$//;$_}grep {/\.wav$/} readdir($dh); closedir $dh; @@ -75,14 +91,14 @@ my $filename = $_; my $rec_id = substr($filename, 0, 11); my $segment = substr($filename, 12, 7); - my $utt_id = "$spkr_id-$rec_id-$segment"; my $wav = "$data_base/voxceleb1_wav/$spkr_id/$filename.wav"; - if (exists $test_spkrs{$spkr_id}) { + my $utt_id = "$new_spkr_id-$rec_id-$segment"; + if (exists $test_spkrs{$new_spkr_id}) { print WAV_TEST "$utt_id", " $wav", "\n"; - print SPKR_TEST "$utt_id", " $spkr_id", "\n"; + print SPKR_TEST "$utt_id", " $new_spkr_id", "\n"; } else { print WAV_TRAIN "$utt_id", " $wav", "\n"; - print SPKR_TRAIN "$utt_id", " $spkr_id", "\n"; + print SPKR_TRAIN "$utt_id", " $new_spkr_id", "\n"; } } } @@ -93,6 +109,7 @@ close(WAV_TRAIN) or die; close(TRIAL_OUT) or die; close(TRIAL_IN) or die; +close(META_IN) or die; if (system( "utils/utt2spk_to_spk2utt.pl $out_test_dir/utt2spk >$out_test_dir/spk2utt") != 0) { diff --git a/egs/voxceleb/v1/local/prepare_for_eer.py b/egs/voxceleb/v1/local/prepare_for_eer.py index 6bfa04e011b..2f569b70bc5 100755 --- a/egs/voxceleb/v1/local/prepare_for_eer.py +++ b/egs/voxceleb/v1/local/prepare_for_eer.py @@ -16,4 +16,4 @@ spkrutt2target[spkr+utt]=target for line in scores: spkr, utt, score = line.strip().split() - print(score, spkrutt2target[spkr+utt]) + print("{} {}".format(score, spkrutt2target[spkr+utt])) diff --git a/egs/voxceleb/v2/run.sh b/egs/voxceleb/v2/run.sh index e57799cee27..37bb60fe35c 100755 --- a/egs/voxceleb/v2/run.sh +++ b/egs/voxceleb/v2/run.sh @@ -27,7 +27,7 @@ stage=0 if [ $stage -le 0 ]; then local/make_voxceleb2.pl $voxceleb2_root dev data/voxceleb2_train local/make_voxceleb2.pl $voxceleb2_root test data/voxceleb2_test - # This script reates data/voxceleb1_test and data/voxceleb1_train. + # This script creates data/voxceleb1_test and data/voxceleb1_train. # Our evaluation set is the test portion of VoxCeleb1. local/make_voxceleb1.pl $voxceleb1_root data # We'll train on all of VoxCeleb2, plus the training portion of VoxCeleb1. @@ -66,7 +66,7 @@ if [ $stage -le 2 ]; then # Make a reverberated version of the VoxCeleb2 list. Note that we don't add any # additive noise here. - python steps/data/reverberate_data_dir.py \ + steps/data/reverberate_data_dir.py \ "${rvb_opts[@]}" \ --speech-rvb-probability 1 \ --pointsource-noise-addition-probability 0 \ @@ -91,11 +91,11 @@ if [ $stage -le 2 ]; then done # Augment with musan_noise - python steps/data/augment_data_dir.py --utt-suffix "noise" --fg-interval 1 --fg-snrs "15:10:5:0" --fg-noise-dir "data/musan_noise" data/train data/train_noise + steps/data/augment_data_dir.py --utt-suffix "noise" --fg-interval 1 --fg-snrs "15:10:5:0" --fg-noise-dir "data/musan_noise" data/train data/train_noise # Augment with musan_music - python steps/data/augment_data_dir.py --utt-suffix "music" --bg-snrs "15:10:8:5" --num-bg-noises "1" --bg-noise-dir "data/musan_music" data/train data/train_music + steps/data/augment_data_dir.py --utt-suffix "music" --bg-snrs "15:10:8:5" --num-bg-noises "1" --bg-noise-dir "data/musan_music" data/train data/train_music # Augment with musan_speech - python steps/data/augment_data_dir.py --utt-suffix "babble" --bg-snrs "20:17:15:13" --num-bg-noises "3:4:5:6:7" --bg-noise-dir "data/musan_speech" data/train data/train_babble + steps/data/augment_data_dir.py --utt-suffix "babble" --bg-snrs "20:17:15:13" --num-bg-noises "3:4:5:6:7" --bg-noise-dir "data/musan_speech" data/train data/train_babble # Combine reverb, noise, music, and babble into one directory. utils/combine_data.sh data/train_aug data/train_reverb data/train_noise data/train_music data/train_babble diff --git a/egs/voxforge/gst_demo/run-live.py b/egs/voxforge/gst_demo/run-live.py index 725a306c42c..7876e5f2046 100755 --- a/egs/voxforge/gst_demo/run-live.py +++ b/egs/voxforge/gst_demo/run-live.py @@ -6,6 +6,7 @@ # # Apache 2.0 +from __future__ import print_function import sys import os import gi @@ -46,7 +47,7 @@ def init_gst(self): """Initialize the speech components""" self.pulsesrc = Gst.ElementFactory.make("pulsesrc", "pulsesrc") if self.pulsesrc == None: - print >> sys.stderr, "Error loading pulsesrc GST plugin. You probably need the gstreamer1.0-pulseaudio package" + print("Error loading pulsesrc GST plugin. You probably need the gstreamer1.0-pulseaudio package", file=sys.stderr) sys.exit() self.audioconvert = Gst.ElementFactory.make("audioconvert", "audioconvert") self.audioresample = Gst.ElementFactory.make("audioresample", "audioresample") @@ -56,7 +57,7 @@ def init_gst(self): if self.asr: model_dir = "online-data/models/tri2b_mmi/" if not os.path.isdir(model_dir): - print >> sys.stderr, "Model (%s) not downloaded. Run run-simulated.sh first" % model_dir + print("Model (%s) not downloaded. Run run-simulated.sh first" % model_dir, file=sys.stderr) sys.exit(1) self.asr.set_property("fst", model_dir + "HCLG.fst") self.asr.set_property("lda-mat", model_dir + "matrix") @@ -67,12 +68,12 @@ def init_gst(self): self.asr.set_property("beam", 12.0) self.asr.set_property("acoustic-scale", 0.0769) else: - print >> sys.stderr, "Couldn't create the onlinegmmfasterdecoder element. " + print("Couldn't create the onlinegmmfasterdecoder element. ", file=sys.stderr) if "GST_PLUGIN_PATH" in os.environ: - print >> sys.stderr, "Have you compiled the Kaldi GStreamer plugin?" + print("Have you compiled the Kaldi GStreamer plugin?", file=sys.stderr) else: - print >> sys.stderr, "You probably need to set the GST_PLUGIN_PATH envoronment variable" - print >> sys.stderr, "Try running: GST_PLUGIN_PATH=../../../src/gst-plugin %s" % sys.argv[0] + print("You probably need to set the GST_PLUGIN_PATH envoronment variable", file=sys.stderr) + print("Try running: GST_PLUGIN_PATH=../../../src/gst-plugin %s" % sys.argv[0], file=sys.stderr) sys.exit(); # initially silence the decoder @@ -111,10 +112,10 @@ def button_clicked(self, button): if __name__ == '__main__': app = DemoApp() - print ''' + print(''' The (bigram) language model used to build the decoding graph was estimated on an audio book's text. The text in question is King Solomon's Mines" (http://www.gutenberg.org/ebooks/2166). - You may want to read some sentences from this book first ...''' + You may want to read some sentences from this book first ...''') Gtk.main() diff --git a/egs/voxforge/s5/local/make_trans.py b/egs/voxforge/s5/local/make_trans.py index 1b4f5c4136a..612755c8be4 100755 --- a/egs/voxforge/s5/local/make_trans.py +++ b/egs/voxforge/s5/local/make_trans.py @@ -12,11 +12,12 @@ if this is the case produces a transcript line for each file in the format: prefix_a0405 IT SEEMED THE ORDAINED ORDER OF THINGS THAT DOGS SHOULD WORK """ +from __future__ import print_function import sys def err(msg): - print >> sys.stderr, msg + print(msg, file=sys.stderr) if len(sys.argv) < 3: err("Usage: %s ... " % sys.argv[0]) @@ -46,5 +47,5 @@ def err(msg): if not uid in utt2trans: err("No transcript found for %s_%s" % (id_prefix, uid)) continue - print "%s-%s %s" % (id_prefix, uid, utt2trans[uid]) + print("%s-%s %s" % (id_prefix, uid, utt2trans[uid])) diff --git a/egs/voxforge/s5/local/run_mmi_tri2b.sh b/egs/voxforge/s5/local/run_mmi_tri2b.sh index 6517e46a1a7..8a4d03c59c4 100755 --- a/egs/voxforge/s5/local/run_mmi_tri2b.sh +++ b/egs/voxforge/s5/local/run_mmi_tri2b.sh @@ -38,7 +38,7 @@ steps/train_diag_ubm.sh --silence-weight 0.5 --nj 10 --cmd "$train_cmd" \ data/train_si84 data/lang exp/tri2b_ali_si84 exp/dubm2b exp/tri2b_denlats_si84 \ exp/tri2b_fmmi_b0.1 - for iter in `seq 3 8`; do + for iter in `seq 3 8`; do steps/decode_fmmi.sh --nj 10 --cmd "$decode_cmd" --iter $iter \ exp/tri2b/graph_tgpr data/test_dev93 exp/tri2b_fmmi_b0.1/decode_tgpr_dev93_it$iter & done @@ -46,7 +46,7 @@ steps/train_diag_ubm.sh --silence-weight 0.5 --nj 10 --cmd "$train_cmd" \ steps/train_mmi_fmmi.sh --learning-rate 0.005 --boost 0.1 --cmd "$train_cmd" \ data/train_si84 data/lang exp/tri2b_ali_si84 exp/dubm2b exp/tri2b_denlats_si84 \ exp/tri2b_fmmi_b0.1_lr0.005 || exit 1; - for iter in `seq 3 8`; do + for iter in `seq 3 8`; do steps/decode_fmmi.sh --nj 10 --cmd "$decode_cmd" --iter $iter \ exp/tri2b/graph_tgpr data/test_dev93 exp/tri2b_fmmi_b0.1_lr0.005/decode_tgpr_dev93_it$iter & done @@ -54,7 +54,7 @@ steps/train_diag_ubm.sh --silence-weight 0.5 --nj 10 --cmd "$train_cmd" \ steps/train_mmi_fmmi_indirect.sh --boost 0.1 --cmd "$train_cmd" \ data/train_si84 data/lang exp/tri2b_ali_si84 exp/dubm2b exp/tri2b_denlats_si84 \ exp/tri2b_fmmi_indirect_b0.1 - for iter in `seq 3 8`; do + for iter in `seq 3 8`; do steps/decode_fmmi.sh --nj 10 --cmd "$decode_cmd" --iter $iter \ exp/tri2b/graph_tgpr data/test_dev93 exp/tri2b_fmmi_indirect_b0.1/decode_tgpr_dev93_it$iter & done diff --git a/egs/voxforge/s5/local/run_sgmm2x.sh b/egs/voxforge/s5/local/run_sgmm2x.sh index 96a17578203..c019bfdf3be 100755 --- a/egs/voxforge/s5/local/run_sgmm2x.sh +++ b/egs/voxforge/s5/local/run_sgmm2x.sh @@ -26,14 +26,14 @@ steps/decode_sgmm2.sh --use-fmllr true --config conf/decode.config --nj 20 --cmd steps/make_denlats_sgmm2.sh --nj 8 --sub-split 20 --cmd "$decode_cmd" --transform-dir exp/tri3b \ data/train data/lang exp/sgmm2x_4a_ali exp/sgmm2x_4a_denlats steps/train_mmi_sgmm2.sh --cmd "$decode_cmd" --transform-dir exp/tri3b --boost 0.2 \ - data/train data/lang exp/sgmm2x_4a_ali exp/sgmm2x_4a_denlats exp/sgmm2x_4a_mmi_b0.2 + data/train data/lang exp/sgmm2x_4a_ali exp/sgmm2x_4a_denlats exp/sgmm2x_4a_mmi_b0.2 for iter in 1 2 3 4; do steps/decode_sgmm2_rescore.sh --cmd "$decode_cmd" --iter $iter \ --transform-dir exp/tri3b/decode data/lang data/test exp/sgmm2x_4a/decode exp/sgmm2x_4a_mmi_b0.2/decode_it$iter & - done + done -wait +wait steps/decode_combine.sh data/test data/lang exp/tri1/decode exp/tri2a/decode exp/combine_1_2a/decode || exit 1; steps/decode_combine.sh data/test data/lang exp/sgmm2x_4a/decode exp/tri3b_mmi/decode exp/combine_sgmm2x_4a_3b/decode || exit 1; # combining the sgmm run and the best MMI+fMMI run. diff --git a/egs/voxforge/s5/local/voxforge_prepare_dict.sh b/egs/voxforge/s5/local/voxforge_prepare_dict.sh index 4242af29d25..daf4e2326e5 100755 --- a/egs/voxforge/s5/local/voxforge_prepare_dict.sh +++ b/egs/voxforge/s5/local/voxforge_prepare_dict.sh @@ -49,7 +49,7 @@ if [[ "$(uname)" == "Darwin" ]]; then alias readlink=greadlink fi -sequitur=$KALDI_ROOT/tools/sequitur +sequitur=$KALDI_ROOT/tools/sequitur-g2p export PATH=$PATH:$sequitur/bin export PYTHONPATH=$PYTHONPATH:`utils/make_absolute.sh $sequitur/lib/python*/site-packages` diff --git a/egs/voxforge/s5/run.sh b/egs/voxforge/s5/run.sh index 277d41039ea..86fc128469e 100755 --- a/egs/voxforge/s5/run.sh +++ b/egs/voxforge/s5/run.sh @@ -44,7 +44,7 @@ selected=${DATA_ROOT}/selected # /bin/bash run.sh --pos-dep-phones false . utils/parse_options.sh || exit 1 -[[ $# -ge 1 ]] && { echo "Unexpected arguments"; exit 1; } +[[ $# -ge 1 ]] && { echo "Unexpected arguments"; exit 1; } # Select a subset of the data to use # WARNING: the destination directory will be deleted if it already exists! @@ -75,7 +75,7 @@ local/voxforge_format_data.sh || exit 1 # mfccdir should be some place with a largish disk where you # want to store MFCC features. mfccdir=${DATA_ROOT}/mfcc -for x in train test; do +for x in train test; do steps/make_mfcc.sh --cmd "$train_cmd" --nj $njobs \ data/$x exp/make_mfcc/$x $mfccdir || exit 1; steps/compute_cmvn_stats.sh data/$x exp/make_mfcc/$x $mfccdir || exit 1; diff --git a/egs/vystadial_cz/online_demo/build_reference.py b/egs/vystadial_cz/online_demo/build_reference.py index 1be78391d2f..aea12a2c8bc 100755 --- a/egs/vystadial_cz/online_demo/build_reference.py +++ b/egs/vystadial_cz/online_demo/build_reference.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # encoding: utf-8 from __future__ import unicode_literals +from __future__ import print_function import glob import sys @@ -8,7 +9,7 @@ import codecs def build_reference(wav_scp, ref_path): - print wav_scp, ref_path + print(wav_scp, ref_path) with codecs.open(ref_path, 'w', 'utf-8') as w: with codecs.open(wav_scp, 'r', 'utf-8') as scp: for line in scp: @@ -31,8 +32,8 @@ def build_reference(wav_scp, ref_path): usage_args = {'exec': sys.argv[0]} if len(sys.argv) != 3: - print >> sys.stderr, "Wrong number of arguments" - print >> sys.stderr, usage % {'exec': sys.argv[0]} + print("Wrong number of arguments", file=sys.stderr) + print(usage % {'exec': sys.argv[0]}, file=sys.stderr) sys.exit(1) if sys.argv[1].endswith('scp'): @@ -41,12 +42,12 @@ def build_reference(wav_scp, ref_path): scps = glob.glob(os.path.join(sys.argv[1], '*.scp')) target_dir = sys.argv[2] if not len(scps): - print >> sys.stderr, "No '*.scp' files found" - print >> sys.stderr, usage % {'exec': sys.argv[0]} + print("No '*.scp' files found", file=sys.stderr) + print(usage % {'exec': sys.argv[0]}, file=sys.stderr) sys.exit(1) if not os.path.isdir(target_dir): - print >> sys.stderr, "No '*.scp' files found" - print >> sys.stderr, usage % {'exec': sys.argv[0]} + print("No '*.scp' files found", file=sys.stderr) + print(usage % {'exec': sys.argv[0]}, file=sys.stderr) sys.exit(1) refers = [os.path.join(target_dir, os.path.basename(scp) + '.tra') for scp in scps] diff --git a/egs/vystadial_cz/online_demo/live-demo.py b/egs/vystadial_cz/online_demo/live-demo.py index 6b41c12c739..320a930735f 100755 --- a/egs/vystadial_cz/online_demo/live-demo.py +++ b/egs/vystadial_cz/online_demo/live-demo.py @@ -15,6 +15,7 @@ # See the Apache 2 License for the specific language governing permissions and # limitations under the License. # from __future__ import unicode_literals +from __future__ import print_function import pyaudio from kaldi.decoders import PyOnlineLatgenRecogniser @@ -29,7 +30,7 @@ CHANNELS, RATE, FORMAT = 1, 16000, pyaudio.paInt16 -class LiveDemo: +class LiveDemo(object): def __init__(self, audio_batch_size, wst, dec_args): self.batch_size = audio_batch_size @@ -127,7 +128,7 @@ def save_wav(self): if __name__ == '__main__': audio_batch_size, wst_path = int(sys.argv[1]), sys.argv[2] argv = sys.argv[3:] - print >> sys.stderr, 'Python args: %s' % str(sys.argv) + print('Python args: %s' % str(sys.argv), file=sys.stderr) wst = wst2dict(wst_path) demo = LiveDemo(audio_batch_size, wst, argv) diff --git a/egs/vystadial_cz/online_demo/pykaldi-online-latgen-recogniser.py b/egs/vystadial_cz/online_demo/pykaldi-online-latgen-recogniser.py index 02a0400921c..0008a4c01f1 100755 --- a/egs/vystadial_cz/online_demo/pykaldi-online-latgen-recogniser.py +++ b/egs/vystadial_cz/online_demo/pykaldi-online-latgen-recogniser.py @@ -14,6 +14,8 @@ # See the Apache 2 License for the specific language governing permissions and # limitations under the License. # from __future__ import unicode_literals +from __future__ import division +from __future__ import print_function from kaldi.utils import load_wav, wst2dict, lattice_to_nbest from kaldi.decoders import PyOnlineLatgenRecogniser @@ -31,14 +33,14 @@ def write_decoded(f, wav_name, word_ids, wst): if wst is not None: decoded = [wst[w] for w in best_path] else: - decoded = [unicode(w) for w in best_path] + decoded = [str(w) for w in best_path] line = u' '.join([wav_name] + decoded + ['\n']) if DEBUG: - print '%s best path %s' % (wav_name, decoded.encode('UTF-8')) + print('%s best path %s' % (wav_name, decoded.encode('UTF-8'))) for i, s in enumerate(word_ids): if i > 0: break - print 'best path %d: %s' % (i, str(s)) + print('best path %d: %s' % (i, str(s))) f.write(line.encode('UTF-8')) @@ -55,11 +57,11 @@ def decode(d, pcm): while dec_t > 0: decoded_frames += dec_t dec_t = d.decode(max_frames=10) - print "forward decode: %s secs" % str(time.time() - start) + print("forward decode: %s secs" % str(time.time() - start)) start = time.time() d.prune_final() lik, lat = d.get_lattice() - print "backward decode: %s secs" % str(time.time() - start) + print("backward decode: %s secs" % str(time.time() - start)) d.reset(keep_buffer_data=False) return (lat, lik, decoded_frames) @@ -72,7 +74,7 @@ def decode_wrap(argv, audio_batch_size, wav_paths, for wav_name, wav_path in wav_paths: sw, sr = 2, 16000 # 16-bit audio so 1 sample_width = 2 chars pcm = load_wav(wav_path, def_sample_width=sw, def_sample_rate=sr) - print '%s has %f sec' % (wav_name, (float(len(pcm)) / sw) / sr) + print('%s has %f sec' % (wav_name, (float(len(pcm)) / sw) / sr)) lat, lik, decoded_frames = decode(d, pcm) lat.isyms = lat.osyms = fst.read_symbols_text(wst_path) if DEBUG: @@ -80,8 +82,8 @@ def decode_wrap(argv, audio_batch_size, wav_paths, f.write(lat._repr_svg_()) lat.write('%s_pykaldi.fst' % wav_name) - print "Log-likelihood per frame for utterance %s is %f over %d frames" % ( - wav_name, (lik / decoded_frames), decoded_frames) + print("Log-likelihood per frame for utterance %s is %f over %d frames" % ( + wav_name, int(lik / decoded_frames), decoded_frames)) word_ids = lattice_to_nbest(lat, n=10) write_decoded(file_output, wav_name, word_ids, wst) @@ -90,7 +92,7 @@ def decode_wrap(argv, audio_batch_size, wav_paths, audio_scp, audio_batch_size = sys.argv[1], int(sys.argv[2]) dec_hypo, wst_path = sys.argv[3], sys.argv[4] argv = sys.argv[5:] - print >> sys.stderr, 'Python args: %s' % str(sys.argv) + print('Python args: %s' % str(sys.argv), file=sys.stderr) # open audio_scp, decode and write to dec_hypo file with open(audio_scp, 'rb') as r: diff --git a/egs/vystadial_cz/s5/local/results.py b/egs/vystadial_cz/s5/local/results.py index a7c19af214c..f37109d5fcb 100755 --- a/egs/vystadial_cz/s5/local/results.py +++ b/egs/vystadial_cz/s5/local/results.py @@ -14,6 +14,8 @@ # MERCHANTABLITY OR NON-INFRINGEMENT. # See the Apache 2 License for the specific language governing permissions and # limitations under the License. # +from __future__ import division +from __future__ import print_function import argparse import glob import sys @@ -29,8 +31,8 @@ def extract_stat(wer_file): ser = float(s[2].split()[1]) except Exception as e: - print sys.stderr, 'Error parsing file %s' % wer_file - print sys.stderr, str(e) + print(sys.stderr, 'Error parsing file %s' % wer_file) + print(sys.stderr, str(e)) return wer, ser @@ -47,8 +49,8 @@ def extractResults(path): wer, ser = extract_stat(wf) table.append((exp, dataset, lm, lm_w, wer, ser)) except Exception as e: - print >> sys.stderr, 'failed to parse %s' % wf - print >> sys.stderr, str(e) + print('failed to parse %s' % wf, file=sys.stderr) + print(str(e), file=sys.stderr) return table @@ -105,7 +107,7 @@ def Table2LatexTable(table): def createSmallTable(r): d = [] - for k, v in r.iteritems(): + for k, v in r.items(): w, s, r = v if w == []: minw = None @@ -115,7 +117,7 @@ def createSmallTable(r): mins = None else: mins = min(s) # returns tuple if s is list of tuples - mean_r = sum(r) / float(len(r)) + mean_r = float(sum(r)) / len(r) d.append([k, mean_r, minw, mins]) t = Table(d, ['exp', 'RT coef', 'WER', 'SER']) return t @@ -167,7 +169,7 @@ def createSmallTable(r): # remove duplicates: duplicates if equal mimimum wer in dev set min_dev_un = [(e, lm, lmw) for ((e, lm), lmw) in - dict([((e, lm), lmw) for e, lm, lmw in min_dev]).items()] + list(dict([((e, lm), lmw) for e, lm, lmw in min_dev]).items())] # sort according LM -> sort results according experiment & LMs min_dev_un.sort(key=lambda x: (x[1], x[0])) @@ -182,6 +184,6 @@ def createSmallTable(r): d.append(x[0]) t = Table(data=d, colnames=['exp', 'set', 'LM', 'LMW', 'WER', 'SER']) - print str(t) + print(str(t)) if args.latex: - print Table2LatexTable(t) + print(Table2LatexTable(t)) diff --git a/egs/vystadial_cz/s5b/local/chain/tuning/run_tdnn_1a.sh b/egs/vystadial_cz/s5b/local/chain/tuning/run_tdnn_1a.sh index 496ee5e84ca..844ccf80677 100755 --- a/egs/vystadial_cz/s5b/local/chain/tuning/run_tdnn_1a.sh +++ b/egs/vystadial_cz/s5b/local/chain/tuning/run_tdnn_1a.sh @@ -148,7 +148,7 @@ if [ $stage -le 13 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.05 dropout-per-dim-continuous=true" output_opts="l2-regularize=0.02 bottleneck-dim=192" diff --git a/egs/vystadial_en/s5/local/results.py b/egs/vystadial_en/s5/local/results.py index a7c19af214c..f37109d5fcb 100755 --- a/egs/vystadial_en/s5/local/results.py +++ b/egs/vystadial_en/s5/local/results.py @@ -14,6 +14,8 @@ # MERCHANTABLITY OR NON-INFRINGEMENT. # See the Apache 2 License for the specific language governing permissions and # limitations under the License. # +from __future__ import division +from __future__ import print_function import argparse import glob import sys @@ -29,8 +31,8 @@ def extract_stat(wer_file): ser = float(s[2].split()[1]) except Exception as e: - print sys.stderr, 'Error parsing file %s' % wer_file - print sys.stderr, str(e) + print(sys.stderr, 'Error parsing file %s' % wer_file) + print(sys.stderr, str(e)) return wer, ser @@ -47,8 +49,8 @@ def extractResults(path): wer, ser = extract_stat(wf) table.append((exp, dataset, lm, lm_w, wer, ser)) except Exception as e: - print >> sys.stderr, 'failed to parse %s' % wf - print >> sys.stderr, str(e) + print('failed to parse %s' % wf, file=sys.stderr) + print(str(e), file=sys.stderr) return table @@ -105,7 +107,7 @@ def Table2LatexTable(table): def createSmallTable(r): d = [] - for k, v in r.iteritems(): + for k, v in r.items(): w, s, r = v if w == []: minw = None @@ -115,7 +117,7 @@ def createSmallTable(r): mins = None else: mins = min(s) # returns tuple if s is list of tuples - mean_r = sum(r) / float(len(r)) + mean_r = float(sum(r)) / len(r) d.append([k, mean_r, minw, mins]) t = Table(d, ['exp', 'RT coef', 'WER', 'SER']) return t @@ -167,7 +169,7 @@ def createSmallTable(r): # remove duplicates: duplicates if equal mimimum wer in dev set min_dev_un = [(e, lm, lmw) for ((e, lm), lmw) in - dict([((e, lm), lmw) for e, lm, lmw in min_dev]).items()] + list(dict([((e, lm), lmw) for e, lm, lmw in min_dev]).items())] # sort according LM -> sort results according experiment & LMs min_dev_un.sort(key=lambda x: (x[1], x[0])) @@ -182,6 +184,6 @@ def createSmallTable(r): d.append(x[0]) t = Table(data=d, colnames=['exp', 'set', 'LM', 'LMW', 'WER', 'SER']) - print str(t) + print(str(t)) if args.latex: - print Table2LatexTable(t) + print(Table2LatexTable(t)) diff --git a/egs/wsj/s5/local/chain/e2e/run_tdnn_flatstart.sh b/egs/wsj/s5/local/chain/e2e/run_tdnn_flatstart.sh index 9a4f0c87c8d..1ddb3c305ac 100755 --- a/egs/wsj/s5/local/chain/e2e/run_tdnn_flatstart.sh +++ b/egs/wsj/s5/local/chain/e2e/run_tdnn_flatstart.sh @@ -3,33 +3,31 @@ # This script performs chain training in a flat-start manner # and without building or using any context-dependency tree. -# It does not use ivecors or other forms of speaker adaptation -# except simple mean and variance normalization. +# It does not use ivecors or other forms of speaker adaptation. # It is called from run_e2e_phone.sh # Note: this script is configured as phone-based, if you want # to run it in character mode, you'll need to change _nosp -# to _char everywhere and also copy char_lm.fst instead -# of phone_lm.fst (in stage 1 below) - -# local/chain/compare_wer.sh exp/chain/e2e_tdnn_1a -# System e2e_tdnn_1a -#WER dev93 (tgpr) 9.63 -#WER dev93 (tg) 9.07 -#WER dev93 (big-dict,tgpr) 7.41 -#WER dev93 (big-dict,fg) 6.55 -#WER eval92 (tgpr) 5.90 -#WER eval92 (tg) 5.17 -#WER eval92 (big-dict,tgpr) 3.56 -#WER eval92 (big-dict,fg) 2.85 -# Final train prob -0.0726 -# Final valid prob -0.0884 +# to _char everywhere. + +# local/chain/compare_wer.sh exp/chain/e2e_tdnnf_1a +# System e2e_tdnnf_1a +#WER dev93 (tgpr) 8.77 +#WER dev93 (tg) 8.11 +#WER dev93 (big-dict,tgpr) 6.17 +#WER dev93 (big-dict,fg) 5.66 +#WER eval92 (tgpr) 5.62 +#WER eval92 (tg) 5.19 +#WER eval92 (big-dict,tgpr) 3.23 +#WER eval92 (big-dict,fg) 2.80 +# Final train prob -0.0618 +# Final valid prob -0.0825 # Final train prob (xent) # Final valid prob (xent) -# Num-params 3740934 +# Num-params 6772564 -# steps/info/chain_dir_info.pl exp/chain/e2e_tdnn_1a -# exp/chain/e2e_tdnn_1a: num-iters=102 nj=2..5 num-params=3.7M dim=40->84 combine=-0.117->-0.116 (over 3) logprob:train/valid[67,101,final]=(-0.080,-0.073,-0.073/-0.090,-0.089,-0.088) +# steps/info/chain_dir_info.pl exp/chain/e2e_tdnnf_1a +# exp/chain/e2e_tdnnf_1a: num-iters=180 nj=2..8 num-params=6.8M dim=40->84 combine=-0.060->-0.060 (over 3) logprob:train/valid[119,179,final]=(-0.080,-0.062,-0.062/-0.089,-0.083,-0.083) set -e @@ -40,15 +38,15 @@ get_egs_stage=-10 affix=1a # training options -num_epochs=4 +dropout_schedule='0,0@0.20,0.5@0.50,0' +num_epochs=10 num_jobs_initial=2 -num_jobs_final=5 -minibatch_size=150=128,64/300=100,64,32/600=50,32,16/1200=16,8 +num_jobs_final=8 +minibatch_size=150=128,64/300=64,32/600=32,16/1200=8 common_egs_dir= l2_regularize=0.00005 -dim=450 frames_per_iter=3000000 -cmvn_opts="--norm-means=true --norm-vars=true" +cmvn_opts="--norm-means=false --norm-vars=false" train_set=train_si284_spe2e_hires test_sets="test_dev93 test_eval92" @@ -69,7 +67,7 @@ fi lang=data/lang_e2e treedir=exp/chain/e2e_tree # it's actually just a trivial tree (no tree building) -dir=exp/chain/e2e_tdnn_${affix} +dir=exp/chain/e2e_tdnnf_${affix} if [ $stage -le 0 ]; then # Create a version of the lang/ directory that has one state per phone in the @@ -102,25 +100,35 @@ fi if [ $stage -le 2 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $treedir/tree | grep num-pdfs | awk '{print $2}') - opts="l2-regularize=0.01" - output_opts="l2-regularize=0.0025" + tdnn_opts="l2-regularize=0.01 dropout-proportion=0.0 dropout-per-dim-continuous=true" + tdnnf_opts="l2-regularize=0.01 dropout-proportion=0.0 bypass-scale=0.66" + linear_opts="l2-regularize=0.01 orthonormal-constraint=-1.0" + prefinal_opts="l2-regularize=0.01" + output_opts="l2-regularize=0.005" mkdir -p $dir/configs cat < $dir/configs/network.xconfig input dim=40 name=input - relu-batchnorm-layer name=tdnn1 input=Append(-1,0,1) dim=$dim - relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1) dim=$dim $opts - relu-batchnorm-layer name=tdnn3 dim=$dim $opts - relu-batchnorm-layer name=tdnn4 input=Append(-1,0,1) dim=$dim $opts - relu-batchnorm-layer name=tdnn5 dim=$dim $opts - relu-batchnorm-layer name=tdnn6 input=Append(-3,0,3) dim=$dim $opts - relu-batchnorm-layer name=tdnn7 input=Append(-3,0,3) dim=$dim $opts - relu-batchnorm-layer name=tdnn8 input=Append(-3,0,3) dim=$dim $opts - - relu-batchnorm-layer name=prefinal-chain dim=$dim target-rms=0.5 $opts - output-layer name=output include-log-softmax=true dim=$num_targets $output_opts + relu-batchnorm-dropout-layer name=tdnn1 input=Append(-1,0,1) $tdnn_opts dim=1024 + tdnnf-layer name=tdnnf2 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=1 + tdnnf-layer name=tdnnf3 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=1 + tdnnf-layer name=tdnnf4 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=1 + tdnnf-layer name=tdnnf5 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=0 + tdnnf-layer name=tdnnf6 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf7 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf8 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf9 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf10 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf11 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf12 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf13 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + linear-component name=prefinal-l dim=192 $linear_opts + + + prefinal-layer name=prefinal-chain input=prefinal-l $prefinal_opts big-dim=1024 small-dim=192 + output-layer name=output include-log-softmax=false dim=$num_targets $output_opts EOF steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs @@ -139,14 +147,15 @@ if [ $stage -le 3 ]; then --egs.dir "$common_egs_dir" \ --egs.stage $get_egs_stage \ --egs.opts "" \ + --trainer.dropout-schedule $dropout_schedule \ --trainer.num-chunk-per-minibatch $minibatch_size \ --trainer.frames-per-iter $frames_per_iter \ --trainer.num-epochs $num_epochs \ --trainer.optimization.momentum 0 \ --trainer.optimization.num-jobs-initial $num_jobs_initial \ --trainer.optimization.num-jobs-final $num_jobs_final \ - --trainer.optimization.initial-effective-lrate 0.001 \ - --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.optimization.initial-effective-lrate 0.0005 \ + --trainer.optimization.final-effective-lrate 0.00005 \ --trainer.optimization.shrink-value 1.0 \ --trainer.max-param-change 2.0 \ --cleanup.remove-egs true \ diff --git a/egs/wsj/s5/local/chain/e2e/run_tdnn_lstm_flatstart.sh b/egs/wsj/s5/local/chain/e2e/run_tdnn_lstm_flatstart.sh index cc7c64f3cc8..be82e80d5fe 100755 --- a/egs/wsj/s5/local/chain/e2e/run_tdnn_lstm_flatstart.sh +++ b/egs/wsj/s5/local/chain/e2e/run_tdnn_lstm_flatstart.sh @@ -6,31 +6,32 @@ # a full trivial biphone context-dependency tree. This is because this recipe is # meant for character-based (i.e. lexicon-free) modeling where context helps # significantly. -# It does not use ivecors or other forms of speaker adaptation -# except simple mean and variance normalization. +# It does not use ivecors or other forms of speaker adaptation. # It is called from run_e2e_char.sh # Note: this script is configured to run as character-based, if you want # to run it in phoneme mode, you'll need to change _char -# to _nosp everywhere and also copy phone_lm.fst instead -# of char_lm.fst (in stage 1 below) +# to _nosp everywhere. +# local/chain/compare_wer.sh exp/chain/e2e_tdnn_lstm_bichar_1a # System e2e_tdnn_lstm_bichar_1a -# WER dev93 (tgpr) 9.42 -# WER dev93 (tg) 8.85 -# WER dev93 (big-dict,tgpr) 7.70 -# WER dev93 (big-dict,fg) 6.79 -# WER eval92 (tgpr) 6.42 -# WER eval92 (tg) 6.11 -# WER eval92 (big-dict,tgpr) 4.50 -# WER eval92 (big-dict,fg) 4.09 -# Final train prob -0.7535 -# Final valid prob -0.7786 +#WER dev93 (tgpr) 9.85 +#WER dev93 (tg) 9.32 +#WER dev93 (big-dict,tgpr) 8.19 +#WER dev93 (big-dict,fg) 7.27 +#WER eval92 (tgpr) 6.89 +#WER eval92 (tg) 6.70 +#WER eval92 (big-dict,tgpr) 5.14 +#WER eval92 (big-dict,fg) 4.29 +# Final train prob -0.0610 +# Final valid prob -0.0836 +# Final train prob (xent) +# Final valid prob (xent) +# Num-params 9219188 # steps/info/chain_dir_info.pl exp/chain/e2e_tdnn_lstm_bichar_1a/ -# exp/chain/e2e_tdnn_lstm_bichar_1a/: num-iters=138 nj=2..5 num-params=9.2M dim=40->3444 combine=-6.480->-6.478 logprob:train/valid[91,137,final]=(-0.766,-0.754,-0.754/-0.784,-0.779,-0.779) - +# exp/chain/e2e_tdnn_lstm_bichar_1a_nocmvn: num-iters=138 nj=2..5 num-params=9.2M dim=40->3444 combine=-1.211->-1.211 (over 3) logprob:train/valid[91,137,final]=(-0.079,-0.062,-0.061/-0.093,-0.084,-0.084) set -e @@ -50,7 +51,7 @@ common_egs_dir= l2_regularize=0.00001 dim=512 frames_per_iter=2500000 -cmvn_opts="--norm-means=true --norm-vars=true" +cmvn_opts="--norm-means=false --norm-vars=false" train_set=train_si284_spe2e_hires test_sets="test_dev93 test_eval92" @@ -96,8 +97,9 @@ if [ $stage -le 1 ]; then mkdir -p $treedir/log $train_cmd $treedir/log/make_phone_lm.log \ cat data/$train_set/text \| \ - steps/nnet3/chain/e2e/text_to_phones.py data/lang_nosp \| \ - utils/sym2int.pl -f 2- data/lang_nosp/phones.txt \| \ + steps/nnet3/chain/e2e/text_to_phones.py --between-silprob 0.1 \ + data/lang_char \| \ + utils/sym2int.pl -f 2- data/lang_char/phones.txt \| \ chain-est-phone-lm --num-extra-lm-states=2000 \ ark:- $treedir/phone_lm.fst steps/nnet3/chain/e2e/prepare_e2e.sh --nj 30 --cmd "$train_cmd" \ diff --git a/egs/wsj/s5/local/chain/e2e/run_tdnnf_flatstart_char.sh b/egs/wsj/s5/local/chain/e2e/run_tdnnf_flatstart_char.sh new file mode 120000 index 00000000000..b20849c2a48 --- /dev/null +++ b/egs/wsj/s5/local/chain/e2e/run_tdnnf_flatstart_char.sh @@ -0,0 +1 @@ +tuning/run_tdnnf_flatstart_char1b.sh \ No newline at end of file diff --git a/egs/wsj/s5/local/chain/e2e/tuning/run_tdnnf_flatstart_char1a.sh b/egs/wsj/s5/local/chain/e2e/tuning/run_tdnnf_flatstart_char1a.sh new file mode 100755 index 00000000000..4ab0cf58d53 --- /dev/null +++ b/egs/wsj/s5/local/chain/e2e/tuning/run_tdnnf_flatstart_char1a.sh @@ -0,0 +1,225 @@ +#!/bin/bash +# Copyright 2017 Hossein Hadian + +# This script performs chain training in a flat-start manner +# and without building or using any context-dependency tree. +# It does not use ivecors or other forms of speaker adaptation +# It is called from run_e2e_char.sh + +# Note: this script is configured as grapheme-based, if you want +# to run it in phoneme mode, you'll need to change _char +# to _nosp everywhere. + +# This is the same as run_tdnn_lstm_flatstart.sh except it uses +# TDNN-F (and CMVN is disabled). + + +# local/chain/compare_wer.sh exp/chain/e2e_tdnn_lstm_bichar_1a exp/chain/e2e_tdnnf_bichar1a +# System e2e_tdnn_lstm_bichar_1a e2e_tdnnf_bichar1a +# WER dev93 (tgpr) 9.42 8.89 +# WER dev93 (tg) 8.85 8.20 +# WER dev93 (big-dict,tgpr) 7.70 6.96 +# WER dev93 (big-dict,fg) 6.79 6.01 +# WER eval92 (tgpr) 6.42 6.08 +# WER eval92 (tg) 6.11 5.79 +# WER eval92 (big-dict,tgpr) 4.50 4.39 +# WER eval92 (big-dict,fg) 4.09 3.88 +# Final train prob -0.0610 -0.0598 +# Final valid prob -0.0836 -0.0854 +# Final train prob (xent) +# Final valid prob (xent) +# Num-params 9219188 7421044 + +# steps/info/chain_dir_info.pl exp/chain/e2e_tdnnf_bichar1a +# exp/chain/e2e_tdnnf_bichar1a: num-iters=180 nj=2..8 num-params=7.4M dim=40->3444 combine=-0.064->-0.064 (over 3) logprob:train/valid[119,179,final]=(-0.093,-0.060,-0.060/-0.107,-0.086,-0.085) + + +set -e + +# configs for 'chain' +stage=0 +train_stage=-10 +get_egs_stage=-10 +affix=1a + +# training options +dropout_schedule='0,0@0.20,0.5@0.50,0' +num_epochs=10 +num_jobs_initial=2 +num_jobs_final=8 +minibatch_size=150=128,64/300=64,32/600=32,16/1200=8 +common_egs_dir= +l2_regularize=0.00005 +frames_per_iter=3000000 +cmvn_opts="--norm-means=false --norm-vars=false" +train_set=train_si284_spe2e_hires +test_sets="test_dev93 test_eval92" + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 1 ]; then + echo "$0: Estimating a phone language model for the denominator graph..." + mkdir -p $treedir/log + $train_cmd $treedir/log/make_phone_lm.log \ + cat data/$train_set/text \| \ + steps/nnet3/chain/e2e/text_to_phones.py --between-silprob 0.1 \ + data/lang_char \| \ + utils/sym2int.pl -f 2- data/lang_char/phones.txt \| \ + chain-est-phone-lm --num-extra-lm-states=2000 \ + ark:- $treedir/phone_lm.fst + steps/nnet3/chain/e2e/prepare_e2e.sh --nj 30 --cmd "$train_cmd" \ + --type biphone \ + --shared-phones true \ + data/$train_set $lang $treedir +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $treedir/tree | grep num-pdfs | awk '{print $2}') + tdnn_opts="l2-regularize=0.01 dropout-proportion=0.0 dropout-per-dim-continuous=true" + tdnnf_opts="l2-regularize=0.01 dropout-proportion=0.0 bypass-scale=0.66" + linear_opts="l2-regularize=0.01 orthonormal-constraint=-1.0" + prefinal_opts="l2-regularize=0.01" + output_opts="l2-regularize=0.005" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + + input dim=40 name=input + + relu-batchnorm-dropout-layer name=tdnn1 input=Append(-1,0,1) $tdnn_opts dim=1024 + tdnnf-layer name=tdnnf2 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=1 + tdnnf-layer name=tdnnf3 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=1 + tdnnf-layer name=tdnnf4 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=1 + tdnnf-layer name=tdnnf5 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=0 + tdnnf-layer name=tdnnf6 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf7 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf8 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf9 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf10 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf11 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf12 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf13 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + linear-component name=prefinal-l dim=192 $linear_opts + + + prefinal-layer name=prefinal-chain input=prefinal-l $prefinal_opts big-dim=1024 small-dim=192 + output-layer name=output include-log-softmax=false dim=$num_targets $output_opts + +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs +fi + +if [ $stage -le 3 ]; then + # no need to store the egs in a shared storage because we always + # remove them. Anyway, it takes only 5 minutes to generate them. + + steps/nnet3/chain/e2e/train_e2e.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.cmvn-opts "$cmvn_opts" \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize $l2_regularize \ + --chain.apply-deriv-weights false \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "" \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter $frames_per_iter \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.momentum 0 \ + --trainer.optimization.num-jobs-initial $num_jobs_initial \ + --trainer.optimization.num-jobs-final $num_jobs_final \ + --trainer.optimization.initial-effective-lrate 0.0005 \ + --trainer.optimization.final-effective-lrate 0.00005 \ + --trainer.optimization.shrink-value 1.0 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs true \ + --feat-dir data/${train_set} \ + --tree-dir $treedir \ + --dir $dir || exit 1; +fi + +if [ $stage -le 4 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/lang/check_phones_compatible.sh \ + data/lang_char_test_tgpr/phones.txt $lang/phones.txt + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/lang_char_test_tgpr \ + $dir $treedir/graph_tgpr || exit 1; + + utils/lang/check_phones_compatible.sh \ + data/lang_char_test_bd_tgpr/phones.txt $lang/phones.txt + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/lang_char_test_bd_tgpr \ + $dir $treedir/graph_bd_tgpr || exit 1; +fi + +if [ $stage -le 5 ]; then + frames_per_chunk=150 + rm $dir/.error 2>/dev/null || true + + for data in $test_sets; do + ( + data_affix=$(echo $data | sed s/test_//) + nspk=$(wc -l 1397 combine=-0.064->-0.064 (over 2) logprob:train/valid[119,179,final]=(-0.086,-0.060,-0.060/-0.099,-0.087,-0.087) + + +set -e + +# configs for 'chain' +stage=0 +train_stage=-10 +get_egs_stage=-10 +affix=1b + +# training options +dropout_schedule='0,0@0.20,0.5@0.50,0' +num_epochs=10 +num_jobs_initial=2 +num_jobs_final=8 +minibatch_size=150=128,64/300=64,32/600=32,16/1200=8 +common_egs_dir= +l2_regularize=0.00005 +frames_per_iter=3000000 +cmvn_opts="--norm-means=false --norm-vars=false" +train_set=train_si284_spe2e_hires +test_sets="test_dev93 test_eval92" + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 1 ]; then + echo "$0: Estimating a phone language model for the denominator graph..." + mkdir -p $treedir/log + $train_cmd $treedir/log/make_phone_lm.log \ + cat data/$train_set/text \| \ + steps/nnet3/chain/e2e/text_to_phones.py --between-silprob 0.1 \ + data/lang_char \| \ + utils/sym2int.pl -f 2- data/lang_char/phones.txt \| \ + chain-est-phone-lm --num-extra-lm-states=2000 \ + ark:- $treedir/phone_lm.fst + steps/nnet3/chain/e2e/prepare_e2e.sh --nj 30 --cmd "$train_cmd" \ + --type biphone \ + --shared-phones true \ + --tie true \ + --min-biphone-count 100 \ + --min-monophone-count 20 \ + data/$train_set $lang $treedir +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $treedir/tree | grep num-pdfs | awk '{print $2}') + tdnn_opts="l2-regularize=0.01 dropout-proportion=0.0 dropout-per-dim-continuous=true" + tdnnf_opts="l2-regularize=0.01 dropout-proportion=0.0 bypass-scale=0.66" + linear_opts="l2-regularize=0.01 orthonormal-constraint=-1.0" + prefinal_opts="l2-regularize=0.01" + output_opts="l2-regularize=0.005" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + + input dim=40 name=input + + relu-batchnorm-dropout-layer name=tdnn1 input=Append(-1,0,1) $tdnn_opts dim=1024 + tdnnf-layer name=tdnnf2 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=1 + tdnnf-layer name=tdnnf3 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=1 + tdnnf-layer name=tdnnf4 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=1 + tdnnf-layer name=tdnnf5 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=0 + tdnnf-layer name=tdnnf6 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf7 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf8 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf9 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf10 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf11 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf12 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf13 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + linear-component name=prefinal-l dim=192 $linear_opts + + + prefinal-layer name=prefinal-chain input=prefinal-l $prefinal_opts big-dim=1024 small-dim=192 + output-layer name=output include-log-softmax=false dim=$num_targets $output_opts + +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs +fi + +if [ $stage -le 3 ]; then + # no need to store the egs in a shared storage because we always + # remove them. Anyway, it takes only 5 minutes to generate them. + + steps/nnet3/chain/e2e/train_e2e.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.cmvn-opts "$cmvn_opts" \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize $l2_regularize \ + --chain.apply-deriv-weights false \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "" \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter $frames_per_iter \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.momentum 0 \ + --trainer.optimization.num-jobs-initial $num_jobs_initial \ + --trainer.optimization.num-jobs-final $num_jobs_final \ + --trainer.optimization.initial-effective-lrate 0.0005 \ + --trainer.optimization.final-effective-lrate 0.00005 \ + --trainer.optimization.shrink-value 1.0 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs true \ + --feat-dir data/${train_set} \ + --tree-dir $treedir \ + --dir $dir || exit 1; +fi + +if [ $stage -le 4 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/lang/check_phones_compatible.sh \ + data/lang_char_test_tgpr/phones.txt $lang/phones.txt + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/lang_char_test_tgpr \ + $dir $treedir/graph_tgpr || exit 1; + + utils/lang/check_phones_compatible.sh \ + data/lang_char_test_bd_tgpr/phones.txt $lang/phones.txt + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/lang_char_test_bd_tgpr \ + $dir $treedir/graph_bd_tgpr || exit 1; +fi + +if [ $stage -le 5 ]; then + frames_per_chunk=150 + rm $dir/.error 2>/dev/null || true + + for data in $test_sets; do + ( + data_affix=$(echo $data | sed s/test_//) + nspk=$(wc -l $dir/configs/network.xconfig diff --git a/egs/wsj/s5/local/chain/tuning/run_cnn_tdnn_1b.sh b/egs/wsj/s5/local/chain/tuning/run_cnn_tdnn_1b.sh index a3a747ed743..9db76e94430 100755 --- a/egs/wsj/s5/local/chain/tuning/run_cnn_tdnn_1b.sh +++ b/egs/wsj/s5/local/chain/tuning/run_cnn_tdnn_1b.sh @@ -170,7 +170,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/wsj/s5/local/chain/tuning/run_cnn_tdnn_1c.sh b/egs/wsj/s5/local/chain/tuning/run_cnn_tdnn_1c.sh new file mode 100755 index 00000000000..36ec5bb61af --- /dev/null +++ b/egs/wsj/s5/local/chain/tuning/run_cnn_tdnn_1c.sh @@ -0,0 +1,341 @@ +#!/bin/bash + +# 1c is as 1b but taking the first layers from the cnn_tdnn_1a setup in mini_librispeech. +# A little better than the baseline and overfits more. +# +# local/chain/compare_wer.sh exp/chain/tdnn1g_sp exp/chain/cnn_tdnn1c_sp +# System tdnn1g_sp cnn_tdnn1c_sp +#WER dev93 (tgpr) 6.68 6.55 +#WER dev93 (tg) 6.57 6.49 +#WER dev93 (big-dict,tgpr) 4.60 4.52 +#WER dev93 (big-dict,fg) 4.26 4.13 +#WER eval92 (tgpr) 4.54 4.47 +#WER eval92 (tg) 4.32 4.15 +#WER eval92 (big-dict,tgpr) 2.62 2.57 +#WER eval92 (big-dict,fg) 2.32 2.02 +# Final train prob -0.0417 -0.0409 +# Final valid prob -0.0487 -0.0486 +# Final train prob (xent) -0.6461 -0.6203 +# Final valid prob (xent) -0.6882 -0.6591 +# Num-params 8354636 6935084 + + +set -e -o pipefail + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +nj=30 +train_set=train_si284 +test_sets="test_dev93 test_eval92" +gmm=tri4b # this is the source gmm-dir that we'll use for alignments; it + # should have alignments for the specified training data. +num_threads_ubm=32 +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. + +# Options which are not passed through to run_ivector_common.sh +affix=1c #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +common_egs_dir= +reporting_email= + +# LSTM/chain options +train_stage=-10 +xent_regularize=0.1 +dropout_schedule='0,0@0.20,0.5@0.50,0' + +# training chunk-options +chunk_width=140,100,160 + +# training options +srand=0 +remove_egs=true + +#decode options +test_online_decoding=false # if true, it will run the last decoding stage. + +# End configuration section. +echo "$0 $@" # Print the command line for logging + + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 13 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj 100 --cmd "$train_cmd" ${lores_train_data_dir} \ + data/lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 14 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 3500 ${lores_train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 15 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + cnn_opts="l2-regularize=0.01" + ivector_affine_opts="l2-regularize=0.01" + tdnn_opts="l2-regularize=0.01 dropout-proportion=0.0 dropout-per-dim-continuous=true" + tdnnf_first_opts="l2-regularize=0.01 dropout-proportion=0.0 bypass-scale=0.0" + tdnnf_opts="l2-regularize=0.01 dropout-proportion=0.0 bypass-scale=0.66" + linear_opts="l2-regularize=0.01 orthonormal-constraint=-1.0" + prefinal_opts="l2-regularize=0.01" + output_opts="l2-regularize=0.005" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # this takes the MFCCs and generates filterbank coefficients. The MFCCs + # are more compressible so we prefer to dump the MFCCs to disk rather + # than filterbanks. + idct-layer name=idct input=input dim=40 cepstral-lifter=22 affine-transform-file=$dir/configs/idct.mat + + linear-component name=ivector-linear $ivector_affine_opts dim=200 input=ReplaceIndex(ivector, t, 0) + batchnorm-component name=ivector-batchnorm target-rms=0.025 + + batchnorm-component name=idct-batchnorm input=idct + combine-feature-maps-layer name=combine_inputs input=Append(idct-batchnorm, ivector-batchnorm) num-filters1=1 num-filters2=5 height=40 + + conv-relu-batchnorm-layer name=cnn1 $cnn_opts height-in=40 height-out=40 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=48 learning-rate-factor=0.333 max-change=0.25 + conv-relu-batchnorm-layer name=cnn2 $cnn_opts height-in=40 height-out=40 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=48 + conv-relu-batchnorm-layer name=cnn3 $cnn_opts height-in=40 height-out=20 height-subsample-out=2 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=64 + conv-relu-batchnorm-layer name=cnn4 $cnn_opts height-in=20 height-out=20 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=64 + conv-relu-batchnorm-layer name=cnn5 $cnn_opts height-in=20 height-out=10 height-subsample-out=2 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=64 + conv-relu-batchnorm-layer name=cnn6 $cnn_opts height-in=10 height-out=5 height-subsample-out=2 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=128 + + # the first TDNN-F layer has no bypass (since dims don't match), and a larger bottleneck so the + # information bottleneck doesn't become a problem. + tdnnf-layer name=tdnnf7 $tdnnf_first_opts dim=1024 bottleneck-dim=256 time-stride=0 + tdnnf-layer name=tdnnf8 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf9 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf10 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf11 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf12 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf13 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf14 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + tdnnf-layer name=tdnnf15 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=3 + linear-component name=prefinal-l dim=192 $linear_opts + + prefinal-layer name=prefinal-chain input=prefinal-l $prefinal_opts big-dim=1024 small-dim=192 + output-layer name=output include-log-softmax=false dim=$num_targets $output_opts + + prefinal-layer name=prefinal-xent input=prefinal-l $prefinal_opts big-dim=1024 small-dim=192 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 16 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{4,5,6,7}/$USER/kaldi-data/egs/wsj-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$decode_cmd" \ + --feat.online-ivector-dir=$train_ivector_dir \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.0 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=8 \ + --trainer.frames-per-iter=3000000 \ + --trainer.optimization.num-jobs-initial=2 \ + --trainer.optimization.num-jobs-final=8 \ + --trainer.optimization.initial-effective-lrate=0.0005 \ + --trainer.optimization.final-effective-lrate=0.00005 \ + --trainer.num-chunk-per-minibatch=128,64 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 17 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/lang/check_phones_compatible.sh \ + data/lang_test_tgpr/phones.txt $lang/phones.txt + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/lang_test_tgpr \ + $tree_dir $tree_dir/graph_tgpr || exit 1; + + utils/lang/check_phones_compatible.sh \ + data/lang_test_bd_tgpr/phones.txt $lang/phones.txt + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/lang_test_bd_tgpr \ + $tree_dir $tree_dir/graph_bd_tgpr || exit 1; +fi + +if [ $stage -le 18 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + rm $dir/.error 2>/dev/null || true + + for data in $test_sets; do + ( + data_affix=$(echo $data | sed s/test_//) + nspk=$(wc -l /dev/null || true + + for data in $test_sets; do + ( + data_affix=$(echo $data | sed s/test_//) + nspk=$(wc -l $dir/configs/network.xconfig diff --git a/egs/wsj/s5/local/chain/tuning/run_tdnn_1b.sh b/egs/wsj/s5/local/chain/tuning/run_tdnn_1b.sh index a2bb7e93388..544b9b04a0a 100755 --- a/egs/wsj/s5/local/chain/tuning/run_tdnn_1b.sh +++ b/egs/wsj/s5/local/chain/tuning/run_tdnn_1b.sh @@ -158,7 +158,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/wsj/s5/local/chain/tuning/run_tdnn_1c.sh b/egs/wsj/s5/local/chain/tuning/run_tdnn_1c.sh index 7dc30ecf8fe..b268ed7feda 100755 --- a/egs/wsj/s5/local/chain/tuning/run_tdnn_1c.sh +++ b/egs/wsj/s5/local/chain/tuning/run_tdnn_1c.sh @@ -159,7 +159,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/wsj/s5/local/chain/tuning/run_tdnn_1d.sh b/egs/wsj/s5/local/chain/tuning/run_tdnn_1d.sh index 603e0f064b9..d1a7f9d0663 100755 --- a/egs/wsj/s5/local/chain/tuning/run_tdnn_1d.sh +++ b/egs/wsj/s5/local/chain/tuning/run_tdnn_1d.sh @@ -159,7 +159,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/wsj/s5/local/chain/tuning/run_tdnn_1e.sh b/egs/wsj/s5/local/chain/tuning/run_tdnn_1e.sh index 9808e274d83..e20069fbfa1 100755 --- a/egs/wsj/s5/local/chain/tuning/run_tdnn_1e.sh +++ b/egs/wsj/s5/local/chain/tuning/run_tdnn_1e.sh @@ -167,7 +167,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.01" output_opts="l2-regularize=0.0025" diff --git a/egs/wsj/s5/local/chain/tuning/run_tdnn_1f.sh b/egs/wsj/s5/local/chain/tuning/run_tdnn_1f.sh index e3d13ac1f65..86df0779841 100755 --- a/egs/wsj/s5/local/chain/tuning/run_tdnn_1f.sh +++ b/egs/wsj/s5/local/chain/tuning/run_tdnn_1f.sh @@ -161,7 +161,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) opts="l2-regularize=0.01" output_opts="l2-regularize=0.005 bottleneck-dim=320" diff --git a/egs/wsj/s5/local/chain/tuning/run_tdnn_1g.sh b/egs/wsj/s5/local/chain/tuning/run_tdnn_1g.sh index 1724c057e12..8f566ccfe6d 100755 --- a/egs/wsj/s5/local/chain/tuning/run_tdnn_1g.sh +++ b/egs/wsj/s5/local/chain/tuning/run_tdnn_1g.sh @@ -160,7 +160,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print(0.5/$xent_regularize)" | python) tdnn_opts="l2-regularize=0.01 dropout-proportion=0.0 dropout-per-dim-continuous=true" tdnnf_opts="l2-regularize=0.01 dropout-proportion=0.0 bypass-scale=0.66" linear_opts="l2-regularize=0.01 orthonormal-constraint=-1.0" @@ -220,6 +220,7 @@ if [ $stage -le 16 ]; then --chain.apply-deriv-weights=false \ --chain.lm-opts="--num-extra-lm-states=2000" \ --trainer.dropout-schedule $dropout_schedule \ + --trainer.add-option="--optimization.memory-compression-level=2" \ --trainer.srand=$srand \ --trainer.max-param-change=2.0 \ --trainer.num-epochs=10 \ diff --git a/egs/wsj/s5/local/chain/tuning/run_tdnn_lstm_1a.sh b/egs/wsj/s5/local/chain/tuning/run_tdnn_lstm_1a.sh index 4b752a55a4b..6e4f220c1f2 100755 --- a/egs/wsj/s5/local/chain/tuning/run_tdnn_lstm_1a.sh +++ b/egs/wsj/s5/local/chain/tuning/run_tdnn_lstm_1a.sh @@ -181,7 +181,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) mkdir -p $dir/configs cat < $dir/configs/network.xconfig diff --git a/egs/wsj/s5/local/chain/tuning/run_tdnn_lstm_1b.sh b/egs/wsj/s5/local/chain/tuning/run_tdnn_lstm_1b.sh index 51fefb9ca88..2d113e58a93 100755 --- a/egs/wsj/s5/local/chain/tuning/run_tdnn_lstm_1b.sh +++ b/egs/wsj/s5/local/chain/tuning/run_tdnn_lstm_1b.sh @@ -473,7 +473,7 @@ if [ $stage -le 15 ]; then echo "$0: creating neural net configs using the xconfig parser"; num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') - learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) tdnn_opts="l2-regularize=0.01" output_opts="l2-regularize=0.005 bottleneck-dim=256" lstm_opts="l2-regularize=0.005 self-scale=2.0" diff --git a/egs/wsj/s5/local/e2e/run_end2end_char.sh b/egs/wsj/s5/local/e2e/run_end2end_char.sh index e5c84c405e2..ff44802f2be 100755 --- a/egs/wsj/s5/local/e2e/run_end2end_char.sh +++ b/egs/wsj/s5/local/e2e/run_end2end_char.sh @@ -56,6 +56,7 @@ if [ $stage -le 1 ]; then local/wsj_train_lms.sh --dict-suffix "_char" local/wsj_format_local_lms.sh --lang-suffix "_char" echo "$0: Done extending the vocabulary." + exit 0; fi if [ $stage -le 2 ]; then @@ -102,5 +103,5 @@ fi if [ $stage -le 5 ]; then echo "$0: calling the flat-start chain recipe..." - local/chain/e2e/run_tdnn_lstm_flatstart.sh + local/chain/e2e/run_tdnnf_flatstart_char.sh fi diff --git a/egs/wsj/s5/local/nnet3/run_ivector_common.sh b/egs/wsj/s5/local/nnet3/run_ivector_common.sh index 2c218ae3673..813c6e14aed 100755 --- a/egs/wsj/s5/local/nnet3/run_ivector_common.sh +++ b/egs/wsj/s5/local/nnet3/run_ivector_common.sh @@ -149,8 +149,8 @@ if [ $stage -le 5 ]; then done fi -if [ -f data/${train_set}_sp/feats.scp ] && [ $stage -le 8 ]; then - echo "$0: $feats already exists. Refusing to overwrite the features " +if [ -f data/${train_set}_sp/feats.scp ] && [ $stage -le 7 ]; then + echo "$0: data/${train_set}_sp/feats.scp already exists. Refusing to overwrite the features " echo " to avoid wasting time. Please remove the file and continue if you really mean this." exit 1; fi diff --git a/egs/wsj/s5/local/rnnlm/run_rnnlm.sh b/egs/wsj/s5/local/rnnlm/run_rnnlm.sh new file mode 120000 index 00000000000..e638c4df523 --- /dev/null +++ b/egs/wsj/s5/local/rnnlm/run_rnnlm.sh @@ -0,0 +1 @@ +tuning/run_lstm_tdnn_1b.sh \ No newline at end of file diff --git a/egs/wsj/s5/local/rnnlm/tuning/run_lstm_tdnn_1b.sh b/egs/wsj/s5/local/rnnlm/tuning/run_lstm_tdnn_1b.sh index 0bf6b2a102f..8fe50b699cf 100755 --- a/egs/wsj/s5/local/rnnlm/tuning/run_lstm_tdnn_1b.sh +++ b/egs/wsj/s5/local/rnnlm/tuning/run_lstm_tdnn_1b.sh @@ -9,7 +9,18 @@ # Train objf: -1038.00 -5.35 -5.04 -4.87 -4.76 -4.68 -4.61 -4.56 -4.52 -4.47 -4.44 -4.41 -4.37 -4.35 -4.33 -4.31 -4.29 -4.27 -4.25 -4.24 -4.23 -4.21 -4.19 -4.17 -4.16 -4.15 -4.13 -4.12 -4.11 -4.10 -4.09 -4.07 -4.07 -4.06 -4.05 -4.04 -4.03 -4.02 -4.01 -4.00 -3.99 -3.98 -3.98 -3.97 -3.96 -3.96 -3.95 -3.94 -3.93 -3.93 -3.92 -3.92 -3.91 -3.91 -3.90 -3.90 -3.89 -3.88 -3.88 -3.88 -3.88 -3.88 -3.86 -3.86 -3.85 -3.85 -3.84 -3.83 -3.83 -3.83 -3.82 -3.82 -3.81 -3.81 -3.80 -3.80 -3.79 -3.79 -3.79 -3.79 # Dev objf: -11.73 -5.66 -5.18 -4.96 -4.82 -4.73 -4.66 -4.59 -4.54 -4.51 -4.47 -4.44 -4.40 -4.38 -4.36 -4.34 -4.32 -4.30 -4.28 -4.27 -4.26 -4.21 -4.19 -4.18 -4.16 -4.15 -4.14 -4.13 -4.12 -4.12 -4.11 -4.09 -4.09 -4.08 -4.07 -4.07 -4.06 -4.06 -4.05 -4.04 -4.04 -4.04 -4.03 -4.02 -4.02 -4.01 -4.01 -4.00 -4.00 -4.00 -3.99 -3.99 -3.98 -3.98 -3.98 -3.98 -3.97 -3.97 -3.97 -3.97 -3.96 -3.95 -3.95 -3.94 -3.94 -3.94 -3.94 -3.93 -3.93 -3.93 -3.93 -3.93 -3.93 -3.92 -3.92 -3.92 -3.92 -3.92 -3.91 -3.91 +# WER numbers + +# without RNNLM +# %WER 7.51 [ 618 / 8234, 82 ins, 112 del, 424 sub ] exp/chain/tdnn_lstm1b_sp/decode_looped_tgpr_dev93/wer_10_1.0 +# %WER 5.21 [ 294 / 5643, 55 ins, 34 del, 205 sub ] exp/chain/tdnn_lstm1b_sp/decode_looped_tgpr_eval92/wer_11_0.5 + +# with RNNLM +# %WER 5.74 [ 473 / 8234, 81 ins, 76 del, 316 sub ] exp/chain/tdnn_lstm1b_sp/decode_looped_tgpr_dev93_rnnlm/wer_14_1.0 +# %WER 4.27 [ 241 / 5643, 62 ins, 23 del, 156 sub ] exp/chain/tdnn_lstm1b_sp/decode_looped_tgpr_eval92_rnnlm/wer_12_1.0 + # Begin configuration section. + dir=exp/rnnlm_lstm_tdnn_1b embedding_dim=800 lstm_rpd=200 @@ -21,6 +32,11 @@ epochs=20 stage=-10 train_stage=-10 +# variables for rnnlm rescoring +ac_model_dir=exp/chain/tdnn_lstm1b_sp +ngram_order=4 +decode_dir_suffix=rnnlm + . ./cmd.sh . ./utils/parse_options.sh [ -z "$cmd" ] && cmd=$train_cmd @@ -102,4 +118,20 @@ if [ $stage -le 3 ]; then --stage $train_stage --num-epochs $epochs --cmd "$cmd" $dir fi +LM=tgpr +if [ $stage -le 4 ]; then + for decode_set in dev93 eval92; do + decode_dir=${ac_model_dir}/decode_looped_${LM}_${decode_set} + + # Lattice rescoring + rnnlm/lmrescore_pruned.sh \ + --cmd "$decode_cmd --mem 4G" \ + --weight 0.8 --max-ngram-order $ngram_order \ + data/lang_test_$LM $dir \ + data/test_${decode_set}_hires ${decode_dir} \ + ${decode_dir}_${decode_dir_suffix} & + done + wait +fi + exit 0 diff --git a/egs/wsj/s5/local/run_sgmm2.sh b/egs/wsj/s5/local/run_sgmm2.sh index e2b12184c22..f391797ee58 100755 --- a/egs/wsj/s5/local/run_sgmm2.sh +++ b/egs/wsj/s5/local/run_sgmm2.sh @@ -144,7 +144,7 @@ local/score_combine.sh data/test_eval92 \ # %WER 3.76 [ 212 / 5643, 32 ins, 12 del, 168 sub ] exp/combine_tri4b_fmmi_a_sgmm2_5b_mmi_b0.1/decode_bd_tgpr_eval92_it8_3/wer_12 # Checking MBR decode of baseline: -rm -r exp/sgmm2_5b_mmi_b0.1/decode_bd_tgpr_eval92_it3.mbr 2>/dev/null +rm -r exp/sgmm2_5b_mmi_b0.1/decode_bd_tgpr_eval92_it3.mbr 2>/dev/null cp -r exp/sgmm2_5b_mmi_b0.1/decode_bd_tgpr_eval92_it3{,.mbr} local/score_mbr.sh data/test_eval92 data/lang_test_bd_tgpr exp/sgmm2_5b_mmi_b0.1/decode_bd_tgpr_eval92_it3.mbr # MBR decoding did not seem to help (baseline was 3.85). I think this is normal at such low WERs. diff --git a/egs/wsj/s5/run.sh b/egs/wsj/s5/run.sh index 277252cecc3..4d88ff58f59 100755 --- a/egs/wsj/s5/run.sh +++ b/egs/wsj/s5/run.sh @@ -320,41 +320,48 @@ if [ $stage -le 6 ]; then fi fi +if [ $stage -le 7 ]; then + # Caution: this part needs a GPU. + local/chain/run_tdnn.sh +fi exit 0; -### Caution: the parts of the script below this statement are not run by default. -### - +# Below are some commented-out commands that demonstrate how to run various other things-- +# mainly outdated methods. # Train and test MMI, and boosted MMI, on tri4b (LDA+MLLT+SAT on # all the data). Use 30 jobs. -steps/align_fmllr.sh --nj 30 --cmd "$train_cmd" \ - data/train_si284 data/lang exp/tri4b exp/tri4b_ali_si284 || exit 1; -local/run_mmi_tri4b.sh - -# These demonstrate how to build a sytem usable for online-decoding with the nnet2 setup. -# (see local/run_nnet2.sh for other, non-online nnet2 setups). -local/online/run_nnet2.sh -local/online/run_nnet2_baseline.sh -local/online/run_nnet2_discriminative.sh - -# Demonstration of RNNLM rescoring on TDNN models. We comment this out by -# default. +# Note: there isn't much use for this these days. +#steps/align_fmllr.sh --nj 30 --cmd "$train_cmd" \ +# data/train_si284 data/lang exp/tri4b exp/tri4b_ali_si284 || exit 1; +#local/run_mmi_tri4b.sh + +# The following are the old nnet2 recipes. +#local/online/run_nnet2.sh +#local/online/run_nnet2_baseline.sh +#local/online/run_nnet2_discriminative.sh + +# The following is the + + +# Demonstration of RNNLM rescoring on nnet2 TDNN models. This is +# outdated now. # local/run_rnnlms.sh #local/run_nnet2.sh # You probably want to run the sgmm2 recipe as it's generally a bit better: -local/run_sgmm2.sh +# The SGMM2 recipe. This is better than GMMs but you probably just want the neural net. +# local/run_sgmm2.sh # We demonstrate MAP adaptation of GMMs to gender-dependent systems here. This also serves # as a generic way to demonstrate MAP adaptation to different domains. # local/run_gender_dep.sh -# You probably want to run the hybrid recipe as it is complementary: -local/nnet/run_dnn.sh +# This is the old "nnet1" neural net. +#local/nnet/run_dnn.sh # The following demonstrate how to re-segment long audios. # local/run_segmentation_long_utts.sh diff --git a/egs/wsj/s5/steps/align_basis_fmllr.sh b/egs/wsj/s5/steps/align_basis_fmllr.sh index d65986bd9ec..e5510c5ab7e 100755 --- a/egs/wsj/s5/steps/align_basis_fmllr.sh +++ b/egs/wsj/s5/steps/align_basis_fmllr.sh @@ -20,6 +20,7 @@ cmd=run.pl use_graphs=false # Begin configuration. scale_opts="--transition-scale=1.0 --acoustic-scale=0.1 --self-loop-scale=0.1" +basis_fmllr_opts="--fmllr-min-count=22 --num-iters=10 --size-scale=0.2 --step-size-iters=3" beam=10 retry_beam=40 boost_silence=1.5 # factor by which to boost silence during alignment. @@ -32,8 +33,11 @@ echo "$0 $@" # Print the command line for logging . parse_options.sh || exit 1; if [ $# != 4 ]; then - echo "usage: steps/align_fmllr.sh " - echo "e.g.: steps/align_fmllr.sh data/train data/lang exp/tri1 exp/tri1_ali" + echo "usage: steps/align_basis_fmllr.sh " + echo "e.g.: steps/align_basis_fmllr.sh data/train data/lang exp/tri4 exp/tri4_ali" + echo "Note: should ideally have been trained by steps/train_sat_basis.sh, or" + echo "if a non-SAT system (not recommended), the basis should have been computed" + echo "by steps/get_fmllr_basis.sh." echo "main options (for others, see top of script file)" echo " --config # config containing options" echo " --nj # number of parallel jobs" @@ -57,9 +61,19 @@ mkdir -p $dir/log echo $nj > $dir/num_jobs [[ -d $sdata && $data/feats.scp -ot $sdata ]] || split_data.sh $data $nj || exit 1; + +for f in $srcdir/tree $srcdir/final.mdl $srcdir/fmllr.basis \ + $data/feats.scp $lang/phones.txt; do + if [ ! -f $f ]; then + echo "$0: expected file $f to exist" + exit 1 + fi +done + utils/lang/check_phones_compatible.sh $lang/phones.txt $srcdir/phones.txt || exit 1; cp $lang/phones.txt $dir || exit 1; + cp $srcdir/{tree,final.mdl} $dir || exit 1; cp $srcdir/final.occs $dir; splice_opts=`cat $srcdir/splice_opts 2>/dev/null` # frame-splicing options. @@ -123,22 +137,20 @@ if [ $stage -le 2 ]; then ali-to-post "ark:gunzip -c $dir/pre_ali.JOB.gz|" ark:- \| \ weight-silence-post 0.0 $silphonelist $alimdl ark:- ark:- \| \ gmm-post-to-gpost $alimdl "$sifeats" ark:- ark:- \| \ - gmm-est-basis-fmllr-gpost --fmllr-min-count=22 --num-iters=10 \ - --size-scale=0.2 --step-size-iters=3 \ - --write-weights=ark:$dir/pre_wgt.JOB \ + gmm-est-basis-fmllr-gpost $basis_fmllr_opts --spk2utt=ark:$sdata/JOB/spk2utt \ $mdl $srcdir/fmllr.basis "$sifeats" ark,s,cs:- \ ark:$dir/trans.JOB || exit 1; -# else -# $cmd JOB=1:$nj $dir/log/fmllr.JOB.log \ -# ali-to-post "ark:gunzip -c $dir/pre_ali.JOB.gz|" ark:- \| \ -# weight-silence-post 0.0 $silphonelist $alimdl ark:- ark:- \| \ -# gmm-est-fmllr --fmllr-update-type=$fmllr_update_type \ -# --spk2utt=ark:$sdata/JOB/spk2utt $mdl "$sifeats" \ -# ark,s,cs:- ark:$dir/trans.JOB || exit 1; + else + $cmd JOB=1:$nj $dir/log/fmllr.JOB.log \ + ali-to-post "ark:gunzip -c $dir/pre_ali.JOB.gz|" ark:- \| \ + weight-silence-post 0.0 $silphonelist $alimdl ark:- ark:- \| \ + gmm-est-basis-fmllr $basis_fmllr_opts --spk2utt=ark:$sdata/JOB/spk2utt \ + $mdl $srcdir/fmllr.basis "$sifeats" \ + ark,s,cs:- ark:$dir/trans.JOB || exit 1; fi fi -feats="$sifeats transform-feats ark:$dir/trans.JOB ark:- ark:- |" +feats="$sifeats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk ark:$dir/trans.JOB ark:- ark:- |" if [ $stage -le 3 ]; then echo "$0: doing final alignment." diff --git a/egs/wsj/s5/steps/align_basis_fmllr_lats.sh b/egs/wsj/s5/steps/align_basis_fmllr_lats.sh new file mode 100755 index 00000000000..426168496cc --- /dev/null +++ b/egs/wsj/s5/steps/align_basis_fmllr_lats.sh @@ -0,0 +1,184 @@ +#!/bin/bash +# +# Copyright 2012-2015 Johns Hopkins University (Author: Daniel Povey) +# Apache 2.0 + +# Version of align_fmllr_lats.sh that uses "basis fMLLR", so it is suitable for +# situations where there is very little data per speaker (e.g. when there is a +# one-to-one mapping between utterances and speakers). Intended for use where +# the model was trained with basis-fMLLR (i.e. when you trained the model with +# train_sat_basis.sh where you normally would have trained with train_sat.sh), +# or when it was trained with SAT but you ran get_fmllr_basis.sh on the +# source-model directory. + +# Begin configuration section. +stage=0 +nj=4 +cmd=run.pl +# Begin configuration. +scale_opts="--transition-scale=1.0 --self-loop-scale=0.1" +acoustic_scale=0.1 +beam=10 +retry_beam=40 +final_beam=20 # For the lattice-generation phase there is no retry-beam. This + # is a limitation of gmm-latgen-faster. We just use an + # intermediate beam. We'll lose a little data and it will be + # slightly slower. (however, the min-active of 200 that + # gmm-latgen-faster defaults to may help.) +boost_silence=1.0 # factor by which to boost silence during alignment. +basis_fmllr_opts="--fmllr-min-count=22 --num-iters=10 --size-scale=0.2 --step-size-iters=3" + +generate_ali_from_lats=false # If true, alingments generated from lattices. +# End configuration options. + +echo "$0 $@" # Print the command line for logging + +[ -f path.sh ] && . ./path.sh # source the path. +. parse_options.sh || exit 1; + +if [ $# != 4 ]; then + echo "usage: steps/align_fmllr_lats.sh " + echo "e.g.: steps/align_fmllr_lats.sh data/train data/lang exp/tri1 exp/tri1_lats" + echo "main options (for others, see top of script file)" + echo " --config # config containing options" + echo " --nj # number of parallel jobs" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + exit 1; +fi + +data=$1 +lang=$2 +srcdir=$3 +dir=$4 + +if [ ! -f $srcdir/fmllr.basis ]; then + echo "$0: expected $srcdir/fmllr.basis to exist. Run get_fmllr_basis.sh on $srcdir." +fi + +for f in $data/feats.scp $lang/phones.txt $srcdir/final.mdl; do + [ ! -f $f ] && echo "$0: expected file $f to exist" && exit 1 +done + + +oov=`cat $lang/oov.int` || exit 1; +silphonelist=`cat $lang/phones/silence.csl` || exit 1; +sdata=$data/split$nj + +mkdir -p $dir/log +echo $nj > $dir/num_jobs +[[ -d $sdata && $data/feats.scp -ot $sdata ]] || split_data.sh $data $nj || exit 1; + +utils/lang/check_phones_compatible.sh $lang/phones.txt $srcdir/phones.txt || exit 1; +cp $lang/phones.txt $dir || exit 1; + +cp $srcdir/{tree,final.mdl} $dir || exit 1; +cp $srcdir/final.alimdl $dir 2>/dev/null +cp $srcdir/final.occs $dir; +splice_opts=`cat $srcdir/splice_opts 2>/dev/null` # frame-splicing options. +cp $srcdir/splice_opts $dir 2>/dev/null # frame-splicing options. +cmvn_opts=`cat $srcdir/cmvn_opts 2>/dev/null` +cp $srcdir/cmvn_opts $dir 2>/dev/null # cmn/cmvn option. +delta_opts=`cat $srcdir/delta_opts 2>/dev/null` +cp $srcdir/delta_opts $dir 2>/dev/null + +if [ -f $srcdir/final.mat ]; then feat_type=lda; else feat_type=delta; fi +echo "$0: feature type is $feat_type" + +case $feat_type in + delta) sifeats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | add-deltas $delta_opts ark:- ark:- |";; + lda) sifeats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $srcdir/final.mat ark:- ark:- |" + cp $srcdir/final.mat $dir + cp $srcdir/full.mat $dir 2>/dev/null + ;; + *) echo "Invalid feature type $feat_type" && exit 1; +esac + +## Set up model and alignment model. +mdl=$srcdir/final.mdl +if [ -f $srcdir/final.alimdl ]; then + alimdl=$srcdir/final.alimdl +else + alimdl=$srcdir/final.mdl +fi +[ ! -f $mdl ] && echo "$0: no such model $mdl" && exit 1; +alimdl_cmd="gmm-boost-silence --boost=$boost_silence `cat $lang/phones/optional_silence.csl` $alimdl - |" +mdl_cmd="gmm-boost-silence --boost=$boost_silence `cat $lang/phones/optional_silence.csl` $mdl - |" + + +## because gmm-latgen-faster doesn't support adding the transition-probs to the +## graph itself, we need to bake them into the compiled graphs. This means we can't reuse previously compiled graphs, +## because the other scripts write them without transition probs. +if [ $stage -le 0 ]; then + echo "$0: compiling training graphs" + tra="ark:utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt $sdata/JOB/text|"; + $cmd JOB=1:$nj $dir/log/compile_graphs.JOB.log \ + compile-train-graphs --read-disambig-syms=$lang/phones/disambig.int $scale_opts $dir/tree $dir/final.mdl $lang/L.fst "$tra" \ + "ark:|gzip -c >$dir/fsts.JOB.gz" || exit 1; +fi + + +if [ $stage -le 1 ]; then + # Note: we need to set --transition-scale=0.0 --self-loop-scale=0.0 because, + # as explained above, we compiled the transition probs into the training + # graphs. + echo "$0: aligning data in $data using $alimdl and speaker-independent features." + $cmd JOB=1:$nj $dir/log/align_pass1.JOB.log \ + gmm-align-compiled --transition-scale=0.0 --self-loop-scale=0.0 --acoustic-scale=$acoustic_scale \ + --beam=$beam --retry-beam=$retry_beam "$alimdl_cmd" \ + "ark:gunzip -c $dir/fsts.JOB.gz|" "$sifeats" "ark:|gzip -c >$dir/pre_ali.JOB.gz" || exit 1; +fi + +if [ $stage -le 2 ]; then + echo "$0: computing fMLLR transforms" + if [ "$alimdl" != "$mdl" ]; then + $cmd JOB=1:$nj $dir/log/fmllr.JOB.log \ + ali-to-post "ark:gunzip -c $dir/pre_ali.JOB.gz|" ark:- \| \ + weight-silence-post 0.0 $silphonelist $alimdl ark:- ark:- \| \ + gmm-post-to-gpost $alimdl "$sifeats" ark:- ark:- \| \ + gmm-est-basis-fmllr-gpost $basis_fmllr_opts \ + --spk2utt=ark:$sdata/JOB/spk2utt $mdl $srcdir/fmllr.basis "$sifeats" \ + ark,s,cs:- ark:$dir/trans.JOB || exit 1; + else + $cmd JOB=1:$nj $dir/log/fmllr.JOB.log \ + ali-to-post "ark:gunzip -c $dir/pre_ali.JOB.gz|" ark:- \| \ + weight-silence-post 0.0 $silphonelist $alimdl ark:- ark:- \| \ + gmm-est-basis-fmllr $basis_fmllr_opts \ + --spk2utt=ark:$sdata/JOB/spk2utt $mdl $srcdir/fmllr.basis "$sifeats" \ + ark,s,cs:- ark:$dir/trans.JOB || exit 1; + fi +fi + +feats="$sifeats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk ark:$dir/trans.JOB ark:- ark:- |" + +if [ $stage -le 3 ]; then + # Warning: gmm-latgen-faster doesn't support a retry-beam so you may get more + # alignment errors (however, it does have a default min-active=200 so this + # will tend to reduce alignment errors). + # --allow_partial=false makes sure we reach the end of the decoding graph. + # --word-determinize=false makes sure we retain the alternative pronunciations of + # words (including alternatives regarding optional silences). + # --lattice-beam=$beam keeps all the alternatives that were within the beam, + # it means we do no pruning of the lattice (lattices from a training transcription + # will be small anyway). + echo "$0: generating lattices containing alternate pronunciations." + $cmd JOB=1:$nj $dir/log/generate_lattices.JOB.log \ + gmm-latgen-faster --acoustic-scale=$acoustic_scale --beam=$final_beam \ + --lattice-beam=$final_beam --allow-partial=false --word-determinize=false \ + "$mdl_cmd" "ark:gunzip -c $dir/fsts.JOB.gz|" "$feats" \ + "ark:|gzip -c >$dir/lat.JOB.gz" || exit 1; +fi + +if [ $stage -le 4 ] && $generate_ali_from_lats; then + # If generate_alignments is true, ali.*.gz is generated in lats dir + $cmd JOB=1:$nj $dir/log/generate_alignments.JOB.log \ + lattice-best-path --acoustic-scale=$acoustic_scale "ark:gunzip -c $dir/lat.JOB.gz |" \ + ark:/dev/null "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1; +fi + +rm $dir/pre_ali.*.gz 2>/dev/null || true + +echo "$0: done generating lattices from training transcripts." + +utils/summarize_warnings.pl $dir/log + +exit 0; diff --git a/egs/wsj/s5/steps/align_fmllr_lats.sh b/egs/wsj/s5/steps/align_fmllr_lats.sh index 187d9bf5687..b47b97ef994 100755 --- a/egs/wsj/s5/steps/align_fmllr_lats.sh +++ b/egs/wsj/s5/steps/align_fmllr_lats.sh @@ -5,7 +5,7 @@ # Version of align_fmllr.sh that generates lattices (lat.*.gz) with # alignments of alternative pronunciations in them. Mainly intended -# as a precursor to CTC training for now. +# as a precursor to LF-MMI/chain training for now. # Begin configuration section. stage=0 diff --git a/egs/wsj/s5/steps/cleanup/clean_and_segment_data.sh b/egs/wsj/s5/steps/cleanup/clean_and_segment_data.sh index 670e6c2b714..fb386fa244f 100755 --- a/egs/wsj/s5/steps/cleanup/clean_and_segment_data.sh +++ b/egs/wsj/s5/steps/cleanup/clean_and_segment_data.sh @@ -6,9 +6,9 @@ # This script demonstrates how to re-segment training data selecting only the # "good" audio that matches the transcripts. -# The basic idea is to decode with an existing in-domain acoustic model, and a -# biased language model built from the reference, and then work out the -# segmentation from a ctm like file. +# The basic idea is to decode with an existing in-domain GMM acoustic model, and +# a biased language model built from the reference transcript, and then work out +# the segmentation from a ctm like file. set -e -o pipefail @@ -179,7 +179,7 @@ if [ $stage -le 8 ]; then # the apply_map command below gives us lines of the form 'utt dur-from-$data/utt2dur dur-from-utt2dur.from_ctm', # e.g. AMI_EN2001a_H00_MEE068_0000557_0000594 0.37 0.35 utils/apply_map.pl -f 1 <(awk '{print $1,$1,$2}' <$data/utt2dur) <$dir/utt2dur.from_ctm | \ - awk '{printf("%.3f\n", $2 - $3); }' | sort | uniq -c > $dir/padding_frequencies + awk '{printf("%.3f\n", $2 - $3); }' | sort | uniq -c | sort -nr > $dir/padding_frequencies # there are values other than the most-frequent one (0.02) in there because # of wav files that were shorter than the segment info. padding=$(head -n 1 $dir/padding_frequencies | awk '{print $2}') @@ -206,7 +206,7 @@ fi if $cleanup; then echo "$0: cleaning up intermediate files" - rm -r $dir/fsts $dir/HCLG.fsts.scp || true + rm -r $dir/graphs/fsts $dir/graphs/HCLG.fsts.scp || true rm -r $dir/lats/lat.*.gz $dir/lats/split_fsts || true rm $dir/lattice_oracle/lat.*.gz || true fi diff --git a/egs/wsj/s5/steps/cleanup/clean_and_segment_data_nnet3.sh b/egs/wsj/s5/steps/cleanup/clean_and_segment_data_nnet3.sh new file mode 100755 index 00000000000..35b07d184f4 --- /dev/null +++ b/egs/wsj/s5/steps/cleanup/clean_and_segment_data_nnet3.sh @@ -0,0 +1,270 @@ +#!/bin/bash + +# Copyright 2016 Vimal Manohar +# 2016 Johns Hopkins University (author: Daniel Povey) +# Apache 2.0 + +# This script demonstrates how to re-segment training data selecting only the +# "good" audio that matches the transcripts. +# This script is like clean_and_segment_data.sh, but uses nnet3 model instead of +# a GMM for decoding. +# The basic idea is to decode with an existing in-domain nnet3 acoustic model, +# and a biased language model built from the reference transcript, and then work +# out the segmentation from a ctm like file. + +set -e +set -o pipefail +set -u + +stage=0 + +cmd=run.pl +cleanup=true # remove temporary directories and files +nj=4 +# Decode options +graph_opts= +beam=15.0 +lattice_beam=1.0 + +acwt=0.1 # Just a default value, used for adaptation and beam-pruning.. +post_decode_acwt=1.0 # can be used in 'chain' systems to scale acoustics by 10 so the + # regular scoring script works. + +# Contexts must ideally match training +extra_left_context=0 # Set to some large value, typically 40 for LSTM (must match training) +extra_right_context=0 +extra_left_context_initial=-1 +extra_right_context_final=-1 +frames_per_chunk=150 + +# i-vector options +extractor= # i-Vector extractor. If provided, will extract i-vectors. + # Required if the network was trained with i-vector extractor. +use_vad=false # Use energy-based VAD for i-vector extraction + +segmentation_opts= + +. ./path.sh +. utils/parse_options.sh + + +if [ $# -ne 5 ]; then + cat <] [options] + This script does data cleanup to remove bad portions of transcripts and + may do other minor modifications of transcripts such as allowing repetitions + for disfluencies, and adding or removing non-scored words (by default: + words that map to 'silence phones') + Note: is expected to contain a nnet3-based model. + and decoding options like --extra-left-context must match + the appropriate options used for training. + + e.g. $0 data/train data/lang exp/tri3 exp/tri3_cleanup data/train_cleaned + main options (for others, see top of script file): + --stage # stage to run from, to enable resuming from partially + # completed run (default: 0) + --cmd '$cmd' # command to submit jobs with (e.g. run.pl, queue.pl) + --nj # number of parallel jobs to use in graph creation and + # decoding + --graph-opts 'opts' # Additional options to make_biased_lm_graphs.sh. + # Please run steps/cleanup/make_biased_lm_graphs.sh + # without arguments to see allowed options. + --segmentation-opts 'opts' # Additional options to segment_ctm_edits.py. + # Please run steps/cleanup/internal/segment_ctm_edits.py + # without arguments to see allowed options. + --cleanup # Clean up intermediate files afterward. Default true. + --extractor # i-vector extractor directory if i-vector is + # to be used during decoding. Must match + # the extractor used for training neural-network. + --use-vad # If true, uses energy-based VAD to apply frame weights + # for i-vector stats extraction +EOF + exit 1 +fi + +data=$1 +lang=$2 +srcdir=$3 +dir=$4 +data_out=$5 + + +extra_files= +if [ ! -z "$extractor" ]; then + extra_files="$extractor/final.ie" +fi + +for f in $srcdir/{final.mdl,tree,cmvn_opts} $data/utt2spk $data/feats.scp \ + $lang/words.txt $lang/oov.txt $extra_files; do + if [ ! -f $f ]; then + echo "$0: expected file $f to exist." + exit 1 + fi +done + +mkdir -p $dir +cp $srcdir/final.mdl $dir +cp $srcdir/tree $dir +cp $srcdir/cmvn_opts $dir +cp $srcdir/{splice_opts,delta_opts,final.mat,final.alimdl} $dir 2>/dev/null || true +cp $srcdir/frame_subsampling_factor $dir 2>/dev/null || true + +utils/lang/check_phones_compatible.sh $lang/phones.txt $srcdir/phones.txt +cp $lang/phones.txt $dir + +if [ $stage -le 1 ]; then + echo "$0: Building biased-language-model decoding graphs..." + + + steps/cleanup/make_biased_lm_graphs.sh $graph_opts \ + --nj $nj --cmd "$cmd" \ + $data $lang $dir $dir/graphs +fi + +online_ivector_dir= +if [ ! -z "$extractor" ]; then + online_ivector_dir=$dir/ivectors_$(basename $data) + + if [ $stage -le 2 ]; then + # Compute energy-based VAD + if $use_vad; then + steps/compute_vad_decision.sh $data \ + $data/log $data/data + fi + + steps/online/nnet2/extract_ivectors_online.sh \ + --nj $nj --cmd "$cmd --mem 4G" --use-vad $use_vad \ + $data $extractor $online_ivector_dir + fi +fi + +if [ $stage -le 3 ]; then + echo "$0: Decoding with biased language models..." + + steps/cleanup/decode_segmentation_nnet3.sh \ + --acwt $acwt --post-decode-acwt $post_decode_acwt \ + --beam $beam --lattice-beam $lattice_beam --nj $nj --cmd "$cmd --mem 4G" \ + --skip-scoring true --allow-partial false \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --extra-left-context-initial $extra_left_context_initial \ + --extra-right-context-final $extra_right_context_final \ + --frames-per-chunk $frames_per_chunk \ + ${online_ivector_dir:+--online-ivector-dir $online_ivector_dir} \ + $dir/graphs $data $dir/lats + + # the following is for diagnostics, e.g. it will give us the lattice depth. + steps/diagnostic/analyze_lats.sh --cmd "$cmd" $lang $dir/lats +fi + +frame_shift_opt= +if [ -f $srcdir/frame_subsampling_factor ]; then + frame_shift_opt="--frame-shift 0.0$(cat $srcdir/frame_subsampling_factor)" +fi + +if [ $stage -le 4 ]; then + echo "$0: Doing oracle alignment of lattices..." + steps/cleanup/lattice_oracle_align.sh --cmd "$cmd --mem 4G" $frame_shift_opt \ + $data $lang $dir/lats $dir/lattice_oracle +fi + + +if [ $stage -le 4 ]; then + echo "$0: using default values of non-scored words..." + + # At the level of this script we just hard-code it that non-scored words are + # those that map to silence phones (which is what get_non_scored_words.py + # gives us), although this could easily be made user-configurable. This list + # of non-scored words affects the behavior of several of the data-cleanup + # scripts; essentially, we view the non-scored words as negotiable when it + # comes to the reference transcript, so we'll consider changing the reference + # to match the hyp when it comes to these words. + steps/cleanup/internal/get_non_scored_words.py $lang > $dir/non_scored_words.txt +fi + +if [ $stage -le 5 ]; then + echo "$0: modifying ctm-edits file to allow repetitions [for dysfluencies] and " + echo " ... to fix reference mismatches involving non-scored words. " + + $cmd $dir/log/modify_ctm_edits.log \ + steps/cleanup/internal/modify_ctm_edits.py --verbose=3 $dir/non_scored_words.txt \ + $dir/lattice_oracle/ctm_edits $dir/ctm_edits.modified + + echo " ... See $dir/log/modify_ctm_edits.log for details and stats, including" + echo " a list of commonly-repeated words." +fi + +if [ $stage -le 6 ]; then + echo "$0: applying 'taint' markers to ctm-edits file to mark silences and" + echo " ... non-scored words that are next to errors." + $cmd $dir/log/taint_ctm_edits.log \ + steps/cleanup/internal/taint_ctm_edits.py $dir/ctm_edits.modified $dir/ctm_edits.tainted + echo "... Stats, including global cor/ins/del/sub stats, are in $dir/log/taint_ctm_edits.log." +fi + + +if [ $stage -le 7 ]; then + echo "$0: creating segmentation from ctm-edits file." + + $cmd $dir/log/segment_ctm_edits.log \ + steps/cleanup/internal/segment_ctm_edits.py \ + $segmentation_opts \ + --oov-symbol-file=$lang/oov.txt \ + --ctm-edits-out=$dir/ctm_edits.segmented \ + --word-stats-out=$dir/word_stats.txt \ + $dir/non_scored_words.txt \ + $dir/ctm_edits.tainted $dir/text $dir/segments + + echo "$0: contents of $dir/log/segment_ctm_edits.log are:" + cat $dir/log/segment_ctm_edits.log + echo "For word-level statistics on p(not-being-in-a-segment), with 'worst' words at the top," + echo "see $dir/word_stats.txt" + echo "For detailed utterance-level debugging information, see $dir/ctm_edits.segmented" +fi + +if [ $stage -le 8 ]; then + echo "$0: working out required segment padding to account for feature-generation edge effects." + # make sure $data/utt2dur exists. + utils/data/get_utt2dur.sh $data + # utt2dur.from_ctm contains lines of the form 'utt dur', e.g. + # AMI_EN2001a_H00_MEE068_0000557_0000594 0.35 + # where the times are ultimately derived from the num-frames in the features. + cat $dir/lattice_oracle/ctm_edits | \ + awk '{utt=$1; t=$3+$4; if (t > dur[$1]) dur[$1] = t; } END{for (k in dur) print k, dur[k];}' | \ + sort > $dir/utt2dur.from_ctm + # the apply_map command below gives us lines of the form 'utt dur-from-$data/utt2dur dur-from-utt2dur.from_ctm', + # e.g. AMI_EN2001a_H00_MEE068_0000557_0000594 0.37 0.35 + utils/apply_map.pl -f 1 <(awk '{print $1,$1,$2}' <$data/utt2dur) <$dir/utt2dur.from_ctm | \ + awk '{printf("%.3f\n", $2 - $3); }' | sort | uniq -c | sort -nr > $dir/padding_frequencies + # there are values other than the most-frequent one (0.02) in there because + # of wav files that were shorter than the segment info. + padding=$(head -n 1 $dir/padding_frequencies | awk '{print $2}') + echo "$0: we'll pad segments with $padding seconds at segment ends to correct for feature-generation end effects" + echo $padding >$dir/segment_end_padding +fi + + +if [ $stage -le 8 ]; then + echo "$0: based on the segments and text file in $dir/segments and $dir/text, creating new data-dir in $data_out" + padding=$(cat $dir/segment_end_padding) # e.g. 0.02 + utils/data/subsegment_data_dir.sh --segment-end-padding $padding ${data} $dir/segments $dir/text $data_out + # utils/data/subsegment_data_dir.sh can output directories that have e.g. to many entries left in wav.scp + # Clean this up with the fix_dat_dir.sh script + utils/fix_data_dir.sh $data_out +fi + +if [ $stage -le 9 ]; then + echo "$0: recomputing CMVN stats for the new data" + # Caution: this script puts the CMVN stats in $data_out/data, + # e.g. data/train_cleaned/data. This is not the general pattern we use. + steps/compute_cmvn_stats.sh $data_out $data_out/log $data_out/data +fi + +if $cleanup; then + echo "$0: cleaning up intermediate files" + rm -r $dir/graphs/fsts $dir/graphs/HCLG.fsts.scp || true + rm -r $dir/lats/lat.*.gz $dir/lats/split_fsts || true + rm $dir/lattice_oracle/lat.*.gz || true +fi + +echo "$0: done." diff --git a/egs/wsj/s5/steps/cleanup/combine_short_segments.py b/egs/wsj/s5/steps/cleanup/combine_short_segments.py index 1d14bd2a57f..099b92882a9 100755 --- a/egs/wsj/s5/steps/cleanup/combine_short_segments.py +++ b/egs/wsj/s5/steps/cleanup/combine_short_segments.py @@ -284,7 +284,7 @@ def CombineSegments(input_dir, output_dir, minimum_duration): assert(cur_utt_dur == combined_duration) # now modify the utts list - combined_indices = range(left_index, right_index + 1) + combined_indices = list(range(left_index, right_index + 1)) # start popping from the largest index so that the lower # indexes are valid for i in combined_indices[::-1]: diff --git a/egs/wsj/s5/steps/cleanup/debug_lexicon.sh b/egs/wsj/s5/steps/cleanup/debug_lexicon.sh index 9091764924a..eca807ad247 100755 --- a/egs/wsj/s5/steps/cleanup/debug_lexicon.sh +++ b/egs/wsj/s5/steps/cleanup/debug_lexicon.sh @@ -113,23 +113,24 @@ if [ $stage -le 8 ]; then grep -v '' $phone_lang/phones.txt | awk '{print $1, $1}' | \ sed 's/_B$//' | sed 's/_I$//' | sed 's/_E$//' | sed 's/_S$//' >$dir/phone_map.txt - cat $dir/phone.ctm | utils/apply_map.pl -f 5 $dir/phone_map.txt > $dir/phone_text.ctm > $dir/phone_mapped.ctm export LC_ALL=C + cat $dir/phone.ctm | utils/apply_map.pl -f 5 $dir/phone_map.txt > $dir/phone_mapped.ctm + cat $dir/word.ctm | awk '{printf("%s-%s %010.0f START %s\n", $1, $2, 1000*$3, $5); printf("%s-%s %010.0f END %s\n", $1, $2, 1000*($3+$4), $5);}' | \ sort > $dir/word_processed.ctm # filter out those utteraces which only appea in phone_processed.ctm but not in word_processed.ctm cat $dir/phone_mapped.ctm | awk '{printf("%s-%s %010.0f PHONE %s\n", $1, $2, 1000*($3+(0.5*$4)), $5);}' | \ - awk 'NR==FNR{a[$1] = 1; next} {if($1 in a) print $0}' $dir/word_processed.ctm - \ - > $dir/phone_processed.ctm + awk 'NR==FNR{a[$1] = 1; next} {if($1 in a) print $0}' $dir/word_processed.ctm - | \ + sort > $dir/phone_processed.ctm # merge-sort both ctm's sort -m $dir/word_processed.ctm $dir/phone_processed.ctm > $dir/combined.ctm fi - # after merge-sort of the two ctm's, we add to cover "deserted" phones due to precision limits, and then merge all consecutive 's. +# after merge-sort of the two ctm's, we add to cover "deserted" phones due to precision limits, and then merge all consecutive 's. if [ $stage -le 9 ]; then awk '{print $1, $3, $4}' $dir/combined.ctm | \ perl -e ' while (<>) { chop; @A = split(" ", $_); ($utt, $a,$b) = @A; @@ -137,14 +138,14 @@ if [ $stage -le 9 ]; then if ($a eq "END") { print $utt, " ", $cur_word, " ", join(" ", @phones), "\n"; } if ($a eq "PHONE") { if ($prev eq "END") {print $utt, " ", "", " ", $b, "\n";} else {push @phones, $b;}} $prev = $a;} ' |\ awk 'BEGIN{merge_prev=0;} {utt=$1;word=$2;pron=$3;for (i=4;i<=NF;i++) pron=pron" "$i; - if (word_prev == "" && word == "" && utt_prev == utt) {merge=0;pron_prev=pron_prev" "pron;} else {merge=1;} + if (word_prev == "" && word == "" && utt_prev == utt) {merge=0;pron_prev=pron_prev" "pron;} else {merge=1;} if(merge_prev==1) {print utt_prev, word_prev, pron_prev;}; merge_prev=merge; utt_prev=utt; word_prev=word; pron_prev=pron;} END{if(merge_prev==1) {print utt_prev, word_prev, pron_prev;}}' > $dir/ctm_prons.txt - + steps/cleanup/internal/get_non_scored_words.py $lang > $dir/non_scored_words steps/cleanup/internal/get_pron_stats.py $dir/ctm_prons.txt $phone_lang/phones/silence.txt $phone_lang/phones/optional_silence.txt $dir/non_scored_words - | \ - sort -nr > $dir/prons.txt + sort -nr > $dir/prons.txt fi if [ $stage -le 10 ]; then diff --git a/egs/wsj/s5/steps/cleanup/decode_segmentation_nnet3.sh b/egs/wsj/s5/steps/cleanup/decode_segmentation_nnet3.sh new file mode 100755 index 00000000000..02a9d87d26b --- /dev/null +++ b/egs/wsj/s5/steps/cleanup/decode_segmentation_nnet3.sh @@ -0,0 +1,174 @@ +#!/bin/bash + +# Copyright 2014 Guoguo Chen, 2015 GoVivace Inc. (Nagendra Goel) +# 2017 Vimal Manohar +# Apache 2.0 + +# This script is similar to steps/cleanup/decode_segmentation.sh, but +# does decoding using nnet3 model. + +set -e +set -o pipefail + +# Begin configuration section. +stage=-1 +nj=4 # number of decoding jobs. +acwt=0.1 # Just a default value, used for adaptation and beam-pruning.. +post_decode_acwt=1.0 # can be used in 'chain' systems to scale acoustics by 10 so the + # regular scoring script works. +cmd=run.pl +beam=15.0 +frames_per_chunk=50 +max_active=7000 +min_active=200 +ivector_scale=1.0 +lattice_beam=8.0 # Beam we use in lattice generation. We can reduce this if + # we only need the best path +iter=final +num_threads=1 # if >1, will use gmm-latgen-faster-parallel +scoring_opts= +skip_scoring=false +allow_partial=true +extra_left_context=0 +extra_right_context=0 +extra_left_context_initial=-1 +extra_right_context_final=-1 +online_ivector_dir= +minimize=false +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +[ -f ./path.sh ] && . ./path.sh; # source the path. +. utils/parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "$0: This is a special decoding script for segmentation where we" + echo "use one decoding graph per segment. We assume a file HCLG.fsts.scp exists" + echo "which is the scp file of the graphs for each segment." + echo "This will normally be obtained by steps/cleanup/make_biased_lm_graphs.sh." + echo "" + echo "Usage: $0 [options] " + echo " e.g.: $0 --online-ivector-dir exp/nnet3/ivectors_train_si284_split " + echo " exp/nnet3/tdnn/graph_train_si284_split \\" + echo " data/train_si284_split exp/nnet3/tdnn/decode_train_si284_split" + echo "" + echo "where is assumed to be a sub-directory of the directory" + echo "where the model is." + echo "" + echo "main options (for others, see top of script file)" + echo " --config # config containing options" + echo " --nj # number of parallel jobs" + echo " --iter # Iteration of model to test." + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --acwt # acoustic scale used for lattice generation " + echo " --scoring-opts # options to local/score.sh" + echo " --num-threads # number of threads to use, default 1." + exit 1; +fi + + +graphdir=$1 +data=$2 +dir=$3 + +mkdir -p $dir/log + +if [ -e $dir/$iter.mdl ]; then + srcdir=$dir +elif [ -e $dir/../$iter.mdl ]; then + srcdir=$(dirname $dir) +else + echo "$0: expected either $dir/$iter.mdl or $dir/../$iter.mdl to exist" + exit 1 +fi +model=$srcdir/$iter.mdl + + +extra_files= +if [ ! -z "$online_ivector_dir" ]; then + steps/nnet2/check_ivectors_compatible.sh $srcdir $online_ivector_dir || exit 1 + extra_files="$online_ivector_dir/ivector_online.scp $online_ivector_dir/ivector_period" +fi + +utils/lang/check_phones_compatible.sh $graph_dir/phones.txt $srcdir/phones.txt || exit 1 + +for f in $graphdir/HCLG.fsts.scp $data/feats.scp $model $extra_files; do + [ ! -f $f ] && echo "$0: no such file $f" && exit 1; +done + +sdata=$data/split$nj; +cmvn_opts=`cat $srcdir/cmvn_opts` || exit 1; +thread_string= +[ $num_threads -gt 1 ] && thread_string="-parallel --num-threads=$num_threads" + +mkdir -p $dir/log +[[ -d $sdata && $data/feats.scp -ot $sdata ]] || split_data.sh $data $nj || exit 1; +echo $nj > $dir/num_jobs + +# Split HCLG.fsts.scp by input utterance +n1=$(cat $graphdir/HCLG.fsts.scp | wc -l) +n2=$(cat $data/feats.scp | wc -l) +if [ $n1 != $n2 ]; then + echo "$0: expected $n2 graphs in $graphdir/HCLG.fsts.scp, got $n1" +fi + +mkdir -p $dir/split_fsts +sort -k1,1 $graphdir/HCLG.fsts.scp > $dir/HCLG.fsts.sorted.scp +utils/filter_scps.pl --no-warn -f 1 JOB=1:$nj \ + $sdata/JOB/feats.scp $dir/HCLG.fsts.sorted.scp $dir/split_fsts/HCLG.fsts.JOB.scp +HCLG=scp:$dir/split_fsts/HCLG.fsts.JOB.scp + +## Set up features. +echo "$0: feature type is raw" + +feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- |" + +if [ ! -z "$online_ivector_dir" ]; then + ivector_period=$(cat $online_ivector_dir/ivector_period) || exit 1; + ivector_opts="--online-ivectors=scp:$online_ivector_dir/ivector_online.scp --online-ivector-period=$ivector_period" +fi + +if [ "$post_decode_acwt" == 1.0 ]; then + lat_wspecifier="ark:|gzip -c >$dir/lat.JOB.gz" +else + lat_wspecifier="ark:|lattice-scale --acoustic-scale=$post_decode_acwt ark:- ark:- | gzip -c >$dir/lat.JOB.gz" +fi + +frame_subsampling_opt= +if [ -f $srcdir/frame_subsampling_factor ]; then + # e.g. for 'chain' systems + frame_subsampling_opt="--frame-subsampling-factor=$(cat $srcdir/frame_subsampling_factor)" +fi + +if [ $stage -le 1 ]; then + if [ -f "$graphdir/num_pdfs" ]; then + [ "`cat $graphdir/num_pdfs`" -eq `am-info --print-args=false $model | grep pdfs | awk '{print $NF}'` ] || \ + { echo "Mismatch in number of pdfs with $model"; exit 1; } + fi + $cmd --num-threads $num_threads JOB=1:$nj $dir/log/decode.JOB.log \ + nnet3-latgen-faster$thread_string $ivector_opts $frame_subsampling_opt \ + --frames-per-chunk=$frames_per_chunk \ + --extra-left-context=$extra_left_context \ + --extra-right-context=$extra_right_context \ + --extra-left-context-initial=$extra_left_context_initial \ + --extra-right-context-final=$extra_right_context_final \ + --minimize=$minimize --max-active=$max_active --min-active=$min_active --beam=$beam \ + --lattice-beam=$lattice_beam --acoustic-scale=$acwt --allow-partial=$allow_partial \ + --word-symbol-table=$graphdir/words.txt "$model" \ + "$HCLG" "$feats" "$lat_wspecifier" || exit 1; +fi + + +if [ $stage -le 2 ]; then + if ! $skip_scoring ; then + [ ! -x local/score.sh ] && \ + echo "$0: Not scoring because local/score.sh does not exist or not executable." && exit 1; + iter_opt= + [ "$iter" != "final" ] && iter_opt="--iter $iter" + local/score.sh $iter_opt $scoring_opts --cmd "$cmd" $data $graphdir $dir || + { echo "$0: Scoring failed. (ignore by '--skip-scoring true')"; exit 1; } + fi +fi +echo "Decoding done." +exit 0; diff --git a/egs/wsj/s5/steps/cleanup/internal/align_ctm_ref.py b/egs/wsj/s5/steps/cleanup/internal/align_ctm_ref.py index 848ca61ebe4..d3e012da13c 100755 --- a/egs/wsj/s5/steps/cleanup/internal/align_ctm_ref.py +++ b/egs/wsj/s5/steps/cleanup/internal/align_ctm_ref.py @@ -127,7 +127,7 @@ def read_text(text_file): "Did not get enough columns; line {0} in {1}" "".format(line, text_file.name)) elif len(parts) == 1: - logger.warn("Empty transcript for utterance %s in %s", + logger.warn("Empty transcript for utterance %s in %s", parts[0], text_file.name) yield parts[0], [] else: diff --git a/egs/wsj/s5/steps/cleanup/internal/get_ctm_edits.py b/egs/wsj/s5/steps/cleanup/internal/get_ctm_edits.py index a19c5344572..3032a4b434a 100755 --- a/egs/wsj/s5/steps/cleanup/internal/get_ctm_edits.py +++ b/egs/wsj/s5/steps/cleanup/internal/get_ctm_edits.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # Copyright 2016 Vimal Manohar # 2016 Johns Hopkins University (author: Daniel Povey) @@ -116,17 +116,17 @@ def OpenFiles(): global ctm_edits_out, edits_in, ctm_in, symbol_table, oov_word try: - ctm_edits_out = open(args.ctm_edits_out, 'w') + ctm_edits_out = open(args.ctm_edits_out, 'w', encoding='utf-8') except: sys.exit("get_ctm_edits.py: error opening ctm-edits file {0} for output".format( args.ctm_edits_out)) try: - edits_in = open(args.edits_in) + edits_in = open(args.edits_in, encoding='utf-8') except: sys.exit("get_ctm_edits.py: error opening edits file {0} for input".format( args.edits_in)) try: - ctm_in = open(args.ctm_in) + ctm_in = open(args.ctm_in, encoding='utf-8') except: sys.exit("get_ctm_edits.py: error opening ctm file {0} for input".format( args.ctm_in)) @@ -138,7 +138,7 @@ def OpenFiles(): print("get_ctm_edits.py: error: if you set the the --symbol-table option " "you must also set the --oov option", file = sys.stderr) try: - f = open(args.symbol_table, 'r') + f = open(args.symbol_table, 'r', encoding='utf-8') for line in f.readlines(): [ word, integer ] = line.split() if int(integer) == args.oov: diff --git a/egs/wsj/s5/steps/cleanup/internal/get_non_scored_words.py b/egs/wsj/s5/steps/cleanup/internal/get_non_scored_words.py index aa71fa47d84..69e0242eafb 100755 --- a/egs/wsj/s5/steps/cleanup/internal/get_non_scored_words.py +++ b/egs/wsj/s5/steps/cleanup/internal/get_non_scored_words.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # Copyright 2016 Vimal Manohar # 2016 Johns Hopkins University (author: Daniel Povey) @@ -90,7 +90,7 @@ def read_lang(lang_dir): raise try: - for line in open(lang_dir + '/words.txt').readlines(): + for line in open(lang_dir + '/words.txt', encoding='utf-8').readlines(): [ word, integer ] = line.split() if int(integer) in silence_word_ints: non_scored_words.add(word) diff --git a/egs/wsj/s5/steps/cleanup/internal/get_pron_stats.py b/egs/wsj/s5/steps/cleanup/internal/get_pron_stats.py index 414875f9013..3ea217b6589 100755 --- a/egs/wsj/s5/steps/cleanup/internal/get_pron_stats.py +++ b/egs/wsj/s5/steps/cleanup/internal/get_pron_stats.py @@ -4,6 +4,7 @@ # Apache 2.0. from __future__ import print_function +from __future__ import division import argparse import sys import warnings @@ -74,14 +75,14 @@ def ReadEntries(file_handle): # Each entry in the list represents the pronounciation candidate(s) of a word. # For each non- word, the entry is a list: [utt_id, word, set(pronunciation_candidates)]. e.g: # [911Mothers_2010W-0010916-0012901-1, other, set('AH DH ER', 'AH DH ER K AH N')] -# For each , we split the phones it aligns to into two parts: "nonsil_left", +# For each , we split the phones it aligns to into two parts: "nonsil_left", # which includes phones before the first silphone, and "nonsil_right", which includes -# phones after the last silphone. For example, for : 'V SIL B AH SIL', +# phones after the last silphone. For example, for : 'V SIL B AH SIL', # nonsil_left is 'V' and nonsil_right is empty ''. After processing an entry # in ctm_prons, we put it in "info" as an entry: [utt_id, word, nonsil_right] # only if it's nonsil_right segment is not empty, which may be used when processing # the next word. -# +# # Normally, one non- word is only aligned to one pronounciation candidate. However # when there is a preceding/following , like in the following example, we # assume the phones aligned to should be statistically distributed @@ -89,7 +90,7 @@ def ReadEntries(file_handle): # Thus we append the "nonsil_left" segment of these phones to the pronounciation # of the preceding word, if the last phone of this pronounciation is not a silence phone, # Similarly we can add a pron candidate to the following word. -# +# # For example, for the following part of a ctm_prons file: # 911Mothers_2010W-0010916-0012901-1 other AH DH ER # 911Mothers_2010W-0010916-0012901-1 K AH N SIL B @@ -98,11 +99,11 @@ def ReadEntries(file_handle): # 911Mothers_2010W-0010916-0012901-1 when W EH N # 911Mothers_2010W-0010916-0012901-1 people P IY P AH L # 911Mothers_2010W-0010916-0012901-1 SIL -# 911Mothers_2010W-0010916-0012901-1 heard HH ER +# 911Mothers_2010W-0010916-0012901-1 heard HH ER # 911Mothers_2010W-0010916-0012901-1 D # 911Mothers_2010W-0010916-0012901-1 that SIL DH AH T # 911Mothers_2010W-0010916-0012901-1 my M AY -# +# # The corresponding segment in the "info" list is: # [911Mothers_2010W-0010916-0012901-1, other, set('AH DH ER', 'AH DH ER K AH N')] # [911Mothers_2010W-0010916-0012901-1, , 'B' @@ -112,7 +113,7 @@ def ReadEntries(file_handle): # [911Mothers_2010W-0010916-0012901-1, , 'D'] # [911Mothers_2010W-0010916-0012901-1, that, set('SIL DH AH T')] # [911Mothers_2010W-0010916-0012901-1, my, set('M AY')] -# +# # Then we accumulate pronouciation stats from "info". Basically, for each occurence # of a word, each pronounciation candidate gets equal soft counts. e.g. In the above # example, each pron candidate of "because" gets a count of 1/4. The stats is stored @@ -138,20 +139,20 @@ def GetStatsFromCtmProns(silphones, optional_silence, non_scored_words, ctm_pron # So we apply the same merging method in these cases. if word == '' or (word in non_scored_words and word != '' and word != ''): nonsil_left = [] - nonsil_right = [] + nonsil_right = [] for phone in phones: if phone in silphones: break nonsil_left.append(phone) - + for phone in reversed(phones): if phone in silphones: break nonsil_right.insert(0, phone) - + # info[-1][0] is the utt_id of the last entry - if len(nonsil_left) > 0 and len(info) > 0 and utt == info[-1][0]: - # pron_ext is a set of extended pron candidates. + if len(nonsil_left) > 0 and len(info) > 0 and utt == info[-1][0]: + # pron_ext is a set of extended pron candidates. pron_ext = set() # info[-1][2] is the set of pron candidates of the last entry. for pron in info[-1][2]: @@ -210,8 +211,8 @@ def GetStatsFromCtmProns(silphones, optional_silence, non_scored_words, ctm_pron stats[(word, phones)] = stats.get((word, phones), 0) + count return stats -def WriteStats(stats, file_handle): - for word_pron, count in stats.iteritems(): +def WriteStats(stats, file_handle): + for word_pron, count in stats.items(): print('{0} {1} {2}'.format(count, word_pron[0], word_pron[1]), file=file_handle) file_handle.close() @@ -221,7 +222,7 @@ def Main(): non_scored_words = ReadEntries(args.non_scored_words_file_handle) optional_silence = ReadEntries(args.optional_silence_file_handle) stats = GetStatsFromCtmProns(silphones, optional_silence.pop(), non_scored_words, args.ctm_prons_file_handle) - WriteStats(stats, args.stats_file_handle) + WriteStats(stats, args.stats_file_handle) if __name__ == "__main__": Main() diff --git a/egs/wsj/s5/steps/cleanup/internal/make_one_biased_lm.py b/egs/wsj/s5/steps/cleanup/internal/make_one_biased_lm.py index f37fa866b0f..68055729fd9 100755 --- a/egs/wsj/s5/steps/cleanup/internal/make_one_biased_lm.py +++ b/egs/wsj/s5/steps/cleanup/internal/make_one_biased_lm.py @@ -1,9 +1,10 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # Copyright 2016 Johns Hopkins University (Author: Daniel Povey) # Apache 2.0. from __future__ import print_function +from __future__ import division import sys import argparse import math @@ -47,7 +48,7 @@ -class NgramCounts: +class NgramCounts(object): ## A note on data-structure. ## Firstly, all words are represented as integers. ## We store n-gram counts as an array, indexed by (history-length == n-gram order minus one) @@ -139,24 +140,26 @@ def GetHistToTotalCount(self): # LM-states that would back off to 'this' lm-state, in the total. def CompletelyDiscountLowCountStates(self, min_count): hist_to_total_count = self.GetHistToTotalCount() - for n in reversed(range(2, self.ngram_order)): + for n in reversed(list(range(2, self.ngram_order))): this_order_counts = self.counts[n] + to_delete = [] for hist in this_order_counts.keys(): if hist_to_total_count[hist] < min_count: # we need to completely back off this count. word_to_count = this_order_counts[hist] - del this_order_counts[hist] # delete the key from the dict. + # mark this key for deleting + to_delete.append(hist) backoff_hist = hist[1:] # this will be a tuple not a list. for word, count in word_to_count.items(): self.AddCount(backoff_hist, word, count) - - + for hist in to_delete: + del this_order_counts[hist] # This backs off the counts according to Kneser-Ney (unmodified, # with interpolation). def ApplyBackoff(self, D): assert D > 0.0 and D < 1.0 - for n in reversed(range(1, self.ngram_order)): + for n in reversed(list(range(1, self.ngram_order))): this_order_counts = self.counts[n] for hist, word_to_count in this_order_counts.items(): backoff_hist = hist[1:] @@ -182,7 +185,7 @@ def Print(self, info_string): for this_order_counts in self.counts: for hist, word_to_count in this_order_counts.items(): this_total_count = sum(word_to_count.values()) - print(str(hist) + ': total={0} '.format(this_total_count), + print('{0}: total={1} '.format(hist, this_total_count), end='', file=sys.stderr) print(' '.join(['{0} -> {1} '.format(word, count) for word, count in word_to_count.items() ]), @@ -199,7 +202,7 @@ def AddTopWords(self, top_words_file): word_to_count = self.counts[0][empty_history] total = sum(word_to_count.values()) try: - f = open(top_words_file) + f = open(top_words_file, mode='r', encoding='utf-8') except: sys.exit("make_one_biased_lm.py: error opening top-words file: " "--top-words=" + top_words_file) @@ -242,10 +245,10 @@ def GetHistToStateMap(self): def GetProb(self, hist, word, total_count_map): total_count = total_count_map[hist] word_to_count = self.counts[len(hist)][hist] - prob = word_to_count[word] / total_count + prob = float(word_to_count[word]) / total_count if len(hist) > 0 and word != self.backoff_symbol: prob_in_backoff = self.GetProb(hist[1:], word, total_count_map) - backoff_prob = word_to_count[self.backoff_symbol] / total_count + backoff_prob = float(word_to_count[self.backoff_symbol]) / total_count prob += backoff_prob * prob_in_backoff return prob @@ -262,7 +265,7 @@ def PrintAsFst(self, word_disambig_symbol): hist_to_state = self.GetHistToStateMap() total_count_map = self.GetTotalCountMap() - for n in [ 1, 0 ] + range(2, self.ngram_order): + for n in [ 1, 0 ] + list(range(2, self.ngram_order)): this_order_counts = self.counts[n] # For order 1, make sure the keys are sorted. keys = this_order_counts.keys() if n != 1 else sorted(this_order_counts.keys()) diff --git a/egs/wsj/s5/steps/cleanup/internal/modify_ctm_edits.py b/egs/wsj/s5/steps/cleanup/internal/modify_ctm_edits.py index d6f0d0f6b23..af63ca27d2b 100755 --- a/egs/wsj/s5/steps/cleanup/internal/modify_ctm_edits.py +++ b/egs/wsj/s5/steps/cleanup/internal/modify_ctm_edits.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # Copyright 2016 Vimal Manohar # 2016 Johns Hopkins University (author: Daniel Povey) @@ -105,7 +105,7 @@ def ReadNonScoredWords(non_scored_words_file): global non_scored_words try: - f = open(non_scored_words_file) + f = open(non_scored_words_file, encoding='utf-8') except: sys.exit("modify_ctm_edits.py: error opening file: " "--non-scored-words=" + non_scored_words_file) @@ -317,12 +317,12 @@ def ProcessUtterance(split_lines_of_utt): def ProcessData(): try: - f_in = open(args.ctm_edits_in) + f_in = open(args.ctm_edits_in, encoding='utf-8') except: sys.exit("modify_ctm_edits.py: error opening ctm-edits input " "file {0}".format(args.ctm_edits_in)) try: - f_out = open(args.ctm_edits_out, 'w') + f_out = open(args.ctm_edits_out, 'w', encoding='utf-8') except: sys.exit("modify_ctm_edits.py: error opening ctm-edits output " "file {0}".format(args.ctm_edits_out)) diff --git a/egs/wsj/s5/steps/cleanup/internal/resolve_ctm_edits_overlaps.py b/egs/wsj/s5/steps/cleanup/internal/resolve_ctm_edits_overlaps.py index ad03b557bfe..a123b13f532 100755 --- a/egs/wsj/s5/steps/cleanup/internal/resolve_ctm_edits_overlaps.py +++ b/egs/wsj/s5/steps/cleanup/internal/resolve_ctm_edits_overlaps.py @@ -15,6 +15,7 @@ """ from __future__ import print_function +from __future__ import division import argparse import collections import logging @@ -299,7 +300,7 @@ def run(args): segments, reco2utt = read_segments(args.segments) ctm_edits = read_ctm_edits(args.ctm_edits_in, segments) - for reco, utts in reco2utt.iteritems(): + for reco, utts in reco2utt.items(): ctm_edits_for_reco = [] for utt in sorted(utts, key=lambda x: segments[x][1]): if (reco, utt) in ctm_edits: diff --git a/egs/wsj/s5/steps/cleanup/internal/retrieve_similar_docs.py b/egs/wsj/s5/steps/cleanup/internal/retrieve_similar_docs.py index eb0b18f0408..9594d2ecc60 100755 --- a/egs/wsj/s5/steps/cleanup/internal/retrieve_similar_docs.py +++ b/egs/wsj/s5/steps/cleanup/internal/retrieve_similar_docs.py @@ -223,7 +223,7 @@ def read_map(file_handle, num_values_per_key=None, def get_document_ids(source_docs, indexes): indexes = sorted( - [(key, value[0], value[1]) for key, value in indexes.iteritems()], + [(key, value[0], value[1]) for key, value in indexes.items()], key=lambda x: x[0]) doc_ids = [] @@ -273,7 +273,7 @@ def run(args): "Did not get scores for query {0}".format(query_id)) if args.verbose > 2: - for tup, score in scores.iteritems(): + for tup, score in scores.items(): logger.debug("Score, {num}: {0} {1} {2}".format( tup[0], tup[1], score, num=num_queries)) diff --git a/egs/wsj/s5/steps/cleanup/internal/segment_ctm_edits.py b/egs/wsj/s5/steps/cleanup/internal/segment_ctm_edits.py index 39f6d38d6bf..2ea8f5f6070 100755 --- a/egs/wsj/s5/steps/cleanup/internal/segment_ctm_edits.py +++ b/egs/wsj/s5/steps/cleanup/internal/segment_ctm_edits.py @@ -1,10 +1,12 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 + # Copyright 2016 Vimal Manohar # 2016 Johns Hopkins University (author: Daniel Povey) # Apache 2.0 from __future__ import print_function +from __future__ import division import sys, operator, argparse, os from collections import defaultdict @@ -68,7 +70,7 @@ help="""Minimum duration of silence or non-scored word to be considered a viable split point when truncating based on junk proportion.""") -parser.add_argument("--max-deleted-words-kept-when-merging", type = str, default = 1, +parser.add_argument("--max-deleted-words-kept-when-merging", type = int, default = 1, help = "When merging segments that are found to be overlapping or " "adjacent after all other processing, keep in the transcript the " "reference words that were deleted between the segments [if any] " @@ -171,7 +173,7 @@ def ComputeSegmentCores(split_lines_of_utt): return segment_ranges -class Segment: +class Segment(object): def __init__(self, split_lines_of_utt, start_index, end_index, debug_str = None): self.split_lines_of_utt = split_lines_of_utt # start_index is the index of the first line that appears in this @@ -551,7 +553,7 @@ def PossiblyTruncateStartForJunkProportion(self): if candidate_start_index is None: return # Nothing to do as there is no place to split. candidate_removed_piece_duration = candidate_start_time - self.StartTime() - if begin_junk_duration / candidate_removed_piece_duration < args.max_junk_proportion: + if float(begin_junk_duration) / candidate_removed_piece_duration < args.max_junk_proportion: return # Nothing to do as the candidate piece to remove has too # little junk. # OK, remove the piece. @@ -593,7 +595,7 @@ def PossiblyTruncateEndForJunkProportion(self): if candidate_end_index is None: return # Nothing to do as there is no place to split. candidate_removed_piece_duration = self.EndTime() - candidate_end_time - if end_junk_duration / candidate_removed_piece_duration < args.max_junk_proportion: + if float(end_junk_duration) / candidate_removed_piece_duration < args.max_junk_proportion: return # Nothing to do as the candidate piece to remove has too # little junk. # OK, remove the piece. @@ -807,7 +809,7 @@ def TimeToString(time, frame_length): def WriteSegmentsForUtterance(text_output_handle, segments_output_handle, old_utterance_name, segments): - num_digits = len(str(len(segments))) + num_digits = len('{}'.format(len(segments))) for n in range(len(segments)): segment = segments[n] # split utterances will be named foo-bar-1 foo-bar-2, etc. @@ -840,24 +842,24 @@ def PrintDebugInfoForUtterance(ctm_edits_out_handle, info_to_print = [] for n in range(len(segments_for_utterance)): segment = segments_for_utterance[n] - start_string = 'start-segment-' + str(n+1) + '[' + segment.DebugInfo() + ']' + start_string = 'start-segment-{0}[{1}]'.format(n+1, segment.DebugInfo()) info_to_print.append( (segment.StartTime(), start_string) ) - end_string = 'end-segment-' + str(n+1) + end_string = 'end-segment-{}'.format(n+1) info_to_print.append( (segment.EndTime(), end_string) ) # for segments that were deleted we print info like start-deleted-segment-1, and # otherwise similar info to segments that were retained. for n in range(len(deleted_segments_for_utterance)): segment = deleted_segments_for_utterance[n] - start_string = 'start-deleted-segment-' + str(n+1) + '[' + segment.DebugInfo() + ']' + start_string = 'start-deleted-segment-{0}[{1}]'.format(n+1, segment.DebugInfo()) info_to_print.append( (segment.StartTime(), start_string) ) - end_string = 'end-deleted-segment-' + str(n+1) + end_string = 'end-deleted-segment-{}'.format(n+1) info_to_print.append( (segment.EndTime(), end_string) ) info_to_print = sorted(info_to_print) for i in range(len(split_lines_of_cur_utterance)): split_line=split_lines_of_cur_utterance[i] - split_line[0] += '[' + str(i) + ']' # add an index like [0], [1], to + split_line[0] += '[{}]'.format(i) # add an index like [0], [1], to # the utterance-id so we can easily # look up segment indexes. start_time = float(split_line[2]) @@ -893,7 +895,7 @@ def AccWordStatsForUtterance(split_lines_of_utt, def PrintWordStats(word_stats_out): try: - f = open(word_stats_out, 'w') + f = open(word_stats_out, 'w', encoding='utf-8') except: sys.exit("segment_ctm_edits.py: error opening word-stats file --word-stats-out={0} " "for writing".format(word_stats_out)) @@ -923,23 +925,23 @@ def PrintWordStats(word_stats_out): def ProcessData(): try: - f_in = open(args.ctm_edits_in) + f_in = open(args.ctm_edits_in, encoding='utf-8') except: sys.exit("segment_ctm_edits.py: error opening ctm-edits input " "file {0}".format(args.ctm_edits_in)) try: - text_output_handle = open(args.text_out, 'w') + text_output_handle = open(args.text_out, 'w', encoding='utf-8') except: sys.exit("segment_ctm_edits.py: error opening text output " "file {0}".format(args.text_out)) try: - segments_output_handle = open(args.segments_out, 'w') + segments_output_handle = open(args.segments_out, 'w', encoding='utf-8') except: sys.exit("segment_ctm_edits.py: error opening segments output " "file {0}".format(args.text_out)) if args.ctm_edits_out != None: try: - ctm_edits_output_handle = open(args.ctm_edits_out, 'w') + ctm_edits_output_handle = open(args.ctm_edits_out, 'w', encoding='utf-8') except: sys.exit("segment_ctm_edits.py: error opening ctm-edits output " "file {0}".format(args.ctm_edits_out)) @@ -993,7 +995,7 @@ def ProcessData(): def ReadNonScoredWords(non_scored_words_file): global non_scored_words try: - f = open(non_scored_words_file) + f = open(non_scored_words_file, encoding='utf-8') except: sys.exit("segment_ctm_edits.py: error opening file: " "--non-scored-words=" + non_scored_words_file) @@ -1014,7 +1016,7 @@ def ReadNonScoredWords(non_scored_words_file): oov_symbol = None if args.oov_symbol_file != None: try: - with open(args.oov_symbol_file) as f: + with open(args.oov_symbol_file, encoding='utf-8') as f: line = f.readline() assert len(line.split()) == 1 oov_symbol = line.split()[0] diff --git a/egs/wsj/s5/steps/cleanup/internal/segment_ctm_edits_mild.py b/egs/wsj/s5/steps/cleanup/internal/segment_ctm_edits_mild.py index 46a9369ae98..9fcc2e89360 100755 --- a/egs/wsj/s5/steps/cleanup/internal/segment_ctm_edits_mild.py +++ b/egs/wsj/s5/steps/cleanup/internal/segment_ctm_edits_mild.py @@ -5,6 +5,7 @@ # Apache 2.0 from __future__ import print_function +from __future__ import division import argparse import copy import logging @@ -869,8 +870,7 @@ def relax_boundary_truncation(self, min_segment_length, # a * (length_with_truncation - length_with_relaxed_boundaries) # -> a = (length_cutoff - length_with_relaxed_boundaries) # / (length_with_truncation - length_with_relaxed_boundaries) - a = ((length_cutoff - length_with_relaxed_boundaries) - / (length_with_truncation - length_with_relaxed_boundaries)) + a = (length_cutoff - length_with_relaxed_boundaries) / (length_with_truncation - length_with_relaxed_boundaries) if a < 0.0 or a > 1.0: # TODO(vimal): Should this be an error? _global_logger.warn("bad 'a' value = %.4f", a) @@ -1756,7 +1756,7 @@ def time_to_string(time, frame_length): """ Gives time in string form as an exact multiple of the frame-length, e.g. 0.01 (after rounding). """ - n = round(time / frame_length) + n = round(time /frame_length) assert n >= 0 # The next function call will remove trailing zeros while printing it, so # that e.g. 0.01 will be printed as 0.01 and not 0.0099999999999999. It diff --git a/egs/wsj/s5/steps/cleanup/internal/taint_ctm_edits.py b/egs/wsj/s5/steps/cleanup/internal/taint_ctm_edits.py index 85e1df997a7..4e0e1ae2283 100755 --- a/egs/wsj/s5/steps/cleanup/internal/taint_ctm_edits.py +++ b/egs/wsj/s5/steps/cleanup/internal/taint_ctm_edits.py @@ -201,7 +201,7 @@ def PrintNonScoredStats(): percent_modified, percent_of_incorrect_modified), file = sys.stderr) - keys = sorted(ref_change_stats.keys(), reverse=True, + keys = sorted(list(ref_change_stats.keys()), reverse=True, key = lambda x: ref_change_stats[x]) num_keys_to_print = 40 if args.verbose >= 2 else 10 @@ -219,7 +219,7 @@ def PrintStats(): return print("taint_ctm_edits.py: processed {0} input lines, whose edit-types were: ".format(tot_lines) + ', '.join([ '%s = %.2f%%' % (k, num_lines_of_type[k] * 100.0 / tot_lines) - for k in sorted(num_lines_of_type.keys(), reverse = True, + for k in sorted(list(num_lines_of_type.keys()), reverse = True, key = lambda k: num_lines_of_type[k]) ]), file = sys.stderr) diff --git a/egs/wsj/s5/steps/cleanup/internal/tf_idf.py b/egs/wsj/s5/steps/cleanup/internal/tf_idf.py index 9b2f4d693a6..a098d9f2a44 100644 --- a/egs/wsj/s5/steps/cleanup/internal/tf_idf.py +++ b/egs/wsj/s5/steps/cleanup/internal/tf_idf.py @@ -6,6 +6,7 @@ """ from __future__ import print_function +from __future__ import division import logging import math import re @@ -51,8 +52,7 @@ def get_inverse_document_frequency(self, term, weighting_scheme="log"): if weighting_scheme == "log-smoothed": return math.log(1.0 + float(self.num_docs) / (1.0 + n_t)) if weighting_scheme == "probabilitic": - return math.log((self.num_docs - n_t - 1) - / (1.0 + n_t)) + return math.log((self.num_docs - n_t - 1) / (1.0 + n_t)) def accumulate(self, term): """Adds one count to the number of docs containing the term "term". @@ -66,7 +66,7 @@ def write(self, file_handle): ... for n-gram (, ... ) """ - for term, num in self.num_docs_for_term.iteritems(): + for term, num in self.num_docs_for_term.items(): if num == 0: continue assert isinstance(term, tuple) @@ -135,7 +135,7 @@ def compute_term_stats(self, idf_stats=None): based on the stored raw counts.""" if len(self.raw_counts) == 0: raise RuntimeError("No (term, doc) found in tf-stats.") - for tup, counts in self.raw_counts.iteritems(): + for tup, counts in self.raw_counts.items(): term = tup[0] if counts > self.max_counts_for_term.get(term, 0): @@ -149,7 +149,7 @@ def __str__(self): ... """ lines = [] - for tup, counts in self.raw_counts.iteritems(): + for tup, counts in self.raw_counts.items(): term, doc = tup lines.append("{order} {term} {doc} {counts}".format( order=len(term), term=" ".join(term), @@ -225,7 +225,7 @@ def compute_similarity_scores(self, source_tfidf, source_docs=None, num_terms_per_doc = {} similarity_scores = {} - for tup, value in self.tf_idf.iteritems(): + for tup, value in self.tf_idf.items(): term, doc = tup num_terms_per_doc[doc] = num_terms_per_doc.get(doc, 0) + 1 @@ -253,19 +253,18 @@ def compute_similarity_scores(self, source_tfidf, source_docs=None, similarity_scores.get((doc, src_doc), 0) + src_value * value) else: - for src_tup, src_value in source_tfidf.tf_idf.iteritems(): + for src_tup, src_value in source_tfidf.tf_idf.items(): similarity_scores[(doc, src_doc)] = ( similarity_scores.get((doc, src_doc), 0) + src_value * value) if do_length_normalization: - for doc_pair, value in similarity_scores.iteritems(): + for doc_pair, value in similarity_scores.items(): doc, src_doc = doc_pair - similarity_scores[(doc, src_doc)] = (value - / num_terms_per_doc[doc]) + similarity_scores[(doc, src_doc)] = value / num_terms_per_doc[doc] if logger.isEnabledFor(logging.DEBUG): - for doc, count in num_terms_per_doc.iteritems(): + for doc, count in num_terms_per_doc.items(): logger.debug( 'Seen {0} terms in query document {1}'.format(count, doc)) @@ -329,7 +328,7 @@ def write(self, tf_idf_file): """Writes TFIDF object to file.""" print ("", file=tf_idf_file) - for tup, value in self.tf_idf.iteritems(): + for tup, value in self.tf_idf.items(): term, doc = tup print("{order} {term} {doc} {tfidf}".format( order=len(term), term=" ".join(term), diff --git a/egs/wsj/s5/steps/cleanup/make_biased_lms.py b/egs/wsj/s5/steps/cleanup/make_biased_lms.py index ab508eedc9c..4b1fd320221 100755 --- a/egs/wsj/s5/steps/cleanup/make_biased_lms.py +++ b/egs/wsj/s5/steps/cleanup/make_biased_lms.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 from __future__ import print_function import sys @@ -55,7 +55,7 @@ def ProcessGroupOfLines(group_of_lines): try: command = "steps/cleanup/internal/make_one_biased_lm.py " + args.lm_opts p = subprocess.Popen(command, shell = True, stdin = subprocess.PIPE, - stdout = sys.stdout, stderr = sys.stderr) + stdout = sys.stdout, stderr = sys.stderr) for line in group_of_lines: a = line.split() if len(a) == 0: @@ -63,13 +63,15 @@ def ProcessGroupOfLines(group_of_lines): utterance_id = a[0] # print to utterance-map file print(utterance_id, group_utterance_id, file = utterance_map_file) - rest_of_line = ' '.join(a[1:]) # get rid of utterance id. - print(rest_of_line, file=p.stdin) + rest_of_line = ' '.join(a[1:]) + '\n' # get rid of utterance id. + p.stdin.write(rest_of_line.encode('utf-8')) p.stdin.close() assert p.wait() == 0 - except Exception as e: - sys.exit("make_biased_lms.py: error calling subprocess, command was: " + - command + ", error was : " + str(e)) + except Exception: + sys.stderr.write( + "make_biased_lms.py: error calling subprocess, command was: " + + command) + raise # Print a blank line; this terminates the FST in the Kaldi fst-archive # format. print("") diff --git a/egs/wsj/s5/steps/cleanup/segment_long_utterances.sh b/egs/wsj/s5/steps/cleanup/segment_long_utterances.sh index 16350fdb032..7a16bdcdb12 100755 --- a/egs/wsj/s5/steps/cleanup/segment_long_utterances.sh +++ b/egs/wsj/s5/steps/cleanup/segment_long_utterances.sh @@ -4,6 +4,23 @@ # 2016 Vimal Manohar # Apache 2.0 +# This script performs segmentation of the input data based on the transcription +# and outputs segmented data along with the corresponding aligned transcription. +# The purpose of this script is to divide up the input data (which may consist +# of long recordings such as television shows or audiobooks) into segments which +# are of manageable length for further processing, along with the portion of the +# transcript that seems to match (aligns with) each segment. +# This the light-supervised training scenario where the input transcription is +# not expected to be completely clean and may have significant errors. +# See "JHU Kaldi System for Arabic MGB-3 ASR Challenge using Diarization, +# Audio-transcript Alignment and Transfer Learning": Vimal Manohar, Daniel +# Povey, Sanjeev Khudanpur, ASRU 2017 +# (http://www.danielpovey.com/files/2017_asru_mgb3.pdf) for details. +# The output data is not necessarily particularly clean; you can run +# steps/cleanup/clean_and_segment_data.sh on the output in order to +# further clean it and eliminate data where the transcript doesn't seem to +# match. + . ./path.sh set -e @@ -157,10 +174,17 @@ if [ $stage -le 3 ]; then cp $srcdir/phones.txt $dir 2>/dev/null || true mkdir -p $graph_dir + + n_reco=$(cat $text | wc -l) || exit 1 + nj_reco=$nj + + if [ $nj -gt $n_reco ]; then + nj_reco=$n_reco + fi # Make graphs w.r.t. to the original text (usually recording-level) steps/cleanup/make_biased_lm_graphs.sh $graph_opts \ - --nj $nj --cmd "$cmd" $text \ + --nj $nj_reco --cmd "$cmd" $text \ $lang $dir $dir/graphs if [ -z "$utt2text" ]; then # and then copy it to the sub-segments. @@ -380,7 +404,8 @@ if [ $stage -le 9 ]; then fi if [ $stage -le 10 ]; then - steps/cleanup/internal/resolve_ctm_edits_overlaps.py \ + $cmd $dir/log/resolve_ctm_edits.log \ + steps/cleanup/internal/resolve_ctm_edits_overlaps.py \ ${data_uniform_seg}/segments $decode_dir/ctm_$lmwt/ctm_edits $dir/ctm_edits fi diff --git a/egs/wsj/s5/steps/cleanup/segment_long_utterances_nnet3.sh b/egs/wsj/s5/steps/cleanup/segment_long_utterances_nnet3.sh new file mode 100755 index 00000000000..f0df1e7730c --- /dev/null +++ b/egs/wsj/s5/steps/cleanup/segment_long_utterances_nnet3.sh @@ -0,0 +1,552 @@ +#!/bin/bash + +# Copyright 2014 Guoguo Chen +# 2016 Vimal Manohar +# Apache 2.0 + + +# This script is similar to steps/cleanup/segment_long_utterances.sh, but +# uses nnet3 acoustic model instead of GMM acoustic model for decoding. +# This script performs segmentation of the input data based on the transcription +# and outputs segmented data along with the corresponding aligned transcription. +# The purpose of this script is to divide up the input data (which may consist +# of long recordings such as television shows or audiobooks) into segments which +# are of manageable length for further processing, along with the portion of the +# transcript that seems to match (aligns with) each segment. +# This the light-supervised training scenario where the input transcription is +# not expected to be completely clean and may have significant errors. +# See "JHU Kaldi System for Arabic MGB-3 ASR Challenge using Diarization, +# Audio-transcript Alignment and Transfer Learning": Vimal Manohar, Daniel +# Povey, Sanjeev Khudanpur, ASRU 2017 +# (http://www.danielpovey.com/files/2017_asru_mgb3.pdf) for details. +# The output data is not necessarily particularly clean; you can run +# steps/cleanup/clean_and_segment_data_nnet3.sh on the output in order to +# further clean it and eliminate data where the transcript doesn't seem to +# match. + + +set -e +set -o pipefail +set -u + +stage=-1 +cmd=run.pl +nj=4 + +# Uniform segmentation options +max_segment_duration=30 +overlap_duration=5 +seconds_per_spk_max=30 + +# Decode options +graph_opts= +scale_opts= # for making the graphs +beam=15.0 +lattice_beam=1.0 +lmwt=10 +acwt=0.1 # Just a default value, used for adaptation and beam-pruning.. + +# Contexts must ideally match training +extra_left_context=0 # Set to some large value, typically 40 for LSTM (must match training) +extra_right_context=0 +extra_left_context_initial=-1 +extra_right_context_final=-1 +frames_per_chunk=150 + +# i-vector options +extractor= # i-Vector extractor. If provided, will extract i-vectors. + # Required if the network was trained with i-vector extractor. +use_vad=false # Use energy-based VAD for i-vector extraction + +# TF-IDF similarity search options +max_words=1000 +num_neighbors_to_search=1 # Number of neighboring documents to search around the one retrieved based on maximum tf-idf similarity. +neighbor_tfidf_threshold=0.5 + +align_full_hyp=false # Align full hypothesis i.e. trackback from the end to get the alignment. + +# First-pass segmentation opts +# These options are passed to the script +# steps/cleanup/internal/segment_ctm_edits_mild.py +segmentation_extra_opts= +min_split_point_duration=0.1 +max_deleted_words_kept_when_merging=1 +max_wer=50 +max_segment_length_for_merging=60 +max_bad_proportion=0.75 +max_intersegment_incorrect_words_length=1 +max_segment_length_for_splitting=10 +hard_max_segment_length=15 +min_silence_length_to_split_at=0.3 +min_non_scored_length_to_split_at=0.3 + + +. ./path.sh +. utils/parse_options.sh + +if [ $# -ne 5 ] && [ $# -ne 7 ]; then + cat <] [options] [ ] + e.g.: $0 exp/wsj_tri2b data/lang_nosp data/train_long data/train_long/text data/train_reseg exp/segment_wsj_long_utts_train +This script performs segmentation of the data in and writes out the +segmented data (with a segments file) to + along with the corresponding aligned transcription. +Note: If is not provided, the "text" file in is used as the +raw transcripts to train biased LM for the utterances. +If is provided, then it should be a mapping from the utterance-ids in + to the transcript-keys in the file , which will be +used to train biased LMs for the utterances. +The purpose of this script is to divide up the input data (which may consist of +long recordings such as television shows or audiobooks) into segments which are +of manageable length for further processing, along with the portion of the +transcript that seems to match each segment. +The output data is not necessarily particularly clean; you are advised to run +steps/cleanup/clean_and_segment_data.sh on the output in order to further clean +it and eliminate data where the transcript doesn't seem to match. + main options (for others, see top of script file): + --stage # stage to run from, to enable resuming from partially + # completed run (default: 0) + --cmd '$cmd' # command to submit jobs with (e.g. run.pl, queue.pl) + --nj # number of parallel jobs to use in graph creation and + # decoding + --graph-opts 'opts' # Additional options to make_biased_lm_graphs.sh. + # Please run steps/cleanup/make_biased_lm_graphs.sh + # without arguments to see allowed options. + --segmentation-extra-opts 'opts' # Additional options to segment_ctm_edits_mild.py. + # Please run steps/cleanup/internal/segment_ctm_edits_mild.py + # without arguments to see allowed options. + --align-full-hyp # If true, align full hypothesis + i.e. trackback from the end to get the alignment. + This is different from the normal + Smith-Waterman alignment, where the + traceback will be from the maximum score. + --extractor # i-vector extractor directory if i-vector is + # to be used during decoding. Must match + # the extractor used for training neural-network. + --use-vad # If true, uses energy-based VAD to apply frame weights + # for i-vector stats extraction +EOF + exit 1 +fi + +srcdir=$1 +lang=$2 +data=$3 + +extra_files= +utt2text= +text=$data/text +if [ $# -eq 7 ]; then + text=$4 + utt2text=$5 + out_data=$6 + dir=$7 + extra_files="$utt2text" +else + out_data=$4 + dir=$5 +fi + +if [ ! -z "$extractor" ]; then + extra_files="$extra_files $extractor/final.ie" +fi + +for f in $data/feats.scp $text $extra_files $srcdir/tree \ + $srcdir/final.mdl $srcdir/cmvn_opts; do + if [ ! -f $f ]; then + echo "$0: Could not find file $f" + exit 1 + fi +done + +data_id=`basename $data` +mkdir -p $dir +cp $srcdir/final.mdl $dir +cp $srcdir/tree $dir +cp $srcdir/cmvn_opts $dir +cp $srcdir/{splice_opts,delta_opts,final.mat,final.alimdl} $dir 2>/dev/null || true +cp $srcdir/frame_subsampling_factor $dir 2>/dev/null || true + +if [ -f $srcdir/frame_subsampling_factor ]; then + echo "$0: guessing that this is a chain system, checking parameters." + if [ -z $scale_opts ]; then + echo "$0: setting scale_opts" + scale_opts="--self-loop-scale=1.0 --transition-scale=1.0" + fi + if [ $acwt == 0.1 ]; then + echo "$0: setting acwt=1.0" + acwt=1.0 + fi + if [ $lmwt == 10 ]; then + echo "$0: setting lmwt=1.0" + lmwt=1 + fi +fi + + +utils/lang/check_phones_compatible.sh $lang/phones.txt $srcdir/phones.txt +cp $lang/phones.txt $dir + +data_uniform_seg=$dir/${data_id}_uniform_seg + +# First we split the data into segments of around 30s long, on which +# it would be possible to do a decoding. +# A diarization step will be added in the future. +if [ $stage -le 1 ]; then + echo "$0: Stage 1 (Splitting data directory $data into uniform segments)" + + utils/data/get_utt2dur.sh $data + if [ ! -f $data/segments ]; then + utils/data/get_segments_for_data.sh $data > $data/segments + fi + + utils/data/get_uniform_subsegments.py \ + --max-segment-duration=$max_segment_duration \ + --overlap-duration=$overlap_duration \ + --max-remaining-duration=$(perl -e "print $max_segment_duration / 2.0") \ + $data/segments > $dir/uniform_sub_segments +fi + +if [ $stage -le 2 ]; then + echo "$0: Stage 2 (Prepare uniform sub-segmented data directory)" + rm -r $data_uniform_seg || true + + if [ ! -z "$seconds_per_spk_max" ]; then + utils/data/subsegment_data_dir.sh \ + $data $dir/uniform_sub_segments $dir/${data_id}_uniform_seg.temp + + utils/data/modify_speaker_info.sh --seconds-per-spk-max $seconds_per_spk_max \ + $dir/${data_id}_uniform_seg.temp $data_uniform_seg + else + utils/data/subsegment_data_dir.sh \ + $data $dir/uniform_sub_segments $data_uniform_seg + fi + + utils/fix_data_dir.sh $data_uniform_seg + + # Compute new cmvn stats for the segmented data directory + steps/compute_cmvn_stats.sh $data_uniform_seg/ +fi + +graph_dir=$dir/graphs_uniform_seg + +if [ $stage -le 3 ]; then + echo "$0: Stage 3 (Building biased-language-model decoding graphs)" + + mkdir -p $graph_dir + + n_reco=$(cat $text | wc -l) || exit 1 + nj_reco=$nj + + if [ $nj -gt $n_reco ]; then + nj_reco=$n_reco + fi + + # Make graphs w.r.t. to the original text (usually recording-level) + steps/cleanup/make_biased_lm_graphs.sh $graph_opts \ + --scale-opts "$scale_opts" \ + --nj $nj_reco --cmd "$cmd" $text \ + $lang $dir $dir/graphs + if [ -z "$utt2text" ]; then + # and then copy it to the sub-segments. + cat $dir/uniform_sub_segments | awk '{print $1" "$2}' | \ + utils/apply_map.pl -f 2 $dir/graphs/HCLG.fsts.scp | \ + sort -k1,1 > \ + $graph_dir/HCLG.fsts.scp + else + # and then copy it to the sub-segments. + cat $dir/uniform_sub_segments | awk '{print $1" "$2}' | \ + utils/apply_map.pl -f 2 $utt2text | \ + utils/apply_map.pl -f 2 $dir/graphs/HCLG.fsts.scp | \ + sort -k1,1 > \ + $graph_dir/HCLG.fsts.scp + fi + + cp $lang/words.txt $graph_dir + cp -r $lang/phones $graph_dir + [ -f $dir/graphs/num_pdfs ] && cp $dir/graphs/num_pdfs $graph_dir/ +fi + +decode_dir=$dir/lats +mkdir -p $decode_dir + +online_ivector_dir= +if [ ! -z "$extractor" ]; then + online_ivector_dir=$dir/ivectors_$(basename $data_uniform_seg) + + if [ $stage -le 4 ]; then + # Compute energy-based VAD + if $use_vad; then + steps/compute_vad_decision.sh $data_uniform_seg \ + $data_uniform_seg/log $data_uniform_seg/data + fi + + steps/online/nnet2/extract_ivectors_online.sh \ + --nj $nj --cmd "$cmd --mem 4G" --use-vad $use_vad \ + $data_uniform_seg $extractor $online_ivector_dir + fi +fi + +if [ $stage -le 5 ]; then + echo "$0: Decoding with biased language models..." + + steps/cleanup/decode_segmentation_nnet3.sh \ + --acwt $acwt \ + --beam $beam --lattice-beam $lattice_beam --nj $nj --cmd "$cmd --mem 4G" \ + --skip-scoring true --allow-partial false \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --extra-left-context-initial $extra_left_context_initial \ + --extra-right-context-final $extra_right_context_final \ + --frames-per-chunk $frames_per_chunk \ + ${online_ivector_dir:+--online-ivector-dir $online_ivector_dir} \ + $graph_dir $data_uniform_seg $decode_dir +fi + +frame_shift_opt= +if [ -f $srcdir/frame_subsampling_factor ]; then + frame_shift_opt="--frame-shift 0.0$(cat $srcdir/frame_subsampling_factor)" +fi + +if [ $stage -le 6 ]; then + steps/get_ctm_fast.sh --lmwt $lmwt --cmd "$cmd --mem 4G" \ + --print-silence true $frame_shift_opt \ + $data_uniform_seg $lang $decode_dir $decode_dir/ctm_$lmwt +fi + +# Split the original text into documents, over which we can do +# searching reasonably efficiently. Also get a mapping from the original +# text to the created documents (i.e. text2doc) +# Since the Smith-Waterman alignment is linear in the length of the +# text, we want to keep it reasonably small (a few thousand words). + +if [ $stage -le 7 ]; then + # Split the reference text into documents. + mkdir -p $dir/docs + + # text2doc is a mapping from the original transcript to the documents + # it is split into. + # The format is + # ... + steps/cleanup/internal/split_text_into_docs.pl --max-words $max_words \ + $text $dir/docs/doc2text $dir/docs/docs.txt + utils/utt2spk_to_spk2utt.pl $dir/docs/doc2text > $dir/docs/text2doc +fi + +if [ $stage -le 8 ]; then + # Get TF-IDF for the reference documents. + echo $nj > $dir/docs/num_jobs + + utils/split_data.sh $data_uniform_seg $nj + + mkdir -p $dir/docs/split$nj/ + + # First compute IDF stats + $cmd $dir/log/compute_source_idf_stats.log \ + steps/cleanup/internal/compute_tf_idf.py \ + --tf-weighting-scheme="raw" \ + --idf-weighting-scheme="log" \ + --output-idf-stats=$dir/docs/idf_stats.txt \ + $dir/docs/docs.txt $dir/docs/src_tf_idf.txt + + # Split documents so that they can be accessed easily by parallel jobs. + mkdir -p $dir/docs/split$nj/ + sdir=$dir/docs/split$nj + for n in `seq $nj`; do + + # old2new_utts is a mapping from the original segments to the + # new segments created by uniformly segmenting. + # The format is ... + utils/filter_scp.pl $data_uniform_seg/split$nj/$n/utt2spk $dir/uniform_sub_segments | \ + cut -d ' ' -f 1,2 | utils/utt2spk_to_spk2utt.pl > $sdir/old2new_utts.$n.txt + + if [ ! -z "$utt2text" ]; then + # utt2text, if provided, is a mapping from the to + # . + # Since text2doc is mapping from to documents, we + # first have to find the original-transcripts that are in the current + # split. + utils/filter_scp.pl $sdir/old2new_utts.$n.txt $utt2text | \ + cut -d ' ' -f 2 | sort -u | \ + utils/filter_scp.pl /dev/stdin $dir/docs/text2doc > $sdir/text2doc.$n + else + utils/filter_scp.pl $sdir/old2new_utts.$n.txt \ + $dir/docs/text2doc > $sdir/text2doc.$n + fi + + utils/spk2utt_to_utt2spk.pl $sdir/text2doc.$n | \ + utils/filter_scp.pl /dev/stdin $dir/docs/docs.txt > \ + $sdir/docs.$n.txt + done + + # Compute TF-IDF for the source documents. + $cmd JOB=1:$nj $dir/docs/log/get_tfidf_for_source_texts.JOB.log \ + steps/cleanup/internal/compute_tf_idf.py \ + --tf-weighting-scheme="raw" \ + --idf-weighting-scheme="log" \ + --input-idf-stats=$dir/docs/idf_stats.txt \ + $sdir/docs.JOB.txt $sdir/src_tf_idf.JOB.txt + + sdir=$dir/docs/split$nj + # Make $sdir an absolute pathname. + sdir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $sdir ${PWD}` + + for n in `seq $nj`; do + awk -v f="$sdir/src_tf_idf.$n.txt" '{print $1" "f}' \ + $sdir/text2doc.$n + done | perl -ane 'BEGIN { %tfidfs = (); } + { + if (!defined $tfidfs{$F[0]}) { + $tfidfs{$F[0]} = $F[1]; + } + } + END { + while(my ($k, $v) = each %tfidfs) { + print "$k $v\n"; + } }' > $dir/docs/source2tf_idf.scp +fi + +if [ $stage -le 9 ]; then + echo "$0: using default values of non-scored words..." + + # At the level of this script we just hard-code it that non-scored words are + # those that map to silence phones (which is what get_non_scored_words.py + # gives us), although this could easily be made user-configurable. This list + # of non-scored words affects the behavior of several of the data-cleanup + # scripts; essentially, we view the non-scored words as negotiable when it + # comes to the reference transcript, so we'll consider changing the reference + # to match the hyp when it comes to these words. + steps/cleanup/internal/get_non_scored_words.py $lang > $dir/non_scored_words.txt +fi + +if [ $stage -le 10 ]; then + sdir=$dir/query_docs/split$nj + mkdir -p $sdir + + # Compute TF-IDF for the query documents (decode hypotheses). + # The output is an archive of TF-IDF indexed by the query. + $cmd JOB=1:$nj $decode_dir/ctm_$lmwt/log/compute_query_tf_idf.JOB.log \ + steps/cleanup/internal/ctm_to_text.pl --non-scored-words $dir/non_scored_words.txt \ + $decode_dir/ctm_$lmwt/ctm.JOB \| \ + steps/cleanup/internal/compute_tf_idf.py \ + --tf-weighting-scheme="normalized" \ + --idf-weighting-scheme="log" \ + --input-idf-stats=$dir/docs/idf_stats.txt \ + --accumulate-over-docs=false \ + - $sdir/query_tf_idf.JOB.ark.txt + + # The relevant documents can be found using TF-IDF similarity and nearby + # documents can also be picked for the Smith-Waterman alignment stage. + + # Get a mapping from the new utterance-ids to original transcripts + if [ -z "$utt2text" ]; then + awk '{print $1" "$2}' $dir/uniform_sub_segments > \ + $dir/new2orig_utt + else + awk '{print $1" "$2}' $dir/uniform_sub_segments | \ + utils/apply_map.pl -f 2 $utt2text > \ + $dir/new2orig_utt + fi + + # The query TF-IDFs are all indexed by the utterance-id of the sub-segments. + # The source TF-IDFs use the document-ids created by splitting the reference + # text into documents. + # For each query, we need to retrieve the documents that were created from + # the same original utterance that the sub-segment was from. For this, + # we have to load the source TF-IDF that has those documents. This + # information is provided using the option --source-text-id2tf-idf-file. + # The output of this script is a file where the first column is the + # query-id (i.e. sub-segment-id) and the remaining columns, which is at least + # one in number and a maxmium of (1 + 2 * num-neighbors-to-search) columns + # is the document-ids for the retrieved documents. + $cmd JOB=1:$nj $dir/log/retrieve_similar_docs.JOB.log \ + steps/cleanup/internal/retrieve_similar_docs.py \ + --query-tfidf=$dir/query_docs/split$nj/query_tf_idf.JOB.ark.txt \ + --source-text-id2tfidf=$dir/docs/source2tf_idf.scp \ + --source-text-id2doc-ids=$dir/docs/text2doc \ + --query-id2source-text-id=$dir/new2orig_utt \ + --num-neighbors-to-search=$num_neighbors_to_search \ + --neighbor-tfidf-threshold=$neighbor_tfidf_threshold \ + --relevant-docs=$dir/query_docs/split$nj/relevant_docs.JOB.txt + + $cmd JOB=1:$nj $decode_dir/ctm_$lmwt/log/get_ctm_edits.JOB.log \ + steps/cleanup/internal/stitch_documents.py \ + --query2docs=$dir/query_docs/split$nj/relevant_docs.JOB.txt \ + --input-documents=$dir/docs/split$nj/docs.JOB.txt \ + --output-documents=- \| \ + steps/cleanup/internal/align_ctm_ref.py --eps-symbol='""' \ + --oov-word="'`cat $lang/oov.txt`'" --symbol-table=$lang/words.txt \ + --hyp-format=CTM --align-full-hyp=$align_full_hyp \ + --hyp=$decode_dir/ctm_$lmwt/ctm.JOB --ref=- \ + --output=$decode_dir/ctm_$lmwt/ctm_edits.JOB + + for n in `seq $nj`; do + cat $decode_dir/ctm_$lmwt/ctm_edits.$n + done > $decode_dir/ctm_$lmwt/ctm_edits + +fi + +if [ $stage -le 11 ]; then + $cmd $dir/log/resolve_ctm_edits.log \ + steps/cleanup/internal/resolve_ctm_edits_overlaps.py \ + ${data_uniform_seg}/segments $decode_dir/ctm_$lmwt/ctm_edits $dir/ctm_edits +fi + +if [ $stage -le 12 ]; then + echo "$0: modifying ctm-edits file to allow repetitions [for dysfluencies] and " + echo " ... to fix reference mismatches involving non-scored words. " + + $cmd $dir/log/modify_ctm_edits.log \ + steps/cleanup/internal/modify_ctm_edits.py --verbose=3 $dir/non_scored_words.txt \ + $dir/ctm_edits $dir/ctm_edits.modified + + echo " ... See $dir/log/modify_ctm_edits.log for details and stats, including" + echo " a list of commonly-repeated words." +fi + +if [ $stage -le 13 ]; then + echo "$0: applying 'taint' markers to ctm-edits file to mark silences and" + echo " ... non-scored words that are next to errors." + $cmd $dir/log/taint_ctm_edits.log \ + steps/cleanup/internal/taint_ctm_edits.py --remove-deletions=false \ + $dir/ctm_edits.modified $dir/ctm_edits.tainted + echo "... Stats, including global cor/ins/del/sub stats, are in $dir/log/taint_ctm_edits.log." +fi + +if [ $stage -le 14 ]; then + echo "$0: creating segmentation from ctm-edits file." + + segmentation_opts=( + --min-split-point-duration=$min_split_point_duration + --max-deleted-words-kept-when-merging=$max_deleted_words_kept_when_merging + --merging.max-wer=$max_wer + --merging.max-segment-length=$max_segment_length_for_merging + --merging.max-bad-proportion=$max_bad_proportion + --merging.max-intersegment-incorrect-words-length=$max_intersegment_incorrect_words_length + --splitting.max-segment-length=$max_segment_length_for_splitting + --splitting.hard-max-segment-length=$hard_max_segment_length + --splitting.min-silence-length=$min_silence_length_to_split_at + --splitting.min-non-scored-length=$min_non_scored_length_to_split_at + ) + + $cmd $dir/log/segment_ctm_edits.log \ + steps/cleanup/internal/segment_ctm_edits_mild.py \ + ${segmentation_opts[@]} $segmentation_extra_opts \ + --oov-symbol-file=$lang/oov.txt \ + --ctm-edits-out=$dir/ctm_edits.segmented \ + --word-stats-out=$dir/word_stats.txt \ + $dir/non_scored_words.txt \ + $dir/ctm_edits.tainted $dir/text $dir/segments + + echo "$0: contents of $dir/log/segment_ctm_edits.log are:" + cat $dir/log/segment_ctm_edits.log + echo "For word-level statistics on p(not-being-in-a-segment), with 'worst' words at the top," + echo "see $dir/word_stats.txt" + echo "For detailed utterance-level debugging information, see $dir/ctm_edits.segmented" +fi + +mkdir -p $out_data +if [ $stage -le 15 ]; then + utils/data/subsegment_data_dir.sh $data_uniform_seg \ + $dir/segments $dir/text $out_data +fi diff --git a/egs/wsj/s5/steps/compare_alignments.sh b/egs/wsj/s5/steps/compare_alignments.sh new file mode 100755 index 00000000000..d94d2197fee --- /dev/null +++ b/egs/wsj/s5/steps/compare_alignments.sh @@ -0,0 +1,220 @@ +#!/bin/bash + +# Copyright 2018 Johns Hopkins University (author: Daniel Povey) +# Apache 2.0. + +set -e +stage=0 +cmd=run.pl # We use this only for get_ctm.sh, which can be a little slow. +num_to_sample=1000 # We sample this many utterances for human-readable display, starting from the worst and then + # starting from the middle. +cleanup=true + +if [ -f ./path.sh ]; then . ./path.sh; fi + +. ./utils/parse_options.sh + +if [ $# -ne 5 ] && [ $# -ne 7 ]; then + cat < + or: $0 [options] + e.g.: $0 data/lang data/train exp/tri2_ali exp/tri3_ali exp/compare_ali_2_3 + + Options: + --cmd (run.pl|queue.pl...) # specify how to run the sub-processes. + # (passed through to get_train_ctm.sh) + --cleanup # Specify --cleanup false to prevent + # cleanup of temporary files. + --stage # Enables you to run part of the script. + +EOF + exit 1 +fi + +if [ $# -eq 5 ]; then + lang1=$1 + lang2=$1 + data1=$2 + data2=$2 + ali_dir1=$3 + ali_dir2=$4 + dir=$5 +else + lang1=$1 + lang2=$2 + data1=$3 + data2=$4 + ali_dir1=$5 + ali_dir2=$6 + dir=$7 +fi + +for f in $lang1/phones.txt $lang2/phones.txt $data1/utt2spk $data2/utt2spk \ + $ali_dir1/ali.1.gz $ali_dir2/ali.2.gz; do + if [ ! -f $f ]; then + echo "$0: expected file $f to exist" + exit 1 + fi +done + +# This will exit if the phone symbol id's are different, due to +# `set -e` above. +utils/lang/check_phones_compatible.sh $lang1/phones.txt $lang2/phones.txt + +nj1=$(cat $ali_dir1/num_jobs) +nj2=$(cat $ali_dir2/num_jobs) + +mkdir -p $dir/log + + +if [ $stage -le 0 ]; then + echo "$0: converting alignments to phones." + + for j in $(seq $nj1); do gunzip -c $ali_dir1/ali.$j.gz; done | \ + ali-to-phones --per-frame=true $ali_dir1/final.mdl ark:- ark:- | gzip -c > $dir/phones1.gz + + for j in $(seq $nj2); do gunzip -c $ali_dir2/ali.$j.gz; done | \ + ali-to-phones --per-frame=true $ali_dir2/final.mdl ark:- ark:- | gzip -c > $dir/phones2.gz +fi + +if [ $stage -le 1 ]; then + echo "$0: getting comparison stats and utterance stats." + compare-int-vector --binary=false --write-confusion-matrix=$dir/conf.mat \ + "ark:gunzip -c $dir/phones1.gz|" "ark:gunzip -c $dir/phones2.gz|" 2>$dir/log/compare_phones.log > $dir/utt_stats.phones + tail -n 8 $dir/log/compare_phones.log +fi + +if [ $stage -le 3 ]; then + cat $dir/conf.mat | grep -v -F '[' | sed 's/]//' | awk '{n=NF; for (k=1;k<=n;k++) { conf[NR,k] = $k; row_tot[NR] += $k; col_tot[k] += $k; } } END{ + for (row=1;row<=n;row++) for (col=1;col<=n;col++) { + val = conf[row,col]; this_row_tot = row_tot[row]; this_col_tot = col_tot[col]; + rval=conf[col,row] + min_tot = (this_row_tot < this_col_tot ? this_row_tot : this_col_tot); + if (val != 0) { + phone1 = row-1; phone2 = col-1; + if (row == col) printf("COR %d %d %.2f%\n", phone1, val, (val * 100 / this_row_tot)); + else { + norm_prob = val * val / min_tot; # heuristic for sorting. + printf("SUB %d %d %d %d %.2f%% %.2f%%\n", + norm_prob, phone1, phone2, val, (val * 100 / min_tot), (rval * 100 / min_tot)); }}}}' > $dir/phone_stats.all + + ( + echo "# Format: " + grep '^COR' $dir/phone_stats.all | sort -n -k4,4 | awk '{print $2, $3, $4}' | utils/int2sym.pl -f 1 $lang1/phones.txt + ) > $dir/phones_correct.txt + + ( + echo "#Format: " + echo "# is the number of frames that were labeled in the first" + echo "# set of alignments and in the second." + echo "# is divided by the smaller of the total num-frames of" + echo "# phone1 or phone2, whichever is smaller; expressed as a percentage." + echo "# is the same but for the reverse substitution, from" + echo "# to ; the comparison with the substitutions are)." + grep '^SUB' $dir/phone_stats.all | sort -nr -k2,2 | awk '{print $3,$4,$5,$6,$7}' | utils/int2sym.pl -f 1-2 $lang1/phones.txt + ) > $dir/phone_subs.txt +fi + +if [ $stage -le 4 ]; then + echo "$0: getting CTMs" + steps/get_train_ctm.sh --use-segments false --print-silence true --cmd "$cmd" --frame-shift 1.0 $data1 $lang1 $ali_dir1 $dir/ctm1 + steps/get_train_ctm.sh --use-segments false --print-silence true --cmd "$cmd" --frame-shift 1.0 $data2 $lang2 $ali_dir2 $dir/ctm2 +fi + +if [ $stage -le 5 ]; then + oov=$(cat $lang1/oov.int) + # Note: below, we use $lang1 for both setups; this is by design as compare-int-vector + # assumes they use the same symbol table. + for n in 1 2; do + cat $dir/ctm${n}/ctm | utils/sym2int.pl --map-oov $oov -f 5 $lang1/words.txt | \ + awk 'BEGIN{utt_id="";} { if (utt_id != $1) { if (utt_id != "") printf("\n"); utt_id=$1; printf("%s ", utt_id); } t_start=int($3); t_end=t_start + int($4); word=$5; for (t=t_start; t$dir/words${n}.gz + done +fi + +if [ $stage -le 5 ]; then + compare-int-vector --binary=false --write-tot-counts=$dir/words_tot.vec --write-diff-counts=$dir/words_diff.vec \ + "ark:gunzip -c $dir/words1.gz|" "ark:gunzip -c $dir/words2.gz|" 2>$dir/log/compare_words.log >$dir/utt_stats.words + tail -n 8 $dir/log/compare_words.log +fi + +if [ $stage -le 6 ]; then + + ( echo "# Word stats. Format:"; + echo " " + + paste <(awk '{for (n=2;n 0) print $1*$1/$2, $1/$2, $1, $2, (NR-1)}' | utils/int2sym.pl -f 5 $lang1/words.txt | \ + sort -nr | awk '{print $2, $3, $4, $5;}' + ) > $dir/word_stats.txt + +fi + +if [ $stage -le 7 ]; then + for type in phones words; do + num_utts=$(wc -l <$dir/utt_stats.$type) + cat $dir/utt_stats.$type | awk -v type=$type 'BEGIN{print "Utterance-id proportion-"type"-changed num-frames num-wrong-frames"; } + {print $1, $3 * 1.0 / $2, $2, $3; }' | sort -nr -k2,2 > $dir/utt_stats.$type.sorted + ( + echo "$0: Percentiles 100, 90, .. 0 of proportion-$type-changed distribution (over utterances) are:" + cat $dir/utt_stats.$type.sorted | awk -v n=$num_utts 'BEGIN{k=int((n-1)/10);} {if (NR % k == 1) printf("%s ", $2); } END{print "";}' + ) | tee $dir/utt_stats.$type.percentiles + done +fi + + +if [ $stage -le 8 ]; then + # Display the 1000 worst utterances, and 1000 utterances from the middle of the pack, in a readable format. + num_utts=$(wc -l <$dir/utt_stats.words.sorted) + half_num_utts=$[$num_utts/2]; + if [ $num_to_sample -gt $half_num_utts ]; then + num_to_sample=$half_num_utts + fi + head -n $num_to_sample $dir/utt_stats.words.sorted | awk '{print $1}' > $dir/utt_ids.worst + tail -n +$half_num_utts $dir/utt_stats.words.sorted | head -n $num_to_sample | awk '{print $1}' > $dir/utt_ids.mid + + for suf in worst mid; do + for n in 1 2; do + gunzip -c $dir/phones${n}.gz | copy-int-vector ark:- ark,t:- | utils/filter_scp.pl $dir/utt_ids.$suf >$dir/temp + # the next command reorders them, and duplicates the utterance-idwhich we'll later use + # that to display the word sequence. + awk '{print $1,$1,$1}' <$dir/utt_ids.$suf | utils/apply_map.pl -f 3 $dir/temp > $dir/phones${n}.$suf + rm $dir/temp + done + # the stuff with 0 and below is a kind of hack so that if the phones are the same, we end up + # with just the phone, but if different, we end up with p1/p2. + # The apply_map.pl stuff is to put the transcript there. + + ( + echo "# Format: ... ... " + echo "# If the two alignments have the same phone, just that phone will be printed;" + echo "# otherwise the two phones will be printed, as in 'phone1/phone2'. So '/' is present" + echo "# whenever there is a mismatch." + + paste $dir/phones1.$suf $dir/phones2.$suf | perl -ane ' @A = split("\t", $_); @A1 = split(" ", $A[0]); @A2 = split(" ", $A[1]); + $utt = shift @A1; shift @A2; print $utt, " "; + for ($n = 0; $n < @A1 && $n < @A2; $n++) { $a1=$A1[$n]; $a2=$A2[$n]; if ($a1 eq $a2) { print "$a1 "; } else { print "$a1 0 $a2 "; }} + print "\n" ' | utils/int2sym.pl -f 3- $lang1/phones.txt | sed 's: :/:g' | \ + utils/apply_map.pl -f 2 $data1/text + ) > $dir/compare_phones_${suf}.txt + done +fi + + +if [ $stage -le 9 ] && $cleanup; then + rm $dir/phones{1,2}.gz $dir/words{1,2}.gz $dir/ctm*/ctm $dir/*.vec $dir/conf.mat \ + $dir/utt_ids.* $dir/phones{1,2}.{mid,worst} $dir/utt_stats.{phones,words} \ + $dir/phone_stats.all +fi + +# clean up +exit 0 diff --git a/egs/wsj/s5/steps/compute_vad_decision.sh b/egs/wsj/s5/steps/compute_vad_decision.sh new file mode 100755 index 00000000000..4cf3c5b2b79 --- /dev/null +++ b/egs/wsj/s5/steps/compute_vad_decision.sh @@ -0,0 +1,86 @@ +#!/bin/bash + +# Copyright 2017 Vimal Manohar +# Apache 2.0 + +# To be run from .. (one directory up from here) +# see ../run.sh for example + +# Compute energy based VAD output + +nj=4 +cmd=run.pl +vad_config=conf/vad.conf + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# -lt 1 ] || [ $# -gt 3 ]; then + echo "Usage: $0 [options] [ []]"; + echo "e.g.: $0 data/train exp/make_vad mfcc" + echo "Note: defaults to /log, and defaults to /data" + echo " Options:" + echo " --vad-config # config passed to compute-vad-energy" + echo " --nj # number of parallel jobs" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + exit 1; +fi + +data=$1 +if [ $# -ge 2 ]; then + logdir=$2 +else + logdir=$data/log +fi +if [ $# -ge 3 ]; then + vaddir=$3 +else + vaddir=$data/data +fi + + +# make $vaddir an absolute pathname. +vaddir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $vaddir ${PWD}` + +# use "name" as part of name of the archive. +name=`basename $data` + +mkdir -p $vaddir || exit 1; +mkdir -p $logdir || exit 1; + +if [ -f $data/vad.scp ]; then + mkdir -p $data/.backup + echo "$0: moving $data/vad.scp to $data/.backup" + mv $data/vad.scp $data/.backup +fi + +for f in $data/feats.scp "$vad_config"; do + if [ ! -f $f ]; then + echo "compute_vad_decision.sh: no such file $f" + exit 1; + fi +done + +utils/split_data.sh $data $nj || exit 1; +sdata=$data/split$nj; + +$cmd JOB=1:$nj $logdir/vad_${name}.JOB.log \ + compute-vad --config=$vad_config scp:$sdata/JOB/feats.scp \ + ark,scp:$vaddir/vad_${name}.JOB.ark,$vaddir/vad_${name}.JOB.scp || exit 1 + +for ((n=1; n<=nj; n++)); do + cat $vaddir/vad_${name}.$n.scp || exit 1; +done > $data/vad.scp + +nc=`cat $data/vad.scp | wc -l` +nu=`cat $data/feats.scp | wc -l` +if [ $nc -ne $nu ]; then + echo "**Warning it seems not all of the speakers got VAD output ($nc != $nu);" + echo "**validate_data_dir.sh will fail; you might want to use fix_data_dir.sh" + [ $nc -eq 0 ] && exit 1; +fi + + +echo "Created VAD output for $name" diff --git a/egs/wsj/s5/steps/conf/append_eval_to_ctm.py b/egs/wsj/s5/steps/conf/append_eval_to_ctm.py index f8e2aad891d..90679d2b341 100755 --- a/egs/wsj/s5/steps/conf/append_eval_to_ctm.py +++ b/egs/wsj/s5/steps/conf/append_eval_to_ctm.py @@ -3,6 +3,7 @@ # Copyright 2015 Brno University of Technology (author: Karel Vesely) # Apache 2.0 +from __future__ import print_function import sys,operator # Append Levenshtein alignment of 'hypothesis' and 'reference' into 'CTM': @@ -15,7 +16,7 @@ # 'U' = unknown (not part of scored segment) if len(sys.argv) != 4: - print 'Usage: %s eval-in ctm-in ctm-eval-out' % __file__ + print('Usage: %s eval-in ctm-in ctm-eval-out' % __file__) sys.exit(1) dummy, eval_in, ctm_in, ctm_eval_out = sys.argv @@ -54,7 +55,7 @@ # Build the 'ctm' with 'eval' column added, ctm_eval = [] -for utt,ctm_part in ctm.iteritems(): +for utt,ctm_part in ctm.items(): ctm_part.sort(key = operator.itemgetter(2)) # Sort by 'beg' time, try: # merging 'tuples' by '+', the record has format: @@ -69,7 +70,7 @@ # append, ctm_eval.extend(merged) except KeyError: - print 'Missing key', utt, 'in the word-evaluation stats from scoring' + print('Missing key', utt, 'in the word-evaluation stats from scoring') # Sort again, ctm_eval.sort(key = operator.itemgetter(0,1,2)) diff --git a/egs/wsj/s5/steps/conf/append_prf_to_ctm.py b/egs/wsj/s5/steps/conf/append_prf_to_ctm.py index 547b6176c9f..42acc5e22b7 100755 --- a/egs/wsj/s5/steps/conf/append_prf_to_ctm.py +++ b/egs/wsj/s5/steps/conf/append_prf_to_ctm.py @@ -3,6 +3,7 @@ # Copyright 2015 Brno University of Technology (author: Karel Vesely) # Apache 2.0 +from __future__ import print_function import sys # Append Levenshtein alignment of 'hypothesis' and 'reference' into 'CTM': @@ -16,7 +17,7 @@ # Parse options, if len(sys.argv) != 4: - print "Usage: %s prf ctm_in ctm_out" % __file__ + print("Usage: %s prf ctm_in ctm_out" % __file__) sys.exit(1) prf_file, ctm_file, ctm_out_file = sys.argv[1:] diff --git a/egs/wsj/s5/steps/conf/convert_ctm_to_tra.py b/egs/wsj/s5/steps/conf/convert_ctm_to_tra.py index 8fec0064fd7..25899e19264 100755 --- a/egs/wsj/s5/steps/conf/convert_ctm_to_tra.py +++ b/egs/wsj/s5/steps/conf/convert_ctm_to_tra.py @@ -3,6 +3,7 @@ # Copyright 2015 Brno University of Technology (author: Karel Vesely) # Apache 2.0 +from __future__ import print_function import sys, operator # This scripts loads a 'ctm' file and converts it into the 'tra' format: @@ -14,7 +15,7 @@ # - confidences if len(sys.argv) != 3: - print 'Usage: %s ctm-in tra-out' % __file__ + print('Usage: %s ctm-in tra-out' % __file__) sys.exit(1) dummy, ctm_in, tra_out = sys.argv @@ -31,7 +32,7 @@ # Store the in 'tra' format, with open(tra_out,'w') as f: - for utt,tuples in tra.iteritems(): + for utt,tuples in tra.items(): tuples.sort(key = operator.itemgetter(0)) # Sort by 'beg' time, f.write('%s %s\n' % (utt,' '.join([t[1] for t in tuples]))) diff --git a/egs/wsj/s5/steps/conf/get_ctm_conf.sh b/egs/wsj/s5/steps/conf/get_ctm_conf.sh new file mode 100755 index 00000000000..8dbc9f449cd --- /dev/null +++ b/egs/wsj/s5/steps/conf/get_ctm_conf.sh @@ -0,0 +1,93 @@ +#!/bin/bash +# Copyright Johns Hopkins University (Author: Daniel Povey) 2012. Apache 2.0. + +# This script produces CTM files from a decoding directory that has lattices +# present. This version gives you confidence scores. See also steps/get_ctm.sh + + +# begin configuration section. +cmd=run.pl +stage=0 +min_lmwt=5 +max_lmwt=20 +use_segments=true # if we have a segments file, use it to convert + # the segments to be relative to the original files. +iter=final +#end configuration section. + +echo "$0 $@" # Print the command line for logging + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: $0 [options] " + echo " Options:" + echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." + echo " --stage (0|1|2) # start scoring script from part-way through." + echo " --use-segments (true|false) # use segments and reco2file_and_channel files " + echo " # to produce a ctm relative to the original audio" + echo " # files, with channel information (typically needed" + echo " # for NIST scoring)." + echo "e.g.:" + echo "$0 data/train data/lang exp/tri4a/decode/" + echo "See also: steps/get_ctm.sh, steps/get_ctm_conf_fast.sh" + exit 1; +fi + +data=$1 +lang=$2 # Note: may be graph directory not lang directory, but has the necessary stuff copied. +dir=$3 + +model=$dir/../$iter.mdl # assume model one level up from decoding dir. + + +for f in $lang/words.txt $model $dir/lat.1.gz; do + [ ! -f $f ] && echo "$0: expecting file $f to exist" && exit 1; +done + +name=`basename $data`; # e.g. eval2000 + +mkdir -p $dir/scoring/log + +if [ -f $dir/../frame_shift ]; then + frame_shift_opt="--frame-shift=$(cat $dir/../frame_shift)" + echo "$0: $dir/../frame_shift exists, using $frame_shift_opt" +elif [ -f $dir/../frame_subsampling_factor ]; then + factor=$(cat $dir/../frame_subsampling_factor) || exit 1 + frame_shift_opt="--frame-shift=0.0$factor" + echo "$0: $dir/../frame_subsampling_factor exists, using $frame_shift_opt" +fi + +if [ $stage -le 0 ]; then + if [ -f $data/segments ] && $use_segments; then + f=$data/reco2file_and_channel + [ ! -f $f ] && echo "$0: expecting file $f to exist" && exit 1; + filter_cmd="utils/convert_ctm.pl $data/segments $data/reco2file_and_channel" + else + filter_cmd=cat + fi + + if [ -f $lang/phones/word_boundary.int ]; then + $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/get_ctm.LMWT.log \ + mkdir -p $dir/score_LMWT/ '&&' \ + lattice-prune --inv-acoustic-scale=LMWT --beam=5 "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ + lattice-align-words $lang/phones/word_boundary.int $model ark:- ark:- \| \ + lattice-to-ctm-conf $frame_shift_opt --decode-mbr=true --inv-acoustic-scale=LMWT ark:- - \| \ + utils/int2sym.pl -f 5 $lang/words.txt \| \ + $filter_cmd '>' $dir/score_LMWT/$name.ctm || exit 1; + else + if [ ! -f $lang/phones/align_lexicon.int ]; then + echo "$0: neither $lang/phones/word_boundary.int nor $lang/phones/align_lexicon.int exists: cannot align." + exit 1; + fi + $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/get_ctm.LMWT.log \ + mkdir -p $dir/score_LMWT/ '&&' \ + lattice-prune --inv-acoustic-scale=LMWT --beam=5 "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ + lattice-align-words-lexicon $lang/phones/align_lexicon.int $model ark:- ark:- \| \ + lattice-to-ctm-conf $frame_shift_opt --decode-mbr=true --inv-acoustic-scale=LMWT ark:- - \| \ + utils/int2sym.pl -f 5 $lang/words.txt \| \ + $filter_cmd '>' $dir/score_LMWT/$name.ctm || exit 1; + fi +fi + diff --git a/egs/wsj/s5/steps/conf/parse_arpa_unigrams.py b/egs/wsj/s5/steps/conf/parse_arpa_unigrams.py index 1be32d4c4d7..f0a2fe13497 100755 --- a/egs/wsj/s5/steps/conf/parse_arpa_unigrams.py +++ b/egs/wsj/s5/steps/conf/parse_arpa_unigrams.py @@ -3,11 +3,12 @@ # Copyright 2015 Brno University of Technology (author: Karel Vesely) # Apache 2.0 +from __future__ import print_function import sys, gzip, re # Parse options, if len(sys.argv) != 4: - print "Usage: %s " % __file__ + print("Usage: %s " % __file__) sys.exit(0) words_txt, arpa_gz, unigrams_out = sys.argv[1:] @@ -31,7 +32,7 @@ # Create list, 'wrd id log_p_unigram', words_unigram = [[wrd, id, (wrd_log10[wrd] if wrd in wrd_log10 else -99)] for wrd,id in words ] -print >>sys.stderr, words_unigram[0] +print(words_unigram[0], file=sys.stderr) # Store, with open(unigrams_out,'w') as f: f.writelines(['%s %s %g\n' % (w,i,p) for (w,i,p) in words_unigram]) diff --git a/egs/wsj/s5/steps/conf/prepare_calibration_data.py b/egs/wsj/s5/steps/conf/prepare_calibration_data.py index bc8f92a2f7f..c4da720ba71 100755 --- a/egs/wsj/s5/steps/conf/prepare_calibration_data.py +++ b/egs/wsj/s5/steps/conf/prepare_calibration_data.py @@ -3,6 +3,7 @@ # Copyright 2015 Brno University of Technology (author: Karel Vesely) # Apache 2.0 +from __future__ import division import sys, math from optparse import OptionParser @@ -82,7 +83,7 @@ depths = dict() for l in open(o.lattice_depth): utt,d = l.split(' ',1) - depths[utt] = map(int,d.split()) + depths[utt] = [int(i) for i in d.split()] # Load the 'word_categories' mapping for categorical input features derived from 'lang/words.txt', wrd_to_cat = [ l.split() for l in open(word_categories_file) ] diff --git a/egs/wsj/s5/steps/data/augment_data_dir.py b/egs/wsj/s5/steps/data/augment_data_dir.py old mode 100644 new mode 100755 index b78a644074f..7edcdda2636 --- a/egs/wsj/s5/steps/data/augment_data_dir.py +++ b/egs/wsj/s5/steps/data/augment_data_dir.py @@ -103,8 +103,8 @@ def AugmentWav(utt, wav, dur, fg_snr_opts, bg_snr_opts, fg_noise_utts, \ tot_noise_dur += noise_dur + interval noises.append(noise) - start_times_str = "--start-times='" + ",".join(map(str,start_times)) + "'" - snrs_str = "--snrs='" + ",".join(map(str,snrs)) + "'" + start_times_str = "--start-times='" + ",".join([str(i) for i in start_times]) + "'" + snrs_str = "--snrs='" + ",".join([str(i) for i in snrs]) + "'" noises_str = "--additive-signals='" + ",".join(noises).strip() + "'" # If the wav is just a file @@ -130,11 +130,11 @@ def CopyFileIfExists(utt_suffix, filename, input_dir, output_dir): def main(): args = GetArgs() - fg_snrs = map(int, args.fg_snr_str.split(":")) - bg_snrs = map(int, args.bg_snr_str.split(":")) + fg_snrs = [int(i) for i in args.fg_snr_str.split(":")] + bg_snrs = [int(i) for i in args.bg_snr_str.split(":")] input_dir = args.input_dir output_dir = args.output_dir - num_bg_noises = map(int, args.num_bg_noises.split(":")) + num_bg_noises = [int(i) for i in args.num_bg_noises.split(":")] reco2dur = ParseFileToDict(input_dir + "/reco2dur", value_processor = lambda x: float(x[0])) wav_scp_file = open(input_dir + "/wav.scp", 'r').readlines() diff --git a/egs/wsj/s5/steps/data/reverberate_data_dir.py b/egs/wsj/s5/steps/data/reverberate_data_dir.py index 71e64d9e680..b1745a4b723 100755 --- a/egs/wsj/s5/steps/data/reverberate_data_dir.py +++ b/egs/wsj/s5/steps/data/reverberate_data_dir.py @@ -1,10 +1,10 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # Copyright 2016 Tom Ko +# 2018 David Snyder # Apache 2.0 # script to generate reverberated data # we're using python 3.x style print but want it to work in python 2.x, -from __future__ import print_function import argparse, shlex, glob, math, os, random, sys, warnings, copy, imp, ast data_lib = imp.load_source('dml', 'steps/data/data_dir_manipulation_lib.py') @@ -120,17 +120,18 @@ def CheckArgs(args): return args -class list_cyclic_iterator: +class list_cyclic_iterator(object): def __init__(self, list): self.list_index = 0 self.list = list random.shuffle(self.list) - def next(self): + def __next__(self): item = self.list[self.list_index] self.list_index = (self.list_index + 1) % len(self.list) return item + next = __next__ # for Python 2 # This functions picks an item from the collection according to the associated probability distribution. # The probability estimate of each item in the collection is stored in the "probability" field of @@ -167,14 +168,13 @@ def ParseFileToDict(file, assert2fields = False, value_processor = None): # This function creates a file and write the content of a dictionary into it def WriteDictToFile(dict, file_name): file = open(file_name, 'w') - keys = dict.keys() - keys.sort() + keys = sorted(dict.keys()) for key in keys: value = dict[key] if type(value) in [list, tuple] : if type(value) is tuple: value = list(value) - value.sort() + value = sorted(value) value = ' '.join(str(value)) file.write('{0} {1}\n'.format(key, value)) file.close() @@ -185,8 +185,7 @@ def CreateCorruptedUtt2uniq(input_dir, output_dir, num_replicas, include_origina corrupted_utt2uniq = {} # Parse the utt2spk to get the utterance id utt2spk = ParseFileToDict(input_dir + "/utt2spk", value_processor = lambda x: " ".join(x)) - keys = utt2spk.keys() - keys.sort() + keys = sorted(utt2spk.keys()) if include_original: start_index = 0 else: @@ -219,11 +218,11 @@ def AddPointSourceNoise(noise_addition_descriptor, # descriptor to store the in if noise.bg_fg_type == "background": noise_rvb_command = """wav-reverberate --impulse-response="{0}" --duration={1}""".format(noise_rir.rir_rspecifier, speech_dur) noise_addition_descriptor['start_times'].append(0) - noise_addition_descriptor['snrs'].append(background_snrs.next()) + noise_addition_descriptor['snrs'].append(next(background_snrs)) else: noise_rvb_command = """wav-reverberate --impulse-response="{0}" """.format(noise_rir.rir_rspecifier) noise_addition_descriptor['start_times'].append(round(random.random() * speech_dur, 2)) - noise_addition_descriptor['snrs'].append(foreground_snrs.next()) + noise_addition_descriptor['snrs'].append(next(foreground_snrs)) # check if the rspecifier is a pipe or not if len(noise.noise_rspecifier.split()) == 1: @@ -274,7 +273,7 @@ def GenerateReverberationOpts(room_dict, # the room dictionary, please refer to else: noise_addition_descriptor['noise_io'].append("{0} wav-reverberate --duration={1} - - |".format(isotropic_noise.noise_rspecifier, speech_dur)) noise_addition_descriptor['start_times'].append(0) - noise_addition_descriptor['snrs'].append(background_snrs.next()) + noise_addition_descriptor['snrs'].append(next(background_snrs)) noise_addition_descriptor = AddPointSourceNoise(noise_addition_descriptor, # descriptor to store the information of the noise added room, # the room selected @@ -290,8 +289,8 @@ def GenerateReverberationOpts(room_dict, # the room dictionary, please refer to assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['snrs']) if len(noise_addition_descriptor['noise_io']) > 0: reverberate_opts += "--additive-signals='{0}' ".format(','.join(noise_addition_descriptor['noise_io'])) - reverberate_opts += "--start-times='{0}' ".format(','.join(map(lambda x:str(x), noise_addition_descriptor['start_times']))) - reverberate_opts += "--snrs='{0}' ".format(','.join(map(lambda x:str(x), noise_addition_descriptor['snrs']))) + reverberate_opts += "--start-times='{0}' ".format(','.join([str(x) for x in noise_addition_descriptor['start_times']])) + reverberate_opts += "--snrs='{0}' ".format(','.join([str(x) for x in noise_addition_descriptor['snrs']])) return reverberate_opts @@ -331,8 +330,7 @@ def GenerateReverberatedWavScp(wav_scp, # a dictionary whose values are the Kal foreground_snrs = list_cyclic_iterator(foreground_snr_array) background_snrs = list_cyclic_iterator(background_snr_array) corrupted_wav_scp = {} - keys = wav_scp.keys() - keys.sort() + keys = sorted(wav_scp.keys()) if include_original: start_index = 0 else: @@ -373,8 +371,8 @@ def GenerateReverberatedWavScp(wav_scp, # a dictionary whose values are the Kal # This function replicate the entries in files like segments, utt2spk, text def AddPrefixToFields(input_file, output_file, num_replicas, include_original, prefix, field = [0]): - list = map(lambda x: x.strip(), open(input_file)) - f = open(output_file, "w") + list = [x.strip() for x in open(input_file, encoding='utf-8')] + f = open(output_file, "w" ,encoding='utf-8') if include_original: start_index = 0 else: @@ -413,16 +411,10 @@ def CreateReverberatedCopy(input_dir, wav_scp = ParseFileToDict(input_dir + "/wav.scp", value_processor = lambda x: " ".join(x)) if not os.path.isfile(input_dir + "/reco2dur"): print("Getting the duration of the recordings..."); - read_entire_file="false" - for value in wav_scp.values(): - # we will add more checks for sox commands which modify the header as we come across these cases in our data - if "sox" in value and "speed" in value: - read_entire_file="true" - break - data_lib.RunKaldiCommand("wav-to-duration --read-entire-file={1} scp:{0}/wav.scp ark,t:{0}/reco2dur".format(input_dir, read_entire_file)) + data_lib.RunKaldiCommand("utils/data/get_reco2dur.sh {}".format(input_dir)) durations = ParseFileToDict(input_dir + "/reco2dur", value_processor = lambda x: float(x[0])) - foreground_snr_array = map(lambda x: float(x), foreground_snr_string.split(':')) - background_snr_array = map(lambda x: float(x), background_snr_string.split(':')) + foreground_snr_array = [float(x) for x in foreground_snr_string.split(':')] + background_snr_array = [float(x) for x in background_snr_string.split(':')] GenerateReverberatedWavScp(wav_scp, durations, output_dir, room_dict, pointsource_noise_list, iso_noise_dict, foreground_snr_array, background_snr_array, num_replicas, include_original, prefix, @@ -451,11 +443,11 @@ def CreateReverberatedCopy(input_dir, # This function smooths the probability distribution in the list -def SmoothProbabilityDistribution(list, smoothing_weight=0.0, target_sum=1.0): - if len(list) > 0: +def SmoothProbabilityDistribution(set_list, smoothing_weight=0.0, target_sum=1.0): + if len(list(set_list)) > 0: num_unspecified = 0 accumulated_prob = 0 - for item in list: + for item in set_list: if item.probability is None: num_unspecified += 1 else: @@ -469,7 +461,7 @@ def SmoothProbabilityDistribution(list, smoothing_weight=0.0, target_sum=1.0): warnings.warn("The sum of probabilities specified by user is larger than or equal to 1. " "The items without probabilities specified will be given zero to their probabilities.") - for item in list: + for item in set_list: if item.probability is None: item.probability = uniform_probability else: @@ -477,11 +469,11 @@ def SmoothProbabilityDistribution(list, smoothing_weight=0.0, target_sum=1.0): item.probability = (1 - smoothing_weight) * item.probability + smoothing_weight * uniform_probability # Normalize the probability - sum_p = sum(item.probability for item in list) - for item in list: + sum_p = sum(item.probability for item in set_list) + for item in set_list: item.probability = item.probability / sum_p * target_sum - return list + return set_list # This function parse the array of rir set parameter strings. @@ -527,7 +519,7 @@ def ParseRirList(rir_set_para_array, smoothing_weight, sampling_rate = None): rir_list = [] for rir_set in set_list: - current_rir_list = map(lambda x: rir_parser.parse_args(shlex.split(x.strip())),open(rir_set.filename)) + current_rir_list = [rir_parser.parse_args(shlex.split(x.strip())) for x in open(rir_set.filename)] for rir in current_rir_list: if sampling_rate is not None: # check if the rspecifier is a pipe or not @@ -592,7 +584,7 @@ def ParseNoiseList(noise_set_para_array, smoothing_weight, sampling_rate = None) pointsource_noise_list = [] iso_noise_dict = {} for noise_set in set_list: - current_noise_list = map(lambda x: noise_parser.parse_args(shlex.split(x.strip())),open(noise_set.filename)) + current_noise_list = [noise_parser.parse_args(shlex.split(x.strip())) for x in open(noise_set.filename)] current_pointsource_noise_list = [] for noise in current_noise_list: if sampling_rate is not None: diff --git a/egs/wsj/s5/steps/diagnostic/analyze_lats.sh b/egs/wsj/s5/steps/diagnostic/analyze_lats.sh index d580f516527..df1a6d64801 100755 --- a/egs/wsj/s5/steps/diagnostic/analyze_lats.sh +++ b/egs/wsj/s5/steps/diagnostic/analyze_lats.sh @@ -50,10 +50,10 @@ $cmd JOB=1:$num_jobs $dir/log/lattice_best_path.JOB.log \ $cmd JOB=1:$num_jobs $dir/log/get_lattice_stats.JOB.log \ ali-to-phones --write-lengths=true "$model" "ark:gunzip -c $dir/ali_tmp.JOB.gz|" ark,t:- \| \ - sed -E 's/^[^ ]+ //' \| \ - awk 'BEGIN{FS=" ; "; OFS="\n";} {print "begin " $1; if (NF>1) print "end " $NF; for (n=1;n<=NF;n++) print "all " $n; }' \| \ - sort \| uniq -c \| gzip -c '>' $dir/phone_stats.JOB.gz || exit 1 - + perl -ne 'chomp;s/^\S+\s*//;@a=split /\s;\s/, $_;$count{"begin ".$a[$0]."\n"}++; + if(@a>1){$count{"end ".$a[-1]."\n"}++;}for($i=0;$i<@a;$i++){$count{"all ".$a[$i]."\n"}++;} + END{for $k (sort keys %count){print "$count{$k} $k"}}' \| \ + gzip -c '>' $dir/phone_stats.JOB.gz || exit 1 $cmd $dir/log/analyze_alignments.log \ gunzip -c "$dir/phone_stats.*.gz" \| \ @@ -67,16 +67,16 @@ echo "$0: see stats in $dir/log/analyze_alignments.log" # escaped since it needs to be passed to $cmd. # the 'paste' command will paste together the phone-indexes and the depths # so that one line will be like utt-id1 phone1 phone2 phone3 .. utt-id1 depth1 depth2 depth3 ... -# the awk command computes counts of pairs (phone, lattice-depth) and outputs lines +# the following command computes counts of pairs (phone, lattice-depth) and outputs lines # containing 3 integers representing: # phone lattice_depth, count[phone,lattice_depth] $cmd JOB=1:$num_jobs $dir/log/lattice_best_path.JOB.log \ ali-to-phones --per-frame=true "$model" "ark:gunzip -c $dir/ali_tmp.JOB.gz|" ark,t:- \| \ paste /dev/stdin '<(' gunzip -c $dir/depth_tmp.JOB.gz ')' \| \ - awk '{ half=NF/2; for (n=2; n<=half; n++) { m=n+half; count[$n " " $m]++;}} END{for(k in count) print k, count[k]; }' \| \ + perl -ane '$half=@F/2;for($i=1;$i<$half;$i++){$j=$i+$half;$count{$F[$i]." ".$F[$j]}++;} + END{for $k (sort keys %count){print "$k $count{$k}\n"}}' \| \ gzip -c '>' $dir/depth_stats_tmp.JOB.gz - $cmd $dir/log/analyze_lattice_depth_stats.log \ gunzip -c "$dir/depth_stats_tmp.*.gz" \| \ steps/diagnostic/analyze_lattice_depth_stats.py $lang || exit 1 diff --git a/egs/wsj/s5/steps/diagnostic/analyze_lattice_depth_stats.py b/egs/wsj/s5/steps/diagnostic/analyze_lattice_depth_stats.py index 56b9f69b3c9..6ed2bf78115 100755 --- a/egs/wsj/s5/steps/diagnostic/analyze_lattice_depth_stats.py +++ b/egs/wsj/s5/steps/diagnostic/analyze_lattice_depth_stats.py @@ -5,6 +5,7 @@ # Apache 2.0. from __future__ import print_function +from __future__ import division import argparse import sys, os from collections import defaultdict diff --git a/egs/wsj/s5/steps/dict/apply_g2p_phonetisaurus.sh b/egs/wsj/s5/steps/dict/apply_g2p_phonetisaurus.sh new file mode 100755 index 00000000000..a793f91fd0a --- /dev/null +++ b/egs/wsj/s5/steps/dict/apply_g2p_phonetisaurus.sh @@ -0,0 +1,99 @@ +#!/bin/bash +# Copyright 2014 Johns Hopkins University (Author: Yenda Trmal) +# Copyright 2016 Xiaohui Zhang +# 2018 Ruizhe Huang +# Apache 2.0 + +# This script applies a trained Phonetisarus G2P model to +# synthesize pronunciations for missing words (i.e., words in +# transcripts but not the lexicon), and output the expanded lexicon. +# The user could specify either nbest or pmass option +# to determine the number of output pronunciation variants, +# or use them together to get the intersection of two options. + +# Begin configuration section. +stage=0 +nbest= # Generate up to N, like N=3, pronunciation variants for each word + # (The maximum size of the nbest list, not considering pruning and taking the prob-mass yet). +thresh=5 # Pruning threshold for the n-best list, in (0, 99], which is a -log-probability value. + # A large threshold makes the nbest list shorter, and less likely to hit the max size. + # This value corresponds to the weight_threshold in shortest-path.h of openfst. +pmass= # Select the top variants from the pruned nbest list, + # summing up to this total prob-mass for a word. + # On the "boundary", it's greedy by design, e.g. if pmass = 0.8, + # and we have prob(pron_1) = 0.5, and prob(pron_2) = 0.4, then we get both. +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +[ -f ./path.sh ] && . ./path.sh; # source the path. +. utils/parse_options.sh || exit 1; + +set -u +set -e + +if [ $# != 3 ]; then + echo "Usage: $0 [options] " + echo "... where is a list of words whose pronunciation is to be generated." + echo " is a directory used as a target during training of G2P" + echo " is the directory where the output lexicon should be stored." + echo " The format of the output lexicon output-dir/lexicon.lex is" + echo " \t\t per line." + echo "e.g.: $0 --nbest 1 exp/g2p/oov_words.txt exp/g2p exp/g2p/oov_lex" + echo "" + echo "main options (for others, see top of script file)" + echo " --nbest # Generate upto N pronunciation variants for each word." + echo " --pmass # Select the top variants from the pruned nbest list," + echo " # summing up to this total prob-mass, within [0, 1], for a word." + echo " --thresh # Pruning threshold for n-best." + exit 1; +fi + +wordlist=$1 +modeldir=$2 +outdir=$3 + +model=$modeldir/model.fst +output_lex=$outdir/lexicon.lex +mkdir -p $outdir + +[ ! -f ${model:-} ] && echo "$0: File $model not found in the directory $modeldir." && exit 1 +[ ! -f $wordlist ] && echo "$0: File $wordlist not found!" && exit 1 +[ -z $pmass ] && [ -z $nbest ] && echo "$0: nbest or/and pmass should be specified." && exit 1; +if ! phonetisaurus=`which phonetisaurus-apply` ; then + echo "Phonetisarus was not found !" + echo "Go to $KALDI_ROOT/tools and execute extras/install_phonetisaurus.sh" + exit 1 +fi + +cp $wordlist $outdir/wordlist.txt + +# three options: 1) nbest, 2) pmass, 3) nbest+pmass, +nbest=${nbest:-20} # if nbest is not specified, set it to 20, due to Phonetisaurus mechanism +pmass=${pmass:-1.0} # if pmass is not specified, set it to 1.0, due to Phonetisaurus mechanism + +[[ ! $nbest =~ ^[1-9][0-9]*$ ]] && echo "$0: nbest should be a positive integer." && exit 1; + +echo "Applying the G2P model to wordlist $wordlist" +phonetisaurus-apply --pmass $pmass --nbest $nbest --thresh $thresh \ + --word_list $wordlist --model $model \ + --accumulate --verbose --prob \ + 1>$output_lex + +echo "Completed. Synthesized lexicon for new words is in $output_lex" + +# Some words might have been removed or skipped during the process, +# let's check it and warn the user if so... +nlex=`cut -f 1 $output_lex | sort -u | wc -l` +nwlist=`cut -f 1 $wordlist | sort -u | wc -l` +if [ $nlex -ne $nwlist ] ; then + failed_wordlist=$outdir/lexicon.failed + echo "WARNING: Unable to generate pronunciation for all words. "; + echo "WARINNG: Wordlist: $nwlist words" + echo "WARNING: Lexicon : $nlex words" + comm -13 <(cut -f 1 $output_lex | sort -u ) \ + <(cut -f 1 $wordlist | sort -u ) \ + >$failed_wordlist && echo "WARNING: The list of failed words is in $failed_wordlist" +fi +exit 0 + diff --git a/egs/wsj/s5/steps/dict/apply_lexicon_edits.py b/egs/wsj/s5/steps/dict/apply_lexicon_edits.py index a5bdbc30d46..f8568971fb7 100755 --- a/egs/wsj/s5/steps/dict/apply_lexicon_edits.py +++ b/egs/wsj/s5/steps/dict/apply_lexicon_edits.py @@ -10,7 +10,7 @@ def GetArgs(): parser = argparse.ArgumentParser(description = "Apply an lexicon edits file (output from steps/dict/select_prons_bayesian.py)to an input lexicon" "to produce a learned lexicon.", - epilog = "See steps/dict/learn_lexicon.sh for example") + epilog = "See steps/dict/learn_lexicon_greedy.sh for example") parser.add_argument("in_lexicon", metavar='', type = str, help = "Input lexicon. Each line must be .") diff --git a/egs/wsj/s5/steps/dict/get_pron_stats.py b/egs/wsj/s5/steps/dict/get_pron_stats.py index b5202a69abb..e8106bdd1ac 100755 --- a/egs/wsj/s5/steps/dict/get_pron_stats.py +++ b/egs/wsj/s5/steps/dict/get_pron_stats.py @@ -10,15 +10,16 @@ import sys def GetArgs(): - parser = argparse.ArgumentParser(description = "Accumulate statistics from lattice-alignment outputs for lexicon" - "learning. The inputs are a file containing arc level information from lattice-align-words," - "and a map which maps word-position-dependent phones to word-position-independent phones" - "(output from steps/cleanup/debug_lexicon.txt). The output contains accumulated soft-counts" - "of pronunciations", - epilog = "cat exp/tri3_lex_0.4_work/lats/arc_info_sym.*.txt \\|" - " steps/dict/get_pron_stats.py - exp/tri3_lex_0.4_work/phone_decode/phone_map.txt \\" - " exp/tri3_lex_0.4_work/lats/pron_stats.txt" - "See steps/dict/learn_lexicon.sh for examples in detail.") + parser = argparse.ArgumentParser( + description = "Accumulate statistics from lattice-alignment outputs for lexicon" + "learning. The inputs are a file containing arc level information from lattice-align-words," + "and a map which maps word-position-dependent phones to word-position-independent phones" + "(output from steps/cleanup/debug_lexicon.txt). The output contains accumulated soft-counts" + "of pronunciations", + epilog = "cat exp/tri3_lex_0.4_work/lats/arc_info_sym.*.txt \\|" + " steps/dict/get_pron_stats.py - exp/tri3_lex_0.4_work/phone_decode/phone_map.txt \\" + " exp/tri3_lex_0.4_work/lats/pron_stats.txt" + "See steps/dict/learn_lexicon_greedy.sh for examples in detail.") parser.add_argument("arc_info_file", metavar = "", type = str, help = "Input file containing per arc statistics; " @@ -75,14 +76,14 @@ def GetStatsFromArcInfo(arc_info_file_handle, phone_map_handle): prons[word].add(phones) stats_unmapped[(word, phones)] = stats_unmapped.get((word, phones), 0) + count - for word_pron, count in stats_unmapped.iteritems(): + for word_pron, count in stats_unmapped.items(): phones_unmapped = word_pron[1].split() phones = [phone_map[phone] for phone in phones_unmapped] stats[(word_pron[0], " ".join(phones))] = count return stats def WriteStats(stats, file_handle): - for word_pron, count in stats.iteritems(): + for word_pron, count in stats.items(): print('{2} {0} {1}'.format(word_pron[0], word_pron[1], count), file=file_handle) file_handle.close() diff --git a/egs/wsj/s5/steps/dict/internal/get_subsegments.py b/egs/wsj/s5/steps/dict/internal/get_subsegments.py new file mode 100755 index 00000000000..c431b4c7066 --- /dev/null +++ b/egs/wsj/s5/steps/dict/internal/get_subsegments.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python + +# Copyright 2018 Xiaohui Zhang +# Apache 2.0. + +# we're using python 3.x style print but want it to work in python 2.x, +from __future__ import print_function +import argparse +import sys +import string + +def GetArgs(): + parser = argparse.ArgumentParser( + description = "The purpose of this script is to use a ctm and a vocab file" + "to extract sub-utterances and a sub-segmentation. Extracted sub-utterances" + "are all the strings of consecutive in-vocab words from the ctm" + "surrounded by an out-of-vocab word at each end if present.", + epilog = "e.g. steps/dict/internal/get_subsegments.py exp/tri3_lex_0.4_work/phonetic_decoding/word.ctm \\" + "exp/tri3_lex_0.4_work/learn_vocab.txt exp/tri3_lex_0.4_work/resegmentation/subsegments \\" + "exp/tri3_lex_0.4_work/resegmentation/text" + "See steps/dict/learn_lexicon_greedy.sh for an example.") + + parser.add_argument("ctm", metavar='', type = str, + help = "Input ctm file." + "each line must be ") + parser.add_argument("vocab", metavar='', type = str, + help = "Vocab file." + "each line must be ") + parser.add_argument("subsegment", metavar='', type = str, + help = "Subsegment file. Each line is in format:" + " ") + parser.add_argument("text", metavar='', type = str, + help = "Text file. Each line is in format:" + " ... .") + + print (' '.join(sys.argv), file = sys.stderr) + + args = parser.parse_args() + args = CheckArgs(args) + + return args + +def CheckArgs(args): + if args.ctm == "-": + args.ctm_handle = sys.stdin + else: + args.ctm_handle = open(args.ctm) + + if args.vocab is not '': + if args.vocab == "-": + args.vocab_handle = sys.stdout + else: + args.vocab_handle = open(args.vocab) + + args.subsegment_handle = open(args.subsegment, 'w') + args.text_handle = open(args.text, 'w') + + return args + +def GetSubsegments(args, vocab): + sub_utt = list() + last_is_oov = False + is_oov = False + utt_id_last = None + start_times = {} + end_times = {} + sub_utts = {} + sub_utt_id = 1 + sub_utt_id_last = 1 + end_time_last = 0.0 + for line in args.ctm_handle: + splits = line.strip().split() + if len(splits) < 5: + raise Exception("problematic line",line) + + utt_id = splits[0] + start = float(splits[2]) + dur = float(splits[3]) + word = splits[4] + if utt_id != utt_id_last: + sub_utt_id = 1 + if len(sub_utt)>1: + sub_utts[utt_id_last+'-'+str(sub_utt_id_last)] = (utt_id_last, sub_utt) + end_times[utt_id_last+'-'+str(sub_utt_id_last)] = ent_time_last + sub_utt = [] + start_times[utt_id+'-'+str(sub_utt_id)] = start + is_oov_last = False + if word == '': + is_oov = True + end_times[utt_id+'-'+str(sub_utt_id)] = start + dur + elif word in vocab: + is_oov = True + sub_utt.append(word) + end_times[utt_id+'-'+str(sub_utt_id)] = start + dur + else: + is_oov = False + if is_oov_last == True: + sub_utt.append(word) + sub_utts[utt_id+'-'+str(sub_utt_id_last)] = (utt_id, sub_utt) + end_times[utt_id+'-'+str(sub_utt_id_last)] = start + dur + sub_utt_id += 1 + sub_utt = [word] + start_times[utt_id+'-'+str(sub_utt_id)] = start + utt_id_last = utt_id + sub_utt_id_last = sub_utt_id + is_oov_last = is_oov + ent_time_last = start + dur + + if is_oov: + if word != '': + sub_utt.append(word) + sub_utts[utt_id+'-'+str(sub_utt_id_last)] = (utt_id, sub_utt) + end_times[utt_id+'-'+str(sub_utt_id_last)] = start + dur + + for utt,v in sorted(sub_utts.items()): + print(utt, ' '.join(sub_utts[utt][1]), file=args.text_handle) + print(utt, sub_utts[utt][0], start_times[utt], end_times[utt], file=args.subsegment_handle) + +def ReadVocab(vocab_file_handle): + vocab = set() + if vocab_file_handle: + for line in vocab_file_handle.readlines(): + splits = line.strip().split() + if len(splits) == 0: + continue + if len(splits) > 1: + raise Exception('Invalid format of line ' + line + + ' in vocab file.') + word = splits[0] + vocab.add(word) + return vocab + +def Main(): + args = GetArgs() + + vocab = ReadVocab(args.vocab_handle) + GetSubsegments(args, vocab) + +if __name__ == "__main__": + Main() diff --git a/egs/wsj/s5/steps/dict/internal/prune_pron_candidates.py b/egs/wsj/s5/steps/dict/internal/prune_pron_candidates.py index 1f2863424f3..60c7f75bbe8 100755 --- a/egs/wsj/s5/steps/dict/internal/prune_pron_candidates.py +++ b/egs/wsj/s5/steps/dict/internal/prune_pron_candidates.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright 2016 Xiaohui Zhang +# Copyright 2018 Xiaohui Zhang # Apache 2.0. from __future__ import print_function @@ -10,27 +10,36 @@ import math def GetArgs(): - parser = argparse.ArgumentParser(description = "Prune pronunciation candidates based on soft-counts from lattice-alignment" - "outputs, and a reference lexicon. Basically, for each word we sort all pronunciation" - "cadidates according to their soft-counts, and then select the top r * N candidates" - "(For words in the reference lexicon, N = # pron variants given by the reference" - "lexicon; For oov words, N = avg. # pron variants per word in the reference lexicon)." - "r is a user-specified constant, like 2.", - epilog = "See steps/dict/learn_lexicon.sh for example") - - parser.add_argument("--r", type = float, default = "2.0", - help = "a user-specified ratio parameter which determines how many" - "pronunciation candidates we want to keep for each word.") + parser = argparse.ArgumentParser( + description = "Prune pronunciation candidates based on soft-counts from lattice-alignment" + "outputs, and a reference lexicon. Basically, for each word we sort all pronunciation" + "cadidates according to their soft-counts, and then select the top variant-counts-ratio * N candidates" + "(For words in the reference lexicon, N = # pron variants given by the reference" + "lexicon; For oov words, N = avg. # pron variants per word in the reference lexicon).", + epilog = "See steps/dict/learn_lexicon_greedy.sh for example") + + parser.add_argument("--variant-counts-ratio", type = float, default = "3.0", + help = "A user-specified ratio parameter which determines how many" + "pronunciation candidates we want to keep for each word at most.") parser.add_argument("pron_stats", metavar = "", type = str, - help = "File containing soft-counts of all pronounciation candidates; " + help = "File containing soft-counts of pronounciation candidates; " "each line must be ") + parser.add_argument("lexicon_phonetic_decoding", metavar = "", type = str, + help = "Lexicon containing pronunciation candidates from phonetic decoding." + "each line must be ") + parser.add_argument("lexiconp_g2p", metavar = "", type = str, + help = "Lexicon with probabilities for pronunciation candidates from G2P." + "each line must be ") parser.add_argument("ref_lexicon", metavar = "", type = str, help = "Reference lexicon file, where we obtain # pron variants for" "each word, based on which we prune the pron candidates." "Each line must be ") - parser.add_argument("pruned_prons", metavar = "", type = str, - help = "An output file in lexicon format, which contains prons we want to" - "prune off from the pron_stats file.") + parser.add_argument("lexicon_phonetic_decoding_pruned", metavar = "", type = str, + help = "Output lexicon containing pronunciation candidates from phonetic decoding after pruning." + "each line must be ") + parser.add_argument("lexicon_g2p_pruned", metavar = "", type = str, + help = "Output lexicon containing pronunciation candidates from G2P after pruning." + "each line must be ") print (' '.join(sys.argv), file=sys.stderr) @@ -40,12 +49,13 @@ def GetArgs(): return args def CheckArgs(args): + print(args) args.pron_stats_handle = open(args.pron_stats) + args.lexicon_phonetic_decoding_handle = open(args.lexicon_phonetic_decoding) + args.lexiconp_g2p_handle = open(args.lexiconp_g2p) args.ref_lexicon_handle = open(args.ref_lexicon) - if args.pruned_prons == "-": - args.pruned_prons_handle = sys.stdout - else: - args.pruned_prons_handle = open(args.pruned_prons, "w") + args.lexicon_phonetic_decoding_pruned_handle = open(args.lexicon_phonetic_decoding_pruned, "w") + args.lexicon_g2p_pruned_handle = open(args.lexicon_g2p_pruned, "w") return args def ReadStats(pron_stats_handle): @@ -62,13 +72,11 @@ def ReadStats(pron_stats_handle): phones = ' '.join(splits[2:]) stats[word].append((phones, count)) - for word, entry in stats.iteritems(): - entry.sort(key=lambda x: x[1]) return stats -def ReadLexicon(ref_lexicon_handle): - ref_lexicon = defaultdict(set) - for line in ref_lexicon_handle.readlines(): +def ReadLexicon(lexicon_handle): + lexicon = defaultdict(set) + for line in lexicon_handle.readlines(): splits = line.strip().split() if len(splits) == 0: continue @@ -77,42 +85,74 @@ def ReadLexicon(ref_lexicon_handle): + ' in lexicon file.') word = splits[0] phones = ' '.join(splits[1:]) - ref_lexicon[word].add(phones) - return ref_lexicon + lexicon[word].add(phones) + return lexicon -def PruneProns(args, stats, ref_lexicon): +def ReadLexiconp(lexiconp_handle): + lexicon = defaultdict(set) + pron_probs = defaultdict(float) + for line in lexiconp_handle.readlines(): + splits = line.strip().split() + if len(splits) == 0: + continue + if len(splits) < 3: + raise Exception('Invalid format of line ' + line + + ' in lexicon file.') + word = splits[1] + prob = float(splits[0]) + phones = ' '.join(splits[2:]) + pron_probs[(word, phones)] = prob + lexicon[word].add(phones) + return lexicon, pron_probs + +def PruneProns(args, stats, ref_lexicon, lexicon_phonetic_decoding, lexicon_g2p, lexicon_g2p_probs): + # For those pron candidates from lexicon_phonetic_decoding/g2p which don't + # have stats, we append them to the "stats" dict, with a zero count. + for word, entry in stats.iteritems(): + prons_with_stats = set() + for (pron, count) in entry: + prons_with_stats.add(pron) + for pron in lexicon_g2p[word]: + if pron not in prons_with_stats: + entry.append((pron, lexicon_g2p_probs[(word, pron)]-1.0)) + entry.sort(key=lambda x: x[1]) + # Compute the average # pron variants counts per word in the reference lexicon. num_words_ref = 0 num_prons_ref = 0 for word, prons in ref_lexicon.iteritems(): num_words_ref += 1 num_prons_ref += len(prons) - avg_variants_counts_ref = math.ceil(float(num_prons_ref) / float(num_words_ref)) - + avg_variant_counts_ref = round(float(num_prons_ref) / float(num_words_ref)) for word, entry in stats.iteritems(): if word in ref_lexicon: - variants_counts = args.r * len(ref_lexicon[word]) + variant_counts = args.variant_counts_ratio * len(ref_lexicon[word]) else: - variants_counts = args.r * avg_variants_counts_ref + variant_counts = args.variant_counts_ratio * avg_variant_counts_ref num_variants = 0 - while num_variants < variants_counts: + count = 0.0 + while num_variants < variant_counts: try: - pron, prob = entry.pop() - if word not in ref_lexicon or pron not in ref_lexicon[word]: + pron, count = entry.pop() + if word in ref_lexicon and pron in ref_lexicon[word]: + continue + if pron in lexicon_phonetic_decoding[word]: + num_variants += 1 + print('{0} {1}'.format(word, pron), file=args.lexicon_phonetic_decoding_pruned_handle) + if pron in lexicon_g2p[word]: num_variants += 1 + print('{0} {1}'.format(word, pron), file=args.lexicon_g2p_pruned_handle) except IndexError: break - - for word, entry in stats.iteritems(): - for pron, prob in entry: - if word not in ref_lexicon or pron not in ref_lexicon[word]: - print('{0} {1}'.format(word, pron), file=args.pruned_prons_handle) def Main(): args = GetArgs() ref_lexicon = ReadLexicon(args.ref_lexicon_handle) + lexicon_phonetic_decoding = ReadLexicon(args.lexicon_phonetic_decoding_handle) + lexicon_g2p, lexicon_g2p_probs = ReadLexiconp(args.lexiconp_g2p_handle) stats = ReadStats(args.pron_stats_handle) - PruneProns(args, stats, ref_lexicon) + + PruneProns(args, stats, ref_lexicon, lexicon_phonetic_decoding, lexicon_g2p, lexicon_g2p_probs) if __name__ == "__main__": Main() diff --git a/egs/wsj/s5/steps/dict/internal/sum_arc_info.py b/egs/wsj/s5/steps/dict/internal/sum_arc_info.py new file mode 100755 index 00000000000..5f02bc5fc29 --- /dev/null +++ b/egs/wsj/s5/steps/dict/internal/sum_arc_info.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python + +# Copyright 2018 Xiaohui Zhang +# Apache 2.0 + +from __future__ import print_function +from collections import defaultdict +import argparse +import sys + +class StrToBoolAction(argparse.Action): + """ A custom action to convert bools from shell format i.e., true/false + to python format i.e., True/False """ + def __call__(self, parser, namespace, values, option_string=None): + if values == "true": + setattr(namespace, self.dest, True) + elif values == "false": + setattr(namespace, self.dest, False) + else: + raise Exception("Unknown value {0} for --{1}".format(values, self.dest)) + + +def GetArgs(): + parser = argparse.ArgumentParser( + description = "Accumulate statistics from per arc lattice statitics" + "for lexicon learning", + epilog = "See steps/dict/learn_lexicon_greedy.sh for example") + + parser.add_argument("--set-sum-to-one", type = str, default = True, + action = StrToBoolAction, choices = ["true", "false"], + help = "If normalize posteriors such that the sum of " + "pronunciation posteriors of a word in an utterance is 1.") + parser.add_argument("arc_info_file", metavar = "", type = str, + help = "File containing per arc statistics; " + "each line must be " + "") + parser.add_argument("phone_map", metavar = "", type = str, + help = "An input phone map used to remove word boundary markers from phones;" + "generated in steps/cleanup/debug_lexicon.sh") + parser.add_argument("stats_file", metavar = "", type = str, + help = "Write accumulated statitistics to this file" + "each line is " + "") + + print (' '.join(sys.argv), file=sys.stderr) + + args = parser.parse_args() + args = CheckArgs(args) + + return args + +def CheckArgs(args): + if args.arc_info_file == "-": + args.arc_info_file_handle = sys.stdin + else: + args.arc_info_file_handle = open(args.arc_info_file) + + args.phone_map_handle = open(args.phone_map) + + if args.stats_file == "-": + args.stats_file_handle = sys.stdout + else: + args.stats_file_handle = open(args.stats_file, "w") + + return args + +def Main(): + args = GetArgs() + + lexicon = defaultdict(list) + prons = defaultdict(list) + start_frames = {} + stats = defaultdict(lambda : defaultdict(float)) + sum_tot = defaultdict(float) + + phone_map = {} + for line in args.phone_map_handle.readlines(): + splits = line.strip().split() + phone_map[splits[0]] = splits[1] + + for line in args.arc_info_file_handle.readlines(): + splits = line.strip().split() + + if (len(splits) == 0): + continue + + if (len(splits) < 6): + raise Exception('Invalid format of line ' + line + + ' in ' + args.arc_info_file) + + utt = splits[0] + start_frame = int(splits[1]) + word = splits[4] + count = float(splits[3]) + phones_unmapped = splits[5:] + phones = [phone_map[phone] for phone in phones_unmapped] + phones = ' '.join(phones) + overlap = False + if word == '': + continue + if (word, utt) not in start_frames: + start_frames[(word, utt)] = start_frame + + if (word, utt) in stats: + stats[word, utt][phones] = stats[word, utt].get(phones, 0) + count + else: + stats[(word, utt)][phones] = count + sum_tot[(word, utt)] += count + + if phones not in prons[word]: + prons[word].append(phones) + + for (word, utt) in stats: + count_sum = 0.0 + counts = dict() + for phones in stats[(word, utt)]: + count = stats[(word, utt)][phones] + count_sum += count + counts[phones] = count + # By default we normalize the pron posteriors of each word in each utterance, + # so that they sum up exactly to one. If a word occurs two times in a utterance, + # the effect of this operation is to average the posteriors of these two occurences + # so that there's only one "equivalent occurence" of this word in the utterance. + # However, this case should be extremely rare if the utterances are already + # short sub-utterances produced by steps/dict/internal/get_subsegments.py + for phones in stats[(word, utt)]: + count = counts[phones] / count_sum + print(word, utt, start_frames[(word, utt)], count, phones, file=args.stats_file_handle) + # # Diagnostics info implying incomplete arc_info or multiple occurences of a word in a utterance: + # if count_sum < 0.9 or count_sum > 1.1: + # print(word, utt, start_frame, count_sum, stats[word, utt], file=sys.stderr) + + args.stats_file_handle.close() + +if __name__ == "__main__": + Main() diff --git a/egs/wsj/s5/steps/dict/learn_lexicon.sh b/egs/wsj/s5/steps/dict/learn_lexicon_bayesian.sh similarity index 93% rename from egs/wsj/s5/steps/dict/learn_lexicon.sh rename to egs/wsj/s5/steps/dict/learn_lexicon_bayesian.sh index a719422b593..042f8f94da4 100755 --- a/egs/wsj/s5/steps/dict/learn_lexicon.sh +++ b/egs/wsj/s5/steps/dict/learn_lexicon_bayesian.sh @@ -36,6 +36,7 @@ oov_symbol= lexicon_g2p= min_prob=0.3 +variant_counts_ratio=8 variants_prob_mass=0.7 variants_prob_mass_ref=0.9 @@ -93,6 +94,10 @@ if [ $# -lt 6 ] || [ $# -gt 7 ]; then echo " --min-prob # The cut-off parameter used to select pronunciation candidates from phonetic" echo " # decoding. We remove pronunciations with probabilities less than this value" echo " # after normalizing the probs s.t. the max-prob is 1.0 for each word." + echo " --variant-counts-ratio # This ratio parameter determines the maximum number of pronunciation" + echo " # candidates we will keep for each word, after pruning according to lattice statistics from" + echo " # the first iteration of lattice generation. See steps/dict/internal/prune_pron_candidates.py" + echo " # for details." echo " --prior-mean # Mean of priors (summing up to 1) assigned to three exclusive pronunciation" echo " # source: reference lexicon, g2p, and phonetic decoding (used in the Bayesian" echo " # pronunciation selection procedure). We recommend setting a larger prior" @@ -150,17 +155,17 @@ if [ $stage -le 0 ]; then # Remove non-scored-words from the reference lexicon. awk 'NR==FNR{a[$1] = 1; next} {if(!($1 in a)) print $0}' $dir/non_scored_words \ - $ref_dict/lexicon.txt | tr -s '\t' ' ' > $dir/ref_lexicon.txt + $ref_dict/lexicon.txt | tr -s '\t' ' ' | awk '$1=$1' > $dir/ref_lexicon.txt cat $dir/ref_lexicon.txt | awk '{print $1}' | sort | uniq > $dir/ref_vocab.txt awk 'NR==FNR{a[$1] = 1; next} {if(!($1 in a)) print $0}' $dir/non_scored_words \ $target_vocab | sort | uniq > $dir/target_vocab.txt # From the reference lexicon, we estimate the target_num_prons_per_word as, - # ceiling(avg. # prons per word in the reference lexicon). This'll be used as + # round(avg. # prons per word in the reference lexicon). This'll be used as # the upper bound of # pron variants per word when we apply G2P or select prons to # construct the learned lexicon in later stages. - python -c 'import sys; import math; print int(math.ceil(float(sys.argv[1])/float(sys.argv[2])))' \ + python -c 'import sys; import math; print int(round(float(sys.argv[1])/float(sys.argv[2])))' \ `wc -l $dir/ref_lexicon.txt | awk '{print $1}'` `wc -l $dir/ref_vocab.txt | awk '{print $1}'` \ > $dir/target_num_prons_per_word || exit 1; @@ -225,10 +230,11 @@ if [ $stage -le 2 ]; then # Get the oov words list (w.r.t ref vocab) which are in training data. awk 'NR==FNR{a[$1] = 1; next} {if(!($1 in a)) print $1}' $dir/ref_lexicon.txt \ - $dir/train_counts.txt | sort > $dir/oov_train.txt + $dir/train_counts.txt | awk 'NR==FNR{a[$1] = 1; next} {if(!($1 in a)) print $0}' \ + $dir/non_scored_words - | sort > $dir/oov_train.txt || exit 1; awk 'NR==FNR{a[$1] = 1; next} {if(($1 in a)) b+=$2; else c+=$2} END{print c/(b+c)}' \ - $dir/ref_vocab.txt $dir/train_counts.txt > $dir/train_oov_rate + $dir/ref_vocab.txt $dir/train_counts.txt > $dir/train_oov_rate || exit 1; echo "OOV rate (w.r.t. the reference lexicon) of the acoustic training data is:" cat $dir/train_oov_rate @@ -237,14 +243,14 @@ if [ $stage -le 2 ]; then # cannot be found in lexicon_g2p, we simply assign oov_symbol's pronunciaiton # (like NSN) to them, in order to get phonetic decoding pron candidates for them later on. awk 'NR==FNR{a[$1] = 1; next} ($1 in a)' $dir/oov_train.txt \ - $dir/lexicon_g2p.txt > $dir/g2p_prons_for_oov_train.txt + $dir/lexicon_g2p.txt > $dir/g2p_prons_for_oov_train.txt || exit 1; # Get the pronunciation of oov_symbol. - oov_pron=`cat $dir/non_scored_entries | grep $oov_symbol | cut -f2- -d' '` + oov_pron=`cat $dir/non_scored_entries | grep $oov_symbol | awk '{print $2}'` # For oov words in training data for which we don't even have G2P pron candidates, # we simply assign them the pronunciation of the oov symbol (like ). awk 'NR==FNR{a[$1] = 1; next} {if(!($1 in a)) print $1}' $dir/g2p_prons_for_oov_train.txt \ - $dir/oov_train.txt | awk -v op=$oov_pron '{print $0" "op}' > $dir/oov_train_no_pron.txt + $dir/oov_train.txt | awk -v op="$oov_pron" '{print $0" "op}' > $dir/oov_train_no_pron.txt || exit 1; cat $dir/oov_train_no_pron.txt $dir/g2p_prons_for_oov_train.txt $dir/ref_lexicon.txt | \ awk 'NR==FNR{a[$1] = 1; next} ($1 in a)' $dir/train_counts.txt - | \ @@ -263,7 +269,7 @@ if [ $stage -le 3 ]; then # We prune the phonetic decoding generated prons relative to the largest count, by setting "min_prob", # and only leave prons who are not present in the reference lexicon / g2p-generated lexicon. - cat $dir/ref_lexicon.txt $dir/lexicon_g2p.txt > $dir/phonetic_decoding/filter_lexicon.txt + cat $dir/ref_lexicon.txt $dir/lexicon_g2p.txt | sort -u > $dir/phonetic_decoding/filter_lexicon.txt $cmd $dir/phonetic_decoding/log/prons_to_lexicon.log steps/dict/prons_to_lexicon.py \ --min-prob=$min_prob --filter-lexicon=$dir/phonetic_decoding/filter_lexicon.txt \ @@ -295,7 +301,7 @@ if [ $stage -le 4 ]; then # Generate lattices for the acoustic training data with the combined lexicon. if $retrain_src_mdl; then mdl_dir=$dir/${src_mdl_dir}_retrained; else mdl_dir=$src_mdl_dir; fi - steps/align_fmllr_lats.sh --cmd "$decode_cmd" --nj $nj \ + steps/align_fmllr_lats.sh --acoustic-scale 0.05 --cmd "$decode_cmd" --nj $nj \ $data $dir/lang_combined_iter1 $mdl_dir $dir/lats_iter1 || exit 1; # Get arc level information from the lattice. @@ -321,13 +327,10 @@ if [ $stage -le 5 ]; then rm $dir/dict_combined_iter2/lexiconp.txt $dir/dict_combined_iter2/lexicon.txt 2>/dev/null # Prune away pronunciations which have low acoustic evidence from the first pass of lattice alignment. - $cmd $dir/lats_iter1/log/prune_pron_candidates.log steps/dict/internal/prune_pron_candidates.py $dir/lats_iter1/pron_stats.txt $dir/ref_lexicon.txt $dir/pruned_prons.txt - - awk 'NR==FNR{a[$0] = 1; next} (!($0 in a))' $dir/pruned_prons.txt $dir/lexicon_phonetic_decoding.txt \ - > $dir/lexicon_phonetic_decoding_pruned.txt - - awk 'NR==FNR{a[$0] = 1; next} (!($0 in a))' $dir/pruned_prons.txt $dir/lexicon_g2p.txt \ - > $dir/lexicon_g2p_pruned.txt \ + $cmd $dir/lats_iter1/log/prune_pron_candidates.log steps/dict/internal/prune_pron_candidates.py \ + --variant-counts-ratio $variant_counts_ratio \ + $dir/lats_iter1/pron_stats.txt $dir/lexicon_phonetic_decoding_pruned.txt $dir/lexiconp_g2p.txt $dir/ref_lexicon.txt \ + $dir/lexicon_phonetic_decoding_pruned.txt $dir/lexicon_g2p_pruned.txt # Filter out words which don't appear in the acoustic training data cat $dir/lexicon_phonetic_decoding_pruned.txt $dir/lexicon_g2p_pruned.txt \ @@ -402,7 +405,7 @@ if [ $stage -le 7 ]; then # target vocab. We'll just assign to them pronunciations from lexicon_g2p, if any. cat $dir/lats_iter2/out_of_ref_vocab_prons_learned.txt $dir/ref_lexicon.txt | \ awk 'NR==FNR{a[$1] = 1; next} !($1 in a)' - \ - $dir/target_vocab.txt | sort | uniq > $dir/oov_no_acoustics.txt + $dir/target_vocab.txt | sort | uniq > $dir/oov_no_acoustics.txt || exit 1; awk 'NR==FNR{a[$1] = 1; next} ($1 in a)' $dir/oov_no_acoustics.txt \ $dir/lexicon_g2p.txt > $dir/g2p_prons_for_oov_no_acoustics.txt @@ -426,5 +429,5 @@ if [ $stage -le 8 ]; then echo " ... sort -u \> $dest_dict/lexicon.txt to re-produce the final learned lexicon." cp $dir/lats_iter2/ref_lexicon_edits.txt $dest_dict/lexicon_edits.txt 2>/dev/null steps/dict/apply_lexicon_edits.py $dest_dict/lexicon0.txt $dir/lats_iter2/ref_lexicon_edits.txt - | \ - sort | uniq > $dest_dict/lexicon.txt + sort | uniq > $dest_dict/lexicon.txt || exit 1; fi diff --git a/egs/wsj/s5/steps/dict/learn_lexicon_greedy.sh b/egs/wsj/s5/steps/dict/learn_lexicon_greedy.sh new file mode 100755 index 00000000000..56e85f20d62 --- /dev/null +++ b/egs/wsj/s5/steps/dict/learn_lexicon_greedy.sh @@ -0,0 +1,546 @@ +#! /bin/bash + +# Copyright 2018 Xiaohui Zhang +# Apache 2.0 + +# This recipe has similar inputs and outputs as steps/dict/learn_lexicon.sh +# The major difference is, instead of using a Bayesian framework for +# pronunciation selection, we used a likelihood-reduction based greedy +# pronunciation selection framework presented in the paper: +# "Acoustic data-driven lexicon learning based on a greedy pronunciation " +# "selection framework, by X. Zhang, V. Mahonar, D. Povey and S. Khudanpur," +# "Interspeech 2017." + +# This script demonstrate how to expand a existing lexicon using a combination +# of acoustic evidence and G2P to learn a lexicon that covers words in a target +# vocab, and agrees sufficiently with the acoustics. The basic idea is to +# run phonetic decoding on acoustic training data using an existing +# acoustice model (possibly re-trained using a G2P-expanded lexicon) to get +# alternative pronunciations for words in training data. Then we combine three +# exclusive sources of pronunciations: the reference lexicon (supposedly +# hand-derived), phonetic decoding, and G2P (optional) into one lexicon and then run +# lattice alignment on the same data, to collect acoustic evidence (soft +# counts) of all pronunciations. Based on these statistics, we use a greedy +# framework (see steps/dict/select_prons_greedy.sh for details) to select an +# informative subset of pronunciations for each word with acoustic evidence. +# two important parameters are alpha and beta. Basically, the three dimensions of alpha +# and beta correspond to three pronunciation sources: phonetic-decoding, G2P and +# the reference lexicon, and the larger a value is, the more aggressive we'll +# prune pronunciations from that sooure. The valid range of each dim. is [0, 1] +# (for alpha, and 0 means we never pruned pron from that source.) [0, 100] (for beta). +# The output of steps/dict/select_prons_greedy.sh is a learned lexicon whose vocab +# matches the user-specified target-vocab, and two intermediate outputs which were +# used to generate the learned lexicon: an edits file which records the recommended +# changes to all in-ref-vocab words' prons, and a half-learned lexicon +# ($dest_dict/lexicon0.txt) where all in-ref-vocab words' prons were untouched +# (on top of which we apply the edits file to produce the final learned lexicon). +# The user can always modify the edits file manually and then re-apply it on the +# half-learned lexicon using steps/dict/apply_lexicon_edits.sh to produce the +# final learned lexicon. See the last stage in this script for details. + +stage=0 +# Begin configuration section. +cmd=run.pl +nj= +stage=0 +oov_symbol= +lexiconp_g2p= +min_prob=0.3 +variant_counts_ratio=8 +variant_counts_no_acoustics=1 +alpha="0,0,0" +beta="0,0,0" +delta=0.0000001 +num_gauss= +num_leaves= +retrain_src_mdl=true +cleanup=true +nj_select_prons=200 +learn_iv_prons=false # whether we want to learn the prons of IV words (w.r.t. ref_vocab), + +# End configuration section. + +. ./path.sh +. utils/parse_options.sh + +if [ $# -lt 6 ] || [ $# -gt 7 ]; then + echo "Usage: $0 [options] \\" + echo " ." + echo " This script does lexicon expansion using a combination of acoustic" + echo " evidence and G2P to produce a lexicon that covers words of a target vocab:" + echo "" + echo "Arguments:" + echo " The dir which contains the reference lexicon (most probably hand-derived)" + echo " we want to expand/improve, and nonsilence_phones.txt,.etc which we need " + echo " for building new dict dirs." + echo " The vocabulary we want the final learned lexicon to cover (one word per line)." + echo " acoustic training data we use to get alternative" + echo " pronunciations and collet acoustic evidence." + echo " The dir containing an SAT-GMM acoustic model (we optionaly we re-train it" + echo " using G2P expanded lexicon) to do phonetic decoding (to get alternative" + echo " pronunciations) and lattice-alignment (to collect acoustic evidence for" + echo " evaluating all prounciations)" + echo " The reference lang dir which we use to get non-scored-words" + echo " like for building new dict dirs" + echo " The dict dir where we put the final learned lexicon, whose vocab" + echo " matches ." + echo " The dir which contains all the intermediate outputs of this script." + echo "" + echo "Note: and the vocab of don't have to match. For words" + echo " who are in but not seen in , their pronunciations" + echo " will be given by G2P at the end." + echo "" + echo "e.g. $0 data/local/dict data/local/lm/librispeech-vocab.txt data/train \\" + echo " exp/tri3 data/lang data/local/dict_learned" + echo "Options:" + echo " --stage # stage to run from, to enable resuming from partially" + echo " # completed run (default: 0)" + echo " --cmd '$cmd' # command to submit jobs with (e.g. run.pl, queue.pl)" + echo " --nj # number of parallel jobs" + echo " --oov-symbol '$oov_symbol' # oov symbol, like ." + echo " --lexiconp-g2p # a lexicon (with prob in the second column) file containing g2p generated" + echo " # pronunciations, for words in acoustic training data / target vocabulary. It's optional." + echo " --min-prob # The cut-off parameter used to select pronunciation candidates from phonetic" + echo " # decoding. We remove pronunciations with probabilities less than this value" + echo " # after normalizing the probs s.t. the max-prob is 1.0 for each word." + echo " --variant-counts-ratio # This ratio parameter determines the maximum number of pronunciation" + echo " # candidates we will keep for each word, after pruning according to lattice statistics from" + echo " # the first iteration of lattice generation. See steps/dict/internal/prune_pron_candidates.py" + echo " # for details." + echo " --variant-counts-no-acoustics # how many g2p-prons per word we want to include for each words unseen in acoustic training data." + echo " --alpha ,, # scaling factors used in the greedy pronunciation selection framework, " + echo " # see steps/dict/select_prons_greedy.py for details." + echo " --beta ,, # smoothing factors used in the greedy pronunciation selection framework, " + echo " # see steps/dict/select_prons_greedy.py for details." + echo " --delta # a floor value used in the greedy pronunciation selection framework, " + echo " # see steps/dict/select_prons_greedy.py for details." + echo " --num-gauss # number of gaussians for the re-trained SAT model (on top of )." + echo " --num-leaves # number of leaves for the re-trained SAT model (on top of )." + echo " --retrain-src-mdl # true if you want to re-train the src_mdl before phone decoding (default false)." + exit 1 +fi + +echo "$0 $@" # Print the command line for logging + +ref_dict=$1 +target_vocab=$2 +data=$3 +src_mdl_dir=$4 +ref_lang=$5 +dest_dict=$6 + +if [ -z "$oov_symbol" ]; then + echo "$0: the --oov-symbol option is required." + exit 1 +fi + +if [ $# -gt 6 ]; then + dir=$7 # Most intermediate outputs will be put here. +else + dir=${src_mdl_dir}_lex_learn_work +fi + +mkdir -p $dir +if [ $stage -le 0 ]; then + echo "$0: Some preparatory work." + # Get the word counts of training data. + awk '{for (n=2;n<=NF;n++) counts[$n]++;} END{for (w in counts) printf "%s %d\n",w, counts[w];}' \ + $data/text | sort > $dir/train_counts.txt + + # Get the non-scored entries and exclude them from the reference lexicon/vocab, and target_vocab. + steps/cleanup/internal/get_non_scored_words.py $ref_lang > $dir/non_scored_words + awk 'NR==FNR{a[$1] = 1; next} {if($1 in a) print $0}' $dir/non_scored_words \ + $ref_dict/lexicon.txt > $dir/non_scored_entries + + # Remove non-scored-words from the reference lexicon. + awk 'NR==FNR{a[$1] = 1; next} {if(!($1 in a)) print $0}' $dir/non_scored_words \ + $ref_dict/lexicon.txt | tr -s '\t' ' ' | awk '$1=$1' > $dir/ref_lexicon.txt + + cat $dir/ref_lexicon.txt | awk '{print $1}' | sort | uniq > $dir/ref_vocab.txt + awk 'NR==FNR{a[$1] = 1; next} {if(!($1 in a)) print $0}' $dir/non_scored_words \ + $target_vocab | sort | uniq > $dir/target_vocab.txt + + # From the reference lexicon, we estimate the target_num_prons_per_word as, + # round(avg. # prons per word in the reference lexicon). This'll be used as + # the upper bound of # pron variants per word when we apply G2P or select prons to + # construct the learned lexicon in later stages. + python -c 'import sys; import math; print int(round(float(sys.argv[1])/float(sys.argv[2])))' \ + `wc -l $dir/ref_lexicon.txt | awk '{print $1}'` `wc -l $dir/ref_vocab.txt | awk '{print $1}'` \ + > $dir/target_num_prons_per_word || exit 1; + + if [ -z $lexiconp_g2p ]; then + # create an empty list of g2p generated prons, if it's not given. + touch $dir/lexicon_g2p.txt + touch $dir/lexiconp_g2p.txt + else + # Exchange the 1st column (word) and 2nd column (prob) and remove pronunciations + # which are already in the reference lexicon. + cat $lexiconp_g2p | awk '{a=$1;b=$2; $1="";$2="";print b" "a$0}' | \ + awk 'NR==FNR{a[$0] = 1; next} {w=$2;for (n=3;n<=NF;n++) w=w" "$n; if(!(w in a)) print $0}' \ + $dir/ref_lexicon.txt - > $dir/lexiconp_g2p.txt 2>/dev/null + + # make a copy where we remove the first column (probabilities). + cat $dir/lexiconp_g2p.txt | cut -f1,3- > $dir/lexicon_g2p.txt 2>/dev/null + fi + variant_counts=`cat $dir/target_num_prons_per_word` || exit 1; + $cmd $dir/log/prune_g2p_lexicon.log steps/dict/prons_to_lexicon.py \ + --top-N=$variant_counts $dir/lexiconp_g2p.txt \ + $dir/lexicon_g2p_variant_counts${variant_counts}.txt || exit 1; +fi + +if [ $stage -le 1 ] && $retrain_src_mdl; then + echo "$0: Expand the reference lexicon to cover all words in the target vocab. and then" + echo " ... re-train the source acoustic model for phonetic decoding. " + mkdir -p $dir/dict_expanded_target_vocab + cp $ref_dict/{extra_questions.txt,optional_silence.txt,nonsilence_phones.txt,silence_phones.txt} \ + $dir/dict_expanded_target_vocab 2>/dev/null + rm $dir/dict_expanded_target_vocab/lexiconp.txt $dir/dict_expanded_target_vocab/lexicon.txt 2>/dev/null + + # Get the oov words list (w.r.t ref vocab) which are in the target vocab. + awk 'NR==FNR{a[$1] = 1; next} !($1 in a)' $dir/ref_lexicon.txt \ + $dir/target_vocab.txt | sort | uniq > $dir/oov_target_vocab.txt + + # Assign pronunciations from lexicon_g2p.txt to oov_target_vocab. For words which + # cannot be found in lexicon_g2p.txt, we simply ignore them. + awk 'NR==FNR{a[$1] = 1; next} ($1 in a)' $dir/oov_target_vocab.txt \ + $dir/lexicon_g2p.txt > $dir/lexicon_g2p_oov_target_vocab.txt + + cat $dir/lexicon_g2p_oov_target_vocab.txt $dir/ref_lexicon.txt | \ + awk 'NR==FNR{a[$1] = 1; next} ($1 in a)' $dir/target_vocab.txt - | \ + cat $dir/non_scored_entries - | + sort | uniq > $dir/dict_expanded_target_vocab/lexicon.txt + + utils/prepare_lang.sh --phone-symbol-table $ref_lang/phones.txt $dir/dict_expanded_target_vocab \ + $oov_symbol $dir/lang_expanded_target_vocab_tmp $dir/lang_expanded_target_vocab || exit 1; + + # Align the acoustic training data using the given src_mdl_dir. + alidir=${src_mdl_dir}_ali_$(basename $data) + steps/align_fmllr.sh --nj $nj --cmd "$train_cmd" \ + $data $dir/lang_expanded_target_vocab $src_mdl_dir $alidir || exit 1; + + # Train another SAT system on the given data and put it in $dir/${src_mdl_dir}_retrained + # this model will be used for phonetic decoding and lattice alignment later on. + if [ -z $num_leaves ] || [ -z $num_gauss ] ; then + echo "num_leaves and num_gauss need to be specified." && exit 1; + fi + steps/train_sat.sh --cmd "$train_cmd" $num_leaves $num_gauss \ + $data $dir/lang_expanded_target_vocab $alidir $dir/${src_mdl_dir}_retrained || exit 1; +fi + +if [ $stage -le 2 ]; then + echo "$0: Expand the reference lexicon to cover all words seen in," + echo " ... acoustic training data, and prepare corresponding dict and lang directories." + echo " ... This is needed when generate pron candidates from phonetic decoding." + mkdir -p $dir/dict_expanded_train + cp $ref_dict/{extra_questions.txt,optional_silence.txt,nonsilence_phones.txt,silence_phones.txt} \ + $dir/dict_expanded_train 2>/dev/null + rm $dir/dict_expanded_train/lexiconp.txt $dir/dict_expanded_train/lexicon.txt 2>/dev/null + + # Get the oov words list (w.r.t ref vocab) which are in training data. + awk 'NR==FNR{a[$1] = 1; next} {if(!($1 in a)) print $1}' $dir/ref_lexicon.txt \ + $dir/train_counts.txt | awk 'NR==FNR{a[$1] = 1; next} {if(!($1 in a)) print $0}' \ + $dir/non_scored_words - | sort > $dir/oov_train.txt || exit 1; + + awk 'NR==FNR{a[$1] = 1; next} {if(($1 in a)) b+=$2; else c+=$2} END{print c/(b+c)}' \ + $dir/ref_vocab.txt $dir/train_counts.txt > $dir/train_oov_rate || exit 1; + + echo "OOV rate (w.r.t. the reference lexicon) of the acoustic training data is:" + cat $dir/train_oov_rate + + # Assign pronunciations from lexicon_g2p to oov_train. For words which + # cannot be found in lexicon_g2p, we simply assign oov_symbol's pronunciaiton + # (like NSN) to them, in order to get phonetic decoding pron candidates for them later on. + variant_counts=`cat $dir/target_num_prons_per_word` || exit 1; + awk 'NR==FNR{a[$1] = 1; next} ($1 in a)' $dir/oov_train.txt \ + $dir/lexicon_g2p_variant_counts${variant_counts}.txt > $dir/g2p_prons_for_oov_train.txt || exit 1; + + # Get the pronunciation of oov_symbol. + oov_pron=`cat $dir/non_scored_entries | grep $oov_symbol | awk '{print $2}'` + # For oov words in training data for which we don't even have G2P pron candidates, + # we simply assign them the pronunciation of the oov symbol (like ), + # so that we can get pronunciations for them from phonetic decoding. + awk 'NR==FNR{a[$1] = 1; next} {if(!($1 in a)) print $1}' $dir/g2p_prons_for_oov_train.txt \ + $dir/oov_train.txt | awk -v op="$oov_pron" '{print $0" "op}' > $dir/oov_train_no_pron.txt || exit 1; + + cat $dir/oov_train_no_pron.txt $dir/g2p_prons_for_oov_train.txt $dir/ref_lexicon.txt | \ + awk 'NR==FNR{a[$1] = 1; next} ($1 in a)' $dir/train_counts.txt - | \ + cat - $dir/non_scored_entries | \ + sort | uniq > $dir/dict_expanded_train/lexicon.txt || exit 1; + + utils/prepare_lang.sh $dir/dict_expanded_train $oov_symbol \ + $dir/lang_expanded_train_tmp $dir/lang_expanded_train || exit 1; +fi + +if [ $stage -le 3 ]; then + echo "$0: Generate pronunciation candidates from phonetic decoding on acoustic training data.." + if $retrain_src_mdl; then mdl_dir=$dir/${src_mdl_dir}_retrained; else mdl_dir=$src_mdl_dir; fi + steps/cleanup/debug_lexicon.sh --nj $nj \ + --cmd "$decode_cmd" $data $dir/lang_expanded_train \ + $mdl_dir $dir/dict_expanded_train/lexicon.txt $dir/phonetic_decoding || exit 1; +fi + +if [ $stage -le 4 ]; then + echo "$0: Combine the reference lexicon and pronunciations from phone-decoding/G2P into one" + echo " ... lexicon, and run lattice alignment using this lexicon on acoustic training data" + echo " ... to collect acoustic evidence." + # We first prune the phonetic decoding generated prons relative to the largest count, by setting "min_prob", + # and only leave prons who are not present in the reference lexicon / g2p-generated lexicon. + cat $dir/ref_lexicon.txt $dir/lexicon_g2p.txt | sort -u > $dir/phonetic_decoding/filter_lexicon.txt + + $cmd $dir/phonetic_decoding/log/prons_to_lexicon.log steps/dict/prons_to_lexicon.py \ + --min-prob=$min_prob --filter-lexicon=$dir/phonetic_decoding/filter_lexicon.txt \ + $dir/phonetic_decoding/prons.txt $dir/lexicon_pd_with_eps.txt + + # We abandon phonetic-decoding candidates for infrequent words. + awk '{if($2 < 3) print $1}' $dir/train_counts.txt > $dir/pd_candidates_to_exclude.txt + awk 'NR==FNR{a[$1] = $2; next} {if(a[$1]<10) print $1}' $dir/train_counts.txt \ + $dir/oov_train_no_pron.txt >> $dir/pd_candidates_to_exclude.txt + + if [ -s $dir/pd_candidates_to_exclude.txt ]; then + cat $dir/lexicon_pd_with_eps.txt | grep -vP "|||\[.*\]" | \ + awk 'NR==FNR{a[$0] = 1; next} {if(!($1 in a)) print $0}' $dir/pd_candidates_to_exclude.txt - | \ + sort | uniq > $dir/lexicon_pd.txt || exit 1; + else + cat $dir/lexicon_pd_with_eps.txt | grep -vP "|||\[.*\]" | \ + sort | uniq > $dir/lexicon_pd.txt || exit 1; + fi + + # Combine the reference lexicon, pronunciations from G2P and phonetic decoding into one lexicon. + mkdir -p $dir/dict_combined_iter1 + cp $ref_dict/{extra_questions.txt,optional_silence.txt,nonsilence_phones.txt,silence_phones.txt} \ + $dir/dict_combined_iter1/ 2>/dev/null + rm $dir/dict_combined_iter1/lexiconp.txt $dir/dict_combined_iter1/lexicon.txt 2>/dev/null + + # Filter out words which don't appear in the acoustic training data + cat $dir/lexicon_pd.txt $dir/lexicon_g2p.txt \ + $dir/ref_lexicon.txt | tr -s '\t' ' ' | \ + awk 'NR==FNR{a[$1] = 1; next} ($1 in a)' $dir/train_counts.txt - | \ + cat $dir/non_scored_entries - | \ + sort | uniq > $dir/dict_combined_iter1/lexicon.txt + + utils/prepare_lang.sh --phone-symbol-table $ref_lang/phones.txt \ + $dir/dict_combined_iter1 $oov_symbol \ + $dir/lang_combined_iter1_tmp $dir/lang_combined_iter1 || exit 1; + + # Generate lattices for the acoustic training data with the combined lexicon. + if $retrain_src_mdl; then mdl_dir=$dir/${src_mdl_dir}_retrained; else mdl_dir=$src_mdl_dir; fi + + # Get the vocab for words for which we want to learn pronunciations. + if $learn_iv_prons; then + # If we want to learn the prons of IV words (w.r.t. ref_vocab), the learn_vocab is just the intersection of + # target_vocab and the vocab of words seen in acoustic training data (first col. of train_counts.txt) + awk 'NR==FNR{a[$1] = 1; next} {if($1 in a) print $1}' $dir/target_vocab.txt $dir/train_counts.txt \ + > $dir/learn_vocab.txt + else + # Exclude words from the ref_vocab if we don't want to learn the pronunciations of IV words. + awk 'NR==FNR{a[$1] = 1; next} {if($1 in a) print $1}' $dir/target_vocab.txt $dir/train_counts.txt | \ + awk 'NR==FNR{a[$1] = 1; next} {if(!($1 in a)) print $1}' $dir/ref_vocab.txt - > $dir/learn_vocab.txt + fi + + # In order to get finer lattice stats of alternative prons, we want to make lattices deeper. + # To speed up lattice generation, we use a ctm to create sub-utterances and a sub-segmentation + # for each instance of a word within learn_vocab (or a string of consecutive words within learn_vocab), + # including a single out-of-learn-vocab word at the boundary if present. + mkdir -p $dir/resegmentation + steps/dict/internal/get_subsegments.py $dir/phonetic_decoding/word.ctm $dir/learn_vocab.txt \ + $dir/resegmentation/subsegments $dir/resegmentation/text || exit 1; + utils/data/subsegment_data_dir.sh $data $dir/resegmentation/subsegments $dir/resegmentation/text \ + $dir/resegmentation/data || exit 1; + steps/compute_cmvn_stats.sh $dir/resegmentation/data || exit 1; + + steps/align_fmllr_lats.sh --beam 20 --retry-beam 50 --final-beam 30 --acoustic-scale 0.05 --cmd "$decode_cmd" --nj $nj \ + $dir/resegmentation/data $dir/lang_combined_iter1 $mdl_dir $dir/lats_iter1 || exit 1; + + # Get arc level information from the lattice. + $cmd JOB=1:$nj $dir/lats_iter1/log/get_arc_info.JOB.log \ + lattice-align-words $dir/lang_combined_iter1/phones/word_boundary.int \ + $dir/lats_iter1/final.mdl \ + "ark:gunzip -c $dir/lats_iter1/lat.JOB.gz |" ark:- \| \ + lattice-arc-post --acoustic-scale=0.1 $dir/lats_iter1/final.mdl ark:- - \| \ + utils/int2sym.pl -f 5 $dir/lang_combined_iter1/words.txt \| \ + utils/int2sym.pl -f 6- $dir/lang_combined_iter1/phones.txt '>' \ + $dir/lats_iter1/arc_info_sym.JOB.txt || exit 1; + + # Compute soft counts (pron_stats) of every particular word-pronunciation pair by + # summing up arc level information over all utterances. We'll use this to prune + # pronunciation candidates before the next iteration of lattice generation. + cat $dir/lats_iter1/arc_info_sym.*.txt | steps/dict/get_pron_stats.py - \ + $dir/phonetic_decoding/phone_map.txt $dir/lats_iter1/pron_stats.txt || exit 1; + + # Accumlate utterance-level pronunciation posteriors (into arc_stats) by summing up + # posteriors of arcs representing the same word & pronunciation and starting + # from roughly the same location. See steps/dict/internal/sum_arc_info.py for details. + for i in `seq 1 $nj`;do + cat $dir/lats_iter1/arc_info_sym.${i}.txt | sort -n -k1 -k2 -k3r | \ + steps/dict/internal/sum_arc_info.py - $dir/phonetic_decoding/phone_map.txt $dir/lats_iter1/arc_info_summed.${i}.txt + done + cat $dir/lats_iter1/arc_info_summed.*.txt | sort -k1 -k2 > $dir/lats_iter1/arc_stats.txt + + # Prune the phonetic_decoding lexicon so that any pronunciation that only has non-zero posterior at one word example will be removed. + # The pruned lexicon is put in $dir/lats_iter1. After further pruning in the next stage it'll be put back to $dir. + awk 'NR==FNR{w=$1;for (n=5;n<=NF;n++) w=w" "$n;a[w]+=1;next} {if($0 in a && a[$0]>1) print $0}' \ + $dir/lats_iter1/arc_stats.txt $dir/lexicon_pd.txt > $dir/lats_iter1/lexicon_pd_pruned.txt +fi + +# Here we re-generate lattices (with a wider beam and a pruned combined lexicon) and re-collect pronunciation statistics +if [ $stage -le 5 ]; then + echo "$0: Prune the pronunciation candidates generated from G2P/phonetic decoding, and re-do lattice-alignment." + mkdir -p $dir/dict_combined_iter2 + cp $ref_dict/{extra_questions.txt,optional_silence.txt,nonsilence_phones.txt,silence_phones.txt} \ + $dir/dict_combined_iter2/ 2>/dev/null + rm $dir/dict_combined_iter2/lexiconp.txt $dir/dict_combined_iter2/lexicon.txt 2>/dev/null + + # Prune away pronunciations which have low acoustic evidence from the first pass of lattice generation. + $cmd $dir/lats_iter1/log/prune_pron_candidates.log steps/dict/internal/prune_pron_candidates.py \ + --variant-counts-ratio $variant_counts_ratio \ + $dir/lats_iter1/pron_stats.txt $dir/lats_iter1/lexicon_pd_pruned.txt $dir/lexiconp_g2p.txt $dir/ref_lexicon.txt \ + $dir/lexicon_pd_pruned.txt $dir/lexicon_g2p_pruned.txt + + # Filter out words which don't appear in the acoustic training data. + cat $dir/lexicon_pd_pruned.txt $dir/lexicon_g2p_pruned.txt \ + $dir/ref_lexicon.txt | tr -s '\t' ' ' | \ + awk 'NR==FNR{a[$1] = 1; next} ($1 in a)' $dir/train_counts.txt - | \ + cat $dir/non_scored_entries - | \ + sort | uniq > $dir/dict_combined_iter2/lexicon.txt + + utils/prepare_lang.sh --phone-symbol-table $ref_lang/phones.txt \ + $dir/dict_combined_iter2 $oov_symbol \ + $dir/lang_combined_iter2_tmp $dir/lang_combined_iter2 || exit 1; + + # Re-generate lattices with a wider beam, so that we'll get deeper lattices. + if $retrain_src_mdl; then mdl_dir=$dir/${src_mdl_dir}_retrained; else mdl_dir=$src_mdl_dir; fi + steps/align_fmllr_lats.sh --beam 30 --retry-beam 60 --final-beam 50 --acoustic-scale 0.05 --cmd "$decode_cmd" --nj $nj \ + $dir/resegmentation/data $dir/lang_combined_iter2 $mdl_dir $dir/lats_iter2 || exit 1; + + # Get arc level information from the lattice as we did in the last stage. + $cmd JOB=1:$nj $dir/lats_iter2/log/get_arc_info.JOB.log \ + lattice-align-words $dir/lang_combined_iter2/phones/word_boundary.int \ + $dir/lats_iter2/final.mdl \ + "ark:gunzip -c $dir/lats_iter2/lat.JOB.gz |" ark:- \| \ + lattice-arc-post --acoustic-scale=0.1 $dir/lats_iter2/final.mdl ark:- - \| \ + utils/int2sym.pl -f 5 $dir/lang_combined_iter2/words.txt \| \ + utils/int2sym.pl -f 6- $dir/lang_combined_iter2/phones.txt '>' \ + $dir/lats_iter2/arc_info_sym.JOB.txt || exit 1; + + # Compute soft counts (pron_stats) of every particular word-pronunciation pair as + # we did in the last stage. The stats will only be used as diagnostics. + cat $dir/lats_iter2/arc_info_sym.*.txt | steps/dict/get_pron_stats.py - \ + $dir/phonetic_decoding/phone_map.txt $dir/lats_iter2/pron_stats.txt || exit 1; + + # Accumlate utterance-level pronunciation posteriors as we did in the last stage. + for i in `seq 1 $nj`;do + cat $dir/lats_iter2/arc_info_sym.${i}.txt | sort -n -k1 -k2 -k3r | \ + steps/dict/internal/sum_arc_info.py - $dir/phonetic_decoding/phone_map.txt $dir/lats_iter2/arc_info_summed.${i}.txt + done + cat $dir/lats_iter2/arc_info_summed.*.txt | sort -k1 -k2 > $dir/lats_iter2/arc_stats.txt + + # The pron_stats are the acoustic evidence which the likelihood-reduction-based pronunciation + # selection procedure will be based on. + # Split the utterance-level pronunciation posterior stats into $nj_select_prons pieces, + # so that the following pronunciation selection stage can be parallelized. + numsplit=$nj_select_prons + awk '{print $1"-"$2" "$1}' $dir/lats_iter2/arc_stats.txt > $dir/lats_iter2/utt2word + utt2words=$(for n in `seq $numsplit`; do echo $dir/lats_iter2/utt2word.$n; done) + utils/split_scp.pl --utt2spk=$dir/lats_iter2/utt2word $dir/lats_iter2/utt2word $utt2words || exit 1 + for n in `seq $numsplit`; do + (cat $dir/lats_iter2/utt2word.$n | awk '{$1=substr($1,length($2)+2);print $2" "$1}' - > $dir/lats_iter2/word2utt.$n + awk 'NR==FNR{a[$0] = 1; next} {b=$1" "$2; if(b in a) print $0}' $dir/lats_iter2/word2utt.$n \ + $dir/lats_iter2/arc_stats.txt > $dir/lats_iter2/arc_stats.${n}.txt + ) & + done + wait +fi + +if [ $stage -le 6 ]; then + echo "$0: Select pronunciations according to the acoustic evidence from lattice alignment." + # Given the acoustic evidence (soft-counts), we use a Bayesian framework to select pronunciations + # from three exclusive candidate sources: reference (hand-derived) lexicon, G2P and phonetic decoding. + # The posteriors for all candidate prons for all words are printed into pron_posteriors.txt + # For words which are out of the ref. vocab, the learned prons are written into out_of_ref_vocab_prons_learned.txt. + # Among them, for words without acoustic evidence, we just ignore them, even if pron candidates from G2P were provided). + # For words in the ref. vocab, we instead output a human readable & editable "edits" file called + # ref_lexicon_edits.txt, which records all proposed changes to the prons (if any). Also, a + # summary is printed into the log file. + + $cmd JOB=1:$nj_select_prons $dir/lats_iter2/log/generate_learned_lexicon.JOB.log \ + steps/dict/select_prons_greedy.py \ + --alpha=${alpha} --beta=${beta} \ + --delta=${delta} \ + $ref_dict/silence_phones.txt $dir/lats_iter2/arc_stats.JOB.txt $dir/train_counts.txt $dir/ref_lexicon.txt \ + $dir/lexicon_g2p_pruned.txt $dir/lexicon_pd_pruned.txt \ + $dir/lats_iter2/learned_lexicon.JOB.txt || exit 1; + + cat $dir/lats_iter2/learned_lexicon.*.txt > $dir/lats_iter2/learned_lexicon.txt + rm $dir/lats_iter2/learned_lexicon.*.txt + + $cmd $dir/lats_iter2/log/lexicon_learning_summary.log \ + steps/dict/merge_learned_lexicons.py \ + $dir/lats_iter2/arc_stats.txt $dir/train_counts.txt $dir/ref_lexicon.txt \ + $dir/lexicon_g2p_pruned.txt $dir/lexicon_pd_pruned.txt \ + $dir/lats_iter2/learned_lexicon.txt \ + $dir/lats_iter2/out_of_ref_vocab_prons_learned.txt $dir/lats_iter2/ref_lexicon_edits.txt || exit 1; + + cp $dir/lats_iter2/ref_lexicon_edits.txt $dir/lats_iter2/ref_lexicon_edits.txt + # Remove some stuff that takes up space and is unlikely to be useful later on. + if $cleanup; then + rm -r $dir/lats_iter*/{fsts*,lat*} 2>/dev/null + fi +fi + +if [ $stage -le 7 ]; then + echo "$0: Expand the learned lexicon further to cover words in target vocab that are." + echo " ... not seen in acoustic training data." + mkdir -p $dest_dict + cp $ref_dict/{extra_questions.txt,optional_silence.txt,nonsilence_phones.txt,silence_phones.txt} \ + $dest_dict 2>/dev/null + rm $dest_dict/lexiconp.txt $dest_dict/lexicon.txt 2>/dev/null + # Get the list of oov (w.r.t. ref vocab) without acoustic evidence, which are in the + # target vocab. We'll just assign to them pronunciations from lexicon_g2p, if any. + cat $dir/lats_iter2/out_of_ref_vocab_prons_learned.txt $dir/ref_lexicon.txt | \ + awk 'NR==FNR{a[$1] = 1; next} !($1 in a)' - \ + $dir/target_vocab.txt | sort | uniq > $dir/oov_no_acoustics.txt || exit 1; + + variant_counts=$variant_counts_no_acoustics + + $cmd $dir/log/prune_g2p_lexicon.log steps/dict/prons_to_lexicon.py \ + --top-N=$variant_counts $dir/lexiconp_g2p.txt \ + $dir/lexicon_g2p_variant_counts${variant_counts}.txt || exit 1; + + awk 'NR==FNR{a[$1] = 1; next} ($1 in a)' $dir/oov_no_acoustics.txt \ + $dir/lexicon_g2p_variant_counts${variant_counts}.txt > $dir/g2p_prons_for_oov_no_acoustics.txt|| exit 1; + + # Get the pronunciation of oov_symbol. + oov_pron=`cat $dir/non_scored_entries | grep $oov_symbol | awk '{print $2}'` || exit 1; + # For oov words in target_vocab for which we don't even have G2P pron candidates, + # we simply assign them the pronunciation of the oov symbol (like ), + if [ -s $dir/g2p_prons_for_oov_no_acoustics.txt ]; then + awk 'NR==FNR{a[$1] = 1; next} {if(!($1 in a)) print $1}' $dir/g2p_prons_for_oov_no_acoustics.txt \ + $dir/oov_no_acoustics.txt | awk -v op="$oov_pron" '{print $0" "op}' > $dir/oov_target_vocab_no_pron.txt || exit 1; + else + awk -v op="$oov_pron" '{print $0" "op}' $dir/oov_no_acoustics.txt > $dir/oov_target_vocab_no_pron.txt || exit 1 + fi + + # We concatenate three lexicons togethers: G2P lexicon for oov words without acoustics, + # learned lexicon for oov words with acoustics, and the original reference lexicon (for + # this part, later one we'll apply recommended changes using steps/dict/apply_lexicon_edits.py + cat $dir/g2p_prons_for_oov_no_acoustics.txt $dir/lats_iter2/out_of_ref_vocab_prons_learned.txt \ + $dir/oov_target_vocab_no_pron.txt $dir/ref_lexicon.txt | tr -s '\t' ' ' | sort | uniq > $dest_dict/lexicon.temp + + awk 'NR==FNR{a[$1] = 1; next} ($1 in a)' $dir/target_vocab.txt \ + $dest_dict/lexicon.temp | sort | uniq > $dest_dict/lexicon.nosil + + cat $dir/non_scored_entries $dest_dict/lexicon.nosil | sort | uniq >$dest_dict/lexicon0.txt +fi + +if [ $stage -le 8 ]; then + echo "$0: Apply the ref_lexicon_edits file to the reference lexicon." + echo " ... The user can inspect/modify the edits file and then re-run:" + echo " ... steps/dict/apply_lexicon_edits.py $dest_dict/lexicon0.txt $dir/lats_iter2/ref_lexicon_edits.txt - | \\" + echo " ... sort -u \> $dest_dict/lexicon.txt to re-produce the final learned lexicon." + cp $dir/lats_iter2/ref_lexicon_edits.txt $dest_dict/lexicon_edits.txt 2>/dev/null + steps/dict/apply_lexicon_edits.py $dest_dict/lexicon0.txt $dir/lats_iter2/ref_lexicon_edits.txt - | \ + sort | uniq > $dest_dict/lexicon.txt || exit 1; +fi + +echo "Lexicon learning ends successfully. Please refer to $dir/lats_iter2/log/lexicon_learning_summary.log" +echo " for a summary. The learned lexicon, whose vocab matches the target_vocab, is $dest_dict/lexicon.txt" diff --git a/egs/wsj/s5/steps/dict/merge_learned_lexicons.py b/egs/wsj/s5/steps/dict/merge_learned_lexicons.py new file mode 100755 index 00000000000..6df7eb7a744 --- /dev/null +++ b/egs/wsj/s5/steps/dict/merge_learned_lexicons.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python + +# Copyright 2018 Xiaohui Zhang +# Apache 2.0. + +from __future__ import print_function +from collections import defaultdict +import argparse +import sys +import math + +def GetArgs(): + parser = argparse.ArgumentParser( + description = "Convert a learned lexicon produced by steps/dict/select_prons_greedy.sh" + "into a lexicon for OOV words (w.r.t. ref. vocab) and a human editable lexicon-edit file." + "for in-vocab words, and generate detailed summaries of the lexicon learning results" + "The inputs are a learned lexicon, an arc-stats file, and three source lexicons " + "(phonetic-decoding(PD)/G2P/ref). The outputs are: a learned lexicon for OOVs" + "(learned_lexicon_oov), and a lexicon_edits file (ref_lexicon_edits) containing" + "suggested modifications of prons, for in-vocab words.", + epilog = "See steps/dict/learn_lexicon_greedy.sh for example.") + parser.add_argument("arc_stats_file", metavar = "", type = str, + help = "File containing word-pronunciation statistics obtained from lattices; " + "each line must be ") + parser.add_argument("word_counts_file", metavar = "", type = str, + help = "File containing word counts in acoustic training data; " + "each line must be .") + parser.add_argument("ref_lexicon", metavar = "", type = str, + help = "The reference lexicon (most probably hand-derived)." + "Each line must be ") + parser.add_argument("g2p_lexicon", metavar = "", type = str, + help = "Candidate ronouciations from G2P results." + "Each line must be ") + parser.add_argument("pd_lexicon", metavar = "", type = str, + help = "Candidate ronouciations from phonetic decoding results." + "Each line must be ") + parser.add_argument("learned_lexicon", metavar = "", type = str, + help = "Learned lexicon." + "Each line must be ") + parser.add_argument("learned_lexicon_oov", metavar = "", type = str, + help = "Output file which is the learned lexicon for words out of the ref. vocab.") + parser.add_argument("ref_lexicon_edits", metavar = "", type = str, + help = "Output file containing human-readable & editable pronounciation info (and the" + "accept/reject decision made by our algorithm) for those words in ref. vocab," + "to which any change has been recommended. The info for each word is like:" + "------------ an 4086.0 --------------" + "R | Y | 2401.6 | AH N" + "R | Y | 640.8 | AE N" + "P | Y | 1035.5 | IH N" + "R(ef), P(hone-decoding) represents the pronunciation source" + "Y/N means the recommended decision of including this pron or not" + "and the numbers are soft counts accumulated from lattice-align-word outputs. " + "See the function WriteEditsAndSummary for more details.") + + print (' '.join(sys.argv), file=sys.stderr) + + args = parser.parse_args() + args = CheckArgs(args) + + return args + +def CheckArgs(args): + if args.arc_stats_file == "-": + args.arc_stats_file_handle = sys.stdin + else: + args.arc_stats_file_handle = open(args.arc_stats_file) + args.word_counts_file_handle = open(args.word_counts_file) + args.ref_lexicon_handle = open(args.ref_lexicon) + args.g2p_lexicon_handle = open(args.g2p_lexicon) + args.pd_lexicon_handle = open(args.pd_lexicon) + args.learned_lexicon_handle = open(args.learned_lexicon) + args.learned_lexicon_oov_handle = open(args.learned_lexicon_oov, "w") + args.ref_lexicon_edits_handle = open(args.ref_lexicon_edits, "w") + + return args + +def ReadArcStats(arc_stats_file_handle): + stats = defaultdict(lambda : defaultdict(dict)) + stats_summed = defaultdict(float) + for line in arc_stats_file_handle.readlines(): + splits = line.strip().split() + + if (len(splits) == 0): + continue + + if (len(splits) < 5): + raise Exception('Invalid format of line ' + line + + ' in ' + arc_stats_file) + utt = splits[1] + start_frame = int(splits[2]) + word = splits[0] + count = float(splits[3]) + phones = splits[4:] + phones = ' '.join(phones) + stats[word][(utt, start_frame)][phones] = count + stats_summed[(word, phones)] += count + return stats, stats_summed + +def ReadWordCounts(word_counts_file_handle): + counts = {} + for line in word_counts_file_handle.readlines(): + splits = line.strip().split() + if len(splits) < 2: + raise Exception('Invalid format of line ' + line + + ' in counts file.') + word = splits[0] + count = int(splits[1]) + counts[word] = count + return counts + +def ReadLexicon(args, lexicon_file_handle, counts): + # we're skipping any word not in counts (not seen in training data), + # cause we're only learning prons for words who have acoustic examples. + lexicon = defaultdict(set) + for line in lexicon_file_handle.readlines(): + splits = line.strip().split() + if len(splits) == 0: + continue + if len(splits) < 2: + raise Exception('Invalid format of line ' + line + + ' in lexicon file.') + word = splits[0] + if word not in counts: + continue + phones = ' '.join(splits[1:]) + lexicon[word].add(phones) + return lexicon + +def WriteEditsAndSummary(args, learned_lexicon, ref_lexicon, pd_lexicon, g2p_lexicon, counts, stats, stats_summed): + # Note that learned_lexicon and ref_lexicon are dicts of sets of prons, while the other two lexicons are sets of (word, pron) pairs. + threshold = 2 + words = [defaultdict(set) for i in range(4)] # "words" contains four bins, where we + # classify each word into, according to whether it's count > threshold, + # and whether it's OOVs w.r.t the reference lexicon. + + src = {} + print("# Note: This file contains pronunciation info for words who have candidate " + "prons from G2P/phonetic-decoding accepted in the learned lexicon" + ", sorted by their counts in acoustic training data, " + ,file=args.ref_lexicon_edits_handle) + print("# 1st Col: source of the candidate pron: G(2P) / P(hone-decoding) / R(eference)." + ,file=args.ref_lexicon_edits_handle) + print("# 2nd Col: accepted or not in the learned lexicon (Y/N).", file=args.ref_lexicon_edits_handle) + print("# 3rd Col: soft counts from lattice-alignment (not augmented by prior-counts)." + ,file=args.ref_lexicon_edits_handle) + print("# 4th Col: the pronunciation cadidate.", file=args.ref_lexicon_edits_handle) + + # words which are to be printed into the edits file. + words_to_edit = [] + num_prons_tot = 0 + for word in learned_lexicon: + num_prons_tot += len(learned_lexicon[word]) + count = len(stats[word]) # This count could be smaller than the count read from the dict "counts", + # since in each sub-utterance, multiple occurences (which is rare) of the same word are compressed into one. + # We use this count here so that in the edit-file, soft counts for each word sum up to one. + flags = ['0' for i in range(3)] # "flags" contains three binary indicators, + # indicating where this word's pronunciations come from. + for pron in learned_lexicon[word]: + if word in pd_lexicon and pron in pd_lexicon[word]: + flags[0] = '1' + src[(word, pron)] = 'P' + elif word in ref_lexicon and pron in ref_lexicon[word]: + flags[1] = '1' + src[(word, pron)] = 'R' + elif word in g2p_lexicon and pron in g2p_lexicon[word]: + flags[2] = '1' + src[(word, pron)] = 'G' + if word in ref_lexicon: + all_ref_prons_accepted = True + for pron in ref_lexicon[word]: + if pron not in learned_lexicon[word]: + all_ref_prons_accepted = False + break + if not all_ref_prons_accepted or flags[0] == '1' or flags[2] == '1': + words_to_edit.append((word, len(stats[word]))) + if count > threshold: + words[0][flags[0] + flags[1] + flags[2]].add(word) + else: + words[1][flags[0] + flags[1] + flags[2]].add(word) + else: + if count > threshold: + words[2][flags[0] + flags[2]].add(word) + else: + words[3][flags[0] + flags[2]].add(word) + + words_to_edit_sorted = sorted(words_to_edit, key=lambda entry: entry[1], reverse=True) + for word, count in words_to_edit_sorted: + print("------------",word, "%2.1f" % count, "--------------", file=args.ref_lexicon_edits_handle) + learned_prons = [] + for pron in learned_lexicon[word]: + learned_prons.append((src[(word, pron)], 'Y', stats_summed[(word, pron)], pron)) + for pron in ref_lexicon[word]: + if pron not in learned_lexicon[word]: + learned_prons.append(('R', 'N', stats_summed[(word, pron)], pron)) + learned_prons_sorted = sorted(learned_prons, key=lambda item: item[2], reverse=True) + for item in learned_prons_sorted: + print('{} | {} | {:.2f} | {}'.format(item[0], item[1], item[2], item[3]), file=args.ref_lexicon_edits_handle) + + num_oovs_with_acoustic_evidence = len(set(learned_lexicon.keys()).difference(set(ref_lexicon.keys()))) + num_oovs = len(set(counts.keys()).difference(set(ref_lexicon.keys()))) + num_ivs = len(learned_lexicon) - num_oovs_with_acoustic_evidence + print("Average num. prons per word in the learned lexicon is {}".format(float(num_prons_tot)/float(len(learned_lexicon))), file=sys.stderr) + # print("Here are the words whose reference pron candidates were all declined", words[0]['100'], file=sys.stderr) + print("-------------------------------------------------Summary------------------------------------------", file=sys.stderr) + print("We have acoustic evidence for {} out of {} in-vocab (w.r.t the reference lexicon) words from the acoustic training data.".format(num_ivs, len(ref_lexicon)), file=sys.stderr) + print(" Among those frequent words whose counts in the training text > ", threshold, ":", file=sys.stderr) + num_freq_ivs_from_all_sources = len(words[0]['111']) + len(words[0]['110']) + len(words[0]['011']) + num_freq_ivs_from_g2p_or_phonetic_decoding = len(words[0]['101']) + len(words[0]['001']) + len(words[0]['100']) + num_freq_ivs_from_ref = len(words[0]['010']) + num_infreq_ivs_from_all_sources = len(words[1]['111']) + len(words[1]['110']) + len(words[1]['011']) + num_infreq_ivs_from_g2p_or_phonetic_decoding = len(words[1]['101']) + len(words[1]['001']) + len(words[1]['100']) + num_infreq_ivs_from_ref = len(words[1]['010']) + print(' {} words\' selected prons came from the reference lexicon, G2P/phonetic-decoding.'.format(num_freq_ivs_from_all_sources), file=sys.stderr) + print(' {} words\' selected prons come from G2P/phonetic-decoding-generated.'.format(num_freq_ivs_from_g2p_or_phonetic_decoding), file=sys.stderr) + print(' {} words\' selected prons came from the reference lexicon only.'.format(num_freq_ivs_from_ref), file=sys.stderr) + print(' For those words whose counts in the training text <= {}:'.format(threshold), file=sys.stderr) + print(' {} words\' selected prons came from the reference lexicon, G2P/phonetic-decoding.'.format(num_infreq_ivs_from_all_sources), file=sys.stderr) + print(' {} words\' selected prons come from G2P/phonetic-decoding-generated.'.format(num_infreq_ivs_from_g2p_or_phonetic_decoding), file=sys.stderr) + print(' {} words\' selected prons came from the reference lexicon only.'.format(num_infreq_ivs_from_ref), file=sys.stderr) + print("---------------------------------------------------------------------------------------------------", file=sys.stderr) + num_freq_oovs_from_both_sources = len(words[2]['11']) + num_freq_oovs_from_phonetic_decoding = len(words[2]['10']) + num_freq_oovs_from_g2p = len(words[2]['01']) + num_infreq_oovs_from_both_sources = len(words[3]['11']) + num_infreq_oovs_from_phonetic_decoding = len(words[3]['10']) + num_infreq_oovs_from_g2p = len(words[3]['01']) + print('We have acoustic evidence for {} out of {} OOV (w.r.t the reference lexicon) words from the acoustic training data.'.format(num_oovs_with_acoustic_evidence, num_oovs), file=sys.stderr) + print(' Among those words whose counts in the training text > {}:'.format(threshold), file=sys.stderr) + print(' {} words\' selected prons came from G2P and phonetic-decoding.'.format(num_freq_oovs_from_both_sources), file=sys.stderr) + print(' {} words\' selected prons came from phonetic decoding only.'.format(num_freq_oovs_from_phonetic_decoding), file=sys.stderr) + print(' {} words\' selected prons came from G2P only.'.format(num_freq_oovs_from_g2p), file=sys.stderr) + print(' For those words whose counts in the training text <= {}:'.format(threshold), file=sys.stderr) + print(' {} words\' selected prons came from G2P and phonetic-decoding.'.format(num_infreq_oovs_from_both_sources), file=sys.stderr) + print(' {} words\' selected prons came from phonetic decoding only.'.format(num_infreq_oovs_from_phonetic_decoding), file=sys.stderr) + print(' {} words\' selected prons came from G2P only.'.format(num_infreq_oovs_from_g2p), file=sys.stderr) + +def WriteLearnedLexiconOov(learned_lexicon, ref_lexicon, file_handle): + for word, prons in learned_lexicon.iteritems(): + if word not in ref_lexicon: + for pron in prons: + print('{0} {1}'.format(word, pron), file=file_handle) + file_handle.close() + +def Main(): + args = GetArgs() + + # Read in three lexicon sources, word counts, and pron stats. + counts = ReadWordCounts(args.word_counts_file_handle) + ref_lexicon = ReadLexicon(args, args.ref_lexicon_handle, counts) + g2p_lexicon = ReadLexicon(args, args.g2p_lexicon_handle, counts) + pd_lexicon = ReadLexicon(args, args.pd_lexicon_handle, counts) + stats, stats_summed = ReadArcStats(args.arc_stats_file_handle) + learned_lexicon = ReadLexicon(args, args.learned_lexicon_handle, counts) + + # Write the learned prons for words out of the ref. vocab into learned_lexicon_oov. + WriteLearnedLexiconOov(learned_lexicon, ref_lexicon, args.learned_lexicon_oov_handle) + # Edits will be printed into ref_lexicon_edits, and the summary will be printed into stderr. + WriteEditsAndSummary(args, learned_lexicon, ref_lexicon, pd_lexicon, g2p_lexicon, counts, stats, stats_summed) + +if __name__ == "__main__": + Main() diff --git a/egs/wsj/s5/steps/dict/prons_to_lexicon.py b/egs/wsj/s5/steps/dict/prons_to_lexicon.py index 2a87d172602..37d7810411b 100755 --- a/egs/wsj/s5/steps/dict/prons_to_lexicon.py +++ b/egs/wsj/s5/steps/dict/prons_to_lexicon.py @@ -6,6 +6,7 @@ # we're using python 3.x style print but want it to work in python 2.x, from __future__ import print_function +from collections import defaultdict import argparse import sys @@ -21,15 +22,15 @@ def __call__(self, parser, namespace, values, option_string=None): raise Exception("Unknown value {0} for --{1}".format(values, self.dest)) def GetArgs(): - parser = argparse.ArgumentParser(description = "Converts pronunciation statistics (from phone level decoding) " - "into a lexicon for lexicon learning. We prune the pronunciations " + parser = argparse.ArgumentParser(description = "Converts pronunciation statistics (from phonetic decoding or g2p) " + "into a lexicon for. We prune the pronunciations " "based on a provided stats file, and optionally filter out entries which are present " "in a filter lexicon.", epilog = "e.g. steps/dict/prons_to_lexicon.py --min-prob=0.4 \\" "--filter-lexicon=exp/tri3_lex_0.4_work/phone_decode/filter_lexicon.txt \\" "exp/tri3_lex_0.4_work/phone_decode/prons.txt \\" "exp/tri3_lex_0.4_work/lexicon_phone_decoding.txt" - "See steps/dict/learn_lexicon.sh for examples in detail.") + "See steps/dict/learn_lexicon_greedy.sh for examples in detail.") parser.add_argument("--set-sum-to-one", type = str, default = False, action = StrToBoolAction, choices = ["true", "false"], @@ -39,6 +40,8 @@ def GetArgs(): action = StrToBoolAction, choices = ["true", "false"], help = "If normalize lexicon such that the max " "probability is 1.") + parser.add_argument("--top-N", type = int, default = 0, + help = "If non-zero, we just take the top N pronunciations (according to stats/pron-probs) for each word.") parser.add_argument("--min-prob", type = float, default = 0.1, help = "Remove pronunciation with probabilities less " "than this value after normalization.") @@ -46,8 +49,7 @@ def GetArgs(): help = "Exclude entries in this filter lexicon from the output lexicon." "each line must be ") parser.add_argument("stats_file", metavar='', type = str, - help = "Input file containing pronunciation statistics, representing how many times " - "each word-pronunciation appear in the phonetic decoding results." + help = "Input lexicon file containing pronunciation statistics/probs in the first column." "each line must be ") parser.add_argument("out_lexicon", metavar='', type = str, help = "Output lexicon.") @@ -150,6 +152,18 @@ def NormalizeLexicon(lexicon, set_max_to_one = True, prob = 0 lexicon[entry] = prob +def TakeTopN(lexicon, top_N): + lexicon_reshaped = defaultdict(list) + lexicon_pruned = {} + for entry, prob in lexicon.iteritems(): + lexicon_reshaped[entry[0]].append([entry[1], prob]) + for word in lexicon_reshaped: + prons = lexicon_reshaped[word] + sorted_prons = sorted(prons, reverse=True, key=lambda prons: prons[1]) + for i in range(len(sorted_prons)): + if i >= top_N: + lexicon[(word, sorted_prons[i][0])] = 0 + def WriteLexicon(args, lexicon, filter_lexicon): words = set() num_removed = 0 @@ -179,10 +193,15 @@ def Main(): word_probs = ConvertWordCountsToProbs(args, lexicon, word_count) lexicon = ConvertWordProbsToLexicon(word_probs) - filter_lexicon = ReadLexicon(args.filter_lexicon_handle) - NormalizeLexicon(lexicon, set_max_to_one = args.set_max_to_one, - set_sum_to_one = args.set_sum_to_one, - min_prob = args.min_prob) + filter_lexicon = set() + if args.filter_lexicon is not '': + filter_lexicon = ReadLexicon(args.filter_lexicon_handle) + if args.top_N > 0: + TakeTopN(lexicon, args.top_N) + else: + NormalizeLexicon(lexicon, set_max_to_one = args.set_max_to_one, + set_sum_to_one = args.set_sum_to_one, + min_prob = args.min_prob) WriteLexicon(args, lexicon, filter_lexicon) args.out_lexicon_handle.close() diff --git a/egs/wsj/s5/steps/dict/prune_pron_candidates.py b/egs/wsj/s5/steps/dict/prune_pron_candidates.py index affc5b17705..cd90a389a7c 100755 --- a/egs/wsj/s5/steps/dict/prune_pron_candidates.py +++ b/egs/wsj/s5/steps/dict/prune_pron_candidates.py @@ -4,6 +4,7 @@ # Apache 2.0. from __future__ import print_function +from __future__ import division from collections import defaultdict import argparse import sys @@ -16,7 +17,7 @@ def GetArgs(): "(For words in the reference lexicon, N = # pron variants given by the reference" "lexicon; For oov words, N = avg. # pron variants per word in the reference lexicon)." "r is a user-specified constant, like 2.", - epilog = "See steps/dict/learn_lexicon.sh for example") + epilog = "See steps/dict/learn_lexicon_greedy.sh for example") parser.add_argument("--r", type = float, default = "2.0", help = "a user-specified ratio parameter which determines how many" @@ -61,7 +62,7 @@ def ReadStats(pron_stats_handle): phones = ' '.join(splits[2:]) stats[word].append((phones, count)) - for word, entry in stats.iteritems(): + for word, entry in stats.items(): entry.sort(key=lambda x: x[1]) return stats @@ -86,12 +87,12 @@ def PruneProns(args, stats, ref_lexicon): # Compute the average # pron variants counts per word in the reference lexicon. num_words_ref = 0 num_prons_ref = 0 - for word, prons in ref_lexicon.iteritems(): + for word, prons in ref_lexicon.items(): num_words_ref += 1 num_prons_ref += len(prons) avg_variants_counts_ref = math.ceil(float(num_prons_ref) / float(num_words_ref)) - for word, entry in stats.iteritems(): + for word, entry in stats.items(): if word in ref_lexicon: variants_counts = args.r * len(ref_lexicon[word]) else: @@ -105,7 +106,7 @@ def PruneProns(args, stats, ref_lexicon): except IndexError: break - for word, entry in stats.iteritems(): + for word, entry in stats.items(): for pron, prob in entry: if word not in ref_lexicon or pron not in ref_lexicon[word]: print('{0} {1}'.format(word, pron), file=args.pruned_prons_handle) diff --git a/egs/wsj/s5/steps/dict/select_prons_bayesian.py b/egs/wsj/s5/steps/dict/select_prons_bayesian.py index e728a4af0b8..893dd7cb818 100755 --- a/egs/wsj/s5/steps/dict/select_prons_bayesian.py +++ b/egs/wsj/s5/steps/dict/select_prons_bayesian.py @@ -4,6 +4,7 @@ # Apache 2.0. from __future__ import print_function +from __future__ import division from collections import defaultdict import argparse import sys @@ -23,7 +24,7 @@ def GetArgs(): "a learned lexicon for words out of the ref. vocab (learned_lexicon_oov)," "and a lexicon_edits file containing suggested modifications of prons, for" "words within the ref. vocab (ref_lexicon_edits).", - epilog = "See steps/dict/learn_lexicon.sh for example.") + epilog = "See steps/dict/learn_lexicon_bayesian.sh for example.") parser.add_argument("--prior-mean", type = str, default = "0,0,0", help = "Mean of priors (summing up to 1) assigned to three exclusive n" "pronunciatio sources: reference lexicon, g2p, and phonetic decoding. We " @@ -162,7 +163,7 @@ def FilterPhoneticDecodingLexicon(args, phonetic_decoding_lexicon, stats): for line in args.silence_file_handle: silphones.add(line.strip()) rejected_candidates = set() - for word, prons in phonetic_decoding_lexicon.iteritems(): + for word, prons in phonetic_decoding_lexicon.items(): for pron in prons: for phone in pron.split(): if phone in silphones: @@ -194,7 +195,7 @@ def ComputePriorCounts(args, counts, ref_lexicon, g2p_lexicon, phonetic_decoding prior_mean[2] = 0 prior_mean_sum = sum(prior_mean) try: - prior_mean = [t / prior_mean_sum for t in prior_mean] + prior_mean = [float(t) / prior_mean_sum for t in prior_mean] except ZeroDivisionError: print('WARNING: word {} appears in train_counts but not in any lexicon.'.format(word), file=sys.stderr) prior_counts[word] = [t * args.prior_counts_tot for t in prior_mean] @@ -206,20 +207,20 @@ def ComputePosteriors(args, stats, ref_lexicon, g2p_lexicon, phonetic_decoding_l # The soft-counts were augmented by a user-specified prior count, according the source # (ref/G2P/phonetic-decoding) of this pronunciation. - for word, prons in ref_lexicon.iteritems(): + for word, prons in ref_lexicon.items(): for pron in prons: # c is the augmented soft count (observed count + prior count) - c = prior_counts[word][0] / len(ref_lexicon[word]) + stats.get((word, pron), 0) + c = float(prior_counts[word][0]) / len(ref_lexicon[word]) + stats.get((word, pron), 0) posteriors[word].append((pron, c)) - for word, prons in g2p_lexicon.iteritems(): + for word, prons in g2p_lexicon.items(): for pron in prons: - c = prior_counts[word][1] / len(g2p_lexicon[word]) + stats.get((word, pron), 0) + c = float(prior_counts[word][1]) / len(g2p_lexicon[word]) + stats.get((word, pron), 0) posteriors[word].append((pron, c)) - for word, prons in phonetic_decoding_lexicon.iteritems(): + for word, prons in phonetic_decoding_lexicon.items(): for pron in prons: - c = prior_counts[word][2] / len(phonetic_decoding_lexicon[word]) + stats.get((word, pron), 0) + c = float(prior_counts[word][2]) / len(phonetic_decoding_lexicon[word]) + stats.get((word, pron), 0) posteriors[word].append((pron, c)) num_prons_from_ref = sum(len(ref_lexicon[i]) for i in ref_lexicon) @@ -239,10 +240,10 @@ def ComputePosteriors(args, stats, ref_lexicon, g2p_lexicon, phonetic_decoding_l # each entry is a pair: (prounciation, count) count_sum[word] = sum([entry[1] for entry in posteriors[word]]) - for word, entry in posteriors.iteritems(): + for word, entry in posteriors.items(): new_entry = [] for pron, count in entry: - post = count / count_sum[word] + post = float(count) / count_sum[word] new_entry.append((pron, post)) source = 'R' if word in g2p_lexicon and pron in g2p_lexicon[word]: @@ -260,7 +261,7 @@ def SelectPronsBayesian(args, counts, posteriors, ref_lexicon, g2p_lexicon, phon phonetic_decoding_selected = 0 learned_lexicon = defaultdict(set) - for word, entry in posteriors.iteritems(): + for word, entry in posteriors.items(): num_variants = 0 post_tot = 0.0 variants_counts = args.variants_counts @@ -411,7 +412,7 @@ def WriteEditsAndSummary(args, learned_lexicon, ref_lexicon, phonetic_decoding_l print(' {} words\' selected prons came from G2P only.'.format(num_infreq_oovs_from_g2p), file=sys.stderr) def WriteLearnedLexiconOov(learned_lexicon, ref_lexicon, file_handle): - for word, prons in learned_lexicon.iteritems(): + for word, prons in learned_lexicon.items(): if word not in ref_lexicon: for pron in prons: print('{0} {1}'.format(word, pron), file=file_handle) diff --git a/egs/wsj/s5/steps/dict/select_prons_greedy.py b/egs/wsj/s5/steps/dict/select_prons_greedy.py new file mode 100755 index 00000000000..cf71070e134 --- /dev/null +++ b/egs/wsj/s5/steps/dict/select_prons_greedy.py @@ -0,0 +1,376 @@ +#!/usr/bin/env python + +# Copyright 2018 Xiaohui Zhang +# Apache 2.0. + +from __future__ import print_function +from collections import defaultdict +import argparse +import sys +import math + +def GetArgs(): + parser = argparse.ArgumentParser( + description = "Use a greedy framework to select pronunciation candidates" + "from three sources: a reference lexicon, G2P lexicon and phonetic-decoding" + "(PD) lexicon. Basically, this script implements the Alg. 1 in the paper:" + "Acoustic data-driven lexicon learning based on a greedy pronunciation " + "selection framework, by X. Zhang, V. Mahonar, D. Povey and S. Khudanpur," + "Interspeech 2017. The inputs are an arc-stats file, containing " + "acoustic evidence (tau_{uwb} in the paper) and three source lexicons " + "(phonetic-decoding(PD)/G2P/ref). The outputs is the learned lexicon for" + "all words in the arc_stats (acoustic evidence) file.", + epilog = "See steps/dict/learn_lexicon_greedy.sh for example.") + parser.add_argument("--alpha", type = str, default = "0,0,0", + help = "Scaling factors for the likelihood reduction threshold." + "of three pronunciaiton candidate sources: phonetic-decoding (PD)," + "G2P and reference. The valid range of each dimension is [0, 1], and" + "a large value means we prune pronunciations from this source more" + "aggressively. Setting a dimension to zero means we never want to remove" + "pronunciaiton from that source. See Section 4.3 in the paper for details.") + parser.add_argument("--beta", type = str, default = "0,0,0", + help = "smoothing factors for the likelihood reduction term." + "of three pronunciaiton candidate sources: phonetic-decoding (PD)," + "G2P and reference. The valid range of each dimension is [0, 100], and" + "a large value means we prune pronunciations from this source more" + "aggressively. See Section 4.3 in the paper for details.") + parser.add_argument("--delta", type = float, default = 0.000000001, + help = "Floor value of the pronunciation posterior statistics." + "The valid range is (0, 0.01)," + "See Section 3 in the paper for details.") + parser.add_argument("silence_phones_file", metavar = "", type = str, + help = "File containing a list of silence phones.") + parser.add_argument("arc_stats_file", metavar = "", type = str, + help = "File containing word-pronunciation statistics obtained from lattices; " + "each line must be ") + parser.add_argument("word_counts_file", metavar = "", type = str, + help = "File containing word counts in acoustic training data; " + "each line must be .") + parser.add_argument("ref_lexicon", metavar = "", type = str, + help = "The reference lexicon (most probably hand-derived)." + "Each line must be ") + parser.add_argument("g2p_lexicon", metavar = "", type = str, + help = "Candidate ronouciations from G2P results." + "Each line must be ") + parser.add_argument("pd_lexicon", metavar = "", type = str, + help = "Candidate ronouciations from phonetic decoding results." + "Each line must be ") + parser.add_argument("learned_lexicon", metavar = "", type = str, + help = "Learned lexicon.") + + + print (' '.join(sys.argv), file=sys.stderr) + + args = parser.parse_args() + args = CheckArgs(args) + + return args + +def CheckArgs(args): + args.silence_phones_file_handle = open(args.silence_phones_file) + if args.arc_stats_file == "-": + args.arc_stats_file_handle = sys.stdin + else: + args.arc_stats_file_handle = open(args.arc_stats_file) + args.word_counts_file_handle = open(args.word_counts_file) + args.ref_lexicon_handle = open(args.ref_lexicon) + args.g2p_lexicon_handle = open(args.g2p_lexicon) + args.pd_lexicon_handle = open(args.pd_lexicon) + args.learned_lexicon_handle = open(args.learned_lexicon, "w") + + alpha = args.alpha.strip().split(',') + if len(alpha) is not 3: + raise Exception('Invalid alpha ', args.alpha) + for i in range(0,3): + if float(alpha[i]) < 0 or float(alpha[i]) > 1: + raise Exception('alaph ', alpha[i], + ' is invalid, it must be within [0, 1].') + if float(alpha[i]) == 0: + alpha[i] = -1e-3 + # The absolute likelihood loss (search for loss_abs) is supposed to be positive. + # But it could be negative near zero because of numerical precision limit. + # In this case, even if alpha is set to be zero, which means we never want to + # remove pronunciation from that source, the quality score (search for q_b) + # could still be negative, which means this pron could be potentially removed. + # To prevent this, we set alpha as a negative value near zero to ensure + # q_b is always positive. + + args.alpha = [float(alpha[0]), float(alpha[1]), float(alpha[2])] + print("[alpha_{pd}, alpha_{g2p}, alpha_{ref}] is: ", args.alpha) + exit + beta = args.beta.strip().split(',') + if len(beta) is not 3: + raise Exception('Invalid beta ', args.beta) + for i in range(0,3): + if float(beta[i]) < 0 or float(beta[i]) > 100: + raise Exception('beta ', beta[i], + ' is invalid, it must be within [0, 100].') + args.beta = [float(beta[0]), float(beta[1]), float(beta[2])] + print("[beta_{pd}, beta_{g2p}, beta_{ref}] is: ", args.beta) + + if args.delta <= 0 or args.delta > 0.1: + raise Exception('delta ', args.delta, ' is invalid, it must be within' + '(0, 0.01).') + print("delta is: ", args.delta) + + return args + +def ReadArcStats(arc_stats_file_handle): + stats = defaultdict(lambda : defaultdict(dict)) + stats_summed = defaultdict(float) + for line in arc_stats_file_handle.readlines(): + splits = line.strip().split() + + if (len(splits) == 0): + continue + + if (len(splits) < 5): + raise Exception('Invalid format of line ' + line + + ' in ' + arc_stats_file) + utt = splits[1] + start_frame = int(splits[2]) + word = splits[0] + count = float(splits[3]) + phones = splits[4:] + phones = ' '.join(phones) + stats[word][(utt, start_frame)][phones] = count + stats_summed[(word, phones)] += count + return stats, stats_summed + +def ReadWordCounts(word_counts_file_handle): + counts = {} + for line in word_counts_file_handle.readlines(): + splits = line.strip().split() + if len(splits) < 2: + raise Exception('Invalid format of line ' + line + + ' in counts file.') + word = splits[0] + count = int(splits[1]) + counts[word] = count + return counts + +def ReadLexicon(args, lexicon_file_handle, counts): + # we're skipping any word not in counts (not seen in training data), + # cause we're only learning prons for words who have acoustic examples. + lexicon = defaultdict(set) + for line in lexicon_file_handle.readlines(): + splits = line.strip().split() + if len(splits) == 0: + continue + if len(splits) < 2: + raise Exception('Invalid format of line ' + line + + ' in lexicon file.') + word = splits[0] + if word not in counts: + continue + phones = ' '.join(splits[1:]) + lexicon[word].add(phones) + return lexicon + +def FilterPhoneticDecodingLexicon(args, pd_lexicon): + # We want to remove all candidates which contain silence phones + silphones = set() + for line in args.silence_phones_file_handle: + silphones.add(line.strip()) + rejected_candidates = set() + for word, prons in pd_lexicon.iteritems(): + for pron in prons: + for phone in pron.split(): + if phone in silphones: + rejected_candidates.add((word, pron)) + break + for word, pron in rejected_candidates: + pd_lexicon[word].remove(pron) + return pd_lexicon + +# One iteration of Expectation-Maximization computation (Eq. 3-4 in the paper). +def OneEMIter(args, word, stats, prons, pron_probs, debug=False): + prob_acc = [0.0 for i in range(len(prons[word]))] + s = sum(pron_probs) + for i in range(len(pron_probs)): + pron_probs[i] = pron_probs[i] / s + log_like = 0.0 + for (utt, start_frame) in stats[word]: + prob = [] + soft_counts = [] + for i in range(len(prons[word])): + phones = prons[word][i] + soft_count = stats[word][(utt, start_frame)].get(phones, 0) + if soft_count < args.delta: + soft_count = args.delta + soft_counts.append(soft_count) + prob = [i[0] * i[1] for i in zip(soft_counts, pron_probs)] + for i in range(len(prons[word])): + prob_acc[i] += prob[i] / sum(prob) + log_like += math.log(sum(prob)) + pron_probs = [1.0 / float(len(stats[word])) * p for p in prob_acc] + log_like = 1.0 / float(len(stats[word])) * log_like + if debug: + print("Log_like of the word: ", log_like, "pron probs: ", pron_probs) + return pron_probs, log_like + +def SelectPronsGreedy(args, stats, counts, ref_lexicon, g2p_lexicon, pd_lexicon, dianostic_info=False): + prons = defaultdict(list) # Put all possible prons from three source lexicons into this dictionary + src = {} # Source of each (word, pron) pair: 'P' = phonetic-decoding, 'G' = G2P, 'R' = reference + learned_lexicon = defaultdict(set) # Put all selected prons in this dictionary + for lexicon in ref_lexicon, g2p_lexicon, pd_lexicon: + for word in lexicon: + for pron in lexicon[word]: + prons[word].append(pron) + for word in prons: + for pron in prons[word]: + if word in pd_lexicon and pron in pd_lexicon[word]: + src[(word, pron)] = 'P' + if word in g2p_lexicon and pron in g2p_lexicon[word]: + src[(word, pron)] = 'G' + if word in ref_lexicon and pron in ref_lexicon[word]: + src[(word, pron)] = 'R' + + for word in prons: + if word not in stats: + continue + n = len(prons[word]) + pron_probs = [1/float(n) for i in range(n)] + if dianostic_info: + print("pronunciations of word '{}': {}".format(word, prons[word])) + active_indexes = set(range(len(prons[word]))) + + deleted_prons = [] # indexes of prons to be deleted + soft_counts_normalized = [] + while len(active_indexes) > 1: + log_like = 1.0 + log_like_last = -1.0 + num_iters = 0 + while abs(log_like - log_like_last) > 1e-7: + num_iters += 1 + log_like_last = log_like + pron_probs, log_like = OneEMIter(args, word, stats, prons, pron_probs, False) + if log_like_last == 1.0 and len(soft_counts_normalized) == 0: # the first iteration + soft_counts_normalized = pron_probs + if dianostic_info: + print("Avg.(over all egs) soft counts: {}".format(soft_counts_normalized)) + if dianostic_info: + print("\n Log_like after {} iters of EM: {}, estimated pron_probs: {} \n".format( + num_iters, log_like, pron_probs)) + candidates_to_delete = [] + + for i in active_indexes: + pron_probs_mod = [p for p in pron_probs] + pron_probs_mod[i] = 0.0 + for j in range(len(pron_probs_mod)): + if j in active_indexes and j != i: + pron_probs_mod[j] += 0.01 + pron_probs_mod = [s / sum(pron_probs_mod) for s in pron_probs_mod] + log_like2 = 1.0 + log_like2_last = -1.0 + num_iters2 = 0 + # Running EM until convengence + while abs(log_like2 - log_like2_last) > 0.001 : + num_iters2 += 1 + log_like2_last = log_like2 + pron_probs_mod, log_like2 = OneEMIter(args, word, stats, + prons, pron_probs_mod, False) + + loss_abs = log_like - log_like2 # absolute likelihood loss before normalization + # (supposed to be positive, but could be negative near zero because of numerical precision limit). + log_delta = math.log(args.delta) + thr = -log_delta + loss = loss_abs + source = src[(word, prons[word][i])] + if dianostic_info: + print("\n set the pron_prob of '{}' whose source is {}, to zero results in {}" + " loss in avg. log-likelihood; Num. iters until converging:{}. ".format( + prons[word][i], source, loss, num_iters2)) + # Compute quality score q_b = loss_abs * / (M_w + beta_s(b)) + alpha_s(b) * log_delta + # See Sec. 4.3 and Alg. 1 in the paper. + if source == 'P': + thr *= args.alpha[0] + loss *= float(len(stats[word])) / (float(len(stats[word])) + args.beta[0]) + if source == 'G': + thr *= args.alpha[1] + loss *= float(len(stats[word])) / (float(len(stats[word])) + args.beta[1]) + if source == 'R': + thr *= args.alpha[2] + loss *= float(len(stats[word])) / (float(len(stats[word])) + args.beta[2]) + if loss - thr < 0: # loss - thr here is just q_b + if dianostic_info: + print("Smoothed log-like loss {} is smaller than threshold {} so that the quality" + "score {} is negative, adding the pron to the list of candidates to delete" + ". ".format(loss, thr, loss-thr)) + candidates_to_delete.append((loss-thr, i)) + if len(candidates_to_delete) == 0: + break + candidates_to_delete_sorted = sorted(candidates_to_delete, + key=lambda candidates_to_delete: candidates_to_delete[0]) + + deleted_candidate = candidates_to_delete_sorted[0] + active_indexes.remove(deleted_candidate[1]) + pron_probs[deleted_candidate[1]] = 0.0 + for i in range(len(pron_probs)): + if i in active_indexes: + pron_probs[i] += 0.01 + pron_probs = [s / sum(pron_probs) for s in pron_probs] + source = src[(word, prons[word][deleted_candidate[1]])] + pron = prons[word][deleted_candidate[1]] + soft_count = soft_counts_normalized[deleted_candidate[1]] + quality_score = deleted_candidate[0] + # This part of diagnostic info provides hints to the user on how to adjust the parameters. + if dianostic_info: + print("removed pron {}, from source {} with quality score {:.5f}".format( + pron, source, quality_score)) + if (source == 'P' and soft_count > 0.7 and len(stats[word]) > 5): + print("WARNING: alpha_{pd} or beta_{pd} may be too large!" + " For the word '{}' whose count is {}, the candidate " + " pronunciation from phonetic decoding '{}' with normalized " + " soft count {} (out of 1) is rejected. It shouldn't have been" + " rejected if alpha_{pd} is smaller than {}".format( + word, len(stats[word]), pron, soft_count, -loss / log_delta, + -args.alpha[0] * len(stats[word]) + (objf_change + args.beta[0])), + file=sys.stderr) + if loss_abs > thr: + print(" or beta_{pd} is smaller than {}".format( + (loss_abs / thr - 1) * len(stats[word])), file=sys.stderr) + if (source == 'G' and soft_count > 0.7 and len(stats[word]) > 5): + print("WARNING: alpha_{g2p} or beta_{g2p} may be too large!" + " For the word '{}' whose count is {}, the candidate " + " pronunciation from G2P '{}' with normalized " + " soft count {} (out of 1) is rejected. It shouldn't have been" + " rejected if alpha_{g2p} is smaller than {} ".format( + word, len(stats[word]), pron, soft_count, -loss / log_delta, + -args.alpha[1] * len(stats[word]) + (objf_change + args.beta[1])), + file=sys.stderr) + if loss_abs > thr: + print(" or beta_{g2p} is smaller than {}.".format(( + loss_abs / thr - 1) * len(stats[word])), file=sys.stderr) + deleted_prons.append(deleted_candidate[1]) + for i in range(len(prons[word])): + if i not in deleted_prons: + learned_lexicon[word].add(prons[word][i]) + + return learned_lexicon + +def WriteLearnedLexicon(learned_lexicon, file_handle): + for word, prons in learned_lexicon.iteritems(): + for pron in prons: + print('{0} {1}'.format(word, pron), file=file_handle) + file_handle.close() + +def Main(): + args = GetArgs() + + # Read in three lexicon sources, word counts, and pron stats. + counts = ReadWordCounts(args.word_counts_file_handle) + ref_lexicon = ReadLexicon(args, args.ref_lexicon_handle, counts) + g2p_lexicon = ReadLexicon(args, args.g2p_lexicon_handle, counts) + pd_lexicon = ReadLexicon(args, args.pd_lexicon_handle, counts) + stats, stats_summed = ReadArcStats(args.arc_stats_file_handle) + pd_lexicon = FilterPhoneticDecodingLexicon(args, pd_lexicon) + + # Select prons to construct the learned lexicon. + learned_lexicon = SelectPronsGreedy(args, stats, counts, ref_lexicon, g2p_lexicon, pd_lexicon) + + # Write the learned prons for words out of the ref. vocab into learned_lexicon_oov. + WriteLearnedLexicon(learned_lexicon, args.learned_lexicon_handle) + +if __name__ == "__main__": + Main() diff --git a/egs/wsj/s5/steps/dict/train_g2p.sh b/egs/wsj/s5/steps/dict/train_g2p.sh index d793bbb5d8f..75eb3fc88ec 100755 --- a/egs/wsj/s5/steps/dict/train_g2p.sh +++ b/egs/wsj/s5/steps/dict/train_g2p.sh @@ -24,13 +24,18 @@ set -e if [ $# != 2 ]; then echo "Usage: $0 [options] " echo " where is the training lexicon (one pronunciation per " - echo " word per line) and is directory where the models will " - echo " be stored" - echo "e.g.: train_g2p.sh data/local/lexicon.txt exp/g2p/" + echo " word per line, with lines like 'hello h uh l ow') and" + echo " is directory where the models will be stored" + echo "e.g.: train_g2p.sh --silence-phones data/local/dict/silence_phones.txt data/local/dict/lexicon.txt exp/g2p/" echo "" echo "main options (for others, see top of script file)" echo " --iters # How many iterations. Relates to N-ngram order" echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --silence-phones # e.g. data/local/dict/silence_phones.txt." + echo " # A list of silence phones, one or more per line" + echo " # Relates to --only-words option" + echo " --only-words (true|false) (default: true) # If true, exclude silence words, i.e." + echo " # words with 1 phone which is a silence." exit 1; fi diff --git a/egs/wsj/s5/steps/dict/train_g2p_phonetisaurus.sh b/egs/wsj/s5/steps/dict/train_g2p_phonetisaurus.sh new file mode 100755 index 00000000000..94c483e09e2 --- /dev/null +++ b/egs/wsj/s5/steps/dict/train_g2p_phonetisaurus.sh @@ -0,0 +1,88 @@ +#!/bin/bash + +# Copyright 2017 Intellisist, Inc. (Author: Navneeth K) +# 2017 Xiaohui Zhang +# 2018 Ruizhe Huang +# Apache License 2.0 + +# This script trains a g2p model using Phonetisaurus. + +stage=0 +encoding='utf-8' +only_words=true +silence_phones= + +echo "$0 $@" # Print the command line for logging + +[ -f ./path.sh ] && . ./path.sh; # source the path. +. utils/parse_options.sh || exit 1; + +set -u +set -e + +if [ $# != 2 ]; then + echo "Usage: $0 [options] " + echo " where is the training lexicon (one pronunciation per " + echo " word per line, with lines like 'hello h uh l ow') and" + echo " is directory where the models will be stored" + echo "e.g.: $0 --silence-phones data/local/dict/silence_phones.txt data/local/dict/lexicon.txt exp/g2p/" + echo "" + echo "main options (for others, see top of script file)" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --silence-phones # e.g. data/local/dict/silence_phones.txt." + echo " # A list of silence phones, one or more per line" + echo " # Relates to --only-words option" + echo " --only-words (true|false) (default: true) # If true, exclude silence words, i.e." + echo " # words with one or multiple phones which are all silence." + exit 1; +fi + +lexicon=$1 +wdir=$2 + +[ ! -f $lexicon ] && echo "Cannot find $lexicon" && exit + +isuconv=`which uconv` +if [ -z $isuconv ]; then + echo "uconv was not found. You must install the icu4c package." + exit 1; +fi + +if ! phonetisaurus=`which phonetisaurus-apply` ; then + echo "Phonetisarus was not found !" + echo "Go to $KALDI_ROOT/tools and execute extras/install_phonetisaurus.sh" + exit 1 +fi + +mkdir -p $wdir + + +# For input lexicon, remove pronunciations containing non-utf-8-encodable characters, +# and optionally remove words that are mapped to a single silence phone from the lexicon. +if [ $stage -le 0 ]; then + if $only_words && [ ! -z "$silence_phones" ]; then + awk 'NR==FNR{a[$1] = 1; next} {s=$2;for(i=3;i<=NF;i++) s=s" "$i; if(!(s in a)) print $1" "s}' \ + $silence_phones $lexicon | \ + awk '{printf("%s\t",$1); for (i=2;i 0'> $wdir/lexicon_tab_separated.txt + else + awk '{printf("%s\t",$1); for (i=2;i 0'> $wdir/lexicon_tab_separated.txt + fi +fi + +if [ $stage -le 1 ]; then + # Align lexicon stage. Lexicon is assumed to have first column tab separated + phonetisaurus-align --input=$wdir/lexicon_tab_separated.txt --ofile=${wdir}/aligned_lexicon.corpus || exit 1; +fi + +if [ $stage -le 2 ]; then + # Convert aligned lexicon to arpa using make_kn_lm.py, a re-implementation of srilm's ngram-count functionality. + ./utils/lang/make_kn_lm.py -ngram-order 7 -text ${wdir}/aligned_lexicon.corpus -lm ${wdir}/aligned_lexicon.arpa +fi + +if [ $stage -le 3 ]; then + # Convert the arpa file to FST. + phonetisaurus-arpa2wfst --lm=${wdir}/aligned_lexicon.arpa --ofile=${wdir}/model.fst +fi + diff --git a/egs/wsj/s5/steps/get_ctm.sh b/egs/wsj/s5/steps/get_ctm.sh index 85286e47bea..6ebce2049d9 100755 --- a/egs/wsj/s5/steps/get_ctm.sh +++ b/egs/wsj/s5/steps/get_ctm.sh @@ -2,9 +2,10 @@ # Copyright Johns Hopkins University (Author: Daniel Povey) 2012. Apache 2.0. # This script produces CTM files from a decoding directory that has lattices -# present. It does this for a range of language model weights; see also +# present. It does this for a range of language model weights; see also # get_ctm_fast.sh which does it for just one LM weight and also supports -# the word insertion penalty. +# the word insertion penalty, and get_ctm_conf.sh which outputs CTM files +# with confidence scores. # begin configuration section. @@ -36,7 +37,7 @@ if [ $# -ne 3 ]; then echo " # not equal to 0.01 seconds" echo "e.g.:" echo "$0 data/train data/lang exp/tri4a/decode/" - echo "See also: steps/get_train_ctm.sh, steps/get_ctm_fast.sh" + echo "See also: steps/get_train_ctm.sh, steps/get_ctm_fast.sh, steps/get_ctm_conf.sh" exit 1; fi diff --git a/egs/wsj/s5/steps/get_ctm_conf.sh b/egs/wsj/s5/steps/get_ctm_conf.sh new file mode 120000 index 00000000000..cee23c66bf8 --- /dev/null +++ b/egs/wsj/s5/steps/get_ctm_conf.sh @@ -0,0 +1 @@ +conf/get_ctm_conf.sh \ No newline at end of file diff --git a/egs/wsj/s5/steps/get_ctm_conf_fast.sh b/egs/wsj/s5/steps/get_ctm_conf_fast.sh new file mode 100755 index 00000000000..088fbd4a9cf --- /dev/null +++ b/egs/wsj/s5/steps/get_ctm_conf_fast.sh @@ -0,0 +1,87 @@ +#!/bin/bash +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey) +# 2017 Vimal Manohar +# 2018 Xiaohui Zhang +# 2018 Music Technology Group, Universitat Pompeu Fabra. +# Apache 2.0 + +# This script produces CTM files with confidence scores +# from a decoding directory that has lattices +# present. It does this for one LM weight and also supports +# the word insertion penalty. +# This is similar to get_ctm_conf.sh, but gets the CTM at the utterance-level. +# It can be faster than steps/get_ctm_conf.sh --use-segments false as it splits +# the process across many jobs. + +# begin configuration section. +cmd=run.pl +stage=0 +frame_shift=0.01 +lmwt=10 +wip=0.0 +print_silence=false +#end configuration section. + +echo "$0 $@" # Print the command line for logging + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -ne 4 ]; then + echo "Usage: $0 [options] " + echo " Options:" + echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." + echo " --stage (0|1|2) # start scoring script from part-way through." + echo " --frame-shift (default=0.01) # specify this if your lattices have a frame-shift" + echo " # not equal to 0.01 seconds" + echo "e.g.:" + echo "$0 data/train data/lang exp/tri4a/decode/" + echo "See also: steps/get_ctm.sh, steps/get_ctm_conf.sh" + exit 1; +fi + +data=$1 +lang=$2 # Note: may be graph directory not lang directory, but has the necessary stuff copied. +decode_dir=$3 +dir=$4 + +if [ -f $decode_dir/final.mdl ]; then + model=$decode_dir/final.mdl +else + model=$decode_dir/../final.mdl # assume model one level up from decoding dir. +fi + +for f in $lang/words.txt $model $decode_dir/lat.1.gz; do + [ ! -f $f ] && echo "$0: expecting file $f to exist" && exit 1; +done + +mkdir -p $dir + +nj=$(cat $decode_dir/num_jobs) +echo $nj > $dir/num_jobs + +if [ -f $lang/phones/word_boundary.int ]; then + $cmd JOB=1:$nj $dir/log/get_ctm.JOB.log \ + set -o pipefail '&&' \ + lattice-add-penalty --word-ins-penalty=$wip "ark:gunzip -c $decode_dir/lat.JOB.gz|" ark:- \| \ + lattice-prune --inv-acoustic-scale=$lmwt --beam=5 ark:- ark:- \| \ + lattice-align-words $lang/phones/word_boundary.int $model ark:- ark:- \| \ + lattice-to-ctm-conf --frame-shift=$frame_shift --decode-mbr=true --inv-acoustic-scale=$lmwt ark:- - \| \ + utils/int2sym.pl -f 5 $lang/words.txt \ + '>' $dir/ctm.JOB || exit 1; +elif [ -f $lang/phones/align_lexicon.int ]; then + set -o pipefail '&&' \ + lattice-add-penalty --word-ins-penalty=$wip "ark:gunzip -c $decode_dir/lat.JOB.gz|" ark:- \| \ + lattice-prune --inv-acoustic-scale=$lmwt --beam=5 ark:- ark:- \| \ + lattice-align-words-lexicon $lang/phones/align_lexicon.int $model ark:- ark:- \| \ + lattice-to-ctm-conf --frame-shift=$frame_shift --decode-mbr=true --inv-acoustic-scale=$lmwt ark:- - \| \ + utils/int2sym.pl -f 5 $lang/words.txt \ + '>' $dir/ctm.JOB || exit 1; +else + echo "$0: neither $lang/phones/word_boundary.int nor $lang/phones/align_lexicon.int exists: cannot align." + exit 1; +fi + +for n in `seq $nj`; do + cat $dir/ctm.$n +done > $dir/ctm diff --git a/egs/wsj/s5/steps/get_ctm_fast.sh b/egs/wsj/s5/steps/get_ctm_fast.sh index 75b666300fe..b0fae12b7bc 100755 --- a/egs/wsj/s5/steps/get_ctm_fast.sh +++ b/egs/wsj/s5/steps/get_ctm_fast.sh @@ -1,7 +1,9 @@ #!/bin/bash -# Copyright Johns Hopkins University (Author: Daniel Povey) 2012. Apache 2.0. -# Copyright 2017 Vimal Manohar -# Music Technology Group, Universitat Pompeu Fabra, 2018. Apache 2.0 +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey) +# 2017 Vimal Manohar +# 2018 Xiaohui Zhang +# 2018 Music Technology Group, Universitat Pompeu Fabra. +# Apache 2.0 # This script produces CTM files from a decoding directory that has lattices # present. It does this for one LM weight and also supports @@ -33,7 +35,7 @@ if [ $# -ne 4 ]; then echo " # not equal to 0.01 seconds" echo "e.g.:" echo "$0 data/train data/lang exp/tri4a/decode/" - echo "See also: steps/get_ctm.sh" + echo "See also: steps/get_ctm.sh, steps/get_ctm_conf.sh" exit 1; fi diff --git a/egs/wsj/s5/steps/get_prons.sh b/egs/wsj/s5/steps/get_prons.sh index e7a25890ba6..4c5453edbe2 100755 --- a/egs/wsj/s5/steps/get_prons.sh +++ b/egs/wsj/s5/steps/get_prons.sh @@ -3,7 +3,7 @@ # 2014 Guoguo Chen # Apache 2.0 -# Begin configuration section. +# Begin configuration section. cmd=run.pl stage=1 lmwt=10 @@ -29,6 +29,13 @@ if [ $# != 3 ]; then exit 1; fi +# As the usage message of nbest-to-prons says, its output has lines that can be interpreted as +# ... +# and you could convert these into text form using a command like: +# gunzip -c prons.*.gz | utils/sym2int.pl -f 4 words.txt | utils/sym2int.pl -f 5- phones.txt + + + data=$1 lang=$2 dir=$3 @@ -66,7 +73,7 @@ fi if [ -f $dir/ali.1.gz ]; then echo "$0: $dir/ali.1.gz exists, so starting from alignments." - + if [ $stage -le 1 ]; then rm $dir/prons.*.gz 2>/dev/null $cmd JOB=1:$nj $dir/log/nbest_to_prons.JOB.log \ diff --git a/egs/wsj/s5/steps/get_train_ctm.sh b/egs/wsj/s5/steps/get_train_ctm.sh index 878e11e45ac..6942014fc88 100755 --- a/egs/wsj/s5/steps/get_train_ctm.sh +++ b/egs/wsj/s5/steps/get_train_ctm.sh @@ -20,8 +20,9 @@ echo "$0 $@" # Print the command line for logging [ -f ./path.sh ] && . ./path.sh . parse_options.sh || exit 1; -if [ $# -ne 3 ]; then - echo "Usage: $0 [options] " +if [ $# -ne 3 ] && [ $# -ne 4 ]; then + echo "Usage: $0 [options] []" + echo "( defaults to .)" echo " Options:" echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." echo " --stage (0|1|2) # start scoring script from part-way through." @@ -39,27 +40,31 @@ fi data=$1 lang=$2 # Note: may be graph directory not lang directory, but has the necessary stuff copied. -dir=$3 +ali_dir=$3 +dir=$4 +if [ -z $dir ]; then + dir=$ali_dir +fi -model=$dir/final.mdl # assume model one level up from decoding dir. +model=$ali_dir/final.mdl # assume model one level up from decoding dir. -for f in $lang/words.txt $model $dir/ali.1.gz $lang/oov.int; do +for f in $lang/words.txt $model $ali_dir/ali.1.gz $lang/oov.int; do [ ! -f $f ] && echo "$0: expecting file $f to exist" && exit 1; done oov=`cat $lang/oov.int` || exit 1; -nj=`cat $dir/num_jobs` || exit 1; +nj=`cat $ali_dir/num_jobs` || exit 1; split_data.sh $data $nj || exit 1; sdata=$data/split$nj -mkdir -p $dir/log +mkdir -p $dir/log || exit 1; if [ $stage -le 0 ]; then if [ -f $lang/phones/word_boundary.int ]; then $cmd JOB=1:$nj $dir/log/get_ctm.JOB.log \ - set -o pipefail '&&' linear-to-nbest "ark:gunzip -c $dir/ali.JOB.gz|" \ + set -o pipefail '&&' linear-to-nbest "ark:gunzip -c $ali_dir/ali.JOB.gz|" \ "ark:utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt < $sdata/JOB/text |" \ '' '' ark:- \| \ lattice-align-words $lang/phones/word_boundary.int $model ark:- ark:- \| \ @@ -72,7 +77,7 @@ if [ $stage -le 0 ]; then exit 1; fi $cmd JOB=1:$nj $dir/log/get_ctm.JOB.log \ - set -o pipefail '&&' linear-to-nbest "ark:gunzip -c $dir/ali.JOB.gz|" \ + set -o pipefail '&&' linear-to-nbest "ark:gunzip -c $ali_dir/ali.JOB.gz|" \ "ark:utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt < $sdata/JOB/text |" \ '' '' ark:- \| \ lattice-align-words-lexicon $lang/phones/align_lexicon.int $model ark:- ark:- \| \ @@ -94,4 +99,3 @@ if [ $stage -le 1 ]; then fi rm $dir/ctm.*.gz fi - diff --git a/egs/wsj/s5/steps/libs/common.py b/egs/wsj/s5/steps/libs/common.py index 1e8e2ced6ce..6bf0ea4932c 100644 --- a/egs/wsj/s5/steps/libs/common.py +++ b/egs/wsj/s5/steps/libs/common.py @@ -10,6 +10,7 @@ """ from __future__ import print_function +from __future__ import division import argparse import logging import math @@ -18,6 +19,11 @@ import sys import threading +try: + import thread as thread_module +except: + import _thread as thread_module + logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -230,8 +236,7 @@ def background_command_waiter(command, popen_object, require_zero_status): logger.error(str) # thread.interrupt_main() sends a KeyboardInterrupt to the main # thread, which will generally terminate the program. - import thread - thread.interrupt_main() + thread_module.interrupt_main() else: logger.warning(str) @@ -312,7 +317,7 @@ def read_kaldi_matrix(matrix_file): 'matrix_file' and stores it as a list of rows, where each row is a list. """ try: - lines = map(lambda x: x.split(), open(matrix_file).readlines()) + lines = [x.split() for x in open(matrix_file).readlines()] first_field = lines[0][0] last_field = lines[-1][-1] lines[0] = lines[0][1:] @@ -322,7 +327,7 @@ def read_kaldi_matrix(matrix_file): "Kaldi matrix file has incorrect format, " "only text format matrix files can be read by this script") for i in range(len(lines)): - lines[i] = map(lambda x: int(float(x)), lines[i]) + lines[i] = [int(float(x)) for x in lines[i]] return lines except IOError: raise Exception("Error while reading the kaldi matrix file " @@ -344,7 +349,7 @@ def write_kaldi_matrix(output_file, matrix): if num_cols != len(matrix[row_index]): raise Exception("All the rows of a matrix are expected to " "have the same length") - f.write(" ".join(map(lambda x: str(x), matrix[row_index]))) + f.write(" ".join([str(x) for x in matrix[row_index]])) if row_index != num_rows - 1: f.write("\n") f.write(" ]") @@ -504,7 +509,7 @@ def compute_idct_matrix(K, N, cepstral_lifter=0): lifter_coeffs = compute_lifter_coeffs(cepstral_lifter, K) for k in range(0, K): for n in range(0, N): - matrix[n][k] = matrix[n][k] / lifter_coeffs[k] + matrix[n][k] = float(matrix[n][k]) / lifter_coeffs[k] return matrix diff --git a/egs/wsj/s5/steps/libs/nnet3/report/log_parse.py b/egs/wsj/s5/steps/libs/nnet3/report/log_parse.py index 73f4e5b6533..97da5e04962 100755 --- a/egs/wsj/s5/steps/libs/nnet3/report/log_parse.py +++ b/egs/wsj/s5/steps/libs/nnet3/report/log_parse.py @@ -322,7 +322,7 @@ def parse_progress_logs_for_param_diff(exp_dir, pattern): groups = mat_obj.groups() iteration = groups[0] differences = parse_difference_string(groups[1]) - component_names = component_names.union(differences.keys()) + component_names = component_names.union(list(differences.keys())) progress_per_iter[int(iteration)] = differences component_names = list(component_names) @@ -435,14 +435,14 @@ def parse_prob_logs(exp_dir, key='accuracy', output="output"): raise KaldiLogParseException("Could not find any lines with {k} in " " {l}".format(k=key, l=valid_prob_files)) - iters = list(set(valid_objf.keys()).intersection(train_objf.keys())) + iters = list(set(valid_objf.keys()).intersection(list(train_objf.keys()))) if not iters: raise KaldiLogParseException("Could not any common iterations with" " key {k} in both {tl} and {vl}".format( k=key, tl=train_prob_files, vl=valid_prob_files)) iters.sort() - return list(map(lambda x: (int(x), float(train_objf[x]), - float(valid_objf[x])), iters)) + return list([(int(x), float(train_objf[x]), + float(valid_objf[x])) for x in iters]) def parse_rnnlm_prob_logs(exp_dir, key='objf'): train_prob_files = "%s/log/train.*.*.log" % (exp_dir) @@ -498,14 +498,14 @@ def parse_rnnlm_prob_logs(exp_dir, key='objf'): raise KaldiLogParseException("Could not find any lines with {k} in " " {l}".format(k=key, l=valid_prob_files)) - iters = list(set(valid_objf.keys()).intersection(train_objf.keys())) + iters = list(set(valid_objf.keys()).intersection(list(train_objf.keys()))) if not iters: raise KaldiLogParseException("Could not any common iterations with" " key {k} in both {tl} and {vl}".format( k=key, tl=train_prob_files, vl=valid_prob_files)) iters.sort() - return map(lambda x: (int(x), float(train_objf[x]), - float(valid_objf[x])), iters) + return [(int(x), float(train_objf[x]), + float(valid_objf[x])) for x in iters] @@ -532,7 +532,7 @@ def generate_acc_logprob_report(exp_dir, key="accuracy", output="output"): try: report.append("%d\t%s\t%g\t%g\t%g" % (x[0], str(times[x[0]]), x[1], x[2], x[2]-x[1])) - except KeyError, IndexError: + except (KeyError, IndexError): continue total_time = 0 diff --git a/egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py b/egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py index 229f290e94c..c932a9c54f7 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py @@ -7,6 +7,8 @@ """ This is a module with methods which will be used by scripts for training of deep neural network acoustic model with chain objective. """ +from __future__ import division +from __future__ import print_function import logging import math @@ -167,7 +169,7 @@ def train_new_models(dir, iter, srand, num_jobs, # work out the 1-based archive index. archive_index = (k % num_archives) + 1 # previous : frame_shift = (k/num_archives) % frame_subsampling_factor - frame_shift = ((archive_index + k/num_archives) + frame_shift = ((archive_index + k//num_archives) % frame_subsampling_factor) multitask_egs_opts = common_train_lib.get_multitask_egs_opts( @@ -413,8 +415,7 @@ def compute_preconditioning_matrix(dir, egs_dir, num_lda_jobs, run_opts, rand_prune=rand_prune)) # the above command would have generated dir/{1..num_lda_jobs}.lda_stats - lda_stat_files = list(map(lambda x: '{0}/{1}.lda_stats'.format(dir, x), - range(1, num_lda_jobs + 1))) + lda_stat_files = ['{0}/{1}.lda_stats'.format(dir, x) for x in range(1, num_lda_jobs + 1)] common_lib.execute_command( """{command} {dir}/log/sum_transform_stats.log \ diff --git a/egs/wsj/s5/steps/libs/nnet3/train/common.py b/egs/wsj/s5/steps/libs/nnet3/train/common.py index 720164e5436..1a038cc23f2 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/common.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/common.py @@ -7,6 +7,7 @@ """This module contains classes and methods common to training of nnet3 neural networks. """ +from __future__ import division import argparse import glob @@ -69,9 +70,12 @@ def get_multitask_egs_opts(egs_dir, egs_prefix="", '--output=ark:foo/egs/output.3.ark --weight=ark:foo/egs/weights.3.ark' i.e. egs_prefix is "" for train and "valid_diagnostic." for validation. + + Caution: archive_index is usually an integer, but may be a string ("JOB") + in some cases. """ multitask_egs_opts = "" - egs_suffix = ".{0}".format(archive_index) if archive_index > -1 else "" + egs_suffix = ".{0}".format(archive_index) if archive_index != -1 else "" if use_multitask_egs: output_file_name = ("{egs_dir}/{egs_prefix}output{egs_suffix}.ark" @@ -288,7 +292,7 @@ def halve_range_str(range_str): halved_ranges = [] for r in ranges: # a range may be either e.g. '64', or '128:256' - c = [str(max(1, int(x)/2)) for x in r.split(":")] + c = [str(max(1, int(x)//2)) for x in r.split(":")] halved_ranges.append(":".join(c)) return ','.join(halved_ranges) @@ -525,13 +529,13 @@ def smooth_presoftmax_prior_scale_vector(pdf_counts, presoftmax_prior_scale_power=-0.25, smooth=0.01): total = sum(pdf_counts) - average_count = total/len(pdf_counts) + average_count = float(total) / len(pdf_counts) scales = [] for i in range(len(pdf_counts)): scales.append(math.pow(pdf_counts[i] + smooth * average_count, presoftmax_prior_scale_power)) num_pdfs = len(pdf_counts) - scaled_counts = list(map(lambda x: x * float(num_pdfs) / sum(scales), scales)) + scaled_counts = [x * float(num_pdfs) / sum(scales) for x in scales] return scaled_counts @@ -561,7 +565,7 @@ def get_model_combine_iters(num_iters, num_epochs, in the final model-averaging phase. (note: it's a weighted average where the weights are worked out from a subset of training data.)""" - approx_iters_per_epoch_final = num_archives/num_jobs_final + approx_iters_per_epoch_final = float(num_archives) / num_jobs_final # Note: it used to be that we would combine over an entire epoch, # but in practice we very rarely would use any weights from towards # the end of that range, so we are changing it to use not @@ -578,8 +582,8 @@ def get_model_combine_iters(num_iters, num_epochs, # But if this value is > max_models_combine, then the models # are subsampled to get these many models to combine. - num_iters_combine_initial = min(approx_iters_per_epoch_final/2 + 1, - num_iters/2) + num_iters_combine_initial = min(int(approx_iters_per_epoch_final/2) + 1, + int(num_iters/2)) if num_iters_combine_initial > max_models_combine: subsample_model_factor = int( @@ -591,7 +595,7 @@ def get_model_combine_iters(num_iters, num_epochs, models_to_combine.add(num_iters) else: subsample_model_factor = 1 - num_iters_combine = min(max_models_combine, num_iters/2) + num_iters_combine = min(max_models_combine, num_iters//2) models_to_combine = set(range(num_iters - num_iters_combine + 1, num_iters + 1)) @@ -607,8 +611,7 @@ def get_learning_rate(iter, num_jobs, num_iters, num_archives_processed, effective_learning_rate = ( initial_effective_lrate * math.exp(num_archives_processed - * math.log(final_effective_lrate - / initial_effective_lrate) + * math.log(float(final_effective_lrate) / initial_effective_lrate) / num_archives_to_process)) return num_jobs * effective_learning_rate diff --git a/egs/wsj/s5/steps/libs/nnet3/train/dropout_schedule.py b/egs/wsj/s5/steps/libs/nnet3/train/dropout_schedule.py index 0ad93e5977d..0de9074517f 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/dropout_schedule.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/dropout_schedule.py @@ -186,9 +186,22 @@ def _get_component_dropout(dropout_schedule, data_fraction): def _get_dropout_proportions(dropout_schedule, data_fraction): """Returns dropout proportions based on the dropout_schedule for the - fraction of data seen at this stage of training. + fraction of data seen at this stage of training. Returns a list of + pairs (pattern, dropout_proportion); for instance, it might return + the list ['*', 0.625] meaning a dropout proportion of 0.625 is to + be applied to all dropout components. + Returns None if dropout_schedule is None. + dropout_schedule might be (in the sample case using the default pattern of + '*'): '0.1,0.5@0.5,0.1', meaning a piecewise linear function that starts at + 0.1 when data_fraction=0.0, rises to 0.5 when data_fraction=0.5, and falls + again to 0.1 when data_fraction=1.0. It can also contain space-separated + items of the form 'pattern=schedule', for instance: + '*=0.0,0.5,0.0 lstm.*=0.0,0.3@0.75,0.0' + The more specific patterns should go later, otherwise they will be overridden + by the less specific patterns' commands. + Calls _get_component_dropout() for the different component name patterns in dropout_schedule. @@ -198,6 +211,7 @@ def _get_dropout_proportions(dropout_schedule, data_fraction): See _self_test() for examples. data_fraction: The fraction of data seen until this stage of training. + """ if dropout_schedule is None: return None @@ -213,6 +227,10 @@ def _get_dropout_proportions(dropout_schedule, data_fraction): def get_dropout_edit_string(dropout_schedule, data_fraction, iter_): """Return an nnet3-copy --edits line to modify raw_model_string to set dropout proportions according to dropout_proportions. + E.g. if _dropout_proportions(dropout_schedule, data_fraction) + returns [('*', 0.625)], this will return the string: + "nnet3-copy --edits='set-dropout-proportion name=* proportion=0.625'" + Arguments: dropout_schedule: Value for the --trainer.dropout-schedule option. diff --git a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py index cc5c9693a12..f2722350e41 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py @@ -348,8 +348,7 @@ def compute_preconditioning_matrix(dir, egs_dir, num_lda_jobs, run_opts, rand_prune=rand_prune)) # the above command would have generated dir/{1..num_lda_jobs}.lda_stats - lda_stat_files = list(map(lambda x: '{0}/{1}.lda_stats'.format(dir, x), - range(1, num_lda_jobs + 1))) + lda_stat_files = ['{0}/{1}.lda_stats'.format(dir, x) for x in range(1, num_lda_jobs + 1)] common_lib.execute_command( """{command} {dir}/log/sum_transform_stats.log \ diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/attention.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/attention.py index e870c1a60cf..db4cb392f10 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/attention.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/attention.py @@ -6,6 +6,7 @@ """ from __future__ import print_function +from __future__ import division import math import re import sys diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py index e95de336586..7846c983b19 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py @@ -9,6 +9,7 @@ """ from __future__ import print_function +from __future__ import division import math import re import sys @@ -748,7 +749,8 @@ def check_configs(self): if self.config['target-rms'] < 0.0: raise RuntimeError("target-rms has invalid value {0}" .format(self.config['target-rms'])) - if self.config['learning-rate-factor'] <= 0.0: + if (self.config['learning-rate-factor'] != '' and + self.config['learning-rate-factor'] <= 0.0): raise RuntimeError("learning-rate-factor has invalid value {0}" .format(self.config['learning-rate-factor'])) diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/composite_layers.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/composite_layers.py index e1905d0aa48..bf2a90916ae 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/composite_layers.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/composite_layers.py @@ -135,11 +135,9 @@ def get_full_config(self): def _generate_config(self): configs = [] name = self.name - input_dim = self.descriptors['input']['dim'] input_descriptor = self.descriptors['input']['final-string'] output_dim = self.config['dim'] - assert output_dim == input_dim bottleneck_dim = self.config['bottleneck-dim'] bypass_scale = self.config['bypass-scale'] dropout_proportion = self.config['dropout-proportion'] diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/convolution.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/convolution.py index be8bcaefedf..1628a5e314f 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/convolution.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/convolution.py @@ -7,6 +7,7 @@ """ This module has the implementation of convolutional layers. """ from __future__ import print_function +from __future__ import division import math import re import sys @@ -148,7 +149,7 @@ def set_derived_configs(self): if input_dim % height_in != 0: raise RuntimeError("Input dimension {0} is not a multiple of height-in={1}".format( input_dim, height_in)) - self.config['num-filters-in'] = input_dim / height_in + self.config['num-filters-in'] = input_dim // height_in # Check whether 'str' is a sorted, unique, nonempty list of integers, like -1,0,1., @@ -880,7 +881,7 @@ def _generate_normal_resblock_config(self): num_filters_out = self.config['num-filters'] if height_out != height_in: - if height_out < height_in / 2 - 1 or height_out > height_in / 2 + 1: + if height_out < height_in / 2 - 1 or height_out > height_in / 2 + 1: raise RuntimeError("Expected height-out to be about half height-in, or the same: " "height-in={0} height-out={1}".format(height_in, height_out)) if not time_period_out % 2 == 0: @@ -1030,7 +1031,7 @@ def _generate_bottleneck_resblock_config(self): num_filters_out = self.config['num-filters'] if height_out != height_in: - if height_out < height_in / 2 - 1 or height_out > height_in / 2 + 1: + if height_out < height_in / 2 - 1 or height_out > height_in / 2 + 1: raise RuntimeError("Expected height-out to be about half height-in, or the same: " "height-in={0} height-out={1}".format(height_in, height_out)) height_subsample = 2 diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/gru.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/gru.py index 530ba14474a..2f387a6a1e5 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/gru.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/gru.py @@ -1,6 +1,7 @@ # Copyright 2016 Johns Hopkins University (Dan Povey) # 2017 Gaofeng Cheng (UCAS) # 2017 Lu Huang (THU) +# 2018 Hang Lyu # Apache 2.0. @@ -83,7 +84,7 @@ def get_full_config(self): ans.append((config_name, line)) return ans - # convenience function to generate the LSTM config + # convenience function to generate the GRU config def generate_gru_config(self): # assign some variables to reduce verbosity @@ -468,7 +469,7 @@ def output_name(self, auxiliary_output = None): def output_dim(self, auxiliary_output = None): if auxiliary_output is not None: if auxiliary_output in self.auxiliary_outputs(): - if node_name == 'c_t': + if node_name == 'h_t': return self.config['cell-dim'] # add code for other auxiliary_outputs here when we decide to expose them else: @@ -487,7 +488,7 @@ def get_full_config(self): ans.append((config_name, line)) return ans - # convenience function to generate the PGRU config + # convenience function to generate the Norm-PGRU config def generate_pgru_config(self): # assign some variables to reduce verbosity @@ -711,7 +712,7 @@ def get_full_config(self): ans.append((config_name, line)) return ans - # convenience function to generate the PGRU config + # convenience function to generate the OPGRU config def generate_pgru_config(self): # assign some variables to reduce verbosity @@ -922,7 +923,7 @@ def get_full_config(self): ans.append((config_name, line)) return ans - # convenience function to generate the PGRU config + # convenience function to generate the Norm-OPGRU config def generate_pgru_config(self): # assign some variables to reduce verbosity @@ -1039,3 +1040,1072 @@ def generate_pgru_config(self): configs.append("component-node name={0}.s_t component={0}.s_r input={0}.s_t_preclip_renorm".format(name)) return configs + +# This class is for lines like +# 'fast-gru-layer name=gru1 input=[-1] delay=-3' +# It generates an GRU sub-graph without output projections. +# The output dimension of the layer may be specified via 'cell-dim=xxx', but if not specified, +# the dimension defaults to the same as the input. +# See other configuration values below. +# decay-time is deprecated under GRU or PGRU, as I found the PGRUs do not need the decay-time option to get generalized to unseen sequence length +# +# Parameters of the class, and their defaults: +# input='[-1]' [Descriptor giving the input of the layer.] +# cell-dim=-1 [Dimension of the cell] +# delay=-1 [Delay in the recurrent connections of the GRU/LSTM ] +# clipping-threshold=30 [similar to LSTMs ,nnet3 GRUs use a gradient clipping component at the recurrent connections. +# This is the threshold used to decide if clipping has to be activated ] +# zeroing-interval=20 [interval at which we (possibly) zero out the recurrent derivatives.] +# zeroing-threshold=15 [We only zero out the derivs every zeroing-interval, if derivs exceed this value.] +# self-repair-scale-nonlinearity=1e-5 [It is a constant scaling the self-repair vector computed in derived classes of NonlinearComponent] +# i.e., SigmoidComponent, TanhComponent and RectifiedLinearComponent ] +# ng-per-element-scale-options='' [Additional options used for the diagonal matrices in the GRU/LSTM ] +# gru-nonlinearity-options=' max-change=0.75' [options for GruNonlinearityComponent, see below for detail] +# ng-affine-options='' [Additional options used for the full matrices in the GRU/LSTM, can be used to do things like set biases to initialize to 1] +class XconfigFastGruLayer(XconfigLayerBase): + def __init__(self, first_token, key_to_value, prev_names = None): + assert first_token == "fast-gru-layer" + XconfigLayerBase.__init__(self, first_token, key_to_value, prev_names) + + def set_default_configs(self): + self.config = {'input':'[-1]', + 'cell-dim' : -1, # this is a compulsory argument + 'clipping-threshold' : 30.0, + 'delay' : -1, + 'ng-per-element-scale-options' : ' max-change=0.75', + 'ng-affine-options' : ' max-change=0.75 ', + 'self-repair-scale-nonlinearity' : 0.00001, + 'zeroing-interval' : 20, + 'zeroing-threshold' : 15.0, + # if you want to set 'self-repair-scale', ' self-repair-threshold' + # or 'param-stddev' for GruNonlinearityComponent + # For default, they are 1.0e-05, 0.2 and 1.0 / sqrt(d) where d is cell-dim. + # you can add somethig like 'self-repair-scale=xxx' to gru-nonlinearity-options. + # you can also see src/nnet3/nnet-combined-component.h for detail + 'gru-nonlinearity-options' : ' max-change=0.75' + } + + def set_derived_configs(self): + if self.config['cell-dim'] <= 0: + self.config['cell-dim'] = self.descriptors['input']['dim'] + + def check_configs(self): + key = 'cell-dim' + if self.config['cell-dim'] <= 0: + raise RuntimeError("cell-dim has invalid value {0}.".format(self.config[key])) + + if self.config['delay'] == 0: + raise RuntimeError("delay cannot be zero") + + for key in ['self-repair-scale-nonlinearity']: + if self.config[key] < 0.0 or self.config[key] > 1.0: + raise RuntimeError("{0} has invalid value {1}.".format(key, self.config[key])) + + def output_name(self, auxiliary_output = None): + node_name = 'y_t' + return '{0}.{1}'.format(self.name, node_name) + + def output_dim(self, auxiliary_output = None): + return self.config['cell-dim'] + + def get_full_config(self): + ans = [] + config_lines = self.generate_gru_config() + + for line in config_lines: + for config_name in ['ref', 'final']: + # we do not support user specified matrices in LSTM initialization + # so 'ref' and 'final' configs are the same. + ans.append((config_name, line)) + return ans + + # convenience function to generate the GRU config + def generate_gru_config(self): + + # assign some variables to reduce verbosity + name = self.name + # in the below code we will just call descriptor_strings as descriptors for conciseness + input_dim = self.descriptors['input']['dim'] + input_descriptor = self.descriptors['input']['final-string'] + cell_dim = self.config['cell-dim'] + delay = self.config['delay'] + bptrunc_str = ("clipping-threshold={0}" + " zeroing-threshold={1}" + " zeroing-interval={2}" + " recurrence-interval={3}" + "".format(self.config['clipping-threshold'], + self.config['zeroing-threshold'], + self.config['zeroing-interval'], abs(delay))) + repair_nonlin = self.config['self-repair-scale-nonlinearity'] + repair_nonlin_str = "self-repair-scale={0:.10f}".format(repair_nonlin) if repair_nonlin is not None else '' + affine_str = self.config['ng-affine-options'] + + # string for GruNonlinearityComponent + gru_nonlin_str = self.config['gru-nonlinearity-options'] + + # formulation like: + # z_t = \sigmoid ( U^z x_t + W^z y_{t-1} ) # update gate + # r_t = \sigmoid ( U^r x_t + W^r y_{t-1} ) # reset gate + # h_t = \tanh ( U^h x_t + W^h ( y_{t-1} \dot r_t ) ) + # y_t = ( 1 - z_t ) \dot h_t + z_t \dot y_{t-1} + # Note: + # naming convention: + # .W_. e.g. Gru1.W_i.xr for matrix + # providing output to gate i and operating on an appended vector [x,r] + # notation convention: + # In order to be consistent with the notations which are used in + # nnet-combined-component.cc, we map "\tilde{h_t}" and "h_t" which are + # used in paper to "h_t" and "c_t" + + configs = [] + + configs.append("### Begin Gru layer '{0}'".format(name)) + configs.append("# Update gate control : W_z* matrices") + configs.append("component name={0}.W_z.xh type=NaturalGradientAffineComponent input-dim={1} output-dim={2} {3}".format(name, input_dim + cell_dim, cell_dim, affine_str)) + configs.append("# Reset gate control : W_r* matrices") + configs.append("component name={0}.W_r.xh type=NaturalGradientAffineComponent input-dim={1} output-dim={2} {3}".format(name, input_dim + cell_dim, cell_dim, affine_str)) + + configs.append("# hpart_t related matrix : W_hpart matrice") + configs.append("component name={0}.W_hpart.x type=NaturalGradientAffineComponent input-dim={1} output-dim={2} {3}".format(name, input_dim, cell_dim , affine_str)) + + configs.append("# Defining the non-linearities for z_t and r_t") + configs.append("component name={0}.z type=SigmoidComponent dim={1} {2}".format(name, cell_dim, repair_nonlin_str)) + configs.append("component name={0}.r type=SigmoidComponent dim={1} {2}".format(name, cell_dim, repair_nonlin_str)) + + recurrent_connection = '{0}.s_t'.format(name) + + configs.append("# z_t") + configs.append("component-node name={0}.z_t_pre component={0}.W_z.xh input=Append({1}, IfDefined(Offset({2}, {3})))".format(name, input_descriptor, recurrent_connection, delay)) + configs.append("component-node name={0}.z_t component={0}.z input={0}.z_t_pre".format(name)) + configs.append("# r_t") + configs.append("component-node name={0}.r_t_pre component={0}.W_r.xh input=Append({1}, IfDefined(Offset({2}, {3})))".format(name, input_descriptor, recurrent_connection, delay)) + configs.append("component-node name={0}.r_t component={0}.r input={0}.r_t_pre".format(name)) + + configs.append("# hpart_t") + configs.append("component-node name={0}.hpart_t component={0}.W_hpart.x input={1}".format(name, input_descriptor)) + + configs.append("# y_t") + configs.append("# Note: the output of GruNonlinearityComponent is (h_t, c_t), we just get the second half. Otherwise, in non-projection gru layer, y_t = c_t") + configs.append("component name={0}.gru_nonlin type=GruNonlinearityComponent cell-dim={1} {2}".format(name, cell_dim, gru_nonlin_str)) + configs.append("component-node name={0}.gru_nonlin_t component={0}.gru_nonlin input=Append({0}.z_t, {0}.r_t, {0}.hpart_t, IfDefined(Offset({1}, {2})))".format(name, recurrent_connection, delay)) + configs.append("dim-range-node name={0}.y_t input-node={0}.gru_nonlin_t dim-offset={1} dim={1}".format(name, cell_dim)) + + configs.append("# s_t : recurrence") + configs.append("# Note: in non-projection gru layer, the recurrent part equals the output, namely y_t.") + configs.append("component name={0}.s_r type=BackpropTruncationComponent dim={1} {2}".format(name, cell_dim, bptrunc_str)) + configs.append("component-node name={0}.s_t component={0}.s_r input={0}.y_t".format(name)) + return configs + + +# This class is for lines like +# 'fast-pgru-layer name=pgru1 input=[-1] delay=-3' +# It generates an PGRU sub-graph with output projections. It can also generate +# outputs without projection, but you could use the XconfigGruLayer for this +# simple RNN. +# The output dimension of the layer may be specified via 'cell-dim=xxx', but if not specified, +# the dimension defaults to the same as the input. +# See other configuration values below. +# +# Parameters of the class, and their defaults: +# input='[-1]' [Descriptor giving the input of the layer.] +# cell-dim=-1 [Dimension of the cell] +# recurrent-projection_dim [Dimension of the projection used in recurrent connections, e.g. cell-dim/4] +# non-recurrent-projection-dim [Dimension of the projection in non-recurrent connections, +# in addition to recurrent-projection-dim, e.g. cell-dim/4] +# delay=-1 [Delay in the recurrent connections of the GRU ] +# clipping-threshold=30 [nnet3 GRU use a gradient clipping component at the recurrent connections. +# This is the threshold used to decide if clipping has to be activated ] +# zeroing-interval=20 [interval at which we (possibly) zero out the recurrent derivatives.] +# zeroing-threshold=15 [We only zero out the derivs every zeroing-interval, if derivs exceed this value.] +# self_repair_scale_nonlinearity=1e-5 [It is a constant scaling the self-repair vector computed in derived classes of NonlinearComponent] +# i.e., SigmoidComponent, TanhComponent and RectifiedLinearComponent ] +# ng-per-element-scale-options='' [Additional options used for the diagonal matrices in the GRU ] +# gru-nonlinearity-options=' max-change=0.75' [options for GruNonlinearityComponent, see below for detail] +# ng-affine-options='' [Additional options used for the full matrices in the GRU, can be used to do things like set biases to initialize to 1] +class XconfigFastPgruLayer(XconfigLayerBase): + def __init__(self, first_token, key_to_value, prev_names = None): + assert first_token == "fast-pgru-layer" + XconfigLayerBase.__init__(self, first_token, key_to_value, prev_names) + + def set_default_configs(self): + self.config = {'input' : '[-1]', + 'cell-dim' : -1, # this is a compulsory argument + 'recurrent-projection-dim' : -1, # defaults to cell-dim / 4 + 'non-recurrent-projection-dim' : -1, # defaults to + # recurrent-projection-dim + 'clipping-threshold' : 30.0, + 'delay' : -1, + 'ng-per-element-scale-options' : ' max-change=0.75 ', + 'ng-affine-options' : ' max-change=0.75 ', + 'self-repair-scale-nonlinearity' : 0.00001, + 'zeroing-interval' : 20, + 'zeroing-threshold' : 15.0, + # if you want to set 'self-repair-scale', ' self-repair-threshold' + # or 'param-stddev' for GruNonlinearityComponent + # For default, they are 1.0e-05, 0.2 and 1.0 / sqrt(d) where d is cell-dim. + # you can add somethig like 'self-repair-scale=xxx' to gru-nonlinearity-options. + # you can also see src/nnet3/nnet-combined-component.h for detail + 'gru-nonlinearity-options' : ' max-change=0.75' + } + + def set_derived_configs(self): + if self.config['recurrent-projection-dim'] <= 0: + self.config['recurrent-projection-dim'] = self.config['cell-dim'] / 4 + + if self.config['non-recurrent-projection-dim'] <= 0: + self.config['non-recurrent-projection-dim'] = \ + self.config['recurrent-projection-dim'] + + def check_configs(self): + for key in ['cell-dim', 'recurrent-projection-dim', + 'non-recurrent-projection-dim']: + if self.config[key] <= 0: + raise RuntimeError("{0} has invalid value {1}.".format( + key, self.config[key])) + + if self.config['delay'] == 0: + raise RuntimeError("delay cannot be zero") + + if (self.config['recurrent-projection-dim'] + + self.config['non-recurrent-projection-dim'] > + self.config['cell-dim']): + raise RuntimeError("recurrent+non-recurrent projection dim exceeds " + "cell dim.") + for key in ['self-repair-scale-nonlinearity']: + if self.config[key] < 0.0 or self.config[key] > 1.0: + raise RuntimeError("{0} has invalid value {2}." + .format(self.layer_type, key, + self.config[key])) + + def auxiliary_outputs(self): + return ['c_t'] + + def output_name(self, auxiliary_output = None): + node_name = 'y_t' + if auxiliary_output is not None: + if auxiliary_output in self.auxiliary_outputs(): + node_name = auxiliary_output + else: + raise Exception("In {0} of type {1}, unknown auxiliary output name {1}".format(self.layer_type, auxiliary_output)) + + return '{0}.{1}'.format(self.name, node_name) + + def output_dim(self, auxiliary_output = None): + if auxiliary_output is not None: + if auxiliary_output in self.auxiliary_outputs(): + if node_name == 'c_t': + return self.config['cell-dim'] + # add code for other auxiliary_outputs here when we decide to expose them + else: + raise Exception("In {0} of type {1}, unknown auxiliary output name {1}".format(self.layer_type, auxiliary_output)) + + return self.config['recurrent-projection-dim'] + self.config['non-recurrent-projection-dim'] + + def get_full_config(self): + ans = [] + config_lines = self.generate_pgru_config() + + for line in config_lines: + for config_name in ['ref', 'final']: + # we do not support user specified matrices in LSTM initialization + # so 'ref' and 'final' configs are the same. + ans.append((config_name, line)) + return ans + + # convenience function to generate the PGRU config + def generate_pgru_config(self): + + # assign some variables to reduce verbosity + name = self.name + # in the below code we will just call descriptor_strings as descriptors for conciseness + input_dim = self.descriptors['input']['dim'] + input_descriptor = self.descriptors['input']['final-string'] + cell_dim = self.config['cell-dim'] + rec_proj_dim = self.config['recurrent-projection-dim'] + nonrec_proj_dim = self.config['non-recurrent-projection-dim'] + delay = self.config['delay'] + repair_nonlin = self.config['self-repair-scale-nonlinearity'] + repair_nonlin_str = "self-repair-scale={0:.10f}".format(repair_nonlin) if repair_nonlin is not None else '' + bptrunc_str = ("clipping-threshold={0}" + " zeroing-threshold={1}" + " zeroing-interval={2}" + " recurrence-interval={3}" + "".format(self.config['clipping-threshold'], + self.config['zeroing-threshold'], + self.config['zeroing-interval'], + abs(delay))) + affine_str = self.config['ng-affine-options'] + pes_str = self.config['ng-per-element-scale-options'] + + # Natural gradient per element scale parameters + # TODO: decide if we want to keep exposing these options + if re.search('param-mean', pes_str) is None and \ + re.search('param-stddev', pes_str) is None: + pes_str += " param-mean=0.0 param-stddev=1.0 " + + # string for GruNonlinearityComponent + gru_nonlin_str = self.config['gru-nonlinearity-options'] + + # formulation like: + # z_t = \sigmoid ( U^z x_t + W^z s_{t-1} ) # update gate + # r_t = \sigmoid ( U^r x_t + W^r s_{t-1} ) # reset gate + # h_t = \tanh ( U^h x_t + W^h ( s_{t-1} \dot r_t ) ) + # c_t = ( 1 - z_t ) \dot h_t + z_t \dot c_{t-1} + # y_t = W^y c_t # dim(y_t) = recurrent_dim + non_recurrent_dim. + # This is the output of the GRU. + # s_t = y_t[0:recurrent_dim-1] # dimension range of y_t + # dim(s_t) = recurrent_dim. + # Note: + # naming convention: + # .W_. e.g. Gru1.W_i.xr for matrix + # providing output to gate i and operating on an appended vector [x,r] + # notation convention: + # In order to be consistent with the notations which are used in + # nnet-combined-component.cc, we map "\tilde{h_t}" and "h_t" which are + # used in paper to "h_t" and "c_t" + + configs = [] + configs.append("### Begin Gru layer '{0}'".format(name)) + configs.append("# Update gate control : W_z* matrices") + configs.append("component name={0}.W_z.xs type=NaturalGradientAffineComponent input-dim={1} output-dim={2} {3}".format(name, input_dim + rec_proj_dim, cell_dim, affine_str)) + configs.append("# Reset gate control : W_r* matrices") + configs.append("component name={0}.W_r.xs type=NaturalGradientAffineComponent input-dim={1} output-dim={2} {3}".format(name, input_dim + rec_proj_dim, rec_proj_dim, affine_str)) + + + configs.append("# hpart_t related matrix : W_hpart matric") + configs.append("component name={0}.W_hpart.x type=NaturalGradientAffineComponent input-dim={1} output-dim={2} {3}".format(name, input_dim, cell_dim , affine_str)) + + configs.append("# Defining the non-linearities") + configs.append("component name={0}.z type=SigmoidComponent dim={1} {2}".format(name, cell_dim, repair_nonlin_str)) + configs.append("component name={0}.r type=SigmoidComponent dim={1} {2}".format(name, rec_proj_dim, repair_nonlin_str)) + + recurrent_connection = '{0}.s_t'.format(name) + + configs.append("# z_t and r_t") + configs.append("component-node name={0}.z_t_pre component={0}.W_z.xs input=Append({1}, IfDefined(Offset({2}, {3})))".format(name, input_descriptor, recurrent_connection, delay)) + configs.append("component-node name={0}.z_t component={0}.z input={0}.z_t_pre".format(name)) + configs.append("component-node name={0}.r_t_pre component={0}.W_r.xs input=Append({1}, IfDefined(Offset({2}, {3})))".format(name, input_descriptor, recurrent_connection, delay)) + configs.append("component-node name={0}.r_t component={0}.r input={0}.r_t_pre".format(name)) + + configs.append("# hpart_t") + configs.append("component-node name={0}.hpart_t component={0}.W_hpart.x input={1}".format(name, input_descriptor)) + + configs.append("# c_t") + configs.append("# Note: the output of GruNonlinearityComponent is (h_t, c_t), we use the second half.") + configs.append("component name={0}.gru_nonlin type=GruNonlinearityComponent cell-dim={1} recurrent-dim={2} {3}".format(name, cell_dim, rec_proj_dim, gru_nonlin_str)) + configs.append("component-node name={0}.gru_nonlin_t component={0}.gru_nonlin input=Append({0}.z_t, {0}.r_t, {0}.hpart_t, IfDefined(Offset({0}.c_t, {2})), IfDefined(Offset({1}, {2})))".format(name, recurrent_connection, delay)) + configs.append("dim-range-node name={0}.c_t input-node={0}.gru_nonlin_t dim-offset={1} dim={1}".format(name, cell_dim)) + + configs.append("# the projected matrix W_y.c and y_t") + configs.append("component name={0}.W_y.c type=NaturalGradientAffineComponent input-dim={1} output-dim={2} {3}".format(name, cell_dim, rec_proj_dim + nonrec_proj_dim, affine_str)) + configs.append("component-node name={0}.y_t component={0}.W_y.c input={0}.c_t".format(name)) + + configs.append("# s_t : recurrence") + configs.append("component name={0}.s_r type=BackpropTruncationComponent dim={1} {2}".format(name, rec_proj_dim, bptrunc_str)) + configs.append("dim-range-node name={0}.s_t_pre input-node={0}.y_t dim-offset=0 dim={1}".format(name, rec_proj_dim)) + configs.append("component-node name={0}.s_t component={0}.s_r input={0}.s_t_pre".format(name)) + return configs + + +# This class is for lines like +# 'fast-norm-pgru-layer name=pgru1 input=[-1] delay=-3' + +# Different from the vanilla PGRU, the NormPGRU uses batchnorm in the forward direction +# and renorm in the recurrence. + +# The output dimension of the layer may be specified via 'cell-dim=xxx', but if not specified, +# the dimension defaults to the same as the input. +# See other configuration values below. +# +# Parameters of the class, and their defaults: +# input='[-1]' [Descriptor giving the input of the layer.] +# cell-dim=-1 [Dimension of the cell] +# recurrent-projection_dim [Dimension of the projection used in recurrent connections, e.g. cell-dim/4] +# non-recurrent-projection-dim [Dimension of the projection in non-recurrent connections, +# in addition to recurrent-projection-dim, e.g. cell-dim/4] +# delay=-1 [Delay in the recurrent connections of the GRU ] +# clipping-threshold=30 [nnet3 GRU use a gradient clipping component at the recurrent connections. +# This is the threshold used to decide if clipping has to be activated ] +# zeroing-interval=20 [interval at which we (possibly) zero out the recurrent derivatives.] +# zeroing-threshold=15 [We only zero out the derivs every zeroing-interval, if derivs exceed this value.] +# self_repair_scale_nonlinearity=1e-5 [It is a constant scaling the self-repair vector computed in derived classes of NonlinearComponent] +# i.e., SigmoidComponent, TanhComponent and RectifiedLinearComponent ] +# ng-per-element-scale-options='' [Additional options used for the diagonal matrices in the GRU ] +# gru-nonlinearity-options=' max-change=0.75' [options for GruNonlinearityComponent, see below for detail] +# ng-affine-options='' [Additional options used for the full matrices in the GRU, can be used to do things like set biases to initialize to 1] +class XconfigFastNormPgruLayer(XconfigLayerBase): + def __init__(self, first_token, key_to_value, prev_names = None): + assert first_token == "fast-norm-pgru-layer" + XconfigLayerBase.__init__(self, first_token, key_to_value, prev_names) + + def set_default_configs(self): + self.config = {'input' : '[-1]', + 'cell-dim' : -1, # this is a compulsory argument + 'recurrent-projection-dim' : -1, # defaults to cell-dim / 4 + 'non-recurrent-projection-dim' : -1, # defaults to + # recurrent-projection-dim + 'clipping-threshold' : 30.0, + 'delay' : -1, + 'ng-per-element-scale-options' : ' max-change=0.75 ', + 'ng-affine-options' : ' max-change=0.75 ', + 'self-repair-scale-nonlinearity' : 0.00001, + 'zeroing-interval' : 20, + 'zeroing-threshold' : 15.0, + # if you want to set 'self-repair-scale', ' self-repair-threshold' + # or 'param-stddev' for GruNonlinearityComponent + # For default, they are 1.0e-05, 0.2 and 1.0 / sqrt(d) where d is cell-dim. + # you can add somethig like 'self-repair-scale=xxx' to gru-nonlinearity-options. + # you can also see src/nnet3/nnet-combined-component.h for detail + 'gru-nonlinearity-options' : ' max-change=0.75', + 'dropout-proportion' : -1.0, # If -1.0, no dropout components will be added + 'dropout-per-frame' : True # If False, regular dropout, not per frame + } + + def set_derived_configs(self): + if self.config['recurrent-projection-dim'] <= 0: + self.config['recurrent-projection-dim'] = self.config['cell-dim'] / 4 + + if self.config['non-recurrent-projection-dim'] <= 0: + self.config['non-recurrent-projection-dim'] = \ + self.config['recurrent-projection-dim'] + + def check_configs(self): + for key in ['cell-dim', 'recurrent-projection-dim', + 'non-recurrent-projection-dim']: + if self.config[key] <= 0: + raise RuntimeError("{0} has invalid value {1}.".format( + key, self.config[key])) + + if self.config['delay'] == 0: + raise RuntimeError("delay cannot be zero") + + if (self.config['recurrent-projection-dim'] + + self.config['non-recurrent-projection-dim'] > + self.config['cell-dim']): + raise RuntimeError("recurrent+non-recurrent projection dim exceeds " + "cell dim.") + for key in ['self-repair-scale-nonlinearity']: + if self.config[key] < 0.0 or self.config[key] > 1.0: + raise RuntimeError("{0} has invalid value {2}." + .format(self.layer_type, key, + self.config[key])) + if ((self.config['dropout-proportion'] > 1.0 or + self.config['dropout-proportion'] < 0.0) and + self.config['dropout-proportion'] != -1.0 ): + raise RuntimeError("dropout-proportion has invalid value {0}." + .format(self.config['dropout-proportion'])) + + def auxiliary_outputs(self): + return ['c_t'] + + def output_name(self, auxiliary_output = None): + node_name = 'y_t' + if auxiliary_output is not None: + if auxiliary_output in self.auxiliary_outputs(): + node_name = auxiliary_output + else: + raise Exception("In {0} of type {1}, unknown auxiliary output name {1}".format(self.layer_type, auxiliary_output)) + + return '{0}.{1}'.format(self.name, node_name) + + def output_dim(self, auxiliary_output = None): + if auxiliary_output is not None: + if auxiliary_output in self.auxiliary_outputs(): + if node_name == 'c_t': + return self.config['cell-dim'] + # add code for other auxiliary_outputs here when we decide to expose them + else: + raise Exception("In {0} of type {1}, unknown auxiliary output name {1}".format(self.layer_type, auxiliary_output)) + + return self.config['recurrent-projection-dim'] + self.config['non-recurrent-projection-dim'] + + def get_full_config(self): + ans = [] + config_lines = self.generate_pgru_config() + + for line in config_lines: + for config_name in ['ref', 'final']: + # we do not support user specified matrices in LSTM initialization + # so 'ref' and 'final' configs are the same. + ans.append((config_name, line)) + return ans + + # convenience function to generate the Norm-PGRU config + def generate_pgru_config(self): + + # assign some variables to reduce verbosity + name = self.name + # in the below code we will just call descriptor_strings as descriptors for conciseness + input_dim = self.descriptors['input']['dim'] + input_descriptor = self.descriptors['input']['final-string'] + cell_dim = self.config['cell-dim'] + rec_proj_dim = self.config['recurrent-projection-dim'] + nonrec_proj_dim = self.config['non-recurrent-projection-dim'] + delay = self.config['delay'] + repair_nonlin = self.config['self-repair-scale-nonlinearity'] + repair_nonlin_str = "self-repair-scale={0:.10f}".format(repair_nonlin) if repair_nonlin is not None else '' + bptrunc_str = ("clipping-threshold={0}" + " zeroing-threshold={1}" + " zeroing-interval={2}" + " recurrence-interval={3}" + "".format(self.config['clipping-threshold'], + self.config['zeroing-threshold'], + self.config['zeroing-interval'], + abs(delay))) + affine_str = self.config['ng-affine-options'] + pes_str = self.config['ng-per-element-scale-options'] + dropout_proportion = self.config['dropout-proportion'] + dropout_per_frame = 'true' if self.config['dropout-per-frame'] else 'false' + + # Natural gradient per element scale parameters + # TODO: decide if we want to keep exposing these options + if re.search('param-mean', pes_str) is None and \ + re.search('param-stddev', pes_str) is None: + pes_str += " param-mean=0.0 param-stddev=1.0 " + + # string for GruNonlinearityComponent + gru_nonlin_str = self.config['gru-nonlinearity-options'] + + # formulation like: + # z_t = \sigmoid ( U^z x_t + W^z s_{t-1} ) # update gate + # r_t = \sigmoid ( U^r x_t + W^r s_{t-1} ) # reset gate + # h_t = \tanh ( U^h x_t + W^h ( s_{t-1} \dot r_t ) ) + # c_t = ( 1 - z_t ) \dot h_t + z_t \dot c_{t-1} + # y_t_tmp = W^y c_t + # s_t = renorm ( y_t_tmp[0:rec_proj_dim-1] ) # dim(s_t) = recurrent_dim. + # y_t = batchnorm ( y_t_tmp ) # dim(y_t) = recurrent_dim + non_recurrent_dim. + # This is the output of the GRU. + # Note: + # naming convention: + # .W_. e.g. Gru1.W_i.xr for matrix + # providing output to gate i and operating on an appended vector [x,r] + # notation convention: + # In order to be consistent with the notations which are used in + # nnet-combined-component.cc, we map "\tilde{h_t}" and "h_t" which are + # used in paper to "h_t" and "c_t" + + configs = [] + configs.append("### Begin Gru layer '{0}'".format(name)) + configs.append("# Update gate control : W_z* matrices") + configs.append("component name={0}.W_z.xs type=NaturalGradientAffineComponent input-dim={1} output-dim={2} {3}".format(name, input_dim + rec_proj_dim, cell_dim, affine_str)) + configs.append("# Reset gate control : W_r* matrices") + configs.append("component name={0}.W_r.xs type=NaturalGradientAffineComponent input-dim={1} output-dim={2} {3}".format(name, input_dim + rec_proj_dim, rec_proj_dim, affine_str)) + + + configs.append("# hpart_t related matrix : W_hpart matric") + configs.append("component name={0}.W_hpart.x type=NaturalGradientAffineComponent input-dim={1} output-dim={2} {3}".format(name, input_dim, cell_dim , affine_str)) + + configs.append("# Defining the non-linearities") + configs.append("component name={0}.z type=SigmoidComponent dim={1} {2}".format(name, cell_dim, repair_nonlin_str)) + configs.append("component name={0}.r type=SigmoidComponent dim={1} {2}".format(name, rec_proj_dim, repair_nonlin_str)) + + if dropout_proportion != -1.0: + configs.append("# Defining the dropout component") + configs.append("component name={0}.dropout_z type=DropoutComponent dim={1} " + "dropout-proportion={2} dropout-per-frame={3}" + .format(name, cell_dim, dropout_proportion, dropout_per_frame)) + configs.append("component name={0}.dropout_r type=DropoutComponent dim={1} " + "dropout-proportion={2} dropout-per-frame={3}" + .format(name, rec_proj_dim, dropout_proportion, dropout_per_frame)) + + + recurrent_connection = '{0}.s_t'.format(name) + + configs.append("# z_t") + configs.append("component-node name={0}.z_t_pre component={0}.W_z.xs input=Append({1}, IfDefined(Offset({2}, {3})))".format(name, input_descriptor, recurrent_connection, delay)) + if dropout_proportion != -1.0: + configs.append("component-node name={0}.z_t_predrop component={0}.z input={0}.z_t_pre".format(name)) + configs.append("component-node name={0}.z_t component={0}.dropout_z input={0}.z_t_predrop".format(name)) + else: + configs.append("component-node name={0}.z_t component={0}.z input={0}.z_t_pre".format(name)) + + configs.append("# r_t") + configs.append("component-node name={0}.r_t_pre component={0}.W_r.xs input=Append({1}, IfDefined(Offset({2}, {3})))".format(name, input_descriptor, recurrent_connection, delay)) + if dropout_proportion != -1.0: + configs.append("component-node name={0}.r_t_predrop component={0}.r input={0}.r_t_pre".format(name)) + configs.append("component-node name={0}.r_t component={0}.dropout_r input={0}.r_t_predrop".format(name)) + else: + configs.append("component-node name={0}.r_t component={0}.r input={0}.r_t_pre".format(name)) + + configs.append("# hpart_t") + configs.append("component-node name={0}.hpart_t component={0}.W_hpart.x input={1}".format(name, input_descriptor)) + + configs.append("# c_t") + configs.append("# Note: the output of GruNonlinearityComponent is (h_t, c_t), we use the second half.") + configs.append("component name={0}.gru_nonlin type=GruNonlinearityComponent cell-dim={1} recurrent-dim={2} {3}".format(name, cell_dim, rec_proj_dim, gru_nonlin_str)) + configs.append("component-node name={0}.gru_nonlin_t component={0}.gru_nonlin input=Append({0}.z_t, {0}.r_t, {0}.hpart_t, IfDefined(Offset({0}.c_t, {2})), IfDefined(Offset({1}, {2})))".format(name, recurrent_connection, delay)) + configs.append("dim-range-node name={0}.c_t input-node={0}.gru_nonlin_t dim-offset={1} dim={1}".format(name, cell_dim)) + + configs.append("# the projected matrix W_y.c and y_t_tmp") + configs.append("component name={0}.W_y.c type=NaturalGradientAffineComponent input-dim={1} output-dim={2} {3}".format(name, cell_dim, rec_proj_dim + nonrec_proj_dim, affine_str)) + configs.append("component-node name={0}.y_t_tmp component={0}.W_y.c input={0}.c_t".format(name)) + + configs.append("# s_t : recurrence") + configs.append("component name={0}.renorm type=NormalizeComponent dim={1} target-rms=1.0".format(name, rec_proj_dim)) + configs.append("component name={0}.s_r type=BackpropTruncationComponent dim={1} {2}".format(name, rec_proj_dim, bptrunc_str)) + configs.append("dim-range-node name={0}.s_t_pre input-node={0}.y_t_tmp dim-offset=0 dim={1}".format(name, rec_proj_dim)) + configs.append("component-node name={0}.s_t_renorm component={0}.renorm input={0}.s_t_pre".format(name)) + configs.append("component-node name={0}.s_t component={0}.s_r input={0}.s_t_renorm".format(name)) + + configs.append("# y_t : output") + configs.append("component name={0}.batchnorm type=BatchNormComponent dim={1} target-rms=1.0".format(name, rec_proj_dim + nonrec_proj_dim)) + configs.append("component-node name={0}.y_t component={0}.batchnorm input={0}.y_t_tmp".format(name)) + return configs + + +# This class is for lines like +# 'fast-opgru-layer name=opgru1 input=[-1] delay=-3' +# It generates an PGRU sub-graph with output projections. It can also generate +# outputs without projection, but you could use the XconfigGruLayer for this +# simple RNN. +# The output dimension of the layer may be specified via 'cell-dim=xxx', but if not specified, +# the dimension defaults to the same as the input. +# See other configuration values below. +# +# Parameters of the class, and their defaults: +# input='[-1]' [Descriptor giving the input of the layer.] +# cell-dim=-1 [Dimension of the cell] +# recurrent-projection_dim [Dimension of the projection used in recurrent connections, e.g. cell-dim/4] +# non-recurrent-projection-dim [Dimension of the projection in non-recurrent connections, +# in addition to recurrent-projection-dim, e.g. cell-dim/4] +# delay=-1 [Delay in the recurrent connections of the GRU ] +# clipping-threshold=30 [nnet3 GRU use a gradient clipping component at the recurrent connections. +# This is the threshold used to decide if clipping has to be activated ] +# zeroing-interval=20 [interval at which we (possibly) zero out the recurrent derivatives.] +# zeroing-threshold=15 [We only zero out the derivs every zeroing-interval, if derivs exceed this value.] +# self_repair_scale_nonlinearity=1e-5 [It is a constant scaling the self-repair vector computed in derived classes of NonlinearComponent] +# i.e., SigmoidComponent, TanhComponent and RectifiedLinearComponent ] +# ng-per-element-scale-options='' [Additional options used for the diagonal matrices in the GRU ] +# gru-nonlinearity-options=' max-change=0.75' [options for GruNonlinearityComponent, see below for detail] +# ng-affine-options='' [Additional options used for the full matrices in the GRU, can be used to do things like set biases to initialize to 1] +class XconfigFastOpgruLayer(XconfigLayerBase): + def __init__(self, first_token, key_to_value, prev_names = None): + assert first_token == "fast-opgru-layer" + XconfigLayerBase.__init__(self, first_token, key_to_value, prev_names) + + def set_default_configs(self): + self.config = {'input' : '[-1]', + 'cell-dim' : -1, # this is a compulsory argument + 'recurrent-projection-dim' : -1, # defaults to cell-dim / 4 + 'non-recurrent-projection-dim' : -1, # defaults to + # recurrent-projection-dim + 'clipping-threshold' : 30.0, + 'delay' : -1, + 'ng-per-element-scale-options' : ' max-change=0.75 ', + 'ng-affine-options' : ' max-change=0.75 ', + 'self-repair-scale-nonlinearity' : 0.00001, + 'zeroing-interval' : 20, + 'zeroing-threshold' : 15.0, + # if you want to set 'self-repair-scale', ' self-repair-threshold' + # or 'param-stddev' for GruNonlinearityComponent + # For default, they are 1.0e-05, 0.2 and 1.0 / sqrt(d) where d is cell-dim. + # you can add somethig like 'self-repair-scale=xxx' to gru-nonlinearity-options. + # you can also see src/nnet3/nnet-combined-component.h for detail + 'gru-nonlinearity-options' : ' max-change=0.75' + } + + def set_derived_configs(self): + if self.config['recurrent-projection-dim'] <= 0: + self.config['recurrent-projection-dim'] = self.config['cell-dim'] / 4 + + if self.config['non-recurrent-projection-dim'] <= 0: + self.config['non-recurrent-projection-dim'] = \ + self.config['recurrent-projection-dim'] + + def check_configs(self): + for key in ['cell-dim', 'recurrent-projection-dim', + 'non-recurrent-projection-dim']: + if self.config[key] <= 0: + raise RuntimeError("{0} has invalid value {1}.".format( + key, self.config[key])) + + if self.config['delay'] == 0: + raise RuntimeError("delay cannot be zero") + + if (self.config['recurrent-projection-dim'] + + self.config['non-recurrent-projection-dim'] > + self.config['cell-dim']): + raise RuntimeError("recurrent+non-recurrent projection dim exceeds " + "cell dim.") + for key in ['self-repair-scale-nonlinearity']: + if self.config[key] < 0.0 or self.config[key] > 1.0: + raise RuntimeError("{0} has invalid value {2}." + .format(self.layer_type, key, + self.config[key])) + + def auxiliary_outputs(self): + return ['c_t'] + + def output_name(self, auxiliary_output = None): + node_name = 'y_t' + if auxiliary_output is not None: + if auxiliary_output in self.auxiliary_outputs(): + node_name = auxiliary_output + else: + raise Exception("In {0} of type {1}, unknown auxiliary output name {1}".format(self.layer_type, auxiliary_output)) + + return '{0}.{1}'.format(self.name, node_name) + + def output_dim(self, auxiliary_output = None): + if auxiliary_output is not None: + if auxiliary_output in self.auxiliary_outputs(): + if node_name == 'c_t': + return self.config['cell-dim'] + # add code for other auxiliary_outputs here when we decide to expose them + else: + raise Exception("In {0} of type {1}, unknown auxiliary output name {1}".format(self.layer_type, auxiliary_output)) + + return self.config['recurrent-projection-dim'] + self.config['non-recurrent-projection-dim'] + + def get_full_config(self): + ans = [] + config_lines = self.generate_pgru_config() + + for line in config_lines: + for config_name in ['ref', 'final']: + # we do not support user specified matrices in LSTM initialization + # so 'ref' and 'final' configs are the same. + ans.append((config_name, line)) + return ans + + # convenience function to generate the OPGRU config + def generate_pgru_config(self): + + # assign some variables to reduce verbosity + name = self.name + # in the below code we will just call descriptor_strings as descriptors for conciseness + input_dim = self.descriptors['input']['dim'] + input_descriptor = self.descriptors['input']['final-string'] + cell_dim = self.config['cell-dim'] + rec_proj_dim = self.config['recurrent-projection-dim'] + nonrec_proj_dim = self.config['non-recurrent-projection-dim'] + delay = self.config['delay'] + repair_nonlin = self.config['self-repair-scale-nonlinearity'] + repair_nonlin_str = "self-repair-scale={0:.10f}".format(repair_nonlin) if repair_nonlin is not None else '' + bptrunc_str = ("clipping-threshold={0}" + " zeroing-threshold={1}" + " zeroing-interval={2}" + " recurrence-interval={3}" + "".format(self.config['clipping-threshold'], + self.config['zeroing-threshold'], + self.config['zeroing-interval'], + abs(delay))) + affine_str = self.config['ng-affine-options'] + pes_str = self.config['ng-per-element-scale-options'] + + # Natural gradient per element scale parameters + # TODO: decide if we want to keep exposing these options + if re.search('param-mean', pes_str) is None and \ + re.search('param-stddev', pes_str) is None: + pes_str += " param-mean=0.0 param-stddev=1.0 " + + # string for GruNonlinearityComponent + gru_nonlin_str = self.config['gru-nonlinearity-options'] + + # formulation like: + # z_t = \sigmoid ( U^z x_t + W^z s_{t-1} ) # update gate + # o_t = \sigmoid ( U^o x_t + W^o s_{t-1} ) # reset gate + # h_t = \tanh ( U^h x_t + W^h \dot c_{t-1} ) + # c_t = ( 1 - z_t ) \dot h_t + z_t \dot c_{t-1} + # y_t = ( c_t \dot o_t ) W^y # dim(y_t) = recurrent_dim + non_recurrent_dim. + # This is the output of the GRU. + # s_t = y_t[0:recurrent_dim-1] # dimension range of y_t + # dim(s_t) = recurrent_dim. + # Note: + # naming convention: + # .W_. e.g. Gru1.W_i.xr for matrix + # providing output to gate i and operating on an appended vector [x,r] + # notation convention: + # In order to be consistent with the notations which are used in + # nnet-combined-component.cc, we map "\tilde{h_t}" and "h_t" which are + # used in paper to "h_t" and "c_t" + + configs = [] + configs.append("### Begin Gru layer '{0}'".format(name)) + configs.append("# Update gate control : W_z* matrices") + configs.append("component name={0}.W_z.xs type=NaturalGradientAffineComponent input-dim={1} output-dim={2} {3}".format(name, input_dim + rec_proj_dim, cell_dim, affine_str)) + configs.append("# Reset gate control : W_o* matrices") + configs.append("component name={0}.W_o.xs type=NaturalGradientAffineComponent input-dim={1} output-dim={2} {3}".format(name, input_dim + rec_proj_dim, cell_dim, affine_str)) + + + configs.append("# hpart_t related matrix : W_hpart matric") + configs.append("component name={0}.W_hpart.x type=NaturalGradientAffineComponent input-dim={1} output-dim={2} {3}".format(name, input_dim, cell_dim , affine_str)) + + configs.append("# Defining the non-linearities") + configs.append("component name={0}.z type=SigmoidComponent dim={1} {2}".format(name, cell_dim, repair_nonlin_str)) + configs.append("component name={0}.o type=SigmoidComponent dim={1} {2}".format(name, cell_dim, repair_nonlin_str)) + + recurrent_connection = '{0}.s_t'.format(name) + + configs.append("# z_t and o_t") + configs.append("component-node name={0}.z_t_pre component={0}.W_z.xs input=Append({1}, IfDefined(Offset({2}, {3})))".format(name, input_descriptor, recurrent_connection, delay)) + configs.append("component-node name={0}.z_t component={0}.z input={0}.z_t_pre".format(name)) + configs.append("component-node name={0}.o_t_pre component={0}.W_o.xs input=Append({1}, IfDefined(Offset({2}, {3})))".format(name, input_descriptor, recurrent_connection, delay)) + configs.append("component-node name={0}.o_t component={0}.o input={0}.o_t_pre".format(name)) + + configs.append("# hpart_t") + configs.append("component-node name={0}.hpart_t component={0}.W_hpart.x input={1}".format(name, input_descriptor)) + + configs.append("# c_t") + configs.append("# Note: the output of OutputGruNonlinearityComponent is (h_t, c_t), we use the second half.") + configs.append("component name={0}.gru_nonlin type=OutputGruNonlinearityComponent cell-dim={1} {2}".format(name, cell_dim, gru_nonlin_str)) + configs.append("component-node name={0}.gru_nonlin_t component={0}.gru_nonlin input=Append({0}.z_t, {0}.hpart_t, IfDefined(Offset({0}.c_t, {1})))".format(name, delay)) + configs.append("dim-range-node name={0}.c_t input-node={0}.gru_nonlin_t dim-offset={1} dim={1}".format(name, cell_dim)) + + configs.append("# the projected matrix W_y.cdoto and y_t") + configs.append("component name={0}.cdoto type=ElementwiseProductComponent input-dim={1} output-dim={2}".format(name, 2 * cell_dim, cell_dim)) + configs.append("component-node name={0}.cdoto component={0}.cdoto input=Append({0}.c_t, {0}.o_t)".format(name)) + configs.append("component name={0}.W_y.cdoto type=NaturalGradientAffineComponent input-dim={1} output-dim={2} {3}".format(name, cell_dim, rec_proj_dim + nonrec_proj_dim, affine_str)) + configs.append("component-node name={0}.y_t component={0}.W_y.cdoto input={0}.cdoto".format(name)) + + configs.append("# s_t recurrence") + configs.append("component name={0}.s_r type=BackpropTruncationComponent dim={1} {2}".format(name, rec_proj_dim, bptrunc_str)) + configs.append("dim-range-node name={0}.s_t_preclip input-node={0}.y_t dim-offset=0 dim={1}".format(name, rec_proj_dim)) + configs.append("component-node name={0}.s_t component={0}.s_r input={0}.s_t_preclip".format(name)) + + return configs + + +# This class is for lines like +# 'fast-norm-opgru-layer name=opgru1 input=[-1] delay=-3' + +# Different from the vanilla OPGRU, the NormOPGRU uses batchnorm in the forward direction +# and renorm in the recurrence. + +# The output dimension of the layer may be specified via 'cell-dim=xxx', but if not specified, +# the dimension defaults to the same as the input. +# See other configuration values below. +# +# Parameters of the class, and their defaults: +# input='[-1]' [Descriptor giving the input of the layer.] +# cell-dim=-1 [Dimension of the cell] +# recurrent-projection_dim [Dimension of the projection used in recurrent connections, e.g. cell-dim/4] +# non-recurrent-projection-dim [Dimension of the projection in non-recurrent connections, +# in addition to recurrent-projection-dim, e.g. cell-dim/4] +# delay=-1 [Delay in the recurrent connections of the GRU ] +# clipping-threshold=30 [nnet3 GRU use a gradient clipping component at the recurrent connections. +# This is the threshold used to decide if clipping has to be activated ] +# zeroing-interval=20 [interval at which we (possibly) zero out the recurrent derivatives.] +# zeroing-threshold=15 [We only zero out the derivs every zeroing-interval, if derivs exceed this value.] +# self_repair_scale_nonlinearity=1e-5 [It is a constant scaling the self-repair vector computed in derived classes of NonlinearComponent] +# i.e., SigmoidComponent, TanhComponent and RectifiedLinearComponent ] +# ng-per-element-scale-options='' [Additional options used for the diagonal matrices in the GRU ] +# gru-nonlinearity-options=' max-change=0.75' [options for GruNonlinearityComponent, see below for detail] +# ng-affine-options='' [Additional options used for the full matrices in the GRU, can be used to do things like set biases to initialize to 1] +class XconfigFastNormOpgruLayer(XconfigLayerBase): + def __init__(self, first_token, key_to_value, prev_names = None): + assert first_token == "fast-norm-opgru-layer" + XconfigLayerBase.__init__(self, first_token, key_to_value, prev_names) + + def set_default_configs(self): + self.config = {'input' : '[-1]', + 'cell-dim' : -1, # this is a compulsory argument + 'recurrent-projection-dim' : -1, # defaults to cell-dim / 4 + 'non-recurrent-projection-dim' : -1, # defaults to + # recurrent-projection-dim + 'clipping-threshold' : 30.0, + 'delay' : -1, + 'ng-per-element-scale-options' : ' max-change=0.75 ', + 'ng-affine-options' : ' max-change=0.75 ', + 'self-repair-scale-nonlinearity' : 0.00001, + 'zeroing-interval' : 20, + 'zeroing-threshold' : 15.0, + # if you want to set 'self-repair-scale', ' self-repair-threshold' + # or 'param-stddev' for GruNonlinearityComponent + # For default, they are 1.0e-05, 0.2 and 1.0 / sqrt(d) where d is cell-dim. + # you can add somethig like 'self-repair-scale=xxx' to gru-nonlinearity-options. + # you can also see src/nnet3/nnet-combined-component.h for detail + 'gru-nonlinearity-options' : ' max-change=0.75', + 'dropout-proportion' : -1.0, # If -1.0, no dropout components will be added + 'dropout-per-frame' : True # If False, regular dropout, not per frame + } + + def set_derived_configs(self): + if self.config['recurrent-projection-dim'] <= 0: + self.config['recurrent-projection-dim'] = self.config['cell-dim'] / 4 + + if self.config['non-recurrent-projection-dim'] <= 0: + self.config['non-recurrent-projection-dim'] = \ + self.config['recurrent-projection-dim'] + + def check_configs(self): + for key in ['cell-dim', 'recurrent-projection-dim', + 'non-recurrent-projection-dim']: + if self.config[key] <= 0: + raise RuntimeError("{0} has invalid value {1}.".format( + key, self.config[key])) + + if self.config['delay'] == 0: + raise RuntimeError("delay cannot be zero") + + if (self.config['recurrent-projection-dim'] + + self.config['non-recurrent-projection-dim'] > + self.config['cell-dim']): + raise RuntimeError("recurrent+non-recurrent projection dim exceeds " + "cell dim.") + for key in ['self-repair-scale-nonlinearity']: + if self.config[key] < 0.0 or self.config[key] > 1.0: + raise RuntimeError("{0} has invalid value {2}." + .format(self.layer_type, key, + self.config[key])) + if ((self.config['dropout-proportion'] > 1.0 or + self.config['dropout-proportion'] < 0.0) and + self.config['dropout-proportion'] != -1.0 ): + raise RuntimeError("dropout-proportion has invalid value {0}." + .format(self.config['dropout-proportion'])) + + def auxiliary_outputs(self): + return ['c_t'] + + def output_name(self, auxiliary_output = None): + node_name = 'y_t' + if auxiliary_output is not None: + if auxiliary_output in self.auxiliary_outputs(): + node_name = auxiliary_output + else: + raise Exception("In {0} of type {1}, unknown auxiliary output name {1}".format(self.layer_type, auxiliary_output)) + + return '{0}.{1}'.format(self.name, node_name) + + def output_dim(self, auxiliary_output = None): + if auxiliary_output is not None: + if auxiliary_output in self.auxiliary_outputs(): + if node_name == 'c_t': + return self.config['cell-dim'] + # add code for other auxiliary_outputs here when we decide to expose them + else: + raise Exception("In {0} of type {1}, unknown auxiliary output name {1}".format(self.layer_type, auxiliary_output)) + + return self.config['recurrent-projection-dim'] + self.config['non-recurrent-projection-dim'] + + def get_full_config(self): + ans = [] + config_lines = self.generate_pgru_config() + + for line in config_lines: + for config_name in ['ref', 'final']: + # we do not support user specified matrices in LSTM initialization + # so 'ref' and 'final' configs are the same. + ans.append((config_name, line)) + return ans + + # convenience function to generate the Norm-OPGRU config + def generate_pgru_config(self): + + # assign some variables to reduce verbosity + name = self.name + # in the below code we will just call descriptor_strings as descriptors for conciseness + input_dim = self.descriptors['input']['dim'] + input_descriptor = self.descriptors['input']['final-string'] + cell_dim = self.config['cell-dim'] + rec_proj_dim = self.config['recurrent-projection-dim'] + nonrec_proj_dim = self.config['non-recurrent-projection-dim'] + delay = self.config['delay'] + repair_nonlin = self.config['self-repair-scale-nonlinearity'] + repair_nonlin_str = "self-repair-scale={0:.10f}".format(repair_nonlin) if repair_nonlin is not None else '' + bptrunc_str = ("clipping-threshold={0}" + " zeroing-threshold={1}" + " zeroing-interval={2}" + " recurrence-interval={3}" + "".format(self.config['clipping-threshold'], + self.config['zeroing-threshold'], + self.config['zeroing-interval'], + abs(delay))) + affine_str = self.config['ng-affine-options'] + pes_str = self.config['ng-per-element-scale-options'] + dropout_proportion = self.config['dropout-proportion'] + dropout_per_frame = 'true' if self.config['dropout-per-frame'] else 'false' + + # Natural gradient per element scale parameters + # TODO: decide if we want to keep exposing these options + if re.search('param-mean', pes_str) is None and \ + re.search('param-stddev', pes_str) is None: + pes_str += " param-mean=0.0 param-stddev=1.0 " + + # string for GruNonlinearityComponent + gru_nonlin_str = self.config['gru-nonlinearity-options'] + + # formulation like: + # z_t = \sigmoid ( U^z x_t + W^z s_{t-1} ) # update gate + # o_t = \sigmoid ( U^o x_t + W^o s_{t-1} ) # output gate + # h_t = \tanh ( U^h x_t + W^h \dot c_{t-1} ) + # c_t = ( 1 - z_t ) \dot h_t + z_t \dot c_{t-1} + # y_t_tmp = ( c_t \dot o_t ) W^y + # s_t = renorm ( y_t_tmp[0:rec_proj_dim-1] ) # dim(s_t) = recurrent_dim. + # y_t = batchnorm ( y_t_tmp ) # dim(y_t) = recurrent_dim + non_recurrent_dim. + # This is the output of the GRU. + # Note: + # naming convention: + # .W_. e.g. Gru1.W_i.xr for matrix + # providing output to gate i and operating on an appended vector [x,r] + # notation convention: + # In order to be consistent with the notations which are used in + # nnet-combined-component.cc, we map "\tilde{h_t}" and "h_t" which are + # used in paper to "h_t" and "c_t" + + configs = [] + configs.append("### Begin Gru layer '{0}'".format(name)) + configs.append("# Update gate control : W_z* matrices") + configs.append("component name={0}.W_z.xs type=NaturalGradientAffineComponent input-dim={1} output-dim={2} {3}".format(name, input_dim + rec_proj_dim, cell_dim, affine_str)) + configs.append("# Reset gate control : W_o* matrices") + configs.append("component name={0}.W_o.xs type=NaturalGradientAffineComponent input-dim={1} output-dim={2} {3}".format(name, input_dim + rec_proj_dim, cell_dim, affine_str)) + + + configs.append("# hpart_t related matrix : W_hpart matric") + configs.append("component name={0}.W_hpart.x type=NaturalGradientAffineComponent input-dim={1} output-dim={2} {3}".format(name, input_dim, cell_dim , affine_str)) + + configs.append("# Defining the non-linearities") + configs.append("component name={0}.z type=SigmoidComponent dim={1} {2}".format(name, cell_dim, repair_nonlin_str)) + configs.append("component name={0}.o type=SigmoidComponent dim={1} {2}".format(name, cell_dim, repair_nonlin_str)) + + if dropout_proportion != -1.0: + configs.append("# Defining the dropout component") + configs.append("component name={0}.dropout type=DropoutComponent dim={1} " + "dropout-proportion={2} dropout-per-frame={3}" + .format(name, cell_dim, dropout_proportion, dropout_per_frame)) + + recurrent_connection = '{0}.s_t'.format(name) + + configs.append("# z_t") + configs.append("component-node name={0}.z_t_pre component={0}.W_z.xs input=Append({1}, IfDefined(Offset({2}, {3})))".format(name, input_descriptor, recurrent_connection, delay)) + if dropout_proportion != -1.0: + configs.append("component-node name={0}.z_t_predrop component={0}.z input={0}.z_t_pre".format(name)) + configs.append("component-node name={0}.z_t component={0}.dropout input={0}.z_t_predrop".format(name)) + else: + configs.append("component-node name={0}.z_t component={0}.z input={0}.z_t_pre".format(name)) + + configs.append("# o_t") + configs.append("component-node name={0}.o_t_pre component={0}.W_o.xs input=Append({1}, IfDefined(Offset({2}, {3})))".format(name, input_descriptor, recurrent_connection, delay)) + if dropout_proportion != -1.0: + configs.append("component-node name={0}.o_t_predrop component={0}.o input={0}.o_t_pre".format(name)) + configs.append("component-node name={0}.o_t component={0}.dropout input={0}.o_t_predrop".format(name)) + else: + configs.append("component-node name={0}.o_t component={0}.o input={0}.o_t_pre".format(name)) + + configs.append("# hpart_t") + configs.append("component-node name={0}.hpart_t component={0}.W_hpart.x input={1}".format(name, input_descriptor)) + + configs.append("# c_t") + configs.append("# Note: the output of OutputGruNonlinearityComponent is (h_t, c_t), we use the second half.") + configs.append("component name={0}.gru_nonlin type=OutputGruNonlinearityComponent cell-dim={1} {2}".format(name, cell_dim, gru_nonlin_str)) + configs.append("component-node name={0}.gru_nonlin_t component={0}.gru_nonlin input=Append({0}.z_t, {0}.hpart_t, IfDefined(Offset({0}.c_t, {1})))".format(name, delay)) + configs.append("dim-range-node name={0}.c_t input-node={0}.gru_nonlin_t dim-offset={1} dim={1}".format(name, cell_dim)) + + configs.append("# the projected matrix W_y.cdoto and y_t_tmp") + configs.append("component name={0}.cdoto type=ElementwiseProductComponent input-dim={1} output-dim={2}".format(name, 2 * cell_dim, cell_dim)) + configs.append("component-node name={0}.cdoto component={0}.cdoto input=Append({0}.c_t, {0}.o_t)".format(name)) + configs.append("component name={0}.W_y.cdoto type=NaturalGradientAffineComponent input-dim={1} output-dim={2} {3}".format(name, cell_dim, rec_proj_dim + nonrec_proj_dim, affine_str)) + configs.append("component-node name={0}.y_t_tmp component={0}.W_y.cdoto input={0}.cdoto".format(name)) + + configs.append("# s_t : recurrence") + configs.append("component name={0}.renorm type=NormalizeComponent dim={1} target-rms=1.0".format(name, rec_proj_dim)) + configs.append("component name={0}.s_r type=BackpropTruncationComponent dim={1} {2}".format(name, rec_proj_dim, bptrunc_str)) + configs.append("dim-range-node name={0}.s_t_pre input-node={0}.y_t_tmp dim-offset=0 dim={1}".format(name, rec_proj_dim)) + configs.append("component-node name={0}.s_t_renorm component={0}.renorm input={0}.s_t_pre".format(name)) + configs.append("component-node name={0}.s_t component={0}.s_r input={0}.s_t_renorm".format(name)) + + configs.append("# y_t : output") + configs.append("component name={0}.batchnorm type=BatchNormComponent dim={1} target-rms=1.0".format(name, rec_proj_dim + nonrec_proj_dim)) + configs.append("component-node name={0}.y_t component={0}.batchnorm input={0}.y_t_tmp".format(name)) + + return configs diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py index 1d284146e35..b540423e3cd 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py @@ -27,6 +27,7 @@ 'relu-batchnorm-layer' : xlayers.XconfigBasicLayer, 'relu-batchnorm-so-layer' : xlayers.XconfigBasicLayer, 'batchnorm-so-relu-layer' : xlayers.XconfigBasicLayer, + 'batchnorm-layer' : xlayers.XconfigBasicLayer, 'sigmoid-layer' : xlayers.XconfigBasicLayer, 'tanh-layer' : xlayers.XconfigBasicLayer, 'fixed-affine-layer' : xlayers.XconfigFixedAffineLayer, @@ -68,13 +69,22 @@ 'opgru-layer' : xlayers.XconfigOpgruLayer, 'norm-pgru-layer' : xlayers.XconfigNormPgruLayer, 'norm-opgru-layer' : xlayers.XconfigNormOpgruLayer, + 'fast-gru-layer' : xlayers.XconfigFastGruLayer, + 'fast-pgru-layer' : xlayers.XconfigFastPgruLayer, + 'fast-norm-pgru-layer' : xlayers.XconfigFastNormPgruLayer, + 'fast-opgru-layer' : xlayers.XconfigFastOpgruLayer, + 'fast-norm-opgru-layer' : xlayers.XconfigFastNormOpgruLayer, 'tdnnf-layer': xlayers.XconfigTdnnfLayer, 'prefinal-layer': xlayers.XconfigPrefinalLayer, 'renorm-component': xlayers.XconfigRenormComponent, 'batchnorm-component': xlayers.XconfigBatchnormComponent, 'no-op-component': xlayers.XconfigNoOpComponent, 'linear-component': xlayers.XconfigLinearComponent, - 'scale-component': xlayers.XconfigPerElementScaleComponent + 'affine-component': xlayers.XconfigAffineComponent, + 'scale-component': xlayers.XconfigPerElementScaleComponent, + 'dim-range-component': xlayers.XconfigDimRangeComponent, + 'offset-component': xlayers.XconfigPerElementOffsetComponent, + 'combine-feature-maps-layer': xlayers.XconfigCombineFeatureMapsLayer } # Turn a config line and a list of previous layers into diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/trivial_layers.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/trivial_layers.py index 6b8e3c3a5c2..2728ad40639 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/trivial_layers.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/trivial_layers.py @@ -206,7 +206,9 @@ def set_default_configs(self): 'dim': -1, 'orthonormal-constraint': '', 'max-change': 0.75, - 'l2-regularize': '' } + 'l2-regularize': '', + 'param-stddev': '', + 'learning-rate-factor': '' } def check_configs(self): if self.config['dim'] <= 0: @@ -240,7 +242,8 @@ def _generate_config(self): output_dim = self.config['dim'] opts = '' - for opt_name in ['orthonormal-constraint', 'max-change', 'l2-regularize']: + for opt_name in ['orthonormal-constraint', 'max-change', 'l2-regularize', + 'param-stddev', 'learning-rate-factor' ]: value = self.config[opt_name] if value != '': opts += ' {0}={1}'.format(opt_name, value) @@ -255,6 +258,181 @@ def _generate_config(self): return configs +class XconfigCombineFeatureMapsLayer(XconfigLayerBase): + """This class is for parsing lines like + 'combine-feature-maps-layer name=combine_features1 height=40 num-filters1=1 num-filters2=4' + or + 'combine-feature-maps-layer name=combine_features1 height=40 num-filters1=1 num-filters2=4 num-filters3=2' + + It produces a PermuteComponent. It expects its input to be two or three things + appended together, where the first is of dimension height * num-filters1 and + the second is of dimension height * num-filters2 (and the third, if present is + of dimension height * num-filters2; it interpolates the filters + so the output can be interpreted as a single feature map with the same height + as the input and the sum of the num-filters. + + This is to be used in convolutional setups as part of how we combine the + filterbank inputs with ivectors. + """ + + def __init__(self, first_token, key_to_value, prev_names=None): + XconfigLayerBase.__init__(self, first_token, key_to_value, prev_names) + + def set_default_configs(self): + self.config = { 'input': '[-1]', + 'num-filters1': -1, + 'num-filters2': -1, + 'num-filters3': 0, + 'height': -1 } + + def check_configs(self): + input_dim = self.descriptors['input']['dim'] + if (self.config['num-filters1'] <= 0 or + self.config['num-filters2'] <= 0 or + self.config['num-filters3'] < 0 or + self.config['height'] <= 0): + raise RuntimeError("invalid values of num-filters1, num-filters2 and/or height") + f1 = self.config['num-filters1'] + f2 = self.config['num-filters2'] + f3 = self.config['num-filters3'] + h = self.config['height'] + if input_dim != (f1 + f2 + f3) * h: + raise RuntimeError("Expected input-dim={0} based on num-filters1={1}, num-filters2={2}, " + "num-filters3={3} and height={4}, but got input-dim={5}".format( + (f1 + f2 + f3) * h, f1, f2, f3, h, input_dim)) + + def output_name(self, auxiliary_output=None): + assert auxiliary_output is None + return self.name + + def output_dim(self, auxiliary_output=None): + assert auxiliary_output is None + input_dim = self.descriptors['input']['dim'] + return input_dim + + def get_full_config(self): + ans = [] + config_lines = self._generate_config() + + for line in config_lines: + for config_name in ['ref', 'final']: + # we do not support user specified matrices in this layer + # so 'ref' and 'final' configs are the same. + ans.append((config_name, line)) + return ans + + def _generate_config(self): + # by 'descriptor_final_string' we mean a string that can appear in + # config-files, i.e. it contains the 'final' names of nodes. + input_desc = self.descriptors['input']['final-string'] + dim = self.descriptors['input']['dim'] + num_filters1 = self.config['num-filters1'] + num_filters2 = self.config['num-filters2'] + num_filters3 = self.config['num-filters3'] # normally 0. + height = self.config['height'] + assert dim == (num_filters1 + num_filters2 + num_filters3) * height + + column_map = [] + for h in range(height): + for f in range(num_filters1): + column_map.append(h * num_filters1 + f) + for f in range(num_filters2): + column_map.append(height * num_filters1 + h * num_filters2 + f) + for f in range(num_filters3): + column_map.append(height * (num_filters1 + num_filters2) + h * num_filters3 + f) + + configs = [] + line = ('component name={0} type=PermuteComponent column-map={1} '.format( + self.name, ','.join([str(x) for x in column_map]))) + configs.append(line) + + line = ('component-node name={0} component={0} input={1}'.format( + self.name, input_desc)) + configs.append(line) + return configs + + + + +class XconfigAffineComponent(XconfigLayerBase): + """This class is for parsing lines like + 'affine-component name=linear1 dim=1024 input=Append(-3,0,3)' + which will produce just a single component, of type NaturalGradientAffineComponent, + with output-dim 1024 in this case, and input-dim determined by the dimension + of the input . + + Parameters of the class, and their defaults: + input='[-1]' [Descriptor giving the input of the layer.] + dim=-1 [Dimension of the output] + + The following (shown with their effective defaults) are just passed through + to the component's config line. + + orthonormal-constraint=0.0 + max-change=0.75 + l2-regularize=0.0 + + """ + def __init__(self, first_token, key_to_value, prev_names=None): + XconfigLayerBase.__init__(self, first_token, key_to_value, prev_names) + + def set_default_configs(self): + self.config = {'input': '[-1]', + 'dim': -1, + 'orthonormal-constraint': '', + 'max-change': 0.75, + 'param-stddev': '', + 'bias-stddev': '', + 'l2-regularize': '' } + + def check_configs(self): + if self.config['dim'] <= 0: + raise RuntimeError("'dim' must be specified and > 0.") + + def output_name(self, auxiliary_output=None): + assert auxiliary_output is None + return self.name + + def output_dim(self, auxiliary_output=None): + assert auxiliary_output is None + assert self.config['dim'] > 0 + return self.config['dim'] + + def get_full_config(self): + ans = [] + config_lines = self._generate_config() + + for line in config_lines: + for config_name in ['ref', 'final']: + # we do not support user specified matrices in this layer + # so 'ref' and 'final' configs are the same. + ans.append((config_name, line)) + return ans + + def _generate_config(self): + # by 'descriptor_final_string' we mean a string that can appear in + # config-files, i.e. it contains the 'final' names of nodes. + input_desc = self.descriptors['input']['final-string'] + input_dim = self.descriptors['input']['dim'] + output_dim = self.config['dim'] + + opts = '' + for opt_name in ['orthonormal-constraint', 'max-change', 'l2-regularize', + 'param-stddev', 'bias-stddev']: + value = self.config[opt_name] + if value != '': + opts += ' {0}={1}'.format(opt_name, value) + + configs = [] + line = ('component name={0} type=NaturalGradientAffineComponent input-dim={1} output-dim={2} ' + '{3}'.format(self.name, input_dim, output_dim, opts)) + configs.append(line) + line = ('component-node name={0} component={0} input={1}'.format( + self.name, input_desc)) + configs.append(line) + return configs + + class XconfigPerElementScaleComponent(XconfigLayerBase): """This class is for parsing lines like 'scale-component name=scale1 input=Append(-3,0,3)' @@ -328,3 +506,141 @@ def _generate_config(self): self.name, input_desc)) configs.append(line) return configs + +class XconfigPerElementOffsetComponent(XconfigLayerBase): + """This class is for parsing lines like + 'offset-component name=offset1 input=Append(-3,0,3)' + which will produce just a single component, of type PerElementOffsetComponent, with + output-dim 1024 in this case, and input-dim determined by the dimension of the input . + + Parameters of the class, and their defaults: + input='[-1]' [Descriptor giving the input of the layer.] + + The following (shown with their effective defaults) are just passed through + to the component's config line. (These defaults are mostly set in the + code). + + max-change=0.75 + l2-regularize=0.0 + param-mean=0.0 # affects initialization + param-stddev=0.0 # affects initialization + learning-rate-factor=1.0 + """ + def __init__(self, first_token, key_to_value, prev_names=None): + XconfigLayerBase.__init__(self, first_token, key_to_value, prev_names) + + def set_default_configs(self): + self.config = {'input': '[-1]', + 'l2-regularize': '', + 'max-change': 0.75, + 'param-mean': '', + 'param-stddev': '', + 'learning-rate-factor': '' } + + def check_configs(self): + pass + + def output_name(self, auxiliary_output=None): + assert auxiliary_output is None + return self.name + + def output_dim(self, auxiliary_output=None): + assert auxiliary_output is None + return self.descriptors['input']['dim'] + + def get_full_config(self): + ans = [] + config_lines = self._generate_config() + + for line in config_lines: + for config_name in ['ref', 'final']: + # we do not support user specified matrices in this layer + # so 'ref' and 'final' configs are the same. + ans.append((config_name, line)) + return ans + + def _generate_config(self): + # by 'descriptor_final_string' we mean a string that can appear in + # config-files, i.e. it contains the 'final' names of nodes. + input_desc = self.descriptors['input']['final-string'] + dim = self.descriptors['input']['dim'] + + opts = '' + for opt_name in ['learning-rate-factor', 'max-change', 'l2-regularize', 'param-mean', + 'param-stddev' ]: + value = self.config[opt_name] + if value != '': + opts += ' {0}={1}'.format(opt_name, value) + + configs = [] + line = ('component name={0} type=PerElementOffsetComponent dim={1} {2} ' + ''.format(self.name, dim, opts)) + configs.append(line) + line = ('component-node name={0} component={0} input={1}'.format( + self.name, input_desc)) + configs.append(line) + return configs + + +class XconfigDimRangeComponent(XconfigLayerBase): + """This class is for parsing lines like + 'dim-range-component name=feature1 input=Append(-3,0,3) dim=40 dim-offset=0' + which will produce just a single component, of part of the input. + Parameters of the class, and their defaults: + input='[-1]' [Descriptor giving the input of the layer.] + dim=-1 [Dimension of the output.] + dim-offset=0 [Dimension offset of the input.] + """ + def __init__(self, first_token, key_to_value, prev_names=None): + XconfigLayerBase.__init__(self, first_token, key_to_value, prev_names) + + def set_default_configs(self): + self.config = {'input': '[-1]', + 'dim': -1, + 'dim-offset': 0 } + + def check_configs(self): + input_dim = self.descriptors['input']['dim'] + if self.config['dim'] <= 0: + raise RuntimeError("'dim' must be specified and > 0.") + elif self.config['dim'] > input_dim: + raise RuntimeError("'dim' must be specified and lower than the input dim.") + if self.config['dim-offset'] < 0 : + raise RuntimeError("'dim-offset' must be specified and >= 0.") + elif self.config['dim-offset'] + self.config['dim'] > input_dim: + raise RuntimeError("'dim-offset' plus output dim must be lower than the input dim.") + + def output_name(self, auxiliary_output=None): + assert auxiliary_output is None + return self.name + + def output_dim(self, auxiliary_output=None): + assert auxiliary_output is None + output_dim = self.config['dim'] + if output_dim <= 0: + self.config['dim'] = self.descriptors['input']['dim'] + return output_dim + + def get_full_config(self): + ans = [] + config_lines = self._generate_config() + + for line in config_lines: + for config_name in ['ref', 'final']: + # we do not support user specified matrices in this layer + # so 'ref' and 'final' configs are the same. + ans.append((config_name, line)) + return ans + + def _generate_config(self): + # by 'descriptor_final_string' we mean a string that can appear in + # config-files, i.e. it contains the 'final' names of nodes. + input_node = self.descriptors['input']['final-string'] + output_dim = self.config['dim'] + dim_offset = self.config['dim-offset'] + + configs = [] + line = ('dim-range-node name={0} input-node={1} dim={2} dim-offset={3}'.format( + self.name, input_node, output_dim, dim_offset)) + configs.append(line) + return configs diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/utils.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/utils.py index 08de18167cd..0188248d694 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/utils.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/utils.py @@ -184,7 +184,7 @@ def convert_value_to_type(key, dest_type, string_value): # Also, in any place a raw input/layer/output name can appear, we accept things # like [-1] meaning the previous input/layer/output's name, or [-2] meaning the # last-but-one input/layer/output, and so on. -class Descriptor: +class Descriptor(object): def __init__(self, descriptor_string = None, prev_names = None): @@ -595,7 +595,7 @@ def parse_config_line(orig_config_line): rest_of_line = ' '.join(fields) # rest of the line can be of the form 'a=1 b=" x=1 y=2 " c=Append( i1, i2)' - positions = list(map(lambda x: x.start(), re.finditer('"', rest_of_line))) + positions = [x.start() for x in re.finditer('"', rest_of_line)] if not len(positions) % 2 == 0: raise RuntimeError("Double-quotes should occur in pairs") diff --git a/egs/wsj/s5/steps/lmrescore_rnnlm_lat.sh b/egs/wsj/s5/steps/lmrescore_rnnlm_lat.sh index 1dbcbe1a192..049e15df303 100755 --- a/egs/wsj/s5/steps/lmrescore_rnnlm_lat.sh +++ b/egs/wsj/s5/steps/lmrescore_rnnlm_lat.sh @@ -57,12 +57,9 @@ fi oldlm=$oldlang/G.fst if [ -f $oldlang/G.carpa ]; then oldlm=$oldlang/G.carpa -elif [ ! -f $oldlm ]; then - echo "$0: expecting either $oldlang/G.fst or $oldlang/G.carpa to exist" &&\ - exit 1; fi -[ ! -f $oldlm ] && echo "$0: Missing file $oldlm" && exit 1; +[ ! -f $oldlm ] && echo "$0: expecting either $oldlang/G.fst or $oldlang/G.carpa to exist" && exit 1; [ ! -f $rnnlm_dir/rnnlm ] && echo "$0: Missing file $rnnlm_dir/rnnlm" && exit 1; [ ! -f $rnnlm_dir/unk.probs ] &&\ echo "$0: Missing file $rnnlm_dir/unk.probs" && exit 1; diff --git a/egs/wsj/s5/steps/make_mfcc.sh b/egs/wsj/s5/steps/make_mfcc.sh index c88e0d65e65..8514ce4e38d 100755 --- a/egs/wsj/s5/steps/make_mfcc.sh +++ b/egs/wsj/s5/steps/make_mfcc.sh @@ -75,6 +75,8 @@ if [ -f $data/spk2warp ]; then elif [ -f $data/utt2warp ]; then echo "$0 [info]: using VTLN warp factors from $data/utt2warp" vtln_opts="--vtln-map=ark:$data/utt2warp" +else + vtln_opts="" fi for n in $(seq $nj); do diff --git a/egs/wsj/s5/steps/make_mfcc_pitch_online.sh b/egs/wsj/s5/steps/make_mfcc_pitch_online.sh index 26588506053..df51057a00b 100755 --- a/egs/wsj/s5/steps/make_mfcc_pitch_online.sh +++ b/egs/wsj/s5/steps/make_mfcc_pitch_online.sh @@ -70,7 +70,7 @@ required="$scp $mfcc_config $online_pitch_config" for f in $required; do if [ ! -f $f ]; then - echo "make_mfcc_pitch.sh: no such file $f" + echo "$0: no such file $f" exit 1; fi done diff --git a/egs/wsj/s5/steps/nnet/train.sh b/egs/wsj/s5/steps/nnet/train.sh index c23a15362c7..50a62837b67 100755 --- a/egs/wsj/s5/steps/nnet/train.sh +++ b/egs/wsj/s5/steps/nnet/train.sh @@ -433,18 +433,6 @@ else ${bn_dim:+ --bottleneck-dim=$bn_dim} \ "$cnn_fea" $num_tgt $hid_layers $hid_dim >>$nnet_proto ;; - cnn2d) - delta_order=$([ -z $delta_opts ] && echo "0" || { echo $delta_opts | tr ' ' '\n' | grep "delta[-_]order" | sed 's:^.*=::'; }) - echo "Debug : $delta_opts, delta_order $delta_order" - utils/nnet/make_cnn2d_proto.py $cnn_proto_opts \ - --splice=$splice --delta-order=$delta_order --dir=$dir \ - $num_fea >$nnet_proto - cnn_fea=$(cat $nnet_proto | grep -v '^$' | tail -n1 | awk '{ print $5; }') - utils/nnet/make_nnet_proto.py $proto_opts \ - --no-smaller-input-weights \ - ${bn_dim:+ --bottleneck-dim=$bn_dim} \ - "$cnn_fea" $num_tgt $hid_layers $hid_dim >>$nnet_proto - ;; lstm) utils/nnet/make_lstm_proto.py $proto_opts \ $num_fea $num_tgt >$nnet_proto diff --git a/egs/wsj/s5/steps/nnet2/make_multisplice_configs.py b/egs/wsj/s5/steps/nnet2/make_multisplice_configs.py index 6e7bff3fa17..b5338b516e8 100755 --- a/egs/wsj/s5/steps/nnet2/make_multisplice_configs.py +++ b/egs/wsj/s5/steps/nnet2/make_multisplice_configs.py @@ -4,14 +4,16 @@ # Creates the nnet.config and hidde_*.config scripts used in train_pnorm_multisplice.sh # Parses the splice string to generate relevant variables for get_egs.sh, get_lda.sh and nnet/hidden.config files +from __future__ import division +from __future__ import print_function import re, argparse, sys, math, warnings # returns the set of frame indices required to perform the convolution # between sequences with frame indices in x and y def get_convolution_index_set(x, y): z = [] - for i in xrange(len(x)): - for j in xrange(len(y)): + for i in range(len(x)): + for j in range(len(y)): z.append(x[i]+y[j]) z = list(set(z)) z.sort() @@ -19,7 +21,7 @@ def get_convolution_index_set(x, y): def parse_splice_string(splice_string): layerwise_splice_indexes = splice_string.split('layer')[1:] - print splice_string.split('layer') + print(splice_string.split('layer')) contexts={} first_right_context = 0 # default value first_left_context = 0 # default value @@ -29,14 +31,14 @@ def parse_splice_string(splice_string): try: for cur_splice_indexes in layerwise_splice_indexes: layer_index, frame_indexes = cur_splice_indexes.split("/") - frame_indexes = map(lambda x: int(x), frame_indexes.split(':')) + frame_indexes = [int(x) for x in frame_indexes.split(':')] layer_index = int(layer_index) assert(layer_index >= 0) if layer_index == 0: first_left_context = min(frame_indexes) first_right_context = max(frame_indexes) try: - assert(frame_indexes == range(first_left_context, first_right_context+1)) + assert(frame_indexes == list(range(first_left_context, first_right_context+1))) except AssertionError: raise Exception('Currently the first splice component just accepts contiguous context.') try: @@ -46,11 +48,11 @@ def parse_splice_string(splice_string): left context provided is %d and right context provided is %d.""" % (first_left_context, first_right_context)) # convolve the current splice indices with the splice indices until last layer nnet_frame_indexes = get_convolution_index_set(frame_indexes, nnet_frame_indexes) - cur_context = ":".join(map(lambda x: str(x), frame_indexes)) + cur_context = ":".join([str(x) for x in frame_indexes]) contexts[layer_index] = cur_context except ValueError: raise Exception('Unknown format in splice_indexes variable: {0}'.format(params.splice_indexes)) - print nnet_frame_indexes + print(nnet_frame_indexes) max_left_context = min(nnet_frame_indexes) max_right_context = max(nnet_frame_indexes) return [contexts, ' nnet_left_context={0};\n nnet_right_context={1}\n first_left_context={2};\n first_right_context={3}\n'.format(abs(max_left_context), abs(max_right_context), abs(first_left_context), abs(first_right_context) )] @@ -87,7 +89,7 @@ def create_config_files(output_dir, params): except KeyError: raise Exception('A splice layer is expected to be the first layer. Provide a context for the first layer.') - for i in xrange(1, params.num_hidden_layers): #just run till num_hidden_layers-1 since we do not add splice before the final affine transform + for i in range(1, params.num_hidden_layers): #just run till num_hidden_layers-1 since we do not add splice before the final affine transform lines=[] context_len = 1 if i in contexts: @@ -109,7 +111,7 @@ def create_config_files(output_dir, params): if __name__ == "__main__": - print " ".join(sys.argv) + print(" ".join(sys.argv)) parser = argparse.ArgumentParser() parser.add_argument('--splice-indexes', type=str, help='string specifying the indexes for the splice layers throughout the network') parser.add_argument('--total-input-dim', type=int, help='dimension of the input to the network') @@ -127,7 +129,7 @@ def create_config_files(output_dir, params): parser.add_argument("output_dir", type=str, help="output directory to store the files") params = parser.parse_args() - print params + print(params) if params.mode == "contexts": [context, context_variables] = parse_splice_string(params.splice_indexes) var_file = open("{0}/vars".format(params.output_dir), "w") diff --git a/egs/wsj/s5/steps/nnet3/align.sh b/egs/wsj/s5/steps/nnet3/align.sh index cf1cc9124d3..aa2de2ee1a5 100755 --- a/egs/wsj/s5/steps/nnet3/align.sh +++ b/egs/wsj/s5/steps/nnet3/align.sh @@ -24,6 +24,7 @@ extra_right_context=0 extra_left_context_initial=-1 extra_right_context_final=-1 online_ivector_dir= +graphs_scp= # End configuration options. echo "$0 $@" # Print the command line for logging @@ -52,10 +53,9 @@ dir=$4 oov=`cat $lang/oov.int` || exit 1; mkdir -p $dir/log echo $nj > $dir/num_jobs -touch $dir/per_utt -sdata=$data/split${nj}utt +sdata=$data/split${nj} [[ -d $sdata && $data/feats.scp -ot $sdata ]] || \ - split_data.sh --per-utt $data $nj || exit 1; + split_data.sh $data $nj || exit 1; if $use_gpu; then queue_opt="--gpu 1" @@ -97,8 +97,6 @@ fi echo "$0: aligning data in $data using model from $srcdir, putting alignments in $dir" -tra="ark:utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt $sdata/JOB/text|"; - frame_subsampling_opt= if [ -f $srcdir/frame_subsampling_factor ]; then # e.g. for 'chain' systems @@ -114,9 +112,20 @@ if [ -f $srcdir/frame_subsampling_factor ]; then fi fi +if [ ! -z "$graphs_scp" ]; then + if [ ! -f $graphs_scp ]; then + echo "Could not find graphs $graphs_scp" && exit 1 + fi + tra="scp:utils/filter_scp.pl $sdata/JOB/utt2spk $graphs_scp |" + prog=compile-train-graphs-fsts +else + tra="ark:utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt $sdata/JOB/text|"; + prog=compile-train-graphs +fi $cmd $queue_opt JOB=1:$nj $dir/log/align.JOB.log \ - compile-train-graphs --read-disambig-syms=$lang/phones/disambig.int $dir/tree $srcdir/${iter}.mdl $lang/L.fst "$tra" ark:- \| \ + $prog --read-disambig-syms=$lang/phones/disambig.int $dir/tree \ + $srcdir/${iter}.mdl $lang/L.fst "$tra" ark:- \| \ nnet3-align-compiled $scale_opts $ivector_opts $frame_subsampling_opt \ --frames-per-chunk=$frames_per_chunk \ --extra-left-context=$extra_left_context \ diff --git a/egs/wsj/s5/steps/nnet3/align_lats.sh b/egs/wsj/s5/steps/nnet3/align_lats.sh index 4edc38751c8..e4ba7309435 100755 --- a/egs/wsj/s5/steps/nnet3/align_lats.sh +++ b/egs/wsj/s5/steps/nnet3/align_lats.sh @@ -50,10 +50,9 @@ dir=$4 oov=`cat $lang/oov.int` || exit 1; mkdir -p $dir/log echo $nj > $dir/num_jobs -touch $dir/per_utt -sdata=$data/split${nj}utt +sdata=$data/split${nj} [[ -d $sdata && $data/feats.scp -ot $sdata ]] || \ - split_data.sh --per-utt $data $nj || exit 1; + split_data.sh $data $nj || exit 1; extra_files= if [ ! -z "$online_ivector_dir" ]; then diff --git a/egs/wsj/s5/steps/nnet3/chain/e2e/compute_biphone_stats.py b/egs/wsj/s5/steps/nnet3/chain/e2e/compute_biphone_stats.py new file mode 100755 index 00000000000..e009cc17a9b --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/chain/e2e/compute_biphone_stats.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Hossein Hadian +# Apache 2.0 + +import argparse +from os.path import join +import sys +import copy +import random + +parser = argparse.ArgumentParser(description="""This script reads + sequences of phone ids from std input and counts mono/biphone stats + and writes the results to std out. The output can be used with + gmm-init-biphone to create a better tree. The first part of the + outupt is biphone counts with this format for each line: + + and the second part of the output is monophone counts with the + following format: + """) +parser.add_argument('langdir', type=str) +parser.add_argument('--shared-phones', type=str, choices=['true','false'], + default='true', + help="If true, stats will be collected for shared phones.") + +args = parser.parse_args() +args.shared_phones = True if args.shared_phones == 'true' else False + +# Read phone sets +phone_sets = [] +phones = [] +phone_to_shard_phone = {} +phone_to_shard_phone[0] = 0 # The no-left-context case +with open(join(args.langdir, 'phones/sets.int'), 'r', encoding='latin-1') as f: + for line in f: + phone_set = line.strip().split() + phone_sets.append(phone_set) + for phone in phone_set: + phones.append(phone) + phone_to_shard_phone[phone] = phone_set[0] + +print('Loaded {} phone-sets containing {} phones.'.format(len(phone_sets), + len(phones)), + file=sys.stderr) + +biphone_counts = {} +mono_counts = {} +for line in sys.stdin: + line = line.strip().split() + key = line[0] + line_phones = line[1:] + for pair in zip([0] + line_phones, line_phones): # 0 is for the no left-context case + if args.shared_phones: + pair = (phone_to_shard_phone[pair[0]], phone_to_shard_phone[pair[1]]) + if pair not in biphone_counts: + biphone_counts[pair] = 0 + biphone_counts[pair] += 1 + mono_counts[pair[1]] = 1 if pair[1] not in mono_counts else mono_counts[pair[1]] + 1 + +for phone1 in [0] + phones: + for phone2 in phones: + pair = (phone1, phone2) + shared_pair = ((phone_to_shard_phone[pair[0]], phone_to_shard_phone[pair[1]]) + if args.shared_phones else pair) + count = biphone_counts[shared_pair] if shared_pair in biphone_counts else 0 + if count != 0: + print('{} {} {}'.format(pair[0], pair[1], count)) +for phone in phones: + shared = phone_to_shard_phone[phone] if args.shared_phones else phone + count = mono_counts[shared] if shared in mono_counts else 0 + if count != 0: + print('{} {}'.format(phone, count)) diff --git a/egs/wsj/s5/steps/nnet3/chain/e2e/prepare_e2e.sh b/egs/wsj/s5/steps/nnet3/chain/e2e/prepare_e2e.sh index a060f0f3b36..07d5ee8cfb8 100755 --- a/egs/wsj/s5/steps/nnet3/chain/e2e/prepare_e2e.sh +++ b/egs/wsj/s5/steps/nnet3/chain/e2e/prepare_e2e.sh @@ -14,12 +14,23 @@ cmd=run.pl nj=4 stage=0 shared_phones=true -treedir= # if specified, the tree and model will be copied from there +treedir= # If specified, the tree and model will be copied from there # note that it may not be flat start anymore. -type=mono # can be either mono or biphone -- either way +type=mono # Can be either mono or biphone -- either way # the resulting tree is full (i.e. it doesn't do any tying) +ci_silence=false # If true, silence phones will be treated as context independent scale_opts="--transition-scale=0.0 --self-loop-scale=0.0" +tie=false # If true, gmm-init-biphone will do some tying when + # creating the full biphone tree (it won't be full anymore). + # Specifically, it will revert to monophone if the data + # counts for a biphone are smaller than min_biphone_count. + # If the monophone count is also smaller than min_monophone_count, + # it will revert to a shared global phone. Note that this + # only affects biphone models (i.e., type=biphone) which + # use the special chain topology. +min_biphone_count=100 +min_monophone_count=20 # End configuration section. echo "$0 $@" # Print the command line for logging @@ -34,6 +45,7 @@ if [ $# != 3 ]; then echo " --config # config containing options" echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." echo " --type # context dependency type" + echo " --tie # enable/disable count-based tying" exit 1; fi @@ -63,12 +75,28 @@ if $shared_phones; then shared_phones_opt="--shared-phones=$lang/phones/sets.int" fi +ciphonelist=`cat $lang/phones/context_indep.csl` || exit 1; +if $ci_silence; then + ci_opt="--ci-phones=$ciphonelist" +fi + +tie_opts= +if $tie && [[ "$type" = "biphone" ]]; then + cat $data/text | steps/chain/e2e/text_to_phones.py --edge-silprob 0 \ + --between-silprob 0 \ + $lang | \ + cut -d' ' -f 2- | utils/sym2int.pl $lang/phones.txt | \ + steps/chain/e2e/compute_biphone_stats.py $lang >$dir/phone-stats.txt + tie_opts="--min-biphone-count=$min_biphone_count \ +--min-monophone-count=$min_monophone_count --phone-counts=$dir/phone-stats.txt" +fi + if [ $stage -le 0 ]; then if [ -z $treedir ]; then echo "$0: Initializing $type system." # feat dim does not matter here. Just set it to 10 $cmd $dir/log/init_${type}_mdl_tree.log \ - gmm-init-$type $shared_phones_opt $lang/topo 10 \ + gmm-init-$type $tie_opts $ci_opt $shared_phones_opt $lang/topo 10 \ $dir/0.mdl $dir/tree || exit 1; else echo "$0: Copied tree/mdl from $treedir." >$dir/log/init_mdl_tree.log diff --git a/egs/wsj/s5/steps/nnet3/chain/e2e/text_to_phones.py b/egs/wsj/s5/steps/nnet3/chain/e2e/text_to_phones.py index 0ff05e3c48e..2c51cb57750 100755 --- a/egs/wsj/s5/steps/nnet3/chain/e2e/text_to_phones.py +++ b/egs/wsj/s5/steps/nnet3/chain/e2e/text_to_phones.py @@ -8,6 +8,7 @@ to phone transcriptions using the provided lexicon, and writes them to standard output. """ +from __future__ import print_function import argparse from os.path import join diff --git a/egs/wsj/s5/steps/nnet3/chain/get_egs.sh b/egs/wsj/s5/steps/nnet3/chain/get_egs.sh index 9996820d6d3..ae4a0474a24 100755 --- a/egs/wsj/s5/steps/nnet3/chain/get_egs.sh +++ b/egs/wsj/s5/steps/nnet3/chain/get_egs.sh @@ -151,36 +151,40 @@ mkdir -p $dir/log $dir/info # Get list of validation utterances. frame_shift=$(utils/data/get_frame_shift.sh $data) || exit 1 +if [ -f $data/utt2uniq ]; then + # Must hold out all augmented versions of the same utterance. + echo "$0: File $data/utt2uniq exists, so ensuring the hold-out set" \ + "includes all perturbed versions of the same source utterance." + utils/utt2spk_to_spk2utt.pl $data/utt2uniq 2>/dev/null | + awk -v max_utt=$num_utts_subset '{ + for (n=2;n<=NF;n++) print $n; + printed += NF-1; + if (printed >= max_utt) nextfile; }' | + sort > $dir/valid_uttlist +else + awk '{print $1}' $data/utt2spk | \ + utils/shuffle_list.pl 2>/dev/null | \ + head -$num_utts_subset > $dir/valid_uttlist +fi +len_valid_uttlist=$(wc -l < $dir/valid_uttlist) + awk '{print $1}' $data/utt2spk | \ - utils/shuffle_list.pl 2>/dev/null | head -$num_utts_subset > $dir/valid_uttlist + utils/filter_scp.pl --exclude $dir/valid_uttlist | \ + utils/shuffle_list.pl 2>/dev/null | \ + head -$num_utts_subset > $dir/train_subset_uttlist +len_trainsub_uttlist=$(wc -l <$dir/train_subset_uttlist) -len_uttlist=$(wc -l < $dir/valid_uttlist) -if [ $len_uttlist -lt $num_utts_subset ]; then - echo "Number of utterances is very small. Please check your data." && exit 1; +if [[ $len_valid_uttlist -lt $num_utts_subset || + $len_trainsub_uttlist -lt $num_utts_subset ]]; then + echo "$0: Number of utterances is very small. Please check your data." && exit 1; fi -if [ -f $data/utt2uniq ]; then # this matters if you use data augmentation. - # because of this stage we can again have utts with lengths less than - # frames_per_eg - echo "File $data/utt2uniq exists, so augmenting valid_uttlist to" - echo "include all perturbed versions of the same 'real' utterances." - mv $dir/valid_uttlist $dir/valid_uttlist.tmp - utils/utt2spk_to_spk2utt.pl $data/utt2uniq > $dir/uniq2utt - cat $dir/valid_uttlist.tmp | utils/apply_map.pl $data/utt2uniq | \ - sort | uniq | utils/apply_map.pl $dir/uniq2utt | \ - awk '{for(n=1;n<=NF;n++) print $n;}' | sort > $dir/valid_uttlist - rm $dir/uniq2utt $dir/valid_uttlist.tmp -fi +echo "$0: Holding out $len_valid_uttlist utterances in validation set and" \ + "$len_trainsub_uttlist in training diagnostic set, out of total" \ + "$(wc -l < $data/utt2spk)." -echo "$0: creating egs. To ensure they are not deleted later you can do: touch $dir/.nodelete" -awk '{print $1}' $data/utt2spk | \ - utils/filter_scp.pl --exclude $dir/valid_uttlist | \ - utils/shuffle_list.pl 2>/dev/null | head -$num_utts_subset > $dir/train_subset_uttlist -len_uttlist=$(wc -l <$dir/train_subset_uttlist) -if [ $len_uttlist -lt $num_utts_subset ]; then - echo "Number of utterances is very small. Please check your data." && exit 1; -fi +echo "$0: creating egs. To ensure they are not deleted later you can do: touch $dir/.nodelete" ## Set up features. echo "$0: feature type is raw" @@ -342,9 +346,8 @@ if [ $stage -le 2 ]; then $egs_opts --normalization-fst-scale=$normalization_fst_scale \ $trans_mdl_opt $chaindir/normalization.fst \ "$train_subset_feats" ark,s,cs:- "ark:$dir/train_subset_all.cegs" || exit 1 - wait sleep 5 # wait for file system to sync. - echo "... Getting subsets of validation examples for diagnostics and combination." + echo "$0: Getting subsets of validation examples for diagnostics and combination." if $generate_egs_scp; then valid_diagnostic_output="ark,scp:$dir/valid_diagnostic.cegs,$dir/valid_diagnostic.scp" train_diagnostic_output="ark,scp:$dir/train_diagnostic.cegs,$dir/train_diagnostic.scp" @@ -365,7 +368,6 @@ if [ $stage -le 2 ]; then $cmd $dir/log/create_train_subset_diagnostic.log \ nnet3-chain-subset-egs --n=$num_egs_diagnostic ark:$dir/train_subset_all.cegs \ $train_diagnostic_output || exit 1 - wait sleep 5 # wait for file system to sync. if $generate_egs_scp; then cat $dir/valid_combine.cegs $dir/train_combine.cegs | \ @@ -375,7 +377,7 @@ if [ $stage -le 2 ]; then fi for f in $dir/{combine,train_diagnostic,valid_diagnostic}.cegs; do - [ ! -s $f ] && echo "No examples in file $f" && exit 1; + [ ! -s $f ] && echo "$0: No examples in file $f" && exit 1; done rm $dir/valid_all.cegs $dir/train_subset_all.cegs $dir/{train,valid}_combine.cegs ) || touch $dir/.error & @@ -412,7 +414,7 @@ if [ $stage -le 4 ]; then fi if [ -f $dir/.error ]; then - echo "Error detected while creating train/valid egs" && exit 1 + echo "$0: Error detected while creating train/valid egs" && exit 1 fi if [ $stage -le 5 ]; then @@ -485,11 +487,11 @@ fi wait if [ -f $dir/.error ]; then - echo "Error detected while creating train/valid egs" && exit 1 + echo "$0: Error detected while creating train/valid egs" && exit 1 fi if [ $stage -le 6 ]; then - echo "$0: removing temporary archives" + echo "$0: Removing temporary archives, alignments and lattices" ( cd $dir for f in $(ls -l . | grep 'cegs_orig' | awk '{ X=NF-1; Y=NF-2; if ($X == "->") print $Y, $NF; }'); do rm $f; done @@ -501,7 +503,6 @@ if [ $stage -le 6 ]; then # there are some extra soft links that we should delete. for f in $dir/cegs.*.*.ark; do rm $f; done fi - echo "$0: removing temporary alignments, lattices and transforms" rm $dir/ali.{ark,scp} 2>/dev/null rm $dir/lat_special.*.{ark,scp} 2>/dev/null fi diff --git a/egs/wsj/s5/steps/nnet3/chain/make_weighted_den_fst.sh b/egs/wsj/s5/steps/nnet3/chain/make_weighted_den_fst.sh index 7dade75a0ed..3b6371168ce 100755 --- a/egs/wsj/s5/steps/nnet3/chain/make_weighted_den_fst.sh +++ b/egs/wsj/s5/steps/nnet3/chain/make_weighted_den_fst.sh @@ -86,37 +86,44 @@ else fi fi -if [ $stage -le 1 ]; then - all_phones="" # will contain the names of the .gz files containing phones, - # with some members possibly repeated per the --num-repeats - # option - for n in `seq 0 $[num_alignments-1]`; do - this_num_repeats=${num_repeats_array[$n]} - this_alignment_dir=${ali_dirs[$n]} - num_jobs=$(cat $this_alignment_dir/num_jobs) - if ! [ "$this_num_repeats" -gt 0 ]; then - echo "Expected comma-separated list of integers for --num-repeats option, got '$num_repeats'" - exit 1 - fi +all_phones="" # will contain the names of the .gz files containing phones, + # with some members possibly repeated per the --num-repeats + # option +for n in `seq 0 $[num_alignments-1]`; do + this_num_repeats=${num_repeats_array[$n]} + this_alignment_dir=${ali_dirs[$n]} + num_jobs=$(cat $this_alignment_dir/num_jobs) + if ! [ "$this_num_repeats" -ge 0 ]; then + echo "Expected comma-separated list of integers for --num-repeats option, got '$num_repeats'" + exit 1 + fi + if [ $stage -le 1 ]; then for j in $(seq $num_jobs); do gunzip -c $this_alignment_dir/ali.$j.gz; done | \ ali-to-phones $this_alignment_dir/final.mdl ark:- "ark:|gzip -c >$dir/phones.$n.gz" || exit 1; + fi - all_phones="$all_phones $(for r in $(seq $this_num_repeats); do echo $dir/phones.$n.gz; done)" - done + if [ ! -s $dir/phones.$n.gz ]; then + echo "$dir/phones.$n.gz is empty or does not exist" + exit 1 + fi + all_phones="$all_phones $(for r in $(seq $this_num_repeats); do echo $dir/phones.$n.gz; done)" +done + +if [ $stage -le 2 ]; then $cmd $dir/log/make_phone_lm_fst.log \ gunzip -c $all_phones \| \ chain-est-phone-lm $lm_opts ark:- $dir/phone_lm.fst || exit 1; rm $dir/phones.*.gz fi -if [ $stage -le 2 ]; then +if [ $stage -le 3 ]; then copy-transition-model ${ali_dirs[0]}/final.mdl $dir/0.trans_mdl || exit 1; fi -if [ $stage -le 3 ]; then +if [ $stage -le 4 ]; then $cmd $dir/log/make_den_fst.log \ chain-make-den-fst $dir/tree $dir/0.trans_mdl \ $dir/phone_lm.fst \ diff --git a/egs/wsj/s5/steps/nnet3/chain/train.py b/egs/wsj/s5/steps/nnet3/chain/train.py index a832f57cd8f..40b65afe273 100755 --- a/egs/wsj/s5/steps/nnet3/chain/train.py +++ b/egs/wsj/s5/steps/nnet3/chain/train.py @@ -6,6 +6,8 @@ """ This script is based on steps/nnet3/chain/train.sh """ +from __future__ import division +from __future__ import print_function import argparse import logging diff --git a/egs/wsj/s5/steps/nnet3/components.py b/egs/wsj/s5/steps/nnet3/components.py index 34443d586ca..8e879579776 100644 --- a/egs/wsj/s5/steps/nnet3/components.py +++ b/egs/wsj/s5/steps/nnet3/components.py @@ -84,7 +84,7 @@ def AddBlockAffineLayer(config_lines, name, input, output_dim, num_blocks): def AddPermuteLayer(config_lines, name, input, column_map): components = config_lines['components'] component_nodes = config_lines['component-nodes'] - permute_indexes = ",".join(map(lambda x: str(x), column_map)) + permute_indexes = ",".join([str(x) for x in column_map]) components.append('component name={0}_permute type=PermuteComponent column-map={1}'.format(name, permute_indexes)) component_nodes.append('component-node name={0}_permute component={0}_permute input={1}'.format(name, input['descriptor'])) diff --git a/egs/wsj/s5/steps/nnet3/compute_output.sh b/egs/wsj/s5/steps/nnet3/compute_output.sh index da3cb704878..e55f705043b 100755 --- a/egs/wsj/s5/steps/nnet3/compute_output.sh +++ b/egs/wsj/s5/steps/nnet3/compute_output.sh @@ -54,7 +54,7 @@ fdir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print model=$srcdir/$iter.raw if [ ! -f $srcdir/$iter.raw ]; then - echo "$0: WARNING: no such file $srcdir/$iter.raw. Trying $srcdir/$iter.mdl instead." && exit 1 + echo "$0: WARNING: no such file $srcdir/$iter.raw. Trying $srcdir/$iter.mdl instead." model=$srcdir/$iter.mdl fi @@ -104,12 +104,15 @@ gpu_queue_opt= if $use_gpu; then gpu_queue_opt="--gpu 1" + suffix="-batch" gpu_opt="--use-gpu=yes" +else + gpu_opt="--use-gpu=no" fi if [ $stage -le 2 ]; then $cmd $gpu_queue_opt JOB=1:$nj $dir/log/compute_output.JOB.log \ - nnet3-compute $gpu_opt $ivector_opts $frame_subsampling_opt \ + nnet3-compute$suffix $gpu_opt $ivector_opts $frame_subsampling_opt \ --frames-per-chunk=$frames_per_chunk \ --extra-left-context=$extra_left_context \ --extra-right-context=$extra_right_context \ diff --git a/egs/wsj/s5/steps/nnet3/convert_nnet2_to_nnet3.py b/egs/wsj/s5/steps/nnet3/convert_nnet2_to_nnet3.py index f0a4341d12b..edc2f7e4617 100755 --- a/egs/wsj/s5/steps/nnet3/convert_nnet2_to_nnet3.py +++ b/egs/wsj/s5/steps/nnet3/convert_nnet2_to_nnet3.py @@ -6,6 +6,7 @@ # It requires knowledge of valid components which # can be modified in the configuration section below. +from __future__ import print_function import argparse, os, tempfile, logging, sys, shutil, fileinput, re from collections import defaultdict, namedtuple import numpy as np @@ -51,7 +52,7 @@ SPLICE_COMPONENTS = [c for c in NODE_NAMES if "Splice" in c] AFFINE_COMPONENTS = [c for c in NODE_NAMES if "Affine" in c] -KNOWN_COMPONENTS = NODE_NAMES.keys() +KNOWN_COMPONENTS = list(NODE_NAMES.keys()) # End configuration section logger = logging.getLogger(__name__) @@ -99,6 +100,7 @@ class Nnet3Model(object): def __init__(self): self.input_dim = -1 self.output_dim = -1 + self.ivector_dim = -1 self.counts = defaultdict(int) self.num_components = 0 self.components_read = 0 @@ -117,7 +119,10 @@ def add_component(self, component, pairs): Component = namedtuple("Component", "ident component pairs") if "" in pairs and self.input_dim == -1: - self.input_dim = pairs[""] + self.input_dim = int(pairs[""]) + + if "" in pairs and self.ivector_dim == -1: + self.ivector_dim = int(pairs[""]) # remove nnet2 specific tokens and catch descriptors if component == "" and "

" in pairs: @@ -158,13 +163,18 @@ def write_config(self, filename): config_string=config_string)) f.write("\n# Component nodes\n") - f.write("input-node name=input dim={0}\n".format(self.input_dim)) + if self.ivector_dim != -1: + f.write("input-node name=input dim={0}\n".format(self.input_dim-self.ivector_dim)) + f.write("input-node name=ivector dim={0}\n".format(self.ivector_dim)) + else: + f.write("input-node name=input dim={0}\n".format(self.input_dim)) previous_component = "input" for component in self.components: if component.ident == "splice": # Create splice string for the next node previous_component = make_splice_string(previous_component, - component.pairs[""]) + component.pairs[""], + component.pairs[""]) continue f.write("component-node name={name} component={name} " "input={inp}\n".format(name=component.ident, @@ -263,7 +273,7 @@ def parse_component(line, line_buffer): pairs = {} if component in SPLICE_COMPONENTS: - pairs = parse_splice_component(component, line, line_buffer) + line, pairs = parse_splice_component(component, line, line_buffer) elif component in AFFINE_COMPONENTS: pairs = parse_affine_component(component, line, line_buffer) elif component == "": @@ -334,7 +344,13 @@ def parse_splice_component(component, line, line_buffer): line = consume_token("", line) context = line.strip()[1:-1].split() - return {"" : input_dim, "" : context} + const_component_dim = 0 + line = next(line_buffer) # Context vector adds newline + line = consume_token("", line) + const_component_dim = int(line.strip().split()[0]) + + return line, {"" : input_dim, "" : context, + "" : const_component_dim} def parse_end_of_component(component, line, line_buffer): # Keeps reading until it hits the end tag for component @@ -421,7 +437,7 @@ def consume_token(token, line): return line.partition(token)[2] -def make_splice_string(nodename, context): +def make_splice_string(nodename, context, const_component_dim=0): """Generates splice string from a list of context. E.g. make_splice_string("renorm4", [-4, 4]) @@ -429,6 +445,8 @@ def make_splice_string(nodename, context): """ assert type(context) == list, "context argument must be a list" string = ["Offset({0}, {1})".format(nodename, i) for i in context] + if const_component_dim > 0: + string.append("ReplaceIndex(ivector, t, 0)") string = "Append(" + ", ".join(string) + ")" return string diff --git a/egs/wsj/s5/steps/nnet3/decode.sh b/egs/wsj/s5/steps/nnet3/decode.sh index 5b8374a5a1d..14dda2bd457 100755 --- a/egs/wsj/s5/steps/nnet3/decode.sh +++ b/egs/wsj/s5/steps/nnet3/decode.sh @@ -20,6 +20,10 @@ ivector_scale=1.0 lattice_beam=8.0 # Beam we use in lattice generation. iter=final num_threads=1 # if >1, will use gmm-latgen-faster-parallel +use_gpu=false # If true, will use a GPU, with nnet3-latgen-faster-batch. + # In that case it is recommended to set num-threads to a large + # number, e.g. 20 if you have that many free CPU slots on a GPU + # node, and to use a small number of jobs. scoring_opts= skip_diagnostics=false skip_scoring=false @@ -49,6 +53,9 @@ if [ $# -ne 3 ]; then echo " --iter # Iteration of model to decode; default is final." echo " --scoring-opts # options to local/score.sh" echo " --num-threads # number of threads to use, default 1." + echo " --use-gpu # default: false. If true, we recommend" + echo " # to use large --num-threads as the graph" + echo " # search becomes the limiting factor." exit 1; fi @@ -74,7 +81,16 @@ done sdata=$data/split$nj; cmvn_opts=`cat $srcdir/cmvn_opts` || exit 1; thread_string= -[ $num_threads -gt 1 ] && thread_string="-parallel --num-threads=$num_threads" +if $use_gpu; then + if [ $num_threads -eq 1 ]; then + echo "$0: **Warning: we recommend to use --num-threads > 1 for GPU-based decoding." + fi + thread_string="-batch --num-threads=$num_threads" + queue_opt="--num-threads $num_threads --gpu 1" +elif [ $num_threads -gt 1 ]; then + thread_string="-parallel --num-threads=$num_threads" + queue_opt="--num-threads $num_threads" +fi mkdir -p $dir/log [[ -d $sdata && $data/feats.scp -ot $sdata ]] || split_data.sh $data $nj || exit 1; @@ -104,7 +120,7 @@ if [ -f $srcdir/frame_subsampling_factor ]; then fi if [ $stage -le 1 ]; then - $cmd --num-threads $num_threads JOB=1:$nj $dir/log/decode.JOB.log \ + $cmd $queue_opt JOB=1:$nj $dir/log/decode.JOB.log \ nnet3-latgen-faster$thread_string $ivector_opts $frame_subsampling_opt \ --frames-per-chunk=$frames_per_chunk \ --extra-left-context=$extra_left_context \ diff --git a/egs/wsj/s5/steps/nnet3/decode_grammar.sh b/egs/wsj/s5/steps/nnet3/decode_grammar.sh new file mode 100755 index 00000000000..7ee1efeb7df --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/decode_grammar.sh @@ -0,0 +1,139 @@ +#!/bin/bash + +# Copyright 2012-2015 Johns Hopkins University (Author: Daniel Povey). +# Apache 2.0. + +# This is a version of ./decode.sh that allows you to decode with a GrammarFst. +# See kaldi-asr.org/doc/grammar.html for an overview of what this is about. + +# Begin configuration section. +stage=1 +nj=4 # number of decoding jobs. +acwt=0.1 # Just a default value, used for adaptation and beam-pruning.. +post_decode_acwt=1.0 # can be used in 'chain' systems to scale acoustics by 10 so the + # regular scoring script works. +cmd=run.pl +beam=15.0 +frames_per_chunk=50 +max_active=7000 +min_active=200 +ivector_scale=1.0 +lattice_beam=8.0 # Beam we use in lattice generation. +iter=final +scoring_opts= +skip_diagnostics=false +skip_scoring=false +extra_left_context=0 +extra_right_context=0 +extra_left_context_initial=-1 +extra_right_context_final=-1 +online_ivector_dir= +minimize=false +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +[ -f ./path.sh ] && . ./path.sh; # source the path. +. utils/parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: $0 [options] " + echo "e.g.: steps/nnet3/decode.sh --nj 8 \\" + echo "--online-ivector-dir exp/nnet2_online/ivectors_test_eval92 \\" + echo " exp/tri4b/graph_bg data/test_eval92_hires $dir/decode_bg_eval92" + echo "main options (for others, see top of script file)" + echo " --config # config containing options" + echo " --nj # number of parallel jobs" + echo " --cmd # Command to run in parallel with" + echo " --beam # Decoding beam; default 15.0" + echo " --iter # Iteration of model to decode; default is final." + echo " --scoring-opts # options to local/score.sh" + exit 1; +fi + +graphdir=$1 +data=$2 +dir=$3 +srcdir=`dirname $dir`; # Assume model directory one level up from decoding directory. +model=$srcdir/$iter.mdl + + +extra_files= +if [ ! -z "$online_ivector_dir" ]; then + steps/nnet2/check_ivectors_compatible.sh $srcdir $online_ivector_dir || exit 1 + extra_files="$online_ivector_dir/ivector_online.scp $online_ivector_dir/ivector_period" +fi + +utils/lang/check_phones_compatible.sh {$srcdir,$graphdir}/phones.txt || exit 1 + +for f in $graphdir/HCLG.gra $data/feats.scp $model $extra_files; do + [ ! -f $f ] && echo "$0: no such file $f" && exit 1; +done + +sdata=$data/split$nj; +cmvn_opts=`cat $srcdir/cmvn_opts` || exit 1; + +mkdir -p $dir/log +[[ -d $sdata && $data/feats.scp -ot $sdata ]] || split_data.sh $data $nj || exit 1; +echo $nj > $dir/num_jobs + + +## Set up features. +echo "$0: feature type is raw" + +feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- |" + +if [ ! -z "$online_ivector_dir" ]; then + ivector_period=$(cat $online_ivector_dir/ivector_period) || exit 1; + ivector_opts="--online-ivectors=scp:$online_ivector_dir/ivector_online.scp --online-ivector-period=$ivector_period" +fi + +if [ "$post_decode_acwt" == 1.0 ]; then + lat_wspecifier="ark:|gzip -c >$dir/lat.JOB.gz" +else + lat_wspecifier="ark:|lattice-scale --acoustic-scale=$post_decode_acwt ark:- ark:- | gzip -c >$dir/lat.JOB.gz" +fi + +frame_subsampling_opt= +if [ -f $srcdir/frame_subsampling_factor ]; then + # e.g. for 'chain' systems + frame_subsampling_opt="--frame-subsampling-factor=$(cat $srcdir/frame_subsampling_factor)" +fi + +if [ $stage -le 1 ]; then + $cmd JOB=1:$nj $dir/log/decode.JOB.log \ + nnet3-latgen-grammar $ivector_opts $frame_subsampling_opt \ + --frames-per-chunk=$frames_per_chunk \ + --extra-left-context=$extra_left_context \ + --extra-right-context=$extra_right_context \ + --extra-left-context-initial=$extra_left_context_initial \ + --extra-right-context-final=$extra_right_context_final \ + --minimize=$minimize --max-active=$max_active --min-active=$min_active --beam=$beam \ + --lattice-beam=$lattice_beam --acoustic-scale=$acwt --allow-partial=true \ + --word-symbol-table=$graphdir/words.txt "$model" \ + $graphdir/HCLG.gra "$feats" "$lat_wspecifier" || exit 1; +fi + + +if [ $stage -le 2 ]; then + if ! $skip_diagnostics ; then + [ ! -z $iter ] && iter_opt="--iter $iter" + steps/diagnostic/analyze_lats.sh --cmd "$cmd" $iter_opt $graphdir $dir + fi +fi + + +# The output of this script is the files "lat.*.gz"-- we'll rescore this at +# different acoustic scales to get the final output. +if [ $stage -le 3 ]; then + if ! $skip_scoring ; then + [ ! -x local/score.sh ] && \ + echo "Not scoring because local/score.sh does not exist or not executable." && exit 1; + echo "score best paths" + [ "$iter" != "final" ] && iter_opt="--iter $iter" + local/score.sh $scoring_opts --cmd "$cmd" $data $graphdir $dir + echo "score confidence and timing with sclite" + fi +fi +echo "Decoding done." +exit 0; diff --git a/egs/wsj/s5/steps/nnet3/decode_score_fusion.sh b/egs/wsj/s5/steps/nnet3/decode_score_fusion.sh index 2fcc4a1944d..cb678e84245 100755 --- a/egs/wsj/s5/steps/nnet3/decode_score_fusion.sh +++ b/egs/wsj/s5/steps/nnet3/decode_score_fusion.sh @@ -38,7 +38,7 @@ extra_right_context=0 extra_left_context_initial=-1 extra_right_context_final=-1 online_ivector_dir= -frame_subsampling_factor=1 +frame_subsampling_factor= frames_per_chunk=150 average=true @@ -76,10 +76,10 @@ write_compact=true # If set to false, then writes the lattice in non-compact f if [ $# -lt 5 ]; then echo "Usage: $0 [options] [ ... ] " - echo "e.g.: local/socal/score_fusion.sh --nj 8 \\" - echo "--online-ivector-dir exp/nnet3/ivectors_test_eval92 \\" - echo " data/test_eval92_hires exp/nnet3/tdnn/graph exp/nnet3/tdnn/output exp/nnet3/tdnn1/output .. \\" - echo " exp/nnet3/tdnn_comb/decode_dev" + echo "e.g.: steps/nnet3/decode_score_fusion.sh --nj 8 \\" + echo " --online-ivector-dir exp/nnet3/ivectors_test \\" + echo " data/test_hires exp/nnet3/tdnn/graph exp/nnet3/tdnn/output exp/nnet3/tdnn1/output .. \\" + echo " exp/nnet3/tdnn_comb/decode_test" echo "main options (for others, see top of script file)" echo " --config # config containing options" echo " --nj # number of parallel jobs" @@ -110,15 +110,28 @@ if [ ! -z "$online_ivector_dir" ]; then ivector_opts="--online-ivectors=scp:$online_ivector_dir/ivector_online.scp --online-ivector-period=$ivector_period" fi +# assign frame_subsampling_factor automatically if empty +if [ -z $frame_subsampling_factor ]; then + frame_subsampling_factor=`cat ${model_dirs[0]}/frame_subsampling_factor` || exit 1; +fi + +# check if standard chain system or not. +if [ $frame_subsampling_factor -eq 3 ]; then + if [ $acwt != 1.0 ] || [ $post_decode_acwt != 10.0 ]; then + echo -e '\n\n' + echo "$0 WARNING: In standard chain system, acwt = 1.0, post_decode_acwt = 10.0" + echo "$0 WARNING: Your acwt = $acwt, post_decode_acwt = $post_decode_acwt" + echo "$0 WARNING: This is OK if you know what you are doing." + echo -e '\n\n' + fi +fi + frame_subsampling_opt= if [ $frame_subsampling_factor -ne 1 ]; then # e.g. for 'chain' systems frame_subsampling_opt="--frame-subsampling-factor=$frame_subsampling_factor" fi -# convert $dir to absolute pathname -fdir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $dir ${PWD}` - # Possibly use multi-threaded decoder thread_string= [ $num_threads -gt 1 ] && thread_string="-parallel --num-threads=$num_threads" @@ -143,11 +156,13 @@ for i in `seq 0 $[num_sys-1]`; do # check that they have the same frame-subsampling-factor if [ $frame_subsampling_factor -ne `cat $srcdir/frame_subsampling_factor` ]; then - echo "$0 frame_subsampling_factor must be the same." + echo "$0 frame_subsampling_factor must be the same.\\" + echo "Default:$frame_subsampling_factor \\" + echo "In $srcdir:`cat $srcdir/frame_subsampling_factor`" exit 0; fi - for f in $data/feats.scp $model $extra_files; do + for f in $data/feats.scp $model $extra_files; do [ ! -f $f ] && echo "$0: no such file $f" && exit 1; done @@ -231,9 +246,9 @@ fi if [ $stage -le 0 ]; then $cmd --num-threads $num_threads JOB=1:$nj $dir/log/decode.JOB.log \ - matrix-sum --average=$average "${models[@]}" ark:- \| \ - latgen-faster-mapped$thread_string --lattice-beam=$lattice_beam --acoustic-scale=$acwt --allow-partial=true \ - --minimize=$minimize --max-active=$max_active --min-active=$min_active --beam=$beam \ + matrix-sum --average=$average "${models[@]}" ark:- \| \ + latgen-faster-mapped$thread_string --lattice-beam=$lattice_beam --acoustic-scale=$acwt --allow-partial=true \ + --minimize=$minimize --max-active=$max_active --min-active=$min_active --beam=$beam \ --word-symbol-table=$graphdir/words.txt ${extra_opts} "$model" \ $graphdir/HCLG.fst ark:- "$lat_wspecifier" fi @@ -259,4 +274,3 @@ fi exit 0 - diff --git a/egs/wsj/s5/steps/nnet3/dot/descriptor_parser.py b/egs/wsj/s5/steps/nnet3/dot/descriptor_parser.py index a46d144d0b6..ee6fa11b5c9 100644 --- a/egs/wsj/s5/steps/nnet3/dot/descriptor_parser.py +++ b/egs/wsj/s5/steps/nnet3/dot/descriptor_parser.py @@ -33,7 +33,7 @@ def ParseSubsegmentsAndArguments(segment_endpoints, sub_segments, arguments, inp else: arguments.append(sub_segment_name) else: - arguments = map(lambda x: re.sub(',','', x.strip()), input_string[segment_endpoints[0]:segment_endpoints[1]+1].split()) + arguments = [re.sub(',','', x.strip()) for x in input_string[segment_endpoints[0]:segment_endpoints[1]+1].split()] sub_segments = [] return sub_segments, arguments diff --git a/egs/wsj/s5/steps/nnet3/dot/nnet3_to_dot.py b/egs/wsj/s5/steps/nnet3/dot/nnet3_to_dot.py index f8cd357fa3b..4230b32aa7c 100755 --- a/egs/wsj/s5/steps/nnet3/dot/nnet3_to_dot.py +++ b/egs/wsj/s5/steps/nnet3/dot/nnet3_to_dot.py @@ -189,7 +189,7 @@ def ProcessSumDescriptor(segment, parent_node_name, affix, edge_attributes = Non sub_segment = segment['sub_segments'][i] part_name = "{0}{1}{2}".format(desc_name, sub_segment['name'], i) names.append("<{0}> part {1}".format(GetDotNodeName(part_name)['node'], i)) - dot_graph += DescriptorSegmentToDot(sub_segment, "{0}:{1}".format(desc_name, part_name), desc_name+"_"+str(i)) + dot_graph += DescriptorSegmentToDot(sub_segment, "{0}:{1}".format(desc_name, part_name), "{0}_{1}".format(desc_name, i)) # link the sum node parts to corresponding segments part_index = len(segment['sub_segments']) @@ -321,7 +321,7 @@ def Nnet3ComponentToDot(component_config, component_attributes = None): label = '' if component_attributes is None: component_attributes = component_config.keys() - attributes_to_print = set(component_attributes).intersection(component_config.keys()) + attributes_to_print = set(component_attributes).intersection(list(component_config.keys())) # process the known fields for key in attributes_to_print: if key in component_config: diff --git a/egs/wsj/s5/steps/nnet3/get_degs.sh b/egs/wsj/s5/steps/nnet3/get_degs.sh index 8098b59c4ad..7853daa4563 100755 --- a/egs/wsj/s5/steps/nnet3/get_degs.sh +++ b/egs/wsj/s5/steps/nnet3/get_degs.sh @@ -471,7 +471,6 @@ if [ $stage -le 10 ] && $cleanup; then fi -exit 0 - - echo "$0: Finished decoding and preparing training examples" + +exit 0 diff --git a/egs/wsj/s5/steps/nnet3/get_egs_targets.sh b/egs/wsj/s5/steps/nnet3/get_egs_targets.sh index 2e368283ed4..784693ee44c 100755 --- a/egs/wsj/s5/steps/nnet3/get_egs_targets.sh +++ b/egs/wsj/s5/steps/nnet3/get_egs_targets.sh @@ -130,8 +130,8 @@ if ! [ $num_utts -gt $[$num_utts_subset_valid*4] ]; then fi # Get list of validation utterances. -awk '{print $1}' $data/utt2spk | utils/shuffle_list.pl | head -$num_utts_subset_valid | sort \ - > $dir/valid_uttlist || exit 1; +awk '{print $1}' $data/utt2spk | utils/shuffle_list.pl 2>/dev/null | head -$num_utts_subset_valid | sort \ + > $dir/valid_uttlist if [ -f $data/utt2uniq ]; then # this matters if you use data augmentation. echo "File $data/utt2uniq exists, so augmenting valid_uttlist to" @@ -145,7 +145,7 @@ if [ -f $data/utt2uniq ]; then # this matters if you use data augmentation. fi awk '{print $1}' $data/utt2spk | utils/filter_scp.pl --exclude $dir/valid_uttlist | \ - utils/shuffle_list.pl | head -$num_utts_subset_train | sort > $dir/train_subset_uttlist || exit 1; + utils/shuffle_list.pl 2>/dev/null | head -$num_utts_subset_train | sort > $dir/train_subset_uttlist ## Set up features. echo "$0: feature type is raw" diff --git a/egs/wsj/s5/steps/nnet3/get_saturation.pl b/egs/wsj/s5/steps/nnet3/get_saturation.pl index ed18fc1c399..979736f0847 100755 --- a/egs/wsj/s5/steps/nnet3/get_saturation.pl +++ b/egs/wsj/s5/steps/nnet3/get_saturation.pl @@ -74,6 +74,14 @@ if (! $ok) { print STDERR "Could not parse at least one of the avg-deriv values in the following info line: $_"; } + } elsif (m/type=.*GruNonlinearityComponent/) { + if (m/deriv-avg=[^m]+mean=([^,]+),/) { + $num_nonlinearities += 1; + my $this_saturation = 1.0 - ($1 / 1.0); + $total_saturation += $this_saturation; + } else { + print STDERR "$0: could not make sense of line (no deriv-avg?): $_"; + } } } diff --git a/egs/wsj/s5/steps/nnet3/get_successful_models.py b/egs/wsj/s5/steps/nnet3/get_successful_models.py index 3661d91b8d5..e6dcf376a51 100755 --- a/egs/wsj/s5/steps/nnet3/get_successful_models.py +++ b/egs/wsj/s5/steps/nnet3/get_successful_models.py @@ -56,7 +56,7 @@ if (loss[max_index] - loss[i]) <= args.difference_threshold: accepted_models.append(i+1) - model_list = " ".join(map(lambda x: str(x), accepted_models)) + model_list = " ".join([str(x) for x in accepted_models]) print(model_list) if len(accepted_models) != args.num_models: diff --git a/egs/wsj/s5/steps/nnet3/lstm/make_configs.py b/egs/wsj/s5/steps/nnet3/lstm/make_configs.py index b80a8d4045b..8a533465f07 100755 --- a/egs/wsj/s5/steps/nnet3/lstm/make_configs.py +++ b/egs/wsj/s5/steps/nnet3/lstm/make_configs.py @@ -181,7 +181,7 @@ def ParseSpliceString(splice_indexes, label_delay=None): splice_array = [] try: for i in range(len(split1)): - indexes = map(lambda x: int(x), split1[i].strip().split(",")) + indexes = [int(x) for x in split1[i].strip().split(",")] print(indexes) if len(indexes) < 1: raise ValueError("invalid --splice-indexes argument, too-short element: " @@ -214,12 +214,12 @@ def ParseLstmDelayString(lstm_delay): lstm_delay_array = [] try: for i in range(len(split1)): - indexes = map(lambda x: int(x), split1[i].strip().lstrip('[').rstrip(']').strip().split(",")) + indexes = [int(x) for x in split1[i].strip().lstrip('[').rstrip(']').strip().split(",")] if len(indexes) < 1: raise ValueError("invalid --lstm-delay argument, too-short element: " + lstm_delay) elif len(indexes) == 2 and indexes[0] * indexes[1] >= 0: - raise ValueError('Warning: ' + str(indexes) + ' is not a standard BLSTM mode. There should be a negative delay for the forward, and a postive delay for the backward.') + raise ValueError('Warning: {} is not a standard BLSTM mode. There should be a negative delay for the forward, and a postive delay for the backward.'.format(indexes)) if len(indexes) == 2 and indexes[0] > 0: # always a negative delay followed by a postive delay indexes[0], indexes[1] = indexes[1], indexes[0] lstm_delay_array.append(indexes) @@ -335,9 +335,9 @@ def ProcessSpliceIndexes(config_dir, splice_indexes, label_delay, num_lstm_layer # write the files used by other scripts like steps/nnet3/get_egs.sh f = open(config_dir + "/vars", "w") - print('model_left_context=' + str(left_context), file=f) - print('model_right_context=' + str(right_context), file=f) - print('num_hidden_layers=' + str(num_hidden_layers), file=f) + print('model_left_context={}'.format(left_context), file=f) + print('model_right_context={}'.format(right_context), file=f) + print('num_hidden_layers={}'.format(num_hidden_layers), file=f) # print('initial_right_context=' + str(splice_array[0][-1]), file=f) f.close() diff --git a/egs/wsj/s5/steps/nnet3/make_tdnn_configs.py b/egs/wsj/s5/steps/nnet3/make_tdnn_configs.py index 162fda16d16..d121be6d899 100644 --- a/egs/wsj/s5/steps/nnet3/make_tdnn_configs.py +++ b/egs/wsj/s5/steps/nnet3/make_tdnn_configs.py @@ -98,21 +98,21 @@ input_dim = len(splice_array[0]) * args.feat_dim + args.ivector_dim f = open(args.config_dir + "/vars", "w") -print('left_context=' + str(left_context), file=f) -print('right_context=' + str(right_context), file=f) +print('left_context={}'.format(left_context), file=f) +print('right_context={}'.format(right_context), file=f) # the initial l/r contexts are actually not needed. # print('initial_left_context=' + str(splice_array[0][0]), file=f) # print('initial_right_context=' + str(splice_array[0][-1]), file=f) -print('num_hidden_layers=' + str(num_hidden_layers), file=f) +print('num_hidden_layers={}'.format(num_hidden_layers), file=f) f.close() f = open(args.config_dir + "/init.config", "w") print('# Config file for initializing neural network prior to', file=f) print('# preconditioning matrix computation', file=f) -print('input-node name=input dim=' + str(args.feat_dim), file=f) +print('input-node name=input dim={}'.format(args.feat_dim), file=f) list=[ ('Offset(input, {0})'.format(n) if n != 0 else 'input' ) for n in splice_array[0] ] if args.ivector_dim > 0: - print('input-node name=ivector dim=' + str(args.ivector_dim), file=f) + print('input-node name=ivector dim={}'.format(args.ivector_dim), file=f) list.append('ReplaceIndex(ivector, t, 0)') # example of next line: # output-node name=output input="Append(Offset(input, -3), Offset(input, -2), Offset(input, -1), ... , Offset(input, 3), ReplaceIndex(ivector, t, 0))" diff --git a/egs/wsj/s5/steps/nnet3/multilingual/allocate_multilingual_examples.py b/egs/wsj/s5/steps/nnet3/multilingual/allocate_multilingual_examples.py index 54c65eb5403..a407869854d 100755 --- a/egs/wsj/s5/steps/nnet3/multilingual/allocate_multilingual_examples.py +++ b/egs/wsj/s5/steps/nnet3/multilingual/allocate_multilingual_examples.py @@ -40,7 +40,6 @@ """ -from __future__ import print_function import os, argparse, sys, random import logging import traceback @@ -163,7 +162,7 @@ def process_multilingual_egs(args): "not include any examples from this lang.") logger.info("The proportion of egs from lang {} is {:.2f}. The number of blocks " "per archive for this lang is approximately {:.2f}. " - "{}".format(lang, lang_to_num_examples[lang] / tot_num_egs, + "{}".format(lang, float(lang_to_num_examples[lang]) / tot_num_egs, blocks_per_archive_this_lang, warning)) @@ -173,11 +172,11 @@ def process_multilingual_egs(args): lang_to_num_remaining_egs = [n for n in lang_to_num_examples] for archive_index in range(num_archives + 1): # +1 is because we write to the last archive in two rounds num_remaining_archives = num_archives - archive_index - num_remaining_blocks = num_remaining_egs / args.block_size + num_remaining_blocks = float(num_remaining_egs) / args.block_size last_round = (archive_index == num_archives) if not last_round: - num_blocks_this_archive = int(round(num_remaining_blocks / num_remaining_archives)) + num_blocks_this_archive = int(round(float(num_remaining_blocks) / num_remaining_archives)) logger.info("Generating archive {} containing {} blocks...".format(archive_index, num_blocks_this_archive)) else: # This is the second round for the last archive. Flush all the remaining egs... archive_index = num_archives - 1 @@ -194,7 +193,7 @@ def process_multilingual_egs(args): for block_index in range(num_blocks_this_archive): # Find the lang with the highest proportion of remaining examples - remaining_proportions = [remain / tot for remain, tot in zip(lang_to_num_remaining_egs, lang_to_num_examples)] + remaining_proportions = [float(remain) / tot for remain, tot in zip(lang_to_num_remaining_egs, lang_to_num_examples)] lang_index, max_proportion = max(enumerate(remaining_proportions), key=lambda a: a[1]) # Read 'block_size' examples from the selected lang and write them to the current output scp file: diff --git a/egs/wsj/s5/steps/nnet3/report/generate_plots.py b/egs/wsj/s5/steps/nnet3/report/generate_plots.py index 93cbc940c33..572e2cf08b7 100755 --- a/egs/wsj/s5/steps/nnet3/report/generate_plots.py +++ b/egs/wsj/s5/steps/nnet3/report/generate_plots.py @@ -4,6 +4,7 @@ # 2016 Vimal Manohar # Apache 2.0. +from __future__ import division import argparse import errno import logging @@ -97,7 +98,7 @@ def get_args(): g_plot_colors = ['red', 'blue', 'green', 'black', 'magenta', 'yellow', 'cyan'] -class LatexReport: +class LatexReport(object): """Class for writing a Latex report""" def __init__(self, pdf_file): @@ -422,7 +423,7 @@ def generate_nonlin_stats_plots(exp_dir, output_dir, plot, comparison_dir=None, f.write("\n".join(iter_stat_report)) f.close() if plot: - main_component_names = main_stat_tables.keys() + main_component_names = list(main_stat_tables.keys()) main_component_names.sort() plot_component_names = set(main_component_names) @@ -528,13 +529,13 @@ def generate_clipped_proportion_plots(exp_dir, output_dir, plot, file = open("{dir}/clipped_proportion.log".format(dir=output_dir), "w") iter_stat_report = "" for row in main_cp_stats: - iter_stat_report += "\t".join(map(lambda x: str(x), row)) + "\n" + iter_stat_report += "\t".join([str(x) for x in row]) + "\n" file.write(iter_stat_report) file.close() if plot: main_component_names = ( - stats_per_dir[exp_dir]['cp_per_iter_per_component'].keys()) + list(stats_per_dir[exp_dir]['cp_per_iter_per_component'].keys())) main_component_names.sort() plot_component_names = set(main_component_names) for dir in dirs: @@ -635,22 +636,21 @@ def generate_parameter_diff_plots(exp_dir, output_dir, plot, except KeyError: total_missing_iterations += 1 iter_data.append("NA") - if (total_missing_iterations/len(component_names) > 20 + if (float(total_missing_iterations)/len(component_names) > 20 and not gave_user_warning): logger.warning("There are more than {0} missing " "iterations per component. " "Something might be wrong.".format( - total_missing_iterations - / len(component_names))) + float(total_missing_iterations)/ len(component_names))) gave_user_warning = True f.write(" ".join(iter_data)+"\n") if plot: # get the component names - diff_type = key_file.keys()[0] - main_component_names = stats_per_dir[exp_dir][diff_type][ - 'progress_per_component'].keys() + diff_type = list(key_file.keys())[0] + main_component_names = list(stats_per_dir[exp_dir][diff_type][ + 'progress_per_component'].keys()) main_component_names.sort() plot_component_names = set(main_component_names) diff --git a/egs/wsj/s5/steps/nnet3/report/summarize_compute_debug_timing.py b/egs/wsj/s5/steps/nnet3/report/summarize_compute_debug_timing.py index 442ca4e35cf..5c74eaf128c 100755 --- a/egs/wsj/s5/steps/nnet3/report/summarize_compute_debug_timing.py +++ b/egs/wsj/s5/steps/nnet3/report/summarize_compute_debug_timing.py @@ -7,6 +7,7 @@ # we're using python 3.x style print but want it to work in python 2.x, from __future__ import print_function +from __future__ import division import sys import re import argparse @@ -101,7 +102,7 @@ def Main(): total_time = sum(command_times.values()) sorted_commands = sorted(command_times.items(), key = lambda x: x[1], reverse = True) for item in sorted_commands: - print("{c} : time {t} : fraction {f}".format(c=item[0], t=item[1], f=item[1] / total_time)) + print("{c} : time {t} : fraction {f}".format(c=item[0], t=item[1], f=float(item[1]) / total_time)) if __name__ == "__main__": diff --git a/egs/wsj/s5/steps/nnet3/tdnn/make_configs.py b/egs/wsj/s5/steps/nnet3/tdnn/make_configs.py index 5445b16e165..9e7e92f6768 100755 --- a/egs/wsj/s5/steps/nnet3/tdnn/make_configs.py +++ b/egs/wsj/s5/steps/nnet3/tdnn/make_configs.py @@ -4,6 +4,7 @@ # we're using python 3.x style print but want it to work in python 2.x, from __future__ import print_function +from __future__ import division import os import argparse import shlex @@ -519,10 +520,10 @@ def MakeConfigs(config_dir, splice_indexes_string, # write the files used by other scripts like steps/nnet3/get_egs.sh f = open(config_dir + "/vars", "w") - print('model_left_context=' + str(left_context), file=f) - print('model_right_context=' + str(right_context), file=f) - print('num_hidden_layers=' + str(num_hidden_layers), file=f) - print('num_targets=' + str(num_targets), file=f) + print('model_left_context={}'.format(left_context), file=f) + print('model_right_context={}'.format(right_context), file=f) + print('num_hidden_layers={}'.format(num_hidden_layers), file=f) + print('num_targets={}'.format(num_targets), file=f) print('add_lda=' + ('true' if add_lda else 'false'), file=f) print('include_log_softmax=' + ('true' if include_log_softmax else 'false'), file=f) print('objective_type=' + objective_type, file=f) diff --git a/egs/wsj/s5/steps/nnet3/train_dnn.py b/egs/wsj/s5/steps/nnet3/train_dnn.py index 0c881b4dbdf..e72b29297a4 100755 --- a/egs/wsj/s5/steps/nnet3/train_dnn.py +++ b/egs/wsj/s5/steps/nnet3/train_dnn.py @@ -9,6 +9,7 @@ """ from __future__ import print_function +from __future__ import division import argparse import logging import os @@ -193,7 +194,7 @@ def train(args, run_opts): shutil.copy('{0}/tree'.format(args.ali_dir), args.dir) with open('{0}/num_jobs'.format(args.dir), 'w') as f: - f.write(str(num_jobs)) + f.write('{}'.format(num_jobs)) if args.input_model is None: config_dir = '{0}/configs'.format(args.dir) @@ -301,8 +302,7 @@ def train(args, run_opts): num_archives_expanded = num_archives * args.frames_per_eg num_archives_to_process = int(args.num_epochs * num_archives_expanded) num_archives_processed = 0 - num_iters = ((num_archives_to_process * 2) - / (args.num_jobs_initial + args.num_jobs_final)) + num_iters = int(num_archives_to_process * 2 / (args.num_jobs_initial + args.num_jobs_final)) # If do_final_combination is True, compute the set of models_to_combine. # Otherwise, models_to_combine will be none. diff --git a/egs/wsj/s5/steps/nnet3/train_raw_dnn.py b/egs/wsj/s5/steps/nnet3/train_raw_dnn.py index 34214169d5d..ffccf443b99 100755 --- a/egs/wsj/s5/steps/nnet3/train_raw_dnn.py +++ b/egs/wsj/s5/steps/nnet3/train_raw_dnn.py @@ -9,6 +9,7 @@ """ from __future__ import print_function +from __future__ import division import argparse import logging import pprint @@ -101,7 +102,14 @@ def get_args(): help="Directory with features used for training " "the neural network.") parser.add_argument("--targets-scp", type=str, required=False, - help="Targets for training neural network.") + help="""Targets for training neural network. + This is a kaldi-format SCP file of target matrices. + . + The target matrix's column dim must match + the neural network output dim, and the + row dim must match the number of output frames + i.e. after subsampling if "--frame-subsampling-factor" + option is passed to --egs.opts.""") parser.add_argument("--dir", type=str, required=True, help="Directory to store the models and " "all other files.") @@ -314,8 +322,7 @@ def train(args, run_opts): num_archives_expanded = num_archives * args.frames_per_eg num_archives_to_process = int(args.num_epochs * num_archives_expanded) num_archives_processed = 0 - num_iters = ((num_archives_to_process * 2) - / (args.num_jobs_initial + args.num_jobs_final)) + num_iters = int((num_archives_to_process * 2) / (args.num_jobs_initial + args.num_jobs_final)) # If do_final_combination is True, compute the set of models_to_combine. # Otherwise, models_to_combine will be none. diff --git a/egs/wsj/s5/steps/nnet3/train_raw_rnn.py b/egs/wsj/s5/steps/nnet3/train_raw_rnn.py index e797c86b323..c704b0725d3 100755 --- a/egs/wsj/s5/steps/nnet3/train_raw_rnn.py +++ b/egs/wsj/s5/steps/nnet3/train_raw_rnn.py @@ -10,6 +10,7 @@ raw neural network instead of an acoustic model. """ from __future__ import print_function +from __future__ import division import argparse import logging import pprint @@ -368,8 +369,7 @@ def train(args, run_opts): # avg_num_jobs=(num_jobs_initial+num_jobs_final)/2. num_archives_to_process = int(args.num_epochs * num_archives) num_archives_processed = 0 - num_iters = ((num_archives_to_process * 2) - / (args.num_jobs_initial + args.num_jobs_final)) + num_iters = int((num_archives_to_process * 2) / (args.num_jobs_initial + args.num_jobs_final)) # If do_final_combination is True, compute the set of models_to_combine. # Otherwise, models_to_combine will be none. @@ -509,7 +509,8 @@ def train(args, run_opts): run_opts=run_opts, chunk_width=args.chunk_width, get_raw_nnet_from_am=False, compute_per_dim_accuracy=args.compute_per_dim_accuracy, - max_objective_evaluations=args.max_objective_evaluations) + max_objective_evaluations=args.max_objective_evaluations, + use_multitask_egs=use_multitask_egs) else: common_lib.force_symlink("{0}.raw".format(num_iters), "{0}/final.raw".format(args.dir)) diff --git a/egs/wsj/s5/steps/nnet3/train_rnn.py b/egs/wsj/s5/steps/nnet3/train_rnn.py index 25e7dced19b..ab2aa0c4d8d 100755 --- a/egs/wsj/s5/steps/nnet3/train_rnn.py +++ b/egs/wsj/s5/steps/nnet3/train_rnn.py @@ -8,6 +8,7 @@ """ from __future__ import print_function +from __future__ import division import argparse import logging import os @@ -248,7 +249,7 @@ def train(args, run_opts): shutil.copy('{0}/tree'.format(args.ali_dir), args.dir) with open('{0}/num_jobs'.format(args.dir), 'w') as f: - f.write(str(num_jobs)) + f.write('{}'.format(num_jobs)) config_dir = '{0}/configs'.format(args.dir) var_file = '{0}/vars'.format(config_dir) @@ -369,8 +370,7 @@ def train(args, run_opts): # avg_num_jobs=(num_jobs_initial+num_jobs_final)/2. num_archives_to_process = int(args.num_epochs * num_archives) num_archives_processed = 0 - num_iters = ((num_archives_to_process * 2) - / (args.num_jobs_initial + args.num_jobs_final)) + num_iters = int((num_archives_to_process * 2) / (args.num_jobs_initial + args.num_jobs_final)) # If do_final_combination is True, compute the set of models_to_combine. # Otherwise, models_to_combine will be none. diff --git a/egs/wsj/s5/steps/nnet3/xconfig_to_config.py b/egs/wsj/s5/steps/nnet3/xconfig_to_config.py new file mode 100755 index 00000000000..952745cea9f --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/xconfig_to_config.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 + +# Copyright 2016-2018 Johns Hopkins University (Dan Povey) +# 2016 Vijayaditya Peddinti +# 2017 Google Inc. (vpeddinti@google.com) +# Apache 2.0. + +# This is like xconfig_to_configs.py but with a simpler interface; it writes +# to a single named file. + + +import argparse +import os +import sys +from collections import defaultdict + +sys.path.insert(0, 'steps/') +# the following is in case we weren't running this from the normal directory. +sys.path.insert(0, os.path.realpath(os.path.dirname(sys.argv[0])) + '/') + +import libs.nnet3.xconfig.parser as xparser +import libs.common as common_lib + + +def get_args(): + # we add compulsory arguments as named arguments for readability + parser = argparse.ArgumentParser( + description="Reads an xconfig file and creates config files " + "for neural net creation and training", + epilog='Search egs/*/*/local/{nnet3,chain}/*sh for examples') + parser.add_argument('--xconfig-file', required=True, + help='Filename of input xconfig file') + parser.add_argument('--existing-model', + help='Filename of previously trained neural net ' + '(e.g. final.mdl) which is useful in case of ' + 'using nodes from list of component-nodes in ' + 'already trained model ' + 'to generate new config file for new model.' + 'The context info is also generated using ' + 'a model generated by adding final.config ' + 'to the existing model.' + 'e.g. In Transfer learning: generate new model using ' + 'component nodes in existing model.') + parser.add_argument('--config-file-out', required=True, + help='Filename to write nnet config file.'); + parser.add_argument('--nnet-edits', type=str, default=None, + action=common_lib.NullstrToNoneAction, + help="""This option is useful in case the network you + are creating does not have an output node called + 'output' (e.g. for multilingual setups). You can set + this to an edit-string like: 'rename-node old-name=xxx + new-name=output' if node xxx plays the role of the + output node in this network. This is only used for + computing the left/right context.""") + + print(' '.join(sys.argv), file=sys.stderr) + + args = parser.parse_args() + + return args + + + +def write_config_file(config_file_out, all_layers): + # config_basename_to_lines is map from the basename of the + # config, as a string (i.e. 'ref', 'all', 'init') to a list of + # strings representing lines to put in the config file. + config_basename_to_lines = defaultdict(list) + + for layer in all_layers: + try: + pairs = layer.get_full_config() + for config_basename, line in pairs: + config_basename_to_lines[config_basename].append(line) + except Exception as e: + print("{0}: error producing config lines from xconfig " + "line '{1}': error was: {2}".format(sys.argv[0], + str(layer), repr(e)), + file=sys.stderr) + # we use raise rather than raise(e) as using a blank raise + # preserves the backtrace + raise + + with open(config_file_out, 'w') as f: + print('# This file was created by the command:\n' + '# {0} '.format(sys.argv), file=f) + lines = config_basename_to_lines['final'] + for line in lines: + print(line, file=f) + + +def main(): + args = get_args() + existing_layers = [] + if args.existing_model is not None: + existing_layers = xparser.get_model_component_info(args.existing_model) + all_layers = xparser.read_xconfig_file(args.xconfig_file, existing_layers) + write_config_file(args.config_file_out, all_layers) + + +if __name__ == '__main__': + main() + + +# test: +# (echo 'input dim=40 name=input'; echo 'output name=output input=Append(-1,0,1)') >xconfig; steps/nnet3/xconfig_to_config.py --xconfig-file=xconfig --config-file-out=foo diff --git a/egs/wsj/s5/steps/nnet3/xconfig_to_configs.py b/egs/wsj/s5/steps/nnet3/xconfig_to_configs.py index 3b8dc82fe48..f025eb5b343 100755 --- a/egs/wsj/s5/steps/nnet3/xconfig_to_configs.py +++ b/egs/wsj/s5/steps/nnet3/xconfig_to_configs.py @@ -115,7 +115,7 @@ def write_expanded_xconfig_files(config_dir, all_layers): '# See also ./xconfig.expanded.2\n', file=xconfig_file_out) for layer in all_layers: - print(str(layer), file=xconfig_file_out) + print('{}'.format(layer), file=xconfig_file_out) xconfig_file_out.close() try: @@ -135,7 +135,7 @@ def write_expanded_xconfig_files(config_dir, all_layers): for layer in all_layers: layer.normalize_descriptors() - print(str(layer), file=xconfig_file_out) + print('{}'.format(layer), file=xconfig_file_out) xconfig_file_out.close() diff --git a/egs/wsj/s5/steps/online/nnet2/extract_ivectors_online.sh b/egs/wsj/s5/steps/online/nnet2/extract_ivectors_online.sh index 0a5eb340a34..ddbc1a74266 100755 --- a/egs/wsj/s5/steps/online/nnet2/extract_ivectors_online.sh +++ b/egs/wsj/s5/steps/online/nnet2/extract_ivectors_online.sh @@ -42,6 +42,7 @@ max_count=0 # The use of this option (e.g. --max-count 100) can make # posterior-scaling, so assuming the posterior-scale is 0.1, # --max-count 100 starts having effect after 1000 frames, or # 10 seconds of data. +use_vad=false # End configuration section. @@ -69,8 +70,13 @@ data=$1 srcdir=$2 dir=$3 +extra_files= +if $use_vad; then + extra_files=$data/vad.scp +fi + for f in $data/feats.scp $srcdir/final.ie $srcdir/final.dubm $srcdir/global_cmvn.stats $srcdir/splice_opts \ - $srcdir/online_cmvn.conf $srcdir/final.mat; do + $srcdir/online_cmvn.conf $srcdir/final.mat $extra_files; do [ ! -f $f ] && echo "$0: No such file $f" && exit 1; done @@ -117,9 +123,15 @@ done if [ $stage -le 0 ]; then echo "$0: extracting iVectors" + extra_opts= + if $use_vad; then + extra_opts="--frame-weights-rspecifier=scp:$data/vad.scp" + fi + $cmd JOB=1:$nj $dir/log/extract_ivectors.JOB.log \ - ivector-extract-online2 --config=$ieconf ark:$sdata/JOB/spk2utt scp:$sdata/JOB/feats.scp ark:- \| \ - copy-feats --compress=$compress ark:- \ + ivector-extract-online2 --config=$ieconf $extra_opts \ + ark:$sdata/JOB/spk2utt scp:$sdata/JOB/feats.scp ark:- \| \ + copy-feats --compress=$compress ark:- \ ark,scp:$absdir/ivector_online.JOB.ark,$absdir/ivector_online.JOB.scp || exit 1; fi diff --git a/egs/wsj/s5/steps/online/nnet2/train_diag_ubm.sh b/egs/wsj/s5/steps/online/nnet2/train_diag_ubm.sh index 80a023fed8a..f4383628b1d 100755 --- a/egs/wsj/s5/steps/online/nnet2/train_diag_ubm.sh +++ b/egs/wsj/s5/steps/online/nnet2/train_diag_ubm.sh @@ -35,7 +35,7 @@ subsample=2 # subsample all features with this periodicity, in the main E-M phas cleanup=true min_gaussian_weight=0.0001 remove_low_count_gaussians=true # set this to false if you need #gauss to stay fixed. -num_threads=32 +num_threads=16 parallel_opts= # ignored now. online_cmvn_config=conf/online_cmvn.conf # End configuration section. diff --git a/egs/wsj/s5/steps/online/nnet3/prepare_online_decoding.sh b/egs/wsj/s5/steps/online/nnet3/prepare_online_decoding.sh index 912bf89bc59..045c16f5f6f 100755 --- a/egs/wsj/s5/steps/online/nnet3/prepare_online_decoding.sh +++ b/egs/wsj/s5/steps/online/nnet3/prepare_online_decoding.sh @@ -74,6 +74,11 @@ if [ ! -z "$iedir" ]; then for f in final.{mat,ie,dubm} splice_opts global_cmvn.stats online_cmvn.conf; do [ ! -f $iedir/$f ] && echo "$0: no such file $iedir/$f" && exit 1; done + if $add_pitch; then + iedim=`matrix-dim $iedir/final.mat | awk '{print $1}'` + amdim=`nnet3-am-info $srcdir/${iter}.mdl | grep "input-dim:" | awk '{print $2}'` + [ $(($amdim-$iedim)) -eq 0 ] && echo "$0: remove pitch from the input of ivector extractor" && exit 1; + fi fi diff --git a/egs/wsj/s5/steps/segmentation/ali_to_targets.sh b/egs/wsj/s5/steps/segmentation/ali_to_targets.sh index 78c76a8ea01..56d93df3c6b 100644 --- a/egs/wsj/s5/steps/segmentation/ali_to_targets.sh +++ b/egs/wsj/s5/steps/segmentation/ali_to_targets.sh @@ -82,9 +82,9 @@ nj=$(cat $ali_dir/num_jobs) || exit 1 $cmd JOB=1:$nj $dir/log/get_arc_info.JOB.log \ ali-to-phones --ctm-output --frame-shift=1 \ - $srcdir/final.mdl "ark:gunzip -c $ali_dir/lat.JOB.gz |" - \| \ + $srcdir/final.mdl "ark:gunzip -c $ali_dir/ali.JOB.gz |" - \| \ utils/int2sym.pl -f 5 $lang/phones.txt \| \ - awk '{print $1" "int($3)" "int($4)" 1.0 "$5}' \| \ + awk '{print $1" "int($3)" "int($4)" 1.0 "$5}' \> \ $dir/arc_info_sym.JOB.txt || exit 1 # make $dir an absolute pathname. diff --git a/egs/wsj/s5/steps/segmentation/detect_speech_activity.sh b/egs/wsj/s5/steps/segmentation/detect_speech_activity.sh index 60e3df20df2..831283bb5ec 100755 --- a/egs/wsj/s5/steps/segmentation/detect_speech_activity.sh +++ b/egs/wsj/s5/steps/segmentation/detect_speech_activity.sh @@ -56,7 +56,15 @@ acwt=0.3 # e.g. --speech-in-sil-weight=0.0 --garbage-in-sil-weight=0.0 --sil-in-speech-weight=0.0 --garbage-in-speech-weight=0.3 transform_probs_opts="" +# Postprocessing options segment_padding=0.2 # Duration (in seconds) of padding added to segments +min_segment_dur=0 # Minimum duration (in seconds) required for a segment to be included + # This is before any padding. Segments shorter than this duration will be removed. + # This is an alternative to --min-speech-duration above. +merge_consecutive_max_dur=0 # Merge consecutive segments as long as the merged segment is no longer than this many + # seconds. The segments are only merged if their boundaries are touching. + # This is after padding by --segment-padding seconds. + # 0 means do not merge. Use 'inf' to not limit the duration. echo $* @@ -99,14 +107,14 @@ data_id=`basename $data_dir` sad_dir=${dir}/${sad_name}${affix}_${data_id}_whole${feat_affix} seg_dir=${dir}/${segmentation_name}${affix}_${data_id}_whole${feat_affix} -test_data_dir=data/${data_id}${feat_affix}_hires - if $convert_data_dir_to_whole; then + test_data_dir=data/${data_id}_whole${feat_affix}_hires if [ $stage -le 0 ]; then rm -r ${test_data_dir} || true utils/data/convert_data_dir_to_whole.sh $src_data_dir ${test_data_dir} fi else + test_data_dir=data/${data_id}${feat_affix}_hires if [ $stage -le 0 ]; then rm -r ${test_data_dir} || true utils/copy_data_dir.sh $src_data_dir $test_data_dir @@ -170,7 +178,8 @@ fi ## Prepare FST we search to make speech/silence decisions. ############################################################################### -frame_shift=$(utils/data/get_frame_shift.sh $test_data_dir) +utils/data/get_utt2dur.sh --nj $nj --cmd "$cmd" $test_data_dir || exit 1 +frame_shift=$(utils/data/get_frame_shift.sh $test_data_dir) || exit 1 graph_dir=${dir}/graph_${output_name} if [ $stage -le 5 ]; then @@ -224,7 +233,8 @@ fi if [ $stage -le 7 ]; then steps/segmentation/post_process_sad_to_segments.sh \ - --segment-padding $segment_padding \ + --segment-padding $segment_padding --min-segment-dur $min_segment_dur \ + --merge-consecutive-max-dur $merge_consecutive_max_dur \ --cmd "$cmd" --frame-shift $(perl -e "print $frame_subsampling_factor * $frame_shift") \ ${test_data_dir} ${seg_dir} ${seg_dir} fi diff --git a/egs/wsj/s5/steps/segmentation/internal/find_oov_phone.py b/egs/wsj/s5/steps/segmentation/internal/find_oov_phone.py index 3e9cbbbf178..038640f6271 100644 --- a/egs/wsj/s5/steps/segmentation/internal/find_oov_phone.py +++ b/egs/wsj/s5/steps/segmentation/internal/find_oov_phone.py @@ -8,6 +8,7 @@ /phones/align_lexicon.int. It prints the OOV phone to stdout, if it can find a single phone mapping for the OOV word.""" +from __future__ import print_function import sys diff --git a/egs/wsj/s5/steps/segmentation/internal/get_default_targets_for_out_of_segments.py b/egs/wsj/s5/steps/segmentation/internal/get_default_targets_for_out_of_segments.py index e7000b9de00..0361999d904 100755 --- a/egs/wsj/s5/steps/segmentation/internal/get_default_targets_for_out_of_segments.py +++ b/egs/wsj/s5/steps/segmentation/internal/get_default_targets_for_out_of_segments.py @@ -14,6 +14,7 @@ the application and data, this could be [ 0 0 0 ] or [ 0 0 1 ] or something with fractional weights. """ +from __future__ import division import argparse import logging @@ -131,7 +132,7 @@ def run(args): and np.shape(default_targets)[1] == 3) with common_lib.smart_open(args.out_targets_ark, 'w') as f: - for reco, utts in reco2utt.iteritems(): + for reco, utts in reco2utt.items(): reco_mat = np.repeat(default_targets, reco2num_frames[reco], axis=0) utts.sort(key=lambda x: segments[x][1]) # sort on start time diff --git a/egs/wsj/s5/steps/segmentation/internal/merge_segment_targets_to_recording.py b/egs/wsj/s5/steps/segmentation/internal/merge_segment_targets_to_recording.py index 8c53e5e8db9..e48afbeb872 100755 --- a/egs/wsj/s5/steps/segmentation/internal/merge_segment_targets_to_recording.py +++ b/egs/wsj/s5/steps/segmentation/internal/merge_segment_targets_to_recording.py @@ -9,6 +9,7 @@ in any of the segments are assigned the default targets vector, specified by the option --default-targets or [ 0 0 0 ] if unspecified. """ +from __future__ import division import argparse import logging @@ -158,7 +159,7 @@ def run(args): num_reco = 0 with common_lib.smart_open(args.out_targets_ark, 'w') as fh: - for reco, utts in reco2utt.iteritems(): + for reco, utts in reco2utt.items(): # Read a recording and the list of its utterances from the # reco2utt dictionary reco_mat = np.repeat(default_targets, reco2num_frames[reco], diff --git a/egs/wsj/s5/steps/segmentation/internal/merge_targets.py b/egs/wsj/s5/steps/segmentation/internal/merge_targets.py index 8222eddad8f..84b0c884f45 100755 --- a/egs/wsj/s5/steps/segmentation/internal/merge_targets.py +++ b/egs/wsj/s5/steps/segmentation/internal/merge_targets.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # Copyright 2017 Vimal Manohar # Apache 2.0 @@ -16,7 +16,6 @@ option. """ -from __future__ import print_function import argparse import logging import numpy as np @@ -110,7 +109,7 @@ def should_remove_frame(row, dim): # source[2] = [ 0 0 0 ] """ assert len(row) % dim == 0 - num_sources = len(row) / dim + num_sources = len(row) // dim max_idx = np.argmax(row) max_val = row[max_idx] diff --git a/egs/wsj/s5/steps/segmentation/internal/sad_to_segments.py b/egs/wsj/s5/steps/segmentation/internal/sad_to_segments.py index 9b1c0f12b9a..09da9cbecc1 100755 --- a/egs/wsj/s5/steps/segmentation/internal/sad_to_segments.py +++ b/egs/wsj/s5/steps/segmentation/internal/sad_to_segments.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # Copyright 2017 Vimal Manohar +# 2018 Capital One (Author: Zhiyuan Guan) # Apache 2.0 """ @@ -29,6 +30,7 @@ global_verbose = 0 + def get_args(): parser = argparse.ArgumentParser( description=""" @@ -44,18 +46,31 @@ def get_args(): parser.add_argument("--utt2dur", type=str, help="File containing durations of utterances.") + parser.add_argument("--frame-shift", type=float, default=0.01, help="Frame shift to convert frame indexes to time") parser.add_argument("--segment-padding", type=float, default=0.2, help="Additional padding on speech segments. But we " - "ensure that the padding does not go beyond the " - "adjacent segment.") + "ensure that the padding does not go beyond the " + "adjacent segment.") + + parser.add_argument("--min-segment-dur", type=float, default=0, + help="Minimum duration (in seconds) required for a segment " + "to be included. This is before any padding. Segments " + "shorter than this duration will be removed.") + parser.add_argument("--merge-consecutive-max-dur", type=float, default=0, + help="Merge consecutive segments as long as the merged " + "segment is no longer than this many seconds. The segments " + "are only merged if their boundaries are touching. " + "This is after padding by --segment-padding seconds." + "0 means do not merge. Use 'inf' to not limit the duration.") parser.add_argument("in_sad", type=str, help="Input file containing alignments in " - "text archive format") + "text archive format") + parser.add_argument("out_segments", type=str, help="Output kaldi segments file") @@ -80,28 +95,45 @@ def to_str(segment): class SegmenterStats(object): """Stores stats about the post-process stages""" + def __init__(self): - self.num_segments = 0 + self.num_segments_initial = 0 + self.num_short_segments_filtered = 0 + self.num_merges = 0 + self.num_segments_final = 0 self.initial_duration = 0.0 self.padding_duration = 0.0 + self.filter_short_duration = 0.0 self.final_duration = 0.0 def add(self, other): """Adds stats from another object""" - self.num_segments += other.num_segments + self.num_segments_initial += other.num_segments_initial + self.num_short_segments_filtered += other.num_short_segments_filtered + self.num_merges += other.num_merges + self.num_segments_final += other.num_segments_final self.initial_duration += other.initial_duration - self.padding_duration = other.padding_duration - self.final_duration = other.final_duration + self.filter_short_duration += other.filter_short_duration + self.padding_duration += other.padding_duration + self.final_duration += other.final_duration def __str__(self): - return ("num-segments={num_segments}, " + return ("num-segments-initial={num_segments_initial}, " + "num-short-segments-filtered={num_short_segments_filtered}, " + "num-merges={num_merges}, " + "num-segments-final={num_segments_final}, " "initial-duration={initial_duration}, " + "filter-short-duration={filter_short_duration}, " "padding-duration={padding_duration}, " "final-duration={final_duration}".format( - num_segments=self.num_segments, - initial_duration=self.initial_duration, - padding_duration=self.padding_duration, - final_duration=self.final_duration)) + num_segments_initial=self.num_segments_initial, + num_short_segments_filtered=self.num_short_segments_filtered, + num_merges=self.num_merges, + num_segments_final=self.num_segments_final, + initial_duration=self.initial_duration, + filter_short_duration=self.filter_short_duration, + padding_duration=self.padding_duration, + final_duration=self.final_duration)) def process_label(text_label): @@ -114,13 +146,14 @@ def process_label(text_label): prev_label = int(text_label) if prev_label not in [1, 2]: raise ValueError("Expecting label to 1 (non-speech) or 2 (speech); " - "got {0}".format(prev_label)) + "got {}".format(prev_label)) return prev_label class Segmentation(object): """Stores segmentation for an utterances""" + def __init__(self): self.segments = None self.stats = SegmenterStats() @@ -141,10 +174,9 @@ def initialize_segments(self, alignment, frame_shift=0.01): self.segments.append( [float(i - prev_length) * frame_shift, float(i) * frame_shift, prev_label]) - + self.stats.initial_duration += (prev_length * frame_shift) prev_label = process_label(text_label) prev_length = 0 - self.stats.initial_duration += (prev_length * frame_shift) elif prev_label is None: prev_label = process_label(text_label) @@ -156,7 +188,27 @@ def initialize_segments(self, alignment, frame_shift=0.01): float(len(alignment)) * frame_shift, prev_label]) self.stats.initial_duration += (prev_length * frame_shift) - self.stats.num_segments = len(self.segments) + self.stats.num_segments_initial = len(self.segments) + self.stats.num_segments_final = len(self.segments) + self.stats.final_duration = self.stats.initial_duration + + def filter_short_segments(self, min_dur): + """Filters out segments with durations shorter than 'min_dur'.""" + if min_dur <= 0: + return + + segments_kept = [] + for segment in self.segments: + assert segment[2] == 2, segment + dur = segment[1] - segment[0] + if dur < min_dur: + self.stats.filter_short_duration += dur + self.stats.num_short_segments_filtered += 1 + else: + segments_kept.append(segment) + self.segments = segments_kept + self.stats.num_segments_final = len(self.segments) + self.stats.final_duration -= self.stats.filter_short_duration def pad_speech_segments(self, segment_padding, max_duration=float("inf")): """Pads segments by duration 'segment_padding' on either sides, but @@ -166,19 +218,19 @@ def pad_speech_segments(self, segment_padding, max_duration=float("inf")): max_duration = float("inf") for i, segment in enumerate(self.segments): assert segment[2] == 2, segment - segment[0] -= segment_padding # try adding padding on the left side + segment[0] -= segment_padding # try adding padding on the left side self.stats.padding_duration += segment_padding if segment[0] < 0.0: # Padding takes the segment start to before the beginning of the utterance. # Reduce padding. self.stats.padding_duration += segment[0] segment[0] = 0.0 - if i >= 1 and self.segments[i-1][1] > segment[0]: + if i >= 1 and self.segments[i - 1][1] > segment[0]: # Padding takes the segment start to before the end the previous segment. # Reduce padding. self.stats.padding_duration -= ( - self.segments[i-1][1] - segment[0]) - segment[0] = self.segments[i-1][1] + self.segments[i - 1][1] - segment[0]) + segment[0] = self.segments[i - 1][1] segment[1] += segment_padding self.stats.padding_duration += segment_padding @@ -188,12 +240,35 @@ def pad_speech_segments(self, segment_padding, max_duration=float("inf")): self.stats.padding_duration -= (segment[1] - max_duration) segment[1] = max_duration if (i + 1 < len(self.segments) - and segment[1] > self.segments[i+1][0]): + and segment[1] > self.segments[i + 1][0]): # Padding takes the segment end beyond the start of the next segment. # Reduce padding. self.stats.padding_duration -= ( - segment[1] - self.segments[i+1][0]) - segment[1] = self.segments[i+1][0] + segment[1] - self.segments[i + 1][0]) + segment[1] = self.segments[i + 1][0] + self.stats.final_duration += self.stats.padding_duration + + def merge_consecutive_segments(self, max_dur): + """Merge consecutive segments (happens after padding), provided that + the merged segment is no longer than 'max_dur'.""" + if max_dur <= 0 or not self.segments: + return + + merged_segments = [self.segments[0]] + for segment in self.segments[1:]: + assert segment[2] == 2, segment + if segment[0] == merged_segments[-1][1] and \ + segment[1] - merged_segments[-1][0] <= max_dur: + # The segment starts at the same time the last segment ends, + # and the merged segment is shorter than 'max_dur'. + # Extend the previous segment. + merged_segments[-1][1] = segment[1] + self.stats.num_merges += 1 + else: + merged_segments.append(segment) + + self.segments = merged_segments + self.stats.num_segments_final = len(self.segments) def write(self, key, file_handle): """Write segments to file""" @@ -203,9 +278,9 @@ def write(self, key, file_handle): for segment in self.segments: seg_id = "{key}-{st:07d}-{end:07d}".format( key=key, st=int(segment[0] * 100), end=int(segment[1] * 100)) - print ("{seg_id} {key} {st:.2f} {end:.2f}".format( + print("{seg_id} {key} {st:.2f} {end:.2f}".format( seg_id=seg_id, key=key, st=segment[0], end=segment[1]), - file=file_handle) + file=file_handle) def run(args): @@ -235,9 +310,11 @@ def run(args): segmentation = Segmentation() segmentation.initialize_segments( parts[1:], args.frame_shift) + segmentation.filter_short_segments(args.min_segment_dur) segmentation.pad_speech_segments(args.segment_padding, None if args.utt2dur is None else utt2dur[utt_id]) + segmentation.merge_consecutive_segments(args.merge_consecutive_max_dur) segmentation.write(utt_id, out_segments_fh) global_stats.add(segmentation.stats) logger.info(global_stats) diff --git a/egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh b/egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh index ca9cea2518b..b168c307b57 100755 --- a/egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh +++ b/egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh @@ -18,6 +18,8 @@ nj=18 # The values below are in seconds frame_shift=0.01 segment_padding=0.2 +min_segment_dur=0 +merge_consecutive_max_dur=0 . utils/parse_options.sh @@ -53,6 +55,7 @@ if [ $stage -le 0 ]; then copy-int-vector "ark:gunzip -c $vad_dir/ali.JOB.gz |" ark,t:- \| \ steps/segmentation/internal/sad_to_segments.py \ --frame-shift=$frame_shift --segment-padding=$segment_padding \ + --min-segment-dur=$min_segment_dur --merge-consecutive-max-dur=$merge_consecutive_max_dur \ --utt2dur=$data_dir/utt2dur - $dir/segments.JOB fi diff --git a/egs/wsj/s5/steps/tfrnnlm/lstm.py b/egs/wsj/s5/steps/tfrnnlm/lstm.py index 06969fbcb5d..433dc87b4c6 100644 --- a/egs/wsj/s5/steps/tfrnnlm/lstm.py +++ b/egs/wsj/s5/steps/tfrnnlm/lstm.py @@ -16,8 +16,8 @@ # this script trains a vanilla RNNLM with TensorFlow. # to call the script, do -# python steps/tfrnnlm/lstm.py --data-path=$datadir \ -# --save-path=$savepath --vocab-path=$rnn.wordlist [--hidden-size=$size] +# python steps/tfrnnlm/lstm.py --data_path=$datadir \ +# --save_path=$savepath --vocab_path=$rnn.wordlist [--hidden-size=$size] # # One example recipe is at egs/ami/s5/local/tfrnnlm/run_lstm.sh @@ -38,15 +38,15 @@ flags = tf.flags logging = tf.logging -flags.DEFINE_integer("hidden-size", 200, "hidden dim of RNN") +flags.DEFINE_integer("hidden_size", 200, "hidden dim of RNN") -flags.DEFINE_string("data-path", None, +flags.DEFINE_string("data_path", None, "Where the training/test data is stored.") -flags.DEFINE_string("vocab-path", None, +flags.DEFINE_string("vocab_path", None, "Where the wordlist file is stored.") -flags.DEFINE_string("save-path", None, +flags.DEFINE_string("save_path", None, "Model output directory.") -flags.DEFINE_bool("use-fp16", False, +flags.DEFINE_bool("use_fp16", False, "Train using 16-bit floats instead of 32bit floats") FLAGS = flags.FLAGS @@ -203,7 +203,7 @@ def attn_cell(): config.max_grad_norm) optimizer = tf.train.GradientDescentOptimizer(self._lr) self._train_op = optimizer.apply_gradients( - zip(grads, tvars), + list(zip(grads, tvars)), global_step=tf.contrib.framework.get_or_create_global_step()) self._new_lr = tf.placeholder( diff --git a/egs/wsj/s5/steps/tfrnnlm/lstm_fast.py b/egs/wsj/s5/steps/tfrnnlm/lstm_fast.py index 9643468ccfb..ff6c7263804 100644 --- a/egs/wsj/s5/steps/tfrnnlm/lstm_fast.py +++ b/egs/wsj/s5/steps/tfrnnlm/lstm_fast.py @@ -16,8 +16,8 @@ # this script trains a vanilla RNNLM with TensorFlow. # to call the script, do -# python steps/tfrnnlm/lstm_fast.py --data-path=$datadir \ -# --save-path=$savepath --vocab-path=$rnn.wordlist [--hidden-size=$size] +# python steps/tfrnnlm/lstm_fast.py --data_path=$datadir \ +# --save_path=$savepath --vocab_path=$rnn.wordlist [--hidden-size=$size] # # One example recipe is at egs/ami/s5/local/tfrnnlm/run_vanilla_rnnlm.sh @@ -38,15 +38,15 @@ flags = tf.flags logging = tf.logging -flags.DEFINE_integer("hidden-size", 200, "hidden dim of RNN") +flags.DEFINE_integer("hidden_size", 200, "hidden dim of RNN") -flags.DEFINE_string("data-path", None, +flags.DEFINE_string("data_path", None, "Where the training/test data is stored.") -flags.DEFINE_string("vocab-path", None, +flags.DEFINE_string("vocab_path", None, "Where the wordlist file is stored.") -flags.DEFINE_string("save-path", None, +flags.DEFINE_string("save_path", None, "Model output directory.") -flags.DEFINE_bool("use-fp16", False, +flags.DEFINE_bool("use_fp16", False, "Train using 16-bit floats instead of 32bit floats") FLAGS = flags.FLAGS @@ -218,7 +218,7 @@ def attn_cell(): config.max_grad_norm) optimizer = tf.train.GradientDescentOptimizer(self._lr) self._train_op = optimizer.apply_gradients( - zip(grads, tvars), + list(zip(grads, tvars)), global_step=tf.contrib.framework.get_or_create_global_step()) self._new_lr = tf.placeholder( diff --git a/egs/wsj/s5/steps/tfrnnlm/reader.py b/egs/wsj/s5/steps/tfrnnlm/reader.py index fc3d4d0471c..80cdeccbb26 100644 --- a/egs/wsj/s5/steps/tfrnnlm/reader.py +++ b/egs/wsj/s5/steps/tfrnnlm/reader.py @@ -31,7 +31,7 @@ def _read_words(filename): def _build_vocab(filename): words = _read_words(filename) - word_to_id = dict(zip(words, range(len(words)))) + word_to_id = dict(list(zip(words, list(range(len(words)))))) return word_to_id diff --git a/egs/wsj/s5/steps/tfrnnlm/vanilla_rnnlm.py b/egs/wsj/s5/steps/tfrnnlm/vanilla_rnnlm.py index de263c6923f..ae7a257906e 100644 --- a/egs/wsj/s5/steps/tfrnnlm/vanilla_rnnlm.py +++ b/egs/wsj/s5/steps/tfrnnlm/vanilla_rnnlm.py @@ -16,8 +16,8 @@ # this script trains a vanilla RNNLM with TensorFlow. # to call the script, do -# python steps/tfrnnlm/vanilla_rnnlm.py --data-path=$datadir \ -# --save-path=$savepath --vocab-path=$rnn.wordlist [--hidden-size=$size] +# python steps/tfrnnlm/vanilla_rnnlm.py --data_path=$datadir \ +# --save_path=$savepath --vocab_path=$rnn.wordlist [--hidden-size=$size] # # One example recipe is at egs/ami/s5/local/tfrnnlm/run_vanilla_rnnlm.sh @@ -38,15 +38,15 @@ flags = tf.flags logging = tf.logging -flags.DEFINE_integer("hidden-size", 200, "hidden dim of RNN") +flags.DEFINE_integer("hidden_size", 200, "hidden dim of RNN") -flags.DEFINE_string("data-path", None, +flags.DEFINE_string("data_path", None, "Where the training/test data is stored.") -flags.DEFINE_string("vocab-path", None, +flags.DEFINE_string("vocab_path", None, "Where the wordlist file is stored.") -flags.DEFINE_string("save-path", None, +flags.DEFINE_string("save_path", None, "Model output directory.") -flags.DEFINE_bool("use-fp16", False, +flags.DEFINE_bool("use_fp16", False, "Train using 16-bit floats instead of 32bit floats") FLAGS = flags.FLAGS @@ -201,7 +201,7 @@ def attn_cell(): config.max_grad_norm) optimizer = tf.train.MomentumOptimizer(self._lr, 0.9) self._train_op = optimizer.apply_gradients( - zip(grads, tvars), + list(zip(grads, tvars)), global_step=tf.contrib.framework.get_or_create_global_step()) self._new_lr = tf.placeholder( diff --git a/egs/wsj/s5/steps/train_diag_ubm.sh b/egs/wsj/s5/steps/train_diag_ubm.sh index 10cc4a4b43e..4389844d478 100755 --- a/egs/wsj/s5/steps/train_diag_ubm.sh +++ b/egs/wsj/s5/steps/train_diag_ubm.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright Johns Hopkins University (Author: Daniel Povey), 2012. +# Copyright Johns Hopkins University (Author: Daniel Povey), 2012. # Apache 2.0. # Train a diagonal mixture of Gaussians. This is trained without @@ -67,7 +67,7 @@ echo "$0: feature type is $feat_type" case $feat_type in delta) feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | add-deltas $delta_opts ark:- ark:- |";; lda) feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $alidir/final.mat ark:- ark:- |" - cp $alidir/final.mat $dir + cp $alidir/final.mat $dir ;; *) echo "Invalid feature type $feat_type" && exit 1; esac diff --git a/egs/wsj/s5/steps/train_mono.sh b/egs/wsj/s5/steps/train_mono.sh index 141d128c329..5a0b79a4a1c 100755 --- a/egs/wsj/s5/steps/train_mono.sh +++ b/egs/wsj/s5/steps/train_mono.sh @@ -1,5 +1,6 @@ #!/bin/bash # Copyright 2012 Johns Hopkins University (Author: Daniel Povey) +# 2019 Xiaohui Zhang # Apache 2.0 @@ -13,6 +14,9 @@ cmd=run.pl scale_opts="--transition-scale=1.0 --acoustic-scale=0.1 --self-loop-scale=0.1" num_iters=40 # Number of iterations of training max_iter_inc=30 # Last iter to increase #Gauss on. +initial_beam=6 # beam used in the first iteration (set smaller to speed up initialization) +regular_beam=10 # beam used after the first iteration +retry_beam=40 totgauss=1000 # Target #Gaussians. careful=false boost_silence=1.0 # Factor by which to boost silence likelihoods in alignment @@ -105,8 +109,7 @@ if [ $stage -le 0 ]; then rm $dir/0.*.acc fi - -beam=6 # will change to 10 below after 1st pass +beam=$initial_beam # will change to regular_beam below after 1st pass # note: using slightly wider beams for WSJ vs. RM. x=1 while [ $x -lt $num_iters ]; do @@ -116,7 +119,7 @@ while [ $x -lt $num_iters ]; do echo "$0: Aligning data" mdl="gmm-boost-silence --boost=$boost_silence `cat $lang/phones/optional_silence.csl` $dir/$x.mdl - |" $cmd JOB=1:$nj $dir/log/align.$x.JOB.log \ - gmm-align-compiled $scale_opts --beam=$beam --retry-beam=$[$beam*4] --careful=$careful "$mdl" \ + gmm-align-compiled $scale_opts --beam=$beam --retry-beam=$retry_beam --careful=$careful "$mdl" \ "ark:gunzip -c $dir/fsts.JOB.gz|" "$feats" "ark,t:|gzip -c >$dir/ali.JOB.gz" \ || exit 1; fi @@ -132,7 +135,7 @@ while [ $x -lt $num_iters ]; do if [ $x -le $max_iter_inc ]; then numgauss=$[$numgauss+$incgauss]; fi - beam=10 + beam=$regular_beam x=$[$x+1] done diff --git a/egs/wsj/s5/steps/train_sat.sh b/egs/wsj/s5/steps/train_sat.sh index 0211f7bcf67..92b744dc75c 100755 --- a/egs/wsj/s5/steps/train_sat.sh +++ b/egs/wsj/s5/steps/train_sat.sh @@ -276,4 +276,3 @@ steps/info/gmm_dir_info.pl $dir echo "$0: done training SAT system in $dir" exit 0 - diff --git a/egs/wsj/s5/steps/train_sat_basis.sh b/egs/wsj/s5/steps/train_sat_basis.sh index 45384fe4ecd..5245ea0c619 100755 --- a/egs/wsj/s5/steps/train_sat_basis.sh +++ b/egs/wsj/s5/steps/train_sat_basis.sh @@ -17,6 +17,7 @@ scale_opts="--transition-scale=1.0 --acoustic-scale=0.1 --self-loop-scale=0.1" beam=10 retry_beam=40 boost_silence=1.0 # Factor by which to boost silence likelihoods in alignment +basis_fmllr_opts="--fmllr-min-count=22 --num-iters=10 --size-scale=0.2 --step-size-iters=3" context_opts= # e.g. set this to "--context-width 5 --central-position 2" for quinphone. realign_iters="10 20 30"; fmllr_iters="2 4 6 12"; @@ -93,7 +94,7 @@ esac ## Get initial fMLLR transforms (possibly from alignment dir) if [ -f $alidir/trans.1 ]; then echo "$0: Using transforms from $alidir" - feats="$sifeats transform-feats ark,s,cs:$alidir/trans.JOB ark:- ark:- |" + feats="$sifeats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk ark,s,cs:$alidir/trans.JOB ark:- ark:- |" cur_trans_dir=$alidir else if [ $stage -le -5 ]; then @@ -114,13 +115,11 @@ else ali-to-post "ark:gunzip -c $alidir/ali.JOB.gz|" ark:- \| \ weight-silence-post $silence_weight $silphonelist $alidir/final.mdl ark:- ark:- \| \ gmm-post-to-gpost $alidir/final.mdl "$sifeats" ark:- ark:- \| \ - gmm-est-basis-fmllr-gpost --fmllr-min-count=22 --num-iters=10 \ - --size-scale=0.2 --step-size-iters=3 \ - --write-weights=ark:$dir/pre_wgt.JOB \ + gmm-est-basis-fmllr-gpost $basis_fmllr_opts --spk2utt=ark:$sdata/JOB/spk2utt \ $alidir/final.mdl $alidir/fmllr.basis "$sifeats" ark,s,cs:- \ ark:$alidir/trans.JOB || exit 1; - feats="$sifeats transform-feats ark,s,cs:$alidir/trans.JOB ark:- ark:- |" + feats="$sifeats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk ark,s,cs:$alidir/trans.JOB ark:- ark:- |" cur_trans_dir=$alidir fi fi @@ -214,14 +213,12 @@ while [ $x -lt $num_iters ]; do ali-to-post "ark:gunzip -c $dir/ali.JOB.gz|" ark:- \| \ weight-silence-post $silence_weight $silphonelist $dir/$x.mdl ark:- ark:- \| \ gmm-post-to-gpost $dir/$x.mdl "$sifeats" ark:- ark:- \| \ - gmm-est-basis-fmllr-gpost --fmllr-min-count=22 --num-iters=10 \ - --size-scale=0.2 --step-size-iters=3 \ - --write-weights=ark:$dir/pre_wgt.JOB \ + gmm-est-basis-fmllr-gpost $basis_fmllr_opts --spk2utt=ark:$sdata/JOB/spk2utt \ $dir/$x.mdl $dir/fmllr.basis "$sifeats" ark,s,cs:- \ ark:$dir/trans.JOB || exit 1; fi - feats="$sifeats transform-feats ark:$dir/trans.JOB ark:- ark:- |" + feats="$sifeats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk ark:$dir/trans.JOB ark:- ark:- |" cur_trans_dir=$dir fi diff --git a/egs/wsj/s5/steps/train_sgmm2.sh b/egs/wsj/s5/steps/train_sgmm2.sh index 29c5346c480..7f7df2e046a 100755 --- a/egs/wsj/s5/steps/train_sgmm2.sh +++ b/egs/wsj/s5/steps/train_sgmm2.sh @@ -10,16 +10,15 @@ # (Computer Speech and Language, 2011). # Begin configuration section. -nj=4 cmd=run.pl -stage=-6 # use this to resume partially finished training +stage=-6 # use this to resume partially finished training context_opts= # e.g. set it to "--context-width=5 --central-position=2" for a # quinphone system. scale_opts="--transition-scale=1.0 --acoustic-scale=0.1 --self-loop-scale=0.1" num_iters=25 # Total number of iterations of training num_iters_alimdl=3 # Number of iterations for estimating alignment model. max_iter_inc=15 # Last iter to increase #substates on. -realign_iters="5 10 15"; # Iters to realign on. +realign_iters="5 10 15"; # Iters to realign on. spkvec_iters="5 8 12 17" # Iters to estimate speaker vectors on. increase_iters="6 10 14"; # Iters on which to increase phn dim and/or spk dim; # rarely necessary, and if it is, only the 1st will normally be necessary. @@ -70,7 +69,7 @@ first_spkvec_iter=`echo $spkvec_iters | awk '{print $1}'` || exit 1; ciphonelist=`cat $lang/phones/context_indep.csl` || exit 1; # Check some files. -for f in $data/feats.scp $lang/L.fst $alidir/ali.1.gz $alidir/final.mdl $ubm; do +for f in $data/feats.scp $lang/L.fst $alidir/ali.1.gz $alidir/final.mdl $ubm $alidir/num_jobs; do [ ! -f $f ] && echo "$0: no such file $f" && exit 1; done @@ -112,7 +111,7 @@ echo "$0: feature type is $feat_type" case $feat_type in delta) feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | add-deltas ark:- ark:- |";; lda) feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $alidir/final.mat ark:- ark:- |" - cp $alidir/final.mat $dir + cp $alidir/final.mat $dir ;; *) echo "$0: invalid feature type $feat_type" && exit 1; esac @@ -151,7 +150,7 @@ if [ $stage -le -5 ]; then fi if [ $stage -le -4 ]; then - echo "$0: Initializing the model" + echo "$0: Initializing the model" # Note: if phn_dim > feat_dim+1 or spk_dim > feat_dim, these dims # will be truncated on initialization. $cmd $dir/log/init_sgmm.log \ @@ -176,7 +175,7 @@ if [ $stage -le -2 ]; then fi if [ $stage -le -1 ]; then - echo "$0: converting alignments" + echo "$0: converting alignments" $cmd JOB=1:$nj $dir/log/convert_ali.JOB.log \ convert-ali $alidir/final.mdl $dir/0.mdl $dir/tree "ark:gunzip -c $alidir/ali.JOB.gz|" \ "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1; @@ -204,10 +203,10 @@ while [ $x -lt $num_iters ]; do ark:$dir/tmp_vecs.JOB '&&' mv $dir/tmp_vecs.JOB $dir/vecs.JOB || exit 1; fi spkvecs_opt="--spk-vecs=ark:$dir/vecs.JOB" - fi + fi if [ $x -eq 0 ]; then flags=vwcSt # on the first iteration, don't update projections M or N - elif [ $spk_dim -gt 0 -a $[$x%2] -eq 1 -a $x -ge $first_spkvec_iter ]; then + elif [ $spk_dim -gt 0 -a $[$x%2] -eq 1 -a $x -ge $first_spkvec_iter ]; then # Update N if we have speaker-vector space and x is odd, # and we've already updated the speaker vectors... flags=vNwSct @@ -218,9 +217,9 @@ while [ $x -lt $num_iters ]; do flags=vwSct # no M on early iters, if --update-m-iter option given. fi fi - $spk_dep_weights && [ $x -ge $first_spkvec_iter ] && flags=${flags}u; # update + $spk_dep_weights && [ $x -ge $first_spkvec_iter ] && flags=${flags}u; # update # spk-weight projections "u". - + if [ $stage -le $x ]; then $cmd JOB=1:$nj $dir/log/acc.$x.JOB.log \ sgmm2-acc-stats $spkvecs_opt --utt2spk=ark:$sdata/JOB/utt2spk \ @@ -235,7 +234,7 @@ while [ $x -lt $num_iters ]; do if echo $increase_dim_iters | grep -w $x >/dev/null; then increase_dim_opts="--increase-phn-dim=$phn_dim --increase-spk-dim=$spk_dim" # Note: the command below might have a null effect on some iterations. - if [ $spk_dim -gt $feat_dim ]; then + if [ $spk_dim -gt $feat_dim ]; then cmd JOB=1:$nj $dir/log/copy_vecs.$x.JOB.log \ copy-vector --print-args=false --change-dim=$spk_dim \ ark:$dir/vecs.JOB ark:$dir/vecs_tmp.$JOB '&&' \ @@ -292,7 +291,7 @@ if [ $spk_dim -gt 0 ]; then cur_alimdl=$dir/$[$x+1].alimdl x=$[$x+1] done - rm $dir/final.alimdl 2>/dev/null + rm $dir/final.alimdl 2>/dev/null ln -s $x.alimdl $dir/final.alimdl fi diff --git a/egs/wsj/s5/steps/train_sgmm2_group.sh b/egs/wsj/s5/steps/train_sgmm2_group.sh index 4639616aceb..7263e2d5e8e 100755 --- a/egs/wsj/s5/steps/train_sgmm2_group.sh +++ b/egs/wsj/s5/steps/train_sgmm2_group.sh @@ -14,14 +14,14 @@ # Begin configuration section. cmd=run.pl -stage=-6 # use this to resume partially finished training +stage=-6 # use this to resume partially finished training context_opts= # e.g. set it to "--context-width=5 --central-position=2" for a # quinphone system. scale_opts="--transition-scale=1.0 --acoustic-scale=0.1 --self-loop-scale=0.1" num_iters=25 # Total number of iterations of training num_iters_alimdl=3 # Number of iterations for estimating alignment model. max_iter_inc=15 # Last iter to increase #substates on. -realign_iters="5 10 15"; # Iters to realign on. +realign_iters="5 10 15"; # Iters to realign on. spkvec_iters="5 8 12 17" # Iters to estimate speaker vectors on. increase_iters="6 10 14"; # Iters on which to increase phn dim and/or spk dim; # rarely necessary, and if it is, only the 1st will normally be necessary. @@ -75,7 +75,7 @@ first_spkvec_iter=`echo $spkvec_iters | awk '{print $1}'` || exit 1; ciphonelist=`cat $lang/phones/context_indep.csl` || exit 1; # Check some files. -for f in $data/feats.scp $lang/L.fst $alidir/ali.1.gz $alidir/final.mdl $ubm; do +for f in $data/feats.scp $lang/L.fst $alidir/ali.1.gz $alidir/final.mdl $ubm $alidir/num_jobs; do [ ! -f $f ] && echo "$0: no such file $f" && exit 1; done @@ -117,7 +117,7 @@ echo "$0: feature type is $feat_type" case $feat_type in delta) feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | add-deltas ark:- ark:- |";; lda) feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $alidir/final.mat ark:- ark:- |" - cp $alidir/final.mat $dir + cp $alidir/final.mat $dir ;; *) echo "$0: invalid feature type $feat_type" && exit 1; esac @@ -156,7 +156,7 @@ if [ $stage -le -5 ]; then fi if [ $stage -le -4 ]; then - echo "$0: Initializing the model" + echo "$0: Initializing the model" # Note: if phn_dim > feat_dim+1 or spk_dim > feat_dim, these dims # will be truncated on initialization. $cmd $dir/log/init_sgmm.log \ @@ -181,7 +181,7 @@ if [ $stage -le -2 ]; then fi if [ $stage -le -1 ]; then - echo "$0: converting alignments" + echo "$0: converting alignments" $cmd JOB=1:$nj $dir/log/convert_ali.JOB.log \ convert-ali $alidir/final.mdl $dir/0.mdl $dir/tree "ark:gunzip -c $alidir/ali.JOB.gz|" \ "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1; @@ -209,10 +209,10 @@ while [ $x -lt $num_iters ]; do ark:$dir/tmp_vecs.JOB '&&' mv $dir/tmp_vecs.JOB $dir/vecs.JOB || exit 1; fi spkvecs_opt="--spk-vecs=ark:$dir/vecs.JOB" - fi + fi if [ $x -eq 0 ]; then flags=vwcSt # on the first iteration, don't update projections M or N - elif [ $spk_dim -gt 0 -a $[$x%2] -eq 1 -a $x -ge $first_spkvec_iter ]; then + elif [ $spk_dim -gt 0 -a $[$x%2] -eq 1 -a $x -ge $first_spkvec_iter ]; then # Update N if we have speaker-vector space and x is odd, # and we've already updated the speaker vectors... flags=vNwSct @@ -223,9 +223,9 @@ while [ $x -lt $num_iters ]; do flags=vwSct # no M on early iters, if --update-m-iter option given. fi fi - $spk_dep_weights && [ $x -ge $first_spkvec_iter ] && flags=${flags}u; # update + $spk_dep_weights && [ $x -ge $first_spkvec_iter ] && flags=${flags}u; # update # spk-weight projections "u". - + # Submit separate jobs for small groups (of size $group) of accumulators. Args=() # bash array of training commands for 1:nj, that put accs to stdout. for n in `seq $nj`; do @@ -233,16 +233,16 @@ while [ $x -lt $num_iters ]; do --update-flags=$flags '$gselect_opt' --rand-prune=$rand_prune \ $dir/$x.mdl '$feats' 'ark,s,cs:gunzip -c $dir/ali.JOB.gz | ali-to-post ark:- ark:-|' - |" | sed s/JOB/$n/g` done - + g=0 rm $dir/.error 2>/dev/null if [ $stage -le $x ]; then while [ $[$g*$group] -lt $nj ]; do if [ -s $dir/acc.$x.$g.gz ]; then echo "Skipping creation of acc $dir/acc.$x.$g.gz as it already exists." - else + else start=$[$g*$group + 1]; # start-position in array Args. - # see http://www.thegeekstuff.com/2010/06/bash-array-tutorial/, this uses Bash arrays." + # see http://www.thegeekstuff.com/2010/06/bash-array-tutorial/, this uses Bash arrays." # The syntax "${Args[@]:$start:$group}" is equivalent to, say, # "${Args[3]}" "${Args[4]}" if start=3 and group=2. Except it's smart about the end # of the array, it won't give you empty quoted strings if the length "group" takes you off @@ -258,14 +258,14 @@ while [ $x -lt $num_iters ]; do exit 1; fi fi - + # The next option is needed if the user specifies a phone or speaker sub-space # dimension that's higher than the "normal" one. increase_dim_opts= if echo $increase_dim_iters | grep -w $x >/dev/null; then increase_dim_opts="--increase-phn-dim=$phn_dim --increase-spk-dim=$spk_dim" # Note: the command below might have a null effect on some iterations. - if [ $spk_dim -gt $feat_dim ]; then + if [ $spk_dim -gt $feat_dim ]; then cmd JOB=1:$nj $dir/log/copy_vecs.$x.JOB.log \ copy-vector --print-args=false --change-dim=$spk_dim \ ark:$dir/vecs.JOB ark:$dir/vecs_tmp.$JOB '&&' \ @@ -322,7 +322,7 @@ if [ $spk_dim -gt 0 ]; then while [ $[$g*$group] -lt $nj ]; do if [ -s $dir/acc.$x.$g.gz ]; then echo "Skipping creation of acc $dir/acc.$x.$g.gz as it already exists." - else + else start=$[$g*$group + 1]; # start-position in array Args. $cmd --num-threads "$group" $dir/log/acc.$x.$g.log \ sgmm2-sum-accs --parallel=true "|gzip -c >$dir/acc.$x.$g.gz" "${Args[@]:$start:$group}" || touch $dir/.error & @@ -345,7 +345,7 @@ if [ $spk_dim -gt 0 ]; then cur_alimdl=$dir/$[$x+1].alimdl x=$[$x+1] done - rm $dir/final.alimdl 2>/dev/null + rm $dir/final.alimdl 2>/dev/null ln -s $x.alimdl $dir/final.alimdl fi diff --git a/egs/wsj/s5/steps/train_ubm.sh b/egs/wsj/s5/steps/train_ubm.sh index a78d0639404..5351abbb784 100755 --- a/egs/wsj/s5/steps/train_ubm.sh +++ b/egs/wsj/s5/steps/train_ubm.sh @@ -7,7 +7,6 @@ # We mostly use this for SGMM systems. # Begin configuration section. -nj=4 cmd=run.pl silence_weight= # You can set it to e.g. 0.0, to weight down silence in training. stage=-2 @@ -42,7 +41,7 @@ lang=$3 alidir=$4 dir=$5 -for f in $data/feats.scp $lang/L.fst $alidir/ali.1.gz $alidir/final.mdl; do +for f in $data/feats.scp $lang/L.fst $alidir/ali.1.gz $alidir/final.mdl $alidir/num_jobs; do [ ! -f $f ] && echo "No such file $f" && exit 1; done @@ -75,7 +74,7 @@ echo "$0: feature type is $feat_type" case $feat_type in delta) feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | add-deltas $delta_opts ark:- ark:- |";; lda) feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $alidir/final.mat ark:- ark:- |" - cp $alidir/final.mat $dir + cp $alidir/final.mat $dir ;; *) echo "$0: invalid feature type $feat_type" && exit 1; esac @@ -90,7 +89,7 @@ if [ -f $alidir/trans.1 ]; then fi elif [ -f $alidir/raw_trans.1 ]; then echo "$0: using raw-FMLLR transforms from $alidir" - feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | transform-feats --utt2spk=ark:$sdata/JOB/utt2spk ark,s,cs:$alidir/raw_trans.JOB ark:- ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $alidir/final.mat ark:- ark:- |" + feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | transform-feats --utt2spk=ark:$sdata/JOB/utt2spk ark,s,cs:$alidir/raw_trans.JOB ark:- ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $alidir/final.mat ark:- ark:- |" fi ## @@ -129,7 +128,7 @@ while [ $x -lt $num_iters ]; do fgmm-global-acc-stats $weights_opt --gselect=ark,s,cs:- $dir/$x.ubm "$feats" \ $dir/$x.JOB.acc || exit 1; lowcount_opt="--remove-low-count-gaussians=false" - [ $[$x+1] -eq $num_iters ] && lowcount_opt= # Only remove low-count Gaussians + [ $[$x+1] -eq $num_iters ] && lowcount_opt= # Only remove low-count Gaussians # on last iter-- we can't do it earlier, or the Gaussian-selection info would # be mismatched. $cmd $dir/log/update.$x.log \ diff --git a/egs/wsj/s5/utils/add_lex_disambig.pl b/egs/wsj/s5/utils/add_lex_disambig.pl index dd8a25de6e1..c4277e8dc06 100755 --- a/egs/wsj/s5/utils/add_lex_disambig.pl +++ b/egs/wsj/s5/utils/add_lex_disambig.pl @@ -122,6 +122,7 @@ if ($sil_probs) { shift @A; # Remove silprob shift @A; # Remove silprob + shift @A; # Remove silprob, there three numbers for sil_probs } while(@A > 0) { pop @A; # Remove last phone diff --git a/egs/wsj/s5/utils/apply_map.pl b/egs/wsj/s5/utils/apply_map.pl index ff9507fd894..a138287170b 100755 --- a/egs/wsj/s5/utils/apply_map.pl +++ b/egs/wsj/s5/utils/apply_map.pl @@ -9,47 +9,59 @@ # be sequences of tokens. See the usage message. -if (@ARGV > 0 && $ARGV[0] eq "-f") { - shift @ARGV; - $field_spec = shift @ARGV; - if ($field_spec =~ m/^\d+$/) { - $field_begin = $field_spec - 1; $field_end = $field_spec - 1; - } - if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10) - if ($1 ne "") { - $field_begin = $1 - 1; # Change to zero-based indexing. +$permissive = 0; + +for ($x = 0; $x <= 2; $x++) { + + if (@ARGV > 0 && $ARGV[0] eq "-f") { + shift @ARGV; + $field_spec = shift @ARGV; + if ($field_spec =~ m/^\d+$/) { + $field_begin = $field_spec - 1; $field_end = $field_spec - 1; } - if ($2 ne "") { - $field_end = $2 - 1; # Change to zero-based indexing. + if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10) + if ($1 ne "") { + $field_begin = $1 - 1; # Change to zero-based indexing. + } + if ($2 ne "") { + $field_end = $2 - 1; # Change to zero-based indexing. + } + } + if (!defined $field_begin && !defined $field_end) { + die "Bad argument to -f option: $field_spec"; } } - if (!defined $field_begin && !defined $field_end) { - die "Bad argument to -f option: $field_spec"; - } -} -# Mapping is obligatory -$permissive = 0; -if (@ARGV > 0 && $ARGV[0] eq '--permissive') { - shift @ARGV; - # Mapping is optional (missing key is printed to output) - $permissive = 1; + if (@ARGV > 0 && $ARGV[0] eq '--permissive') { + shift @ARGV; + # Mapping is optional (missing key is printed to output) + $permissive = 1; + } } if(@ARGV != 1) { print STDERR "Invalid usage: " . join(" ", @ARGV) . "\n"; - print STDERR "Usage: apply_map.pl [options] map output\n" . - "options: [-f ]\n" . - "Applies the map 'map' to all input text, where each line of the map\n" . - "is interpreted as a map from the first field to the list of the other fields\n" . - "Note: can look like 4-5, or 4-, or 5-, or 1, it means the field\n" . - "range in the input to apply the map to.\n" . - "e.g.: echo A B | apply_map.pl a.txt\n" . - "where a.txt is:\n" . - "A a1 a2\n" . - "B b\n" . - "will produce:\n" . - "a1 a2 b\n"; + print STDERR <<'EOF'; +Usage: apply_map.pl [options] map output + options: [-f ] [--permissive] + This applies a map to some specified fields of some input text: + For each line in the map file: the first field is the thing wae + map from, and the remaining fields are the sequence we map it to. + The -f (field-range) option says which fields of the input file the map + map should apply to. + If the --permissive option is supplied, fields which are not present + in the map will be left as they were. + Applies the map 'map' to all input text, where each line of the map + is interpreted as a map from the first field to the list of the other fields + Note: can look like 4-5, or 4-, or 5-, or 1, it means the field + range in the input to apply the map to. + e.g.: echo A B | apply_map.pl a.txt + where a.txt is: + A a1 a2 + B b + will produce: + a1 a2 b +EOF exit(1); } @@ -72,12 +84,12 @@ $a = $A[$x]; if (!defined $map{$a}) { if (!$permissive) { - die "apply_map.pl: undefined key $a in $map_file\n"; + die "apply_map.pl: undefined key $a in $map_file\n"; } else { print STDERR "apply_map.pl: warning! missing key $a in $map_file\n"; } } else { - $A[$x] = $map{$a}; + $A[$x] = $map{$a}; } } } diff --git a/egs/wsj/s5/utils/build_const_arpa_lm.sh b/egs/wsj/s5/utils/build_const_arpa_lm.sh index ec067df0d39..51aca1bb2ad 100755 --- a/egs/wsj/s5/utils/build_const_arpa_lm.sh +++ b/egs/wsj/s5/utils/build_const_arpa_lm.sh @@ -34,8 +34,8 @@ mkdir -p $new_lang cp -r $old_lang/* $new_lang unk=`cat $new_lang/oov.int` -bos=`grep -w "" $new_lang/words.txt | awk '{print $2}'` -eos=`grep "" $new_lang/words.txt | awk '{print $2}'` +bos=`grep "^\s" $new_lang/words.txt | awk '{print $2}'` +eos=`grep "^\s" $new_lang/words.txt | awk '{print $2}'` if [[ -z $bos || -z $eos ]]; then echo "$0: and symbols are not in $new_lang/words.txt" exit 1 diff --git a/egs/wsj/s5/utils/ctm/resolve_ctm_overlaps.py b/egs/wsj/s5/utils/ctm/resolve_ctm_overlaps.py index deb8207c5b7..61c9a3014aa 100755 --- a/egs/wsj/s5/utils/ctm/resolve_ctm_overlaps.py +++ b/egs/wsj/s5/utils/ctm/resolve_ctm_overlaps.py @@ -17,6 +17,7 @@ """ from __future__ import print_function +from __future__ import division import argparse import collections import logging @@ -231,7 +232,7 @@ def resolve_overlaps(ctms, segments): try: index = next( (i for i, line in enumerate(ctm_for_next_utt) - if line[2] + line[3] / 2.0 > overlap / 2.0)) + if line[2] + line[3] / 2.0 > overlap / 2.0)) except StopIteration: # This can happen if there is no word hypothesized after # half the overlap region. @@ -277,7 +278,7 @@ def run(args): segments, reco2utt = read_segments(args.segments) ctms = read_ctm(args.ctm_in, segments) - for reco, utts in reco2utt.iteritems(): + for reco, utts in reco2utt.items(): ctms_for_reco = [] for utt in sorted(utts, key=lambda x: segments[x][1]): if (reco, utt) in ctms: diff --git a/egs/wsj/s5/utils/data/convert_data_dir_to_whole.sh b/egs/wsj/s5/utils/data/convert_data_dir_to_whole.sh index dd315cc405b..c113bb512ef 100755 --- a/egs/wsj/s5/utils/data/convert_data_dir_to_whole.sh +++ b/egs/wsj/s5/utils/data/convert_data_dir_to_whole.sh @@ -35,9 +35,15 @@ if [ -f $data/reco2file_and_channel ]; then fi mkdir -p $dir/.backup -mv $dir/feats.scp $dir/cmvn.scp $dir/.backup - -rm $dir/utt2spk || true +if [ -f $dir/feats.scp ]; then + mv $dir/feats.scp $dir/.backup +fi +if [ -f $dir/cmvn.scp ]; then + mv $dir/cmvn.scp $dir/.backup +fi +if [ -f $dir/utt2spk ]; then + mv $dir/utt2spk $dir/.backup +fi [ -f $data/stm ] && cp $data/stm $dir [ -f $data/glm ] && cp $data/glm $dir diff --git a/egs/wsj/s5/utils/data/get_uniform_subsegments.py b/egs/wsj/s5/utils/data/get_uniform_subsegments.py index c61b96e0dbb..cc3015564a5 100755 --- a/egs/wsj/s5/utils/data/get_uniform_subsegments.py +++ b/egs/wsj/s5/utils/data/get_uniform_subsegments.py @@ -4,6 +4,7 @@ # 2017 Matthew Maciejewski # Apache 2.0. +from __future__ import print_function import argparse import logging import sys diff --git a/egs/wsj/s5/utils/data/get_utt2num_frames.sh b/egs/wsj/s5/utils/data/get_utt2num_frames.sh index a6d4f0ecb10..d8b006a5fc0 100755 --- a/egs/wsj/s5/utils/data/get_utt2num_frames.sh +++ b/egs/wsj/s5/utils/data/get_utt2num_frames.sh @@ -10,13 +10,14 @@ frame_shift=0.01 frame_overlap=0.015 . utils/parse_options.sh +. ./path.sh if [ $# -ne 1 ]; then echo "This script writes a file utt2num_frames with the " echo "number of frames in each utterance as measured based on the " echo "duration of the utterances (in utt2dur) and the specified " echo "frame_shift and frame_overlap." - echo "Usage: $0 " + echo "Usage: $0 " exit 1 fi diff --git a/egs/wsj/s5/utils/data/internal/choose_utts_to_combine.py b/egs/wsj/s5/utils/data/internal/choose_utts_to_combine.py index 740b9aa612b..875c238abd5 100755 --- a/egs/wsj/s5/utils/data/internal/choose_utts_to_combine.py +++ b/egs/wsj/s5/utils/data/internal/choose_utts_to_combine.py @@ -89,7 +89,7 @@ def CombineList(min_duration, durations): # for each utterance-index i, group_start[i] gives us the # start-index of the group of utterances of which it's currently # a member. - group_start = range(num_utts) + group_start = list(range(num_utts)) # if utterance-index i currently corresponds to the start of a group # of utterances, then group_durations[i] is the total duration of # that utterance-group, otherwise undefined. @@ -327,7 +327,7 @@ def GetUtteranceGroups(min_duration, spk2utt, utt2dur): utt_groups = GetUtteranceGroups(args.min_duration, spk2utt, utt2dur) # set utt_group names to an array like [ 'utt1', 'utt2-comb2', 'utt4', ... ] -utt_group_names = [ group[0] if len(group)==1 else group[0] + "-comb" + str(len(group)) +utt_group_names = [ group[0] if len(group)==1 else "{0}-comb{1}".format(group[0], len(group)) for group in utt_groups ] diff --git a/egs/wsj/s5/utils/data/internal/perturb_volume.py b/egs/wsj/s5/utils/data/internal/perturb_volume.py index b3bd4225191..c1dfd936358 100755 --- a/egs/wsj/s5/utils/data/internal/perturb_volume.py +++ b/egs/wsj/s5/utils/data/internal/perturb_volume.py @@ -8,6 +8,7 @@ volume of the recordings and writes to stdout the contents of a new wav.scp file. """ +from __future__ import print_function import argparse import re diff --git a/egs/wsj/s5/utils/data/perturb_data_dir_volume.sh b/egs/wsj/s5/utils/data/perturb_data_dir_volume.sh index dae440b03a3..e357ba8cbfb 100755 --- a/egs/wsj/s5/utils/data/perturb_data_dir_volume.sh +++ b/egs/wsj/s5/utils/data/perturb_data_dir_volume.sh @@ -52,15 +52,15 @@ for line in sys.stdin.readlines(): parts = line.strip().split() if line.strip()[-1] == '|': if re.search('sox --vol', ' '.join(parts[-11:])): - print 'true' + print('true') sys.exit(0) elif re.search(':[0-9]+$', line.strip()) is not None: continue else: if ' '.join(parts[1:3]) == 'sox --vol': - print 'true' + print('true') sys.exit(0) -print 'false' +print('false') "` || exit 1 if $volume_perturb_done; then diff --git a/egs/wsj/s5/utils/data/perturb_speed_to_allowed_lengths.py b/egs/wsj/s5/utils/data/perturb_speed_to_allowed_lengths.py index c6bdb95cb2f..ae16e63c945 100755 --- a/egs/wsj/s5/utils/data/perturb_speed_to_allowed_lengths.py +++ b/egs/wsj/s5/utils/data/perturb_speed_to_allowed_lengths.py @@ -60,13 +60,13 @@ def get_args(): args.speed_perturb = True if args.speed_perturb == 'true' else False return args -class Utterance: +class Utterance(object): """ This class represents a Kaldi utterance in a data directory like data/train """ def __init__(self, uid, wavefile, speaker, transcription, dur): - self.wavefile = (wavefile if wavefile.rstrip().endswith('|') else + self.wavefile = (wavefile if wavefile.rstrip(" \t\r\n").endswith('|') else 'cat {} |'.format(wavefile)) self.speaker = speaker self.transcription = transcription @@ -130,7 +130,7 @@ def read_kaldi_mapfile(path): m = {} with open(path, 'r', encoding='latin-1') as f: for line in f: - line = line.strip() + line = line.strip(" \t\r\n") sp_pos = line.find(' ') key = line[:sp_pos] val = line[sp_pos+1:] @@ -321,7 +321,7 @@ def main(): "Coverage rate: {}%".format(start_dur, end_dur, 100.0 - args.coverage_factor * 2)) logger.info("There will be {} unique allowed lengths " - "for the utterances.".format(int(math.log(end_dur / start_dur) / + "for the utterances.".format(int(math.log(end_dur / start_dur)/ math.log(args.factor)))) allowed_durations = find_allowed_durations(start_dur, end_dur, args) diff --git a/egs/wsj/s5/utils/data/resample_data_dir.sh b/egs/wsj/s5/utils/data/resample_data_dir.sh index b972bcc119a..8d96667092f 100755 --- a/egs/wsj/s5/utils/data/resample_data_dir.sh +++ b/egs/wsj/s5/utils/data/resample_data_dir.sh @@ -39,7 +39,6 @@ for line in sys.stdin.readlines(): if splits[-1] == '|': out_line = line.strip() + ' $sox -t wav - -c 1 -b 16 -t wav - rate $freq |' else: - out_line = 'cat {0} {1} | $sox -t wav - -c 1 -b 16 -t wav - rate $freq |'.format(splits[0], ' '.join(splits[1:])) + out_line = '{0} cat {1} | $sox -t wav - -c 1 -b 16 -t wav - rate $freq |'.format(splits[0], ' '.join(splits[1:])) print (out_line)" > ${dir}/wav.scp rm $dir/wav.scp.tmp - diff --git a/egs/wsj/s5/utils/filt.py b/egs/wsj/s5/utils/filt.py index 2847c0034dd..9201d9e493f 100755 --- a/egs/wsj/s5/utils/filt.py +++ b/egs/wsj/s5/utils/filt.py @@ -2,6 +2,7 @@ # Apache 2.0 +from __future__ import print_function import sys vocab=set() @@ -11,4 +12,4 @@ with open(sys.argv[2]) as textfile: for line in textfile: - print " ".join(map(lambda word: word if word in vocab else '', line.strip().split())) + print(" ".join([word if word in vocab else '' for word in line.strip().split()])) diff --git a/egs/wsj/s5/utils/format_lm_sri.sh b/egs/wsj/s5/utils/format_lm_sri.sh index 4ef31d925ca..08f842a08f5 100755 --- a/egs/wsj/s5/utils/format_lm_sri.sh +++ b/egs/wsj/s5/utils/format_lm_sri.sh @@ -48,8 +48,6 @@ else out_dir=$3 fi -mkdir -p $out_dir - for f in $lm $lang_dir/words.txt; do if [ ! -f $f ]; then echo "$0: expected input file $f to exist." @@ -73,7 +71,6 @@ trap 'rm -rf "$tmpdir"' EXIT mkdir -p $out_dir cp -r $lang_dir/* $out_dir || exit 1; -lm_base=$(basename $lm '.gz') awk '{print $1}' $out_dir/words.txt > $tmpdir/voc || exit 1; # Change the LM vocabulary to be the intersection of the current LM vocabulary diff --git a/egs/wsj/s5/utils/lang/bpe/bidi.py b/egs/wsj/s5/utils/lang/bpe/bidi.py new file mode 100755 index 00000000000..447313a5d02 --- /dev/null +++ b/egs/wsj/s5/utils/lang/bpe/bidi.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright 2018 Chun-Chieh Chang + +# This script is largely written by Stephen Rawls +# and uses the python package https://pypi.org/project/PyICU_BiDi/ +# The code leaves right to left text alone and reverses left to right text. + +import icu_bidi +import io +import sys +import unicodedata +# R=strong right-to-left; AL=strong arabic right-to-left +rtl_set = set(chr(i) for i in range(sys.maxunicode) + if unicodedata.bidirectional(chr(i)) in ['R','AL']) +def determine_text_direction(text): + # Easy case first + for char in text: + if char in rtl_set: + return icu_bidi.UBiDiLevel.UBIDI_RTL + # If we made it here we did not encounter any strongly rtl char + return icu_bidi.UBiDiLevel.UBIDI_LTR + +def utf8_visual_to_logical(text): + text_dir = determine_text_direction(text) + + bidi = icu_bidi.Bidi() + bidi.inverse = True + bidi.reordering_mode = icu_bidi.UBiDiReorderingMode.UBIDI_REORDER_INVERSE_LIKE_DIRECT + bidi.reordering_options = icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_DEFAULT # icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_INSERT_MARKS + + bidi.set_para(text, text_dir, None) + + res = bidi.get_reordered(0 | icu_bidi.UBidiWriteReorderedOpt.UBIDI_DO_MIRRORING | icu_bidi.UBidiWriteReorderedOpt.UBIDI_KEEP_BASE_COMBINING) + + return res + +def utf8_logical_to_visual(text): + text_dir = determine_text_direction(text) + + bidi = icu_bidi.Bidi() + + bidi.reordering_mode = icu_bidi.UBiDiReorderingMode.UBIDI_REORDER_DEFAULT + bidi.reordering_options = icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_DEFAULT #icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_INSERT_MARKS + + bidi.set_para(text, text_dir, None) + + res = bidi.get_reordered(0 | icu_bidi.UBidiWriteReorderedOpt.UBIDI_DO_MIRRORING | icu_bidi.UBidiWriteReorderedOpt.UBIDI_KEEP_BASE_COMBINING) + + return res + + +##main## +sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding="utf8") +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf8") +for line in sys.stdin: + line = line.strip() + line = utf8_logical_to_visual(line)[::-1] + sys.stdout.write(line + '\n') diff --git a/egs/wsj/s5/utils/lang/bpe/learn_bpe.py b/egs/wsj/s5/utils/lang/bpe/learn_bpe.py index 70f18f2d1d9..f6c6d5a0ebb 100755 --- a/egs/wsj/s5/utils/lang/bpe/learn_bpe.py +++ b/egs/wsj/s5/utils/lang/bpe/learn_bpe.py @@ -13,6 +13,8 @@ """ from __future__ import unicode_literals +from __future__ import division +from __future__ import print_function import sys import codecs diff --git a/egs/wsj/s5/utils/lang/bpe/prepend_words.py b/egs/wsj/s5/utils/lang/bpe/prepend_words.py new file mode 100755 index 00000000000..4a11895a712 --- /dev/null +++ b/egs/wsj/s5/utils/lang/bpe/prepend_words.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 + +# This script, prepend '|' to every words in the transcript to mark +# the beginning of the words for finding the initial-space of every word +# after decoding. + +import sys +import io +import re + +whitespace = re.compile("[ \t]+") +infile = io.TextIOWrapper(sys.stdin.buffer, encoding='latin-1') +output = io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1') +for line in infile: + words = whitespace.split(line.strip(" \t\r\n")) + output.write(' '.join([ "|"+word for word in words]) + '\n') diff --git a/egs/madcat_ar/v1/local/reverse.py b/egs/wsj/s5/utils/lang/bpe/reverse.py similarity index 100% rename from egs/madcat_ar/v1/local/reverse.py rename to egs/wsj/s5/utils/lang/bpe/reverse.py diff --git a/egs/wsj/s5/utils/lang/check_g_properties.pl b/egs/wsj/s5/utils/lang/check_g_properties.pl index ee0f6ddb515..e092b606157 100755 --- a/egs/wsj/s5/utils/lang/check_g_properties.pl +++ b/egs/wsj/s5/utils/lang/check_g_properties.pl @@ -28,6 +28,7 @@ ($sym, $int) = @A; if ($sym eq "" || $sym eq "") { $is_forbidden{$int} = 1; } if ($sym eq "#0") { $hash_zero = $int; } + if ($sym =~ m/^#nonterm/) { $is_nonterminal{$int} = 1; } } if (-e "$lang/phones/wdisambig_words.int") { @@ -65,9 +66,9 @@ } elsif ($A[2] == 0) { print I $_; $has_epsilons = 1; - } elsif ($A[2] != $A[3]) { + } elsif ($A[2] != $A[3] && !$is_nonterminal{$A[2]} ) { chop; - print "$0: validating $lang: error: line $_ in G.fst has inputs and outputs different but input is not disambig symbol.\n"; + print "$0: validating $lang: error: line $_ in G.fst has inputs and outputs different but input is not disambig symbol or nonterminal.\n"; exit(1); } } diff --git a/egs/wsj/s5/utils/lang/compute_sentence_probs_arpa.py b/egs/wsj/s5/utils/lang/compute_sentence_probs_arpa.py index 5a7743badee..dc480903db4 100755 --- a/egs/wsj/s5/utils/lang/compute_sentence_probs_arpa.py +++ b/egs/wsj/s5/utils/lang/compute_sentence_probs_arpa.py @@ -99,13 +99,13 @@ def compute_begin_prob(sub_list): for i in range(1, len(sub_list) - 1): logprob += compute_sublist_prob(sub_list[:i + 1]) return logprob - + # The probability is computed in this way: # p(word_N | word_N-1 ... word_1) = ngram_dict[word_1 ... word_N][0]. # Here gram_dict is a dictionary stores a tuple corresponding to ngrams. # The first element of tuple is probablity and the second is backoff probability (if exists). # If the particular ngram (word_1 ... word_N) is not in the dictionary, then -# p(word_N | word_N-1 ... word_1) = p(word_N | word_(N-1) ... word_2) * backoff_weight(word_(N-1) | word_(N-2) ... word_1) +# p(word_N | word_N-1 ... word_1) = p(word_N | word_(N-1) ... word_2) * backoff_weight(word_(N-1) | word_(N-2) ... word_1) # If the sequence (word_(N-1) ... word_1) is not in the dictionary, then the backoff_weight gets replaced with 0.0 (log1) # More details can be found in https://cmusphinx.github.io/wiki/arpaformat/ def compute_sentence_prob(sentence, ngram_order): @@ -127,7 +127,7 @@ def compute_sentence_prob(sentence, ngram_order): logprob += compute_sublist_prob(cur_sublist) return logprob - + def output_result(text_in_handle, output_file_handle, ngram_order): lines = text_in_handle.readlines() @@ -139,8 +139,8 @@ def output_result(text_in_handle, output_file_handle, ngram_order): output_file_handle.write("{}\n".format(new_logprob)) text_in_handle.close() output_file_handle.close() - - + + if __name__ == "__main__": check_args(args) ngram_dict, tot_num = load_model(args.arpa_lm) @@ -149,7 +149,7 @@ def output_result(text_in_handle, output_file_handle, ngram_order): if not num_valid: sys.exit("compute_sentence_probs_arpa.py: Wrong loading model.") if args.ngram_order <= 0 or args.ngram_order > max_ngram_order: - sys.exit("compute_sentence_probs_arpa.py: " + + sys.exit("compute_sentence_probs_arpa.py: " + "Invalid ngram_order (either negative or greater than maximum ngram number ({}) allowed)".format(max_ngram_order)) output_result(args.text_in_handle, args.prob_file_handle, args.ngram_order) diff --git a/egs/wsj/s5/utils/lang/extend_lang.sh b/egs/wsj/s5/utils/lang/extend_lang.sh new file mode 100755 index 00000000000..c13d5d3e78b --- /dev/null +++ b/egs/wsj/s5/utils/lang/extend_lang.sh @@ -0,0 +1,184 @@ +#!/bin/bash +# Copyright 2018 Johns Hopkins University (Author: Daniel Povey); + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# derived files, that go in data/lang/. + +# Begin configuration section. +sil_prob=0.5 +# end configuration section + +echo "$0 $@" # Print the command line for logging + +. utils/parse_options.sh + +if [ $# -ne 3 ]; then + echo "Usage: utils/extend_lang.sh " + echo "e.g.: utils/extend_lang.sh data/lang data/local/dict_new_words/lexiconp.txt data/lang_new_words" + echo "" + echo "This script creates a lang/ directory with L.fst and L_disambig.fst" + echo "derived from the provided lexicon, but all other information being the same as the old" + echo "lang/ directory, including the phones.txt and words.txt being compatible (however," + echo "words.txt may have new words, and phones.txt may have extra disambiguation symbols" + echo "if needed). We do not allow new phones." + echo "" + echo "CAUTION: the lexicon generated will only cover the words in the provided lexicon," + echo "which might not include all the words in words.txt. You should make sure your" + echo "lexicon is a superset of the original lexicon used to generate ," + echo "if this would be a problem for your scenario." + echo "" + echo "The basename of must be either lexicon.txt, lexiconp.txt or lexiconp_silprob.txt." + echo "" + echo "Options" + echo " --sil-prob # default: 0.5 [must have 0 <= silprob < 1]" + exit 1; +fi + +srcdir=$1 +lexicon=$2 +dir=$3 + +[ -f path.sh ] && . ./path.sh + +for f in $srcdir/phones.txt $lexicon; do + if [ ! -f $f ]; then + echo "$0: expected file $f to exist" + exit 1 + fi +done + +if ! awk '{if(NF < 2) exit(1)} END{if(NR==0) exit(1)}' <$lexicon; then + echo "$0: it looks like there words without pronunciations or.." + echo " ...blank lines in $lexicon, or it is empty." + exit 1 +fi + +mkdir -p $dir + +if [ -d $dir/phones ]; then rm -r $dir/phones; fi + +cp -r $srcdir/phones $dir/ + +for f in oov.int oov.txt phones.txt topo words.txt; do + cp $srcdir/$f $dir/ +done + +tmpdir=$dir/temp +rm -r $tmpdir 2>/dev/null +mkdir -p $tmpdir + + +# TODO: more checking. +if [ $(basename $lexicon) != lexiconp.txt ]; then + echo "$0: currently this script only supports the lexiconp.txt format; your lexicon" + echo " ... has to have that filename." +fi + +# Get the list of extra words. +awk -v w=$srcdir/words.txt 'BEGIN{while(getline $tmpdir/extra_words.txt + +# Add entries to words.txt for all the words that were not previously in the +# lexicon. +highest_number=$(tail -n 1 $srcdir/words.txt | awk '{print $2}') +awk -v start=$highest_number '{print $1, NR+start}' <$tmpdir/extra_words.txt >>$dir/words.txt +echo "$0: added $(wc -l <$tmpdir/extra_words.txt) extra words to words.txt" + +if [ -f $dir/phones/nonterminals.txt ]; then + # extra grammar-decoding-related options for getting the lexicon. + grammar_opts="--left-context-phones=$dir/phones/left_context_phones.txt --nonterminals=$srcdir/phones/nonterminals.txt" +else + grammar_opts="" +fi + +if [ -f $dir/phones/word_boundary.txt ]; then + # was `if $position_dependent_phones; then..` in prepare_lang.sh + # TODO: add support for silprobs + perl -ane '@A=split(" ",$_); $w = shift @A; $p = shift @A; @A>0||die; + if(@A==1) { print "$w $p $A[0]_S\n"; } else { print "$w $p $A[0]_B "; + for($n=1;$n<@A-1;$n++) { print "$A[$n]_I "; } print "$A[$n]_E\n"; } ' \ + < $lexicon > $tmpdir/lexiconp.txt || exit 1; +else + cp $lexicon $tmpdir/lexiconp.txt +fi + +# Check that there are no unseen phones in the lexicon. +if ! utils/sym2int.pl -f 3- $srcdir/phones.txt $tmpdir/lexiconp.txt >/dev/null; then + echo "$0: it looks like there are unseen phones in your lexicon $lexicon" + exit 1 +fi + +ndisambig=$(utils/add_lex_disambig.pl --pron-probs $tmpdir/lexiconp.txt $tmpdir/lexiconp_disambig.txt) + +ndisambig=$[ndisambig+1] # Add one to disambiguate silence. + +# we'll need to figure out whether any of these disambiguation symbols are +# absent from our current disambiguation phones.. if they are, then we need to +# add them as new disambiguation symbols to phones.txt. +for n in $(seq 0 $ndisambig); do + sym='#'$n; if ! grep -w -q "$sym" $dir/phones/disambig.txt; then echo "$sym"; fi +done > $tmpdir/extra_disambig.txt +highest_number=$(tail -n 1 $srcdir/phones.txt | awk '{print $2}') +awk -v start=$highest_number '{print $1, NR+start}' <$tmpdir/extra_disambig.txt >>$dir/words.txt +echo "$0: added $(wc -l <$tmpdir/extra_disambig.txt) extra disambiguation symbols to phones.txt" + + +silphone=`cat $srcdir/phones/optional_silence.txt` || exit 1; +[ -z "$silphone" ] && \ + ( echo "You have no optional-silence phone; it is required in the current scripts" + echo "but you may use the option --sil-prob 0.0 to stop it being used." ) && \ + exit 1; + + +# First remove pron-probs from the lexicon. +perl -ape 's/(\S+\s+)\S+\s+(.+)/$1$2/;' <$tmpdir/lexiconp.txt >$tmpdir/align_lexicon.txt + +# Note: here, $silphone will have no suffix e.g. _S because it occurs as optional-silence, +# and is not part of a word. +[ ! -z "$silphone" ] && echo " $silphone" >> $tmpdir/align_lexicon.txt + +cat $tmpdir/align_lexicon.txt | \ + perl -ane '@A = split; print $A[0], " ", join(" ", @A), "\n";' | sort | uniq > $dir/phones/align_lexicon.txt + +# create phones/align_lexicon.int from phones/align_lexicon.txt +cat $dir/phones/align_lexicon.txt | utils/sym2int.pl -f 3- $dir/phones.txt | \ + utils/sym2int.pl -f 1-2 $dir/words.txt > $dir/phones/align_lexicon.int + +# Create the basic L.fst without disambiguation symbols, for use +# in training. + +utils/lang/make_lexicon_fst.py $grammar_opts --sil-prob=$sil_prob --sil-phone=$silphone \ + $tmpdir/lexiconp.txt | \ + fstcompile --isymbols=$dir/phones.txt --osymbols=$dir/words.txt \ + --keep_isymbols=false --keep_osymbols=false | \ + fstarcsort --sort_type=olabel > $dir/L.fst || exit 1; + + +# and create the version that has disambiguation symbols. +utils/lang/make_lexicon_fst.py $grammar_opts \ + --sil-prob=$sil_prob --sil-phone=$silphone --sil-disambig='#'$ndisambig \ + $tmpdir/lexiconp_disambig.txt | \ + fstcompile --isymbols=$dir/phones.txt --osymbols=$dir/words.txt \ + --keep_isymbols=false --keep_osymbols=false | \ + fstaddselfloops $dir/phones/wdisambig_phones.int $dir/phones/wdisambig_words.int | \ + fstarcsort --sort_type=olabel > $dir/L_disambig.fst || exit 1; + + +echo "$(basename $0): validating output directory" +# the --skip-generate-words-check option is needed because L.fst may not actually +# contain all the words in words.txt. +! utils/validate_lang.pl --skip-generate-words-check $dir && echo "$(basename $0): error validating output" && exit 1; + +exit 0; diff --git a/egs/wsj/s5/utils/lang/grammar/augment_phones_txt.py b/egs/wsj/s5/utils/lang/grammar/augment_phones_txt.py new file mode 100755 index 00000000000..f0087680a4b --- /dev/null +++ b/egs/wsj/s5/utils/lang/grammar/augment_phones_txt.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 + + +import argparse +import re +import os +import sys + +def get_args(): + parser = argparse.ArgumentParser(description="""This script augments a phones.txt + file (a phone-level symbol table) by adding certain special symbols + relating to grammar support. See ../add_nonterminals.sh for context.""") + + parser.add_argument('input_phones_txt', type=str, + help='Filename of input phones.txt file, to be augmented') + parser.add_argument('nonterminal_symbols_list', type=str, + help='Filename of a file containing a list of nonterminal ' + 'symbols, one per line. E.g. #nonterm:contact_list') + parser.add_argument('output_phones_txt', type=str, help='Filename of output ' + 'phones.txt file. May be the same as input-phones-txt.') + args = parser.parse_args() + return args + + + + +def read_phones_txt(filename): + """Reads the phones.txt file in 'filename', returns a 2-tuple (lines, highest_symbol) + where 'lines' is all the lines the phones.txt as a list of strings, + and 'highest_symbol' is the integer value of the highest-numbered symbol + in the symbol table. It is an error if the phones.txt is empty or mis-formatted.""" + + # The use of latin-1 encoding does not preclude reading utf-8. latin-1 + # encoding means "treat words as sequences of bytes", and it is compatible + # with utf-8 encoding as well as other encodings such as gbk, as long as the + # spaces are also spaces in ascii (which we check). It is basically how we + # emulate the behavior of python before python3. + whitespace = re.compile("[ \t]+") + with open(filename, 'r', encoding='latin-1') as f: + lines = [line.strip(" \t\r\n") for line in f] + highest_numbered_symbol = 0 + for line in lines: + s = whitespace.split(line) + try: + i = int(s[1]) + if i > highest_numbered_symbol: + highest_numbered_symbol = i + except: + raise RuntimeError("Could not interpret line '{0}' in file '{1}'".format( + line, filename)) + if s[0] == '#nonterm_bos': + raise RuntimeError("It looks like the symbol table {0} already has nonterminals " + "in it.".format(filename)) + return lines, highest_numbered_symbol + + +def read_nonterminals(filename): + """Reads the user-defined nonterminal symbols in 'filename', checks that + it has the expected format and has no duplicates, and returns the nonterminal + symbols as a list of strings, e.g. + ['#nonterm:contact_list', '#nonterm:phone_number', ... ]. """ + ans = [line.strip(" \t\r\n") for line in open(filename, 'r', encoding='latin-1')] + if len(ans) == 0: + raise RuntimeError("The file {0} contains no nonterminal symbols.".format(filename)) + for nonterm in ans: + if nonterm[:9] != '#nonterm:': + raise RuntimeError("In file '{0}', expected nonterminal symbols to start with '#nonterm:', found '{1}'" + .format(filename, nonterm)) + if len(set(ans)) != len(ans): + raise RuntimeError("Duplicate nonterminal symbols are present in file {0}".format(filename)) + return ans + +def write_phones_txt(orig_lines, highest_numbered_symbol, nonterminals, filename): + """Writes updated phones.txt to 'filename'. 'orig_lines' is the original lines + in the phones.txt file as a list of strings (without the newlines); + highest_numbered_symbol is the highest numbered symbol in the original + phones.txt; nonterminals is a list of strings like '#nonterm:foo'.""" + with open(filename, 'w', encoding='latin-1') as f: + for l in orig_lines: + print(l, file=f) + cur_symbol = highest_numbered_symbol + 1 + for n in ['#nonterm_bos', '#nonterm_begin', '#nonterm_end', '#nonterm_reenter' ] + nonterminals: + print("{0} {1}".format(n, cur_symbol), file=f) + cur_symbol = cur_symbol + 1 + + + +def main(): + args = get_args() + (lines, highest_symbol) = read_phones_txt(args.input_phones_txt) + nonterminals = read_nonterminals(args.nonterminal_symbols_list) + write_phones_txt(lines, highest_symbol, nonterminals, args.output_phones_txt) + + +if __name__ == '__main__': + main() diff --git a/egs/wsj/s5/utils/lang/grammar/augment_words_txt.py b/egs/wsj/s5/utils/lang/grammar/augment_words_txt.py new file mode 100755 index 00000000000..1bfe02a2c9d --- /dev/null +++ b/egs/wsj/s5/utils/lang/grammar/augment_words_txt.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 + + +import argparse +import os +import sys +import re + +def get_args(): + parser = argparse.ArgumentParser(description="""This script augments a words.txt + file (a word-level symbol table) by adding certain special symbols + relating to grammar support. See ../add_nonterminals.sh for context, + and augment_phones_txt.py.""") + + parser.add_argument('input_words_txt', type=str, + help='Filename of input words.txt file, to be augmented') + parser.add_argument('nonterminal_symbols_list', type=str, + help='Filename of a file containing a list of nonterminal ' + 'symbols, one per line. E.g. #nonterm:contact_list') + parser.add_argument('output_words_txt', type=str, help='Filename of output ' + 'words.txt file. May be the same as input-words-txt.') + args = parser.parse_args() + return args + + + + +def read_words_txt(filename): + """Reads the words.txt file in 'filename', returns a 2-tuple (lines, highest_symbol) + where 'lines' is all the lines the words.txt as a list of strings, + and 'highest_symbol' is the integer value of the highest-numbered symbol + in the symbol table. It is an error if the words.txt is empty or mis-formatted.""" + + # The use of latin-1 encoding does not preclude reading utf-8. latin-1 + # encoding means "treat words as sequences of bytes", and it is compatible + # with utf-8 encoding as well as other encodings such as gbk, as long as the + # spaces are also spaces in ascii (which we check). It is basically how we + # emulate the behavior of python before python3. + whitespace = re.compile("[ \t]+") + with open(filename, 'r', encoding='latin-1') as f: + lines = [line.strip(" \t\r\n") for line in f] + highest_numbered_symbol = 0 + for line in lines: + s = whitespace.split(line) + try: + i = int(s[1]) + if i > highest_numbered_symbol: + highest_numbered_symbol = i + except: + raise RuntimeError("Could not interpret line '{0}' in file '{1}'".format( + line, filename)) + if s[0] in [ '#nonterm_begin', '#nonterm_end' ]: + raise RuntimeError("It looks like the symbol table {0} already has nonterminals " + "in it.".format(filename)) + return lines, highest_numbered_symbol + + +def read_nonterminals(filename): + """Reads the user-defined nonterminal symbols in 'filename', checks that + it has the expected format and has no duplicates, and returns the nonterminal + symbols as a list of strings, e.g. + ['#nonterm:contact_list', '#nonterm:phone_number', ... ]. """ + ans = [line.strip(" \t\r\n") for line in open(filename, 'r', encoding='latin-1')] + if len(ans) == 0: + raise RuntimeError("The file {0} contains no nonterminal symbols.".format(filename)) + for nonterm in ans: + if nonterm[:9] != '#nonterm:': + raise RuntimeError("In file '{0}', expected nonterminal symbols to start with '#nonterm:', found '{1}'" + .format(filename, nonterm)) + if len(set(ans)) != len(ans): + raise RuntimeError("Duplicate nonterminal symbols are present in file {0}".format(filename)) + return ans + +def write_words_txt(orig_lines, highest_numbered_symbol, nonterminals, filename): + """Writes updated words.txt to 'filename'. 'orig_lines' is the original lines + in the words.txt file as a list of strings (without the newlines); + highest_numbered_symbol is the highest numbered symbol in the original + words.txt; nonterminals is a list of strings like '#nonterm:foo'.""" + with open(filename, 'w', encoding='latin-1') as f: + for l in orig_lines: + print(l, file=f) + cur_symbol = highest_numbered_symbol + 1 + for n in [ '#nonterm_begin', '#nonterm_end' ] + nonterminals: + print("{0} {1}".format(n, cur_symbol), file=f) + cur_symbol = cur_symbol + 1 + + +def main(): + args = get_args() + (lines, highest_symbol) = read_words_txt(args.input_words_txt) + nonterminals = read_nonterminals(args.nonterminal_symbols_list) + write_words_txt(lines, highest_symbol, nonterminals, args.output_words_txt) + + +if __name__ == '__main__': + main() diff --git a/egs/wsj/s5/utils/lang/internal/arpa2fst_constrained.py b/egs/wsj/s5/utils/lang/internal/arpa2fst_constrained.py index 19acd311c3d..31dfd08fbd2 100755 --- a/egs/wsj/s5/utils/lang/internal/arpa2fst_constrained.py +++ b/egs/wsj/s5/utils/lang/internal/arpa2fst_constrained.py @@ -4,6 +4,7 @@ # Apache 2.0. from __future__ import print_function +from __future__ import division import sys import argparse import math @@ -44,7 +45,7 @@ print(' '.join(sys.argv), file = sys.stderr) -class HistoryState: +class HistoryState(object): def __init__(self): # note: neither backoff_prob nor the floats # in word_to_prob are in log space. @@ -56,7 +57,7 @@ def __init__(self): self.word_to_prob = dict() -class ArpaModel: +class ArpaModel(object): def __init__(self): # self.orders is indexed by history-length [i.e. 0 for unigram, # 1 for bigram and so on], and is then a dict indexed diff --git a/egs/wsj/s5/utils/lang/limit_arpa_unk_history.py b/egs/wsj/s5/utils/lang/limit_arpa_unk_history.py index 81c0df36d2b..68f7b4b5639 100755 --- a/egs/wsj/s5/utils/lang/limit_arpa_unk_history.py +++ b/egs/wsj/s5/utils/lang/limit_arpa_unk_history.py @@ -58,6 +58,7 @@ def get_ngram_stats(old_lm_lines): def find_and_replace_unks(old_lm_lines, max_ngrams, skip_rows): ngram_diffs = defaultdict(int) + whitespace_pattern = re.compile("[ \t]+") unk_pattern = re.compile( "[0-9.-]+(?:[\s\\t]\S+){1,3}[\s\\t]" + args.oov_dict_entry + "[\s\\t](?!-[0-9]+\.[0-9]+).*") @@ -70,13 +71,17 @@ def find_and_replace_unks(old_lm_lines, max_ngrams, skip_rows): new_lm_lines = old_lm_lines[:skip_rows] for i in range(skip_rows, len(old_lm_lines)): - line = old_lm_lines[i].strip() + line = old_lm_lines[i].strip(" \t\r\n") if "\{}-grams:".format(3) in line: passed_2grams = True if "\{}-grams:".format(max_ngrams) in line: last_ngram = True + for i in range(max_ngrams): + if "\{}-grams:".format(i+1) in line: + ngram = i+1 + # remove any n-gram states of the form: foo -> X # that is, any n-grams of order > 2 where # is the second-to-last word @@ -85,7 +90,6 @@ def find_and_replace_unks(old_lm_lines, max_ngrams, skip_rows): if passed_2grams: g_unk = unk_pattern.search(line) if g_unk: - ngram = len(g_unk.group(0).split()) - 1 ngram_diffs[ngram] = ngram_diffs[ngram] - 1 unk_row_count += 1 continue @@ -98,7 +102,7 @@ def find_and_replace_unks(old_lm_lines, max_ngrams, skip_rows): if not last_ngram: g_backoff = backoff_pattern.search(line) if g_backoff: - updated_row = g_backoff.group(0).split()[:-1] + updated_row = whitespace_pattern.split(g_backoff.group(0))[:-1] updated_row = updated_row[0] + \ "\t" + " ".join(updated_row[1:]) + "\n" new_lm_lines.append(updated_row) diff --git a/egs/wsj/s5/utils/lang/make_kn_lm.py b/egs/wsj/s5/utils/lang/make_kn_lm.py new file mode 100755 index 00000000000..00c4c8f2378 --- /dev/null +++ b/egs/wsj/s5/utils/lang/make_kn_lm.py @@ -0,0 +1,379 @@ +#!/usr/bin/env python3 + +# Copyright 2016 Johns Hopkins University (Author: Daniel Povey) +# 2018 Ruizhe Huang +# Apache 2.0. + +# This is an implementation of computing Kneser-Ney smoothed language model +# in the same way as srilm. This is a back-off, unmodified version of +# Kneser-Ney smoothing, which produces the same results as the following +# command (as an example) of srilm: +# +# $ ngram-count -order 4 -kn-modify-counts-at-end -ukndiscount -gt1min 0 -gt2min 0 -gt3min 0 -gt4min 0 \ +# -text corpus.txt -lm lm.arpa +# +# The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py +# The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html + +import sys +import os +import re +import io +import math +import argparse +from collections import Counter, defaultdict + + +parser = argparse.ArgumentParser(description=""" + Generate kneser-ney language model as arpa format. By default, + it will read the corpus from standard input, and output to standard output. + """) +parser.add_argument("-ngram-order", type=int, default=4, choices=[2, 3, 4, 5, 6, 7], help="Order of n-gram") +parser.add_argument("-text", type=str, default=None, help="Path to the corpus file") +parser.add_argument("-lm", type=str, default=None, help="Path to output arpa file for language models") +parser.add_argument("-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level") +args = parser.parse_args() + +default_encoding = "latin-1" # For encoding-agnostic scripts, we assume byte stream as input. + # Need to be very careful about the use of strip() and split() + # in this case, because there is a latin-1 whitespace character + # (nbsp) which is part of the unicode encoding range. + # Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717 +strip_chars = " \t\r\n" +whitespace = re.compile("[ \t]+") + + +class CountsForHistory: + # This class (which is more like a struct) stores the counts seen in a + # particular history-state. It is used inside class NgramCounts. + # It really does the job of a dict from int to float, but it also + # keeps track of the total count. + def __init__(self): + # The 'lambda: defaultdict(float)' is an anonymous function taking no + # arguments that returns a new defaultdict(float). + self.word_to_count = defaultdict(int) + self.word_to_context = defaultdict(set) # using a set to count the number of unique contexts + self.word_to_f = dict() # discounted probability + self.word_to_bow = dict() # back-off weight + self.total_count = 0 + + def words(self): + return self.word_to_count.keys() + + def __str__(self): + # e.g. returns ' total=12: 3->4, 4->6, -1->2' + return ' total={0}: {1}'.format( + str(self.total_count), + ', '.join(['{0} -> {1}'.format(word, count) + for word, count in self.word_to_count.items()])) + + def add_count(self, predicted_word, context_word, count): + assert count >= 0 + + self.total_count += count + self.word_to_count[predicted_word] += count + if context_word is not None: + self.word_to_context[predicted_word].add(context_word) + + +class NgramCounts: + # A note on data-structure. Firstly, all words are represented as + # integers. We store n-gram counts as an array, indexed by (history-length + # == n-gram order minus one) (note: python calls arrays "lists") of dicts + # from histories to counts, where histories are arrays of integers and + # "counts" are dicts from integer to float. For instance, when + # accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd + # do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an + # array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict. + def __init__(self, ngram_order, bos_symbol='', eos_symbol=''): + assert ngram_order >= 2 + + self.ngram_order = ngram_order + self.bos_symbol = bos_symbol + self.eos_symbol = eos_symbol + + self.counts = [] + for n in range(ngram_order): + self.counts.append(defaultdict(lambda: CountsForHistory())) + + self.d = [] # list of discounting factor for each order of ngram + + # adds a raw count (called while processing input data). + # Suppose we see the sequence '6 7 8 9' and ngram_order=4, 'history' + # would be (6,7,8) and 'predicted_word' would be 9; 'count' would be + # 1. + def add_count(self, history, predicted_word, context_word, count): + self.counts[len(history)][history].add_count(predicted_word, context_word, count) + + # 'line' is a string containing a sequence of integer word-ids. + # This function adds the un-smoothed counts from this line of text. + def add_raw_counts_from_line(self, line): + words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol] + + for i in range(len(words)): + for n in range(1, self.ngram_order+1): + if i + n > len(words): + break + + ngram = words[i: i + n] + predicted_word = ngram[-1] + history = tuple(ngram[: -1]) + if i == 0 or n == self.ngram_order: + context_word = None + else: + context_word = words[i-1] + + self.add_count(history, predicted_word, context_word, 1) + + def add_raw_counts_from_standard_input(self): + lines_processed = 0 + infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding) # byte stream as input + for line in infile: + line = line.strip(strip_chars) + if line == '': + break + self.add_raw_counts_from_line(line) + lines_processed += 1 + if lines_processed == 0 or args.verbose > 0: + print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr) + + def add_raw_counts_from_file(self, filename): + lines_processed = 0 + with open(filename, encoding=default_encoding) as fp: + for line in fp: + line = line.strip(strip_chars) + if line == '': + break + self.add_raw_counts_from_line(line) + lines_processed += 1 + if lines_processed == 0 or args.verbose > 0: + print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr) + + def cal_discounting_constants(self): + # For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N), + # where n1_N is the number of unique N-grams with count = 1 (counts-of-counts). + # This constant is used similarly to absolute discounting. + # Return value: d is a list of floats, where d[N+1] = D_N + + self.d = [0] # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0 + # This is a special case: as we currently assumed having seen all vocabularies in the dictionary, + # but perhaps this is not the case for some other scenarios. + for n in range(1, self.ngram_order): + this_order_counts = self.counts[n] + n1 = 0 + n2 = 0 + for hist, counts_for_hist in this_order_counts.items(): + stat = Counter(counts_for_hist.word_to_count.values()) + n1 += stat[1] + n2 += stat[2] + assert n1 + 2 * n2 > 0 + self.d.append(n1 * 1.0 / (n1 + 2 * n2)) + + def cal_f(self): + # f(a_z) is a probability distribution of word sequence a_z. + # Typically f(a_z) is discounted to be less than the ML estimate so we have + # some leftover probability for the z words unseen in the context (a_). + # + # f(a_z) = (c(a_z) - D0) / c(a_) ;; for highest order N-grams + # f(_z) = (n(*_z) - D1) / n(*_*) ;; for lower order N-grams + + # highest order N-grams + n = self.ngram_order - 1 + this_order_counts = self.counts[n] + for hist, counts_for_hist in this_order_counts.items(): + for w, c in counts_for_hist.word_to_count.items(): + counts_for_hist.word_to_f[w] = max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count + + # lower order N-grams + for n in range(0, self.ngram_order - 1): + this_order_counts = self.counts[n] + for hist, counts_for_hist in this_order_counts.items(): + + n_star_star = 0 + for w in counts_for_hist.word_to_count.keys(): + n_star_star += len(counts_for_hist.word_to_context[w]) + + if n_star_star != 0: + for w in counts_for_hist.word_to_count.keys(): + n_star_z = len(counts_for_hist.word_to_context[w]) + counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star + else: # patterns begin with , they do not have "modified count", so use raw count instead + for w in counts_for_hist.word_to_count.keys(): + n_star_z = counts_for_hist.word_to_count[w] + counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / counts_for_hist.total_count + + def cal_bow(self): + # Backoff weights are only necessary for ngrams which form a prefix of a longer ngram. + # Thus, two sorts of ngrams do not have a bow: + # 1) highest order ngram + # 2) ngrams ending in + # + # bow(a_) = (1 - Sum_Z1 f(a_z)) / (1 - Sum_Z1 f(_z)) + # Note that Z1 is the set of all words with c(a_z) > 0 + + # highest order N-grams + n = self.ngram_order - 1 + this_order_counts = self.counts[n] + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + counts_for_hist.word_to_bow[w] = None + + # lower order N-grams + for n in range(0, self.ngram_order - 1): + this_order_counts = self.counts[n] + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + if w == self.eos_symbol: + counts_for_hist.word_to_bow[w] = None + else: + a_ = hist + (w,) + + assert len(a_) < self.ngram_order + assert a_ in self.counts[len(a_)].keys() + + a_counts_for_hist = self.counts[len(a_)][a_] + + sum_z1_f_a_z = 0 + for u in a_counts_for_hist.word_to_count.keys(): + sum_z1_f_a_z += a_counts_for_hist.word_to_f[u] + + sum_z1_f_z = 0 + _ = a_[1:] + _counts_for_hist = self.counts[len(_)][_] + for u in a_counts_for_hist.word_to_count.keys(): # Should be careful here: what is Z1 + sum_z1_f_z += _counts_for_hist.word_to_f[u] + + counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (1.0 - sum_z1_f_z) + + def print_raw_counts(self, info_string): + # these are useful for debug. + print(info_string) + res = [] + for this_order_counts in self.counts: + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + ngram = " ".join(hist) + " " + w + ngram = ngram.strip(strip_chars) + + res.append("{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w])) + res.sort(reverse=True) + for r in res: + print(r) + + def print_modified_counts(self, info_string): + # these are useful for debug. + print(info_string) + res = [] + for this_order_counts in self.counts: + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + ngram = " ".join(hist) + " " + w + ngram = ngram.strip(strip_chars) + + modified_count = len(counts_for_hist.word_to_context[w]) + raw_count = counts_for_hist.word_to_count[w] + + if modified_count == 0: + res.append("{0}\t{1}".format(ngram, raw_count)) + else: + res.append("{0}\t{1}".format(ngram, modified_count)) + res.sort(reverse=True) + for r in res: + print(r) + + def print_f(self, info_string): + # these are useful for debug. + print(info_string) + res = [] + for this_order_counts in self.counts: + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + ngram = " ".join(hist) + " " + w + ngram = ngram.strip(strip_chars) + + f = counts_for_hist.word_to_f[w] + if f == 0: # f() is always 0 + f = 1e-99 + + res.append("{0}\t{1}".format(ngram, math.log(f, 10))) + res.sort(reverse=True) + for r in res: + print(r) + + def print_f_and_bow(self, info_string): + # these are useful for debug. + print(info_string) + res = [] + for this_order_counts in self.counts: + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + ngram = " ".join(hist) + " " + w + ngram = ngram.strip(strip_chars) + + f = counts_for_hist.word_to_f[w] + if f == 0: # f() is always 0 + f = 1e-99 + + bow = counts_for_hist.word_to_bow[w] + if bow is None: + res.append("{1}\t{0}".format(ngram, math.log(f, 10))) + else: + res.append("{1}\t{0}\t{2}".format(ngram, math.log(f, 10), math.log(bow, 10))) + res.sort(reverse=True) + for r in res: + print(r) + + def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1')): + # print as ARPA format. + + print('\\data\\', file=fout) + for hist_len in range(self.ngram_order): + # print the number of n-grams. + print('ngram {0}={1}'.format( + hist_len + 1, + sum([len(counts_for_hist.word_to_f) for counts_for_hist in self.counts[hist_len].values()])), + file=fout + ) + + print('', file=fout) + + for hist_len in range(self.ngram_order): + print('\\{0}-grams:'.format(hist_len + 1), file=fout) + + this_order_counts = self.counts[hist_len] + for hist, counts_for_hist in this_order_counts.items(): + for word in counts_for_hist.word_to_count.keys(): + ngram = hist + (word,) + prob = counts_for_hist.word_to_f[word] + bow = counts_for_hist.word_to_bow[word] + + if prob == 0: # f() is always 0 + prob = 1e-99 + + line = '{0}\t{1}'.format('%.7f' % math.log10(prob), ' '.join(ngram)) + if bow is not None: + line += '\t{0}'.format('%.7f' % math.log10(bow)) + print(line, file=fout) + print('', file=fout) + print('\\end\\', file=fout) + + +if __name__ == "__main__": + + ngram_counts = NgramCounts(args.ngram_order) + + if args.text is None: + ngram_counts.add_raw_counts_from_standard_input() + else: + assert os.path.isfile(args.text) + ngram_counts.add_raw_counts_from_file(args.text) + + ngram_counts.cal_discounting_constants() + ngram_counts.cal_f() + ngram_counts.cal_bow() + + if args.lm is None: + ngram_counts.print_as_arpa() + else: + with open(args.lm, 'w', encoding=default_encoding) as f: + ngram_counts.print_as_arpa(fout=f) diff --git a/egs/wsj/s5/utils/lang/make_lexicon_fst.py b/egs/wsj/s5/utils/lang/make_lexicon_fst.py new file mode 100755 index 00000000000..790af2f2314 --- /dev/null +++ b/egs/wsj/s5/utils/lang/make_lexicon_fst.py @@ -0,0 +1,411 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Johns Hopkins University (author: Daniel Povey) +# Apache 2.0. + +# see get_args() below for usage message. +import argparse +import os +import sys +import math +import re + +# The use of latin-1 encoding does not preclude reading utf-8. latin-1 +# encoding means "treat words as sequences of bytes", and it is compatible +# with utf-8 encoding as well as other encodings such as gbk, as long as the +# spaces are also spaces in ascii (which we check). It is basically how we +# emulate the behavior of python before python3. +sys.stdout = open(1, 'w', encoding='latin-1', closefd=False) +sys.stderr = open(2, 'w', encoding='latin-1', closefd=False) + +def get_args(): + parser = argparse.ArgumentParser(description="""This script creates the + text form of a lexicon FST, to be compiled by fstcompile using the + appropriate symbol tables (phones.txt and words.txt) . It will mostly + be invoked indirectly via utils/prepare_lang.sh. The output goes to + the stdout.""") + + parser.add_argument('--sil-phone', dest='sil_phone', type=str, + help="""Text form of optional-silence phone, e.g. 'SIL'. See also + the --silprob option.""") + parser.add_argument('--sil-prob', dest='sil_prob', type=float, default=0.0, + help="""Probability of silence between words (including at the + beginning and end of word sequences). Must be in the range [0.0, 1.0]. + This refers to the optional silence inserted by the lexicon; see + the --silphone option.""") + parser.add_argument('--sil-disambig', dest='sil_disambig', type=str, + help="""Disambiguation symbol to disambiguate silence, e.g. #5. + Will only be supplied if you are creating the version of L.fst + with disambiguation symbols, intended for use with cyclic G.fst. + This symbol was introduced to fix a rather obscure source of + nondeterminism of CLG.fst, that has to do with reordering of + disambiguation symbols and phone symbols.""") + parser.add_argument('--left-context-phones', dest='left_context_phones', type=str, + help="""Only relevant if --nonterminals is also supplied; this relates + to grammar decoding (see http://kaldi-asr.org/doc/grammar.html or + src/doc/grammar.dox). Format is a list of left-context phones, + in text form, one per line. E.g. data/lang/phones/left_context_phones.txt""") + parser.add_argument('--nonterminals', type=str, + help="""If supplied, --left-context-phones must also be supplied. + List of user-defined nonterminal symbols such as #nonterm:contact_list, + one per line. E.g. data/local/dict/nonterminals.txt.""") + parser.add_argument('lexiconp', type=str, + help="""Filename of lexicon with pronunciation probabilities + (normally lexiconp.txt), with lines of the form 'word prob p1 p2...', + e.g. 'a 1.0 ay'""") + args = parser.parse_args() + return args + + +def read_lexiconp(filename): + """Reads the lexiconp.txt file in 'filename', with lines like 'word pron p1 p2 ...'. + Returns a list of tuples (word, pron_prob, pron), where 'word' is a string, + 'pron_prob', a float, is the pronunciation probability (which must be >0.0 + and would normally be <=1.0), and 'pron' is a list of strings representing phones. + An element in the returned list might be ('hello', 1.0, ['h', 'eh', 'l', 'ow']). + """ + + ans = [] + found_empty_prons = False + found_large_pronprobs = False + # See the comment near the top of this file, RE why we use latin-1. + with open(filename, 'r', encoding='latin-1') as f: + whitespace = re.compile("[ \t]+") + for line in f: + a = whitespace.split(line.strip(" \t\r\n")) + if len(a) < 2: + print("{0}: error: found bad line '{1}' in lexicon file {2} ".format( + sys.argv[0], line.strip(" \t\r\n"), filename), file=sys.stderr) + sys.exit(1) + word = a[0] + if word == "": + # This would clash with the epsilon symbol normally used in OpenFst. + print("{0}: error: found as a word in lexicon file " + "{1}".format(line.strip(" \t\r\n"), filename), file=sys.stderr) + sys.exit(1) + try: + pron_prob = float(a[1]) + except: + print("{0}: error: found bad line '{1}' in lexicon file {2}, 2nd field " + "should be pron-prob".format(sys.argv[0], line.strip(" \t\r\n"), filename), + file=sys.stderr) + sys.exit(1) + prons = a[2:] + if pron_prob <= 0.0: + print("{0}: error: invalid pron-prob in line '{1}' of lexicon file {1} ".format( + sys.argv[0], line.strip(" \t\r\n"), filename), file=sys.stderr) + sys.exit(1) + if len(prons) == 0: + found_empty_prons = True + ans.append( (word, pron_prob, prons) ) + if pron_prob > 1.0: + found_large_pronprobs = True + if found_empty_prons: + print("{0}: warning: found at least one word with an empty pronunciation " + "in lexicon file {1}.".format(sys.argv[0], filename), + file=sys.stderr) + if found_large_pronprobs: + print("{0}: warning: found at least one word with pron-prob >1.0 " + "in {1}".format(sys.argv[0], filename), file=sys.stderr) + + + if len(ans) == 0: + print("{0}: error: found no pronunciations in lexicon file {1}".format( + sys.argv[0], filename), file=sys.stderr) + sys.exit(1) + return ans + + +def write_nonterminal_arcs(start_state, loop_state, next_state, + nonterminals, left_context_phones): + """This function relates to the grammar-decoding setup, see + kaldi-asr.org/doc/grammar.html. It is called from write_fst_no_silence + and write_fst_silence, and writes to the stdout some extra arcs + in the lexicon FST that relate to nonterminal symbols. + See the section "Special symbols in L.fst, + kaldi-asr.org/doc/grammar.html#grammar_special_l. + start_state: the start-state of L.fst. + loop_state: the state of high out-degree in L.fst where words leave + and enter. + next_state: the number from which this function can start allocating its + own states. the updated value of next_state will be returned. + nonterminals: the user-defined nonterminal symbols as a list of + strings, e.g. ['#nonterm:contact_list', ... ]. + left_context_phones: a list of phones that may appear as left-context, + e.g. ['a', 'ah', ... '#nonterm_bos']. + """ + shared_state = next_state + next_state += 1 + final_state = next_state + next_state += 1 + + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=start_state, dest=shared_state, + phone='#nonterm_begin', word='#nonterm_begin', + cost=0.0)) + + for nonterminal in nonterminals: + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=loop_state, dest=shared_state, + phone=nonterminal, word=nonterminal, + cost=0.0)) + # this_cost equals log(len(left_context_phones)) but the expression below + # better captures the meaning. Applying this cost to arcs keeps the FST + # stochatic (sum-to-one, like an HMM), so that if we do weight pushing + # things won't get weird. In the grammar-FST code when we splice things + # together we will cancel out this cost, see the function CombineArcs(). + this_cost = -math.log(1.0 / len(left_context_phones)) + + for left_context_phone in left_context_phones: + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=shared_state, dest=loop_state, + phone=left_context_phone, word='', cost=this_cost)) + # arc from loop-state to a final-state with #nonterm_end as ilabel and olabel + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=loop_state, dest=final_state, + phone='#nonterm_end', word='#nonterm_end', cost=0.0)) + print("{state}\t{final_cost}".format( + state=final_state, final_cost=0.0)) + return next_state + + + +def write_fst_no_silence(lexicon, nonterminals=None, left_context_phones=None): + """Writes the text format of L.fst to the standard output. This version is for + when --sil-prob=0.0, meaning there is no optional silence allowed. + + 'lexicon' is a list of 3-tuples (word, pron-prob, prons) as returned by + read_lexiconp(). + 'nonterminals', which relates to grammar decoding (see kaldi-asr.org/doc/grammar.html), + is either None, or the user-defined nonterminal symbols as a list of + strings, e.g. ['#nonterm:contact_list', ... ]. + 'left_context_phones', which also relates to grammar decoding, and must be + supplied if 'nonterminals' is supplied is either None or a list of + phones that may appear as left-context, e.g. ['a', 'ah', ... '#nonterm_bos']. + """ + + loop_state = 0 + next_state = 1 # the next un-allocated state, will be incremented as we go. + for (word, pronprob, pron) in lexicon: + cost = -math.log(pronprob) + cur_state = loop_state + for i in range(len(pron) - 1): + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=cur_state, + dest=next_state, + phone=pron[i], + word=(word if i == 0 else ''), + cost=(cost if i == 0 else 0.0))) + cur_state = next_state + next_state += 1 + + i = len(pron) - 1 # note: i == -1 if pron is empty. + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=cur_state, + dest=loop_state, + phone=(pron[i] if i >= 0 else ''), + word=(word if i <= 0 else ''), + cost=(cost if i <= 0 else 0.0))) + + if nonterminals is not None: + next_state = write_nonterminal_arcs( + start_state, loop_state, next_state, + nonterminals, left_context_phones) + + print("{state}\t{final_cost}".format( + state=loop_state, + final_cost=0.0)) + + +def write_fst_with_silence(lexicon, sil_prob, sil_phone, sil_disambig, + nonterminals=None, left_context_phones=None): + """Writes the text format of L.fst to the standard output. This version is for + when --sil-prob != 0.0, meaning there is optional silence + 'lexicon' is a list of 3-tuples (word, pron-prob, prons) + as returned by read_lexiconp(). + 'sil_prob', which is expected to be strictly between 0.. and 1.0, is the + probability of silence + 'sil_phone' is the silence phone, e.g. "SIL". + 'sil_disambig' is either None, or the silence disambiguation symbol, e.g. "#5". + 'nonterminals', which relates to grammar decoding (see kaldi-asr.org/doc/grammar.html), + is either None, or the user-defined nonterminal symbols as a list of + strings, e.g. ['#nonterm:contact_list', ... ]. + 'left_context_phones', which also relates to grammar decoding, and must be + supplied if 'nonterminals' is supplied is either None or a list of + phones that may appear as left-context, e.g. ['a', 'ah', ... '#nonterm_bos']. + """ + + assert sil_prob > 0.0 and sil_prob < 1.0 + sil_cost = -math.log(sil_prob) + no_sil_cost = -math.log(1.0 - sil_prob); + + start_state = 0 + loop_state = 1 # words enter and leave from here + sil_state = 2 # words terminate here when followed by silence; this state + # has a silence transition to loop_state. + next_state = 3 # the next un-allocated state, will be incremented as we go. + + + print('{src}\t{dest}\t{phone}\t{word}\t{cost}'.format( + src=start_state, dest=loop_state, + phone='', word='', cost=no_sil_cost)) + print('{src}\t{dest}\t{phone}\t{word}\t{cost}'.format( + src=start_state, dest=sil_state, + phone='', word='', cost=sil_cost)) + if sil_disambig is None: + print('{src}\t{dest}\t{phone}\t{word}\t{cost}'.format( + src=sil_state, dest=loop_state, + phone=sil_phone, word='', cost=0.0)) + else: + sil_disambig_state = next_state + next_state += 1 + print('{src}\t{dest}\t{phone}\t{word}\t{cost}'.format( + src=sil_state, dest=sil_disambig_state, + phone=sil_phone, word='', cost=0.0)) + print('{src}\t{dest}\t{phone}\t{word}\t{cost}'.format( + src=sil_disambig_state, dest=loop_state, + phone=sil_disambig, word='', cost=0.0)) + + + for (word, pronprob, pron) in lexicon: + pron_cost = -math.log(pronprob) + cur_state = loop_state + for i in range(len(pron) - 1): + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=cur_state, dest=next_state, + phone=pron[i], + word=(word if i == 0 else ''), + cost=(pron_cost if i == 0 else 0.0))) + cur_state = next_state + next_state += 1 + + i = len(pron) - 1 # note: i == -1 if pron is empty. + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=cur_state, + dest=loop_state, + phone=(pron[i] if i >= 0 else ''), + word=(word if i <= 0 else ''), + cost=no_sil_cost + (pron_cost if i <= 0 else 0.0))) + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=cur_state, + dest=sil_state, + phone=(pron[i] if i >= 0 else ''), + word=(word if i <= 0 else ''), + cost=sil_cost + (pron_cost if i <= 0 else 0.0))) + + if nonterminals is not None: + next_state = write_nonterminal_arcs( + start_state, loop_state, next_state, + nonterminals, left_context_phones) + + print("{state}\t{final_cost}".format( + state=loop_state, + final_cost=0.0)) + + + + +def write_words_txt(orig_lines, highest_numbered_symbol, nonterminals, filename): + """Writes updated words.txt to 'filename'. 'orig_lines' is the original lines + in the words.txt file as a list of strings (without the newlines); + highest_numbered_symbol is the highest numbered symbol in the original + words.txt; nonterminals is a list of strings like '#nonterm:foo'.""" + with open(filename, 'w', encoding='latin-1') as f: + for l in orig_lines: + print(l, file=f) + cur_symbol = highest_numbered_symbol + 1 + for n in [ '#nonterm_begin', '#nonterm_end' ] + nonterminals: + print("{0} {1}".format(n, cur_symbol), file=f) + cur_symbol = cur_symbol + 1 + + +def read_nonterminals(filename): + """Reads the user-defined nonterminal symbols in 'filename', checks that + it has the expected format and has no duplicates, and returns the nonterminal + symbols as a list of strings, e.g. + ['#nonterm:contact_list', '#nonterm:phone_number', ... ]. """ + ans = [line.strip(" \t\r\n") for line in open(filename, 'r', encoding='latin-1')] + if len(ans) == 0: + raise RuntimeError("The file {0} contains no nonterminals symbols.".format(filename)) + for nonterm in ans: + if nonterm[:9] != '#nonterm:': + raise RuntimeError("In file '{0}', expected nonterminal symbols to start with '#nonterm:', found '{1}'" + .format(filename, nonterm)) + if len(set(ans)) != len(ans): + raise RuntimeError("Duplicate nonterminal symbols are present in file {0}".format(filename)) + return ans + +def read_left_context_phones(filename): + """Reads, checks, and returns a list of left-context phones, in text form, one + per line. Returns a list of strings, e.g. ['a', 'ah', ..., '#nonterm_bos' ]""" + ans = [line.strip(" \t\r\n") for line in open(filename, 'r', encoding='latin-1')] + if len(ans) == 0: + raise RuntimeError("The file {0} contains no left-context phones.".format(filename)) + whitespace = re.compile("[ \t]+") + for s in ans: + if len(whitespace.split(s)) != 1: + raise RuntimeError("The file {0} contains an invalid line '{1}'".format(filename, s) ) + + if len(set(ans)) != len(ans): + raise RuntimeError("Duplicate nonterminal symbols are present in file {0}".format(filename)) + return ans + + +def is_token(s): + """Returns true if s is a string and is space-free.""" + if not isinstance(s, str): + return False + whitespace = re.compile("[ \t\r\n]+") + split_str = whitespace.split(s); + return len(split_str) == 1 and s == split_str[0] + + +def main(): + args = get_args() + + lexicon = read_lexiconp(args.lexiconp) + + if args.nonterminals is None: + nonterminals, left_context_phones = None, None + else: + if args.left_context_phones is None: + print("{0}: if --nonterminals is specified, --left-context-phones must also " + "be specified".format(sys.argv[0])) + sys.exit(1) + nonterminals = read_nonterminals(args.nonterminals) + left_context_phones = read_left_context_phones(args.left_context_phones) + + if args.sil_prob == 0.0: + write_fst_no_silence(lexicon, + nonterminals=nonterminals, + left_context_phones=left_context_phones) + else: + # Do some checking that the options make sense. + if args.sil_prob < 0.0 or args.sil_prob >= 1.0: + print("{0}: invalid value specified --sil-prob={1}".format( + sys.argv[0], args.sil_prob), file=sys.stderr) + sys.exit(1) + + if not is_token(args.sil_phone): + print("{0}: you specified --sil-prob={1} but --sil-phone is set " + "to '{2}'".format(sys.argv[0], args.sil_prob, args.sil_phone), + file=sys.stderr) + sys.exit(1) + if args.sil_disambig is not None and not is_token(args.sil_disambig): + print("{0}: invalid value --sil-disambig='{1}' was specified." + "".format(sys.argv[0], args.sil_disambig), file=sys.stderr) + sys.exit(1) + write_fst_with_silence(lexicon, args.sil_prob, args.sil_phone, + args.sil_disambig, + nonterminals=nonterminals, + left_context_phones=left_context_phones) + + + +# (lines, highest_symbol) = read_words_txt(args.input_words_txt) +# nonterminals = read_nonterminals(args.nonterminal_symbols_list) +# write_words_txt(lines, highest_symbol, nonterminals, args.output_words_txt) + + +if __name__ == '__main__': + main() diff --git a/egs/wsj/s5/utils/lang/make_lexicon_fst_silprob.py b/egs/wsj/s5/utils/lang/make_lexicon_fst_silprob.py new file mode 100755 index 00000000000..0633c4bec73 --- /dev/null +++ b/egs/wsj/s5/utils/lang/make_lexicon_fst_silprob.py @@ -0,0 +1,408 @@ +#!/usr/bin/env python3 +# Copyright 2018 Johns Hopkins University (author: Daniel Povey) +# 2018 Jiedan Zhu +# Apache 2.0. +# see get_args() below for usage message. + +import argparse +import os +import sys +import math +import re + +# The use of latin-1 encoding does not preclude reading utf-8. latin-1 +# encoding means "treat words as sequences of bytes", and it is compatible +# with utf-8 encoding as well as other encodings such as gbk, as long as the +# spaces are also spaces in ascii (which we check). It is basically how we +# emulate the behavior of python before python3. + +sys.stdout = open(1, 'w', encoding='latin-1', closefd=False) +sys.stderr = open(2, 'w', encoding='latin-1', closefd=False) + + +def get_args(): + parser = argparse.ArgumentParser(description="""This script creates the + text form of a lexicon FST, to be compiled by fstcompile using the + appropriate symbol tables (phones.txt and words.txt) . It will mostly + be invoked indirectly via utils/prepare_lang.sh. The output goes to + the stdout. + + This version is for a lexicon with word-specific silence probabilities, + see http://www.danielpovey.com/files/2015_interspeech_silprob.pdf + for an explanation""") + + parser.add_argument('--sil-phone', dest='sil_phone', type=str, + help="Text form of optional-silence phone, e.g. 'SIL'.") + parser.add_argument('--sil-disambig', dest='sil_disambig', type=str, default="", + help="""Disambiguation symbol to disambiguate silence, e.g. #5. + Will only be supplied if you are creating the version of L.fst + with disambiguation symbols, intended for use with cyclic G.fst. + This symbol was introduced to fix a rather obscure source of + nondeterminism of CLG.fst, that has to do with reordering of + disambiguation symbols and phone symbols.""") + parser.add_argument('lexiconp', type=str, + help="""Filename of lexicon with pronunciation probabilities + (normally lexiconp.txt), with lines of the form + 'word pron-prob prob-of-sil correction-term-for-sil correction-term-for-no-sil p1 p2...', + e.g. 'a 1.0 0.8 1.2 0.6 ay'""") + parser.add_argument('silprobs', type=str, + help="""Filename with silence probabilities, with lines of the form + ' p(sil-after|) // + _s correction-term-for-sil-for- // + _n correction-term-for-no-sil-for- // + overall p(overall-sil), where // represents line break. + See also utils/dict_dir_add_pronprobs.sh, + which creates this file as silprob.txt.""") + parser.add_argument('--left-context-phones', dest='left_context_phones', type=str, + help="""Only relevant if --nonterminals is also supplied; this relates + to grammar decoding (see http://kaldi-asr.org/doc/grammar.html or + src/doc/grammar.dox). Format is a list of left-context phones, + in text form, one per line. E.g. data/lang/phones/left_context_phones.txt""") + parser.add_argument('--nonterminals', type=str, + help="""If supplied, --left-context-phones must also be supplied. + List of user-defined nonterminal symbols such as #nonterm:contact_list, + one per line. E.g. data/local/dict/nonterminals.txt.""") + + args = parser.parse_args() + return args + + +def read_silprobs(filename): + """ Reads the silprobs file (e.g. silprobs.txt) which will have a format like this: + 0.99 + _s 2.50607106867326 + _n 0.00653829808100956 + overall 0.20 + and returns it as a 4-tuple, e.g. in this example (0.99, 2.50, 0.006, 0.20) + """ + silbeginprob = -1 + silendcorrection = -1 + nonsilendcorrection = -1 + siloverallprob = -1 + with open(filename, 'r', encoding='latin-1') as f: + whitespace = re.compile("[ \t]+") + for line in f: + a = whitespace.split(line.strip(" \t\r\n")) + if len(a) != 2: + print("{0}: error: found bad line '{1}' in silprobs file {1} ".format( + sys.argv[0], line.strip(" \t\r\n"), filename), file=sys.stderr) + sys.exit(1) + label = a[0] + try: + if label == "": + silbeginprob = float(a[1]) + elif label == "_s": + silendcorrection = float(a[1]) + elif label == "_n": + nonsilendcorrection = float(a[1]) + elif label == "overall": + siloverallprob = float(a[1]) # this is not in use, still keep it? + else: + raise RuntimeError() + except: + print("{0}: error: found bad line '{1}' in silprobs file {1}" + .format(sys.argv[0], line.strip(" \t\r\n"), filename), + file=sys.stderr) + sys.exit(1) + if (silbeginprob <= 0.0 or silbeginprob > 1.0 or + silendcorrection <= 0.0 or nonsilendcorrection <= 0.0 or + siloverallprob <= 0.0 or siloverallprob > 1.0): + print("{0}: error: prob is not correct in silprobs file {1}." + .format(sys.argv[0], filename), file=sys.stderr) + sys.exit(1) + return (silbeginprob, silendcorrection, nonsilendcorrection, siloverallprob) + + +def read_lexiconp(filename): + """Reads the lexiconp.txt file in 'filename', with lines like + 'word p(pronunciation|word) p(sil-after|word) correction-term-for-sil + correction-term-for-no-sil p1 p2 ...'. + Returns a list of tuples (word, pron_prob, word_sil_prob, + sil_word_correction, non_sil_word_correction, prons), where 'word' is a string, + 'pron_prob', a float, is the pronunciation probability (which must be >0.0 + and would normally be <=1.0), 'word_sil_prob' is a float, + 'sil_word_correction' is a float, 'non_sil_word_correction' is a float, + and 'pron' is a list of strings representing phones. + An element in the returned list might be + ('hello', 1.0, 0.5, 0.3, 0.6, ['h', 'eh', 'l', 'ow']). + """ + ans = [] + found_empty_prons = False + found_large_pronprobs = False + # See the comment near the top of this file, RE why we use latin-1. + whitespace = re.compile("[ \t]+") + with open(filename, 'r', encoding='latin-1') as f: + for line in f: + a = whitespace.split(line.strip(" \t\r\n")) + if len(a) < 2: + print("{0}: error: found bad line '{1}' in lexicon file {1} ".format( + sys.argv[0], line.strip(" \t\r\n"), filename), file=sys.stderr) + sys.exit(1) + word = a[0] + if word == "": + # This would clash with the epsilon symbol normally used in OpenFst. + print("{0}: error: found as a word in lexicon file " + "{1}".format(line.strip(" \t\r\n"), filename), file=sys.stderr) + sys.exit(1) + try: + pron_prob = float(a[1]) + word_sil_prob = float(a[2]) + sil_word_correction = float(a[3]) + non_sil_word_correction = float(a[4]) + except: + print("{0}: error: found bad line '{1}' in lexicon file {2}, 2nd field " + "through 5th field should be numbers".format(sys.argv[0], + line.strip(" \t\r\n"), filename), + file=sys.stderr) + sys.exit(1) + prons = a[5:] + if pron_prob <= 0.0: + print("{0}: error: invalid pron-prob in line '{1}' of lexicon file {2} ".format( + sys.argv[0], line.strip(" \t\r\n"), filename), file=sys.stderr) + sys.exit(1) + if len(prons) == 0: + found_empty_prons = True + ans.append(( + word, pron_prob, word_sil_prob, + sil_word_correction, non_sil_word_correction, prons)) + if pron_prob > 1.0: + found_large_pronprobs = True + if found_empty_prons: + print("{0}: warning: found at least one word with an empty pronunciation " + "in lexicon file {1}.".format(sys.argv[0], filename), + file=sys.stderr) + if found_large_pronprobs: + print("{0}: warning: found at least one word with pron-prob >1.0 " + "in {1}".format(sys.argv[0], filename), file=sys.stderr) + if len(ans) == 0: + print("{0}: error: found no pronunciations in lexicon file {1}".format( + sys.argv[0], filename), file=sys.stderr) + sys.exit(1) + return ans + + +def write_nonterminal_arcs(start_state, sil_state, non_sil_state, + next_state, sil_phone, + nonterminals, left_context_phones): + """This function relates to the grammar-decoding setup, see + kaldi-asr.org/doc/grammar.html. It is called from write_fst, and writes to + the stdout some extra arcs in the lexicon FST that relate to nonterminal + symbols. + + See the section "Special symbols in L.fst, + kaldi-asr.org/doc/grammar.html#grammar_special_l. + start_state: the start-state of L.fst. + sil_state: the state of high out-degree in L.fst where words leave + when preceded by optional silence + non_sil_state: the state of high out-degree in L.fst where words leave + when not preceded by optional silence + next_state: the number from which this function can start allocating its + own states. the updated value of next_state will be returned. + sil_phone: the optional-silence phone (a string, e.g 'sil') + nonterminals: the user-defined nonterminal symbols as a list of + strings, e.g. ['#nonterm:contact_list', ... ]. + left_context_phones: a list of phones that may appear as left-context, + e.g. ['a', 'ah', ... '#nonterm_bos']. + """ + shared_state = next_state + next_state += 1 + final_state = next_state + next_state += 1 + + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=start_state, dest=shared_state, + phone='#nonterm_begin', word='#nonterm_begin', + cost=0.0)) + + for nonterminal in nonterminals: + # What we are doing here could be viewed as a little lazy, by going to + # 'shared_state' instead of a state specific to nonsilence vs. silence + # left-context vs. unknown (for #nonterm_begin). If we made them + # separate we could improve (by half) the correctness of how it + # interacts with sil-probs in the hard-to-handle case where + # word-position-dependent phones are not used and some words end + # in the optional-silence phone. + for src in [sil_state, non_sil_state]: + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=src, dest=shared_state, + phone=nonterminal, word=nonterminal, + cost=0.0)) + + # this_cost equals log(len(left_context_phones)) but the expression below + # better captures the meaning. Applying this cost to arcs keeps the FST + # stochatic (sum-to-one, like an HMM), so that if we do weight pushing + # things won't get weird. In the grammar-FST code when we splice things + # together we will cancel out this cost, see the function CombineArcs(). + this_cost = -math.log(1.0 / len(left_context_phones)) + + for left_context_phone in left_context_phones: + # The following line is part of how we get this to interact correctly with + # the silence probabilities: if the left-context phone was the silence + # phone, it goes to sil_state, else nonsil_state. This won't always + # do the right thing if you have a system without word-position-dependent + # phones (--position-dependent-phones false to prepare_lang.sh) and + # you have words that end in the optional-silence phone. + dest = (sil_state if left_context_phone == sil_phone else non_sil_state) + + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=shared_state, dest=dest, + phone=left_context_phone, word='', cost=this_cost)) + + # arc from sil_state and non_sil_state to a final-state with #nonterm_end as + # ilabel and olabel. The costs on these arcs are zero because if you take + # that arc, you are not really terminating the sequence, you are just + # skipping to sil_state or non_sil_state in the FST one level up. It + # takes the correct path because of the code around 'dest = ...' a few + # lines above this, after reaching 'shared_state' because it saw the + # user-defined nonterminal. + for src in [sil_state, non_sil_state]: + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=src, dest=final_state, + phone='#nonterm_end', word='#nonterm_end', cost=0.0)) + print("{state}\t{final_cost}".format( + state=final_state, final_cost=0.0)) + return next_state + +def write_fst(lexicon, silprobs, sil_phone, sil_disambig, + nonterminals = None, left_context_phones = None): + """Writes the text format of L.fst (or L_disambig.fst) to the standard output. + 'lexicon' is a list of 5-tuples + (word, pronprob, wordsilprob, silwordcorrection, nonsilwordcorrection, pron) + as returned by read_lexiconp(). + 'silprobs' is a 4-tuple of probabilities as returned by read_silprobs(). + 'sil_phone' is the silence phone, e.g. "SIL". + 'sil_disambig' is either '', or the silence disambiguation symbol, e.g. "#5". + 'nonterminals', which relates to grammar decoding (see kaldi-asr.org/doc/grammar.html), + is either None, or the user-defined nonterminal symbols as a list of + strings, e.g. ['#nonterm:contact_list', ... ]. + 'left_context_phones', which also relates to grammar decoding, and must be + supplied if 'nonterminals' is supplied is either None or a list of + phones that may appear as left-context, e.g. ['a', 'ah', ... '#nonterm_bos']. + """ + silbeginprob, silendcorrection, nonsilendcorrection, siloverallprob = silprobs + initial_sil_cost = -math.log(silbeginprob) + initial_non_sil_cost = -math.log(1.0 - silbeginprob); + sil_end_correction_cost = -math.log(silendcorrection) + non_sil_end_correction_cost = -math.log(nonsilendcorrection); + start_state = 0 + non_sil_state = 1 # words enter and leave from here + sil_state = 2 # words terminate here when followed by silence; this state + # has a silence transition to loop_state. + next_state = 3 # the next un-allocated state, will be incremented as we go. + + # Arcs from the start state to the silence and nonsilence loop states + # The one to the nonsilence state has the silence disambiguation symbol + # (We always use that symbol on the *non*-silence-containing arcs, which + # avoids having to introduce extra arcs). + print('{src}\t{dest}\t{phone}\t{word}\t{cost}'.format( + src=start_state, dest=non_sil_state, + phone=sil_disambig, word='', cost=initial_non_sil_cost)) + print('{src}\t{dest}\t{phone}\t{word}\t{cost}'.format( + src=start_state, dest=sil_state, + phone=sil_phone, word='', cost=initial_sil_cost)) + + for (word, pronprob, wordsilprob, silwordcorrection, nonsilwordcorrection, pron) in lexicon: + pron_cost = -math.log(pronprob) + word_to_sil_cost = -math.log(wordsilprob) + word_to_non_sil_cost = -math.log(1.0 - wordsilprob) + sil_to_word_cost = -math.log(silwordcorrection) + non_sil_to_word_cost = -math.log(nonsilwordcorrection) + + if len(pron) == 0: + # this is not really expected but we try to handle it gracefully. + pron = [''] + + new_state = next_state # allocate a new state + next_state += 1 + # Create transitions from both non_sil_state and sil_state to 'new_state', + # with the word label and the word's first phone on them + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=non_sil_state, dest=new_state, + phone=pron[0], word=word, cost=(pron_cost + non_sil_to_word_cost))) + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=sil_state, dest=new_state, + phone=pron[0], word=word, cost=(pron_cost + sil_to_word_cost))) + cur_state = new_state + + # add states and arcs for all but the first phone. + for i in range(1, len(pron)): + new_state = next_state + next_state += 1 + print("{src}\t{dest}\t{phone}\t".format( + src=cur_state, dest=new_state, phone=pron[i])) + cur_state = new_state + + # ... and from there we return via two arcs to the silence and + # nonsilence state. the silence-disambig symbol, if used,q + # goes on the nonsilence arc; this saves us having to insert an epsilon. + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=cur_state, dest=non_sil_state, + phone=sil_disambig, word='', + cost=word_to_non_sil_cost)) + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=cur_state, dest=sil_state, + phone=sil_phone, word='', + cost=word_to_sil_cost)) + + if nonterminals is not None: + next_state = write_nonterminal_arcs( + start_state, sil_state, non_sil_state, + next_state, sil_phone, + nonterminals, left_context_phones) + + print('{src}\t{cost}'.format(src=sil_state, cost=sil_end_correction_cost)) + print('{src}\t{cost}'.format(src=non_sil_state, cost=non_sil_end_correction_cost)) + +def read_nonterminals(filename): + """Reads the user-defined nonterminal symbols in 'filename', checks that + it has the expected format and has no duplicates, and returns the nonterminal + symbols as a list of strings, e.g. + ['#nonterm:contact_list', '#nonterm:phone_number', ... ]. """ + ans = [line.strip(" \t\r\n") for line in open(filename, 'r', encoding='latin-1')] + if len(ans) == 0: + raise RuntimeError("The file {0} contains no nonterminals symbols.".format(filename)) + for nonterm in ans: + if nonterm[:9] != '#nonterm:': + raise RuntimeError("In file '{0}', expected nonterminal symbols to start with '#nonterm:', found '{1}'" + .format(filename, nonterm)) + if len(set(ans)) != len(ans): + raise RuntimeError("Duplicate nonterminal symbols are present in file {0}".format(filename)) + return ans + +def read_left_context_phones(filename): + """Reads, checks, and returns a list of left-context phones, in text form, one + per line. Returns a list of strings, e.g. ['a', 'ah', ..., '#nonterm_bos' ]""" + ans = [line.strip(" \t\r\n") for line in open(filename, 'r', encoding='latin-1')] + if len(ans) == 0: + raise RuntimeError("The file {0} contains no left-context phones.".format(filename)) + for s in ans: + if len(s.split()) != 1: + raise RuntimeError("The file {0} contains an invalid line '{1}'".format(filename, s) ) + + if len(set(ans)) != len(ans): + raise RuntimeError("Duplicate nonterminal symbols are present in file {0}".format(filename)) + return ans + + +def main(): + args = get_args() + silprobs = read_silprobs(args.silprobs) + lexicon = read_lexiconp(args.lexiconp) + + + if args.nonterminals is None: + nonterminals, left_context_phones = None, None + else: + if args.left_context_phones is None: + print("{0}: if --nonterminals is specified, --left-context-phones must also " + "be specified".format(sys.argv[0])) + sys.exit(1) + nonterminals = read_nonterminals(args.nonterminals) + left_context_phones = read_left_context_phones(args.left_context_phones) + + write_fst(lexicon, silprobs, args.sil_phone, args.sil_disambig, + nonterminals, left_context_phones) + + +if __name__ == '__main__': + main() diff --git a/egs/wsj/s5/utils/lang/make_phone_lm.py b/egs/wsj/s5/utils/lang/make_phone_lm.py index 47d2a45d229..5cc9a8de832 100755 --- a/egs/wsj/s5/utils/lang/make_phone_lm.py +++ b/egs/wsj/s5/utils/lang/make_phone_lm.py @@ -4,6 +4,7 @@ # Apache 2.0. from __future__ import print_function +from __future__ import division import sys import argparse import math @@ -65,7 +66,7 @@ -class CountsForHistory: +class CountsForHistory(object): ## This class (which is more like a struct) stores the counts seen in a ## particular history-state. It is used inside class NgramCounts. ## It really does the job of a dict from int to float, but it also @@ -77,7 +78,7 @@ def __init__(self): self.total_count = 0 def Words(self): - return self.word_to_count.keys() + return list(self.word_to_count.keys()) def __str__(self): # e.g. returns ' total=12 3->4 4->6 -1->2' @@ -109,7 +110,7 @@ def AddCount(self, predicted_word, count): else: self.word_to_count[predicted_word] = new_count -class NgramCounts: +class NgramCounts(object): ## A note on data-structure. Firstly, all words are represented as ## integers. We store n-gram counts as an array, indexed by (history-length ## == n-gram order minus one) (note: python calls arrays "lists") of dicts @@ -187,7 +188,7 @@ def ApplyBackoff(self): # there will be no unigram. if args.verbose >= 1: initial_num_ngrams = self.GetNumNgrams() - for n in reversed(range(args.no_backoff_ngram_order, args.ngram_order)): + for n in reversed(list(range(args.no_backoff_ngram_order, args.ngram_order))): this_order_counts = self.counts[n] for hist, counts_for_hist in this_order_counts.items(): backoff_hist = hist[1:] @@ -276,8 +277,8 @@ def PruneEmptyStates(self): states_removed_per_hist_len = [ 0 ] * args.ngram_order - for n in reversed(range(args.no_backoff_ngram_order, - args.ngram_order)): + for n in reversed(list(range(args.no_backoff_ngram_order, + args.ngram_order))): num_states_removed = 0 for hist, counts_for_hist in self.counts[n].items(): l = len(counts_for_hist.word_to_count) @@ -304,14 +305,14 @@ def EnsureStructurallyNeededNgramsExist(self): # we have a unigram state]. if args.verbose >= 1: num_ngrams_initial = self.GetNumNgrams() - for n in reversed(range(args.no_backoff_ngram_order, - args.ngram_order)): + for n in reversed(list(range(args.no_backoff_ngram_order, + args.ngram_order))): for hist, counts_for_hist in self.counts[n].items(): # This loop ensures that if we have an n-gram like (6, 7, 8) -> 9, # then, say, (7, 8) -> 9 and (8) -> 9 exist. reduced_hist = hist - for m in reversed(range(args.no_backoff_ngram_order, n)): + for m in reversed(list(range(args.no_backoff_ngram_order, n))): reduced_hist = reduced_hist[1:] # shift an element off # the history. counts_for_backoff_hist = self.counts[m][reduced_hist] @@ -321,7 +322,7 @@ def EnsureStructurallyNeededNgramsExist(self): # then, say, (6, 7) -> 8 and (6) -> 7 exist. This will be needed # for FST representations of the ARPA LM. reduced_hist = hist - for m in reversed(range(args.no_backoff_ngram_order, n)): + for m in reversed(list(range(args.no_backoff_ngram_order, n))): this_word = reduced_hist[-1] reduced_hist = reduced_hist[:-1] # pop an element off the # history @@ -346,7 +347,7 @@ def PrintAsFst(self, word_disambig_symbol): # History will map from history (as a tuple) to integer FST-state. hist_to_state = self.GetHistToStateMap() - for n in [ 1, 0 ] + range(2, args.ngram_order): + for n in [ 1, 0 ] + list(range(2, args.ngram_order)): this_order_counts = self.counts[n] # For order 1, make sure the keys are sorted. keys = this_order_counts.keys() if n != 1 else sorted(this_order_counts.keys()) @@ -388,7 +389,7 @@ def GetProtectedNgrams(self): # add the backed-off n-grams (7, 8) -> 9 and (8) -> 9 to # 'protected-ngrams'. reduced_hist = hist - for m in reversed(range(args.no_backoff_ngram_order, n)): + for m in reversed(list(range(args.no_backoff_ngram_order, n))): reduced_hist = reduced_hist[1:] # shift an element off # the history. @@ -399,7 +400,7 @@ def GetProtectedNgrams(self): # history-state (6, 7, 8), then n-grams (6, 7, 8) and (6, 7) are # protected. This assures that the FST states are accessible. reduced_hist = hist - for m in reversed(range(args.no_backoff_ngram_order, n)): + for m in reversed(list(range(args.no_backoff_ngram_order, n))): ans.add(reduced_hist) reduced_hist = reduced_hist[:-1] # pop an element off the # history @@ -499,7 +500,7 @@ def PruningLogprobChange(self, count, discount, backoff_count, backoff_total): # and the 'count' term is zero in the numerator part of the log expression, # because symbol 'a' is completely backed off in 'this' state. this_a_change = augmented_count * \ - math.log((new_discount * new_backoff_count / new_backoff_total) / \ + math.log((new_discount * new_backoff_count / new_backoff_total)/ \ augmented_count) # other_a_change is the log-like change of symbol 'a' coming from all @@ -511,7 +512,7 @@ def PruningLogprobChange(self, count, discount, backoff_count, backoff_total): # doing so gives us an upper bound on the divergence. other_a_change = \ a_other_count * math.log((new_backoff_count / new_backoff_total) / \ - (backoff_count / backoff_total)) + (backoff_count / backoff_total)) # b_change is the log-like change of phantom symbol 'b' coming from # 'this' state (and note: it only comes from this state, that's how we diff --git a/egs/wsj/s5/utils/make_lexicon_fst.pl b/egs/wsj/s5/utils/make_lexicon_fst.pl index f97129c05cb..cd39ef98b4c 100755 --- a/egs/wsj/s5/utils/make_lexicon_fst.pl +++ b/egs/wsj/s5/utils/make_lexicon_fst.pl @@ -1,4 +1,9 @@ #!/usr/bin/env perl + +# THIS SCRIPT IS DEPRECATED AND WILL BE REMOVED. See +# utils/lang/make_lexicon_fst.py which is the python-based replacement. + + use warnings; #sed replacement for -w perl parameter # Copyright 2010-2011 Microsoft Corporation # 2013 Johns Hopkins University (author: Daniel Povey) diff --git a/egs/wsj/s5/utils/make_lexicon_fst_silprob.pl b/egs/wsj/s5/utils/make_lexicon_fst_silprob.pl index 557af4fe65e..cef26caf2f5 100755 --- a/egs/wsj/s5/utils/make_lexicon_fst_silprob.pl +++ b/egs/wsj/s5/utils/make_lexicon_fst_silprob.pl @@ -1,4 +1,8 @@ #!/usr/bin/env perl + +# THIS SCRIPT IS DEPRECATED AND WILL BE REMOVED. See +# utils/lang/make_lexicon_fst_silprob.py which is the python-based replacement. + use warnings; #sed replacement for -w perl parameter # Copyright 2010-2011 Microsoft Corporation # 2013 Johns Hopkins University (author: Daniel Povey) @@ -19,8 +23,8 @@ # limitations under the License. -# makes lexicon FST, in text form, from lexicon which contains (optional) -# probabilities of pronuniations, and (mandatory) probabilities of silence +# makes lexicon FST, in text form, from lexicon which contains (optional) +# probabilities of pronuniations, and (mandatory) probabilities of silence # before and after the pronunciation. This script is almost the same with # the make_lexicon_fst.pl script except for the word-dependent silprobs part @@ -68,7 +72,7 @@ $w = shift @A; if ($w eq "") { $silbeginprob = shift @A; - } + } if ($w eq "_s") { $silendcorrection = shift @A; } @@ -142,6 +146,6 @@ } } $cost = -log($silendcorrection); -print "$silstart\t$cost\n"; +print "$silstart\t$cost\n"; $cost = -log($nonsilendcorrection); print "$nonsilstart\t$cost\n"; diff --git a/egs/wsj/s5/utils/mkgraph.sh b/egs/wsj/s5/utils/mkgraph.sh index 1becfc45be3..31e86cd38f6 100755 --- a/egs/wsj/s5/utils/mkgraph.sh +++ b/egs/wsj/s5/utils/mkgraph.sh @@ -78,6 +78,19 @@ P=$(tree-info $tree | grep "central-position" | cut -d' ' -f2) || { echo "Error [[ -f $2/frame_subsampling_factor && "$loopscale" == "0.1" ]] && \ echo "$0: WARNING: chain models need '--self-loop-scale 1.0'"; +if [ -f $lang/phones/nonterm_phones_offset.int ]; then + if [[ $N != 2 || $P != 1 ]]; then + echo "$0: when doing grammar decoding, you can only build graphs for left-biphone trees." + exit 1 + fi + nonterm_phones_offset=$(cat $lang/phones/nonterm_phones_offset.int) + nonterm_opt="--nonterm-phones-offset=$nonterm_phones_offset" + prepare_grammar_command="make-grammar-fst --nonterm-phones-offset=$nonterm_phones_offset - -" +else + prepare_grammar_command="cat" + nonterm_opt= +fi + mkdir -p $lang/tmp trap "rm -f $lang/tmp/LG.fst.$$" EXIT HUP INT PIPE TERM # Note: [[ ]] is like [ ] but enables certain extra constructs, e.g. || in @@ -85,8 +98,7 @@ trap "rm -f $lang/tmp/LG.fst.$$" EXIT HUP INT PIPE TERM if [[ ! -s $lang/tmp/LG.fst || $lang/tmp/LG.fst -ot $lang/G.fst || \ $lang/tmp/LG.fst -ot $lang/L_disambig.fst ]]; then fsttablecompose $lang/L_disambig.fst $lang/G.fst | fstdeterminizestar --use-log=true | \ - fstminimizeencoded | fstpushspecial | \ - fstarcsort --sort_type=ilabel > $lang/tmp/LG.fst.$$ || exit 1; + fstminimizeencoded | fstpushspecial > $lang/tmp/LG.fst.$$ || exit 1; mv $lang/tmp/LG.fst.$$ $lang/tmp/LG.fst fstisstochastic $lang/tmp/LG.fst || echo "[info]: LG not stochastic." fi @@ -98,10 +110,10 @@ ilabels_tmp=$ilabels.$$ trap "rm -f $clg_tmp $ilabels_tmp" EXIT HUP INT PIPE TERM if [[ ! -s $clg || $clg -ot $lang/tmp/LG.fst \ || ! -s $ilabels || $ilabels -ot $lang/tmp/LG.fst ]]; then - fstcomposecontext --context-size=$N --central-position=$P \ + fstcomposecontext $nonterm_opt --context-size=$N --central-position=$P \ --read-disambig-syms=$lang/phones/disambig.int \ --write-disambig-syms=$lang/tmp/disambig_ilabels_${N}_${P}.int \ - $ilabels_tmp < $lang/tmp/LG.fst |\ + $ilabels_tmp $lang/tmp/LG.fst |\ fstarcsort --sort_type=ilabel > $clg_tmp mv $clg_tmp $clg mv $ilabels_tmp $ilabels @@ -111,7 +123,7 @@ fi trap "rm -f $dir/Ha.fst.$$" EXIT HUP INT PIPE TERM if [[ ! -s $dir/Ha.fst || $dir/Ha.fst -ot $model \ || $dir/Ha.fst -ot $lang/tmp/ilabels_${N}_${P} ]]; then - make-h-transducer --disambig-syms-out=$dir/disambig_tid.int \ + make-h-transducer $nonterm_opt --disambig-syms-out=$dir/disambig_tid.int \ --transition-scale=$tscale $lang/tmp/ilabels_${N}_${P} $tree $model \ > $dir/Ha.fst.$$ || exit 1; mv $dir/Ha.fst.$$ $dir/Ha.fst @@ -134,8 +146,9 @@ fi trap "rm -f $dir/HCLG.fst.$$" EXIT HUP INT PIPE TERM if [[ ! -s $dir/HCLG.fst || $dir/HCLG.fst -ot $dir/HCLGa.fst ]]; then - add-self-loops --self-loop-scale=$loopscale --reorder=true \ - $model < $dir/HCLGa.fst | fstconvert --fst_type=const > $dir/HCLG.fst.$$ || exit 1; + add-self-loops --self-loop-scale=$loopscale --reorder=true $model $dir/HCLGa.fst | \ + $prepare_grammar_command | \ + fstconvert --fst_type=const > $dir/HCLG.fst.$$ || exit 1; mv $dir/HCLG.fst.$$ $dir/HCLG.fst if [ $tscale == 1.0 -a $loopscale == 1.0 ]; then # No point doing this test if transition-scale not 1, as it is bound to fail. @@ -162,7 +175,8 @@ mkdir -p $dir/phones cp $lang/phones/word_boundary.* $dir/phones/ 2>/dev/null # might be needed for ctm scoring, cp $lang/phones/align_lexicon.* $dir/phones/ 2>/dev/null # might be needed for ctm scoring, cp $lang/phones/optional_silence.* $dir/phones/ 2>/dev/null # might be needed for analyzing alignments. - # but ignore the error if it's not there. + # but ignore the error if it's not there. + cp $lang/phones/disambig.{txt,int} $dir/phones/ 2> /dev/null cp $lang/phones/silence.csl $dir/phones/ || exit 1; diff --git a/egs/wsj/s5/utils/nnet/gen_dct_mat.py b/egs/wsj/s5/utils/nnet/gen_dct_mat.py index d0f043ad7a4..77461112d0b 100755 --- a/egs/wsj/s5/utils/nnet/gen_dct_mat.py +++ b/egs/wsj/s5/utils/nnet/gen_dct_mat.py @@ -16,16 +16,21 @@ # limitations under the License. # ./gen_dct_mat.py -# script generates matrix with DCT transform, which is sparse -# and takes into account that data-layout is along frequency axis, +# script generates matrix with DCT transform, which is sparse +# and takes into account that data-layout is along frequency axis, # while DCT is done along temporal axis. +from __future__ import division +from __future__ import print_function from math import * import sys from optparse import OptionParser +def print_on_same_line(text): + print(text, end=' ') + parser = OptionParser() parser.add_option('--fea-dim', dest='dim', help='feature dimension') parser.add_option('--splice', dest='splice', help='applied splice value') @@ -49,19 +54,19 @@ #generate sparse DCT matrix -print '[' +print('[') for k in range(dct_basis): for m in range(dim): for n in range(timeContext): - if(n==0): - print m*'0 ', - else: - print (dim-1)*'0 ', - print str(sqrt(2.0/timeContext)*cos(M_PI/timeContext*k*(n+0.5))), + if(n==0): + print_on_same_line(m*'0 ') + else: + print_on_same_line((dim-1)*'0 ') + print_on_same_line(str(sqrt(2.0/timeContext)*cos(M_PI/timeContext*k*(n+0.5)))) if(n==timeContext-1): - print (dim-m-1)*'0 ', - print - print + print_on_same_line((dim-m-1)*'0 ') + print() + print() -print ']' +print(']') diff --git a/egs/wsj/s5/utils/nnet/gen_hamm_mat.py b/egs/wsj/s5/utils/nnet/gen_hamm_mat.py index a4262a8cffd..110178c6702 100755 --- a/egs/wsj/s5/utils/nnet/gen_hamm_mat.py +++ b/egs/wsj/s5/utils/nnet/gen_hamm_mat.py @@ -18,12 +18,17 @@ # ./gen_hamm_mat.py # script generates diagonal matrix with hamming window values +from __future__ import division +from __future__ import print_function from math import * import sys from optparse import OptionParser +def print_on_same_line(text): + print(text, end=' ') + parser = OptionParser() parser.add_option('--fea-dim', dest='dim', help='feature dimension') parser.add_option('--splice', dest='splice', help='applied splice value') @@ -42,16 +47,16 @@ dim_mat=(2*splice+1)*dim timeContext=2*splice+1 -print '[' +print('[') for row in range(dim_mat): for col in range(dim_mat): if col!=row: - print '0', + print_on_same_line('0') else: i=int(row/dim) - print str(0.54 - 0.46*cos((M_2PI * i) / (timeContext-1))), - print + print_on_same_line(str(0.54 - 0.46*cos((M_2PI * i) / (timeContext-1)))) + print() -print ']' +print(']') diff --git a/egs/wsj/s5/utils/nnet/gen_splice.py b/egs/wsj/s5/utils/nnet/gen_splice.py index 0241aeed6ba..f3a2c8b39ac 100755 --- a/egs/wsj/s5/utils/nnet/gen_splice.py +++ b/egs/wsj/s5/utils/nnet/gen_splice.py @@ -18,12 +18,16 @@ # ./gen_splice.py # generates Component +from __future__ import print_function from math import * import sys from optparse import OptionParser +def print_on_same_line(text): + print(text, end=' ') + parser = OptionParser() parser.add_option('--fea-dim', dest='dim_in', help='feature dimension') parser.add_option('--splice', dest='splice', help='number of frames to concatenate with the central frame') @@ -40,12 +44,12 @@ dim_out=(2*splice+1)*dim_in -print '', dim_out, dim_in -print '[', +print(' {0} {1}'.format(dim_out, dim_in)) +print_on_same_line('[') -splice_vec = range(-splice*splice_step, splice*splice_step+1, splice_step) +splice_vec = list(range(-splice*splice_step, splice*splice_step+1, splice_step)) for idx in range(len(splice_vec)): - print splice_vec[idx], + print_on_same_line(splice_vec[idx]) -print ']' +print(']') diff --git a/egs/wsj/s5/utils/nnet/make_blstm_proto.py b/egs/wsj/s5/utils/nnet/make_blstm_proto.py index 6e540ec791a..4d269cfdef0 100755 --- a/egs/wsj/s5/utils/nnet/make_blstm_proto.py +++ b/egs/wsj/s5/utils/nnet/make_blstm_proto.py @@ -17,6 +17,7 @@ # Generated Nnet prototype, to be initialized by 'nnet-initialize'. +from __future__ import print_function import sys ### @@ -54,7 +55,7 @@ parser.print_help() sys.exit(1) -(feat_dim, num_leaves) = map(int,args); +(feat_dim, num_leaves) = [int(i) for i in args]; # Original prototype from Jiayu, # @@ -77,18 +78,18 @@ # The BLSTM layers, if o.num_layers == 1: # Single BLSTM, - print " %d %d %s" % (feat_dim, 2*o.proj_dim_last, o.cell_dim) + lstm_extra_opts + print(" %d %d %s" % (feat_dim, 2*o.proj_dim_last, o.cell_dim) + lstm_extra_opts) else: # >1 BLSTM, - print " %d %d %s" % (feat_dim, 2*o.proj_dim, o.cell_dim) + lstm_extra_opts + print(" %d %d %s" % (feat_dim, 2*o.proj_dim, o.cell_dim) + lstm_extra_opts) for l in range(o.num_layers - 2): - print " %d %d %s" % (2*o.proj_dim, 2*o.proj_dim, o.cell_dim) + lstm_extra_opts - print " %d %d %s" % (2*o.proj_dim, 2*o.proj_dim_last, o.cell_dim) + lstm_extra_opts + print(" %d %d %s" % (2*o.proj_dim, 2*o.proj_dim, o.cell_dim) + lstm_extra_opts) + print(" %d %d %s" % (2*o.proj_dim, 2*o.proj_dim_last, o.cell_dim) + lstm_extra_opts) # Adding for more stability, -print " %d %d" % (2*o.proj_dim_last, 2*o.proj_dim_last) +print(" %d %d" % (2*o.proj_dim_last, 2*o.proj_dim_last)) # Softmax layer, -print " %d %d 0.0 0.0" % (2*o.proj_dim_last, num_leaves) + softmax_affine_opts -print " %d %d" % (num_leaves, num_leaves) +print(" %d %d 0.0 0.0" % (2*o.proj_dim_last, num_leaves) + softmax_affine_opts) +print(" %d %d" % (num_leaves, num_leaves)) diff --git a/egs/wsj/s5/utils/nnet/make_cnn2d_proto.py b/egs/wsj/s5/utils/nnet/make_cnn2d_proto.py deleted file mode 100755 index 73455563b51..00000000000 --- a/egs/wsj/s5/utils/nnet/make_cnn2d_proto.py +++ /dev/null @@ -1,257 +0,0 @@ -#!/usr/bin/python - -# Copyright 2014 Brno University of Technology (author: Karel Vesely) - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -# MERCHANTABLITY OR NON-INFRINGEMENT. -# See the Apache 2 License for the specific language governing permissions and -# limitations under the License. - -# Generated Nnet prototype, to be initialized by 'nnet-initialize'. - -import math, random, sys, warnings -from optparse import OptionParser - -### -### Parse options -### -usage="%prog [options] >nnet-proto-file" -parser = OptionParser(usage) - -parser.add_option('--activation-type', dest='activation_type', - help='Select type of activation function : (|) [default: %default]', - default='', type='string'); - -parser.add_option('--cnn1-num-filters', dest='cnn1_num_filters', - help='Number of filters in first convolutional layer [default: %default]', - default=128, type='int') -# this is given by splice -# parser.add_option('--cnn1-fmap-x-len', dest='cnn1_fmap_x_len', -# help='Size of cnn1-fmap-x-len [default: %default]', -# default=11, type='int') - -# this should be equal to feat_raw_dim -# parser.add_option('--cnn1-fmap-y-len', dest='cnn1_fmap_y_len', -# help='Size of cnn1-fmap-y-len [default: %default]', -# default=32, type='int') - -parser.add_option('--cnn1-filt-x-len', dest='cnn1_filt_x_len', - help='Size of cnn1-filt-x-len [default: %default]', - default=9, type='int') -parser.add_option('--cnn1-filt-y-len', dest='cnn1_filt_y_len', - help='Size of cnn1-filt-y-len [default: %default]', - default=9, type='int') - -parser.add_option('--cnn1-filt-x-step', dest='cnn1_filt_x_step', - help='Size of cnn1-filt-x-step [default: %default]', - default=1, type='int') -parser.add_option('--cnn1-filt-y-step', dest='cnn1_filt_y_step', - help='Size of cnn1-filt-y-step [default: %default]', - default=1, type='int') -parser.add_option('--cnn1-connect-fmap', dest='cnn1_connect_fmap', - help='Size of cnn1-connect-fmap [default: %default]', - default=0, type='int') - -parser.add_option('--pool1-x-len', dest='pool1_x_len', - help='Size of pool1-filt-x-len [default: %default]', - default=1, type='int') -parser.add_option('--pool1-x-step', dest='pool1_x_step', - help='Size of pool1-x-step [default: %default]', - default=1, type='int') - - -# -parser.add_option('--pool1-y-len', dest='pool1_y_len', - help='Size of pool1-y-len [default: %default]', - default=3, type='int') -parser.add_option('--pool1-y-step', dest='pool1_y_step', - help='Size of pool1-y-step [default: %default]', - default=3, type='int') - -parser.add_option('--pool1-type', dest='pool1_type', - help='Type of pooling (Max || Average) [default: %default]', - default='Max', type='string') - -parser.add_option('--cnn2-num-filters', dest='cnn2_num_filters', - help='Number of filters in first convolutional layer [default: %default]', - default=256, type='int') -parser.add_option('--cnn2-filt-x-len', dest='cnn2_filt_x_len', - help='Size of cnn2-filt-x-len [default: %default]', - default=3, type='int') -parser.add_option('--cnn2-filt-y-len', dest='cnn2_filt_y_len', - help='Size of cnn2-filt-y-len [default: %default]', - default=4, type='int') -parser.add_option('--cnn2-filt-x-step', dest='cnn2_filt_x_step', - help='Size of cnn2-filt-x-step [default: %default]', - default=1, type='int') -parser.add_option('--cnn2-filt-y-step', dest='cnn2_filt_y_step', - help='Size of cnn2-filt-y-step [default: %default]', - default=1, type='int') -parser.add_option('--cnn2-connect-fmap', dest='cnn2_connect_fmap', - help='Size of cnn2-connect-fmap [default: %default]', - default=1, type='int') - -parser.add_option('--pitch-dim', dest='pitch_dim', - help='Number of features representing pitch [default: %default]', - default=0, type='int') -parser.add_option('--delta-order', dest='delta_order', - help='Order of delta features [default: %default]', - default=2, type='int') -parser.add_option('--splice', dest='splice', - help='Length of splice [default: %default]', - default=5,type='int') -parser.add_option('--dir', dest='dirct', - help='Directory, where network prototypes will be saved [default: %default]', - default='.', type='string') -parser.add_option('--num-pitch-neurons', dest='num_pitch_neurons', - help='Number of neurons in layers processing pitch features [default: %default]', - default='200', type='int') - - -(o,args) = parser.parse_args() -if len(args) != 1 : - parser.print_help() - sys.exit(1) - -feat_dim=int(args[0]) -### End parse options - -feat_raw_dim = feat_dim / (o.delta_order+1) / (o.splice*2+1) - o.pitch_dim # we need number of feats without deltas and splice and pitch -o.cnn1_fmap_y_len = feat_raw_dim -o.cnn1_fmap_x_len = o.splice*2+1 - -# Check -assert(feat_dim > 0) -assert(o.pool1_type == 'Max' or o.pool1_type == 'Average') - -## Extra checks if dimensions are matching, if not match them by -## producing a warning -# cnn1 -assert( (o.cnn1_fmap_y_len - o.cnn1_filt_y_len) % o.cnn1_filt_y_step == 0 ) -assert( (o.cnn1_fmap_x_len - o.cnn1_filt_x_len) % o.cnn1_filt_x_step == 0 ) - -# subsample1 -cnn1_out_fmap_y_len=((1 + (o.cnn1_fmap_y_len - o.cnn1_filt_y_len) / o.cnn1_filt_y_step)) -cnn1_out_fmap_x_len=((1 + (o.cnn1_fmap_x_len - o.cnn1_filt_x_len) / o.cnn1_filt_x_step)) - -# fix filt_len and filt_step -def fix_filt_step(inp_len, filt_len, filt_step): - - if ((inp_len - filt_len) % filt_step == 0): - return filt_step - else: - # filt_step <= filt_len - for filt_step in xrange(filt_len, 0, -1): - if ((inp_len - filt_len) % filt_step == 0): - return filt_step - -o.pool1_y_step = fix_filt_step(cnn1_out_fmap_y_len, o.pool1_y_len, o.pool1_y_step) -if o.pool1_y_step == 1 and o.pool1_y_len != 1: - warnings.warn('WARNING: Choose different pool1_y_len as subsampling is not happening'); - -o.pool1_x_step = fix_filt_step(cnn1_out_fmap_x_len, o.pool1_x_len, o.pool1_x_step) -if o.pool1_x_step == 1 and o.pool1_x_len != 1: - warnings.warn('WARNING: Choose different pool1_x_len as subsampling is not happening'); - - -### -### Print prototype of the network -### - -# Begin the prototype -print "" - -# Convolutional part of network -'''1st CNN layer''' -cnn1_input_dim=feat_raw_dim * (o.delta_order+1) * (o.splice*2+1) -cnn1_out_fmap_x_len=((1 + (o.cnn1_fmap_x_len - o.cnn1_filt_x_len) / o.cnn1_filt_x_step)) -cnn1_out_fmap_y_len=((1 + (o.cnn1_fmap_y_len - o.cnn1_filt_y_len) / o.cnn1_filt_y_step)) -cnn1_output_dim=o.cnn1_num_filters * cnn1_out_fmap_x_len * cnn1_out_fmap_y_len - -'''1st Pooling layer''' -pool1_input_dim=cnn1_output_dim -pool1_fmap_x_len=cnn1_out_fmap_x_len -pool1_out_fmap_x_len=((1 + (pool1_fmap_x_len - o.pool1_x_len) / o.pool1_x_step)) -pool1_fmap_y_len=cnn1_out_fmap_y_len -pool1_out_fmap_y_len=((1 + (pool1_fmap_y_len - o.pool1_y_len) / o.pool1_y_step)) -pool1_output_dim=o.cnn1_num_filters*pool1_out_fmap_x_len*pool1_out_fmap_y_len - -'''2nd CNN layer''' -cnn2_input_dim=pool1_output_dim -cnn2_fmap_x_len=pool1_out_fmap_x_len -cnn2_out_fmap_x_len=((1 + (cnn2_fmap_x_len - o.cnn2_filt_x_len) / o.cnn2_filt_x_step)) -cnn2_fmap_y_len=pool1_out_fmap_y_len -cnn2_out_fmap_y_len=((1 + (cnn2_fmap_y_len - o.cnn2_filt_y_len) / o.cnn2_filt_y_step)) -cnn2_output_dim=o.cnn2_num_filters * cnn2_out_fmap_x_len * cnn2_out_fmap_y_len - - -convolution_proto = '' - -convolution_proto += " %d %d %d %d %d %d %d %d %d %f %f %f\n" % \ - ( cnn1_input_dim, cnn1_output_dim, o.cnn1_fmap_x_len, o.cnn1_fmap_y_len, o.cnn1_filt_x_len, o.cnn1_filt_y_len, o.cnn1_filt_x_step, o.cnn1_filt_y_step, o.cnn1_connect_fmap, 0.0, 0.0, 0.01 ) -convolution_proto += "<%sPooling2DComponent> %d %d %d %d %d %d %d %d\n" % \ - ( o.pool1_type, pool1_input_dim, pool1_output_dim, pool1_fmap_x_len, pool1_fmap_y_len, o.pool1_x_len, o.pool1_y_len, o.pool1_x_step, o.pool1_y_step ) -convolution_proto += " %d %d %f\n" % \ - ( pool1_output_dim, pool1_output_dim, 1.0 ) -convolution_proto += " %d %d %f\n" % \ - ( pool1_output_dim, pool1_output_dim, 0.0 ) -convolution_proto += "%s %d %d\n" % \ - ( o.activation_type, pool1_output_dim, pool1_output_dim ) -convolution_proto += " %d %d %d %d %d %d %d %d %d %f %f %f\n" % \ - ( cnn2_input_dim, cnn2_output_dim, cnn2_fmap_x_len, cnn2_fmap_y_len, o.cnn2_filt_x_len, o.cnn2_filt_y_len, o.cnn2_filt_x_step, o.cnn2_filt_y_step, o.cnn2_connect_fmap, -2.0, 4.0, 0.1 ) -convolution_proto += " %d %d %f\n" % \ - ( cnn2_output_dim, cnn2_output_dim, 1.0) -convolution_proto += " %d %d %f\n" % \ - ( cnn2_output_dim, cnn2_output_dim, 0.0) -convolution_proto += "%s %d %d\n" % \ - ( o.activation_type, cnn2_output_dim, cnn2_output_dim) - -if (o.pitch_dim > 0): - # convolutional part - f_conv = open('%s/nnet.proto.convolution' % o.dirct, 'w') - f_conv.write('\n') - f_conv.write(convolution_proto) - f_conv.write('\n') - f_conv.close() - - # pitch part - f_pitch = open('%s/nnet.proto.pitch' % o.dirct, 'w') - f_pitch.write('\n') - f_pitch.write(' %d %d %f %f %f\n' % \ - ((o.pitch_dim * (o.delta_order+1) * (o.splice*2+1)), o.num_pitch_neurons, -2.0, 4.0, 0.109375)) - f_pitch.write('%s %d %d\n' % \ - (o.activation_type, o.num_pitch_neurons, o.num_pitch_neurons)) - f_pitch.write(' %d %d %f %f %f\n' % \ - (o.num_pitch_neurons, o.num_pitch_neurons, -2.0, 4.0, 0.109375)) - f_pitch.write('%s %d %d\n' % \ - (o.activation_type, o.num_pitch_neurons, o.num_pitch_neurons)) - f_pitch.write('\n') - f_pitch.close() - - # paralell part - vector = '' - for i in range(1, (feat_raw_dim + o.pitch_dim) * (o.delta_order+1) * (o.splice*2+1), feat_raw_dim + o.pitch_dim): - vector += '%d:1:%d ' % (i, i + feat_raw_dim - 1) - for i in range(feat_raw_dim+1, (feat_raw_dim + o.pitch_dim) * (o.delta_order+1) * (o.splice*2+1), feat_raw_dim + o.pitch_dim): - vector += '%d:1:%d ' % (i, i + o.pitch_dim - 1) - print ' %d %d %s ' % \ - ((feat_raw_dim + o.pitch_dim) * (o.delta_order+1) * (o.splice*2+1), (feat_raw_dim + o.pitch_dim) * (o.delta_order+1) * (o.splice*2+1), vector) - print ' %d %d %s %s ' % \ - ((feat_raw_dim + o.pitch_dim) * (o.delta_order+1) * (o.splice*2+1), o.num_pitch_neurons + cnn2_output_dim, '%s/nnet.proto.convolution' % o.dirct, '%s/nnet.proto.pitch' % o.dirct) - - num_convolution_output = o.num_pitch_neurons + cnn2_output_dim -else: # no pitch - print convolution_proto - -# We are done! -sys.exit(0) - - diff --git a/egs/wsj/s5/utils/nnet/make_cnn_proto.py b/egs/wsj/s5/utils/nnet/make_cnn_proto.py index c6aa519ea96..4d8b9ca2946 100755 --- a/egs/wsj/s5/utils/nnet/make_cnn_proto.py +++ b/egs/wsj/s5/utils/nnet/make_cnn_proto.py @@ -17,6 +17,8 @@ # Generated Nnet prototype, to be initialized by 'nnet-initialize'. +from __future__ import division +from __future__ import print_function import math, random, sys from optparse import OptionParser @@ -88,7 +90,7 @@ ### # Begin the prototype -print "" +print("") # Convolutional part of network num_patch1 = 1 + (feat_raw_dim - o.patch_dim1) / o.patch_step1 @@ -150,13 +152,13 @@ vector += '%d:1:%d ' % (i, i + feat_raw_dim - 1) for i in range(feat_raw_dim+1, inputdim_of_cnn + 1, feat_raw_dim + o.pitch_dim): vector += '%d:1:%d ' % (i, i + o.pitch_dim - 1) - print ' %d %d %s ' % \ - (inputdim_of_cnn, inputdim_of_cnn, vector) - print ' %d %d %s %s ' % \ - (inputdim_of_cnn, o.num_pitch_neurons + outputdim_of_cnn, '%s/nnet.proto.convolution' % o.protodir, '%s/nnet.proto.pitch' % o.protodir) + print(' %d %d %s ' % \ + (inputdim_of_cnn, inputdim_of_cnn, vector)) + print(' %d %d %s %s ' % \ + (inputdim_of_cnn, o.num_pitch_neurons + outputdim_of_cnn, '%s/nnet.proto.convolution' % o.protodir, '%s/nnet.proto.pitch' % o.protodir)) else: # no pitch - print convolution_proto + print(convolution_proto) # We are done! sys.exit(0) diff --git a/egs/wsj/s5/utils/nnet/make_lstm_proto.py b/egs/wsj/s5/utils/nnet/make_lstm_proto.py index a2da0a194fc..6818c860ed0 100755 --- a/egs/wsj/s5/utils/nnet/make_lstm_proto.py +++ b/egs/wsj/s5/utils/nnet/make_lstm_proto.py @@ -17,6 +17,7 @@ # Generated Nnet prototype, to be initialized by 'nnet-initialize'. +from __future__ import print_function import sys ### @@ -52,7 +53,7 @@ parser.print_help() sys.exit(1) -(feat_dim, num_leaves) = map(int,args); +(feat_dim, num_leaves) = [int(i) for i in args]; # Original prototype from Jiayu, # @@ -73,14 +74,14 @@ if None != o.param_stddev: softmax_affine_opts += " %f " % o.param_stddev # The LSTM layers, -print " %d %d %s" % (feat_dim, o.proj_dim, o.cell_dim) + lstm_extra_opts +print(" %d %d %s" % (feat_dim, o.proj_dim, o.cell_dim) + lstm_extra_opts) for l in range(o.num_layers - 1): - print " %d %d %s" % (o.proj_dim, o.proj_dim, o.cell_dim) + lstm_extra_opts + print(" %d %d %s" % (o.proj_dim, o.proj_dim, o.cell_dim) + lstm_extra_opts) # Adding for more stability, -print " %d %d" % (o.proj_dim, o.proj_dim) +print(" %d %d" % (o.proj_dim, o.proj_dim)) # Softmax layer, -print " %d %d 0.0 0.0" % (o.proj_dim, num_leaves) + softmax_affine_opts -print " %d %d" % (num_leaves, num_leaves) +print(" %d %d 0.0 0.0" % (o.proj_dim, num_leaves) + softmax_affine_opts) +print(" %d %d" % (num_leaves, num_leaves)) diff --git a/egs/wsj/s5/utils/nnet/make_nnet_proto.py b/egs/wsj/s5/utils/nnet/make_nnet_proto.py index 7b5c50beeb8..4f60be6c9d0 100755 --- a/egs/wsj/s5/utils/nnet/make_nnet_proto.py +++ b/egs/wsj/s5/utils/nnet/make_nnet_proto.py @@ -17,14 +17,13 @@ # Generated Nnet prototype, to be initialized by 'nnet-initialize'. +from __future__ import division +from __future__ import print_function import math, random, sys, re ### ### Parse options ### - -print >> sys.stderr, sys.argv - from optparse import OptionParser usage="%prog [options] >nnet-proto-file" parser = OptionParser(usage) @@ -90,7 +89,7 @@ o.affine_opts = o.affine_opts.replace("_"," ") o.dropout_opts = o.dropout_opts.replace("_"," ") -(feat_dim, num_leaves, num_hid_layers, num_hid_neurons) = map(int,args); +(feat_dim, num_leaves, num_hid_layers, num_hid_neurons) = [int(i) for i in args]; ### End parse options @@ -123,46 +122,46 @@ def Glorot(dim1, dim2): assert(num_hid_layers == 0) if o.bottleneck_trick: # 25% smaller stddev -> small bottleneck range, 10x smaller learning rate - print " %d %d %f %f" % \ + print(" %d %d %f %f" % \ (feat_dim, o.bottleneck_dim, \ - (o.param_stddev_factor * Glorot(feat_dim, o.bottleneck_dim) * 0.75 ), 0.1) + (o.param_stddev_factor * Glorot(feat_dim, o.bottleneck_dim) * 0.75 ), 0.1)) # 25% smaller stddev -> smaller gradient in prev. layer, 10x smaller learning rate for weigts & biases - print " %d %d %f %f %f %f %f %f" % \ + print(" %d %d %f %f %f %f %f %f" % \ (o.bottleneck_dim, num_hid_neurons, o.hid_bias_mean, o.hid_bias_range, \ - (o.param_stddev_factor * Glorot(o.bottleneck_dim, num_hid_neurons) * 0.75 ), 0.1, 0.1, o.max_norm) + (o.param_stddev_factor * Glorot(o.bottleneck_dim, num_hid_neurons) * 0.75 ), 0.1, 0.1, o.max_norm)) else: - print " %d %d %f" % \ + print(" %d %d %f" % \ (feat_dim, o.bottleneck_dim, \ - (o.param_stddev_factor * Glorot(feat_dim, o.bottleneck_dim))) - print " %d %d %f %f %f %f" % \ + (o.param_stddev_factor * Glorot(feat_dim, o.bottleneck_dim)))) + print(" %d %d %f %f %f %f" % \ (o.bottleneck_dim, num_hid_neurons, o.hid_bias_mean, o.hid_bias_range, \ - (o.param_stddev_factor * Glorot(o.bottleneck_dim, num_hid_neurons)), o.max_norm) - print "%s %d %d %s" % (o.activation_type, num_hid_neurons, num_hid_neurons, o.activation_opts) # Non-linearity + (o.param_stddev_factor * Glorot(o.bottleneck_dim, num_hid_neurons)), o.max_norm)) + print("%s %d %d %s" % (o.activation_type, num_hid_neurons, num_hid_neurons, o.activation_opts)) # Non-linearity # Last AffineTransform (10x smaller learning rate on bias) - print " %d %d %f %f %f %f %f" % \ + print(" %d %d %f %f %f %f %f" % \ (num_hid_neurons, num_leaves, 0.0, 0.0, \ - (o.param_stddev_factor * Glorot(num_hid_neurons, num_leaves)), 1.0, 0.1) + (o.param_stddev_factor * Glorot(num_hid_neurons, num_leaves)), 1.0, 0.1)) # Optionaly append softmax if o.with_softmax: if o.block_softmax_dims == "": - print " %d %d" % (num_leaves, num_leaves) + print(" %d %d" % (num_leaves, num_leaves)) else: - print " %d %d %s" % (num_leaves, num_leaves, o.block_softmax_dims) - print "" + print(" %d %d %s" % (num_leaves, num_leaves, o.block_softmax_dims)) + print("") # We are done! sys.exit(0) # NO HIDDEN LAYERS! # Add only last layer (logistic regression) if num_hid_layers == 0: - print " %d %d %f %f %f" % \ - (feat_dim, num_leaves, 0.0, 0.0, (o.param_stddev_factor * Glorot(feat_dim, num_leaves))) + print(" %d %d %f %f %f" % \ + (feat_dim, num_leaves, 0.0, 0.0, (o.param_stddev_factor * Glorot(feat_dim, num_leaves)))) if o.with_softmax: if o.block_softmax_dims == "": - print " %d %d" % (num_leaves, num_leaves) + print(" %d %d" % (num_leaves, num_leaves)) else: - print " %d %d %s" % (num_leaves, num_leaves, o.block_softmax_dims) - print "" + print(" %d %d %s" % (num_leaves, num_leaves, o.block_softmax_dims)) + print("") # We are done! sys.exit(0) @@ -173,63 +172,63 @@ def Glorot(dim1, dim2): # Begin the prototype, # First AffineTranform, -print " %d %d %f %f %f %f %s" % \ +print(" %d %d %f %f %f %f %s" % \ (feat_dim, num_hid_neurons, o.hid_bias_mean, o.hid_bias_range, \ (o.param_stddev_factor * Glorot(feat_dim, num_hid_neurons) * \ - (math.sqrt(1.0/12.0) if o.smaller_input_weights else 1.0)), o.max_norm, o.affine_opts) + (math.sqrt(1.0/12.0) if o.smaller_input_weights else 1.0)), o.max_norm, o.affine_opts)) # Note.: compensating dynamic range mismatch between input features and Sigmoid-hidden layers, # i.e. mapping the std-dev of N(0,1) (input features) to std-dev of U[0,1] (sigmoid-outputs). # This is done by multiplying with stddev(U[0,1]) = sqrt(1/12). # The stddev of weights is consequently reduced with scale 0.29, -print "%s %d %d %s" % (o.activation_type, num_hid_neurons, num_hid_neurons, o.activation_opts) +print("%s %d %d %s" % (o.activation_type, num_hid_neurons, num_hid_neurons, o.activation_opts)) if o.with_dropout: - print " %d %d %s" % (num_hid_neurons, num_hid_neurons, o.dropout_opts) + print(" %d %d %s" % (num_hid_neurons, num_hid_neurons, o.dropout_opts)) # Internal AffineTransforms, for i in range(num_hid_layers-1): - print " %d %d %f %f %f %f %s" % \ + print(" %d %d %f %f %f %f %s" % \ (num_hid_neurons, num_hid_neurons, o.hid_bias_mean, o.hid_bias_range, \ - (o.param_stddev_factor * Glorot(num_hid_neurons, num_hid_neurons)), o.max_norm, o.affine_opts) - print "%s %d %d %s" % (o.activation_type, num_hid_neurons, num_hid_neurons, o.activation_opts) + (o.param_stddev_factor * Glorot(num_hid_neurons, num_hid_neurons)), o.max_norm, o.affine_opts)) + print("%s %d %d %s" % (o.activation_type, num_hid_neurons, num_hid_neurons, o.activation_opts)) if o.with_dropout: - print " %d %d %s" % (num_hid_neurons, num_hid_neurons, o.dropout_opts) + print(" %d %d %s" % (num_hid_neurons, num_hid_neurons, o.dropout_opts)) # Optionaly add bottleneck, if o.bottleneck_dim != 0: assert(o.bottleneck_dim > 0) if o.bottleneck_trick: # 25% smaller stddev -> small bottleneck range, 10x smaller learning rate - print " %d %d %f %f" % \ + print(" %d %d %f %f" % \ (num_hid_neurons, o.bottleneck_dim, \ - (o.param_stddev_factor * Glorot(num_hid_neurons, o.bottleneck_dim) * 0.75 ), 0.1) + (o.param_stddev_factor * Glorot(num_hid_neurons, o.bottleneck_dim) * 0.75 ), 0.1)) # 25% smaller stddev -> smaller gradient in prev. layer, 10x smaller learning rate for weigts & biases - print " %d %d %f %f %f %f %f %f %s" % \ + print(" %d %d %f %f %f %f %f %f %s" % \ (o.bottleneck_dim, num_hid_neurons, o.hid_bias_mean, o.hid_bias_range, \ - (o.param_stddev_factor * Glorot(o.bottleneck_dim, num_hid_neurons) * 0.75 ), 0.1, 0.1, o.max_norm, o.affine_opts) + (o.param_stddev_factor * Glorot(o.bottleneck_dim, num_hid_neurons) * 0.75 ), 0.1, 0.1, o.max_norm, o.affine_opts)) else: # Same learninig-rate and stddev-formula everywhere, - print " %d %d %f" % \ + print(" %d %d %f" % \ (num_hid_neurons, o.bottleneck_dim, \ - (o.param_stddev_factor * Glorot(num_hid_neurons, o.bottleneck_dim))) - print " %d %d %f %f %f %f %s" % \ + (o.param_stddev_factor * Glorot(num_hid_neurons, o.bottleneck_dim)))) + print(" %d %d %f %f %f %f %s" % \ (o.bottleneck_dim, num_hid_neurons, o.hid_bias_mean, o.hid_bias_range, \ - (o.param_stddev_factor * Glorot(o.bottleneck_dim, num_hid_neurons)), o.max_norm, o.affine_opts) - print "%s %d %d %s" % (o.activation_type, num_hid_neurons, num_hid_neurons, o.activation_opts) + (o.param_stddev_factor * Glorot(o.bottleneck_dim, num_hid_neurons)), o.max_norm, o.affine_opts)) + print("%s %d %d %s" % (o.activation_type, num_hid_neurons, num_hid_neurons, o.activation_opts)) if o.with_dropout: - print " %d %d %s" % (num_hid_neurons, num_hid_neurons, o.dropout_opts) + print(" %d %d %s" % (num_hid_neurons, num_hid_neurons, o.dropout_opts)) # Last AffineTransform (10x smaller learning rate on bias) -print " %d %d %f %f %f %f %f" % \ +print(" %d %d %f %f %f %f %f" % \ (num_hid_neurons, num_leaves, 0.0, 0.0, \ - (o.param_stddev_factor * Glorot(num_hid_neurons, num_leaves)), 1.0, 0.1) + (o.param_stddev_factor * Glorot(num_hid_neurons, num_leaves)), 1.0, 0.1)) # Optionaly append softmax if o.with_softmax: if o.block_softmax_dims == "": - print " %d %d" % (num_leaves, num_leaves) + print(" %d %d" % (num_leaves, num_leaves)) else: - print " %d %d %s" % (num_leaves, num_leaves, o.block_softmax_dims) + print(" %d %d %s" % (num_leaves, num_leaves, o.block_softmax_dims)) # We are done! sys.exit(0) diff --git a/egs/wsj/s5/utils/parallel/limit_num_gpus.sh b/egs/wsj/s5/utils/parallel/limit_num_gpus.sh index d9707a816c4..9d7caddd1f6 100755 --- a/egs/wsj/s5/utils/parallel/limit_num_gpus.sh +++ b/egs/wsj/s5/utils/parallel/limit_num_gpus.sh @@ -18,8 +18,8 @@ if [ "$1" == "--num-gpus" ]; then shift fi -if ! printf "%d" "$num_gpus" >/dev/null || [ $num_gpus -le 0 ]; then - echo $0: Must pass a positive interger after --num-gpus +if ! printf "%d" "$num_gpus" >/dev/null || [ $num_gpus -le -1 ]; then + echo $0: Must pass a positive interger or 0 after --num-gpus echo e.g. $0 --num-gpus 2 local/tfrnnlm/run_lstm.sh exit 1 fi @@ -35,18 +35,24 @@ CUDA_VISIBLE_DEVICES= num_total_gpus=`nvidia-smi -L | wc -l` num_gpus_assigned=0 -for i in `seq 0 $[$num_total_gpus-1]`; do -# going over all GPUs and check if it is idle, and add to the list if yes - if nvidia-smi -i $i | grep "No running processes found" >/dev/null; then - CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}$i, && num_gpus_assigned=$[$num_gpus_assigned+1] - fi -# once we have enough GPUs, break out of the loop - [ $num_gpus_assigned -eq $num_gpus ] && break -done +if [ $num_gpus -eq 0 ] ; then + echo "$0: Running the job on CPU. Disabling submitting to gpu" + export CUDA_VISIBLE_DEVICES="" +else + for i in `seq 0 $[$num_total_gpus-1]`; do + # going over all GPUs and check if it is idle, and add to the list if yes + if nvidia-smi -i $i | grep "No running processes found" >/dev/null; then + CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}$i, && num_gpus_assigned=$[$num_gpus_assigned+1] + fi + # once we have enough GPUs, break out of the loop + [ $num_gpus_assigned -eq $num_gpus ] && break + done -[ $num_gpus_assigned -ne $num_gpus ] && echo Could not find enough idle GPUs && exit 1 + [ $num_gpus_assigned -ne $num_gpus ] && echo Could not find enough idle GPUs && exit 1 -export CUDA_VISIBLE_DEVICES=$(echo $CUDA_VISIBLE_DEVICES | sed "s=,$==g") + export CUDA_VISIBLE_DEVICES=$(echo $CUDA_VISIBLE_DEVICES | sed "s=,$==g") + + echo "$0: Running the job on GPU(s) $CUDA_VISIBLE_DEVICES" +fi -echo "$0: Running the job on GPU(s) $CUDA_VISIBLE_DEVICES" "$@" diff --git a/egs/wsj/s5/utils/parallel/pbs.pl b/egs/wsj/s5/utils/parallel/pbs.pl index 6c8d4488882..d61bb1d4566 100755 --- a/egs/wsj/s5/utils/parallel/pbs.pl +++ b/egs/wsj/s5/utils/parallel/pbs.pl @@ -11,19 +11,17 @@ use Cwd; use Getopt::Long; -# This is a version of the queue.pl modified so that it works under PBS +# This is a version of the queue.pl modified so that it works under PBS # The PBS is one of the several "almost compatible" queueing systems. The # command switches and environment variables are different, so we are adding # a this script. An optimal solution might probably be to make the variable # names and the commands configurable, as similar problems can be expected # with Torque, Univa... and who knows what else # -# queue.pl has the same functionality as run.pl, except that -# it runs the job in question on the queue (Sun GridEngine). -# This version of queue.pl uses the task array functionality -# of the grid engine. Note: it's different from the queue.pl -# in the s4 and earlier scripts. - +# pbs.pl has the same functionality as run.pl, except that +# it runs the job in question on the queue (PBS). +# This version of pbs.pl uses the task array functionality +# of PBS. # The script now supports configuring the queue system using a config file # (default in conf/pbs.conf; but can be passed specified with --config option) # and a set of command line options. @@ -78,12 +76,12 @@ sub print_usage() { print STDERR - "Usage: queue.pl [options] [JOB=1:n] log-file command-line arguments...\n" . - "e.g.: queue.pl foo.log echo baz\n" . + "Usage: pbs.pl [options] [JOB=1:n] log-file command-line arguments...\n" . + "e.g.: pbs.pl foo.log echo baz\n" . " (which will echo \"baz\", with stdout and stderr directed to foo.log)\n" . - "or: queue.pl -q all.q\@xyz foo.log echo bar \| sed s/bar/baz/ \n" . + "or: pbs.pl -q all.q\@xyz foo.log echo bar \| sed s/bar/baz/ \n" . " (which is an example of using a pipe; you can provide other escaped bash constructs)\n" . - "or: queue.pl -q all.q\@qyz JOB=1:10 foo.JOB.log echo JOB \n" . + "or: pbs.pl -q all.q\@qyz JOB=1:10 foo.JOB.log echo JOB \n" . " (which illustrates the mechanism to submit parallel jobs; note, you can use \n" . " another string other than JOB)\n" . "Note: if you pass the \"-sync y\" option to qsub, this script will take note\n" . @@ -113,7 +111,7 @@ () } else { my $argument = shift @ARGV; if ($argument =~ m/^--/) { - print STDERR "queue.pl: Warning: suspicious argument '$argument' to $switch; starts with '-'\n"; + print STDERR "pbs.pl: Warning: suspicious argument '$argument' to $switch; starts with '-'\n"; } if ($switch eq "-sync" && $argument =~ m/^[yY]/) { $sync = 1; @@ -141,7 +139,7 @@ () $jobend = $3; shift; if ($jobstart > $jobend) { - die "queue.pl: invalid job range $ARGV[0]"; + die "pbs.pl: invalid job range $ARGV[0]"; } if ($jobstart <= 0) { die "run.pl: invalid job range $ARGV[0], start must be strictly positive (this is a GridEngine limitation)."; @@ -153,7 +151,7 @@ () $jobend = $2; shift; } elsif ($ARGV[0] =~ m/.+\=.*\:.*$/) { - print STDERR "queue.pl: Warning: suspicious first argument to queue.pl: $ARGV[0]\n"; + print STDERR "pbs.pl: Warning: suspicious first argument to queue.pl: $ARGV[0]\n"; } } @@ -248,7 +246,7 @@ () $cli_options{$option} = $value; } } else { - print STDERR "queue.pl: unable to parse line '$line' in config file ($config)\n"; + print STDERR "pbs.pl: unable to parse line '$line' in config file ($config)\n"; exit(1); } } @@ -256,7 +254,7 @@ () close(CONFIG); if ($read_command != 1) { - print STDERR "queue.pl: config file ($config) does not contain the line \"command .*\"\n"; + print STDERR "pbs.pl: config file ($config) does not contain the line \"command .*\"\n"; exit(1); } @@ -271,7 +269,7 @@ () $qsub_opts .= "$cli_config_options{$option} "; } else { if ($opened_config_file == 0) { $config = "default config file"; } - die "queue.pl: Command line option $option not described in $config (or value '$value' not allowed)\n"; + die "pbs.pl: Command line option $option not described in $config (or value '$value' not allowed)\n"; } } @@ -280,7 +278,7 @@ () if ($array_job == 1 && $logfile !~ m/$jobname/ && $jobend > $jobstart) { - print STDERR "queue.pl: you are trying to run a parallel job but " + print STDERR "pbs.pl: you are trying to run a parallel job but " . "you are putting the output into just one log file ($logfile)\n"; exit(1); } @@ -289,7 +287,7 @@ () # Work out the command; quote escaping is done here. # Note: the rules for escaping stuff are worked out pretty # arbitrarily, based on what we want it to do. Some things that -# we pass as arguments to queue.pl, such as "|", we want to be +# we pass as arguments to pbs.pl, such as "|", we want to be # interpreted by bash, so we don't escape them. Other things, # such as archive specifiers like 'ark:gunzip -c foo.gz|', we want # to be passed, in quotes, to the Kaldi program. Our heuristic @@ -394,16 +392,16 @@ () if ($ret != 0) { if ($sync && $ret == 256) { # this is the exit status when a job failed (bad exit status) if (defined $jobname) { $logfile =~ s/\$PBS_ARRAY_INDEX/*/g; } - print STDERR "queue.pl: job writing to $logfile failed\n"; + print STDERR "pbs.pl: job writing to $logfile failed\n"; } else { - print STDERR "queue.pl: error submitting jobs to queue (return status was $ret)\n"; + print STDERR "pbs.pl: error submitting jobs to queue (return status was $ret)\n"; print STDERR "queue log file is $queue_logfile, command was $qsub_cmd\n"; print STDERR `tail $queue_logfile`; } exit(1); } -my $sge_job_id; +my $pbs_job_id; if (! $sync) { # We're not submitting with -sync y, so we # need to wait for the jobs to finish. We wait for the # sync-files we "touched" in the script to exist. @@ -415,25 +413,25 @@ () push @syncfiles, "$syncfile.$jobid"; } } - # We will need the sge_job_id, to check that job still exists - { # Get the SGE job-id from the log file in q/ - open(L, "<$queue_logfile") || die "Error opening log file $queue_logfile"; - undef $sge_job_id; - while () { - if (m/Your job\S* (\d+)[. ].+ has been submitted/) { - if (defined $sge_job_id) { + # We will need the pbs_job_id, to check that job still exists + { # Get the PBS job-id from the log file in q/ + open my $L, '<', $queue_logfile || die "Error opening log file $queue_logfile"; + undef $pbs_job_id; + while (<$L>) { + if (/(\d+.+\.pbsserver)/) { + if (defined $pbs_job_id) { die "Error: your job was submitted more than once (see $queue_logfile)"; } else { - $sge_job_id = $1; + $pbs_job_id = $1; } } } - close(L); - if (!defined $sge_job_id) { - die "Error: log file $queue_logfile does not specify the SGE job-id."; + close $L; + if (!defined $pbs_job_id) { + die "Error: log file $queue_logfile does not specify the PBS job-id."; } } - my $check_sge_job_ctr=1; + my $check_pbs_job_ctr=1; # my $wait = 0.1; my $counter = 0; @@ -462,11 +460,11 @@ () } } - # Check that the job exists in SGE. Job can be killed if duration + # Check that the job exists in PBS. Job can be killed if duration # exceeds some hard limit, or in case of a machine shutdown. - if (($check_sge_job_ctr++ % 10) == 0) { # Don't run qstat too often, avoid stress on SGE. + if (($check_pbs_job_ctr++ % 10) == 0) { # Don't run qstat too often, avoid stress on PBS. if ( -f $f ) { next; }; #syncfile appeared: OK. - $ret = system("qstat -t $sge_job_id >/dev/null 2>/dev/null"); + $ret = system("qstat -t $pbs_job_id >/dev/null 2>/dev/null"); # system(...) : To get the actual exit value, shift $ret right by eight bits. if ($ret>>8 == 1) { # Job does not seem to exist # Don't consider immediately missing job as error, first wait some @@ -501,13 +499,13 @@ () # time elapsed between file modification and the start of this # program], then we assume the program really finished OK, # and maybe something is up with the file system. - print STDERR "**queue.pl: syncfile $f was not created but job seems\n" . + print STDERR "**pbs.pl: syncfile $f was not created but job seems\n" . "**to have finished OK. Probably your file-system has problems.\n" . "**This is just a warning.\n"; last; } else { chop $last_line; - print STDERR "queue.pl: Error, unfinished job no " . + print STDERR "pbs.pl: Error, unfinished job no " . "longer exists, log is in $logfile, last line is '$last_line', " . "syncfile is $f, return status of qstat was $ret\n" . "Possible reasons: a) Exceeded time limit? -> Use more jobs!" . @@ -515,7 +513,7 @@ () exit(1); } } elsif ($ret != 0) { - print STDERR "queue.pl: Warning: qstat command returned status $ret (qstat -t $sge_job_id,$!)\n"; + print STDERR "pbs.pl: Warning: qstat command returned status $ret (qstat -t $pbs_job_id,$!)\n"; } } } @@ -574,14 +572,14 @@ () else { # we failed. if (@logfiles == 1) { if (defined $jobname) { $logfile =~ s/\$PBS_ARRAY_INDEX/$jobstart/g; } - print STDERR "queue.pl: job failed with status $status, log is in $logfile\n"; + print STDERR "pbs.pl: job failed with status $status, log is in $logfile\n"; if ($logfile =~ m/JOB/) { - print STDERR "queue.pl: probably you forgot to put JOB=1:\$nj in your script.\n"; + print STDERR "pbs.pl: probably you forgot to put JOB=1:\$nj in your script.\n"; } } else { if (defined $jobname) { $logfile =~ s/\$PBS_ARRAY_INDEX/*/g; } my $numjobs = 1 + $jobend - $jobstart; - print STDERR "queue.pl: $num_failed / $numjobs failed, log is in $logfile\n"; + print STDERR "pbs.pl: $num_failed / $numjobs failed, log is in $logfile\n"; } exit(1); } diff --git a/egs/wsj/s5/utils/parallel/queue.pl b/egs/wsj/s5/utils/parallel/queue.pl index e14af5ef6e3..bddcb4fec23 100755 --- a/egs/wsj/s5/utils/parallel/queue.pl +++ b/egs/wsj/s5/utils/parallel/queue.pl @@ -176,7 +176,7 @@ sub caught_signal { option max_jobs_run=* -tc $0 default gpu=0 option gpu=0 -option gpu=* -l gpu=$0 -q g.q +option gpu=* -l gpu=$0 -q '*.q' EOF # Here the configuration options specified by the user on the command line diff --git a/egs/wsj/s5/utils/parallel/retry.pl b/egs/wsj/s5/utils/parallel/retry.pl index a039d6f5a74..618e9fb01bc 100755 --- a/egs/wsj/s5/utils/parallel/retry.pl +++ b/egs/wsj/s5/utils/parallel/retry.pl @@ -94,7 +94,6 @@ sub get_log_file { # Later on we might want to figure out which array jobs failed # and have to be rerun, but for now we just die. print STDERR "$0: job failed and log file $log_file does not exist (array job?).\n"; - exit($return_status) } else { rename($log_file, $log_file . ".bak"); print STDERR "$0: job failed; renaming log file to ${log_file}.bak and rerunning\n"; diff --git a/egs/wsj/s5/utils/parse_options.sh b/egs/wsj/s5/utils/parse_options.sh index 34476fdb37a..335e69e9ac7 100755 --- a/egs/wsj/s5/utils/parse_options.sh +++ b/egs/wsj/s5/utils/parse_options.sh @@ -42,7 +42,7 @@ done ### -### No we process the command line options +### Now we process the command line options ### while true; do [ -z "${1:-}" ] && break; # break if there are no arguments diff --git a/egs/wsj/s5/utils/perturb_data_dir_speed.sh b/egs/wsj/s5/utils/perturb_data_dir_speed.sh index a50cdb04be4..99c9cbdb1f0 100755 --- a/egs/wsj/s5/utils/perturb_data_dir_speed.sh +++ b/egs/wsj/s5/utils/perturb_data_dir_speed.sh @@ -102,6 +102,9 @@ fi if [ -f $srcdir/spk2gender ]; then utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender fi +if [ -f $srcdir/utt2lang ]; then + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2lang >$destdir/utt2lang +fi #prepare speed-perturbed utt2dur if [ ! -f $srcdir/utt2dur ]; then diff --git a/egs/wsj/s5/utils/prepare_lang.sh b/egs/wsj/s5/utils/prepare_lang.sh index fa5ff7856b0..7c018fd94f9 100755 --- a/egs/wsj/s5/utils/prepare_lang.sh +++ b/egs/wsj/s5/utils/prepare_lang.sh @@ -67,6 +67,8 @@ extra_word_disambig_syms= # if set, add disambiguation symbols from this num_extra_phone_disambig_syms=1 # Standard one phone disambiguation symbol is used for optional silence. # Increasing this number does not harm, but is only useful if you later # want to introduce this labels to L_disambig.fst + + # end configuration sections echo "$0 $@" # Print the command line for logging @@ -74,12 +76,14 @@ echo "$0 $@" # Print the command line for logging . utils/parse_options.sh if [ $# -ne 4 ]; then - echo "usage: utils/prepare_lang.sh " + echo "Usage: utils/prepare_lang.sh " echo "e.g.: utils/prepare_lang.sh data/local/dict data/local/lang data/lang" echo " should contain the following files:" echo " extra_questions.txt lexicon.txt nonsilence_phones.txt optional_silence.txt silence_phones.txt" echo "See http://kaldi-asr.org/doc/data_prep.html#data_prep_lang_creating for more info." echo "options: " + echo " may also, for the grammar-decoding case (see http://kaldi-asr.org/doc/grammar.html)" + echo "contain a file nonterminals.txt containing symbols like #nonterm:contact_list, one per line." echo " --num-sil-states # default: 5, #states in silence models." echo " --num-nonsil-states # default: 3, #states in non-silence models." echo " --position-dependent-phones (true|false) # default: true; if true, use _B, _E, _S & _I" @@ -104,6 +108,11 @@ srcdir=$1 oov_word=$2 tmpdir=$3 dir=$4 + + +if [ -d $dir/phones ]; then + rm -r $dir/phones +fi mkdir -p $dir $tmpdir $dir/phones silprob=false @@ -209,7 +218,6 @@ else paste -d' ' $tmpdir/phones $tmpdir/phones > $tmpdir/phone_map.txt fi -mkdir -p $dir/phones # various sets of phones... # Sets of phones for use in clustering, and making monophone systems. @@ -380,9 +388,9 @@ fi # format of $dir/words.txt: # 0 -#!EXCLAMATION-POINT 1 -#!SIL 2 -#"CLOSE-QUOTE 3 +#a 1 +#aa 2 +#aarvark 3 #... silphone=`cat $srcdir/optional_silence.txt` || exit 1; @@ -403,9 +411,40 @@ perl -ape 's/(\S+\s+)\S+\s+(.+)/$1$2/;' <$tmpdir/lexiconp.txt >$tmpdir/align_lex [ ! -z "$silphone" ] && echo " $silphone" >> $tmpdir/align_lexicon.txt cat $tmpdir/align_lexicon.txt | \ - perl -ane '@A = split; print $A[0], " ", join(" ", @A), "\n";' | sort | uniq > $dir/phones/align_lexicon.txt + perl -ane '@A = split; print $A[0], " ", join(" ", @A), "\n";' | sort | uniq > $dir/phones/align_lexicon.txt + +if [ -f $srcdir/nonterminals.txt ]; then + utils/lang/grammar/augment_phones_txt.py $dir/phones.txt $srcdir/nonterminals.txt $dir/phones.txt + utils/lang/grammar/augment_words_txt.py $dir/words.txt $srcdir/nonterminals.txt $dir/words.txt + cp $srcdir/nonterminals.txt $dir/phones/nonterminals.txt + utils/sym2int.pl $dir/phones.txt <$dir/phones/nonterminals.txt >$dir/phones/nonterminals.int + + for w in "#nonterm_begin" "#nonterm_end" $(cat $srcdir/nonterminals.txt); do + echo $w $w # These are words without pronunciations, so leave those prons + # empty. + done >> $dir/phones/align_lexicon.txt + nonterm_phones_offset=$(grep '#nonterm_bos' <$dir/phones.txt | awk '{print $2}') + echo $nonterm_phones_offset > $dir/phones/nonterm_phones_offset.int + echo '#nonterm_bos' > $dir/phones/nonterm_phones_offset.txt # temporary. + + if [ -f $dir/phones/word_boundary.txt ]; then + # word-position-dependent system. Only include the optional-silence phone, + # and phones that can end a word, plus the special symbol #nonterm_bos, in the + # left-context phones. + awk '{if ($2 == "end" || $2 == "singleton") print $1; }' <$dir/phones/word_boundary.txt | \ + cat - $dir/phones/optional_silence.txt $dir/phones/nonterm_phones_offset.txt > $dir/phones/left_context_phones.txt + else + cat $dir/phones/{silence,nonsilence}.txt $dir/phones/nonterm_phones_offset.txt > $dir/phones/left_context_phones.txt + fi + utils/sym2int.pl $dir/phones.txt <$dir/phones/left_context_phones.txt >$dir/phones/left_context_phones.int -# create phones/align_lexicon.int + # we need to write utils/lang/make_lexicon_fst_silprob.py before this can work. + grammar_opts="--left-context-phones=$dir/phones/left_context_phones.txt --nonterminals=$srcdir/nonterminals.txt" +else + grammar_opts= +fi + +# create phones/align_lexicon.int from phones/align_lexicon.txt cat $dir/phones/align_lexicon.txt | utils/sym2int.pl -f 3- $dir/phones.txt | \ utils/sym2int.pl -f 1-2 $dir/words.txt > $dir/phones/align_lexicon.int @@ -413,18 +452,20 @@ cat $dir/phones/align_lexicon.txt | utils/sym2int.pl -f 3- $dir/phones.txt | \ # in training. if $silprob; then - # Add silence probabilities (modlels the prob. of silence before and after each + # Add silence probabilities (models the prob. of silence before and after each # word). On some setups this helps a bit. See utils/dict_dir_add_pronprobs.sh # and where it's called in the example scripts (run.sh). - utils/make_lexicon_fst_silprob.pl $tmpdir/lexiconp_silprob.txt $srcdir/silprob.txt $silphone "" | \ + utils/lang/make_lexicon_fst_silprob.py $grammar_opts --sil-phone=$silphone \ + $tmpdir/lexiconp_silprob.txt $srcdir/silprob.txt | \ fstcompile --isymbols=$dir/phones.txt --osymbols=$dir/words.txt \ - --keep_isymbols=false --keep_osymbols=false | \ + --keep_isymbols=false --keep_osymbols=false | \ fstarcsort --sort_type=olabel > $dir/L.fst || exit 1; else - utils/make_lexicon_fst.pl --pron-probs $tmpdir/lexiconp.txt $sil_prob $silphone | \ + utils/lang/make_lexicon_fst.py $grammar_opts --sil-prob=$sil_prob --sil-phone=$silphone \ + $tmpdir/lexiconp.txt | \ fstcompile --isymbols=$dir/phones.txt --osymbols=$dir/words.txt \ - --keep_isymbols=false --keep_osymbols=false | \ - fstarcsort --sort_type=olabel > $dir/L.fst || exit 1; + --keep_isymbols=false --keep_osymbols=false | \ + fstarcsort --sort_type=olabel > $dir/L.fst || exit 1; fi # The file oov.txt contains a word that we will map any OOVs to during @@ -490,15 +531,19 @@ utils/gen_topo.pl $num_nonsil_states $num_sil_states $nonsilphonelist $silphonel # disambiguation symbols from G.fst. if $silprob; then - utils/make_lexicon_fst_silprob.pl $tmpdir/lexiconp_silprob_disambig.txt $srcdir/silprob.txt $silphone '#'$ndisambig | \ + utils/lang/make_lexicon_fst_silprob.py $grammar_opts \ + --sil-phone=$silphone --sil-disambig='#'$ndisambig \ + $tmpdir/lexiconp_silprob_disambig.txt $srcdir/silprob.txt | \ fstcompile --isymbols=$dir/phones.txt --osymbols=$dir/words.txt \ - --keep_isymbols=false --keep_osymbols=false | \ + --keep_isymbols=false --keep_osymbols=false | \ fstaddselfloops $dir/phones/wdisambig_phones.int $dir/phones/wdisambig_words.int | \ fstarcsort --sort_type=olabel > $dir/L_disambig.fst || exit 1; else - utils/make_lexicon_fst.pl --pron-probs $tmpdir/lexiconp_disambig.txt $sil_prob $silphone '#'$ndisambig | \ + utils/lang/make_lexicon_fst.py $grammar_opts \ + --sil-prob=$sil_prob --sil-phone=$silphone --sil-disambig='#'$ndisambig \ + $tmpdir/lexiconp_disambig.txt | \ fstcompile --isymbols=$dir/phones.txt --osymbols=$dir/words.txt \ - --keep_isymbols=false --keep_osymbols=false | \ + --keep_isymbols=false --keep_osymbols=false | \ fstaddselfloops $dir/phones/wdisambig_phones.int $dir/phones/wdisambig_words.int | \ fstarcsort --sort_type=olabel > $dir/L_disambig.fst || exit 1; fi diff --git a/egs/wsj/s5/utils/reverse_arpa.py b/egs/wsj/s5/utils/reverse_arpa.py index 5437aec4341..e154a6e0813 100755 --- a/egs/wsj/s5/utils/reverse_arpa.py +++ b/egs/wsj/s5/utils/reverse_arpa.py @@ -2,11 +2,12 @@ # -*- coding: utf-8 -*- # Copyright 2012 Mirko Hannemann BUT, mirko.hannemann@gmail.com +from __future__ import print_function import sys import codecs # for UTF-8/unicode if len(sys.argv) != 2: - print 'usage: reverse_arpa arpa.in' + print('usage: reverse_arpa arpa.in') sys.exit() arpaname = sys.argv[1] @@ -34,13 +35,13 @@ try: file = codecs.open(arpaname, "r", "utf-8") except IOError: - print 'file not found: ' + arpaname + print('file not found: ' + arpaname) sys.exit() text=file.readline() while (text and text[:6] != "\\data\\"): text=file.readline() if not text: - print "invalid ARPA file" + print("invalid ARPA file") sys.exit() #print text, while (text and text[:5] != "ngram"): text=file.readline() @@ -54,7 +55,7 @@ r = ind[0].split() read_n = int(r[1].strip()) if read_n != n+1: - print "invalid ARPA file:", text + print("invalid ARPA file: {}".format(text)) sys.exit() n = read_n cngrams.append(counts) @@ -68,7 +69,7 @@ for n in range(1,len(cngrams)+1): # unigrams, bigrams, trigrams while (text and "-grams:" not in text): text=file.readline() if n != int(text[1]): - print "invalid ARPA file:", text + print("invalid ARPA file:{}".format(text)) sys.exit() #print text,cngrams[n-1] this_ngrams={} # stores all read ngrams @@ -115,7 +116,7 @@ while (text and text[:5] != "\\end\\"): text=file.readline() if not text: - print "invalid ARPA file" + print("invalid ARPA file") sys.exit() file.close() #print text, @@ -133,14 +134,13 @@ #p(ABCD)+b(ABCD)-p(BCD)+p(ABC)-p(BC)+p(AB)-p(B)+p(A) DCBA 0 # compute new reversed ARPA model -print "\\data\\" +print("\\data\\") for n in range(1,len(cngrams)+1): # unigrams, bigrams, trigrams - print "ngram "+str(n)+"="+str(len(ngrams[n-1].keys())) + print("ngram {0} = {1}".format(n, len(ngrams[n-1].keys()))) offset = 0.0 for n in range(1,len(cngrams)+1): # unigrams, bigrams, trigrams - print "\\"+str(n)+"-grams:" - keys = ngrams[n-1].keys() - keys.sort() + print("\\{}-grams:".format(n)) + keys = sorted(ngrams[n-1].keys()) for ngram in keys: prob = ngrams[n-1][ngram] # reverse word order @@ -179,10 +179,10 @@ elif n == 2: revprob = revprob + offset # add weight to bigrams starting with if (prob[1] != inf): # only backoff weights from not newly created ngrams - print revprob,rev_ngram.encode("utf-8"),back + print(revprob,rev_ngram.encode("utf-8"),back) else: - print revprob,rev_ngram.encode("utf-8"),"-100000.0" + print(revprob,rev_ngram.encode("utf-8"),"-100000.0") else: # highest order - no backoff weights if (n==2) and (rev_ngram[:3] == ""): revprob = revprob + offset - print revprob,rev_ngram.encode("utf-8") -print "\\end\\" + print(revprob,rev_ngram.encode("utf-8")) +print("\\end\\") diff --git a/egs/wsj/s5/utils/scoring/wer_per_spk_details.pl b/egs/wsj/s5/utils/scoring/wer_per_spk_details.pl index 217448e9fb0..f44c0d9cfb3 100755 --- a/egs/wsj/s5/utils/scoring/wer_per_spk_details.pl +++ b/egs/wsj/s5/utils/scoring/wer_per_spk_details.pl @@ -130,10 +130,15 @@ sub format_sys { } open(UTT2SPK,$ARGV[0]) or die "Could not open the utt2spk file $ARGV[0]"; -while() { - chomp; - my @F=split; - die "Incompatible format of the utt2spk file: $_" if @F != 2; + +(my $utt_is_utf8, my @utt_lines) = get_utf8_or_bytestream(\*UTT2SPK); +die "Cannot read file" unless @utt_lines; + +while (@utt_lines) { + my $line = shift @utt_lines; + chomp $line; + my @F=split(" ", $line); + die "Incompatible format of the utt2spk file: $_" if @F != 2; $UTTMAP{$F[0]} = $F[1]; # Set width of speaker column by its longest label, if($SPK_WIDTH < length($F[1])) { $SPK_WIDTH = length($F[1]) } diff --git a/egs/wsj/s5/utils/subset_data_dir.sh b/egs/wsj/s5/utils/subset_data_dir.sh index ba52d140ccc..93ee0971b88 100755 --- a/egs/wsj/s5/utils/subset_data_dir.sh +++ b/egs/wsj/s5/utils/subset_data_dir.sh @@ -124,8 +124,10 @@ function do_filtering { [ -f $srcdir/reco2file_and_channel ] && \ utils/filter_scp.pl $destdir/reco <$srcdir/reco2file_and_channel >$destdir/reco2file_and_channel - # Filter the STM file for proper sclite scoring (this will also remove the comments lines) - [ -f $srcdir/stm ] && utils/filter_scp.pl $destdir/reco < $srcdir/stm > $destdir/stm + # Filter the STM file for proper sclite scoring + # Copy over the comments from STM file + [ -f $srcdir/stm ] && grep "^;;" $srcdir/stm > $destdir/stm + [ -f $srcdir/stm ] && utils/filter_scp.pl $destdir/reco < $srcdir/stm >> $destdir/stm rm $destdir/reco else diff --git a/egs/wsj/s5/utils/validate_data_dir.sh b/egs/wsj/s5/utils/validate_data_dir.sh index a8b0542c1bb..dc06b6fa59e 100755 --- a/egs/wsj/s5/utils/validate_data_dir.sh +++ b/egs/wsj/s5/utils/validate_data_dir.sh @@ -79,6 +79,7 @@ trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM export LC_ALL=C function check_sorted_and_uniq { + ! perl -ne '((substr $_,-1) eq "\n") or die "file $ARGV has invalid newline";' $1 && exit 1; ! awk '{print $1}' $1 | sort | uniq | cmp -s - <(awk '{print $1}' $1) && \ echo "$0: file $1 is not in sorted order or has duplicates" && exit 1; } diff --git a/egs/wsj/s5/utils/validate_dict_dir.pl b/egs/wsj/s5/utils/validate_dict_dir.pl index 981dc005116..8f8534c329b 100755 --- a/egs/wsj/s5/utils/validate_dict_dir.pl +++ b/egs/wsj/s5/utils/validate_dict_dir.pl @@ -5,7 +5,7 @@ # 2015 Daniel Povey # 2017 Johns Hopkins University (Jan "Yenda" Trmal ) # -# Validation script for data/local/dict +# Validation script for 'dict' directories (e.g. data/local/dict) # this function reads the opened file (supplied as a first # parameter) into an array of lines. For each @@ -56,6 +56,10 @@ sub validate_utf8_whitespaces { use feature 'unicode_strings'; for (my $i = 0; $i < scalar @{$unicode_lines}; $i++) { my $current_line = $unicode_lines->[$i]; + if ((substr $current_line, -1) ne "\n"){ + print STDERR "$0: The current line (nr. $i) has invalid newline\n"; + return 1; + } # we replace TAB, LF, CR, and SPACE # this is to simplify the test if ($current_line =~ /\x{000d}/) { @@ -442,7 +446,7 @@ sub check_lexicon_pair { } foreach (0 .. @col-1) { if(!$silence{@col[$_]} and !$nonsilence{@col[$_]}) { - set_to_fail(); print "--> ERROR: phone \"@col[$_]\" is not in {, non}silence.txt (line $idx, block ", $_+1, ")\n"; + set_to_fail(); print "--> ERROR: phone \"@col[$_]\" is not in {, non}silence_phones.txt (line $idx, block ", $_+1, ")\n"; } $idx ++; } @@ -464,6 +468,22 @@ sub check_lexicon_pair { $success == 0 || print "--> $dict/extra_questions.txt is OK\n"; } else { print "--> $dict/extra_questions.txt is empty (this is OK)\n";} +if (-f "$dict/nonterminals.txt") { + open(NT, "<$dict/nonterminals.txt") || die "opening $dict/nonterminals.txt"; + my %nonterminals = (); + my $line_number = 1; + while () { + chop; + my @line = split(" ", $_); + if (@line != 1 || ! m/^#nonterm:/ || defined $nonterminals{$line[0]}) { + print "--> ERROR: bad (or duplicate) line $line_number: '$_' in $dict/nonterminals.txt\n"; exit 1; + } + $nonterminals{$line[0]} = 1; + $line_number++; + } + print "--> $dict/nonterminals.txt is OK\n"; +} + # check nonsilence_phones.txt again for phone-pairs that are never # distnguishable. (note: this situation is normal and expected for silence diff --git a/egs/wsj/s5/utils/validate_lang.pl b/egs/wsj/s5/utils/validate_lang.pl index 2501d25c8f3..ea2272f3cda 100755 --- a/egs/wsj/s5/utils/validate_lang.pl +++ b/egs/wsj/s5/utils/validate_lang.pl @@ -56,6 +56,10 @@ sub validate_utf8_whitespaces { use feature 'unicode_strings'; for (my $i = 0; $i < scalar @{$unicode_lines}; $i++) { my $current_line = $unicode_lines->[$i]; + if ((substr $current_line, -1) ne "\n"){ + print STDERR "$0: The current line (nr. $i) has invalid newline\n"; + return 1; + } # we replace TAB, LF, CR, and SPACE # this is to simplify the test if ($current_line =~ /\x{000d}/) { @@ -96,15 +100,21 @@ sub check_allowed_whitespace { $skip_det_check = 0; $skip_disambig_check = 0; +$skip_generate_words_check = 0; -if (@ARGV > 0 && $ARGV[0] eq "--skip-determinization-check") { - $skip_det_check = 1; - shift @ARGV; -} - -if (@ARGV > 0 && $ARGV[0] eq "--skip-disambig-check") { - $skip_disambig_check = 1; - shift @ARGV; +for ($x=0; $x <= 3; $x++) { + if (@ARGV > 0 && $ARGV[0] eq "--skip-determinization-check") { + $skip_det_check = 1; + shift @ARGV; + } + if (@ARGV > 0 && $ARGV[0] eq "--skip-disambig-check") { + $skip_disambig_check = 1; + shift @ARGV; + } + if (@ARGV > 0 && $ARGV[0] eq "--skip-generate-words-check") { + $skip_generate_words_check = 1; + shift @ARGV; + } } if (@ARGV != 1) { @@ -479,32 +489,15 @@ sub check_summation { %sum = (%silence, %nonsilence, %disambig); $sum{""} = 1; - my @itset = intersect(\%sum, \%psymtab); - my @key1 = keys %sum; - my @key2 = keys %psymtab; - my %itset = (); foreach(@itset) {$itset{$_} = 1;} - if (@itset < @key1) { - $exit = 1; print "--> ERROR: phones in silence.txt, nonsilence.txt, disambig.txt but not in phones.txt -- "; - foreach (@key1) { - if (!$itset{$_}) { - print "$_ "; - } + my $ok = 1; + foreach $p (keys %psymtab) { + if (! defined $sum{$p} && $p !~ m/^#nonterm/) { + $exit = 1; $ok = 0; print("--> ERROR: phone $p is not in silence.txt, nonsilence.txt or disambig.txt..."); } - print "\n"; } - if (@itset < @key2) { - $exit = 1; print "--> ERROR: phones in phones.txt but not in silence.txt, nonsilence.txt, disambig.txt -- "; - foreach (@key2) { - if (!$itset{$_}) { - print "$_ "; - } - } - print "\n"; - } - - if (@itset == @key1 and @itset == @key2) { - print "--> summation property is OK\n"; + if ($ok) { + print "--> found no unexplainable phones in phones.txt\n"; } return; } @@ -817,6 +810,9 @@ sub check_summation { } foreach $fst ("L.fst", "L_disambig.fst") { + if ($skip_generate_words_check) { + next; + } $wlen = int(rand(100)) + 1; print "--> generating a $wlen word sequence\n"; $wordseq = ""; @@ -824,10 +820,11 @@ sub check_summation { $wordseq_syms = ""; foreach (1 .. $wlen) { $id = int(rand(scalar(keys %wint2sym))); - # exclude disambiguation symbols, BOS and EOS and epsilon from the word - # sequence. + # exclude disambiguation symbols, BOS and EOS, epsilon, and + # grammar-related symbols from the word sequence. while (defined $wdisambig_words_hash{$id} or - $wint2sym{$id} eq "" or $wint2sym{$id} eq "" or $id == 0) { + $wint2sym{$id} eq "" or $wint2sym{$id} eq "" or + $wint2sym{$id} =~ m/^#nonterm/ or $id == 0) { $id = int(rand(scalar(keys %wint2sym))); } $wordseq_syms = $wordseq_syms . $wint2sym{$id} . " "; diff --git a/egs/wsj/s5/utils/validate_text.pl b/egs/wsj/s5/utils/validate_text.pl index 172396c867e..7f75cf12f20 100755 --- a/egs/wsj/s5/utils/validate_text.pl +++ b/egs/wsj/s5/utils/validate_text.pl @@ -74,6 +74,10 @@ sub validate_utf8_whitespaces { use feature 'unicode_strings'; for (my $i = 0; $i < scalar @{$unicode_lines}; $i++) { my $current_line = $unicode_lines->[$i]; + if ((substr $current_line, -1) ne "\n"){ + print STDERR "$0: The current line (nr. $i) has invalid newline\n"; + return 1; + } my @A = split(" ", $current_line); my $utt_id = $A[0]; # we replace TAB, LF, CR, and SPACE diff --git a/egs/yomdle_fa/README.txt b/egs/yomdle_fa/README.txt new file mode 100644 index 00000000000..984ffdb53b5 --- /dev/null +++ b/egs/yomdle_fa/README.txt @@ -0,0 +1,3 @@ +This directory contains example scripts for OCR on the Yomdle and Slam datasets. +Training is done on the Yomdle dataset and testing is done on Slam. +LM rescoring is also done with extra corpus data obtained from various newswires (e.g. Hamshahri) diff --git a/egs/yomdle_fa/v1/cmd.sh b/egs/yomdle_fa/v1/cmd.sh new file mode 100755 index 00000000000..3c8eb9f93a5 --- /dev/null +++ b/egs/yomdle_fa/v1/cmd.sh @@ -0,0 +1,13 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +export cmd="queue.pl" diff --git a/egs/yomdle_fa/v1/image b/egs/yomdle_fa/v1/image new file mode 120000 index 00000000000..1668ee99922 --- /dev/null +++ b/egs/yomdle_fa/v1/image @@ -0,0 +1 @@ +../../cifar/v1/image/ \ No newline at end of file diff --git a/egs/yomdle_fa/v1/local/augment_data.sh b/egs/yomdle_fa/v1/local/augment_data.sh new file mode 100755 index 00000000000..1c38bcb072d --- /dev/null +++ b/egs/yomdle_fa/v1/local/augment_data.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# Copyright 2018 Hossein Hadian +# 2018 Ashish Arora + +# Apache 2.0 +# This script performs data augmentation. + +nj=4 +cmd=run.pl +feat_dim=40 +fliplr=false +verticle_shift=0 +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +srcdir=$1 +outdir=$2 +datadir=$3 + +mkdir -p $datadir/augmentations +echo "copying $srcdir to $datadir/augmentations/aug1, allowed length, creating feats.scp" + +for set in aug1; do + image/copy_data_dir.sh --spk-prefix $set- --utt-prefix $set- \ + $srcdir $datadir/augmentations/$set + cat $srcdir/allowed_lengths.txt > $datadir/augmentations/$set/allowed_lengths.txt + local/extract_features.sh --nj $nj --cmd "$cmd" --feat-dim $feat_dim \ + --vertical-shift $verticle_shift \ + --fliplr $fliplr --augment 'random_scale' $datadir/augmentations/$set + +done + +echo " combine original data and data from different augmentations" +utils/combine_data.sh --extra-files images.scp $outdir $srcdir $datadir/augmentations/aug1 +cat $srcdir/allowed_lengths.txt > $outdir/allowed_lengths.txt diff --git a/egs/yomdle_fa/v1/local/bidi.py b/egs/yomdle_fa/v1/local/bidi.py new file mode 100755 index 00000000000..447313a5d02 --- /dev/null +++ b/egs/yomdle_fa/v1/local/bidi.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright 2018 Chun-Chieh Chang + +# This script is largely written by Stephen Rawls +# and uses the python package https://pypi.org/project/PyICU_BiDi/ +# The code leaves right to left text alone and reverses left to right text. + +import icu_bidi +import io +import sys +import unicodedata +# R=strong right-to-left; AL=strong arabic right-to-left +rtl_set = set(chr(i) for i in range(sys.maxunicode) + if unicodedata.bidirectional(chr(i)) in ['R','AL']) +def determine_text_direction(text): + # Easy case first + for char in text: + if char in rtl_set: + return icu_bidi.UBiDiLevel.UBIDI_RTL + # If we made it here we did not encounter any strongly rtl char + return icu_bidi.UBiDiLevel.UBIDI_LTR + +def utf8_visual_to_logical(text): + text_dir = determine_text_direction(text) + + bidi = icu_bidi.Bidi() + bidi.inverse = True + bidi.reordering_mode = icu_bidi.UBiDiReorderingMode.UBIDI_REORDER_INVERSE_LIKE_DIRECT + bidi.reordering_options = icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_DEFAULT # icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_INSERT_MARKS + + bidi.set_para(text, text_dir, None) + + res = bidi.get_reordered(0 | icu_bidi.UBidiWriteReorderedOpt.UBIDI_DO_MIRRORING | icu_bidi.UBidiWriteReorderedOpt.UBIDI_KEEP_BASE_COMBINING) + + return res + +def utf8_logical_to_visual(text): + text_dir = determine_text_direction(text) + + bidi = icu_bidi.Bidi() + + bidi.reordering_mode = icu_bidi.UBiDiReorderingMode.UBIDI_REORDER_DEFAULT + bidi.reordering_options = icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_DEFAULT #icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_INSERT_MARKS + + bidi.set_para(text, text_dir, None) + + res = bidi.get_reordered(0 | icu_bidi.UBidiWriteReorderedOpt.UBIDI_DO_MIRRORING | icu_bidi.UBidiWriteReorderedOpt.UBIDI_KEEP_BASE_COMBINING) + + return res + + +##main## +sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding="utf8") +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf8") +for line in sys.stdin: + line = line.strip() + line = utf8_logical_to_visual(line)[::-1] + sys.stdout.write(line + '\n') diff --git a/egs/yomdle_fa/v1/local/chain/compare_wer.sh b/egs/yomdle_fa/v1/local/chain/compare_wer.sh new file mode 100755 index 00000000000..ab880c1adb5 --- /dev/null +++ b/egs/yomdle_fa/v1/local/chain/compare_wer.sh @@ -0,0 +1,67 @@ +#!/bin/bash + +# this script is used for comparing decoding results between systems. +# e.g. local/chain/compare_wer.sh exp/chain/cnn{1a,1b} + +# Copyright 2017 Chun Chieh Chang +# 2017 Ashish Arora + +if [ $# == 0 ]; then + echo "Usage: $0: [ ... ]" + echo "e.g.: $0 exp/chain/cnn{1a,1b}" + exit 1 +fi + +echo "# $0 $*" +used_epochs=false + +echo -n "# System " +for x in $*; do printf "% 10s" " $(basename $x)"; done +echo + +echo -n "# WER " +for x in $*; do + wer=$(cat $x/decode_test/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# CER " +for x in $*; do + cer=$(cat $x/decode_test/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + + +if $used_epochs; then + exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. +fi + +echo -n "# Final train prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final train prob (xent) " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob (xent) " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo diff --git a/egs/yomdle_fa/v1/local/chain/run_cnn_e2eali_1b.sh b/egs/yomdle_fa/v1/local/chain/run_cnn_e2eali_1b.sh new file mode 100755 index 00000000000..700b57d9fce --- /dev/null +++ b/egs/yomdle_fa/v1/local/chain/run_cnn_e2eali_1b.sh @@ -0,0 +1,244 @@ +#!/bin/bash + +# e2eali_1b is the same as chainali_1a but uses the e2e chain model to get the +# lattice alignments and to build a tree + +# local/chain/compare_wer.sh scale_baseline2/exp_yomdle_farsi/chain/e2e_cnn_1a scale_baseline2/exp_yomdle_farsi/chain/cnn_e2eali_1b +# System e2e_cnn_1a cnn_e2eali_1b +# WER 19.55 18.45 +# CER 5.64 4.94 +# Final train prob -0.0065 -0.0633 +# Final valid prob 0.0015 -0.0619 +# Final train prob (xent) -0.2636 +# Final valid prob (xent) -0.2511 + +set -e -o pipefail + +data_dir=data +exp_dir=exp + +stage=0 + +nj=30 +train_set=train +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. +affix=_1b #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +common_egs_dir= +reporting_email= + +# chain options +train_stage=-10 +xent_regularize=0.1 +frame_subsampling_factor=4 +# training chunk-options +chunk_width=340,300,200,100 +num_leaves=500 +# we don't need extra left/right context for TDNN systems. +chunk_left_context=0 +chunk_right_context=0 +tdnn_dim=450 +# training options +srand=0 +remove_egs=true +lang_test=lang_test +# End configuration section. +echo "$0 $@" # Print the command line for logging + + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 2 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/nnet3/align_lats.sh --nj $nj --cmd "$cmd" \ + --acoustic-scale 1.0 \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0' \ + ${train_data_dir} $data_dir/lang $e2echain_model_dir $lat_dir + echo "" >$lat_dir/splice_opts + +fi + +if [ $stage -le 3 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor $frame_subsampling_factor \ + --alignment-subsampling-factor 1 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$cmd" $num_leaves ${train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 4 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + cnn_opts="l2-regularize=0.075" + tdnn_opts="l2-regularize=0.075" + output_opts="l2-regularize=0.1" + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=72" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=144" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=196" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=120 name=input + + conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-8,-4,0,4,8) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' mod?els... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 5 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/iam-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=500" \ + --chain.frame-subsampling-factor=$frame_subsampling_factor \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=16 \ + --trainer.frames-per-iter=1000000 \ + --trainer.optimization.num-jobs-initial=4 \ + --trainer.optimization.num-jobs-final=8 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=32,16 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.chunk-left-context=$chunk_left_context \ + --egs.chunk-right-context=$chunk_right_context \ + --egs.chunk-left-context-initial=0 \ + --egs.chunk-right-context-final=0 \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 6 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $data_dir/$lang_test \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 7 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --extra-left-context $chunk_left_context \ + --extra-right-context $chunk_right_context \ + --extra-left-context-initial 0 \ + --extra-right-context-final 0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph $data_dir/test $dir/decode_test || exit 1; +fi diff --git a/egs/yomdle_fa/v1/local/chain/run_flatstart_cnn1a.sh b/egs/yomdle_fa/v1/local/chain/run_flatstart_cnn1a.sh new file mode 100755 index 00000000000..bb5352943f6 --- /dev/null +++ b/egs/yomdle_fa/v1/local/chain/run_flatstart_cnn1a.sh @@ -0,0 +1,170 @@ +#!/bin/bash +# Copyright 2017 Hossein Hadian + +# This script does end2end chain training (i.e. from scratch) + +# local/chain/compare_wer.sh exp_yomdle_farsi/chain/e2e_cnn_1a exp_yomdle_farsi/chain/cnn_e2eali_1b +# System e2e_cnn_1a cnn_e2eali_1b +# WER 19.55 18.45 +# CER 5.64 4.94 +# Final train prob -0.0065 -0.0633 +# Final valid prob 0.0015 -0.0619 +# Final train prob (xent) -0.2636 +# Final valid prob (xent) -0.2511 + +set -e + +data_dir=data +exp_dir=exp + +# configs for 'chain' +stage=0 +nj=30 +train_stage=-10 +get_egs_stage=-10 +affix=1a + +# training options +tdnn_dim=450 +num_epochs=4 +num_jobs_initial=4 +num_jobs_final=8 +minibatch_size=150=64,32/300=32,16/600=16,8/1200=8,4 +common_egs_dir= +l2_regularize=0.00005 +frames_per_iter=1000000 +cmvn_opts="--norm-means=false --norm-vars=false" +train_set=train +lang_test=lang_test + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 1 ]; then + steps/nnet3/chain/e2e/prepare_e2e.sh --nj $nj --cmd "$cmd" \ + --shared-phones true \ + --type mono \ + $data_dir/$train_set $lang $treedir + $cmd $treedir/log/make_phone_lm.log \ + cat $data_dir/$train_set/text \| \ + steps/nnet3/chain/e2e/text_to_phones.py $data_dir/lang \| \ + utils/sym2int.pl -f 2- $data_dir/lang/phones.txt \| \ + chain-est-phone-lm --num-extra-lm-states=500 \ + ark:- $treedir/phone_lm.fst +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $treedir/tree | grep num-pdfs | awk '{print $2}') + + cnn_opts="l2-regularize=0.075" + tdnn_opts="l2-regularize=0.075" + output_opts="l2-regularize=0.1" + + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=72" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=144" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=144" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=120 name=input + conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-8,-4,0,4,8) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $output_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs +fi + +if [ $stage -le 3 ]; then + # no need to store the egs in a shared storage because we always + # remove them. Anyway, it takes only 5 minutes to generate them. + + steps/nnet3/chain/e2e/train_e2e.py --stage $train_stage \ + --cmd "$cmd" \ + --feat.cmvn-opts "$cmvn_opts" \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize $l2_regularize \ + --chain.apply-deriv-weights false \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--num_egs_diagnostic 100 --num_utts_subset 400" \ + --chain.frame-subsampling-factor 4 \ + --chain.alignment-subsampling-factor 4 \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter $frames_per_iter \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.momentum 0 \ + --trainer.optimization.num-jobs-initial $num_jobs_initial \ + --trainer.optimization.num-jobs-final $num_jobs_final \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.optimization.shrink-value 1.0 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs true \ + --feat-dir $data_dir/${train_set} \ + --tree-dir $treedir \ + --dir $dir || exit 1; +fi + +if [ $stage -le 4 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $data_dir/$lang_test \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 5 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $nj --cmd "$cmd" \ + $dir/graph $data_dir/test $dir/decode_test || exit 1; +fi + +echo "Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/yomdle_fa/v1/local/create_download.sh b/egs/yomdle_fa/v1/local/create_download.sh new file mode 100755 index 00000000000..1040ecc2165 --- /dev/null +++ b/egs/yomdle_fa/v1/local/create_download.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# Copyright 2018 Chun-Chieh Chang + +# The original format of the dataset given is GEDI and page images. +# This script is written to create line images from page images. +# It also creates csv files from the GEDI files. + +database_slam=/export/corpora5/slam/SLAM/Farsi/transcribed +database_yomdle=/export/corpora5/slam/YOMDLE/final_farsi +slam_dir=download/slam_farsi +yomdle_dir=download/yomdle_farsi + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +echo "$0: Processing SLAM ${language}" +echo "Date: $(date)." +mkdir -p ${slam_dir}/{truth_csv,truth_csv_raw,truth_line_image} +local/GEDI2CSV_enriched.py \ + --inputDir ${database_slam} \ + --outputDir ${slam_dir}/truth_csv_raw \ + --log ${slam_dir}/GEDI2CSV_enriched.log +local/create_line_image_from_page_image.py \ + ${database_slam} \ + ${slam_dir}/truth_csv_raw \ + ${slam_dir} + +echo "$0: Processing YOMDLE ${language}" +echo "Date: $(date)." +mkdir -p ${yomdle_dir}/{truth_csv,truth_csv_raw,truth_line_image} +local/YOMDLE2CSV.py \ + --inputDir ${database_yomdle} \ + --outputDir ${yomdle_dir}/truth_csv_raw/ \ + --log ${yomdle_dir}/YOMDLE2CSV.log +local/create_line_image_from_page_image.py \ + --im-format "jpg" \ + ${database_yomdle}/images \ + ${yomdle_dir}/truth_csv_raw \ + ${yomdle_dir} diff --git a/egs/yomdle_fa/v1/local/create_line_image_from_page_image.py b/egs/yomdle_fa/v1/local/create_line_image_from_page_image.py new file mode 100755 index 00000000000..7135bb1b242 --- /dev/null +++ b/egs/yomdle_fa/v1/local/create_line_image_from_page_image.py @@ -0,0 +1,458 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora +# Apache 2.0 +# minimum bounding box part in this script is originally from +#https://github.com/BebeSparkelSparkel/MinimumBoundingBox +#https://startupnextdoor.com/computing-convex-hull-in-python/ +""" This module will be used for extracting line images from page image. + Given the word segmentation (bounding box around a word) for every word, it will + extract line segmentation. To extract line segmentation, it will take word bounding + boxes of a line as input, will create a minimum area bounding box that will contain + all corner points of word bounding boxes. The obtained bounding box (will not necessarily + be vertically or horizontally aligned). Hence to extract line image from line bounding box, + page image is rotated and line image is cropped and saved. +""" + +import argparse +import csv +import itertools +import sys +import os +import numpy as np +from math import atan2, cos, sin, pi, degrees, sqrt +from collections import namedtuple + +from scipy.spatial import ConvexHull +from PIL import Image +from scipy.misc import toimage + +parser = argparse.ArgumentParser(description="Creates line images from page image") +parser.add_argument('image_dir', type=str, help='Path to full page images') +parser.add_argument('csv_dir', type=str, help='Path to csv files') +parser.add_argument('out_dir', type=str, help='Path to output directory') +parser.add_argument('--im-format', type=str, default='png', help='What file format are the images') +parser.add_argument('--padding', type=int, default=100, help='Padding so BBox does not exceed image area') +parser.add_argument('--head', type=int, default=-1, help='Number of csv files to process') +args = parser.parse_args() + +""" +bounding_box is a named tuple which contains: + area (float): area of the rectangle + length_parallel (float): length of the side that is parallel to unit_vector + length_orthogonal (float): length of the side that is orthogonal to unit_vector + rectangle_center(int, int): coordinates of the rectangle center + (use rectangle_corners to get the corner points of the rectangle) + unit_vector (float, float): direction of the length_parallel side. + (it's orthogonal vector can be found with the orthogonal_vector function + unit_vector_angle (float): angle of the unit vector to be in radians. + corner_points [(float, float)]: set that contains the corners of the rectangle +""" + +bounding_box_tuple = namedtuple('bounding_box_tuple', 'area ' + 'length_parallel ' + 'length_orthogonal ' + 'rectangle_center ' + 'unit_vector ' + 'unit_vector_angle ' + 'corner_points' + ) + + +def unit_vector(pt0, pt1): + """ Given two points pt0 and pt1, return a unit vector that + points in the direction of pt0 to pt1. + Returns + ------- + (float, float): unit vector + """ + dis_0_to_1 = sqrt((pt0[0] - pt1[0])**2 + (pt0[1] - pt1[1])**2) + return (pt1[0] - pt0[0]) / dis_0_to_1, \ + (pt1[1] - pt0[1]) / dis_0_to_1 + + +def orthogonal_vector(vector): + """ Given a vector, returns a orthogonal/perpendicular vector of equal length. + Returns + ------ + (float, float): A vector that points in the direction orthogonal to vector. + """ + return -1 * vector[1], vector[0] + + +def bounding_area(index, hull): + """ Given index location in an array and convex hull, it gets two points + hull[index] and hull[index+1]. From these two points, it returns a named + tuple that mainly contains area of the box that bounds the hull. This + bounding box orintation is same as the orientation of the lines formed + by the point hull[index] and hull[index+1]. + Returns + ------- + a named tuple that contains: + area: area of the rectangle + length_parallel: length of the side that is parallel to unit_vector + length_orthogonal: length of the side that is orthogonal to unit_vector + rectangle_center: coordinates of the rectangle center + unit_vector: direction of the length_parallel side. + (it's orthogonal vector can be found with the orthogonal_vector function) + """ + unit_vector_p = unit_vector(hull[index], hull[index+1]) + unit_vector_o = orthogonal_vector(unit_vector_p) + + dis_p = tuple(np.dot(unit_vector_p, pt) for pt in hull) + dis_o = tuple(np.dot(unit_vector_o, pt) for pt in hull) + + min_p = min(dis_p) + min_o = min(dis_o) + len_p = max(dis_p) - min_p + len_o = max(dis_o) - min_o + + return {'area': len_p * len_o, + 'length_parallel': len_p, + 'length_orthogonal': len_o, + 'rectangle_center': (min_p + float(len_p) / 2, min_o + float(len_o) / 2), + 'unit_vector': unit_vector_p, + } + + +def to_xy_coordinates(unit_vector_angle, point): + """ Given angle from horizontal axis and a point from origin, + returns converted unit vector coordinates in x, y coordinates. + angle of unit vector should be in radians. + Returns + ------ + (float, float): converted x,y coordinate of the unit vector. + """ + angle_orthogonal = unit_vector_angle + pi / 2 + return point[0] * cos(unit_vector_angle) + point[1] * cos(angle_orthogonal), \ + point[0] * sin(unit_vector_angle) + point[1] * sin(angle_orthogonal) + + +def rotate_points(center_of_rotation, angle, points): + """ Rotates a point cloud around the center_of_rotation point by angle + input + ----- + center_of_rotation (float, float): angle of unit vector to be in radians. + angle (float): angle of rotation to be in radians. + points [(float, float)]: Points to be a list or tuple of points. Points to be rotated. + Returns + ------ + [(float, float)]: Rotated points around center of rotation by angle + """ + rot_points = [] + ang = [] + for pt in points: + diff = tuple([pt[d] - center_of_rotation[d] for d in range(2)]) + diff_angle = atan2(diff[1], diff[0]) + angle + ang.append(diff_angle) + diff_length = sqrt(sum([d**2 for d in diff])) + rot_points.append((center_of_rotation[0] + diff_length * cos(diff_angle), + center_of_rotation[1] + diff_length * sin(diff_angle))) + + return rot_points + + +def rectangle_corners(rectangle): + """ Given rectangle center and its inclination, returns the corner + locations of the rectangle. + Returns + ------ + [(float, float)]: 4 corner points of rectangle. + """ + corner_points = [] + for i1 in (.5, -.5): + for i2 in (i1, -1 * i1): + corner_points.append((rectangle['rectangle_center'][0] + i1 * rectangle['length_parallel'], + rectangle['rectangle_center'][1] + i2 * rectangle['length_orthogonal'])) + + return rotate_points(rectangle['rectangle_center'], rectangle['unit_vector_angle'], corner_points) + + +def get_orientation(origin, p1, p2): + """ + Given origin and two points, return the orientation of the Point p1 with + regards to Point p2 using origin. + Returns + ------- + integer: Negative if p1 is clockwise of p2. + """ + difference = ( + ((p2[0] - origin[0]) * (p1[1] - origin[1])) + - ((p1[0] - origin[0]) * (p2[1] - origin[1])) + ) + return difference + + +def compute_hull(points): + """ + Given input list of points, return a list of points that + made up the convex hull. + Returns + ------- + [(float, float)]: convexhull points + """ + hull_points = [] + start = points[0] + min_x = start[0] + for p in points[1:]: + if p[0] < min_x: + min_x = p[0] + start = p + + point = start + hull_points.append(start) + + far_point = None + while far_point is not start: + p1 = None + for p in points: + if p is point: + continue + else: + p1 = p + break + + far_point = p1 + + for p2 in points: + if p2 is point or p2 is p1: + continue + else: + direction = get_orientation(point, far_point, p2) + if direction > 0: + far_point = p2 + + hull_points.append(far_point) + point = far_point + return hull_points + + +def minimum_bounding_box(points): + """ Given a list of 2D points, it returns the minimum area rectangle bounding all + the points in the point cloud. + Returns + ------ + returns a namedtuple that contains: + area: area of the rectangle + length_parallel: length of the side that is parallel to unit_vector + length_orthogonal: length of the side that is orthogonal to unit_vector + rectangle_center: coordinates of the rectangle center + unit_vector: direction of the length_parallel side. RADIANS + unit_vector_angle: angle of the unit vector + corner_points: set that contains the corners of the rectangle + """ + + if len(points) <= 2: raise ValueError('More than two points required.') + + hull_ordered = [points[index] for index in ConvexHull(points).vertices] + hull_ordered.append(hull_ordered[0]) + #hull_ordered = compute_hull(points) + hull_ordered = tuple(hull_ordered) + + min_rectangle = bounding_area(0, hull_ordered) + for i in range(1, len(hull_ordered)-1): + rectangle = bounding_area(i, hull_ordered) + if rectangle['area'] < min_rectangle['area']: + min_rectangle = rectangle + + min_rectangle['unit_vector_angle'] = atan2(min_rectangle['unit_vector'][1], min_rectangle['unit_vector'][0]) + min_rectangle['rectangle_center'] = to_xy_coordinates(min_rectangle['unit_vector_angle'], min_rectangle['rectangle_center']) + + return bounding_box_tuple( + area = min_rectangle['area'], + length_parallel = min_rectangle['length_parallel'], + length_orthogonal = min_rectangle['length_orthogonal'], + rectangle_center = min_rectangle['rectangle_center'], + unit_vector = min_rectangle['unit_vector'], + unit_vector_angle = min_rectangle['unit_vector_angle'], + corner_points = set(rectangle_corners(min_rectangle)) + ) + + +def get_center(im): + """ Given image, returns the location of center pixel + Returns + ------- + (int, int): center of the image + """ + center_x = float(im.size[0]) / 2 + center_y = float(im.size[1]) / 2 + return int(center_x), int(center_y) + + +def get_horizontal_angle(unit_vector_angle): + """ Given an angle in radians, returns angle of the unit vector in + first or fourth quadrant. + Returns + ------ + (float): updated angle of the unit vector to be in radians. + It is only in first or fourth quadrant. + """ + if unit_vector_angle > pi / 2 and unit_vector_angle <= pi: + unit_vector_angle = unit_vector_angle - pi + elif unit_vector_angle > -pi and unit_vector_angle < -pi / 2: + unit_vector_angle = unit_vector_angle + pi + + return unit_vector_angle + + +def get_smaller_angle(bounding_box): + """ Given a rectangle, returns its smallest absolute angle from horizontal axis. + Returns + ------ + (float): smallest angle of the rectangle to be in radians. + """ + unit_vector = bounding_box.unit_vector + unit_vector_angle = bounding_box.unit_vector_angle + ortho_vector = orthogonal_vector(unit_vector) + ortho_vector_angle = atan2(ortho_vector[1], ortho_vector[0]) + + unit_vector_angle_updated = get_horizontal_angle(unit_vector_angle) + ortho_vector_angle_updated = get_horizontal_angle(ortho_vector_angle) + + if abs(unit_vector_angle_updated) < abs(ortho_vector_angle_updated): + return unit_vector_angle_updated + else: + return ortho_vector_angle_updated + + +def rotated_points(bounding_box, center): + """ Given the rectangle, returns corner points of rotated rectangle. + It rotates the rectangle around the center by its smallest angle. + Returns + ------- + [(int, int)]: 4 corner points of rectangle. + """ + p1, p2, p3, p4 = bounding_box.corner_points + x1, y1 = p1 + x2, y2 = p2 + x3, y3 = p3 + x4, y4 = p4 + center_x, center_y = center + rotation_angle_in_rad = -get_smaller_angle(bounding_box) + x_dash_1 = (x1 - center_x) * cos(rotation_angle_in_rad) - (y1 - center_y) * sin(rotation_angle_in_rad) + center_x + x_dash_2 = (x2 - center_x) * cos(rotation_angle_in_rad) - (y2 - center_y) * sin(rotation_angle_in_rad) + center_x + x_dash_3 = (x3 - center_x) * cos(rotation_angle_in_rad) - (y3 - center_y) * sin(rotation_angle_in_rad) + center_x + x_dash_4 = (x4 - center_x) * cos(rotation_angle_in_rad) - (y4 - center_y) * sin(rotation_angle_in_rad) + center_x + + y_dash_1 = (y1 - center_y) * cos(rotation_angle_in_rad) + (x1 - center_x) * sin(rotation_angle_in_rad) + center_y + y_dash_2 = (y2 - center_y) * cos(rotation_angle_in_rad) + (x2 - center_x) * sin(rotation_angle_in_rad) + center_y + y_dash_3 = (y3 - center_y) * cos(rotation_angle_in_rad) + (x3 - center_x) * sin(rotation_angle_in_rad) + center_y + y_dash_4 = (y4 - center_y) * cos(rotation_angle_in_rad) + (x4 - center_x) * sin(rotation_angle_in_rad) + center_y + return x_dash_1, y_dash_1, x_dash_2, y_dash_2, x_dash_3, y_dash_3, x_dash_4, y_dash_4 + + +def pad_image(image): + """ Given an image, returns a padded image around the border. + This routine save the code from crashing if bounding boxes that are + slightly outside the page boundary. + Returns + ------- + image: page image + """ + offset = int(args.padding // 2) + padded_image = Image.new('RGB', (image.size[0] + int(args.padding), image.size[1] + int(args.padding)), "white") + padded_image.paste(im = image, box = (offset, offset)) + return padded_image + +def update_minimum_bounding_box_input(bounding_box_input): + """ Given list of 2D points, returns list of 2D points shifted by an offset. + Returns + ------ + points [(float, float)]: points, a list or tuple of 2D coordinates + """ + updated_minimum_bounding_box_input = [] + offset = int(args.padding // 2) + for point in bounding_box_input: + x, y = point + new_x = x + offset + new_y = y + offset + word_coordinate = (new_x, new_y) + updated_minimum_bounding_box_input.append(word_coordinate) + + return updated_minimum_bounding_box_input + + +### main ### +csv_count = 0 +for filename in sorted(os.listdir(args.csv_dir)): + if filename.endswith('.csv') and (csv_count < args.head or args.head < 0): + csv_count = csv_count + 1 + with open(os.path.join(args.csv_dir, filename), 'r', encoding='utf-8') as f: + image_file = os.path.join(args.image_dir, os.path.splitext(filename)[0] + '.' + args.im_format) + if not os.path.isfile(image_file): + continue + csv_out_file = os.path.join(args.out_dir, 'truth_csv', filename) + csv_out_fh = open(csv_out_file, 'w', encoding='utf-8') + csv_out_writer = csv.writer(csv_out_fh) + im = Image.open(image_file) + im = pad_image(im) + count = 1 + for row in itertools.islice(csv.reader(f), 0, None): + if count == 1: + count = 0 + continue + + points = [] + points.append((int(row[2]), int(row[3]))) + points.append((int(row[4]), int(row[5]))) + points.append((int(row[6]), int(row[7]))) + points.append((int(row[8]), int(row[9]))) + + x = [int(row[2]), int(row[4]), int(row[6]), int(row[8])] + y = [int(row[3]), int(row[5]), int(row[7]), int(row[9])] + min_x, min_y = min(x), min(y) + max_x, max_y = max(x), max(y) + if min_x == max_x or min_y == max_y: + continue + + try: + updated_mbb_input = update_minimum_bounding_box_input(points) + bounding_box = minimum_bounding_box(updated_mbb_input) + except Exception as e: + print("Error: Skipping Image " + row[1]) + continue + + p1, p2, p3, p4 = bounding_box.corner_points + x1, y1 = p1 + x2, y2 = p2 + x3, y3 = p3 + x4, y4 = p4 + min_x = int(min(x1, x2, x3, x4)) + min_y = int(min(y1, y2, y3, y4)) + max_x = int(max(x1, x2, x3, x4)) + max_y = int(max(y1, y2, y3, y4)) + box = (min_x, min_y, max_x, max_y) + region_initial = im.crop(box) + rot_points = [] + p1_new = (x1 - min_x, y1 - min_y) + p2_new = (x2 - min_x, y2 - min_y) + p3_new = (x3 - min_x, y3 - min_y) + p4_new = (x4 - min_x, y4 - min_y) + rot_points.append(p1_new) + rot_points.append(p2_new) + rot_points.append(p3_new) + rot_points.append(p4_new) + + cropped_bounding_box = bounding_box_tuple(bounding_box.area, + bounding_box.length_parallel, + bounding_box.length_orthogonal, + bounding_box.length_orthogonal, + bounding_box.unit_vector, + bounding_box.unit_vector_angle, + set(rot_points)) + + rotation_angle_in_rad = get_smaller_angle(cropped_bounding_box) + img2 = region_initial.rotate(degrees(rotation_angle_in_rad), resample = Image.BICUBIC) + x_dash_1, y_dash_1, x_dash_2, y_dash_2, x_dash_3, y_dash_3, x_dash_4, y_dash_4 = rotated_points( + cropped_bounding_box, get_center(region_initial)) + + min_x = int(min(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) + min_y = int(min(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) + max_x = int(max(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) + max_y = int(max(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) + box = (min_x, min_y, max_x, max_y) + region_final = img2.crop(box) + csv_out_writer.writerow(row) + image_out_file = os.path.join(args.out_dir, 'truth_line_image', row[1]) + region_final.save(image_out_file) diff --git a/egs/yomdle_fa/v1/local/extract_features.sh b/egs/yomdle_fa/v1/local/extract_features.sh new file mode 100755 index 00000000000..f75837ae5b3 --- /dev/null +++ b/egs/yomdle_fa/v1/local/extract_features.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Copyright 2017 Yiwen Shao +# 2018 Ashish Arora + +nj=4 +cmd=run.pl +feat_dim=40 +fliplr=false +augment='no_aug' +num_channels=3 +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +data=$1 +featdir=$data/data +scp=$data/images.scp +logdir=$data/log + +mkdir -p $logdir +mkdir -p $featdir + +# make $featdir an absolute pathname +featdir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $featdir ${PWD}` + +for n in $(seq $nj); do + split_scps="$split_scps $logdir/images.$n.scp" +done + +# split images.scp +utils/split_scp.pl $scp $split_scps || exit 1; + +$cmd JOB=1:$nj $logdir/extract_features.JOB.log \ + image/ocr/make_features.py $logdir/images.JOB.scp \ + --allowed_len_file_path $data/allowed_lengths.txt \ + --feat-dim $feat_dim --num-channels $num_channels --fliplr $fliplr --augment_type $augment \| \ + copy-feats --compress=true --compression-method=7 \ + ark:- ark,scp:$featdir/images.JOB.ark,$featdir/images.JOB.scp + +## aggregates the output scp's to get feats.scp +for n in $(seq $nj); do + cat $featdir/images.$n.scp || exit 1; +done > $data/feats.scp || exit 1 diff --git a/egs/yomdle_fa/v1/local/gedi2csv.py b/egs/yomdle_fa/v1/local/gedi2csv.py new file mode 100755 index 00000000000..0b80c2e80bb --- /dev/null +++ b/egs/yomdle_fa/v1/local/gedi2csv.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python3 + +""" +GEDI2CSV +Convert GEDI-type bounding boxes to CSV format + +GEDI Format Example: + + + + + + + + + +CSV Format Example +ID,name,col1,row1,col2,row2,col3,row3,col4,row4,confidence,truth,pgrot,bbrot,qual,script,lang +0,chinese_scanned_books_0001_0.png,99,41,99,14,754,14,754,41,100,凡我的邻人说是好的,有一大部分在我灵魂中却,0,0.0,0,,zh-cn +""" + +import logging +import os +import sys +import time +import glob +import csv +import imghdr +from PIL import Image +import argparse +import pdb +import cv2 +import numpy as np +import xml.etree.ElementTree as ET + +sin = np.sin +cos = np.cos +pi = np.pi + +def Rotate2D(pts, cnt, ang=90): + M = np.array([[cos(ang),-sin(ang)],[sin(ang),cos(ang)]]) + res = np.dot(pts-cnt,M)+cnt + return M, res + +def npbox2string(npar): + if np.shape(npar)[0] != 1: + print('Error during CSV conversion\n') + c1,r1 = npar[0][0],npar[0][1] + c2,r2 = npar[0][2],npar[0][3] + c3,r3 = npar[0][4],npar[0][5] + c4,r4 = npar[0][6],npar[0][7] + + return c1,r1,c2,r2,c3,r3,c4,r4 + +# cv2.minAreaRect() returns a Box2D structure which contains following detals - ( center (x,y), (width, height), angle of rotation ) +# Get 4 corners of the rectangle using cv2.boxPoints() + +class GEDI2CSV(object): + + """ Initialize the extractor""" + def __init__(self, logger, args): + self._logger = logger + self._args = args + + """ + Segment image with GEDI bounding box information + """ + def csvfile(self, coords, polys, baseName, pgrot): + + """ for writing the files """ + writePath = self._args.outputDir + writePath = os.path.join(writePath,'') + if os.path.isdir(writePath) != True: + os.makedirs(writePath) + + rotlist = [] + + header=['ID','name','col1','row1','col2','row2','col3','row3','col4','row4','confidence','truth','pgrot','bbrot','qual','script','text_type'] + conf=100 + write_ctr = 0 + if len(coords) == 0 and len(polys) == 0: + self._logger.info('Found %s with no text content',(baseName)) + print('...Found %s with no text content' % (baseName)) + return + + strPos = writePath + baseName + + """ for each group of coordinates """ + for i in coords: + + [id,x,y,w,h,degrees,text,qual,script,text_type] = i + + contour = np.array([(x,y),(x+w,y),(x+w,y+h),(x,y+h)]) + + """ + First rotate around upper left corner based on orientationD keyword + """ + M, rot = Rotate2D(contour, np.array([x,y]), degrees*pi/180) + rot = np.int0(rot) + + # rot is the 8 points rotated by degrees + # pgrot is the rotation after extraction, so save + + # save rotated points to list or array + rot = np.reshape(rot,(-1,1)).T + c1,r1,c2,r2,c3,r3,c4,r4 = npbox2string(rot) + + text = text.replace(u'\ufeff','') + + bbrot = degrees + rotlist.append([id,baseName + '_' + id + '.png',c1,r1,c2,r2,c3,r3,c4,r4,conf,text,pgrot,bbrot,qual,script,text_type]) + + # if there are polygons, first save the text + for j in polys: + arr = [] + [id,poly_val,text,qual,script,text_type] = j + for i in poly_val: + arr.append(eval(i)) + + contour = np.asarray(arr) + convex = cv2.convexHull(contour) + rect = cv2.minAreaRect(convex) + box = cv2.boxPoints(rect) + box = np.int0(box) + box = np.reshape(box,(-1,1)).T + c1,r1,c2,r2,c3,r3,c4,r4 = npbox2string(box) + + bbrot = 0.0 + + rotlist.append([id,baseName + '_' + id + '.png',c1,r1,c2,r2,c3,r3,c4,r4,conf,text,pgrot,bbrot,qual,script,text_type]) + + # then write out all of list to file + with open(strPos + ".csv", "w", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(header) + for row in rotlist: + writer.writerow(row) + write_ctr += 1 + + return write_ctr + + +def main(args): + + startTime = time.clock() + + writePath = args.outputDir + if os.path.isdir(writePath) != True: + os.makedirs(writePath) + + """ Setup logging """ + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + if args.log: + handler = logging.FileHandler(args.log) + handler.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + + gtconverter = GEDI2CSV(logger, args) + namespaces = {"gedi" : "http://lamp.cfar.umd.edu/media/projects/GEDI/"} + keyCnt=0 + + fileCnt = 0 + line_write_ctr = 0 + line_error_ctr = 0 + + """ + Get all XML files in the directory and sub folders + """ + for root, dirnames, filenames in os.walk(args.inputDir, followlinks=True): + for file in filenames: + if file.lower().endswith('.xml'): + fullName = os.path.join(root,file) + baseName = os.path.splitext(fullName) + + fileCnt += 1 + + """ read the XML file """ + tree = ET.parse(fullName) + gedi_root = tree.getroot() + child = gedi_root.findall('gedi:DL_DOCUMENT',namespaces)[0] + totalpages = int(child.attrib['NrOfPages']) + coordinates=[] + polygons = [] + if args.ftype == 'boxed': + fileTypeStr = 'col' + elif args.ftype == 'transcribed': + fileTypeStr = 'Text_Content' + else: + print('Filetype must be either boxed or transcribed!') + logger.info('Filetype must be either boxed or transcribed!') + sys.exit(-1) + + if args.quality == 'both': + qualset = {'Regular','Low-Quality'} + elif args.quality == 'low': + qualset = {'Low-Quality'} + elif args.quality == 'regular': + qualset = {'Regular'} + else: + print('Quality must be both, low or regular!') + logger.info('Quality must be both, low or regular!') + sys.exit(-1) + + + + """ and for each page """ + for i, pgs in enumerate(child.iterfind('gedi:DL_PAGE',namespaces)): + + if 'GEDI_orientation' not in pgs.attrib: + pageRot=0 + else: + pageRot = int(pgs.attrib['GEDI_orientation']) + logger.info(' PAGE ROTATION %s, %s' % (fullName, str(pageRot))) + + """ find children for each page """ + for zone in pgs.findall('gedi:DL_ZONE',namespaces): + + if zone.attrib['gedi_type']=='Text' and zone.attrib['Type'] in \ + ('Machine_Print','Confusable_Allograph','Handwriting') and zone.attrib['Quality'] in qualset: + if zone.get('polygon'): + keyCnt+=1 + polygons.append([zone.attrib['id'],zone.get('polygon').split(';'), + zone.get('Text_Content'),zone.get('Quality'),zone.get('Script'),zone.get('Type')]) + elif zone.get(fileTypeStr) != None: + keyCnt+=1 + coord = [zone.attrib['id'],int(zone.attrib['col']),int(zone.attrib['row']), + int(zone.attrib['width']), int(zone.attrib['height']), + float(zone.get('orientationD',0.0)), + zone.get('Text_Content'),zone.get('Quality'),zone.get('Script'),zone.get('Type')] + coordinates.append(coord) + + if len(coordinates) > 0 or len(polygons) > 0: + line_write_ctr += gtconverter.csvfile(coordinates, polygons, os.path.splitext(file)[0], pageRot) + else: + print('...%s has no applicable content' % (baseName[0])) + + print('complete...total files %d, lines written %d' % (fileCnt, line_write_ctr)) + + +def parse_arguments(argv): + """ Args and defaults """ + parser = argparse.ArgumentParser() + + parser.add_argument('--inputDir', type=str, help='Input directory', required=True) + parser.add_argument('--outputDir', type=str, help='Output directory', required=True) + parser.add_argument('--ftype', type=str, help='GEDI file type (either "boxed" or "transcribed")', default='transcribed') + parser.add_argument('--quality', type=str, help='GEDI file quality (either "both" or "low" or "regular")', default='regular') + parser.add_argument('--log', type=str, help='Log directory', default='./GEDI2CSV_enriched.log') + + return parser.parse_args(argv) + +if __name__ == '__main__': + """ Run """ + main(parse_arguments(sys.argv[1:])) + + + + + + diff --git a/egs/yomdle_fa/v1/local/prepare_dict.sh b/egs/yomdle_fa/v1/local/prepare_dict.sh new file mode 100755 index 00000000000..8d14130d8c0 --- /dev/null +++ b/egs/yomdle_fa/v1/local/prepare_dict.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +# Copyright 2017 Hossein Hadian +# 2017 Chun Chieh Chang +# 2017 Ashish Arora + +# This script prepares the dictionary. + +set -e +dir=data/local/dict +data_dir=data + +. ./utils/parse_options.sh || exit 1; + +base_dir=$(echo "$DIRECTORY" | cut -d "/" -f2) + +mkdir -p $dir + +local/prepare_lexicon.py --data-dir $data_dir $dir + +perl -i -ne 'print if /\S/' $dir/lexicon.txt +cut -d' ' -f2- $dir/lexicon.txt | sed 's/SIL//g' | tr ' ' '\n' | sort -u | sed '/^$/d' >$dir/nonsilence_phones.txt || exit 1; + +echo ' SIL' >> $dir/lexicon.txt + +echo SIL > $dir/silence_phones.txt + +echo SIL >$dir/optional_silence.txt + +echo -n "" >$dir/extra_questions.txt diff --git a/egs/yomdle_fa/v1/local/prepare_lexicon.py b/egs/yomdle_fa/v1/local/prepare_lexicon.py new file mode 100755 index 00000000000..46be4f37970 --- /dev/null +++ b/egs/yomdle_fa/v1/local/prepare_lexicon.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora + +import argparse +import os + +parser = argparse.ArgumentParser(description="""Creates the list of characters and words in lexicon""") +parser.add_argument('dir', type=str, help='output path') +parser.add_argument('--data-dir', type=str, default='data', help='Path to text file') +args = parser.parse_args() + +### main ### +lex = {} +text_path = os.path.join(args.data_dir, 'train', 'text') +text_fh = open(text_path, 'r', encoding='utf-8') + +with open(text_path, 'r', encoding='utf-8') as f: + for line in f: + line_vect = line.strip().split(' ') + for i in range(1, len(line_vect)): + characters = list(line_vect[i]) + # Put SIL instead of "|". Because every "|" in the beginning of the words is for initial-space of that word + characters = " ".join([ 'SIL' if char == '|' else char for char in characters]) + characters = characters.replace('#','') + lex[line_vect[i]] = characters + +with open(os.path.join(args.dir, 'lexicon.txt'), 'w', encoding='utf-8') as fp: + for key in sorted(lex): + fp.write(key + " " + lex[key] + "\n") diff --git a/egs/yomdle_fa/v1/local/process_data.py b/egs/yomdle_fa/v1/local/process_data.py new file mode 100755 index 00000000000..3423cc5380e --- /dev/null +++ b/egs/yomdle_fa/v1/local/process_data.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora +# 2018 Chun Chieh Chang + +""" This script reads the extracted Farsi OCR (yomdle and slam) database files + and creates the following files (for the data subset selected via --dataset): + text, utt2spk, images.scp. + Eg. local/process_data.py data/download/ data/local/splits/train.txt data/train + Eg. text file: english_phone_books_0001_1 To sum up, then, it would appear that + utt2spk file: english_phone_books_0001_0 english_phone_books_0001 + images.scp file: english_phone_books_0001_0 \ + data/download/truth_line_image/english_phone_books_0001_0.png +""" + +import argparse +import os +import sys +import csv +import itertools +import unicodedata + +parser = argparse.ArgumentParser(description="Creates text, utt2spk, and images.scp files") +parser.add_argument('database_path', type=str, help='Path to data') +parser.add_argument('out_dir', type=str, help='directory to output files') +parser.add_argument('--head', type=int, default=-1, help='limit on number of synth data') +args = parser.parse_args() + +### main ### +print("Processing '{}' data...".format(args.out_dir)) + +text_file = os.path.join(args.out_dir, 'text') +text_fh = open(text_file, 'w', encoding='utf-8') +utt2spk_file = os.path.join(args.out_dir, 'utt2spk') +utt2spk_fh = open(utt2spk_file, 'w', encoding='utf-8') +image_file = os.path.join(args.out_dir, 'images.scp') +image_fh = open(image_file, 'w', encoding='utf-8') + +count = 0 +for filename in sorted(os.listdir(os.path.join(args.database_path, 'truth_csv'))): + if filename.endswith('.csv') and (count < args.head or args.head < 0): + count = count + 1 + csv_filepath = os.path.join(args.database_path, 'truth_csv', filename) + csv_file = open(csv_filepath, 'r', encoding='utf-8') + row_count = 0 + for row in csv.reader(csv_file): + if row_count == 0: + row_count = 1 + continue + image_id = os.path.splitext(row[1])[0] + image_filepath = os.path.join(args.database_path, 'truth_line_image', row[1]) + text = unicodedata.normalize('NFC', row[11]) + file_info = os.stat(image_filepath) + if file_info.st_size != 0: + if text: + text_fh.write(image_id + ' ' + text + '\n') + utt2spk_fh.write(image_id + ' ' + '_'.join(image_id.split('_')[:-1]) + '\n') + image_fh.write(image_id + ' ' + image_filepath + ' ' + row[13] + '\n') diff --git a/egs/yomdle_fa/v1/local/score.sh b/egs/yomdle_fa/v1/local/score.sh new file mode 100755 index 00000000000..f2405205f02 --- /dev/null +++ b/egs/yomdle_fa/v1/local/score.sh @@ -0,0 +1,5 @@ +#!/bin/bash + + +steps/scoring/score_kaldi_wer.sh --max-lmwt 10 "$@" +steps/scoring/score_kaldi_cer.sh --max-lmwt 10 --stage 2 "$@" diff --git a/egs/yomdle_fa/v1/local/train_lm.sh b/egs/yomdle_fa/v1/local/train_lm.sh new file mode 100755 index 00000000000..bc738f217da --- /dev/null +++ b/egs/yomdle_fa/v1/local/train_lm.sh @@ -0,0 +1,110 @@ +#!/bin/bash + +# Copyright 2016 Vincent Nguyen +# 2016 Johns Hopkins University (author: Daniel Povey) +# 2017 Ashish Arora +# 2017 Hossein Hadian +# Apache 2.0 +# +# This script trains a LM on the YOMDLE training transcriptions. +# It is based on the example scripts distributed with PocoLM + +# It will check if pocolm is installed and if not will proceed with installation + +set -e +stage=0 +dir=data/local/local_lm +data_dir=data + +echo "$0 $@" # Print the command line for logging +. ./utils/parse_options.sh || exit 1; + +lm_dir=${dir}/data + + +mkdir -p $dir +. ./path.sh || exit 1; # for KALDI_ROOT +export PATH=$KALDI_ROOT/tools/pocolm/scripts:$PATH +( # First make sure the pocolm toolkit is installed. + cd $KALDI_ROOT/tools || exit 1; + if [ -d pocolm ]; then + echo Not installing the pocolm toolkit since it is already there. + else + echo "$0: Please install the PocoLM toolkit with: " + echo " cd ../../../tools; extras/install_pocolm.sh; cd -" + exit 1; + fi +) || exit 1; + +bypass_metaparam_optim_opt= +# If you want to bypass the metaparameter optimization steps with specific metaparameters +# un-comment the following line, and change the numbers to some appropriate values. +# You can find the values from output log of train_lm.py. +# These example numbers of metaparameters is for 4-gram model (with min-counts) +# running with train_lm.py. +# The dev perplexity should be close to the non-bypassed model. +#bypass_metaparam_optim_opt= +# Note: to use these example parameters, you may need to remove the .done files +# to make sure the make_lm_dir.py be called and tain only 3-gram model +#for order in 3; do +#rm -f ${lm_dir}/${num_word}_${order}.pocolm/.done + +if [ $stage -le 0 ]; then + mkdir -p ${dir}/data + mkdir -p ${dir}/data/text + + echo "$0: Getting the Data sources" + + rm ${dir}/data/text/* 2>/dev/null || true + + # Note: the name 'dev' is treated specially by pocolm, it automatically + # becomes the dev set. + nr=`cat $data_dir/train/text | wc -l` + nr_dev=$(($nr / 10 )) + nr_train=$(( $nr - $nr_dev )) + + # use the training data as an additional data source. + # we can later fold the dev data into this. + head -n $nr_train $data_dir/train/text | cut -d " " -f 2- > ${dir}/data/text/train.txt + tail -n $nr_dev $data_dir/train/text | cut -d " " -f 2- > ${dir}/data/text/dev.txt + + # for reporting perplexities, we'll use the "real" dev set. + # (the validation data is used as ${dir}/data/text/dev.txt to work + # out interpolation weights.) + # note, we can't put it in ${dir}/data/text/, because then pocolm would use + # it as one of the data sources. + cut -d " " -f 2- < $data_dir/test/text > ${dir}/data/real_dev_set.txt + + # get the wordlist from MADCAT text + cat ${dir}/data/text/train.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + cat ${dir}/data/word_count | awk '{print $2}' > ${dir}/data/wordlist +fi + +order=3 + +if [ $stage -le 1 ]; then + # decide on the vocabulary. + # Note: you'd use --wordlist if you had a previously determined word-list + # that you wanted to use. + # Note: if you have more than one order, use a certain amount of words as the + # vocab and want to restrict max memory for 'sort', + echo "$0: training the unpruned LM" + min_counts='train=1' + wordlist=${dir}/data/wordlist + + lm_name="`basename ${wordlist}`_${order}" + if [ -n "${min_counts}" ]; then + lm_name+="_`echo ${min_counts} | tr -s "[:blank:]" "_" | tr "=" "-"`" + fi + unpruned_lm_dir=${lm_dir}/${lm_name}.pocolm + train_lm.py --wordlist=${wordlist} --num-splits=5 --warm-start-ratio=1 \ + --min-counts="$min_counts" \ + --limit-unk-history=true \ + ${bypass_metaparam_optim_opt} \ + ${dir}/data/text ${order} ${lm_dir}/work ${unpruned_lm_dir} + + get_data_prob.py ${dir}/data/real_dev_set.txt ${unpruned_lm_dir} 2>&1 | grep -F '[perplexity' + + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${unpruned_lm_dir} | gzip -c > ${dir}/data/arpa/${order}gram_unpruned.arpa.gz +fi diff --git a/egs/yomdle_fa/v1/local/train_lm_lr.sh b/egs/yomdle_fa/v1/local/train_lm_lr.sh new file mode 100755 index 00000000000..5bfc20acdeb --- /dev/null +++ b/egs/yomdle_fa/v1/local/train_lm_lr.sh @@ -0,0 +1,113 @@ +#!/bin/bash + +# Copyright 2016 Vincent Nguyen +# 2016 Johns Hopkins University (author: Daniel Povey) +# 2017 Ashish Arora +# 2017 Hossein Hadian +# Apache 2.0 +# +# This script trains a LM on the YOMDLE+Extra training transcriptions. +# It is based on the example scripts distributed with PocoLM + +# It will check if pocolm is installed and if not will proceed with installation + +set -e +stage=0 +dir=data/local/local_lm +data_dir=data +extra_lm=download/extra_lm.txt +order=3 + +echo "$0 $@" # Print the command line for logging +. ./utils/parse_options.sh || exit 1; + +lm_dir=${dir}/data + + +mkdir -p $dir +. ./path.sh || exit 1; # for KALDI_ROOT +export PATH=$KALDI_ROOT/tools/pocolm/scripts:$PATH +( # First make sure the pocolm toolkit is installed. + cd $KALDI_ROOT/tools || exit 1; + if [ -d pocolm ]; then + echo Not installing the pocolm toolkit since it is already there. + else + echo "$0: Please install the PocoLM toolkit with: " + echo " cd ../../../tools; extras/install_pocolm.sh; cd -" + exit 1; + fi +) || exit 1; + +bypass_metaparam_optim_opt= +# If you want to bypass the metaparameter optimization steps with specific metaparameters +# un-comment the following line, and change the numbers to some appropriate values. +# You can find the values from output log of train_lm.py. +# These example numbers of metaparameters is for 4-gram model (with min-counts) +# running with train_lm.py. +# The dev perplexity should be close to the non-bypassed model. +#bypass_metaparam_optim_opt= +# Note: to use these example parameters, you may need to remove the .done files +# to make sure the make_lm_dir.py be called and tain only 3-gram model +#for order in 3; do +#rm -f ${lm_dir}/${num_word}_${order}.pocolm/.done + +if [ $stage -le 0 ]; then + mkdir -p ${dir}/data + mkdir -p ${dir}/data/text + + echo "$0: Getting the Data sources" + + rm ${dir}/data/text/* 2>/dev/null || true + + cat ${extra_lm} | local/bidi.py | utils/lang/bpe/prepend_words.py --encoding 'utf-8' | python3 utils/lang/bpe/apply_bpe.py -c $data_dir/train/bpe.out | sed 's/@@//g' > ${dir}/data/text/extra_lm.txt + + # Note: the name 'dev' is treated specially by pocolm, it automatically + # becomes the dev set. + nr=`cat $data_dir/train/text | wc -l` + nr_dev=$(($nr / 10 )) + nr_train=$(( $nr - $nr_dev )) + + # use the training data as an additional data source. + # we can later fold the dev data into this. + head -n $nr_train $data_dir/train/text | cut -d " " -f 2- > ${dir}/data/text/train.txt + tail -n $nr_dev $data_dir/train/text | cut -d " " -f 2- > ${dir}/data/text/dev.txt + + # for reporting perplexities, we'll use the "real" dev set. + # (the validation data is used as ${dir}/data/text/dev.txt to work + # out interpolation weights.) + # note, we can't put it in ${dir}/data/text/, because then pocolm would use + # it as one of the data sources. + cut -d " " -f 2- < $data_dir/test/text > ${dir}/data/real_dev_set.txt + + # get the wordlist from MADCAT text + cat ${dir}/data/text/{train,extra_lm}.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + #cat ${dir}/data/text/extra_fa.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + cat ${dir}/data/word_count | awk '{print $2}' > ${dir}/data/wordlist +fi + +if [ $stage -le 1 ]; then + # decide on the vocabulary. + # Note: you'd use --wordlist if you had a previously determined word-list + # that you wanted to use. + # Note: if you have more than one order, use a certain amount of words as the + # vocab and want to restrict max memory for 'sort', + echo "$0: training the unpruned LM" + min_counts='extra_lm=10 train=1' + wordlist=${dir}/data/wordlist + + lm_name="`basename ${wordlist}`_${order}" + if [ -n "${min_counts}" ]; then + lm_name+="_`echo ${min_counts} | tr -s "[:blank:]" "_" | tr "=" "-"`" + fi + unpruned_lm_dir=${lm_dir}/${lm_name}.pocolm + train_lm.py --wordlist=${wordlist} --num-splits=30 --warm-start-ratio=1 \ + --min-counts="$min_counts" \ + --limit-unk-history=true \ + ${bypass_metaparam_optim_opt} \ + ${dir}/data/text ${order} ${lm_dir}/work ${unpruned_lm_dir} + + get_data_prob.py ${dir}/data/real_dev_set.txt ${unpruned_lm_dir} 2>&1 | grep -F '[perplexity' + + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${unpruned_lm_dir} | gzip -c > ${dir}/data/arpa/${order}gram_unpruned.arpa.gz +fi diff --git a/egs/yomdle_fa/v1/local/wer_output_filter b/egs/yomdle_fa/v1/local/wer_output_filter new file mode 100755 index 00000000000..08d5563bca4 --- /dev/null +++ b/egs/yomdle_fa/v1/local/wer_output_filter @@ -0,0 +1,151 @@ +#!/usr/bin/env perl +# Copyright 2012-2014 Johns Hopkins University (Author: Yenda Trmal) +# Apache 2.0 + +use utf8; + +use open qw(:encoding(utf8)); +binmode STDIN, ":utf8"; +binmode STDOUT, ":utf8"; +binmode STDERR, ":utf8"; + +# Arabic-specific normalization +while (<>) { + @F = split " "; + print "$F[0] "; + foreach $s (@F[1..$#F]) { + # Normalize tabs, spaces, and no-break spaces + $s =~ s/[\x{0009}\x{0020}\x{00A0}]+/ /g; + # Normalize "dots"/"filled-circles" to periods + $s =~ s/[\x{25CF}\x{u2022}\x{2219}]+/\x{002E}/g; + # Normalize dashes to regular hyphen + $s =~ s/[\x{2010}\x{2011}\x{2012}\x{2013}\x{2014}\x{2015}]+/\x{002D}/g; + # Normalize various parenthesis to regular parenthesis + $s =~ s/\x{UFF09}/\x{0029}/g; + $s =~ s/\x{UFF08}/\x{0028}/g; + + # Convert various presentation forms to base form + $s =~ s/[\x{FED1}\x{FED3}\x{FED4}\x{FED2}]+/\x{0641}/g; + $s =~ s/[\x{FBB0}\x{FBB1}]+/\x{06D3}/g; + $s =~ s/[\x{FECD}\x{FECF}\x{FED0}\x{FECE}]+/\x{063A}/g; + $s =~ s/[\x{FBDD}]+/\x{0677}/g; + $s =~ s/[\x{FBA6}\x{FBA8}\x{FBA9}\x{FBA7}]+/\x{06C1}/g; + $s =~ s/[\x{FEC1}\x{FEC3}\x{FEC4}\x{FEC2}]+/\x{0637}/g; + $s =~ s/[\x{FE85}\x{FE86}]+/\x{0624}/g; + $s =~ s/[\x{FEA5}\x{FEA7}\x{FEA8}\x{FEA6}]+/\x{062E}/g; + $s =~ s/[\x{FBD9}\x{FBDA}]+/\x{06C6}/g; + $s =~ s/[\x{FE8F}\x{FE91}\x{FE92}\x{FE90}]+/\x{0628}/g; + $s =~ s/[\x{FEED}\x{FEEE}]+/\x{0648}/g; + $s =~ s/[\x{FE99}\x{FE9B}\x{FE9C}\x{FE9A}]+/\x{062B}/g; + $s =~ s/[\x{FEBD}\x{FEBF}\x{FEC0}\x{FEBE}]+/\x{0636}/g; + $s =~ s/[\x{FEE5}\x{FEE7}\x{FEE8}\x{FEE6}]+/\x{0646}/g; + $s =~ s/[\x{FBFC}\x{FBFE}\x{FBFF}\x{FBFD}]+/\x{06CC}/g; + $s =~ s/[\x{FBA4}\x{FBA5}]+/\x{06C0}/g; + $s =~ s/[\x{FB72}\x{FB74}\x{FB75}\x{FB73}]+/\x{0684}/g; + $s =~ s/[\x{FBD3}\x{FBD5}\x{FBD6}\x{FBD4}]+/\x{06AD}/g; + $s =~ s/[\x{FB6A}\x{FB6C}\x{FB6D}\x{FB6B}]+/\x{06A4}/g; + $s =~ s/[\x{FB66}\x{FB68}\x{FB69}\x{FB67}]+/\x{0679}/g; + $s =~ s/[\x{FB5E}\x{FB60}\x{FB61}\x{FB5F}]+/\x{067A}/g; + $s =~ s/[\x{FB88}\x{FB89}]+/\x{0688}/g; + $s =~ s/[\x{FB7E}\x{FB80}\x{FB81}\x{FB7F}]+/\x{0687}/g; + $s =~ s/[\x{FB8E}\x{FB90}\x{FB91}\x{FB8F}]+/\x{06A9}/g; + $s =~ s/[\x{FB86}\x{FB87}]+/\x{068E}/g; + $s =~ s/[\x{FE83}\x{FE84}]+/\x{0623}/g; + $s =~ s/[\x{FB8A}\x{FB8B}]+/\x{0698}/g; + $s =~ s/[\x{FED5}\x{FED7}\x{FED8}\x{FED6}]+/\x{0642}/g; + $s =~ s/[\x{FED9}\x{FEDB}\x{FEDC}\x{FEDA}]+/\x{0643}/g; + $s =~ s/[\x{FBE0}\x{FBE1}]+/\x{06C5}/g; + $s =~ s/[\x{FEB9}\x{FEBB}\x{FEBC}\x{FEBA}]+/\x{0635}/g; + $s =~ s/[\x{FEC5}\x{FEC7}\x{FEC8}\x{FEC6}]+/\x{0638}/g; + $s =~ s/[\x{FE8D}\x{FE8E}]+/\x{0627}/g; + $s =~ s/[\x{FB9A}\x{FB9C}\x{FB9D}\x{FB9B}]+/\x{06B1}/g; + $s =~ s/[\x{FEAD}\x{FEAE}]+/\x{0631}/g; + $s =~ s/[\x{FEF1}\x{FEF3}\x{FEF4}\x{FEF2}]+/\x{064A}/g; + $s =~ s/[\x{FE93}\x{FE94}]+/\x{0629}/g; + $s =~ s/[\x{FBE4}\x{FBE6}\x{FBE7}\x{FBE5}]+/\x{06D0}/g; + $s =~ s/[\x{FE89}\x{FE8B}\x{FE8C}\x{FE8A}]+/\x{0626}/g; + $s =~ s/[\x{FB84}\x{FB85}]+/\x{068C}/g; + $s =~ s/[\x{FE9D}\x{FE9F}\x{FEA0}\x{FE9E}]+/\x{062C}/g; + $s =~ s/[\x{FB82}\x{FB83}]+/\x{068D}/g; + $s =~ s/[\x{FEA1}\x{FEA3}\x{FEA4}\x{FEA2}]+/\x{062D}/g; + $s =~ s/[\x{FB52}\x{FB54}\x{FB55}\x{FB53}]+/\x{067B}/g; + $s =~ s/[\x{FB92}\x{FB94}\x{FB95}\x{FB93}]+/\x{06AF}/g; + $s =~ s/[\x{FB7A}\x{FB7C}\x{FB7D}\x{FB7B}]+/\x{0686}/g; + $s =~ s/[\x{FBDB}\x{FBDC}]+/\x{06C8}/g; + $s =~ s/[\x{FB56}\x{FB58}\x{FB59}\x{FB57}]+/\x{067E}/g; + $s =~ s/[\x{FEB5}\x{FEB7}\x{FEB8}\x{FEB6}]+/\x{0634}/g; + $s =~ s/[\x{FBE2}\x{FBE3}]+/\x{06C9}/g; + $s =~ s/[\x{FB96}\x{FB98}\x{FB99}\x{FB97}]+/\x{06B3}/g; + $s =~ s/[\x{FE80}]+/\x{0621}/g; + $s =~ s/[\x{FBAE}\x{FBAF}]+/\x{06D2}/g; + $s =~ s/[\x{FB62}\x{FB64}\x{FB65}\x{FB63}]+/\x{067F}/g; + $s =~ s/[\x{FEE9}\x{FEEB}\x{FEEC}\x{FEEA}]+/\x{0647}/g; + $s =~ s/[\x{FE81}\x{FE82}]+/\x{0622}/g; + $s =~ s/[\x{FBDE}\x{FBDF}]+/\x{06CB}/g; + $s =~ s/[\x{FE87}\x{FE88}]+/\x{0625}/g; + $s =~ s/[\x{FB6E}\x{FB70}\x{FB71}\x{FB6F}]+/\x{06A6}/g; + $s =~ s/[\x{FBA0}\x{FBA2}\x{FBA3}\x{FBA1}]+/\x{06BB}/g; + $s =~ s/[\x{FBAA}\x{FBAC}\x{FBAD}\x{FBAB}]+/\x{06BE}/g; + $s =~ s/[\x{FEA9}\x{FEAA}]+/\x{062F}/g; + $s =~ s/[\x{FEE1}\x{FEE3}\x{FEE4}\x{FEE2}]+/\x{0645}/g; + $s =~ s/[\x{FEEF}\x{FBE8}\x{FBE9}\x{FEF0}]+/\x{0649}/g; + $s =~ s/[\x{FB8C}\x{FB8D}]+/\x{0691}/g; + $s =~ s/[\x{FB76}\x{FB78}\x{FB79}\x{FB77}]+/\x{0683}/g; + $s =~ s/[\x{FB5A}\x{FB5C}\x{FB5D}\x{FB5B}]+/\x{0680}/g; + $s =~ s/[\x{FB9E}\x{FB9F}]+/\x{06BA}/g; + $s =~ s/[\x{FEC9}\x{FECB}\x{FECC}\x{FECA}]+/\x{0639}/g; + $s =~ s/[\x{FEDD}\x{FEDF}\x{FEE0}\x{FEDE}]+/\x{0644}/g; + $s =~ s/[\x{FB50}\x{FB51}]+/\x{0671}/g; + $s =~ s/[\x{FEB1}\x{FEB3}\x{FEB4}\x{FEB2}]+/\x{0633}/g; + $s =~ s/[\x{FE95}\x{FE97}\x{FE98}\x{FE96}]+/\x{062A}/g; + $s =~ s/[\x{FBD7}\x{FBD8}]+/\x{06C7}/g; + $s =~ s/[\x{FEAF}\x{FEB0}]+/\x{0632}/g; + $s =~ s/[\x{FEAB}\x{FEAC}]+/\x{0630}/g; + + # Remove tatweel + $s =~ s/\x{0640}//g; + # Remove vowels and hamza + $s =~ s/[\x{064B}-\x{0655}]+//g; + # Remove right-to-left and left-to-right + $s =~ s/[\x{200F}\x{200E}]+//g; + # Arabic Keheh to Arabic Kaf + $s =~ s/\x{06A9}/\x{0643}/g; + # Arabic Yeh to Farsi Yeh + $s =~ s/\x{064A}/\x{06CC}/g; + # Decompose RIAL + $s =~ s/\x{FDFC}/\x{0631}\x{06CC}\x{0627}\x{0644}/g; + # Farsi arabic-indic digits to arabic-indic digits + $s =~ s/\x{06F0}/\x{0660}/g; + $s =~ s/\x{06F1}/\x{0661}/g; + $s =~ s/\x{06F2}/\x{0662}/g; + $s =~ s/\x{06F3}/\x{0663}/g; + $s =~ s/\x{06F4}/\x{0664}/g; + $s =~ s/\x{06F5}/\x{0665}/g; + $s =~ s/\x{06F6}/\x{0666}/g; + $s =~ s/\x{06F7}/\x{0667}/g; + $s =~ s/\x{06F8}/\x{0668}/g; + $s =~ s/\x{06F9}/\x{0669}/g; + # Arabic-indic digits to digits + $s =~ s/\x{0660}/0/g; + $s =~ s/\x{0661}/1/g; + $s =~ s/\x{0662}/2/g; + $s =~ s/\x{0663}/3/g; + $s =~ s/\x{0664}/4/g; + $s =~ s/\x{0665}/5/g; + $s =~ s/\x{0666}/6/g; + $s =~ s/\x{0667}/7/g; + $s =~ s/\x{0668}/8/g; + $s =~ s/\x{0669}/9/g; + # Arabic comma to comma + $s =~ s/\x{060C}/\x{002C}/g; + + $s =~ s/\|/ /g; + if ($s ne "") { + print "$s"; + } else { + print ""; + } + } + print "\n"; +} + diff --git a/egs/yomdle_fa/v1/local/yomdle2csv.py b/egs/yomdle_fa/v1/local/yomdle2csv.py new file mode 100755 index 00000000000..8f208e2d968 --- /dev/null +++ b/egs/yomdle_fa/v1/local/yomdle2csv.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 + +""" +GEDI2CSV +Convert GEDI-type bounding boxes to CSV format + +GEDI Format Example: + + + + + + + + + +CSV Format Example +ID,name,col1,row1,col2,row2,col3,row3,col4,row4,confidence,truth,pgrot,bbrot,qual,script,lang +0,chinese_scanned_books_0001_0.png,99,41,99,14,754,14,754,41,100,凡我的邻人说是好的,有一大部分在我灵魂中却,0,0.0,0,,zh-cn +""" + +import logging +import os +import sys +import time +import glob +import csv +import imghdr +from PIL import Image +import argparse +import pdb +import cv2 +import numpy as np +import xml.etree.ElementTree as ET + +sin = np.sin +cos = np.cos +pi = np.pi + +def Rotate2D(pts, cnt, ang=90): + M = np.array([[cos(ang),-sin(ang)],[sin(ang),cos(ang)]]) + res = np.dot(pts-cnt,M)+cnt + return M, res + +def npbox2string(npar): + if np.shape(npar)[0] != 1: + print('Error during CSV conversion\n') + c1,r1 = npar[0][0],npar[0][1] + c2,r2 = npar[0][2],npar[0][3] + c3,r3 = npar[0][4],npar[0][5] + c4,r4 = npar[0][6],npar[0][7] + + return c1,r1,c2,r2,c3,r3,c4,r4 + +# cv2.minAreaRect() returns a Box2D structure which contains following detals - ( center (x,y), (width, height), angle of rotation ) +# Get 4 corners of the rectangle using cv2.boxPoints() + +class GEDI2CSV(object): + + """ Initialize the extractor""" + def __init__(self, logger, args): + self._logger = logger + self._args = args + + """ + Segment image with GEDI bounding box information + """ + def csvfile(self, coords, polys, baseName, pgrot): + + """ for writing the files """ + writePath = self._args.outputDir + if os.path.isdir(writePath) != True: + os.makedirs(writePath) + + rotlist = [] + + header=['ID','name','col1','row1','col2','row2','col3','row3','col4','row4','confidence','truth','pgrot','bbrot','qual','script','lang'] + conf=100 + pgrot = 0 + bbrot = 0 + qual = 0 + script = '' + + write_ctr = 0 + if len(coords) == 0 and len(polys) == 0: + self._logger.info('Found %s with no text content',(baseName)) + print('...Found %s with no text content' % (baseName)) + return + + strPos = writePath + baseName + + for j in polys: + try: + arr = [] + [id,poly_val,text,qual,lang] = j + script=None + #print(j) + for i in poly_val: + if len(i.strip()) > 0: + #print(i) + arr.append(eval(i)) + + contour = np.asarray(arr) + #print(contour) + convex = cv2.convexHull(contour) + rect = cv2.minAreaRect(convex) + box = cv2.boxPoints(rect) + box = np.int0(box) + box = np.reshape(box,(-1,1)).T + c1,r1,c2,r2,c3,r3,c4,r4 = npbox2string(box) + + bbrot = 0.0 + + rotlist.append([id,baseName + '_' + id + '.png',c1,r1,c2,r2,c3,r3,c4,r4,conf,text,pgrot,bbrot,qual,script,lang]) + + except: + print('...polygon error %s, %s' % (j, baseName)) + continue + + # then write out all of list to file + with open(strPos + ".csv", "w", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(header) + for row in rotlist: + writer.writerow(row) + write_ctr += 1 + + return write_ctr + + +def main(args): + + startTime = time.clock() + + writePath = args.outputDir + print('write to %s' % (writePath)) + if os.path.isdir(writePath) != True: + os.makedirs(writePath) + + """ Setup logging """ + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + if args.log: + handler = logging.FileHandler(args.log) + handler.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + + gtconverter = GEDI2CSV(logger, args) + namespaces = {"gedi" : "http://lamp.cfar.umd.edu/media/projects/GEDI/"} + keyCnt=0 + + fileCnt = 0 + line_write_ctr = 0 + line_error_ctr = 0 + file_error_ctr = 0 + """ + Get all XML files in the directory and sub folders + """ + print('reading %s' % (args.inputDir)) + for root, dirnames, filenames in os.walk(args.inputDir, followlinks=True): + for file in filenames: + if file.lower().endswith('.xml'): + fullName = os.path.join(root,file) + baseName = os.path.splitext(fullName) + + fileCnt += 1 + + try: + """ read the XML file """ + tree = ET.parse(fullName) + except: + print('...ERROR parsing %s' % (fullName)) + file_error_ctr += 1 + continue + + gedi_root = tree.getroot() + child = gedi_root.findall('gedi:DL_DOCUMENT',namespaces)[0] + totalpages = int(child.attrib['NrOfPages']) + coordinates=[] + polygons = [] + + """ and for each page """ + for i, pgs in enumerate(child.iterfind('gedi:DL_PAGE',namespaces)): + + if 'GEDI_orientation' not in pgs.attrib: + pageRot=0 + else: + pageRot = int(pgs.attrib['GEDI_orientation']) + logger.info(' PAGE ROTATION %s, %s' % (fullName, str(pageRot))) + + """ find children for each page """ + for zone in pgs.findall('gedi:DL_ZONE',namespaces): + + if zone.attrib['gedi_type']=='Text' : + if zone.get('polygon'): + keyCnt+=1 + polygons.append([zone.attrib['id'],zone.get('polygon').split(';'), + zone.get('Text_Content'),zone.get('Illegible'),zone.get('Language')]) + else: + print('...Not polygon') + + + if len(coordinates) > 0 or len(polygons) > 0: + line_write_ctr += gtconverter.csvfile(coordinates, polygons, os.path.splitext(file)[0], pageRot) + else: + print('...%s has no text content' % (baseName[0])) + + + print('complete...total files %d, lines written %d, img errors %d, line error %d' % (fileCnt, line_write_ctr, file_error_ctr, line_error_ctr)) + + +def parse_arguments(argv): + """ Args and defaults """ + parser = argparse.ArgumentParser() + + parser.add_argument('--inputDir', type=str, help='Input directory', default='/data/YOMDLE/final_arabic/xml') + parser.add_argument('--outputDir', type=str, help='Output directory', default='/exp/YOMDLE/final_arabic/csv_truth/') + parser.add_argument('--log', type=str, help='Log directory', default='/exp/logs.txt') + + return parser.parse_args(argv) + + +if __name__ == '__main__': + """ Run """ + main(parse_arguments(sys.argv[1:])) diff --git a/egs/yomdle_fa/v1/path.sh b/egs/yomdle_fa/v1/path.sh new file mode 100644 index 00000000000..2d17b17a84a --- /dev/null +++ b/egs/yomdle_fa/v1/path.sh @@ -0,0 +1,6 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C diff --git a/egs/yomdle_fa/v1/run.sh b/egs/yomdle_fa/v1/run.sh new file mode 100755 index 00000000000..a7547b1ee69 --- /dev/null +++ b/egs/yomdle_fa/v1/run.sh @@ -0,0 +1,120 @@ +#!/bin/bash + +set -e +stage=0 +nj=60 + +database_slam=/export/corpora5/slam/SLAM/Farsi/transcribed +database_yomdle=/export/corpora5/slam/YOMDLE/final_farsi +download_dir=data_yomdle_farsi/download/ +extra_lm=download/extra_lm.txt +data_dir=data_yomdle_farsi +exp_dir=exp_yomdle_farsi + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + local/create_download.sh --database-slam $database_slam \ + --database-yomdle $database_yomdle \ + --slam-dir download/slam_farsi \ + --yomdle-dir download/yomdle_farsi +fi + +if [ $stage -le 0 ]; then + mkdir -p data_slam_farsi/slam + mkdir -p data_yomdle_farsi/yomdle + local/process_data.py download/slam_farsi data_slam_farsi/slam + local/process_data.py download/yomdle_farsi data_yomdle_farsi/yomdle + ln -s ../data_slam_farsi/slam ${data_dir}/test + ln -s ../data_yomdle_farsi/yomdle ${data_dir}/train + image/fix_data_dir.sh ${data_dir}/test + image/fix_data_dir.sh ${data_dir}/train +fi + +mkdir -p $data_dir/{train,test}/data +if [ $stage -le 1 ]; then + echo "$0: Obtaining image groups. calling get_image2num_frames" + echo "Date: $(date)." + image/get_image2num_frames.py --feat-dim 40 $data_dir/train + image/get_allowed_lengths.py --frame-subsampling-factor 4 10 $data_dir/train + + for datasplit in train test; do + echo "$0: Extracting features and calling compute_cmvn_stats for dataset: $datasplit. " + echo "Date: $(date)." + local/extract_features.sh --nj $nj --cmd "$cmd" \ + --feat-dim 40 --num-channels 3 --fliplr true \ + $data_dir/${datasplit} + steps/compute_cmvn_stats.sh $data_dir/${datasplit} || exit 1; + done + + echo "$0: Fixing data directory for train dataset" + echo "Date: $(date)." + utils/fix_data_dir.sh $data_dir/train +fi + +if [ $stage -le 2 ]; then + for datasplit in train; do + echo "$(date) stage 2: Performing augmentation, it will double training data" + local/augment_data.sh --nj $nj --cmd "$cmd" --feat-dim 40 --fliplr false $data_dir/${datasplit} $data_dir/${datasplit}_aug $data_dir + steps/compute_cmvn_stats.sh $data_dir/${datasplit}_aug || exit 1; + done +fi + +if [ $stage -le 3 ]; then + echo "$0: Preparing dictionary and lang..." + if [ ! -f $data_dir/train/bpe.out ]; then + cut -d' ' -f2- $data_dir/train/text | local/bidi.py | utils/lang/bpe/prepend_words.py | python3 utils/lang/bpe/learn_bpe.py -s 700 > $data_dir/train/bpe.out + for datasplit in test train train_aug; do + cut -d' ' -f1 $data_dir/$datasplit/text > $data_dir/$datasplit/ids + cut -d' ' -f2- $data_dir/$datasplit/text | local/bidi.py | utils/lang/bpe/prepend_words.py | python3 utils/lang/bpe/apply_bpe.py -c $data_dir/train/bpe.out | sed 's/@@//g' > $data_dir/$datasplit/bpe_text + mv $data_dir/$datasplit/text $data_dir/$datasplit/text.old + paste -d' ' $data_dir/$datasplit/ids $data_dir/$datasplit/bpe_text > $data_dir/$datasplit/text + done + fi + + local/prepare_dict.sh --data-dir $data_dir --dir $data_dir/local/dict + # This recipe uses byte-pair encoding, the silences are part of the words' pronunciations. + # So we set --sil-prob to 0.0 + utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 8 --sil-prob 0.0 --position-dependent-phones false \ + $data_dir/local/dict "" $data_dir/lang/temp $data_dir/lang + utils/lang/bpe/add_final_optional_silence.sh --final-sil-prob 0.5 $data_dir/lang +fi + +if [ $stage -le 4 ]; then + echo "$0: Estimating a language model for decoding..." + local/train_lm.sh --data-dir $data_dir --dir $data_dir/local/local_lm + utils/format_lm.sh $data_dir/lang $data_dir/local/local_lm/data/arpa/3gram_unpruned.arpa.gz \ + $data_dir/local/dict/lexicon.txt $data_dir/lang_test +fi + +if [ $stage -le 5 ]; then + echo "$0: Calling the flat-start chain recipe..." + echo "Date: $(date)." + local/chain/run_flatstart_cnn1a.sh --nj $nj --train-set train_aug --data-dir $data_dir --exp-dir $exp_dir +fi + +if [ $stage -le 6 ]; then + echo "$0: Aligning the training data using the e2e chain model..." + echo "Date: $(date)." + steps/nnet3/align.sh --nj $nj --cmd "$cmd" \ + --scale-opts '--transition-scale=1.0 --acoustic-scale=1.0 --self-loop-scale=1.0' \ + $data_dir/train_aug $data_dir/lang $exp_dir/chain/e2e_cnn_1a $exp_dir/chain/e2e_ali_train +fi + +if [ $stage -le 7 ]; then + echo "$0: Building a tree and training a regular chain model using the e2e alignments..." + echo "Date: $(date)." + local/chain/run_cnn_e2eali_1b.sh --nj $nj --train-set train_aug --data-dir $data_dir --exp-dir $exp_dir +fi + +if [ $stage -le 8 ]; then + echo "$0: Estimating a language model for lattice rescoring...$(date)" + local/train_lm_lr.sh --data-dir $data_dir --dir $data_dir/local/local_lm_lr --extra-lm $extra_lm --order 6 + + utils/build_const_arpa_lm.sh $data_dir/local/local_lm_lr/data/arpa/6gram_unpruned.arpa.gz \ + $data_dir/lang_test $data_dir/lang_test_lr + steps/lmrescore_const_arpa.sh $data_dir/lang_test $data_dir/lang_test_lr \ + $data_dir/test $exp_dir/chain/cnn_e2eali_1b/decode_test $exp_dir/chain/cnn_e2eali_1b/decode_test_lr +fi diff --git a/egs/yomdle_fa/v1/steps b/egs/yomdle_fa/v1/steps new file mode 120000 index 00000000000..1b186770dd1 --- /dev/null +++ b/egs/yomdle_fa/v1/steps @@ -0,0 +1 @@ +../../wsj/s5/steps/ \ No newline at end of file diff --git a/egs/yomdle_fa/v1/utils b/egs/yomdle_fa/v1/utils new file mode 120000 index 00000000000..a3279dc8679 --- /dev/null +++ b/egs/yomdle_fa/v1/utils @@ -0,0 +1 @@ +../../wsj/s5/utils/ \ No newline at end of file diff --git a/egs/yomdle_korean/README.txt b/egs/yomdle_korean/README.txt new file mode 100644 index 00000000000..3bf4cc8cd2d --- /dev/null +++ b/egs/yomdle_korean/README.txt @@ -0,0 +1,3 @@ +This directory contains example scripts for OCR on the Yomdle and Slam datasets. +Training is done on the Yomdle dataset and testing is done on Slam. +LM rescoring is also done with extra corpus data obtained from various sources diff --git a/egs/yomdle_korean/v1/cmd.sh b/egs/yomdle_korean/v1/cmd.sh new file mode 100755 index 00000000000..3d69546dfe8 --- /dev/null +++ b/egs/yomdle_korean/v1/cmd.sh @@ -0,0 +1,12 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. +export cmd="queue.pl" diff --git a/egs/yomdle_korean/v1/image b/egs/yomdle_korean/v1/image new file mode 120000 index 00000000000..1668ee99922 --- /dev/null +++ b/egs/yomdle_korean/v1/image @@ -0,0 +1 @@ +../../cifar/v1/image/ \ No newline at end of file diff --git a/egs/yomdle_korean/v1/local/augment_data.sh b/egs/yomdle_korean/v1/local/augment_data.sh new file mode 100755 index 00000000000..136bfd24eb2 --- /dev/null +++ b/egs/yomdle_korean/v1/local/augment_data.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# Copyright 2018 Hossein Hadian +# 2018 Ashish Arora + +# Apache 2.0 +# This script performs data augmentation. + +nj=4 +cmd=run.pl +feat_dim=40 +verticle_shift=0 +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +srcdir=$1 +outdir=$2 +datadir=$3 + +mkdir -p $datadir/augmentations +echo "copying $srcdir to $datadir/augmentations/aug1, allowed length, creating feats.scp" + +for set in aug1; do + image/copy_data_dir.sh --spk-prefix $set- --utt-prefix $set- \ + $srcdir $datadir/augmentations/$set + cat $srcdir/allowed_lengths.txt > $datadir/augmentations/$set/allowed_lengths.txt + local/extract_features.sh --nj $nj --cmd "$cmd" --feat-dim $feat_dim \ + --vertical-shift $verticle_shift \ + --fliplr false --augment 'random_scale' $datadir/augmentations/$set +done + +echo " combine original data and data from different augmentations" +utils/combine_data.sh --extra-files images.scp $outdir $srcdir $datadir/augmentations/aug1 +cat $srcdir/allowed_lengths.txt > $outdir/allowed_lengths.txt diff --git a/egs/yomdle_korean/v1/local/chain/compare_wer.sh b/egs/yomdle_korean/v1/local/chain/compare_wer.sh new file mode 100755 index 00000000000..80f31e0f311 --- /dev/null +++ b/egs/yomdle_korean/v1/local/chain/compare_wer.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +# this script is used for comparing decoding results between systems. +# e.g. local/chain/compare_wer.sh exp/chain/cnn{1a,1b} + +# Copyright 2017 Chun Chieh Chang +# 2017 Ashish Arora + +if [ $# == 0 ]; then + echo "Usage: $0: [ ... ]" + echo "e.g.: $0 exp/chain/cnn{1a,1b}" + exit 1 +fi + +echo "# $0 $*" +used_epochs=false + +echo -n "# System " +for x in $*; do printf "% 10s" " $(basename $x)"; done +echo + +echo -n "# WER " +for x in $*; do + wer=$(cat $x/decode_test/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# WER (rescored) " +for x in $*; do + wer=$(cat $x/decode_test_rescored/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# CER " +for x in $*; do + cer=$(cat $x/decode_test/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + +echo -n "# CER (rescored) " +for x in $*; do + cer=$(cat $x/decode_test_rescored/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + +if $used_epochs; then + exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. +fi + +echo -n "# Final train prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo diff --git a/egs/yomdle_korean/v1/local/chain/run_cnn_e2eali.sh b/egs/yomdle_korean/v1/local/chain/run_cnn_e2eali.sh new file mode 120000 index 00000000000..fcf59f917c1 --- /dev/null +++ b/egs/yomdle_korean/v1/local/chain/run_cnn_e2eali.sh @@ -0,0 +1 @@ +tuning/run_cnn_e2eali_1b.sh \ No newline at end of file diff --git a/egs/yomdle_korean/v1/local/chain/run_e2e_cnn.sh b/egs/yomdle_korean/v1/local/chain/run_e2e_cnn.sh new file mode 100755 index 00000000000..cea60a221a1 --- /dev/null +++ b/egs/yomdle_korean/v1/local/chain/run_e2e_cnn.sh @@ -0,0 +1,132 @@ +#!/bin/bash + +# Copyright 2017 Hossein Hadian + +# This script does end2end chain training (i.e. from scratch) +# local/chain/compare_wer.sh exp/chain/e2e_cnn_1a/ +# System e2e_cnn_1a +# score_basic score_nomalized +# WER 13.64 10.6 +# WER (rescored) 13.13 10.2 +# CER 2.99 3.0 +# CER (rescored) 2.88 2.9 +# Final train prob 0.0113 +# Final valid prob 0.0152 +# steps/info/chain_dir_info.pl exp/chain/e2e_cnn_1a +# exp/chain/e2e_cnn_1a: num-iters=48 nj=5..8 num-params=3.0M dim=40->352 combine=0.047->0.047 (over 2) logprob:train/valid[31,47,final]=(0.002,0.008,0.011/0.008,0.013,0.015) + +set -e +# configs for 'chain' +stage=0 +nj=30 +train_stage=-10 +get_egs_stage=-10 +affix=1a + +# training options +tdnn_dim=450 +minibatch_size=150=64,32/300=32,16/600=16,8/1200=8,4 +cmvn_opts="--norm-means=false --norm-vars=false" +train_set=train +lang_decode=data/lang +decode_e2e=true +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 1 ]; then + steps/nnet3/chain/e2e/prepare_e2e.sh --nj $nj --cmd "$cmd" \ + --shared-phones true \ + --type mono \ + data/$train_set $lang $treedir + $cmd $treedir/log/make_phone_lm.log \ + cat data/$train_set/text \| \ + steps/nnet3/chain/e2e/text_to_phones.py data/lang \| \ + utils/sym2int.pl -f 2- data/lang/phones.txt \| \ + chain-est-phone-lm --num-extra-lm-states=500 \ + ark:- $treedir/phone_lm.fst +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $treedir/tree | grep num-pdfs | awk '{print $2}') + cnn_opts="l2-regularize=0.075" + tdnn_opts="l2-regularize=0.075" + output_opts="l2-regularize=0.1" + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=70" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=40 name=input + conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $output_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs +fi + +if [ $stage -le 3 ]; then + steps/nnet3/chain/e2e/train_e2e.py --stage $train_stage \ + --cmd "$cmd" \ + --feat.cmvn-opts "$cmvn_opts" \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.apply-deriv-weights true \ + --egs.stage $get_egs_stage \ + --egs.opts "--num_egs_diagnostic 100 --num_utts_subset 400" \ + --chain.frame-subsampling-factor 4 \ + --chain.alignment-subsampling-factor 4 \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs 3 \ + --trainer.optimization.momentum 0 \ + --trainer.optimization.num-jobs-initial 5 \ + --trainer.optimization.num-jobs-final 8 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.optimization.shrink-value 1.0 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs true \ + --feat-dir data/${train_set} \ + --tree-dir $treedir \ + --dir $dir || exit 1; +fi diff --git a/egs/yomdle_korean/v1/local/chain/tuning/run_cnn_e2eali_1a.sh b/egs/yomdle_korean/v1/local/chain/tuning/run_cnn_e2eali_1a.sh new file mode 100755 index 00000000000..03333f6d229 --- /dev/null +++ b/egs/yomdle_korean/v1/local/chain/tuning/run_cnn_e2eali_1a.sh @@ -0,0 +1,236 @@ +#!/bin/bash + +# e2eali_1a is the same as 1a but uses the e2e chain model to get the +# lattice alignments and to build a tree + +# local/chain/compare_wer.sh exp/old/chain/cnn_e2eali_1a/ +# System cnn_e2eali_1a +# WER 15.68 +# CER 3.18 +# Final train prob -0.0331 +# Final valid prob -0.0395 + +# steps/info/chain_dir_info.pl exp/chain/cnn_e2eali_1a/ +# exp/old/chain/cnn_e2eali_1a/: num-iters=33 nj=3..16 num-params=5.2M dim=40->456 combine=-0.035->-0.035 (over 1) xent:train/valid[21,32,final]=(-0.226,-0.175,-0.169/-0.248,-0.202,-0.195) logprob:train/valid[21,32,final]=(-0.039,-0.034,-0.033/-0.046,-0.040,-0.039) + +# Normalize scoring +# WER = 11.7 +# CER = 3.3 + +set -e -o pipefail +stage=0 +nj=30 +train_set=train +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. +affix=_1a #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +common_egs_dir= +reporting_email= + +# chain options +train_stage=-10 +xent_regularize=0.1 +frame_subsampling_factor=4 +# training chunk-options +chunk_width=340,300,200,100 +num_leaves=500 +# we don't need extra left/right context for TDNN systems. +tdnn_dim=450 +# training options +srand=0 +remove_egs=false +lang_decode=data/lang +lang_rescore=data/lang_rescore_6g +decode_chain=false +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 2 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/nnet3/align_lats.sh --nj $nj --cmd "$cmd" \ + --acoustic-scale 1.0 \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0' \ + ${train_data_dir} data/lang $e2echain_model_dir $lat_dir + echo "" >$lat_dir/splice_opts +fi + +if [ $stage -le 3 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor $frame_subsampling_factor \ + --alignment-subsampling-factor 1 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$cmd" $num_leaves ${train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 4 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + cnn_opts="l2-regularize=0.075" + tdnn_opts="l2-regularize=0.075" + output_opts="l2-regularize=0.1" + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=90" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=40 name=input + conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' mod?els... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 5 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/iam-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=false \ + --chain.frame-subsampling-factor=$frame_subsampling_factor \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=900" \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=4 \ + --trainer.frames-per-iter=1000000 \ + --trainer.optimization.num-jobs-initial=3 \ + --trainer.optimization.num-jobs-final=16 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=32,16 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 6 ] && $decode_chain; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $lang_decode \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 7 ] && $decode_chain; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --beam 12 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/test $dir/decode_test || exit 1; + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/test $dir/decode_test{,_rescored} || exit 1 + + echo "Done. Date: $(date). Results:" + local/chain/compare_wer.sh $dir +fi diff --git a/egs/yomdle_korean/v1/local/chain/tuning/run_cnn_e2eali_1b.sh b/egs/yomdle_korean/v1/local/chain/tuning/run_cnn_e2eali_1b.sh new file mode 100755 index 00000000000..fd9cdc8921d --- /dev/null +++ b/egs/yomdle_korean/v1/local/chain/tuning/run_cnn_e2eali_1b.sh @@ -0,0 +1,208 @@ +#!/bin/bash + +# e2eali_1b is the same as e2eali_1a but has fewer CNN layers, smaller +# l2-regularize, more epochs and uses dropout. + +#local/chain/compare_wer.sh exp/chain/cnn_e2eali_1b/ +# System cnn_e2eali_1b +# score_basic score_nomalized +# WER 13.01 10.0 +# WER (rescored) 12.69 9.6 +# CER 2.78 3.0 +# CER (rescored) 2.70 2.8 +# Final train prob -0.0568 +# Final valid prob -0.0410 +#steps/info/chain_dir_info.pl exp/chain/cnn_e2eali_1b +#exp/chain/cnn_e2eali_1b: num-iters=67 nj=3..16 num-params=5.2M dim=40->464 combine=-0.052->-0.052 (over 1) xent:train/valid[43,66,final]=(-0.379,-0.319,-0.304/-0.291,-0.234,-0.227) logprob:train/valid[43,66,final]=(-0.069,-0.058,-0.057/-0.046,-0.041,-0.041) +set -e -o pipefail +stage=0 +nj=30 +train_set=train +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. +affix=_1a #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +common_egs_dir= +reporting_email= + +# chain options +train_stage=-10 +xent_regularize=0.1 +frame_subsampling_factor=4 +# training chunk-options +chunk_width=340,300,200,100 +num_leaves=1000 +# we don't need extra left/right context for TDNN systems. +tdnn_dim=550 +# training options +srand=0 +remove_egs=false +lang_decode=data/lang +decode_chain=true +dropout_schedule='0,0@0.20,0.2@0.50,0' +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 2 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/nnet3/align_lats.sh --nj $nj --cmd "$cmd" \ + --acoustic-scale 1.0 \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0' \ + ${train_data_dir} data/lang $e2echain_model_dir $lat_dir + echo "" >$lat_dir/splice_opts +fi + +if [ $stage -le 3 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor $frame_subsampling_factor \ + --alignment-subsampling-factor 1 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$cmd" $num_leaves ${train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 4 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + cnn_opts="l2-regularize=0.03 dropout-proportion=0.0" + tdnn_opts="l2-regularize=0.03" + output_opts="l2-regularize=0.04" + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=90" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=40 name=input + conv-relu-batchnorm-dropout-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-dropout-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-dropout-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-dropout-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-dropout-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-dropout-layer name=cnn6 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-dropout-layer name=cnn7 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + relu-batchnorm-dropout-layer name=tdnn1 input=Append(-8,-4,0,4,8) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + relu-batchnorm-dropout-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + relu-batchnorm-dropout-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' mod?els... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 5 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/iam-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=false \ + --chain.frame-subsampling-factor=$frame_subsampling_factor \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=900" \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=16 \ + --trainer.frames-per-iter=2000000 \ + --trainer.optimization.num-jobs-initial=3 \ + --trainer.optimization.num-jobs-final=16 \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=32,16 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi diff --git a/egs/yomdle_korean/v1/local/check_tools.sh b/egs/yomdle_korean/v1/local/check_tools.sh new file mode 100755 index 00000000000..5b4d3107d3b --- /dev/null +++ b/egs/yomdle_korean/v1/local/check_tools.sh @@ -0,0 +1,43 @@ +#!/bin/bash -u + +# Copyright 2015 (c) Johns Hopkins University (Jan Trmal ) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +[ -f ./path.sh ] && . ./path.sh +set +e + +command -v python3 >&/dev/null \ + || { echo >&2 "python3 not found on PATH. You will have to install Python3, preferably >= 3.6"; exit 1; } + +python3 -c "import numpy" +if [ $? -ne 0 ] ; then + echo >&2 "This recipe needs numpy installed." + exit 1 +fi + +python3 -c "import scipy" +if [ $? -ne 0 ] ; then + echo >&2 "This recipe needs scipy installed." + exit 1 +fi + +python3 -c "import scipy.misc; scipy.misc.__dict__['imread']" +if [ $? -ne 0 ] ; then + echo >&2 "This recipe needs scipy-image and Pillow installed." + exit 1 +fi + + +exit 0 diff --git a/egs/yomdle_korean/v1/local/extract_features.sh b/egs/yomdle_korean/v1/local/extract_features.sh new file mode 100755 index 00000000000..3880ebad3e8 --- /dev/null +++ b/egs/yomdle_korean/v1/local/extract_features.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +# Copyright 2017 Yiwen Shao +# 2018 Ashish Arora + +# Apache 2.0 +# This script runs the make features script in parallel. + +nj=4 +cmd=run.pl +feat_dim=40 +augment='no_aug' +fliplr=false +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +data=$1 +featdir=$data/data +scp=$data/images.scp +logdir=$data/log + +mkdir -p $logdir +mkdir -p $featdir + +# make $featdir an absolute pathname +featdir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $featdir ${PWD}` + +for n in $(seq $nj); do + split_scps="$split_scps $logdir/images.$n.scp" +done + +# split images.scp +utils/split_scp.pl $scp $split_scps || exit 1; + +$cmd JOB=1:$nj $logdir/extract_features.JOB.log \ + image/ocr/make_features.py $logdir/images.JOB.scp \ + --allowed_len_file_path $data/allowed_lengths.txt \ + --feat-dim $feat_dim --fliplr $fliplr --augment_type $augment \| \ + copy-feats --compress=true --compression-method=7 \ + ark:- ark,scp:$featdir/images.JOB.ark,$featdir/images.JOB.scp + +## aggregates the output scp's to get feats.scp +for n in $(seq $nj); do + cat $featdir/images.$n.scp || exit 1; +done > $data/feats.scp || exit 1 diff --git a/egs/yomdle_korean/v1/local/normalize_data.py b/egs/yomdle_korean/v1/local/normalize_data.py new file mode 100755 index 00000000000..fba3e762789 --- /dev/null +++ b/egs/yomdle_korean/v1/local/normalize_data.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Hossein Hadian + +# Apache 2.0 +# This script converts a BPE-encoded text to normal text. It is used in scoring + +import sys, io +import string +import unicodedata +infile = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') +output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') +for line in infile: + words = line.strip().split() + uttid = words[0] + transcript = ' '.join(words[1:]) + text_normalized = unicodedata.normalize('NFC', transcript) + output.write(uttid + ' ' + text_normalized + '\n') diff --git a/egs/yomdle_korean/v1/local/prepare_dict.sh b/egs/yomdle_korean/v1/local/prepare_dict.sh new file mode 100755 index 00000000000..22db5ae834d --- /dev/null +++ b/egs/yomdle_korean/v1/local/prepare_dict.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash + +# Copyright 2017 Hossein Hadian +# 2017 Babak Rekabdar +# 2017 Chun Chieh Chang +# 2017 Ashish Arora + +# This script prepares the dictionary. + +set -e +dir=data/local/dict +. ./utils/parse_options.sh || exit 1; + +mkdir -p $dir + +local/prepare_lexicon.py $dir + +cut -d' ' -f2- $dir/lexicon.txt | sed 's/SIL//g' | tr ' ' '\n' | sort -u | sed '/^$/d' >$dir/nonsilence_phones.txt || exit 1; + +echo ' SIL' >> $dir/lexicon.txt + +echo SIL > $dir/silence_phones.txt + +echo SIL >$dir/optional_silence.txt + +echo -n "" >$dir/extra_questions.txt diff --git a/egs/yomdle_korean/v1/local/prepare_lexicon.py b/egs/yomdle_korean/v1/local/prepare_lexicon.py new file mode 100755 index 00000000000..ec8d43d8335 --- /dev/null +++ b/egs/yomdle_korean/v1/local/prepare_lexicon.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Babak Rekabdar +# 2017 Hossein Hadian +# 2017 Chun Chieh Chang +# 2017 Ashish Arora +# Apache 2.0 + +# This script prepares lexicon for BPE. It gets the set of all words that occur in data/train/text. +# Since this lexicon is based on BPE, it replaces '|' with silence. + +import argparse +import os +import unicodedata +parser = argparse.ArgumentParser(description="""Creates the list of characters and words in lexicon""") +parser.add_argument('dir', type=str, help='output path') +args = parser.parse_args() + +### main ### +lex = {} +text_path = os.path.join('data', 'train', 'text') +with open(text_path, 'r', encoding='utf-8') as f: + for line in f: + line_vect = line.strip().split(' ') + for i in range(1, len(line_vect)): + char_normalized = unicodedata.normalize('NFD', line_vect[i]).replace('\n', '') + characters = list(char_normalized) + characters = " ".join([ 'SIL' if char == '|' else char for char in characters]) + characters = list(characters) + characters = "".join([ '' if char == '#' else char for char in characters]) + lex[line_vect[i]] = characters + +with open(os.path.join(args.dir, 'lexicon.txt'), 'w', encoding='utf-8') as fp: + for key in sorted(lex): + fp.write(key + " " + lex[key] + "\n") diff --git a/egs/yomdle_korean/v1/local/process_corpus.py b/egs/yomdle_korean/v1/local/process_corpus.py new file mode 100755 index 00000000000..b39030270b7 --- /dev/null +++ b/egs/yomdle_korean/v1/local/process_corpus.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# Copyright 2018 Ashish Arora +# Apache 2.0 +# This script reads valid phones and removes the lines in the corpus +# which have any other phone. + +import os +import sys, io + +phone_file = os.path.join('data/local/text/cleaned/phones.txt') +infile = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') +output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') +phone_dict = dict() +with open(phone_file, 'r', encoding='utf-8') as phone_fh: + for line in phone_fh: + line = line.strip().split()[0] + phone_dict[line] = line + +phone_dict[' '] = ' ' +corpus_text = list() +for line in infile: + text = line.strip() + skip_text = False + for phone in text: + if phone not in phone_dict.keys(): + skip_text = True + break + if not skip_text: + output.write(text+ '\n') + diff --git a/egs/yomdle_korean/v1/local/process_data.py b/egs/yomdle_korean/v1/local/process_data.py new file mode 100755 index 00000000000..d7546b0a803 --- /dev/null +++ b/egs/yomdle_korean/v1/local/process_data.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora +# 2018 Chun Chieh Chang + +""" This script reads the extracted Tamil OCR (yomdle and slam) database files + and creates the following files (for the data subset selected via --dataset): + text, utt2spk, images.scp. + Eg. local/process_data.py data/download/ data/local/splits/train.txt data/train + + Eg. text file: english_phone_books_0001_1 To sum up, then, it would appear that + utt2spk file: english_phone_books_0001_0 english_phone_books_0001 + images.scp file: english_phone_books_0001_0 \ + data/download/truth_line_image/english_phone_books_0001_0.png +""" + +import argparse +import os +import sys +import csv +import itertools +import unicodedata +import re +import string +import unicodedata +parser = argparse.ArgumentParser(description="Creates text, utt2spk, and images.scp files") +parser.add_argument('database_path', type=str, help='Path to data') +parser.add_argument('data_split', type=str, help='Path to file that contain datasplits') +parser.add_argument('out_dir', type=str, help='directory to output files') +args = parser.parse_args() + +### main ### +print("Processing '{}' data...".format(args.out_dir)) + +text_file = os.path.join(args.out_dir, 'text') +text_fh = open(text_file, 'w', encoding='utf-8') +utt2spk_file = os.path.join(args.out_dir, 'utt2spk') +utt2spk_fh = open(utt2spk_file, 'w', encoding='utf-8') +image_file = os.path.join(args.out_dir, 'images.scp') +image_fh = open(image_file, 'w', encoding='utf-8') + +with open(args.data_split) as f: + for line in f: + line = line.strip() + image_id = line + image_filename = image_id + '.png' + image_filepath = os.path.join(args.database_path, 'truth_line_image', image_filename) + if not os.path.isfile (image_filepath): + print("File does not exist {}".format(image_filepath)) + continue + line_id = int(line.split('_')[-1]) + csv_filename = '_'.join(line.split('_')[:-1]) + '.csv' + csv_filepath = os.path.join(args.database_path, 'truth_csv', csv_filename) + csv_file = open(csv_filepath, 'r', encoding='utf-8') + for row in csv.reader(csv_file): + if row[1] == image_filename: + text = row[11] + text_vect = text.split() # this is to avoid non-utf-8 spaces + text = " ".join(text_vect) + #text_normalized = unicodedata.normalize('NFD', text).replace('\n', '') + if not text: + continue + text_fh.write(image_id + ' ' + text + '\n') + utt2spk_fh.write(image_id + ' ' + '_'.join(line.split('_')[:-1]) + '\n') + image_fh.write(image_id + ' ' + image_filepath + '\n') diff --git a/egs/yomdle_korean/v1/local/score.sh b/egs/yomdle_korean/v1/local/score.sh new file mode 100755 index 00000000000..31564d25326 --- /dev/null +++ b/egs/yomdle_korean/v1/local/score.sh @@ -0,0 +1,5 @@ +#!/bin/bash + + +steps/scoring/score_kaldi_wer.sh "$@" +steps/scoring/score_kaldi_cer.sh --stage 2 "$@" diff --git a/egs/yomdle_korean/v1/local/semisup/chain/run_cnn_chainali_semisupervised_1a.sh b/egs/yomdle_korean/v1/local/semisup/chain/run_cnn_chainali_semisupervised_1a.sh new file mode 100755 index 00000000000..f6b2c1bac42 --- /dev/null +++ b/egs/yomdle_korean/v1/local/semisup/chain/run_cnn_chainali_semisupervised_1a.sh @@ -0,0 +1,327 @@ +#!/bin/bash + +# Copyright 2017 Vimal Manohar +# 2018 Ashish Arora +# Apache 2.0 +# This script is semi-supervised recipe with 25k line images of supervised data +# and 22k line images of unsupervised data with naive splitting. +# Based on "Semi-Supervised Training of Acoustic Models using Lattice-Free MMI", +# Vimal Manohar, Hossein Hadian, Daniel Povey, Sanjeev Khudanpur, ICASSP 2018 +# http://www.danielpovey.com/files/2018_icassp_semisupervised_mmi.pdf +# local/semisup/run_semisup.sh shows how to call this. + +# We use 3-gram LM trained on 5M lines of auxilary data. +# This script uses the same tree as that for the seed model. +# Unsupervised set: train_unsup (25k tamil line images) +# unsup_frames_per_eg=150 +# Deriv weights: Lattice posterior of best path pdf +# Unsupervised weight: 1.0 +# Weights for phone LM (supervised, unsupervised): 3,2 +# LM for decoding unsupervised data: 4gram +# Supervision: Naive split lattices +# output-0 and output-1 are for superivsed and unsupervised data respectively. + +# local/chain/compare_wer.sh exp/chain/cnn_e2eali_1b/ exp/semisup_100k/chain/tdnn_semisup_1a/ +# System cnn_e2eali_1b tdnn_semisup_1a +# WER 15.06 13.83 +# CER 3.15 2.83 +# Final train prob -0.0343 0.6103-0.0360 +# Final valid prob -0.0403 0.6054-0.0418 + +# steps/info/chain_dir_info.pl exp/semisup_100k/chain/tdnn_semisup_1a/ +# exp/semisup_100k/chain/tdnn_semisup_1a/: num-iters=58 nj=6..16 num-params=3.7M dim=40->456 combine=0.240->0.240 (over 1) + +# Normalize scoring +#WER = 10.4 +#CER = 2.9 + +set -u -e -o pipefail + +stage=0 # Start from -1 for supervised seed system training +train_stage=-100 +nj=30 +test_nj=30 + +# The following 3 options decide the output directory for semi-supervised +# chain system +# dir=${exp_root}/chain${chain_affix}/tdnn${tdnn_affix} +exp_root=exp/semisup_100k +chain_affix= # affix for chain dir +tdnn_affix=_semisup_1a # affix for semi-supervised chain system + +# Datasets-Expects supervised_set and unsupervised_set +supervised_set=train +unsupervised_set=train_unsup + +# Input seed system +sup_chain_dir=exp/chain/cnn_e2eali_1b # supervised chain system +sup_lat_dir=exp/chain/e2e_train_lats # Seed model options +sup_tree_dir=exp/chain/tree_e2e # tree directory for supervised chain system + +# Semi-supervised options +supervision_weights=1.0,1.0 # Weights for supervised, unsupervised data egs. + # Can be used to scale down the effect of unsupervised data + # by using a smaller scale for it e.g. 1.0,0.3 +lm_weights=3,2 # Weights on phone counts from supervised, unsupervised data for denominator FST creation + +sup_egs_dir= # Supply this to skip supervised egs creation +unsup_egs_dir= # Supply this to skip unsupervised egs creation +unsup_egs_opts= # Extra options to pass to unsupervised egs creation +# Neural network opts +xent_regularize=0.1 +tdnn_dim=450 +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +if [ -f ./path.sh ]; then . ./path.sh; fi +. ./utils/parse_options.sh + +lang_decode=data/lang +lang_rescore=data/lang_rescore_6g +dir=$exp_root/chain$chain_affix/tdnn$tdnn_affix +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=40 name=input + + conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts + + # We use separate outputs for supervised and unsupervised data + # so we can properly track the train and valid objectives. + output name=output-0 input=output.affine + output name=output-1 input=output.affine + output name=output-0-xent input=output-xent.log-softmax + output name=output-1-xent input=output-xent.log-softmax +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +# Get values for $model_left_context, $model_right_context +. $dir/configs/vars + +left_context=$model_left_context +right_context=$model_right_context + +egs_left_context=$(perl -e "print int($left_context + $frame_subsampling_factor / 2)") +egs_right_context=$(perl -e "print int($right_context + $frame_subsampling_factor / 2)") + +if [ -z "$sup_egs_dir" ]; then + sup_egs_dir=$dir/egs_$supervised_set + frames_per_eg=$(cat $sup_chain_dir/egs/info/frames_per_eg) + + if [ $stage -le 12 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $sup_egs_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/fisher_english-$(date +'%m_%d_%H_%M')/s5c/$sup_egs_dir/storage $sup_egs_dir/storage + fi + mkdir -p $sup_egs_dir/ + touch $sup_egs_dir/.nodelete # keep egs around when that run dies. + + echo "$0: generating egs from the supervised data" + steps/nnet3/chain/get_egs.sh --cmd "$cmd" \ + --left-tolerance 3 --right-tolerance 3 \ + --left-context $egs_left_context --right-context $egs_right_context \ + --frame-subsampling-factor $frame_subsampling_factor \ + --alignment-subsampling-factor 1 \ + --frames-overlap-per-eg 0 --constrained false \ + --frames-per-eg $frames_per_eg \ + --frames-per-iter 2000000 \ + --cmvn-opts "$cmvn_opts" \ + --generate-egs-scp true \ + data/${supervised_set} $dir \ + $sup_lat_dir $sup_egs_dir + fi +else + frames_per_eg=$(cat $sup_egs_dir/info/frames_per_eg) +fi + +unsup_frames_per_eg=340,300,200,100 # Using a frames-per-eg of 150 for unsupervised data + # was found to be better than allowing smaller chunks + # (160,140,110,80) like for supervised system +lattice_lm_scale=0.5 # lm-scale for using the weights from unsupervised lattices when + # creating numerator supervision +lattice_prune_beam=6.0 # beam for pruning the lattices prior to getting egs + # for unsupervised data +tolerance=3 # frame-tolerance for chain training + +unsup_lat_dir=$sup_chain_dir/decode_$unsupervised_set +if [ -z "$unsup_egs_dir" ]; then + unsup_egs_dir=$dir/egs_$unsupervised_set + + if [ $stage -le 13 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $unsup_egs_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/fisher_english-$(date +'%m_%d_%H_%M')/s5c/$unsup_egs_dir/storage $unsup_egs_dir/storage + fi + mkdir -p $unsup_egs_dir + touch $unsup_egs_dir/.nodelete # keep egs around when that run dies. + + echo "$0: generating egs from the unsupervised data" + steps/nnet3/chain/get_egs.sh \ + --cmd "$cmd" --alignment-subsampling-factor 1 \ + --left-tolerance $tolerance --right-tolerance $tolerance \ + --left-context $egs_left_context --right-context $egs_right_context \ + --frames-per-eg $unsup_frames_per_eg --frames-per-iter 2000000 \ + --frame-subsampling-factor $frame_subsampling_factor \ + --cmvn-opts "$cmvn_opts" --lattice-lm-scale $lattice_lm_scale \ + --lattice-prune-beam "$lattice_prune_beam" \ + --deriv-weights-scp $sup_chain_dir/best_path_$unsupervised_set/weights.scp \ + --generate-egs-scp true $unsup_egs_opts \ + data/$unsupervised_set $dir \ + $unsup_lat_dir $unsup_egs_dir + fi +fi + +comb_egs_dir=$dir/comb_egs +if [ $stage -le 14 ]; then + steps/nnet3/chain/multilingual/combine_egs.sh --cmd "$cmd" \ + --block-size 64 \ + --lang2weight $supervision_weights 2 \ + $sup_egs_dir $unsup_egs_dir $comb_egs_dir + touch $comb_egs_dir/.nodelete # keep egs around when that run dies. +fi + +if [ $train_stage -le -4 ]; then + # This is to skip stages of den-fst creation, which was already done. + train_stage=-4 +fi + +chunk_width=340,300,200,100 +if [ $stage -le 15 ]; then + steps/nnet3/chain/train.py --stage $train_stage \ + --egs.dir "$comb_egs_dir" \ + --egs.chunk-width=$chunk_width \ + --cmd "$cmd" \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00001 \ + --chain.apply-deriv-weights=true \ + --chain.frame-subsampling-factor=$frame_subsampling_factor \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=900" \ + --trainer.srand=0 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=32,16 \ + --trainer.optimization.momentum=0.0 \ + --trainer.frames-per-iter=2000000 \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs 5 \ + --trainer.optimization.num-jobs-initial 6 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ + --cleanup.remove-egs false \ + --feat-dir data/$supervised_set \ + --tree-dir $sup_tree_dir \ + --lat-dir $sup_lat_dir \ + --dir $dir || exit 1; + +fi + +if [ $stage -le 17 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 $lang_decode $dir $dir/graph +fi + +if [ $stage -le 18 ]; then + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --beam 12 --frames-per-chunk 340 --nj $nj --cmd "$cmd" \ + $dir/graph data/test $dir/decode_test + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/test $dir/decode_test{,_rescored} || exit 1 +fi +exit 0; + diff --git a/egs/yomdle_korean/v1/local/semisup/chain/run_cnn_chainali_semisupervised_1b.sh b/egs/yomdle_korean/v1/local/semisup/chain/run_cnn_chainali_semisupervised_1b.sh new file mode 100755 index 00000000000..8185fa2645d --- /dev/null +++ b/egs/yomdle_korean/v1/local/semisup/chain/run_cnn_chainali_semisupervised_1b.sh @@ -0,0 +1,325 @@ +#!/bin/bash + +# Copyright 2017 Vimal Manohar +# 2018 Ashish Arora +# Apache 2.0 +# This script is semi-supervised recipe with 25k line images of supervised data +# and 22k line images of unsupervised data with naive splitting. +# Based on "Semi-Supervised Training of Acoustic Models using Lattice-Free MMI", +# Vimal Manohar, Hossein Hadian, Daniel Povey, Sanjeev Khudanpur, ICASSP 2018 +# http://www.danielpovey.com/files/2018_icassp_semisupervised_mmi.pdf +# local/semisup/run_semisup.sh shows how to call this. + +# We use 3-gram LM trained on 5M lines of auxilary data. +# This script uses the same tree as that for the seed model. +# Unsupervised set: train_unsup (25k tamil line images) +# unsup_frames_per_eg=150 +# Deriv weights: Lattice posterior of best path pdf +# Unsupervised weight: 1.0 +# Weights for phone LM (supervised, unsupervised): 3,2 +# LM for decoding unsupervised data: 4gram +# Supervision: Naive split lattices +# output-0 and output-1 are for superivsed and unsupervised data respectively. + +# local/chain/compare_wer.sh exp/semisup_100k/chain/tdnn_semisup_1b/ +# System tdnn_semisup_1b +# score_basic score_normalized +# WER 13.73 10.2 +# WER (rescored) 12.80 9.4 +# CER 2.78 2.8 +# CER (rescored) 2.57 2.7 +# Final train prob 0.6138-0.0337 +# Final valid prob 0.6115-0.0399 + +# steps/info/chain_dir_info.pl exp/semisup_100k/chain/tdnn_semisup_1b/ +# exp/semisup_100k/chain/tdnn_semisup_1b/: num-iters=46 nj=6..16 num-params=5.7M dim=40->456 combine=0.239->0.239 (over 1) + +set -u -e -o pipefail +stage=0 # Start from -1 for supervised seed system training +train_stage=-100 +nj=30 +test_nj=30 + +# The following 3 options decide the output directory for semi-supervised +# chain system +# dir=${exp_root}/chain${chain_affix}/tdnn${tdnn_affix} +exp_root=exp/semisup_100k +chain_affix= # affix for chain dir +tdnn_affix=_semisup_1b # affix for semi-supervised chain system + +# Datasets-Expects supervised_set and unsupervised_set +supervised_set=train +unsupervised_set=train_unsup + +# Input seed system +sup_chain_dir=exp/chain/cnn_e2eali_1b # supervised chain system +sup_lat_dir=exp/chain/e2e_train_lats # Seed model options +sup_tree_dir=exp/chain/tree_e2e # tree directory for supervised chain system + +# Semi-supervised options +supervision_weights=1.0,1.0 # Weights for supervised, unsupervised data egs. + # Can be used to scale down the effect of unsupervised data + # by using a smaller scale for it e.g. 1.0,0.3 +lm_weights=3,2 # Weights on phone counts from supervised, unsupervised data for denominator FST creation + +sup_egs_dir= # Supply this to skip supervised egs creation +unsup_egs_dir= # Supply this to skip unsupervised egs creation +unsup_egs_opts= # Extra options to pass to unsupervised egs creation +# Neural network opts +xent_regularize=0.1 +tdnn_dim=550 +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +if [ -f ./path.sh ]; then . ./path.sh; fi +. ./utils/parse_options.sh + +lang_decode=data/lang +lang_rescore=data/lang_rescore_6g +dropout_schedule='0,0@0.20,0.2@0.50,0' +dir=$exp_root/chain$chain_affix/tdnn$tdnn_affix +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=40 name=input + conv-relu-batchnorm-dropout-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-dropout-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-dropout-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-dropout-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-dropout-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-dropout-layer name=cnn6 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-dropout-layer name=cnn7 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + relu-batchnorm-dropout-layer name=tdnn1 input=Append(-4,-2,0,2,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + relu-batchnorm-dropout-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + relu-batchnorm-dropout-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts + + # We use separate outputs for supervised and unsupervised data + # so we can properly track the train and valid objectives. + output name=output-0 input=output.affine + output name=output-1 input=output.affine + output name=output-0-xent input=output-xent.log-softmax + output name=output-1-xent input=output-xent.log-softmax +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +# Get values for $model_left_context, $model_right_context +. $dir/configs/vars + +left_context=$model_left_context +right_context=$model_right_context + +egs_left_context=$(perl -e "print int($left_context + $frame_subsampling_factor / 2)") +egs_right_context=$(perl -e "print int($right_context + $frame_subsampling_factor / 2)") + +if [ -z "$sup_egs_dir" ]; then + sup_egs_dir=$dir/egs_$supervised_set + frames_per_eg=$(cat $sup_chain_dir/egs/info/frames_per_eg) + + if [ $stage -le 12 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $sup_egs_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/fisher_english-$(date +'%m_%d_%H_%M')/s5c/$sup_egs_dir/storage $sup_egs_dir/storage + fi + mkdir -p $sup_egs_dir/ + touch $sup_egs_dir/.nodelete # keep egs around when that run dies. + + echo "$0: generating egs from the supervised data" + steps/nnet3/chain/get_egs.sh --cmd "$cmd" \ + --left-tolerance 3 --right-tolerance 3 \ + --left-context $egs_left_context --right-context $egs_right_context \ + --frame-subsampling-factor $frame_subsampling_factor \ + --alignment-subsampling-factor 1 \ + --frames-overlap-per-eg 0 --constrained false \ + --frames-per-eg $frames_per_eg \ + --frames-per-iter 2000000 \ + --cmvn-opts "$cmvn_opts" \ + --generate-egs-scp true \ + data/${supervised_set} $dir \ + $sup_lat_dir $sup_egs_dir + fi +else + frames_per_eg=$(cat $sup_egs_dir/info/frames_per_eg) +fi + +unsup_frames_per_eg=340,300,200,100 # Using a frames-per-eg of 150 for unsupervised data + # was found to be better than allowing smaller chunks + # (160,140,110,80) like for supervised system +lattice_lm_scale=0.5 # lm-scale for using the weights from unsupervised lattices when + # creating numerator supervision +lattice_prune_beam=6.0 # beam for pruning the lattices prior to getting egs + # for unsupervised data +tolerance=3 # frame-tolerance for chain training + +unsup_lat_dir=$sup_chain_dir/decode_$unsupervised_set +if [ -z "$unsup_egs_dir" ]; then + unsup_egs_dir=$dir/egs_$unsupervised_set + + if [ $stage -le 13 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $unsup_egs_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/fisher_english-$(date +'%m_%d_%H_%M')/s5c/$unsup_egs_dir/storage $unsup_egs_dir/storage + fi + mkdir -p $unsup_egs_dir + touch $unsup_egs_dir/.nodelete # keep egs around when that run dies. + + echo "$0: generating egs from the unsupervised data" + steps/nnet3/chain/get_egs.sh \ + --cmd "$cmd" --alignment-subsampling-factor 1 \ + --left-tolerance $tolerance --right-tolerance $tolerance \ + --left-context $egs_left_context --right-context $egs_right_context \ + --frames-per-eg $unsup_frames_per_eg --frames-per-iter 2000000 \ + --frame-subsampling-factor $frame_subsampling_factor \ + --cmvn-opts "$cmvn_opts" --lattice-lm-scale $lattice_lm_scale \ + --lattice-prune-beam "$lattice_prune_beam" \ + --deriv-weights-scp $sup_chain_dir/best_path_$unsupervised_set/weights.scp \ + --generate-egs-scp true $unsup_egs_opts \ + data/$unsupervised_set $dir \ + $unsup_lat_dir $unsup_egs_dir + fi +fi + +comb_egs_dir=$dir/comb_egs +if [ $stage -le 14 ]; then + steps/nnet3/chain/multilingual/combine_egs.sh --cmd "$cmd" \ + --block-size 64 \ + --lang2weight $supervision_weights 2 \ + $sup_egs_dir $unsup_egs_dir $comb_egs_dir + touch $comb_egs_dir/.nodelete # keep egs around when that run dies. +fi + +if [ $train_stage -le -4 ]; then + # This is to skip stages of den-fst creation, which was already done. + train_stage=-4 +fi + +chunk_width=340,300,200,100 +if [ $stage -le 15 ]; then + steps/nnet3/chain/train.py --stage $train_stage \ + --egs.dir "$comb_egs_dir" \ + --egs.chunk-width=$chunk_width \ + --cmd "$cmd" \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00001 \ + --chain.apply-deriv-weights=true \ + --chain.frame-subsampling-factor=$frame_subsampling_factor \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=900" \ + --trainer.srand=0 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=32,16 \ + --trainer.optimization.momentum=0.0 \ + --trainer.frames-per-iter=2000000 \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs 16 \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.optimization.num-jobs-initial 6 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ + --cleanup.remove-egs false \ + --feat-dir data/$supervised_set \ + --tree-dir $sup_tree_dir \ + --lat-dir $sup_lat_dir \ + --dir $dir || exit 1; + +fi + +if [ $stage -le 17 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 $lang_decode $dir $dir/graph +fi + +if [ $stage -le 18 ]; then + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --beam 12 --frames-per-chunk 340 --nj $nj --cmd "$cmd" \ + $dir/graph data/test $dir/decode_test + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/test $dir/decode_test{,_rescored} || exit 1 +fi +exit 0; + diff --git a/egs/yomdle_korean/v1/local/semisup/process_data.py b/egs/yomdle_korean/v1/local/semisup/process_data.py new file mode 100755 index 00000000000..94ad770ec2d --- /dev/null +++ b/egs/yomdle_korean/v1/local/semisup/process_data.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora +# 2018 Chun Chieh Chang + +""" This script reads the slam boxed Tamil OCR dataset and creates the following + files utt2spk, images.scp. Since boxed data do not have transcripts, it do not + creates text file. It is created as a separate script, because the data that + local/process_data.py is processing contains some empty transcripts which + should be removed or it will create bug while applying BPE. + + Eg. local/semisup/process_data.py data/download/ data/local/splits/train_unsup.txt + data/train_unsup + + Eg. utt2spk file: english_phone_books_0001_0 english_phone_books_0001 + images.scp file: english_phone_books_0001_0 \ + data/download/truth_line_image/english_phone_books_0001_0.png +""" +import argparse +import os +import sys +import csv +import itertools +import unicodedata +import re +import string +parser = argparse.ArgumentParser(description="Creates text, utt2spk, and images.scp files") +parser.add_argument('database_path', type=str, help='Path to data') +parser.add_argument('data_split', type=str, help='Path to file that contain datasplits') +parser.add_argument('out_dir', type=str, help='directory to output files') +args = parser.parse_args() + +### main ### +print("Processing '{}' data...".format(args.out_dir)) + +utt2spk_file = os.path.join(args.out_dir, 'utt2spk') +utt2spk_fh = open(utt2spk_file, 'w', encoding='utf-8') +image_file = os.path.join(args.out_dir, 'images.scp') +image_fh = open(image_file, 'w', encoding='utf-8') +text_file = os.path.join(args.out_dir, 'text') +text_fh = open(text_file, 'w', encoding='utf-8') + +with open(args.data_split) as f: + for line in f: + line = line.strip() + image_id = line + image_filename = image_id + '.png' + image_filepath = os.path.join(args.database_path, 'truth_line_image', image_filename) + if not os.path.isfile (image_filepath): + print("File does not exist {}".format(image_filepath)) + continue + line_id = int(line.split('_')[-1]) + csv_filename = '_'.join(line.split('_')[:-1]) + '.csv' + csv_filepath = os.path.join(args.database_path, 'truth_csv', csv_filename) + csv_file = open(csv_filepath, 'r', encoding='utf-8') + for row in csv.reader(csv_file): + if row[1] == image_filename: + text = 'semisup' + text_fh.write(image_id + ' ' + text + '\n') + utt2spk_fh.write(image_id + ' ' + '_'.join(line.split('_')[:-1]) + '\n') + image_fh.write(image_id + ' ' + image_filepath + '\n') diff --git a/egs/yomdle_korean/v1/local/semisup/run_semisup.sh b/egs/yomdle_korean/v1/local/semisup/run_semisup.sh new file mode 100755 index 00000000000..5e20f50c99e --- /dev/null +++ b/egs/yomdle_korean/v1/local/semisup/run_semisup.sh @@ -0,0 +1,71 @@ +#!/bin/bash + +# Copyright 2017 Vimal Manohar +# 2018 Ashish Arora +# Apache 2.0 + +# This script demonstrates semi-supervised training using 25k line images of +# supervised data and 22k line images of unsupervised data. +# We assume the supervised data is in data/train and unsupervised data +# is in data/train_unsup. +# For LM training, we use 5 million lines of tamil text. + +set -e +set -o pipefail +stage=0 +nj=30 +exp_root=exp/semisup_56k +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +mkdir -p data/train_unsup/data +if [ $stage -le 0 ]; then + echo "stage 0: Processing train unsupervised data...$(date)" + local/semisup/process_data.py data/download/ \ + data/local/splits/train_unsup.txt \ + data/train_unsup + image/fix_data_dir.sh data/train_unsup +fi + +if [ $stage -le 1 ]; then + echo "stage 1: Obtaining image groups. calling get_image2num_frames..." + image/get_image2num_frames.py --feat-dim 40 data/train_unsup + image/get_allowed_lengths.py --frame-subsampling-factor 4 10 data/train_unsup + echo "Extracting features and calling compute_cmvn_stats: $(date) " + local/extract_features.sh --nj $nj --cmd "$cmd" --feat-dim 40 data/train_unsup + steps/compute_cmvn_stats.sh data/train_unsup || exit 1; + image/fix_data_dir.sh data/train_unsup +fi + +for f in data/train/utt2spk data/train_unsup/utt2spk \ + data/train/text; do + if [ ! -f $f ]; then + echo "$0: Could not find $f" + exit 1; + fi +done + +# Prepare semi-supervised train set +if [ $stage -le 1 ]; then + utils/combine_data.sh data/semisup100k_250k \ + data/train data/train_unsup || exit 1 +fi + +############################################################################### +# Semi-supervised training using 25k line images supervised data and +# 22k hours unsupervised data. We use tree, lattices +# and seed chain system from the previous stage. +############################################################################### +if [ $stage -le 2 ]; then + local/semisup/chain/run_cnn_chainali_semisupervised_1b.sh \ + --supervised-set train \ + --unsupervised-set train_unsup \ + --sup-chain-dir exp/chain/cnn_e2eali_1b_ep16_7cnn \ + --sup-lat-dir exp/chain/e2e_train_lats \ + --sup-tree-dir exp/chain/tree_e2e \ + --chain-affix "" \ + --tdnn-affix _semisup_ep16_7cnn \ + --stage 15 --train_stage 9 \ + --exp-root $exp_root || exit 1 +fi diff --git a/egs/yomdle_korean/v1/local/train_lm.sh b/egs/yomdle_korean/v1/local/train_lm.sh new file mode 100755 index 00000000000..c73c42fb7dc --- /dev/null +++ b/egs/yomdle_korean/v1/local/train_lm.sh @@ -0,0 +1,127 @@ +#!/bin/bash + +# Copyright 2016 Vincent Nguyen +# 2016 Johns Hopkins University (author: Daniel Povey) +# 2017 Ashish Arora +# 2017 Hossein Hadian +# Apache 2.0 +# +# This script trains a LM on the training transcriptions and corpus text. +# It is based on the example scripts distributed with PocoLM + +# It will check if pocolm is installed and if not will proceed with installation + +set -e +stage=0 +dir=data/local/local_lm +order=6 +echo "$0 $@" # Print the command line for logging +. ./utils/parse_options.sh || exit 1; + +lm_dir=${dir}/data + + +mkdir -p $dir +. ./path.sh || exit 1; # for KALDI_ROOT +export PATH=$KALDI_ROOT/tools/pocolm/scripts:$PATH +( # First make sure the pocolm toolkit is installed. + cd $KALDI_ROOT/tools || exit 1; + if [ -d pocolm ]; then + echo Not installing the pocolm toolkit since it is already there. + else + echo "$0: Please install the PocoLM toolkit with: " + echo " cd ../../../tools; extras/install_pocolm.sh; cd -" + exit 1; + fi +) || exit 1; + +bypass_metaparam_optim_opt= +# If you want to bypass the metaparameter optimization steps with specific metaparameters +# un-comment the following line, and change the numbers to some appropriate values. +# You can find the values from output log of train_lm.py. +# These example numbers of metaparameters is for 4-gram model (with min-counts) +# running with train_lm.py. +# The dev perplexity should be close to the non-bypassed model. +#bypass_metaparam_optim_opt="--bypass-metaparameter-optimization=0.031,0.860,0.678,0.194,0.037,0.006,0.928,0.712,0.454,0.220,0.926,0.844,0.749,0.358,0.966,0.879,0.783,0.544,0.966,0.826,0.674,0.450" +# Note: to use these example parameters, you may need to remove the .done files +# to make sure the make_lm_dir.py be called and tain only 3-gram model +#for order in 3; do +#rm -f ${lm_dir}/${num_word}_${order}.pocolm/.done + +if [ $stage -le 0 ]; then + mkdir -p ${dir}/data + mkdir -p ${dir}/data/text + + echo "$0: Getting the Data sources" + + rm ${dir}/data/text/* 2>/dev/null || true + + # use the validation data as the dev set. + # Note: the name 'dev' is treated specially by pocolm, it automatically + # becomes the dev set. + + cat data/local/text/cleaned/bpe_val.txt > ${dir}/data/text/dev.txt + # use the training data as an additional data source. + # we can later fold the dev data into this. + cat data/train/text | cut -d " " -f 2- > ${dir}/data/text/train.txt + cat data/local/text/cleaned/bpe_corpus.txt > ${dir}/data/text/corpus_text.txt + # for reporting perplexities, we'll use the "real" dev set. + # (the validation data is used as ${dir}/data/text/dev.txt to work + # out interpolation weights.) + # note, we can't put it in ${dir}/data/text/, because then pocolm would use + # it as one of the data sources. + cut -d " " -f 2- < data/test/text > ${dir}/data/real_dev_set.txt + + # get the wordlist from train and corpus text + cat ${dir}/data/text/{train,corpus_text}.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + cat ${dir}/data/word_count | awk '{print $2}' > ${dir}/data/wordlist +fi + +if [ $stage -le 1 ]; then + # decide on the vocabulary. + # Note: you'd use --wordlist if you had a previously determined word-list + # that you wanted to use. + # Note: if you have more than one order, use a certain amount of words as the + # vocab and want to restrict max memory for 'sort', + echo "$0: training the unpruned LM" + min_counts='train=1' + wordlist=${dir}/data/wordlist + + lm_name="`basename ${wordlist}`_${order}" + if [ -n "${min_counts}" ]; then + lm_name+="_`echo ${min_counts} | tr -s "[:blank:]" "_" | tr "=" "-"`" + fi + unpruned_lm_dir=${lm_dir}/${lm_name}.pocolm + train_lm.py --wordlist=${wordlist} --num-splits=20 --warm-start-ratio=20 \ + --limit-unk-history=true \ + ${bypass_metaparam_optim_opt} \ + ${dir}/data/text ${order} ${lm_dir}/work ${unpruned_lm_dir} + + get_data_prob.py ${dir}/data/real_dev_set.txt ${unpruned_lm_dir} 2>&1 | grep -F '[perplexity' + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${unpruned_lm_dir} | gzip -c > ${dir}/data/arpa/${order}gram_unpruned.arpa.gz +fi + +if [ $stage -le 2 ]; then + echo "$0: pruning the LM (to larger size)" + # Using 10 million n-grams for a big LM for rescoring purposes. + size=10000000 + prune_lm_dir.py --target-num-ngrams=$size --initial-threshold=0.02 ${unpruned_lm_dir} ${dir}/data/lm_${order}_prune_big + + get_data_prob.py ${dir}/data/real_dev_set.txt ${dir}/data/lm_${order}_prune_big 2>&1 | grep -F '[perplexity' + #[perplexity = 22.0613098868] over 151116.0 words + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${dir}/data/lm_${order}_prune_big | gzip -c > ${dir}/data/arpa/${order}gram_big.arpa.gz +fi + +if [ $stage -le 3 ]; then + echo "$0: pruning the LM (to smaller size)" + # Using 2 million n-grams for a smaller LM for graph building. Prune from the + # bigger-pruned LM, it'll be faster. + size=2000000 + prune_lm_dir.py --target-num-ngrams=$size ${dir}/data/lm_${order}_prune_big ${dir}/data/lm_${order}_prune_small + + get_data_prob.py ${dir}/data/real_dev_set.txt ${dir}/data/lm_${order}_prune_small 2>&1 | grep -F '[perplexity' + #[perplexity = 23.4801171202] over 151116.0 words + format_arpa_lm.py ${dir}/data/lm_${order}_prune_small | gzip -c > ${dir}/data/arpa/${order}gram_small.arpa.gz +fi diff --git a/egs/yomdle_korean/v1/local/wer_output_filter b/egs/yomdle_korean/v1/local/wer_output_filter new file mode 100755 index 00000000000..59e364e0231 --- /dev/null +++ b/egs/yomdle_korean/v1/local/wer_output_filter @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Hossein Hadian + +# Apache 2.0 +# This script converts a BPE-encoded text to normal text. It is used in scoring + +import sys, io +import string +infile = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') +output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') +for line in infile: + words = line.strip().split() + uttid = words[0] + transcript = ''.join(words[1:]) + transcript = transcript.replace('|', ' ') + output.write(uttid + ' ' + transcript + '\n') diff --git a/egs/yomdle_korean/v1/local/yomdle b/egs/yomdle_korean/v1/local/yomdle new file mode 120000 index 00000000000..2c4544c1399 --- /dev/null +++ b/egs/yomdle_korean/v1/local/yomdle @@ -0,0 +1 @@ +../../../yomdle_tamil/v1/local/yomdle/ \ No newline at end of file diff --git a/egs/yomdle_korean/v1/path.sh b/egs/yomdle_korean/v1/path.sh new file mode 100755 index 00000000000..2d17b17a84a --- /dev/null +++ b/egs/yomdle_korean/v1/path.sh @@ -0,0 +1,6 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C diff --git a/egs/yomdle_korean/v1/run_end2end.sh b/egs/yomdle_korean/v1/run_end2end.sh new file mode 100755 index 00000000000..65f5beb4b08 --- /dev/null +++ b/egs/yomdle_korean/v1/run_end2end.sh @@ -0,0 +1,186 @@ +#!/bin/bash + +# Copyright 2018 Hossein Hadian +# Ashish Arora +# Jonathan Chang +# Apache 2.0 + +set -e +stage=0 +nj=30 + +language_main=Korean +slam_dir=/export/corpora5/slam/SLAM/ +yomdle_dir=/export/corpora5/slam/YOMDLE/ +corpus_dir=/export/corpora5/handwriting_ocr/corpus_data/ko/ +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +./local/check_tools.sh +# Start from stage=-2 for data preparation. This stage stores line images, +# csv files and splits{train,test,train_unsup} data/download/truth_line_image, +# data/download/truth_csv and data/local/splits respectively. +if [ $stage -le -2 ]; then + echo "$(date): preparing data, obtaining line images and csv files..." + local/yomdle/create_download_dir.sh --language_main $language_main \ + --slam_dir $slam_dir --yomdle_dir $yomdle_dir +fi + +if [ $stage -le -1 ]; then + echo "$(date): getting corpus text for language modelling..." + mkdir -p data/local/text/cleaned + cat $corpus_dir/* > data/local/text/ko.txt + head -20000 data/local/text/ko.txt > data/local/text/cleaned/val.txt + tail -n +20000 data/local/text/ko.txt > data/local/text/cleaned/corpus.txt +fi + +mkdir -p data/{train,test}/data +if [ $stage -le 0 ]; then + echo "$0 stage 0: Processing train and test data.$(date)" + echo " creating text, images.scp, utt2spk and spk2utt" + #local/prepare_data.sh data/download/ + for set in train test; do + local/process_data.py data/download/ \ + data/local/splits/${set}.txt data/${set} + image/fix_data_dir.sh data/${set} + done +fi + +if [ $stage -le 1 ]; then + echo "$(date) stage 1: getting allowed image widths for e2e training..." + image/get_image2num_frames.py --feat-dim 40 data/train + image/get_allowed_lengths.py --frame-subsampling-factor 4 10 data/train + for set in train test; do + echo "$(date) Extracting features, creating feats.scp file" + local/extract_features.sh --nj $nj --cmd "$cmd" --feat-dim 40 data/${set} + steps/compute_cmvn_stats.sh data/${set} || exit 1; + done + image/fix_data_dir.sh data/train +fi + +if [ $stage -le 3 ]; then + echo "$(date) stage 3: BPE preparation" + # getting non-silence phones. + cut -d' ' -f2- data/train/text | \ +python3 <( +cat << "END" +import os, sys, io; +infile = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8'); +output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8'); +phone_dict = dict(); +for line in infile: + line_vect = line.strip().split(); + for word in line_vect: + for phone in word: + phone_dict[phone] = phone; + +for phone in phone_dict.keys(): + output.write(phone+ '\n'); +END + ) > data/local/text/cleaned/phones.txt + + cut -d' ' -f2- data/train/text > data/local/text/cleaned/train.txt + + echo "learning BPE..." + # it is currently learned with only training text but we can also use all corpus text + # to learn BPE. phones are added so that one isolated occurance of every phone exists. + cat data/local/text/cleaned/phones.txt data/local/text/cleaned/train.txt | \ + utils/lang/bpe/prepend_words.py | utils/lang/bpe/learn_bpe.py -s 700 > data/local/bpe.txt || exit 1; +fi + +if [ $stage -le 4 ]; then + echo "$(date) stage 4: applying BPE..." + echo "applying BPE on train, test text..." + for set in test train; do + cut -d' ' -f1 data/$set/text > data/$set/ids + cut -d' ' -f2- data/$set/text | utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt | \ + sed 's/@@//g' > data/$set/bpe_text + mv data/$set/text data/$set/text.old + paste -d' ' data/$set/ids data/$set/bpe_text > data/$set/text + rm -f data/$set/bpe_text data/$set/ids + done + + echo "applying BPE to corpus text..." + cat data/local/text/cleaned/corpus.txt | utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt | \ + sed 's/@@//g' > data/local/text/cleaned/bpe_corpus.txt + cat data/local/text/cleaned/val.txt | utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt | \ + sed 's/@@//g' > data/local/text/cleaned/bpe_val.txt +fi + +if [ $stage -le 5 ]; then + echo "$(date) stage 5: Preparing dictionary and lang..." + local/prepare_dict.sh --dir data/local/dict + utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 4 --sil-prob 0.0 --position-dependent-phones false \ + data/local/dict "" data/lang/temp data/lang + utils/lang/bpe/add_final_optional_silence.sh --final-sil-prob 0.5 data/lang +fi + +if [ $stage -le 6 ]; then + echo "$(date) stage 6: Calling the flat-start chain recipe..." + local/chain/run_e2e_cnn.sh +fi + +if [ $stage -le 7 ]; then + echo "$(date) stage 7: Aligning the training data using the e2e chain model..." + steps/nnet3/align.sh --nj $nj --cmd "$cmd" \ + --scale-opts '--transition-scale=1.0 --acoustic-scale=1.0 --self-loop-scale=1.0' \ + data/train data/lang exp/chain/e2e_cnn_1a exp/chain/e2e_ali_train +fi + +chunk_width='340,300,200,100' +lang_decode=data/lang +lang_rescore=data/lang_rescore_6g +if [ $stage -le 8 ]; then + echo "$(date) stage 8: Building a tree and training a regular chain model using the e2e alignments..." + local/chain/run_cnn_e2eali.sh --chunk_width $chunk_width +fi + +if [ $stage -le 9 ]; then + echo "$(date) stage 9: Estimating a language model for decoding..." + local/train_lm.sh + utils/format_lm.sh data/lang data/local/local_lm/data/arpa/6gram_small.arpa.gz \ + data/local/dict/lexicon.txt data/lang + utils/build_const_arpa_lm.sh data/local/local_lm/data/arpa/6gram_unpruned.arpa.gz \ + data/lang data/lang_rescore_6g +fi + +if [ $stage -le 10 ] && $decode_e2e; then + echo "$(date) stage 10: decoding end2end setup..." + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $lang_decode \ + exp/chain/e2e_cnn_1a/ exp/chain/e2e_cnn_1a/graph || exit 1; + + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj 30 --cmd "$cmd" --beam 12 \ + exp/chain/e2e_cnn_1a/graph data/test exp/chain/e2e_cnn_1a/decode_test || exit 1; + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/test exp/chain/e2e_cnn_1a/decode_test{,_rescored} || exit 1 + + echo "Done. Date: $(date). Results:" + local/chain/compare_wer.sh exp/chain/e2e_cnn_1a/ +fi + +if [ $stage -le 11 ] && $decode_chain; then + echo "$(date) stage 11: decoding chain alignment setup..." + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $lang_decode \ + exp/chain/cnn_e2eali_1a/ exp/chain/cnn_e2eali_1a/graph || exit 1; + + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj 30 --cmd "$cmd" --beam 12 \ + exp/chain/cnn_e2eali_1a/graph data/test exp/chain/cnn_e2eali_1a/decode_test || exit 1; + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/test exp/chain/cnn_e2eali_1a/decode_test{,_rescored} || exit 1 + + echo "Done. Date: $(date). Results:" + local/chain/compare_wer.sh exp/chain/cnn_e2eali_1a +fi diff --git a/egs/yomdle_korean/v1/steps b/egs/yomdle_korean/v1/steps new file mode 120000 index 00000000000..1b186770dd1 --- /dev/null +++ b/egs/yomdle_korean/v1/steps @@ -0,0 +1 @@ +../../wsj/s5/steps/ \ No newline at end of file diff --git a/egs/yomdle_korean/v1/utils b/egs/yomdle_korean/v1/utils new file mode 120000 index 00000000000..a3279dc8679 --- /dev/null +++ b/egs/yomdle_korean/v1/utils @@ -0,0 +1 @@ +../../wsj/s5/utils/ \ No newline at end of file diff --git a/egs/yomdle_russian/README.txt b/egs/yomdle_russian/README.txt new file mode 100644 index 00000000000..3bf4cc8cd2d --- /dev/null +++ b/egs/yomdle_russian/README.txt @@ -0,0 +1,3 @@ +This directory contains example scripts for OCR on the Yomdle and Slam datasets. +Training is done on the Yomdle dataset and testing is done on Slam. +LM rescoring is also done with extra corpus data obtained from various sources diff --git a/egs/yomdle_russian/v1/cmd.sh b/egs/yomdle_russian/v1/cmd.sh new file mode 100755 index 00000000000..3d69546dfe8 --- /dev/null +++ b/egs/yomdle_russian/v1/cmd.sh @@ -0,0 +1,12 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. +export cmd="queue.pl" diff --git a/egs/yomdle_russian/v1/image b/egs/yomdle_russian/v1/image new file mode 120000 index 00000000000..1668ee99922 --- /dev/null +++ b/egs/yomdle_russian/v1/image @@ -0,0 +1 @@ +../../cifar/v1/image/ \ No newline at end of file diff --git a/egs/yomdle_russian/v1/local/chain/compare_wer.sh b/egs/yomdle_russian/v1/local/chain/compare_wer.sh new file mode 100755 index 00000000000..80f31e0f311 --- /dev/null +++ b/egs/yomdle_russian/v1/local/chain/compare_wer.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +# this script is used for comparing decoding results between systems. +# e.g. local/chain/compare_wer.sh exp/chain/cnn{1a,1b} + +# Copyright 2017 Chun Chieh Chang +# 2017 Ashish Arora + +if [ $# == 0 ]; then + echo "Usage: $0: [ ... ]" + echo "e.g.: $0 exp/chain/cnn{1a,1b}" + exit 1 +fi + +echo "# $0 $*" +used_epochs=false + +echo -n "# System " +for x in $*; do printf "% 10s" " $(basename $x)"; done +echo + +echo -n "# WER " +for x in $*; do + wer=$(cat $x/decode_test/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# WER (rescored) " +for x in $*; do + wer=$(cat $x/decode_test_rescored/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# CER " +for x in $*; do + cer=$(cat $x/decode_test/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + +echo -n "# CER (rescored) " +for x in $*; do + cer=$(cat $x/decode_test_rescored/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + +if $used_epochs; then + exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. +fi + +echo -n "# Final train prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo diff --git a/egs/yomdle_russian/v1/local/chain/run_cnn_e2eali.sh b/egs/yomdle_russian/v1/local/chain/run_cnn_e2eali.sh new file mode 120000 index 00000000000..e2545b0186e --- /dev/null +++ b/egs/yomdle_russian/v1/local/chain/run_cnn_e2eali.sh @@ -0,0 +1 @@ +tuning/run_cnn_e2eali_1a.sh \ No newline at end of file diff --git a/egs/yomdle_russian/v1/local/chain/run_e2e_cnn.sh b/egs/yomdle_russian/v1/local/chain/run_e2e_cnn.sh new file mode 100755 index 00000000000..6f5742cd34b --- /dev/null +++ b/egs/yomdle_russian/v1/local/chain/run_e2e_cnn.sh @@ -0,0 +1,129 @@ +#!/bin/bash + +# Copyright 2017 Hossein Hadian +# This script does end2end chain training (i.e. from scratch) +# local/chain/compare_wer.sh exp/chain/e2e_cnn_1a/ +# System e2e_cnn_1a +# score_basic rescoring + nomalized +# WER 16.24 11.0 +# WER (rescored) 15.63 10.5 +# CER 5.98 5.6 +# CER (rescored) 5.66 5.3 +# Final train prob 0.1376 +# Final valid prob 0.1913 +# steps/info/chain_dir_info.pl exp/chain/e2e_cnn_1a +# exp/chain/e2e_cnn_1a: num-iters=27 nj=5..8 num-params=3.0M dim=40->470 combine=0.091->0.091 (over 1) logprob:train/valid[17,26,final]=(0.135,0.137,0.138/0.191,0.191,0.191) + +set -e +# configs for 'chain' +stage=0 +nj=30 +train_stage=-10 +get_egs_stage=-10 +affix=1a + +# training options +tdnn_dim=450 +minibatch_size=150=64,32/300=32,16/600=16,8/1200=8,4 +cmvn_opts="--norm-means=false --norm-vars=false" +train_set=train +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 1 ]; then + steps/nnet3/chain/e2e/prepare_e2e.sh --nj $nj --cmd "$cmd" \ + --shared-phones true \ + --type mono \ + data/$train_set $lang $treedir + $cmd $treedir/log/make_phone_lm.log \ + cat data/$train_set/text \| \ + steps/nnet3/chain/e2e/text_to_phones.py data/lang \| \ + utils/sym2int.pl -f 2- data/lang/phones.txt \| \ + chain-est-phone-lm --num-extra-lm-states=500 \ + ark:- $treedir/phone_lm.fst +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $treedir/tree | grep num-pdfs | awk '{print $2}') + cnn_opts="l2-regularize=0.075" + tdnn_opts="l2-regularize=0.075" + output_opts="l2-regularize=0.1" + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=70" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=40 name=input + conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $output_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs +fi + +if [ $stage -le 3 ]; then + steps/nnet3/chain/e2e/train_e2e.py --stage $train_stage \ + --cmd "$cmd" \ + --feat.cmvn-opts "$cmvn_opts" \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.apply-deriv-weights true \ + --egs.stage $get_egs_stage \ + --egs.opts "--num_egs_diagnostic 100 --num_utts_subset 400" \ + --chain.frame-subsampling-factor 4 \ + --chain.alignment-subsampling-factor 4 \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs 3 \ + --trainer.optimization.momentum 0 \ + --trainer.optimization.num-jobs-initial 5 \ + --trainer.optimization.num-jobs-final 8 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.optimization.shrink-value 1.0 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs true \ + --feat-dir data/${train_set} \ + --tree-dir $treedir \ + --dir $dir || exit 1; +fi diff --git a/egs/yomdle_russian/v1/local/chain/tuning/run_cnn_e2eali_1a.sh b/egs/yomdle_russian/v1/local/chain/tuning/run_cnn_e2eali_1a.sh new file mode 100755 index 00000000000..cd582472993 --- /dev/null +++ b/egs/yomdle_russian/v1/local/chain/tuning/run_cnn_e2eali_1a.sh @@ -0,0 +1,203 @@ +#!/bin/bash + +# local/chain/compare_wer.sh exp/chain/cnn_e2eali_1a +# System cnn_e2eali_1a rescoring + nomalized +# WER 12.08 7.7 +# WER (rescored) 11.90 7.5 +# CER 3.60 3.4 +# CER (rescored) 3.42 3.2 +# Final train prob -0.0373 +# Final valid prob -0.0362 +# steps/info/chain_dir_info.pl exp/chain/cnn_e2eali_1a +# exp/chain/cnn_e2eali_1a: num-iters=74 nj=3..16 num-params=6.3M dim=40->848 combine=-0.039->-0.039 (over 1) xent:train/valid[48,73,final]=(-0.206,-0.153,-0.146/-0.191,-0.156,-0.151) logprob:train/valid[48,73,final]=(-0.044,-0.038,-0.037/-0.040,-0.037,-0.036) + +set -e -o pipefail +stage=0 +nj=30 +train_set=train +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. +affix=_1a #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +common_egs_dir= +reporting_email= + +# chain options +train_stage=-10 +xent_regularize=0.1 +frame_subsampling_factor=4 +# training chunk-options +chunk_width=340,300,200,100 +num_leaves=1000 +# we don't need extra left/right context for TDNN systems. +tdnn_dim=550 +# training options +srand=0 +remove_egs=false +dropout_schedule='0,0@0.20,0.2@0.50,0' +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 2 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/nnet3/align_lats.sh --nj $nj --cmd "$cmd" \ + --acoustic-scale 1.0 \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0' \ + ${train_data_dir} data/lang $e2echain_model_dir $lat_dir + echo "" >$lat_dir/splice_opts +fi + +if [ $stage -le 3 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor $frame_subsampling_factor \ + --alignment-subsampling-factor 1 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$cmd" $num_leaves ${train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 4 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + cnn_opts="l2-regularize=0.03 dropout-proportion=0.0" + tdnn_opts="l2-regularize=0.03" + output_opts="l2-regularize=0.04" + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=90" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=40 name=input + conv-relu-batchnorm-dropout-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-dropout-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-dropout-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-dropout-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-dropout-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-dropout-layer name=cnn6 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-dropout-layer name=cnn7 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + relu-batchnorm-dropout-layer name=tdnn1 input=Append(-8,-4,0,4,8) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + relu-batchnorm-dropout-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + relu-batchnorm-dropout-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' mod?els... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 5 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/iam-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=false \ + --chain.frame-subsampling-factor=$frame_subsampling_factor \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=900" \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=16 \ + --trainer.frames-per-iter=2000000 \ + --trainer.optimization.num-jobs-initial=3 \ + --trainer.optimization.num-jobs-final=16 \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=32,16 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi diff --git a/egs/yomdle_russian/v1/local/check_tools.sh b/egs/yomdle_russian/v1/local/check_tools.sh new file mode 100755 index 00000000000..5b4d3107d3b --- /dev/null +++ b/egs/yomdle_russian/v1/local/check_tools.sh @@ -0,0 +1,43 @@ +#!/bin/bash -u + +# Copyright 2015 (c) Johns Hopkins University (Jan Trmal ) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +[ -f ./path.sh ] && . ./path.sh +set +e + +command -v python3 >&/dev/null \ + || { echo >&2 "python3 not found on PATH. You will have to install Python3, preferably >= 3.6"; exit 1; } + +python3 -c "import numpy" +if [ $? -ne 0 ] ; then + echo >&2 "This recipe needs numpy installed." + exit 1 +fi + +python3 -c "import scipy" +if [ $? -ne 0 ] ; then + echo >&2 "This recipe needs scipy installed." + exit 1 +fi + +python3 -c "import scipy.misc; scipy.misc.__dict__['imread']" +if [ $? -ne 0 ] ; then + echo >&2 "This recipe needs scipy-image and Pillow installed." + exit 1 +fi + + +exit 0 diff --git a/egs/yomdle_russian/v1/local/extract_features.sh b/egs/yomdle_russian/v1/local/extract_features.sh new file mode 100755 index 00000000000..3880ebad3e8 --- /dev/null +++ b/egs/yomdle_russian/v1/local/extract_features.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +# Copyright 2017 Yiwen Shao +# 2018 Ashish Arora + +# Apache 2.0 +# This script runs the make features script in parallel. + +nj=4 +cmd=run.pl +feat_dim=40 +augment='no_aug' +fliplr=false +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +data=$1 +featdir=$data/data +scp=$data/images.scp +logdir=$data/log + +mkdir -p $logdir +mkdir -p $featdir + +# make $featdir an absolute pathname +featdir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $featdir ${PWD}` + +for n in $(seq $nj); do + split_scps="$split_scps $logdir/images.$n.scp" +done + +# split images.scp +utils/split_scp.pl $scp $split_scps || exit 1; + +$cmd JOB=1:$nj $logdir/extract_features.JOB.log \ + image/ocr/make_features.py $logdir/images.JOB.scp \ + --allowed_len_file_path $data/allowed_lengths.txt \ + --feat-dim $feat_dim --fliplr $fliplr --augment_type $augment \| \ + copy-feats --compress=true --compression-method=7 \ + ark:- ark,scp:$featdir/images.JOB.ark,$featdir/images.JOB.scp + +## aggregates the output scp's to get feats.scp +for n in $(seq $nj); do + cat $featdir/images.$n.scp || exit 1; +done > $data/feats.scp || exit 1 diff --git a/egs/yomdle_russian/v1/local/prepare_dict.sh b/egs/yomdle_russian/v1/local/prepare_dict.sh new file mode 100755 index 00000000000..22db5ae834d --- /dev/null +++ b/egs/yomdle_russian/v1/local/prepare_dict.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash + +# Copyright 2017 Hossein Hadian +# 2017 Babak Rekabdar +# 2017 Chun Chieh Chang +# 2017 Ashish Arora + +# This script prepares the dictionary. + +set -e +dir=data/local/dict +. ./utils/parse_options.sh || exit 1; + +mkdir -p $dir + +local/prepare_lexicon.py $dir + +cut -d' ' -f2- $dir/lexicon.txt | sed 's/SIL//g' | tr ' ' '\n' | sort -u | sed '/^$/d' >$dir/nonsilence_phones.txt || exit 1; + +echo ' SIL' >> $dir/lexicon.txt + +echo SIL > $dir/silence_phones.txt + +echo SIL >$dir/optional_silence.txt + +echo -n "" >$dir/extra_questions.txt diff --git a/egs/yomdle_russian/v1/local/prepare_lexicon.py b/egs/yomdle_russian/v1/local/prepare_lexicon.py new file mode 100755 index 00000000000..a68b1cb49dd --- /dev/null +++ b/egs/yomdle_russian/v1/local/prepare_lexicon.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Babak Rekabdar +# 2017 Hossein Hadian +# 2017 Chun Chieh Chang +# 2017 Ashish Arora +# Apache 2.0 + +# This script prepares lexicon for BPE. It gets the set of all words that occur in data/train/text. +# Since this lexicon is based on BPE, it replaces '|' with silence. + +import argparse +import os +import unicodedata +parser = argparse.ArgumentParser(description="""Creates the list of characters and words in lexicon""") +parser.add_argument('dir', type=str, help='output path') +args = parser.parse_args() + +### main ### +lex = {} +text_path = os.path.join('data', 'train', 'text') +with open(text_path, 'r', encoding='utf-8') as f: + for line in f: + line_vect = line.strip().split(' ') + for i in range(1, len(line_vect)): + characters = list(line_vect[i]) + characters = " ".join([ 'SIL' if char == '|' else char for char in characters]) + characters = list(characters) + characters = "".join([ '' if char == '#' else char for char in characters]) + lex[line_vect[i]] = characters + +with open(os.path.join(args.dir, 'lexicon.txt'), 'w', encoding='utf-8') as fp: + for key in sorted(lex): + fp.write(key + " " + lex[key] + "\n") diff --git a/egs/yomdle_russian/v1/local/process_corpus.py b/egs/yomdle_russian/v1/local/process_corpus.py new file mode 100755 index 00000000000..b39030270b7 --- /dev/null +++ b/egs/yomdle_russian/v1/local/process_corpus.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# Copyright 2018 Ashish Arora +# Apache 2.0 +# This script reads valid phones and removes the lines in the corpus +# which have any other phone. + +import os +import sys, io + +phone_file = os.path.join('data/local/text/cleaned/phones.txt') +infile = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') +output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') +phone_dict = dict() +with open(phone_file, 'r', encoding='utf-8') as phone_fh: + for line in phone_fh: + line = line.strip().split()[0] + phone_dict[line] = line + +phone_dict[' '] = ' ' +corpus_text = list() +for line in infile: + text = line.strip() + skip_text = False + for phone in text: + if phone not in phone_dict.keys(): + skip_text = True + break + if not skip_text: + output.write(text+ '\n') + diff --git a/egs/yomdle_russian/v1/local/process_data.py b/egs/yomdle_russian/v1/local/process_data.py new file mode 100755 index 00000000000..d7546b0a803 --- /dev/null +++ b/egs/yomdle_russian/v1/local/process_data.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora +# 2018 Chun Chieh Chang + +""" This script reads the extracted Tamil OCR (yomdle and slam) database files + and creates the following files (for the data subset selected via --dataset): + text, utt2spk, images.scp. + Eg. local/process_data.py data/download/ data/local/splits/train.txt data/train + + Eg. text file: english_phone_books_0001_1 To sum up, then, it would appear that + utt2spk file: english_phone_books_0001_0 english_phone_books_0001 + images.scp file: english_phone_books_0001_0 \ + data/download/truth_line_image/english_phone_books_0001_0.png +""" + +import argparse +import os +import sys +import csv +import itertools +import unicodedata +import re +import string +import unicodedata +parser = argparse.ArgumentParser(description="Creates text, utt2spk, and images.scp files") +parser.add_argument('database_path', type=str, help='Path to data') +parser.add_argument('data_split', type=str, help='Path to file that contain datasplits') +parser.add_argument('out_dir', type=str, help='directory to output files') +args = parser.parse_args() + +### main ### +print("Processing '{}' data...".format(args.out_dir)) + +text_file = os.path.join(args.out_dir, 'text') +text_fh = open(text_file, 'w', encoding='utf-8') +utt2spk_file = os.path.join(args.out_dir, 'utt2spk') +utt2spk_fh = open(utt2spk_file, 'w', encoding='utf-8') +image_file = os.path.join(args.out_dir, 'images.scp') +image_fh = open(image_file, 'w', encoding='utf-8') + +with open(args.data_split) as f: + for line in f: + line = line.strip() + image_id = line + image_filename = image_id + '.png' + image_filepath = os.path.join(args.database_path, 'truth_line_image', image_filename) + if not os.path.isfile (image_filepath): + print("File does not exist {}".format(image_filepath)) + continue + line_id = int(line.split('_')[-1]) + csv_filename = '_'.join(line.split('_')[:-1]) + '.csv' + csv_filepath = os.path.join(args.database_path, 'truth_csv', csv_filename) + csv_file = open(csv_filepath, 'r', encoding='utf-8') + for row in csv.reader(csv_file): + if row[1] == image_filename: + text = row[11] + text_vect = text.split() # this is to avoid non-utf-8 spaces + text = " ".join(text_vect) + #text_normalized = unicodedata.normalize('NFD', text).replace('\n', '') + if not text: + continue + text_fh.write(image_id + ' ' + text + '\n') + utt2spk_fh.write(image_id + ' ' + '_'.join(line.split('_')[:-1]) + '\n') + image_fh.write(image_id + ' ' + image_filepath + '\n') diff --git a/egs/yomdle_russian/v1/local/score.sh b/egs/yomdle_russian/v1/local/score.sh new file mode 100755 index 00000000000..31564d25326 --- /dev/null +++ b/egs/yomdle_russian/v1/local/score.sh @@ -0,0 +1,5 @@ +#!/bin/bash + + +steps/scoring/score_kaldi_wer.sh "$@" +steps/scoring/score_kaldi_cer.sh --stage 2 "$@" diff --git a/egs/yomdle_russian/v1/local/train_lm.sh b/egs/yomdle_russian/v1/local/train_lm.sh new file mode 100755 index 00000000000..c73c42fb7dc --- /dev/null +++ b/egs/yomdle_russian/v1/local/train_lm.sh @@ -0,0 +1,127 @@ +#!/bin/bash + +# Copyright 2016 Vincent Nguyen +# 2016 Johns Hopkins University (author: Daniel Povey) +# 2017 Ashish Arora +# 2017 Hossein Hadian +# Apache 2.0 +# +# This script trains a LM on the training transcriptions and corpus text. +# It is based on the example scripts distributed with PocoLM + +# It will check if pocolm is installed and if not will proceed with installation + +set -e +stage=0 +dir=data/local/local_lm +order=6 +echo "$0 $@" # Print the command line for logging +. ./utils/parse_options.sh || exit 1; + +lm_dir=${dir}/data + + +mkdir -p $dir +. ./path.sh || exit 1; # for KALDI_ROOT +export PATH=$KALDI_ROOT/tools/pocolm/scripts:$PATH +( # First make sure the pocolm toolkit is installed. + cd $KALDI_ROOT/tools || exit 1; + if [ -d pocolm ]; then + echo Not installing the pocolm toolkit since it is already there. + else + echo "$0: Please install the PocoLM toolkit with: " + echo " cd ../../../tools; extras/install_pocolm.sh; cd -" + exit 1; + fi +) || exit 1; + +bypass_metaparam_optim_opt= +# If you want to bypass the metaparameter optimization steps with specific metaparameters +# un-comment the following line, and change the numbers to some appropriate values. +# You can find the values from output log of train_lm.py. +# These example numbers of metaparameters is for 4-gram model (with min-counts) +# running with train_lm.py. +# The dev perplexity should be close to the non-bypassed model. +#bypass_metaparam_optim_opt="--bypass-metaparameter-optimization=0.031,0.860,0.678,0.194,0.037,0.006,0.928,0.712,0.454,0.220,0.926,0.844,0.749,0.358,0.966,0.879,0.783,0.544,0.966,0.826,0.674,0.450" +# Note: to use these example parameters, you may need to remove the .done files +# to make sure the make_lm_dir.py be called and tain only 3-gram model +#for order in 3; do +#rm -f ${lm_dir}/${num_word}_${order}.pocolm/.done + +if [ $stage -le 0 ]; then + mkdir -p ${dir}/data + mkdir -p ${dir}/data/text + + echo "$0: Getting the Data sources" + + rm ${dir}/data/text/* 2>/dev/null || true + + # use the validation data as the dev set. + # Note: the name 'dev' is treated specially by pocolm, it automatically + # becomes the dev set. + + cat data/local/text/cleaned/bpe_val.txt > ${dir}/data/text/dev.txt + # use the training data as an additional data source. + # we can later fold the dev data into this. + cat data/train/text | cut -d " " -f 2- > ${dir}/data/text/train.txt + cat data/local/text/cleaned/bpe_corpus.txt > ${dir}/data/text/corpus_text.txt + # for reporting perplexities, we'll use the "real" dev set. + # (the validation data is used as ${dir}/data/text/dev.txt to work + # out interpolation weights.) + # note, we can't put it in ${dir}/data/text/, because then pocolm would use + # it as one of the data sources. + cut -d " " -f 2- < data/test/text > ${dir}/data/real_dev_set.txt + + # get the wordlist from train and corpus text + cat ${dir}/data/text/{train,corpus_text}.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + cat ${dir}/data/word_count | awk '{print $2}' > ${dir}/data/wordlist +fi + +if [ $stage -le 1 ]; then + # decide on the vocabulary. + # Note: you'd use --wordlist if you had a previously determined word-list + # that you wanted to use. + # Note: if you have more than one order, use a certain amount of words as the + # vocab and want to restrict max memory for 'sort', + echo "$0: training the unpruned LM" + min_counts='train=1' + wordlist=${dir}/data/wordlist + + lm_name="`basename ${wordlist}`_${order}" + if [ -n "${min_counts}" ]; then + lm_name+="_`echo ${min_counts} | tr -s "[:blank:]" "_" | tr "=" "-"`" + fi + unpruned_lm_dir=${lm_dir}/${lm_name}.pocolm + train_lm.py --wordlist=${wordlist} --num-splits=20 --warm-start-ratio=20 \ + --limit-unk-history=true \ + ${bypass_metaparam_optim_opt} \ + ${dir}/data/text ${order} ${lm_dir}/work ${unpruned_lm_dir} + + get_data_prob.py ${dir}/data/real_dev_set.txt ${unpruned_lm_dir} 2>&1 | grep -F '[perplexity' + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${unpruned_lm_dir} | gzip -c > ${dir}/data/arpa/${order}gram_unpruned.arpa.gz +fi + +if [ $stage -le 2 ]; then + echo "$0: pruning the LM (to larger size)" + # Using 10 million n-grams for a big LM for rescoring purposes. + size=10000000 + prune_lm_dir.py --target-num-ngrams=$size --initial-threshold=0.02 ${unpruned_lm_dir} ${dir}/data/lm_${order}_prune_big + + get_data_prob.py ${dir}/data/real_dev_set.txt ${dir}/data/lm_${order}_prune_big 2>&1 | grep -F '[perplexity' + #[perplexity = 22.0613098868] over 151116.0 words + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${dir}/data/lm_${order}_prune_big | gzip -c > ${dir}/data/arpa/${order}gram_big.arpa.gz +fi + +if [ $stage -le 3 ]; then + echo "$0: pruning the LM (to smaller size)" + # Using 2 million n-grams for a smaller LM for graph building. Prune from the + # bigger-pruned LM, it'll be faster. + size=2000000 + prune_lm_dir.py --target-num-ngrams=$size ${dir}/data/lm_${order}_prune_big ${dir}/data/lm_${order}_prune_small + + get_data_prob.py ${dir}/data/real_dev_set.txt ${dir}/data/lm_${order}_prune_small 2>&1 | grep -F '[perplexity' + #[perplexity = 23.4801171202] over 151116.0 words + format_arpa_lm.py ${dir}/data/lm_${order}_prune_small | gzip -c > ${dir}/data/arpa/${order}gram_small.arpa.gz +fi diff --git a/egs/yomdle_russian/v1/local/wer_output_filter b/egs/yomdle_russian/v1/local/wer_output_filter new file mode 100755 index 00000000000..59e364e0231 --- /dev/null +++ b/egs/yomdle_russian/v1/local/wer_output_filter @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Hossein Hadian + +# Apache 2.0 +# This script converts a BPE-encoded text to normal text. It is used in scoring + +import sys, io +import string +infile = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') +output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') +for line in infile: + words = line.strip().split() + uttid = words[0] + transcript = ''.join(words[1:]) + transcript = transcript.replace('|', ' ') + output.write(uttid + ' ' + transcript + '\n') diff --git a/egs/yomdle_russian/v1/local/yomdle b/egs/yomdle_russian/v1/local/yomdle new file mode 120000 index 00000000000..2c4544c1399 --- /dev/null +++ b/egs/yomdle_russian/v1/local/yomdle @@ -0,0 +1 @@ +../../../yomdle_tamil/v1/local/yomdle/ \ No newline at end of file diff --git a/egs/yomdle_russian/v1/path.sh b/egs/yomdle_russian/v1/path.sh new file mode 100755 index 00000000000..2d17b17a84a --- /dev/null +++ b/egs/yomdle_russian/v1/path.sh @@ -0,0 +1,6 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C diff --git a/egs/yomdle_russian/v1/run_end2end.sh b/egs/yomdle_russian/v1/run_end2end.sh new file mode 100755 index 00000000000..12beebeaa05 --- /dev/null +++ b/egs/yomdle_russian/v1/run_end2end.sh @@ -0,0 +1,186 @@ +#!/bin/bash + +# Copyright 2018 Hossein Hadian +# Ashish Arora +# Jonathan Chang +# Apache 2.0 + +set -e +stage=0 +nj=30 + +language_main=Russian +slam_dir=/export/corpora5/slam/SLAM/ +yomdle_dir=/export/corpora5/slam/YOMDLE/ +corpus_dir=/export/corpora5/handwriting_ocr/corpus_data/ru/ +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +./local/check_tools.sh +# Start from stage=-2 for data preparation. This stage stores line images, +# csv files and splits{train,test,train_unsup} data/download/truth_line_image, +# data/download/truth_csv and data/local/splits respectively. +if [ $stage -le -2 ]; then + echo "$0: $(date): preparing data, obtaining line images and csv files..." + local/yomdle/create_download_dir.sh --language_main $language_main \ + --slam_dir $slam_dir --yomdle_dir $yomdle_dir +fi + +if [ $stage -le -1 ]; then + echo "$0: $(date): getting corpus text for language modelling..." + mkdir -p data/local/text/cleaned + cat $corpus_dir/* > data/local/text/ru.txt + head -20000 data/local/text/ru.txt > data/local/text/cleaned/val.txt + tail -n +20000 data/local/text/ru.txt > data/local/text/cleaned/corpus.txt +fi + +mkdir -p data/{train,test}/data +if [ $stage -le 0 ]; then + echo "$0: stage 0: Processing train and test data.$(date)" + echo "$0: creating text, images.scp, utt2spk and spk2utt" + #local/prepare_data.sh data/download/ + for set in train test; do + local/process_data.py data/download/ \ + data/local/splits/${set}.txt data/${set} + image/fix_data_dir.sh data/${set} + done +fi + +if [ $stage -le 1 ]; then + echo "$0: $(date) stage 1: getting allowed image widths for e2e training..." + image/get_image2num_frames.py --feat-dim 40 data/train + image/get_allowed_lengths.py --frame-subsampling-factor 4 10 data/train + for set in train test; do + echo "$0: $(date) Extracting features, creating feats.scp file" + local/extract_features.sh --nj $nj --cmd "$cmd" --feat-dim 40 data/${set} + steps/compute_cmvn_stats.sh data/${set} || exit 1; + done + image/fix_data_dir.sh data/train +fi + +if [ $stage -le 3 ]; then + echo "$0: $(date) stage 3: BPE preparation" + # getting non-silence phones. + cut -d' ' -f2- data/train/text | \ +python3 <( +cat << "END" +import os, sys, io; +infile = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8'); +output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8'); +phone_dict = dict(); +for line in infile: + line_vect = line.strip().split(); + for word in line_vect: + for phone in word: + phone_dict[phone] = phone; + +for phone in phone_dict.keys(): + output.write(phone+ '\n'); +END + ) > data/local/text/cleaned/phones.txt + + cut -d' ' -f2- data/train/text > data/local/text/cleaned/train.txt + + echo "$0: learning BPE..." + # it is currently learned with only training text but we can also use all corpus text + # to learn BPE. phones are added so that one isolated occurance of every phone exists. + cat data/local/text/cleaned/phones.txt data/local/text/cleaned/train.txt | \ + utils/lang/bpe/prepend_words.py | utils/lang/bpe/learn_bpe.py -s 700 > data/local/bpe.txt || exit 1; +fi + +if [ $stage -le 4 ]; then + echo "$0: $(date) stage 4: applying BPE..." + echo "$0: applying BPE on train, test text..." + for set in test train; do + cut -d' ' -f1 data/$set/text > data/$set/ids + cut -d' ' -f2- data/$set/text | utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt | \ + sed 's/@@//g' > data/$set/bpe_text + mv data/$set/text data/$set/text.old + paste -d' ' data/$set/ids data/$set/bpe_text > data/$set/text + rm -f data/$set/bpe_text data/$set/ids + done + + echo "$0: applying BPE to corpus text..." + cat data/local/text/cleaned/corpus.txt | utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt | \ + sed 's/@@//g' > data/local/text/cleaned/bpe_corpus.txt + cat data/local/text/cleaned/val.txt | utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt | \ + sed 's/@@//g' > data/local/text/cleaned/bpe_val.txt +fi + +if [ $stage -le 5 ]; then + echo "$0: $(date) stage 5: Preparing dictionary and lang..." + local/prepare_dict.sh --dir data/local/dict + utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 4 --sil-prob 0.0 --position-dependent-phones false \ + data/local/dict "" data/lang/temp data/lang + utils/lang/bpe/add_final_optional_silence.sh --final-sil-prob 0.5 data/lang +fi + +if [ $stage -le 6 ]; then + echo "$0: $(date) stage 6: Calling the flat-start chain recipe..." + local/chain/run_e2e_cnn.sh +fi + +if [ $stage -le 7 ]; then + echo "$0: $(date) stage 7: Aligning the training data using the e2e chain model..." + steps/nnet3/align.sh --nj $nj --cmd "$cmd" \ + --scale-opts '--transition-scale=1.0 --acoustic-scale=1.0 --self-loop-scale=1.0' \ + data/train data/lang exp/chain/e2e_cnn_1a exp/chain/e2e_ali_train +fi + +chunk_width='340,300,200,100' +lang_decode=data/lang +lang_rescore=data/lang_rescore_6g +if [ $stage -le 8 ]; then + echo "$0: $(date) stage 8: Building a tree and training a regular chain model using the e2e alignments..." + local/chain/run_cnn_e2eali.sh --chunk_width $chunk_width +fi + +if [ $stage -le 9 ]; then + echo "$0: $(date) stage 9: Estimating a language model for decoding..." + local/train_lm.sh + utils/format_lm.sh data/lang data/local/local_lm/data/arpa/6gram_small.arpa.gz \ + data/local/dict/lexicon.txt data/lang + utils/build_const_arpa_lm.sh data/local/local_lm/data/arpa/6gram_unpruned.arpa.gz \ + data/lang data/lang_rescore_6g +fi + +if [ $stage -le 10 ] && $decode_e2e; then + echo "$0: $(date) stage 10: decoding end2end setup..." + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $lang_decode \ + exp/chain/e2e_cnn_1a/ exp/chain/e2e_cnn_1a/graph || exit 1; + + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj 30 --cmd "$cmd" --beam 12 \ + exp/chain/e2e_cnn_1a/graph data/test exp/chain/e2e_cnn_1a/decode_test || exit 1; + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/test exp/chain/e2e_cnn_1a/decode_test{,_rescored} || exit 1 + + echo "$0: Done. Date: $(date). Results:" + local/chain/compare_wer.sh exp/chain/e2e_cnn_1a/ +fi + +if [ $stage -le 11 ] && $decode_chain; then + echo "$0: $(date) stage 11: decoding chain alignment setup..." + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $lang_decode \ + exp/chain/cnn_e2eali_1a/ exp/chain/cnn_e2eali_1a/graph || exit 1; + + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj 30 --cmd "$cmd" --beam 12 \ + exp/chain/cnn_e2eali_1a/graph data/test exp/chain/cnn_e2eali_1a/decode_test || exit 1; + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/test exp/chain/cnn_e2eali_1a/decode_test{,_rescored} || exit 1 + + echo "$0: Done. Date: $(date). Results:" + local/chain/compare_wer.sh exp/chain/cnn_e2eali_1a +fi diff --git a/egs/yomdle_russian/v1/steps b/egs/yomdle_russian/v1/steps new file mode 120000 index 00000000000..1b186770dd1 --- /dev/null +++ b/egs/yomdle_russian/v1/steps @@ -0,0 +1 @@ +../../wsj/s5/steps/ \ No newline at end of file diff --git a/egs/yomdle_russian/v1/utils b/egs/yomdle_russian/v1/utils new file mode 120000 index 00000000000..a3279dc8679 --- /dev/null +++ b/egs/yomdle_russian/v1/utils @@ -0,0 +1 @@ +../../wsj/s5/utils/ \ No newline at end of file diff --git a/egs/yomdle_tamil/README.txt b/egs/yomdle_tamil/README.txt new file mode 100644 index 00000000000..0f295e5ae5f --- /dev/null +++ b/egs/yomdle_tamil/README.txt @@ -0,0 +1,3 @@ +This directory contains example scripts for OCR on the Yomdle and Slam datasets. +Training is done on the Yomdle dataset and testing is done on Slam. +LM rescoring is also done with extra corpus data obtained from various sources. diff --git a/egs/yomdle_tamil/v1/cmd.sh b/egs/yomdle_tamil/v1/cmd.sh new file mode 100755 index 00000000000..3d69546dfe8 --- /dev/null +++ b/egs/yomdle_tamil/v1/cmd.sh @@ -0,0 +1,12 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. +export cmd="queue.pl" diff --git a/egs/yomdle_tamil/v1/image b/egs/yomdle_tamil/v1/image new file mode 120000 index 00000000000..1668ee99922 --- /dev/null +++ b/egs/yomdle_tamil/v1/image @@ -0,0 +1 @@ +../../cifar/v1/image/ \ No newline at end of file diff --git a/egs/yomdle_tamil/v1/local/augment_data.sh b/egs/yomdle_tamil/v1/local/augment_data.sh new file mode 100755 index 00000000000..136bfd24eb2 --- /dev/null +++ b/egs/yomdle_tamil/v1/local/augment_data.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# Copyright 2018 Hossein Hadian +# 2018 Ashish Arora + +# Apache 2.0 +# This script performs data augmentation. + +nj=4 +cmd=run.pl +feat_dim=40 +verticle_shift=0 +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +srcdir=$1 +outdir=$2 +datadir=$3 + +mkdir -p $datadir/augmentations +echo "copying $srcdir to $datadir/augmentations/aug1, allowed length, creating feats.scp" + +for set in aug1; do + image/copy_data_dir.sh --spk-prefix $set- --utt-prefix $set- \ + $srcdir $datadir/augmentations/$set + cat $srcdir/allowed_lengths.txt > $datadir/augmentations/$set/allowed_lengths.txt + local/extract_features.sh --nj $nj --cmd "$cmd" --feat-dim $feat_dim \ + --vertical-shift $verticle_shift \ + --fliplr false --augment 'random_scale' $datadir/augmentations/$set +done + +echo " combine original data and data from different augmentations" +utils/combine_data.sh --extra-files images.scp $outdir $srcdir $datadir/augmentations/aug1 +cat $srcdir/allowed_lengths.txt > $outdir/allowed_lengths.txt diff --git a/egs/yomdle_tamil/v1/local/chain/compare_wer.sh b/egs/yomdle_tamil/v1/local/chain/compare_wer.sh new file mode 100755 index 00000000000..80f31e0f311 --- /dev/null +++ b/egs/yomdle_tamil/v1/local/chain/compare_wer.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +# this script is used for comparing decoding results between systems. +# e.g. local/chain/compare_wer.sh exp/chain/cnn{1a,1b} + +# Copyright 2017 Chun Chieh Chang +# 2017 Ashish Arora + +if [ $# == 0 ]; then + echo "Usage: $0: [ ... ]" + echo "e.g.: $0 exp/chain/cnn{1a,1b}" + exit 1 +fi + +echo "# $0 $*" +used_epochs=false + +echo -n "# System " +for x in $*; do printf "% 10s" " $(basename $x)"; done +echo + +echo -n "# WER " +for x in $*; do + wer=$(cat $x/decode_test/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# WER (rescored) " +for x in $*; do + wer=$(cat $x/decode_test_rescored/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# CER " +for x in $*; do + cer=$(cat $x/decode_test/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + +echo -n "# CER (rescored) " +for x in $*; do + cer=$(cat $x/decode_test_rescored/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + +if $used_epochs; then + exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. +fi + +echo -n "# Final train prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo diff --git a/egs/yomdle_tamil/v1/local/chain/run_cnn_e2eali.sh b/egs/yomdle_tamil/v1/local/chain/run_cnn_e2eali.sh new file mode 120000 index 00000000000..fcf59f917c1 --- /dev/null +++ b/egs/yomdle_tamil/v1/local/chain/run_cnn_e2eali.sh @@ -0,0 +1 @@ +tuning/run_cnn_e2eali_1b.sh \ No newline at end of file diff --git a/egs/yomdle_tamil/v1/local/chain/run_e2e_cnn.sh b/egs/yomdle_tamil/v1/local/chain/run_e2e_cnn.sh new file mode 100755 index 00000000000..f553467d4a6 --- /dev/null +++ b/egs/yomdle_tamil/v1/local/chain/run_e2e_cnn.sh @@ -0,0 +1,159 @@ +#!/bin/bash + +# Copyright 2017 Hossein Hadian + +# This script does end2end chain training (i.e. from scratch) +# local/chain/compare_wer.sh exp/chain/e2e_cnn_1a/ +# System e2e_cnn_1a +# score_basic score_nomalized +# WER 13.64 10.6 +# WER (rescored) 13.13 10.2 +# CER 2.99 3.0 +# CER (rescored) 2.88 2.9 +# Final train prob 0.0113 +# Final valid prob 0.0152 +# steps/info/chain_dir_info.pl exp/chain/e2e_cnn_1a +# exp/chain/e2e_cnn_1a: num-iters=48 nj=5..8 num-params=3.0M dim=40->352 combine=0.047->0.047 (over 2) logprob:train/valid[31,47,final]=(0.002,0.008,0.011/0.008,0.013,0.015) + +set -e +# configs for 'chain' +stage=0 +nj=30 +train_stage=-10 +get_egs_stage=-10 +affix=1a + +# training options +tdnn_dim=450 +minibatch_size=150=64,32/300=32,16/600=16,8/1200=8,4 +cmvn_opts="--norm-means=false --norm-vars=false" +train_set=train +lang_decode=data/lang +lang_rescore=data/lang_rescore_6g +decode_e2e=true +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 1 ]; then + steps/nnet3/chain/e2e/prepare_e2e.sh --nj $nj --cmd "$cmd" \ + --shared-phones true \ + --type mono \ + data/$train_set $lang $treedir + $cmd $treedir/log/make_phone_lm.log \ + cat data/$train_set/text \| \ + steps/nnet3/chain/e2e/text_to_phones.py data/lang \| \ + utils/sym2int.pl -f 2- data/lang/phones.txt \| \ + chain-est-phone-lm --num-extra-lm-states=500 \ + ark:- $treedir/phone_lm.fst +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $treedir/tree | grep num-pdfs | awk '{print $2}') + cnn_opts="l2-regularize=0.075" + tdnn_opts="l2-regularize=0.075" + output_opts="l2-regularize=0.1" + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=70" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=40 name=input + conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $output_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs +fi + +if [ $stage -le 3 ]; then + steps/nnet3/chain/e2e/train_e2e.py --stage $train_stage \ + --cmd "$cmd" \ + --feat.cmvn-opts "$cmvn_opts" \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.apply-deriv-weights true \ + --egs.stage $get_egs_stage \ + --egs.opts "--num_egs_diagnostic 100 --num_utts_subset 400" \ + --chain.frame-subsampling-factor 4 \ + --chain.alignment-subsampling-factor 4 \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs 3 \ + --trainer.optimization.momentum 0 \ + --trainer.optimization.num-jobs-initial 5 \ + --trainer.optimization.num-jobs-final 8 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.optimization.shrink-value 1.0 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs true \ + --feat-dir data/${train_set} \ + --tree-dir $treedir \ + --dir $dir || exit 1; +fi + +if [ $stage -le 4 ] && $decode_e2e; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $lang_decode \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 5 ] && $decode_e2e; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj 30 --cmd "$cmd" --beam 12 \ + $dir/graph data/test $dir/decode_test || exit 1; + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/test $dir/decode_test{,_rescored} || exit 1 + + echo "Done. Date: $(date). Results:" + local/chain/compare_wer.sh $dir +fi diff --git a/egs/yomdle_tamil/v1/local/chain/tuning/run_cnn_e2eali_1a.sh b/egs/yomdle_tamil/v1/local/chain/tuning/run_cnn_e2eali_1a.sh new file mode 100755 index 00000000000..03333f6d229 --- /dev/null +++ b/egs/yomdle_tamil/v1/local/chain/tuning/run_cnn_e2eali_1a.sh @@ -0,0 +1,236 @@ +#!/bin/bash + +# e2eali_1a is the same as 1a but uses the e2e chain model to get the +# lattice alignments and to build a tree + +# local/chain/compare_wer.sh exp/old/chain/cnn_e2eali_1a/ +# System cnn_e2eali_1a +# WER 15.68 +# CER 3.18 +# Final train prob -0.0331 +# Final valid prob -0.0395 + +# steps/info/chain_dir_info.pl exp/chain/cnn_e2eali_1a/ +# exp/old/chain/cnn_e2eali_1a/: num-iters=33 nj=3..16 num-params=5.2M dim=40->456 combine=-0.035->-0.035 (over 1) xent:train/valid[21,32,final]=(-0.226,-0.175,-0.169/-0.248,-0.202,-0.195) logprob:train/valid[21,32,final]=(-0.039,-0.034,-0.033/-0.046,-0.040,-0.039) + +# Normalize scoring +# WER = 11.7 +# CER = 3.3 + +set -e -o pipefail +stage=0 +nj=30 +train_set=train +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. +affix=_1a #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +common_egs_dir= +reporting_email= + +# chain options +train_stage=-10 +xent_regularize=0.1 +frame_subsampling_factor=4 +# training chunk-options +chunk_width=340,300,200,100 +num_leaves=500 +# we don't need extra left/right context for TDNN systems. +tdnn_dim=450 +# training options +srand=0 +remove_egs=false +lang_decode=data/lang +lang_rescore=data/lang_rescore_6g +decode_chain=false +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 2 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/nnet3/align_lats.sh --nj $nj --cmd "$cmd" \ + --acoustic-scale 1.0 \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0' \ + ${train_data_dir} data/lang $e2echain_model_dir $lat_dir + echo "" >$lat_dir/splice_opts +fi + +if [ $stage -le 3 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor $frame_subsampling_factor \ + --alignment-subsampling-factor 1 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$cmd" $num_leaves ${train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 4 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + cnn_opts="l2-regularize=0.075" + tdnn_opts="l2-regularize=0.075" + output_opts="l2-regularize=0.1" + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=90" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=40 name=input + conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' mod?els... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 5 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/iam-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=false \ + --chain.frame-subsampling-factor=$frame_subsampling_factor \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=900" \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=4 \ + --trainer.frames-per-iter=1000000 \ + --trainer.optimization.num-jobs-initial=3 \ + --trainer.optimization.num-jobs-final=16 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=32,16 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 6 ] && $decode_chain; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $lang_decode \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 7 ] && $decode_chain; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --beam 12 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/test $dir/decode_test || exit 1; + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/test $dir/decode_test{,_rescored} || exit 1 + + echo "Done. Date: $(date). Results:" + local/chain/compare_wer.sh $dir +fi diff --git a/egs/yomdle_tamil/v1/local/chain/tuning/run_cnn_e2eali_1b.sh b/egs/yomdle_tamil/v1/local/chain/tuning/run_cnn_e2eali_1b.sh new file mode 100755 index 00000000000..fb15ce10dde --- /dev/null +++ b/egs/yomdle_tamil/v1/local/chain/tuning/run_cnn_e2eali_1b.sh @@ -0,0 +1,236 @@ +#!/bin/bash + +# e2eali_1b is the same as e2eali_1a but has fewer CNN layers, smaller +# l2-regularize, more epochs and uses dropout. + +#local/chain/compare_wer.sh exp/chain/cnn_e2eali_1b/ +# System cnn_e2eali_1b +# score_basic score_nomalized +# WER 13.01 10.0 +# WER (rescored) 12.69 9.6 +# CER 2.78 3.0 +# CER (rescored) 2.70 2.8 +# Final train prob -0.0568 +# Final valid prob -0.0410 +#steps/info/chain_dir_info.pl exp/chain/cnn_e2eali_1b +#exp/chain/cnn_e2eali_1b: num-iters=67 nj=3..16 num-params=5.2M dim=40->464 combine=-0.052->-0.052 (over 1) xent:train/valid[43,66,final]=(-0.379,-0.319,-0.304/-0.291,-0.234,-0.227) logprob:train/valid[43,66,final]=(-0.069,-0.058,-0.057/-0.046,-0.041,-0.041) +set -e -o pipefail +stage=0 +nj=30 +train_set=train +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. +affix=_1b #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +common_egs_dir= +reporting_email= + +# chain options +train_stage=-10 +xent_regularize=0.1 +frame_subsampling_factor=4 +# training chunk-options +chunk_width=340,300,200,100 +num_leaves=500 +# we don't need extra left/right context for TDNN systems. +tdnn_dim=550 +# training options +srand=0 +remove_egs=false +lang_decode=data/lang +lang_rescore=data/lang_rescore_6g +decode_chain=true +dropout_schedule='0,0@0.20,0.2@0.50,0' +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 2 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/nnet3/align_lats.sh --nj $nj --cmd "$cmd" \ + --acoustic-scale 1.0 \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0' \ + ${train_data_dir} data/lang $e2echain_model_dir $lat_dir + echo "" >$lat_dir/splice_opts +fi + +if [ $stage -le 3 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor $frame_subsampling_factor \ + --alignment-subsampling-factor 1 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$cmd" $num_leaves ${train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 4 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + cnn_opts="l2-regularize=0.03 dropout-proportion=0.0" + tdnn_opts="l2-regularize=0.03" + output_opts="l2-regularize=0.04" + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=90" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=40 name=input + + conv-relu-batchnorm-dropout-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-dropout-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-dropout-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-dropout-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-dropout-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + relu-batchnorm-dropout-layer name=tdnn1 input=Append(-4,-2,0,2,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + relu-batchnorm-dropout-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + relu-batchnorm-dropout-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' mod?els... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 5 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/iam-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=false \ + --chain.frame-subsampling-factor=$frame_subsampling_factor \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=900" \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=8 \ + --trainer.frames-per-iter=2000000 \ + --trainer.optimization.num-jobs-initial=3 \ + --trainer.optimization.num-jobs-final=16 \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=32,16 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 6 ] && $decode_chain; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $lang_decode \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 7 ] && $decode_chain; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --beam 12 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/test $dir/decode_test || exit 1; + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/test $dir/decode_test{,_rescored} || exit 1 + + echo "Done. Date: $(date). Results:" + local/chain/compare_wer.sh $dir +fi diff --git a/egs/yomdle_tamil/v1/local/check_tools.sh b/egs/yomdle_tamil/v1/local/check_tools.sh new file mode 100755 index 00000000000..5b4d3107d3b --- /dev/null +++ b/egs/yomdle_tamil/v1/local/check_tools.sh @@ -0,0 +1,43 @@ +#!/bin/bash -u + +# Copyright 2015 (c) Johns Hopkins University (Jan Trmal ) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +[ -f ./path.sh ] && . ./path.sh +set +e + +command -v python3 >&/dev/null \ + || { echo >&2 "python3 not found on PATH. You will have to install Python3, preferably >= 3.6"; exit 1; } + +python3 -c "import numpy" +if [ $? -ne 0 ] ; then + echo >&2 "This recipe needs numpy installed." + exit 1 +fi + +python3 -c "import scipy" +if [ $? -ne 0 ] ; then + echo >&2 "This recipe needs scipy installed." + exit 1 +fi + +python3 -c "import scipy.misc; scipy.misc.__dict__['imread']" +if [ $? -ne 0 ] ; then + echo >&2 "This recipe needs scipy-image and Pillow installed." + exit 1 +fi + + +exit 0 diff --git a/egs/yomdle_tamil/v1/local/extract_features.sh b/egs/yomdle_tamil/v1/local/extract_features.sh new file mode 100755 index 00000000000..3880ebad3e8 --- /dev/null +++ b/egs/yomdle_tamil/v1/local/extract_features.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +# Copyright 2017 Yiwen Shao +# 2018 Ashish Arora + +# Apache 2.0 +# This script runs the make features script in parallel. + +nj=4 +cmd=run.pl +feat_dim=40 +augment='no_aug' +fliplr=false +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +data=$1 +featdir=$data/data +scp=$data/images.scp +logdir=$data/log + +mkdir -p $logdir +mkdir -p $featdir + +# make $featdir an absolute pathname +featdir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $featdir ${PWD}` + +for n in $(seq $nj); do + split_scps="$split_scps $logdir/images.$n.scp" +done + +# split images.scp +utils/split_scp.pl $scp $split_scps || exit 1; + +$cmd JOB=1:$nj $logdir/extract_features.JOB.log \ + image/ocr/make_features.py $logdir/images.JOB.scp \ + --allowed_len_file_path $data/allowed_lengths.txt \ + --feat-dim $feat_dim --fliplr $fliplr --augment_type $augment \| \ + copy-feats --compress=true --compression-method=7 \ + ark:- ark,scp:$featdir/images.JOB.ark,$featdir/images.JOB.scp + +## aggregates the output scp's to get feats.scp +for n in $(seq $nj); do + cat $featdir/images.$n.scp || exit 1; +done > $data/feats.scp || exit 1 diff --git a/egs/yomdle_tamil/v1/local/prepare_dict.sh b/egs/yomdle_tamil/v1/local/prepare_dict.sh new file mode 100755 index 00000000000..22db5ae834d --- /dev/null +++ b/egs/yomdle_tamil/v1/local/prepare_dict.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash + +# Copyright 2017 Hossein Hadian +# 2017 Babak Rekabdar +# 2017 Chun Chieh Chang +# 2017 Ashish Arora + +# This script prepares the dictionary. + +set -e +dir=data/local/dict +. ./utils/parse_options.sh || exit 1; + +mkdir -p $dir + +local/prepare_lexicon.py $dir + +cut -d' ' -f2- $dir/lexicon.txt | sed 's/SIL//g' | tr ' ' '\n' | sort -u | sed '/^$/d' >$dir/nonsilence_phones.txt || exit 1; + +echo ' SIL' >> $dir/lexicon.txt + +echo SIL > $dir/silence_phones.txt + +echo SIL >$dir/optional_silence.txt + +echo -n "" >$dir/extra_questions.txt diff --git a/egs/yomdle_tamil/v1/local/prepare_lexicon.py b/egs/yomdle_tamil/v1/local/prepare_lexicon.py new file mode 100755 index 00000000000..3de96056c2a --- /dev/null +++ b/egs/yomdle_tamil/v1/local/prepare_lexicon.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Babak Rekabdar +# 2017 Hossein Hadian +# 2017 Chun Chieh Chang +# 2017 Ashish Arora +# Apache 2.0 + +# This script prepares lexicon for BPE. It gets the set of all words that occur in data/train/text. +# Since this lexicon is based on BPE, it replaces '|' with silence. + +import argparse +import os + +parser = argparse.ArgumentParser(description="""Creates the list of characters and words in lexicon""") +parser.add_argument('dir', type=str, help='output path') +args = parser.parse_args() + +### main ### +lex = {} +text_path = os.path.join('data', 'train', 'text') +with open(text_path, 'r', encoding='utf-8') as f: + for line in f: + line_vect = line.strip().split(' ') + for i in range(1, len(line_vect)): + characters = list(line_vect[i]) + characters = " ".join([ 'SIL' if char == '|' else char for char in characters]) + characters = list(characters) + characters = "".join([ '' if char == '#' else char for char in characters]) + lex[line_vect[i]] = characters + +with open(os.path.join(args.dir, 'lexicon.txt'), 'w', encoding='utf-8') as fp: + for key in sorted(lex): + fp.write(key + " " + lex[key] + "\n") diff --git a/egs/yomdle_tamil/v1/local/process_corpus.py b/egs/yomdle_tamil/v1/local/process_corpus.py new file mode 100755 index 00000000000..b39030270b7 --- /dev/null +++ b/egs/yomdle_tamil/v1/local/process_corpus.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# Copyright 2018 Ashish Arora +# Apache 2.0 +# This script reads valid phones and removes the lines in the corpus +# which have any other phone. + +import os +import sys, io + +phone_file = os.path.join('data/local/text/cleaned/phones.txt') +infile = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') +output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') +phone_dict = dict() +with open(phone_file, 'r', encoding='utf-8') as phone_fh: + for line in phone_fh: + line = line.strip().split()[0] + phone_dict[line] = line + +phone_dict[' '] = ' ' +corpus_text = list() +for line in infile: + text = line.strip() + skip_text = False + for phone in text: + if phone not in phone_dict.keys(): + skip_text = True + break + if not skip_text: + output.write(text+ '\n') + diff --git a/egs/yomdle_tamil/v1/local/process_data.py b/egs/yomdle_tamil/v1/local/process_data.py new file mode 100755 index 00000000000..7c116165ddd --- /dev/null +++ b/egs/yomdle_tamil/v1/local/process_data.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora +# 2018 Chun Chieh Chang + +""" This script reads the extracted Tamil OCR (yomdle and slam) database files + and creates the following files (for the data subset selected via --dataset): + text, utt2spk, images.scp. + Eg. local/process_data.py data/download/ data/local/splits/train.txt data/train + + Eg. text file: english_phone_books_0001_1 To sum up, then, it would appear that + utt2spk file: english_phone_books_0001_0 english_phone_books_0001 + images.scp file: english_phone_books_0001_0 \ + data/download/truth_line_image/english_phone_books_0001_0.png +""" + +import argparse +import os +import sys +import csv +import itertools +import unicodedata +import re +import string +parser = argparse.ArgumentParser(description="Creates text, utt2spk, and images.scp files") +parser.add_argument('database_path', type=str, help='Path to data') +parser.add_argument('data_split', type=str, help='Path to file that contain datasplits') +parser.add_argument('out_dir', type=str, help='directory to output files') +args = parser.parse_args() + +### main ### +print("Processing '{}' data...".format(args.out_dir)) + +text_file = os.path.join(args.out_dir, 'text') +text_fh = open(text_file, 'w', encoding='utf-8') +utt2spk_file = os.path.join(args.out_dir, 'utt2spk') +utt2spk_fh = open(utt2spk_file, 'w', encoding='utf-8') +image_file = os.path.join(args.out_dir, 'images.scp') +image_fh = open(image_file, 'w', encoding='utf-8') + +with open(args.data_split) as f: + for line in f: + line = line.strip() + image_id = line + image_filename = image_id + '.png' + image_filepath = os.path.join(args.database_path, 'truth_line_image', image_filename) + if not os.path.isfile (image_filepath): + print("File does not exist {}".format(image_filepath)) + continue + line_id = int(line.split('_')[-1]) + csv_filename = '_'.join(line.split('_')[:-1]) + '.csv' + csv_filepath = os.path.join(args.database_path, 'truth_csv', csv_filename) + csv_file = open(csv_filepath, 'r', encoding='utf-8') + for row in csv.reader(csv_file): + if row[1] == image_filename: + text = row[11] + if not text: + continue + text_fh.write(image_id + ' ' + text + '\n') + utt2spk_fh.write(image_id + ' ' + '_'.join(line.split('_')[:-1]) + '\n') + image_fh.write(image_id + ' ' + image_filepath + '\n') diff --git a/egs/yomdle_tamil/v1/local/score.sh b/egs/yomdle_tamil/v1/local/score.sh new file mode 100755 index 00000000000..31564d25326 --- /dev/null +++ b/egs/yomdle_tamil/v1/local/score.sh @@ -0,0 +1,5 @@ +#!/bin/bash + + +steps/scoring/score_kaldi_wer.sh "$@" +steps/scoring/score_kaldi_cer.sh --stage 2 "$@" diff --git a/egs/yomdle_tamil/v1/local/semisup/chain/run_cnn_chainali_semisupervised_1a.sh b/egs/yomdle_tamil/v1/local/semisup/chain/run_cnn_chainali_semisupervised_1a.sh new file mode 100755 index 00000000000..f6b2c1bac42 --- /dev/null +++ b/egs/yomdle_tamil/v1/local/semisup/chain/run_cnn_chainali_semisupervised_1a.sh @@ -0,0 +1,327 @@ +#!/bin/bash + +# Copyright 2017 Vimal Manohar +# 2018 Ashish Arora +# Apache 2.0 +# This script is semi-supervised recipe with 25k line images of supervised data +# and 22k line images of unsupervised data with naive splitting. +# Based on "Semi-Supervised Training of Acoustic Models using Lattice-Free MMI", +# Vimal Manohar, Hossein Hadian, Daniel Povey, Sanjeev Khudanpur, ICASSP 2018 +# http://www.danielpovey.com/files/2018_icassp_semisupervised_mmi.pdf +# local/semisup/run_semisup.sh shows how to call this. + +# We use 3-gram LM trained on 5M lines of auxilary data. +# This script uses the same tree as that for the seed model. +# Unsupervised set: train_unsup (25k tamil line images) +# unsup_frames_per_eg=150 +# Deriv weights: Lattice posterior of best path pdf +# Unsupervised weight: 1.0 +# Weights for phone LM (supervised, unsupervised): 3,2 +# LM for decoding unsupervised data: 4gram +# Supervision: Naive split lattices +# output-0 and output-1 are for superivsed and unsupervised data respectively. + +# local/chain/compare_wer.sh exp/chain/cnn_e2eali_1b/ exp/semisup_100k/chain/tdnn_semisup_1a/ +# System cnn_e2eali_1b tdnn_semisup_1a +# WER 15.06 13.83 +# CER 3.15 2.83 +# Final train prob -0.0343 0.6103-0.0360 +# Final valid prob -0.0403 0.6054-0.0418 + +# steps/info/chain_dir_info.pl exp/semisup_100k/chain/tdnn_semisup_1a/ +# exp/semisup_100k/chain/tdnn_semisup_1a/: num-iters=58 nj=6..16 num-params=3.7M dim=40->456 combine=0.240->0.240 (over 1) + +# Normalize scoring +#WER = 10.4 +#CER = 2.9 + +set -u -e -o pipefail + +stage=0 # Start from -1 for supervised seed system training +train_stage=-100 +nj=30 +test_nj=30 + +# The following 3 options decide the output directory for semi-supervised +# chain system +# dir=${exp_root}/chain${chain_affix}/tdnn${tdnn_affix} +exp_root=exp/semisup_100k +chain_affix= # affix for chain dir +tdnn_affix=_semisup_1a # affix for semi-supervised chain system + +# Datasets-Expects supervised_set and unsupervised_set +supervised_set=train +unsupervised_set=train_unsup + +# Input seed system +sup_chain_dir=exp/chain/cnn_e2eali_1b # supervised chain system +sup_lat_dir=exp/chain/e2e_train_lats # Seed model options +sup_tree_dir=exp/chain/tree_e2e # tree directory for supervised chain system + +# Semi-supervised options +supervision_weights=1.0,1.0 # Weights for supervised, unsupervised data egs. + # Can be used to scale down the effect of unsupervised data + # by using a smaller scale for it e.g. 1.0,0.3 +lm_weights=3,2 # Weights on phone counts from supervised, unsupervised data for denominator FST creation + +sup_egs_dir= # Supply this to skip supervised egs creation +unsup_egs_dir= # Supply this to skip unsupervised egs creation +unsup_egs_opts= # Extra options to pass to unsupervised egs creation +# Neural network opts +xent_regularize=0.1 +tdnn_dim=450 +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +if [ -f ./path.sh ]; then . ./path.sh; fi +. ./utils/parse_options.sh + +lang_decode=data/lang +lang_rescore=data/lang_rescore_6g +dir=$exp_root/chain$chain_affix/tdnn$tdnn_affix +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=40 name=input + + conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts + + # We use separate outputs for supervised and unsupervised data + # so we can properly track the train and valid objectives. + output name=output-0 input=output.affine + output name=output-1 input=output.affine + output name=output-0-xent input=output-xent.log-softmax + output name=output-1-xent input=output-xent.log-softmax +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +# Get values for $model_left_context, $model_right_context +. $dir/configs/vars + +left_context=$model_left_context +right_context=$model_right_context + +egs_left_context=$(perl -e "print int($left_context + $frame_subsampling_factor / 2)") +egs_right_context=$(perl -e "print int($right_context + $frame_subsampling_factor / 2)") + +if [ -z "$sup_egs_dir" ]; then + sup_egs_dir=$dir/egs_$supervised_set + frames_per_eg=$(cat $sup_chain_dir/egs/info/frames_per_eg) + + if [ $stage -le 12 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $sup_egs_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/fisher_english-$(date +'%m_%d_%H_%M')/s5c/$sup_egs_dir/storage $sup_egs_dir/storage + fi + mkdir -p $sup_egs_dir/ + touch $sup_egs_dir/.nodelete # keep egs around when that run dies. + + echo "$0: generating egs from the supervised data" + steps/nnet3/chain/get_egs.sh --cmd "$cmd" \ + --left-tolerance 3 --right-tolerance 3 \ + --left-context $egs_left_context --right-context $egs_right_context \ + --frame-subsampling-factor $frame_subsampling_factor \ + --alignment-subsampling-factor 1 \ + --frames-overlap-per-eg 0 --constrained false \ + --frames-per-eg $frames_per_eg \ + --frames-per-iter 2000000 \ + --cmvn-opts "$cmvn_opts" \ + --generate-egs-scp true \ + data/${supervised_set} $dir \ + $sup_lat_dir $sup_egs_dir + fi +else + frames_per_eg=$(cat $sup_egs_dir/info/frames_per_eg) +fi + +unsup_frames_per_eg=340,300,200,100 # Using a frames-per-eg of 150 for unsupervised data + # was found to be better than allowing smaller chunks + # (160,140,110,80) like for supervised system +lattice_lm_scale=0.5 # lm-scale for using the weights from unsupervised lattices when + # creating numerator supervision +lattice_prune_beam=6.0 # beam for pruning the lattices prior to getting egs + # for unsupervised data +tolerance=3 # frame-tolerance for chain training + +unsup_lat_dir=$sup_chain_dir/decode_$unsupervised_set +if [ -z "$unsup_egs_dir" ]; then + unsup_egs_dir=$dir/egs_$unsupervised_set + + if [ $stage -le 13 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $unsup_egs_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/fisher_english-$(date +'%m_%d_%H_%M')/s5c/$unsup_egs_dir/storage $unsup_egs_dir/storage + fi + mkdir -p $unsup_egs_dir + touch $unsup_egs_dir/.nodelete # keep egs around when that run dies. + + echo "$0: generating egs from the unsupervised data" + steps/nnet3/chain/get_egs.sh \ + --cmd "$cmd" --alignment-subsampling-factor 1 \ + --left-tolerance $tolerance --right-tolerance $tolerance \ + --left-context $egs_left_context --right-context $egs_right_context \ + --frames-per-eg $unsup_frames_per_eg --frames-per-iter 2000000 \ + --frame-subsampling-factor $frame_subsampling_factor \ + --cmvn-opts "$cmvn_opts" --lattice-lm-scale $lattice_lm_scale \ + --lattice-prune-beam "$lattice_prune_beam" \ + --deriv-weights-scp $sup_chain_dir/best_path_$unsupervised_set/weights.scp \ + --generate-egs-scp true $unsup_egs_opts \ + data/$unsupervised_set $dir \ + $unsup_lat_dir $unsup_egs_dir + fi +fi + +comb_egs_dir=$dir/comb_egs +if [ $stage -le 14 ]; then + steps/nnet3/chain/multilingual/combine_egs.sh --cmd "$cmd" \ + --block-size 64 \ + --lang2weight $supervision_weights 2 \ + $sup_egs_dir $unsup_egs_dir $comb_egs_dir + touch $comb_egs_dir/.nodelete # keep egs around when that run dies. +fi + +if [ $train_stage -le -4 ]; then + # This is to skip stages of den-fst creation, which was already done. + train_stage=-4 +fi + +chunk_width=340,300,200,100 +if [ $stage -le 15 ]; then + steps/nnet3/chain/train.py --stage $train_stage \ + --egs.dir "$comb_egs_dir" \ + --egs.chunk-width=$chunk_width \ + --cmd "$cmd" \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00001 \ + --chain.apply-deriv-weights=true \ + --chain.frame-subsampling-factor=$frame_subsampling_factor \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=900" \ + --trainer.srand=0 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=32,16 \ + --trainer.optimization.momentum=0.0 \ + --trainer.frames-per-iter=2000000 \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs 5 \ + --trainer.optimization.num-jobs-initial 6 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ + --cleanup.remove-egs false \ + --feat-dir data/$supervised_set \ + --tree-dir $sup_tree_dir \ + --lat-dir $sup_lat_dir \ + --dir $dir || exit 1; + +fi + +if [ $stage -le 17 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 $lang_decode $dir $dir/graph +fi + +if [ $stage -le 18 ]; then + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --beam 12 --frames-per-chunk 340 --nj $nj --cmd "$cmd" \ + $dir/graph data/test $dir/decode_test + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/test $dir/decode_test{,_rescored} || exit 1 +fi +exit 0; + diff --git a/egs/yomdle_tamil/v1/local/semisup/chain/run_cnn_chainali_semisupervised_1b.sh b/egs/yomdle_tamil/v1/local/semisup/chain/run_cnn_chainali_semisupervised_1b.sh new file mode 100755 index 00000000000..17d59642b05 --- /dev/null +++ b/egs/yomdle_tamil/v1/local/semisup/chain/run_cnn_chainali_semisupervised_1b.sh @@ -0,0 +1,323 @@ +#!/bin/bash + +# Copyright 2017 Vimal Manohar +# 2018 Ashish Arora +# Apache 2.0 +# This script is semi-supervised recipe with 25k line images of supervised data +# and 22k line images of unsupervised data with naive splitting. +# Based on "Semi-Supervised Training of Acoustic Models using Lattice-Free MMI", +# Vimal Manohar, Hossein Hadian, Daniel Povey, Sanjeev Khudanpur, ICASSP 2018 +# http://www.danielpovey.com/files/2018_icassp_semisupervised_mmi.pdf +# local/semisup/run_semisup.sh shows how to call this. + +# We use 3-gram LM trained on 5M lines of auxilary data. +# This script uses the same tree as that for the seed model. +# Unsupervised set: train_unsup (25k tamil line images) +# unsup_frames_per_eg=150 +# Deriv weights: Lattice posterior of best path pdf +# Unsupervised weight: 1.0 +# Weights for phone LM (supervised, unsupervised): 3,2 +# LM for decoding unsupervised data: 4gram +# Supervision: Naive split lattices +# output-0 and output-1 are for superivsed and unsupervised data respectively. + +# local/chain/compare_wer.sh exp/semisup_100k/chain/tdnn_semisup_1b/ +# System tdnn_semisup_1b +# score_basic score_normalized +# WER 13.73 10.2 +# WER (rescored) 12.80 9.4 +# CER 2.78 2.8 +# CER (rescored) 2.57 2.7 +# Final train prob 0.6138-0.0337 +# Final valid prob 0.6115-0.0399 + +# steps/info/chain_dir_info.pl exp/semisup_100k/chain/tdnn_semisup_1b/ +# exp/semisup_100k/chain/tdnn_semisup_1b/: num-iters=46 nj=6..16 num-params=5.7M dim=40->456 combine=0.239->0.239 (over 1) + +set -u -e -o pipefail +stage=0 # Start from -1 for supervised seed system training +train_stage=-100 +nj=30 +test_nj=30 + +# The following 3 options decide the output directory for semi-supervised +# chain system +# dir=${exp_root}/chain${chain_affix}/tdnn${tdnn_affix} +exp_root=exp/semisup_100k +chain_affix= # affix for chain dir +tdnn_affix=_semisup_1b # affix for semi-supervised chain system + +# Datasets-Expects supervised_set and unsupervised_set +supervised_set=train +unsupervised_set=train_unsup + +# Input seed system +sup_chain_dir=exp/chain/cnn_e2eali_1b # supervised chain system +sup_lat_dir=exp/chain/e2e_train_lats # Seed model options +sup_tree_dir=exp/chain/tree_e2e # tree directory for supervised chain system + +# Semi-supervised options +supervision_weights=1.0,1.0 # Weights for supervised, unsupervised data egs. + # Can be used to scale down the effect of unsupervised data + # by using a smaller scale for it e.g. 1.0,0.3 +lm_weights=3,2 # Weights on phone counts from supervised, unsupervised data for denominator FST creation + +sup_egs_dir= # Supply this to skip supervised egs creation +unsup_egs_dir= # Supply this to skip unsupervised egs creation +unsup_egs_opts= # Extra options to pass to unsupervised egs creation +# Neural network opts +xent_regularize=0.1 +tdnn_dim=550 +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +if [ -f ./path.sh ]; then . ./path.sh; fi +. ./utils/parse_options.sh + +lang_decode=data/lang +lang_rescore=data/lang_rescore_6g +dropout_schedule='0,0@0.20,0.2@0.50,0' +dir=$exp_root/chain$chain_affix/tdnn$tdnn_affix +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=40 name=input + conv-relu-batchnorm-dropout-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-dropout-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-dropout-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-dropout-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-dropout-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common3 height-subsample-out=2 + relu-batchnorm-dropout-layer name=tdnn1 input=Append(-4,-2,0,2,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + relu-batchnorm-dropout-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + relu-batchnorm-dropout-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts dropout-proportion=0.0 + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts + + # We use separate outputs for supervised and unsupervised data + # so we can properly track the train and valid objectives. + output name=output-0 input=output.affine + output name=output-1 input=output.affine + output name=output-0-xent input=output-xent.log-softmax + output name=output-1-xent input=output-xent.log-softmax +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +# Get values for $model_left_context, $model_right_context +. $dir/configs/vars + +left_context=$model_left_context +right_context=$model_right_context + +egs_left_context=$(perl -e "print int($left_context + $frame_subsampling_factor / 2)") +egs_right_context=$(perl -e "print int($right_context + $frame_subsampling_factor / 2)") + +if [ -z "$sup_egs_dir" ]; then + sup_egs_dir=$dir/egs_$supervised_set + frames_per_eg=$(cat $sup_chain_dir/egs/info/frames_per_eg) + + if [ $stage -le 12 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $sup_egs_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/fisher_english-$(date +'%m_%d_%H_%M')/s5c/$sup_egs_dir/storage $sup_egs_dir/storage + fi + mkdir -p $sup_egs_dir/ + touch $sup_egs_dir/.nodelete # keep egs around when that run dies. + + echo "$0: generating egs from the supervised data" + steps/nnet3/chain/get_egs.sh --cmd "$cmd" \ + --left-tolerance 3 --right-tolerance 3 \ + --left-context $egs_left_context --right-context $egs_right_context \ + --frame-subsampling-factor $frame_subsampling_factor \ + --alignment-subsampling-factor 1 \ + --frames-overlap-per-eg 0 --constrained false \ + --frames-per-eg $frames_per_eg \ + --frames-per-iter 2000000 \ + --cmvn-opts "$cmvn_opts" \ + --generate-egs-scp true \ + data/${supervised_set} $dir \ + $sup_lat_dir $sup_egs_dir + fi +else + frames_per_eg=$(cat $sup_egs_dir/info/frames_per_eg) +fi + +unsup_frames_per_eg=340,300,200,100 # Using a frames-per-eg of 150 for unsupervised data + # was found to be better than allowing smaller chunks + # (160,140,110,80) like for supervised system +lattice_lm_scale=0.5 # lm-scale for using the weights from unsupervised lattices when + # creating numerator supervision +lattice_prune_beam=6.0 # beam for pruning the lattices prior to getting egs + # for unsupervised data +tolerance=3 # frame-tolerance for chain training + +unsup_lat_dir=$sup_chain_dir/decode_$unsupervised_set +if [ -z "$unsup_egs_dir" ]; then + unsup_egs_dir=$dir/egs_$unsupervised_set + + if [ $stage -le 13 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $unsup_egs_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/fisher_english-$(date +'%m_%d_%H_%M')/s5c/$unsup_egs_dir/storage $unsup_egs_dir/storage + fi + mkdir -p $unsup_egs_dir + touch $unsup_egs_dir/.nodelete # keep egs around when that run dies. + + echo "$0: generating egs from the unsupervised data" + steps/nnet3/chain/get_egs.sh \ + --cmd "$cmd" --alignment-subsampling-factor 1 \ + --left-tolerance $tolerance --right-tolerance $tolerance \ + --left-context $egs_left_context --right-context $egs_right_context \ + --frames-per-eg $unsup_frames_per_eg --frames-per-iter 2000000 \ + --frame-subsampling-factor $frame_subsampling_factor \ + --cmvn-opts "$cmvn_opts" --lattice-lm-scale $lattice_lm_scale \ + --lattice-prune-beam "$lattice_prune_beam" \ + --deriv-weights-scp $sup_chain_dir/best_path_$unsupervised_set/weights.scp \ + --generate-egs-scp true $unsup_egs_opts \ + data/$unsupervised_set $dir \ + $unsup_lat_dir $unsup_egs_dir + fi +fi + +comb_egs_dir=$dir/comb_egs +if [ $stage -le 14 ]; then + steps/nnet3/chain/multilingual/combine_egs.sh --cmd "$cmd" \ + --block-size 64 \ + --lang2weight $supervision_weights 2 \ + $sup_egs_dir $unsup_egs_dir $comb_egs_dir + touch $comb_egs_dir/.nodelete # keep egs around when that run dies. +fi + +if [ $train_stage -le -4 ]; then + # This is to skip stages of den-fst creation, which was already done. + train_stage=-4 +fi + +chunk_width=340,300,200,100 +if [ $stage -le 15 ]; then + steps/nnet3/chain/train.py --stage $train_stage \ + --egs.dir "$comb_egs_dir" \ + --egs.chunk-width=$chunk_width \ + --cmd "$cmd" \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00001 \ + --chain.apply-deriv-weights=true \ + --chain.frame-subsampling-factor=$frame_subsampling_factor \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=900" \ + --trainer.srand=0 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=32,16 \ + --trainer.optimization.momentum=0.0 \ + --trainer.frames-per-iter=2000000 \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs 8 \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.optimization.num-jobs-initial 6 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ + --cleanup.remove-egs false \ + --feat-dir data/$supervised_set \ + --tree-dir $sup_tree_dir \ + --lat-dir $sup_lat_dir \ + --dir $dir || exit 1; + +fi + +if [ $stage -le 17 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 $lang_decode $dir $dir/graph +fi + +if [ $stage -le 18 ]; then + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --beam 12 --frames-per-chunk 340 --nj $nj --cmd "$cmd" \ + $dir/graph data/test $dir/decode_test + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/test $dir/decode_test{,_rescored} || exit 1 +fi +exit 0; + diff --git a/egs/yomdle_tamil/v1/local/semisup/process_data.py b/egs/yomdle_tamil/v1/local/semisup/process_data.py new file mode 100755 index 00000000000..94ad770ec2d --- /dev/null +++ b/egs/yomdle_tamil/v1/local/semisup/process_data.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora +# 2018 Chun Chieh Chang + +""" This script reads the slam boxed Tamil OCR dataset and creates the following + files utt2spk, images.scp. Since boxed data do not have transcripts, it do not + creates text file. It is created as a separate script, because the data that + local/process_data.py is processing contains some empty transcripts which + should be removed or it will create bug while applying BPE. + + Eg. local/semisup/process_data.py data/download/ data/local/splits/train_unsup.txt + data/train_unsup + + Eg. utt2spk file: english_phone_books_0001_0 english_phone_books_0001 + images.scp file: english_phone_books_0001_0 \ + data/download/truth_line_image/english_phone_books_0001_0.png +""" +import argparse +import os +import sys +import csv +import itertools +import unicodedata +import re +import string +parser = argparse.ArgumentParser(description="Creates text, utt2spk, and images.scp files") +parser.add_argument('database_path', type=str, help='Path to data') +parser.add_argument('data_split', type=str, help='Path to file that contain datasplits') +parser.add_argument('out_dir', type=str, help='directory to output files') +args = parser.parse_args() + +### main ### +print("Processing '{}' data...".format(args.out_dir)) + +utt2spk_file = os.path.join(args.out_dir, 'utt2spk') +utt2spk_fh = open(utt2spk_file, 'w', encoding='utf-8') +image_file = os.path.join(args.out_dir, 'images.scp') +image_fh = open(image_file, 'w', encoding='utf-8') +text_file = os.path.join(args.out_dir, 'text') +text_fh = open(text_file, 'w', encoding='utf-8') + +with open(args.data_split) as f: + for line in f: + line = line.strip() + image_id = line + image_filename = image_id + '.png' + image_filepath = os.path.join(args.database_path, 'truth_line_image', image_filename) + if not os.path.isfile (image_filepath): + print("File does not exist {}".format(image_filepath)) + continue + line_id = int(line.split('_')[-1]) + csv_filename = '_'.join(line.split('_')[:-1]) + '.csv' + csv_filepath = os.path.join(args.database_path, 'truth_csv', csv_filename) + csv_file = open(csv_filepath, 'r', encoding='utf-8') + for row in csv.reader(csv_file): + if row[1] == image_filename: + text = 'semisup' + text_fh.write(image_id + ' ' + text + '\n') + utt2spk_fh.write(image_id + ' ' + '_'.join(line.split('_')[:-1]) + '\n') + image_fh.write(image_id + ' ' + image_filepath + '\n') diff --git a/egs/yomdle_tamil/v1/local/semisup/run_semisup.sh b/egs/yomdle_tamil/v1/local/semisup/run_semisup.sh new file mode 100755 index 00000000000..0b82def2ead --- /dev/null +++ b/egs/yomdle_tamil/v1/local/semisup/run_semisup.sh @@ -0,0 +1,70 @@ +#!/bin/bash + +# Copyright 2017 Vimal Manohar +# 2018 Ashish Arora +# Apache 2.0 + +# This script demonstrates semi-supervised training using 25k line images of +# supervised data and 22k line images of unsupervised data. +# We assume the supervised data is in data/train and unsupervised data +# is in data/train_unsup. +# For LM training, we use 5 million lines of tamil text. + +set -e +set -o pipefail +stage=0 +nj=30 +exp_root=exp/semisup_100k +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +mkdir -p data/train_unsup/data +if [ $stage -le 0 ]; then + echo "stage 0: Processing train unsupervised data...$(date)" + local/semisup/process_data.py data/download/ \ + data/local/splits/train_unsup.txt \ + data/train_unsup + image/fix_data_dir.sh data/train_unsup +fi + +if [ $stage -le 1 ]; then + echo "stage 1: Obtaining image groups. calling get_image2num_frames..." + image/get_image2num_frames.py --feat-dim 40 data/train_unsup + image/get_allowed_lengths.py --frame-subsampling-factor 4 10 data/train_unsup + echo "Extracting features and calling compute_cmvn_stats: $(date) " + local/extract_features.sh --nj $nj --cmd "$cmd" --feat-dim 40 data/train_unsup + steps/compute_cmvn_stats.sh data/train_unsup || exit 1; + image/fix_data_dir.sh data/train_unsup +fi + +for f in data/train/utt2spk data/train_unsup/utt2spk \ + data/train/text; do + if [ ! -f $f ]; then + echo "$0: Could not find $f" + exit 1; + fi +done + +# Prepare semi-supervised train set +if [ $stage -le 1 ]; then + utils/combine_data.sh data/semisup100k_250k \ + data/train_aug data/train_unsup || exit 1 +fi + +############################################################################### +# Semi-supervised training using 25k line images supervised data and +# 22k hours unsupervised data. We use tree, lattices +# and seed chain system from the previous stage. +############################################################################### +if [ $stage -le 2 ]; then + local/semisup/chain/run_cnn_chainali_semisupervised_1b.sh \ + --supervised-set train_aug \ + --unsupervised-set train_unsup \ + --sup-chain-dir exp/chain/cnn_e2eali_1b \ + --sup-lat-dir exp/chain/e2e_train_lats \ + --sup-tree-dir exp/chain/tree_e2e \ + --chain-affix "" \ + --tdnn-affix _semisup_1a \ + --exp-root $exp_root || exit 1 +fi diff --git a/egs/yomdle_tamil/v1/local/train_lm.sh b/egs/yomdle_tamil/v1/local/train_lm.sh new file mode 100755 index 00000000000..bb21c67b63f --- /dev/null +++ b/egs/yomdle_tamil/v1/local/train_lm.sh @@ -0,0 +1,129 @@ +#!/bin/bash + +# Copyright 2016 Vincent Nguyen +# 2016 Johns Hopkins University (author: Daniel Povey) +# 2017 Ashish Arora +# 2017 Hossein Hadian +# Apache 2.0 +# +# This script trains a LM on the training transcriptions and corpus text. +# It is based on the example scripts distributed with PocoLM + +# It will check if pocolm is installed and if not will proceed with installation + +set -e +stage=0 +dir=data/local/local_lm +order=6 +echo "$0 $@" # Print the command line for logging +. ./utils/parse_options.sh || exit 1; + +lm_dir=${dir}/data + + +mkdir -p $dir +. ./path.sh || exit 1; # for KALDI_ROOT +export PATH=$KALDI_ROOT/tools/pocolm/scripts:$PATH +( # First make sure the pocolm toolkit is installed. + cd $KALDI_ROOT/tools || exit 1; + if [ -d pocolm ]; then + echo Not installing the pocolm toolkit since it is already there. + else + echo "$0: Please install the PocoLM toolkit with: " + echo " cd ../../../tools; extras/install_pocolm.sh; cd -" + exit 1; + fi +) || exit 1; + +bypass_metaparam_optim_opt= +# If you want to bypass the metaparameter optimization steps with specific metaparameters +# un-comment the following line, and change the numbers to some appropriate values. +# You can find the values from output log of train_lm.py. +# These example numbers of metaparameters is for 4-gram model (with min-counts) +# running with train_lm.py. +# The dev perplexity should be close to the non-bypassed model. +#bypass_metaparam_optim_opt="--bypass-metaparameter-optimization=0.031,0.860,0.678,0.194,0.037,0.006,0.928,0.712,0.454,0.220,0.926,0.844,0.749,0.358,0.966,0.879,0.783,0.544,0.966,0.826,0.674,0.450" +# Note: to use these example parameters, you may need to remove the .done files +# to make sure the make_lm_dir.py be called and tain only 3-gram model +#for order in 3; do +#rm -f ${lm_dir}/${num_word}_${order}.pocolm/.done + +if [ $stage -le 0 ]; then + mkdir -p ${dir}/data + mkdir -p ${dir}/data/text + + echo "$0: Getting the Data sources" + + rm ${dir}/data/text/* 2>/dev/null || true + + # use the validation data as the dev set. + # Note: the name 'dev' is treated specially by pocolm, it automatically + # becomes the dev set. + + cat data/local/text/cleaned/bpe_val.txt > ${dir}/data/text/dev.txt + # use the training data as an additional data source. + # we can later fold the dev data into this. + cat data/train/text | cut -d " " -f 2- > ${dir}/data/text/train.txt + cat data/local/text/cleaned/bpe_corpus.txt > ${dir}/data/text/corpus_text.txt + # for reporting perplexities, we'll use the "real" dev set. + # (the validation data is used as ${dir}/data/text/dev.txt to work + # out interpolation weights.) + # note, we can't put it in ${dir}/data/text/, because then pocolm would use + # it as one of the data sources. + cut -d " " -f 2- < data/test/text > ${dir}/data/real_dev_set.txt + + # get the wordlist from train and corpus text + cat ${dir}/data/text/{train,corpus_text}.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + cat ${dir}/data/word_count | awk '{print $2}' > ${dir}/data/wordlist +fi + +if [ $stage -le 1 ]; then + # decide on the vocabulary. + # Note: you'd use --wordlist if you had a previously determined word-list + # that you wanted to use. + # Note: if you have more than one order, use a certain amount of words as the + # vocab and want to restrict max memory for 'sort', + echo "$0: training the unpruned LM" + min_counts='train=1' + wordlist=${dir}/data/wordlist + + lm_name="`basename ${wordlist}`_${order}" + if [ -n "${min_counts}" ]; then + lm_name+="_`echo ${min_counts} | tr -s "[:blank:]" "_" | tr "=" "-"`" + fi + unpruned_lm_dir=${lm_dir}/${lm_name}.pocolm + train_lm.py --wordlist=${wordlist} --num-splits=20 --warm-start-ratio=20 \ + --limit-unk-history=true \ + ${bypass_metaparam_optim_opt} \ + ${dir}/data/text ${order} ${lm_dir}/work ${unpruned_lm_dir} + + get_data_prob.py ${dir}/data/real_dev_set.txt ${unpruned_lm_dir} 2>&1 | grep -F '[perplexity' + #3grm: [perplexity = 27.7734168008] over 151116.0 words + #6grm: [perplexity = 18.6681627154] over 151116.0 words + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${unpruned_lm_dir} | gzip -c > ${dir}/data/arpa/${order}gram_unpruned.arpa.gz +fi + +if [ $stage -le 2 ]; then + echo "$0: pruning the LM (to larger size)" + # Using 10 million n-grams for a big LM for rescoring purposes. + size=10000000 + prune_lm_dir.py --target-num-ngrams=$size --initial-threshold=0.02 ${unpruned_lm_dir} ${dir}/data/lm_${order}_prune_big + + get_data_prob.py ${dir}/data/real_dev_set.txt ${dir}/data/lm_${order}_prune_big 2>&1 | grep -F '[perplexity' + #[perplexity = 22.0613098868] over 151116.0 words + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${dir}/data/lm_${order}_prune_big | gzip -c > ${dir}/data/arpa/${order}gram_big.arpa.gz +fi + +if [ $stage -le 3 ]; then + echo "$0: pruning the LM (to smaller size)" + # Using 2 million n-grams for a smaller LM for graph building. Prune from the + # bigger-pruned LM, it'll be faster. + size=2000000 + prune_lm_dir.py --target-num-ngrams=$size ${dir}/data/lm_${order}_prune_big ${dir}/data/lm_${order}_prune_small + + get_data_prob.py ${dir}/data/real_dev_set.txt ${dir}/data/lm_${order}_prune_small 2>&1 | grep -F '[perplexity' + #[perplexity = 23.4801171202] over 151116.0 words + format_arpa_lm.py ${dir}/data/lm_${order}_prune_small | gzip -c > ${dir}/data/arpa/${order}gram_small.arpa.gz +fi diff --git a/egs/yomdle_tamil/v1/local/wer_output_filter b/egs/yomdle_tamil/v1/local/wer_output_filter new file mode 100755 index 00000000000..59e364e0231 --- /dev/null +++ b/egs/yomdle_tamil/v1/local/wer_output_filter @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Hossein Hadian + +# Apache 2.0 +# This script converts a BPE-encoded text to normal text. It is used in scoring + +import sys, io +import string +infile = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') +output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') +for line in infile: + words = line.strip().split() + uttid = words[0] + transcript = ''.join(words[1:]) + transcript = transcript.replace('|', ' ') + output.write(uttid + ' ' + transcript + '\n') diff --git a/egs/yomdle_tamil/v1/local/yomdle/create_download_dir.sh b/egs/yomdle_tamil/v1/local/yomdle/create_download_dir.sh new file mode 100755 index 00000000000..de932e01021 --- /dev/null +++ b/egs/yomdle_tamil/v1/local/yomdle/create_download_dir.sh @@ -0,0 +1,103 @@ +#!/bin/bash + +# Copyright 2018 Chun Chieh Chang +# 2018 Ashish Arora +# 2018 Hossein Hadian +# Apache 2.0 + +# This script assumes that the SLAM and Yomdle OCR database is stored in slam_dir and +# yomdle_dir. It reads the xml files and converts them to csv files. It then with the +# help of csv files, extracts lines images from page images. It can create dataset for +# any yomdle and slam language. Assuming it is creating dataset for Tamil OCR. It +# creates csv files for yomdle English, yomdle Tamil, slam Tamil transcribed and slam +# Tamil boxed. It also creates train, test and train_unsup sets for training and testing. +# Yomdle (English and Tamil) is training set, slam Tamil transcribed is test set, and +# slam Tamil boxed is semi-supervised set. + +set -e +stage=0 +language_main=Tamil +slam_dir=/export/corpora5/slam/SLAM/ +yomdle_dir=/export/corpora5/slam/YOMDLE/ + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +mkdir -p data/local/splits +language_lower=$(echo "$language_main" | tr '[:upper:]' '[:lower:]') + +echo "$0: extracting line images for english and ${language} for shared model training" +if [ $stage -le 0 ]; then + for language in english $language_lower; do + echo "$0: Processing YOMDLE ${language}" + mkdir -p data/download/${language}/{truth_csv,truth_line_image} + local/yomdle/yomdle2csv.py \ + --inputDir $yomdle_dir/final_$language/ \ + --outputDir data/download/${language}/truth_csv/ \ + --log data/download/yomdle2csv.${language}.log + local/yomdle/create_line_image_from_page_image.py \ + $yomdle_dir/final_$language/images/ \ + data/download/${language}/truth_csv/ \ + data/download/${language}/truth_line_image/ \ + data/local/yomdle-${language}-train.list \ + --filter + done +fi + +echo "$0: extracting line images for slam ${language} for testing" +if [ $stage -le 1 ]; then + echo "$0: Processing slam ${language_main}" + mkdir -p data/download/${language_main}/{truth_csv,truth_line_image} + local/yomdle/gedi2csv_enriched.py \ + --inputDir $slam_dir/${language_main}/transcribed/ \ + --outputDir data/download/${language_main}/truth_csv/ \ + --log data/download/gedi2csv.${language_main}.log + local/yomdle/create_line_image_from_page_image.py \ + $slam_dir/${language_main}/transcribed/ \ + data/download/${language_main}/truth_csv/ \ + data/download/${language_main}/truth_line_image/ \ + data/local/yomdle-${language_main}-test.list \ + --ext '.png' +fi + +echo "$0: extracting line images for semi supervised training for slam ${language}" +if [ $stage -le 2 ]; then + echo "$0: Processing slam ${language_main}" + mkdir -p data/download/${language_main}_boxed/{truth_csv,truth_line_image} + local/yomdle/gedi2csv_enriched.py \ + --inputDir $slam_dir/${language_main}/boxed \ + --ftype boxed \ + --outputDir data/download/${language_main}_boxed/truth_csv/ \ + --log data/download/gedi2csv.${language_main}_boxed.log + local/yomdle/create_line_image_from_page_image.py \ + $slam_dir/${language_main}/boxed \ + data/download/${language_main}_boxed/truth_csv/ \ + data/download/${language_main}_boxed/truth_line_image/ \ + data/local/yomdle-${language_main}-train_unsup.list \ + --ext '.png' \ + --filter +fi + +echo "$0: storing english, given language(transcribed and untranscribed) line images together" +if [ $stage -le 3 ]; then + cp -r data/download/${language_main}_boxed/truth_line_image/* data/download/$language_lower/truth_line_image/ + cp -r data/download/$language_main/truth_line_image/* data/download/$language_lower/truth_line_image/ + cp -r data/download/english/truth_line_image/* data/download/$language_lower/truth_line_image/ + cp -r data/download/${language_main}_boxed/truth_csv/* data/download/$language_lower/truth_csv/ + cp -r data/download/$language_main/truth_csv/* data/download/$language_lower/truth_csv/ + cp -r data/download/english/truth_csv/* data/download/$language_lower/truth_csv/ +fi + + +if [ $stage -le 4 ]; then + mv data/download/$language_lower/truth_line_image/ data/download/ + mv data/download/$language_lower/truth_csv/ data/download/ +fi + +echo "$0: storing train, test and train unsupervised splits" +if [ $stage -le 5 ]; then + cat data/local/yomdle-${language_lower}-train.list data/local/yomdle-english-train.list > data/local/splits/train.txt + cp data/local/yomdle-${language_main}-test.list data/local/splits/test.txt + cp data/local/yomdle-${language_main}-train_unsup.list data/local/splits/train_unsup.txt +fi diff --git a/egs/yomdle_tamil/v1/local/yomdle/create_line_image_from_page_image.py b/egs/yomdle_tamil/v1/local/yomdle/create_line_image_from_page_image.py new file mode 100755 index 00000000000..885f18c7deb --- /dev/null +++ b/egs/yomdle_tamil/v1/local/yomdle/create_line_image_from_page_image.py @@ -0,0 +1,406 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora +# 2018 Chun Chieh Chang + +# Apache 2.0 + +# minimum bounding box part in this script is originally from +#https://github.com/BebeSparkelSparkel/MinimumBoundingBox +#https://startupnextdoor.com/computing-convex-hull-in-python/ +""" This module will be used for extracting line images from page image. + Given the word segmentation (bounding box around a word) for every word, it will + extract line segmentation. To extract line segmentation, it will take word bounding + boxes of a line as input, will create a minimum area bounding box that will contain + all corner points of word bounding boxes. The obtained bounding box (will not necessarily + be vertically or horizontally aligned). Hence to extract line image from line bounding box, + page image is rotated and line image is cropped and saved. +""" + +import argparse +import csv +import itertools +import sys +import os +import numpy as np +from math import atan2, cos, sin, pi, degrees, sqrt +from collections import namedtuple + +from scipy.spatial import ConvexHull +from PIL import Image +from scipy.misc import toimage +from pathlib import Path +from glob import glob +parser = argparse.ArgumentParser(description="Creates line images from page image") +parser.add_argument('image_dir', type=str, help='Path to full page images') +parser.add_argument('csv_dir', type=str, help='Path to csv files') +parser.add_argument('out_dir', type=str, help='Path to output directory') +parser.add_argument('output_file', type=str, help='file containing all line images id') +parser.add_argument('--padding', type=int, default=100, help='Padding so BBox does not exceed image area') +parser.add_argument('--ext', type=str, default='.jpg', help='Extention of the line images') +parser.add_argument("--filter", action="store_true", + help="If true, filter height/width<10 pixels minimum area rectangles") +args = parser.parse_args() + +""" +bounding_box is a named tuple which contains: + area (float): area of the rectangle + length_parallel (float): length of the side that is parallel to unit_vector + length_orthogonal (float): length of the side that is orthogonal to unit_vector + rectangle_center(int, int): coordinates of the rectangle center + (use rectangle_corners to get the corner points of the rectangle) + unit_vector (float, float): direction of the length_parallel side. + (it's orthogonal vector can be found with the orthogonal_vector function + unit_vector_angle (float): angle of the unit vector to be in radians. + corner_points [(float, float)]: set that contains the corners of the rectangle +""" + +bounding_box_tuple = namedtuple('bounding_box_tuple', 'area ' + 'length_parallel ' + 'length_orthogonal ' + 'rectangle_center ' + 'unit_vector ' + 'unit_vector_angle ' + 'corner_points' + ) + + +def unit_vector(pt0, pt1): + """ Given two points pt0 and pt1, return a unit vector that + points in the direction of pt0 to pt1. + Returns + ------- + (float, float): unit vector + """ + dis_0_to_1 = sqrt((pt0[0] - pt1[0])**2 + (pt0[1] - pt1[1])**2) + return (pt1[0] - pt0[0]) / dis_0_to_1, \ + (pt1[1] - pt0[1]) / dis_0_to_1 + + +def orthogonal_vector(vector): + """ Given a vector, returns a orthogonal/perpendicular vector of equal length. + Returns + ------ + (float, float): A vector that points in the direction orthogonal to vector. + """ + return -1 * vector[1], vector[0] + + +def bounding_area(index, hull): + """ Given index location in an array and convex hull, it gets two points + hull[index] and hull[index+1]. From these two points, it returns a named + tuple that mainly contains area of the box that bounds the hull. This + bounding box orintation is same as the orientation of the lines formed + by the point hull[index] and hull[index+1]. + Returns + ------- + a named tuple that contains: + area: area of the rectangle + length_parallel: length of the side that is parallel to unit_vector + length_orthogonal: length of the side that is orthogonal to unit_vector + rectangle_center: coordinates of the rectangle center + unit_vector: direction of the length_parallel side. + (it's orthogonal vector can be found with the orthogonal_vector function) + """ + unit_vector_p = unit_vector(hull[index], hull[index+1]) + unit_vector_o = orthogonal_vector(unit_vector_p) + + dis_p = tuple(np.dot(unit_vector_p, pt) for pt in hull) + dis_o = tuple(np.dot(unit_vector_o, pt) for pt in hull) + + min_p = min(dis_p) + min_o = min(dis_o) + len_p = max(dis_p) - min_p + len_o = max(dis_o) - min_o + + return {'area': len_p * len_o, + 'length_parallel': len_p, + 'length_orthogonal': len_o, + 'rectangle_center': (min_p + float(len_p) / 2, min_o + float(len_o) / 2), + 'unit_vector': unit_vector_p, + } + + +def to_xy_coordinates(unit_vector_angle, point): + """ Given angle from horizontal axis and a point from origin, + returns converted unit vector coordinates in x, y coordinates. + angle of unit vector should be in radians. + Returns + ------ + (float, float): converted x,y coordinate of the unit vector. + """ + angle_orthogonal = unit_vector_angle + pi / 2 + return point[0] * cos(unit_vector_angle) + point[1] * cos(angle_orthogonal), \ + point[0] * sin(unit_vector_angle) + point[1] * sin(angle_orthogonal) + + +def rotate_points(center_of_rotation, angle, points): + """ Rotates a point cloud around the center_of_rotation point by angle + input + ----- + center_of_rotation (float, float): angle of unit vector to be in radians. + angle (float): angle of rotation to be in radians. + points [(float, float)]: Points to be a list or tuple of points. Points to be rotated. + Returns + ------ + [(float, float)]: Rotated points around center of rotation by angle + """ + rot_points = [] + ang = [] + for pt in points: + diff = tuple([pt[d] - center_of_rotation[d] for d in range(2)]) + diff_angle = atan2(diff[1], diff[0]) + angle + ang.append(diff_angle) + diff_length = sqrt(sum([d**2 for d in diff])) + rot_points.append((center_of_rotation[0] + diff_length * cos(diff_angle), + center_of_rotation[1] + diff_length * sin(diff_angle))) + + return rot_points + + +def rectangle_corners(rectangle): + """ Given rectangle center and its inclination, returns the corner + locations of the rectangle. + Returns + ------ + [(float, float)]: 4 corner points of rectangle. + """ + corner_points = [] + for i1 in (.5, -.5): + for i2 in (i1, -1 * i1): + corner_points.append((rectangle['rectangle_center'][0] + i1 * rectangle['length_parallel'], + rectangle['rectangle_center'][1] + i2 * rectangle['length_orthogonal'])) + + return rotate_points(rectangle['rectangle_center'], rectangle['unit_vector_angle'], corner_points) + + +def minimum_bounding_box(points): + """ Given a list of 2D points, it returns the minimum area rectangle bounding all + the points in the point cloud. + Returns + ------ + returns a namedtuple that contains: + area: area of the rectangle + length_parallel: length of the side that is parallel to unit_vector + length_orthogonal: length of the side that is orthogonal to unit_vector + rectangle_center: coordinates of the rectangle center + unit_vector: direction of the length_parallel side. RADIANS + unit_vector_angle: angle of the unit vector + corner_points: set that contains the corners of the rectangle + """ + + if len(points) <= 2: raise ValueError('More than two points required.') + + hull_ordered = [points[index] for index in ConvexHull(points).vertices] + hull_ordered.append(hull_ordered[0]) + hull_ordered = tuple(hull_ordered) + + min_rectangle = bounding_area(0, hull_ordered) + for i in range(1, len(hull_ordered)-1): + rectangle = bounding_area(i, hull_ordered) + if rectangle['area'] < min_rectangle['area']: + min_rectangle = rectangle + + min_rectangle['unit_vector_angle'] = atan2(min_rectangle['unit_vector'][1], min_rectangle['unit_vector'][0]) + min_rectangle['rectangle_center'] = to_xy_coordinates(min_rectangle['unit_vector_angle'], min_rectangle['rectangle_center']) + + return bounding_box_tuple( + area = min_rectangle['area'], + length_parallel = min_rectangle['length_parallel'], + length_orthogonal = min_rectangle['length_orthogonal'], + rectangle_center = min_rectangle['rectangle_center'], + unit_vector = min_rectangle['unit_vector'], + unit_vector_angle = min_rectangle['unit_vector_angle'], + corner_points = set(rectangle_corners(min_rectangle)) + ) + + +def get_center(im): + """ Given image, returns the location of center pixel + Returns + ------- + (int, int): center of the image + """ + center_x = float(im.size[0]) / 2 + center_y = float(im.size[1]) / 2 + return int(center_x), int(center_y) + + +def get_horizontal_angle(unit_vector_angle): + """ Given an angle in radians, returns angle of the unit vector in + first or fourth quadrant. + Returns + ------ + (float): updated angle of the unit vector to be in radians. + It is only in first or fourth quadrant. + """ + if unit_vector_angle > pi / 2 and unit_vector_angle <= pi: + unit_vector_angle = unit_vector_angle - pi + elif unit_vector_angle > -pi and unit_vector_angle < -pi / 2: + unit_vector_angle = unit_vector_angle + pi + + return unit_vector_angle + + +def get_smaller_angle(bounding_box): + """ Given a rectangle, returns its smallest absolute angle from horizontal axis. + Returns + ------ + (float): smallest angle of the rectangle to be in radians. + """ + unit_vector = bounding_box.unit_vector + unit_vector_angle = bounding_box.unit_vector_angle + ortho_vector = orthogonal_vector(unit_vector) + ortho_vector_angle = atan2(ortho_vector[1], ortho_vector[0]) + + unit_vector_angle_updated = get_horizontal_angle(unit_vector_angle) + ortho_vector_angle_updated = get_horizontal_angle(ortho_vector_angle) + + if abs(unit_vector_angle_updated) < abs(ortho_vector_angle_updated): + return unit_vector_angle_updated + else: + return ortho_vector_angle_updated + + +def rotated_points(bounding_box, center): + """ Given the rectangle, returns corner points of rotated rectangle. + It rotates the rectangle around the center by its smallest angle. + Returns + ------- + [(int, int)]: 4 corner points of rectangle. + """ + p1, p2, p3, p4 = bounding_box.corner_points + x1, y1 = p1 + x2, y2 = p2 + x3, y3 = p3 + x4, y4 = p4 + center_x, center_y = center + rotation_angle_in_rad = -get_smaller_angle(bounding_box) + x_dash_1 = (x1 - center_x) * cos(rotation_angle_in_rad) - (y1 - center_y) * sin(rotation_angle_in_rad) + center_x + x_dash_2 = (x2 - center_x) * cos(rotation_angle_in_rad) - (y2 - center_y) * sin(rotation_angle_in_rad) + center_x + x_dash_3 = (x3 - center_x) * cos(rotation_angle_in_rad) - (y3 - center_y) * sin(rotation_angle_in_rad) + center_x + x_dash_4 = (x4 - center_x) * cos(rotation_angle_in_rad) - (y4 - center_y) * sin(rotation_angle_in_rad) + center_x + + y_dash_1 = (y1 - center_y) * cos(rotation_angle_in_rad) + (x1 - center_x) * sin(rotation_angle_in_rad) + center_y + y_dash_2 = (y2 - center_y) * cos(rotation_angle_in_rad) + (x2 - center_x) * sin(rotation_angle_in_rad) + center_y + y_dash_3 = (y3 - center_y) * cos(rotation_angle_in_rad) + (x3 - center_x) * sin(rotation_angle_in_rad) + center_y + y_dash_4 = (y4 - center_y) * cos(rotation_angle_in_rad) + (x4 - center_x) * sin(rotation_angle_in_rad) + center_y + return x_dash_1, y_dash_1, x_dash_2, y_dash_2, x_dash_3, y_dash_3, x_dash_4, y_dash_4 + + +def pad_image(image): + """ Given an image, returns a padded image around the border. + This routine save the code from crashing if bounding boxes that are + slightly outside the page boundary. + Returns + ------- + image: page image + """ + offset = int(args.padding // 2) + padded_image = Image.new('L', (image.size[0] + int(args.padding), image.size[1] + int(args.padding)), "white") + padded_image.paste(im = image, box = (offset, offset)) + return padded_image + +def update_minimum_bounding_box_input(bounding_box_input): + """ Given list of 2D points, returns list of 2D points shifted by an offset. + Returns + ------ + points [(float, float)]: points, a list or tuple of 2D coordinates + """ + updated_minimum_bounding_box_input = [] + offset = int(args.padding // 2) + for point in bounding_box_input: + x, y = point + new_x = x + offset + new_y = y + offset + word_coordinate = (new_x, new_y) + updated_minimum_bounding_box_input.append(word_coordinate) + + return updated_minimum_bounding_box_input + + +### main ### +globvar = 0 +text_fh = open(args.output_file, 'w', encoding='utf-8') +file_list = list(Path(args.csv_dir).rglob("*.[cC][sS][vV]")) +for filename in sorted(file_list): + filename = str(filename) + with open(str(filename), 'r', encoding='utf-8') as f: + base_name = os.path.basename(filename) + image_file = os.path.join(args.image_dir, base_name.split('.')[0] + args.ext) + try: + im = Image.open(image_file).convert('L') + except Exception as e: + print("Error: No such Image " + row[1]) + globvar += 1 + continue + im = pad_image(im) + for row in itertools.islice(csv.reader(f), 1, None): + points = [] + points.append((int(row[2]), int(row[3]))) + points.append((int(row[4]), int(row[5]))) + points.append((int(row[6]), int(row[7]))) + points.append((int(row[8]), int(row[9]))) + x = [int(row[2]), int(row[4]), int(row[6]), int(row[8])] + y = [int(row[3]), int(row[5]), int(row[7]), int(row[9])] + min_x, min_y = min(x), min(y) + max_x, max_y = max(x), max(y) + try: + updated_mbb_input = update_minimum_bounding_box_input(points) + bounding_box = minimum_bounding_box(updated_mbb_input) + except Exception as e: + globvar += 1 + continue + p1, p2, p3, p4 = bounding_box.corner_points + x1, y1 = p1 + x2, y2 = p2 + x3, y3 = p3 + x4, y4 = p4 + min_x = int(min(x1, x2, x3, x4)) + min_y = int(min(y1, y2, y3, y4)) + max_x = int(max(x1, x2, x3, x4)) + max_y = int(max(y1, y2, y3, y4)) + box = (min_x, min_y, max_x, max_y) + region_initial = im.crop(box) + rot_points = [] + p1_new = (x1 - min_x, y1 - min_y) + p2_new = (x2 - min_x, y2 - min_y) + p3_new = (x3 - min_x, y3 - min_y) + p4_new = (x4 - min_x, y4 - min_y) + rot_points.append(p1_new) + rot_points.append(p2_new) + rot_points.append(p3_new) + rot_points.append(p4_new) + cropped_bounding_box = bounding_box_tuple(bounding_box.area, + bounding_box.length_parallel, + bounding_box.length_orthogonal, + bounding_box.length_orthogonal, + bounding_box.unit_vector, + bounding_box.unit_vector_angle, + set(rot_points)) + rotation_angle_in_rad = get_smaller_angle(cropped_bounding_box) + img2 = region_initial.rotate(degrees(rotation_angle_in_rad), resample = Image.BICUBIC) + x_dash_1, y_dash_1, x_dash_2, y_dash_2, x_dash_3, y_dash_3, x_dash_4, y_dash_4 = rotated_points( + cropped_bounding_box, get_center(region_initial)) + min_x = int(min(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) + min_y = int(min(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) + max_x = int(max(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) + max_y = int(max(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) + box = (min_x, min_y, max_x, max_y) + region_final = img2.crop(box) + width, height = region_final.size + if args.filter: + if height > (width * 2): + globvar += 1 + continue + if height < 10: + globvar += 1 + continue + if width < 10: + globvar += 1 + continue + fname = row[1].split('.')[0] + text_fh.write(fname + '\n') + image_out_file = os.path.join(args.out_dir, row[1]) + region_final.save(image_out_file) +print(globvar) diff --git a/egs/yomdle_tamil/v1/local/yomdle/gedi2csv_enriched.py b/egs/yomdle_tamil/v1/local/yomdle/gedi2csv_enriched.py new file mode 100755 index 00000000000..51d7a34e7e8 --- /dev/null +++ b/egs/yomdle_tamil/v1/local/yomdle/gedi2csv_enriched.py @@ -0,0 +1,229 @@ +#!/usr/bin/env python3 + +''' +Convert GEDI-type bounding boxes to CSV format +''' + +import logging +import os +import sys +import time +import glob +import csv +import imghdr +from PIL import Image +import argparse +import pdb +import cv2 +import numpy as np +import xml.etree.ElementTree as ET + +sin = np.sin +cos = np.cos +pi = np.pi + +def Rotate2D(pts, cnt, ang=90): + M = np.array([[cos(ang),-sin(ang)],[sin(ang),cos(ang)]]) + res = np.dot(pts-cnt,M)+cnt + return M, res + +def npbox2string(npar): + if np.shape(npar)[0] != 1: + print('Error during CSV conversion\n') + c1,r1 = npar[0][0],npar[0][1] + c2,r2 = npar[0][2],npar[0][3] + c3,r3 = npar[0][4],npar[0][5] + c4,r4 = npar[0][6],npar[0][7] + + return c1,r1,c2,r2,c3,r3,c4,r4 + +# cv2.minAreaRect() returns a Box2D structure which contains following detals - ( center (x,y), (width, height), angle of rotation ) +# Get 4 corners of the rectangle using cv2.boxPoints() +class GEDI2CSV(object): + ''' Initialize the extractor''' + def __init__(self, logger, args): + self._logger = logger + self._args = args + + ''' + Segment image with GEDI bounding box information + ''' + def csvfile(self, coords, polys, baseName, pgrot): + + ''' for writing the files ''' + writePath = self._args.outputDir + writePath = os.path.join(writePath,'') + if os.path.isdir(writePath) != True: + os.makedirs(writePath) + + rotlist = [] + + header=['ID','name','col1','row1','col2','row2','col3','row3','col4','row4','confidence','truth','pgrot','bbrot','qual','script','text_type'] + conf=100 + write_ctr = 0 + if len(coords) == 0 and len(polys) == 0: + self._logger.info('Found %s with no text content',(baseName)) + print('...Found %s with no text content' % (baseName)) + return + + strPos = writePath + baseName + + ''' for each group of coordinates ''' + for i in coords: + + [id,x,y,w,h,degrees,text,qual,script,text_type] = i + contour = np.array([(x,y),(x+w,y),(x+w,y+h),(x,y+h)]) + """First rotate around upper left corner based on orientationD keyword""" + M, rot = Rotate2D(contour, np.array([x,y]), degrees*pi/180) + rot = np.int0(rot) + + # rot is the 8 points rotated by degrees + # pgrot is the rotation after extraction, so save + # save rotated points to list or array + rot = np.reshape(rot,(-1,1)).T + c1,r1,c2,r2,c3,r3,c4,r4 = npbox2string(rot) + + bbrot = degrees + rotlist.append([id,baseName + '_' + id + '.png',c1,r1,c2,r2,c3,r3,c4,r4,conf,text,pgrot,bbrot,qual,script,text_type]) + + # if there are polygons, first save the text + for j in polys: + arr = [] + [id,poly_val,text,qual,script,text_type] = j + for i in poly_val: + arr.append(eval(i)) + + contour = np.asarray(arr) + convex = cv2.convexHull(contour) + rect = cv2.minAreaRect(convex) + box = cv2.boxPoints(rect) + box = np.int0(box) + box = np.reshape(box,(-1,1)).T + c1,r1,c2,r2,c3,r3,c4,r4 = npbox2string(box) + + bbrot = 0.0 + rotlist.append([id,baseName + '_' + id + '.png',c1,r1,c2,r2,c3,r3,c4,r4,conf,text,pgrot,bbrot,qual,script,text_type]) + # then write out all of list to file + with open(strPos + ".csv", "w", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(header) + for row in rotlist: + writer.writerow(row) + write_ctr += 1 + return write_ctr + + +def main(args): + startTime = time.clock() + writePath = args.outputDir + if os.path.isdir(writePath) != True: + os.makedirs(writePath) + ''' Setup logging ''' + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + if args.log: + handler = logging.FileHandler(args.log) + handler.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + gtconverter = GEDI2CSV(logger, args) + namespaces = {"gedi" : "http://lamp.cfar.umd.edu/media/projects/GEDI/"} + keyCnt=0 + fileCnt = 0 + line_write_ctr = 0 + line_error_ctr = 0 + + ''' + Get all XML files in the directory and sub folders + ''' + for root, dirnames, filenames in os.walk(args.inputDir, followlinks=True): + for file in filenames: + if file.lower().endswith('.xml'): + fullName = os.path.join(root,file) + baseName = os.path.splitext(fullName) + fileCnt += 1 + ''' read the XML file ''' + tree = ET.parse(fullName) + gedi_root = tree.getroot() + child = gedi_root.findall('gedi:DL_DOCUMENT',namespaces)[0] + totalpages = int(child.attrib['NrOfPages']) + coordinates=[] + polygons = [] + if args.ftype == 'boxed': + fileTypeStr = 'col' + elif args.ftype == 'transcribed': + fileTypeStr = 'Text_Content' + else: + print('Filetype must be either boxed or transcribed!') + logger.info('Filetype must be either boxed or transcribed!') + sys.exit(-1) + + if args.quality == 'both': + qualset = {'Regular','Low-Quality'} + elif args.quality == 'low': + qualset = {'Low-Quality'} + elif args.quality == 'regular': + qualset = {'Regular'} + else: + print('Quality must be both, low or regular!') + logger.info('Quality must be both, low or regular!') + sys.exit(-1) + + + + ''' and for each page ''' + for i, pgs in enumerate(child.iterfind('gedi:DL_PAGE',namespaces)): + + if 'GEDI_orientation' not in pgs.attrib: + pageRot=0 + else: + pageRot = int(pgs.attrib['GEDI_orientation']) + logger.info(' PAGE ROTATION %s, %s' % (fullName, str(pageRot))) + + ''' find children for each page ''' + for zone in pgs.findall('gedi:DL_ZONE',namespaces): + + if zone.attrib['gedi_type']=='Text' and zone.attrib['Type'] in \ + ('Machine_Print','Confusable_Allograph','Handwriting') and zone.attrib['Quality'] in qualset: + if zone.get('polygon'): + keyCnt+=1 + polygons.append([zone.attrib['id'],zone.get('polygon').split(';'), + zone.get('Text_Content'),zone.get('Quality'),zone.get('Script'),zone.get('Type')]) + elif zone.get(fileTypeStr) != None: + keyCnt+=1 + coord = [zone.attrib['id'],int(zone.attrib['col']),int(zone.attrib['row']), + int(zone.attrib['width']), int(zone.attrib['height']), + float(zone.get('orientationD',0.0)), + zone.get('Text_Content'),zone.get('Quality'),zone.get('Script'),zone.get('Type')] + coordinates.append(coord) + + if len(coordinates) > 0 or len(polygons) > 0: + line_write_ctr += gtconverter.csvfile(coordinates, polygons, os.path.splitext(file)[0], pageRot) + else: + print('...%s has no applicable content' % (baseName[0])) + + print('complete...total files %d, lines written %d' % (fileCnt, line_write_ctr)) + + +''' Args and defaults ''' +def parse_arguments(argv): + parser = argparse.ArgumentParser() + + parser.add_argument('--inputDir', type=str, help='Input directory', required=True) + parser.add_argument('--outputDir', type=str, help='Output directory', required=True) + parser.add_argument('--ftype', type=str, help='GEDI file type (either "boxed" or "transcribed")', default='transcribed') + parser.add_argument('--quality', type=str, help='GEDI file quality (either "both" or "low" or "regular")', default='regular') + parser.add_argument('--log', type=str, help='Log directory', default='./GEDI2CSV_enriched.log') + + return parser.parse_args(argv) + +''' Run ''' +if __name__ == '__main__': + main(parse_arguments(sys.argv[1:])) + + + + + + diff --git a/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/baseline_text_detect.sh b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/baseline_text_detect.sh new file mode 100755 index 00000000000..057d22ab492 --- /dev/null +++ b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/baseline_text_detect.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +# INPUT: +# LANGUAGE - SLAM language to evaluate +# TRUTH_CSV - Transcription annotation csv file +# formt: ID,name,col1,row1,col2,row2,col3,row3,col4,row4,confidence,truth,rotation,quality,script +# +# PREDICT_CSV - The predicted transcription csv file +# formt: ID,name,col1,row1,col2,row2,col3,row3,col4,row4,confidence,truth +# +# OUTPUT: +# OUTPUT_DIR - +# +# + +source activate py35 +export LD_LIBRARY_PATH=/exp/scale18/ocr/tools/leptonica-1.74.4/lib:/exp/scale18/ocr/tools/tesseract/install/lib:$LD_LIBRARY_PATH + +echo "LD_LIBRARY_PATH = ${LD_LIBRARY_PATH}" +echo "PATH = ${PATH}" + +MODELS_DIR=/exp/scale18/ocr/tools/tessdata +MODELS_LANG=far+eng +LANGUAGE=Farsi +INPUT=/exp/scale18/ocr/data/derived/SLAM_2.0/${LANGUAGE}/transcribed_list.txt +OUTPUT=/exp/detter/scale18/slam2/results/${LANGUAGE}/transcribed/bbox_csv +OVERLAY=/exp/detter/scale18/slam2/results/${LANGUAGE}/transcribed/bbox_overlay +LOG=/exp/detter/scale18/slam2/results/${LANGUAGE}/transcribed/bbox_log.txt + +echo "Models ${MODELS_DIR}" +echo "Model lang ${MODELS_LANG}" + +echo "Language ${LANGUAGE}" +echo "Input ${INPUT}" +echo "Output ${OUTPUT}" +echo "Overlay ${OVERLAY}" + +echo "...evalulate text detection" +python /exp/detter/scale18/ocr/cv_scale/detect_lines/get_bbox_tesserocr.py \ +--tess_data=${MODELS_DIR} \ +--lang=${MODELS_LANG} \ +--oem=1 \ +--blur=0 \ +--line=1 \ +--input=${INPUT} \ +--output=${OUTPUT} \ +--overlay=${OVERLAY} \ +--log=${LOG} + +echo "...COMPLETE..." diff --git a/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/convert2snor.py b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/convert2snor.py new file mode 100755 index 00000000000..c8f1d2efa48 --- /dev/null +++ b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/convert2snor.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 + +""" This script converts kaldi format into snor format.. +""" +import sys +import io + +sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding="utf8") +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf8") + +for line in sys.stdin: + line = line.strip() + line_vect = line.split() + utt_id = line_vect[0] + utt = ' '.join(line_vect[1:]) + sys.stdout.write(utt + " (" + utt_id + ")\n") diff --git a/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/eval_text_detect.sh b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/eval_text_detect.sh new file mode 100755 index 00000000000..2243d46e10a --- /dev/null +++ b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/eval_text_detect.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# DESC: Evaluate text detection bounding boxes +# +# INPUT: +# truth - Directory of input truth csv files +# formt: ID,name,col1,row1,col2,row2,col3,row3,col4,row4,box_conf +# predict - Directory of input predict csv files +# formt: ID,name,col1,row1,col2,row2,col3,row3,col4,row4,box_conf +# iou - intersection over union +# +# OUTPUT: +# output - output directory of results (plot) +# log - log output + +source activate py35 +echo "LD_LIBRARY_PATH = ${LD_LIBRARY_PATH}" +echo "PATH = ${PATH}" + +LANGUAGE=Farsi +IOU=0.50 +TRUTH_CSV=/exp/scale18/ocr/data/derived/SLAM_2.0/${LANGUAGE}/transcribed/truth_csv +PREDICT_CSV=/exp/detter/scale18/slam2/results/${LANGUAGE}/transcribed/bbox_csv +RESULTS=/exp/detter/scale18/slam2/results/${LANGUAGE}/transcribed/bbox_results +LOG=/exp/detter/scale18/slam2/results/${LANGUAGE}/transcribed/bbox_results_log.txt + +echo "Language ${LANGUAGE}" +echo "Truth ${TRUTH_CSV}" +echo "Predict ${PREDICT_CSV}" +echo "Results ${RESULTS}" + +echo "\n...evalulate text detection" +python /exp/detter/scale18/ocr/cv_scale/eval/eval_detect_lines.py \ +--truth=${TRUTH_CSV} \ +--predict=${PREDICT_CSV} \ +--iou=${IOU} \ +--output=${RESULTS} \ +--log=${LOG} + +echo "...COMPLETE..." diff --git a/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/normalized_score.sh b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/normalized_score.sh new file mode 100755 index 00000000000..f55600939ae --- /dev/null +++ b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/normalized_score.sh @@ -0,0 +1,130 @@ +#!/bin/bash + +# This script normalizes hypothesis and reference file and performs scoring. +# Eg. ./local/yomdle/normalized_scoring/normalized_score.sh + +if [ $# -ne 3 ]; then + echo "USAGE: ./local/yomdle/normalized_scoring/normalized_score.sh " + exit 1 +fi + +OUTDIR=$1 +HYP_FILE=$2 +LANG=$3 + +# ocr_score.pl is slow, especially for CER computation +# Therefore default option is to convert files to uxxxx format and use sclite for scoring +# Turn following switch to false to use ocr_score.pl instead +USE_SCLITE=true +script_dir=local/yomdle/normalized_scoring +OCR_SCORE=${script_dir}/ocr_score.pl +SCLITE=../../../tools/sctk/bin/sclite + +LANG=$(echo $LANG | tr '[:upper:]' '[:lower:]') +echo "performing some normalizations..." + +mkdir -p $OUTDIR +cat $HYP_FILE | python3 $script_dir/convert2snor.py > data/local/text/hyp_file.txt +cat data/test/text.old | python3 $script_dir/convert2snor.py > data/local/text/ref_file.txt +# Step 1. Run some normalizations that are common to all languages +python3 ${script_dir}/utils/normalize_spaces.py data/local/text/hyp_file.txt $OUTDIR/hyp.norm-sp.txt +python3 ${script_dir}/utils/normalize_spaces.py data/local/text/ref_file.txt $OUTDIR/ref.norm-sp.txt + +python3 ${script_dir}/utils/normalize_common.py $OUTDIR/hyp.norm-sp.txt $OUTDIR/hyp.norm-sp-common.txt +python3 ${script_dir}/utils/normalize_common.py $OUTDIR/ref.norm-sp.txt $OUTDIR/ref.norm-sp-common.txt + +# Step 1. Run language specific normalization +if [ "$LANG" == "farsi" ]; then + # Farsi Normalization + python3 ${script_dir}/utils/normalize_farsi.py $OUTDIR/hyp.norm-sp-common.txt $OUTDIR/hyp.norm-final.txt + python3 ${script_dir}/utils/normalize_farsi.py $OUTDIR/ref.norm-sp-common.txt $OUTDIR/ref.norm-final.txt +else + # For now no normalization for other langs + cp $OUTDIR/hyp.norm-sp-common.txt $OUTDIR/hyp.norm-final.txt + cp $OUTDIR/ref.norm-sp-common.txt $OUTDIR/ref.norm-final.txt +fi + +# Step 2. Run tokenization to get word-based output +python3 ${script_dir}/utils/trans_to_tokenized_words.py $OUTDIR/hyp.norm-final.txt $OUTDIR/hyp.norm-final.words.txt +python3 ${script_dir}/utils/trans_to_tokenized_words.py $OUTDIR/ref.norm-final.txt $OUTDIR/ref.norm-final.words.txt + +# Step 3. Also need to turn into space-seperated character stream to get char-based output +python3 ${script_dir}/utils/trans_to_chars.py $OUTDIR/hyp.norm-final.txt $OUTDIR/hyp.norm-final.chars.txt +python3 ${script_dir}/utils/trans_to_chars.py $OUTDIR/ref.norm-final.txt $OUTDIR/ref.norm-final.chars.txt + +# Step 5. Look for reference uttids that aren't in hypothesis and add them in as blank hypotheses. This is needed because +# otherwise sclite will not penalize systems for missing hypotheses +#python3 ${script_dir}/utils/find_missing_hyp_ids.py $OUTDIR/ref.norm-final.words.txt $OUTDIR/hyp.norm-final.words.txt > $OUTDIR/missing-hyp-ids.list +#python3 ${script_dir}/utils/insert_empty_hyp.py $OUTDIR/missing-hyp-ids.list $OUTDIR/hyp.norm-final.words.txt $OUTDIR/hyp.norm-final.words.withmissing.txt +#python3 ${script_dir}/utils/insert_empty_hyp.py $OUTDIR/missing-hyp-ids.list $OUTDIR/hyp.norm-final.chars.txt $OUTDIR/hyp.norm-final.chars.withmissing.txt + +# Step 5. Look for reference uttids that aren't in hypothesis and add them in as blank hypotheses. This is needed because +# otherwise sclite will not penalize systems for missing hypotheses +python3 ${script_dir}/utils/find_missing_hyp_ids.py $OUTDIR/ref.norm-final.words.txt $OUTDIR/hyp.norm-final.words.txt > $OUTDIR/missing-hyp-ids.list +#python3 ${script_dir}/utils/insert_empty_hyp.py $OUTDIR/missing-hyp-ids.list $OUTDIR/hyp.norm-final.words.txt $OUTDIR/hyp.norm-final.words.withmissing.txt +#python3 ${script_dir}/utils/insert_empty_hyp.py $OUTDIR/missing-hyp-ids.list $OUTDIR/hyp.norm-final.chars.txt $OUTDIR/hyp.norm-final.chars.withmissing.txt +cp $OUTDIR/hyp.norm-final.words.txt $OUTDIR/hyp.norm-final.words.withmissing.txt +cp $OUTDIR/hyp.norm-final.chars.txt $OUTDIR/hyp.norm-final.chars.withmissing.txt + +# Step 6. Possible filtering +# TODO +# Currently just cp non-filtered transcripts to filtered transcripts +# This will eventually filter out "bad" uttids that should be removed prior to scoring +cp $OUTDIR/ref.norm-final.words.txt $OUTDIR/ref.norm-final.words.filtered.txt +cp $OUTDIR/ref.norm-final.chars.txt $OUTDIR/ref.norm-final.chars.filtered.txt +cp $OUTDIR/hyp.norm-final.words.withmissing.txt $OUTDIR/hyp.norm-final.words.filtered.txt +cp $OUTDIR/hyp.norm-final.chars.withmissing.txt $OUTDIR/hyp.norm-final.chars.filtered.txt + + +# Step 7. Now we can run scoring + +if [ "$USE_SCLITE" == true ]; then + # First convert files to uxxxx format + python3 ${script_dir}/utils/word_trans_utf8_to_uxxxx.py $OUTDIR/ref.norm-final.words.filtered.txt $OUTDIR/ref.norm-final.words.filtered.uxxxx + python3 ${script_dir}/utils/word_trans_utf8_to_uxxxx.py $OUTDIR/hyp.norm-final.words.filtered.txt $OUTDIR/hyp.norm-final.words.filtered.uxxxx + python3 ${script_dir}/utils/char_trans_utf8_to_uxxxx.py $OUTDIR/ref.norm-final.chars.filtered.txt $OUTDIR/ref.norm-final.chars.filtered.uxxxx + python3 ${script_dir}/utils/char_trans_utf8_to_uxxxx.py $OUTDIR/hyp.norm-final.chars.filtered.txt $OUTDIR/hyp.norm-final.chars.filtered.uxxxx + + echo "Computing WER" + $SCLITE -r $OUTDIR/ref.norm-final.words.filtered.uxxxx -h $OUTDIR/hyp.norm-final.words.filtered.uxxxx -i swb -o all >/dev/null + wer_sys_file=$OUTDIR/hyp.norm-final.words.filtered.uxxxx.sys + + WER=$(grep 'Sum/Avg' ${wer_sys_file} | awk '{print $(NF-2)}') + echo "WER = $WER" + + echo "Computing CER" + $SCLITE -r $OUTDIR/ref.norm-final.chars.filtered.uxxxx -h $OUTDIR/hyp.norm-final.chars.filtered.uxxxx -i swb -o all >/dev/null + cer_sys_file=$OUTDIR/hyp.norm-final.chars.filtered.uxxxx.sys + + CER=$(grep 'Sum/Avg' ${cer_sys_file} | awk '{print $(NF-2)}') + echo "CER = $CER" + +else + echo "Computing WER" + LANG=C perl -CSAD $OCR_SCORE --ref_format trn --hyp_format trn $OUTDIR/ref.norm-final.words.filtered.txt $OUTDIR/hyp.norm-final.words.filtered.txt >/dev/null + wer_sys_file=$OUTDIR/hyp.norm-final.words.filtered.txt.sys + + WER=$(awk '{print $4}' ${wer_sys_file} | head -n 4 | tail -n 1) + echo "WER = $WER" + + echo "Computing CER" + LANG=C perl -CSAD $OCR_SCORE --ref_format trn --hyp_format trn $OUTDIR/ref.norm-final.chars.filtered.txt $OUTDIR/hyp.norm-final.chars.filtered.txt >/dev/null + cer_sys_file=$OUTDIR/hyp.norm-final.chars.filtered.txt.sys + + CER=$(awk '{print $4}' ${cer_sys_file} | head -n 4 | tail -n 1) + echo "CER = $CER" +fi + +num_missing_hyp=$(wc -l $OUTDIR/missing-hyp-ids.list | awk '{print $1}') + +echo "Done." +echo "" +echo "For detailed system scores see:" +echo -e "\t${wer_sys_file}" +echo -e "\t${cer_sys_file}" + +if [ "$num_missing_hyp" -gt 0 ]; then +echo "" +echo "Warning, you are missing ${num_missing_hyp} hypothesis lines. Your score is penalized due to missing lines." +echo -e "\tFind missing hypothesis ids here: $OUTDIR/missing-hyp-ids.list" +fi diff --git a/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/ocr_score.pl b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/ocr_score.pl new file mode 100755 index 00000000000..14b7e50a66c --- /dev/null +++ b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/ocr_score.pl @@ -0,0 +1,1431 @@ +eval '(exit $?0)' && eval 'exec perl -w -S $0 ${1+"$@"}' && eval 'exec perl -w -S $0 $argv:q' + if 0; +# The above two evil-looking lines enable the perl interpreter to be found on +# the PATH when this script is executed as a Unix command. + +################################################################### +# Copyright 2008 by BBN Technologies Corp. All Rights Reserved +################################################################### + +use strict; + +use File::Basename; + +my $usage = + "Usage: $0 [options] reference hypothesis [output-prefix]\n". + "Options:\n". + " --ref_format reference file format (default is kaldi)\n". + " --hyp_format hypothesis file format (default is kaldi)\n". " trn denotes SNOR format transcript:\n". + " line ends with unique ID in parentheses,\n". + " no other parentheses allowed.\n"; + +# Set default formats + +my %ref_format_values = map { $_ => 1 } ("trn", "stm", "kaldi" ); +my %hyp_format_values = map { $_ => 1 } ("trn", "ctm", "kaldi" ); + +my $ref_format = "kaldi"; +my $hyp_format = "kaldi"; +my $SPEAKER_SIDE_DELIM = "-"; +my $require_all_in_ref = 0; + +use Getopt::Long; + +if ( ! GetOptions( "ref_format=s" => \$ref_format, + "hyp_format=s" => \$hyp_format, + "use_all_ref=i" => \$require_all_in_ref) ) { + die $usage; +} + +if ( ! $ref_format_values{$ref_format} ) { + die "Invalid format for reference file: $ref_format\n$usage"; +} + +if ( ! $hyp_format_values{$hyp_format} ) { + die "Invalid format for hypothesis file: $hyp_format\n$usage"; +} + +die $usage + if ( ( @ARGV > 3 ) || ( @ARGV < 2 ) ); + +my $ref = $ARGV[0]; +my $hyp = $ARGV[1]; +my $out_pre = (defined $ARGV[2]) ? $ARGV[2] : $ARGV[1]; + +# Todo: +# 1) Support nested 's in hypothesis ctm files +# 2) Support multiple reference word paths +# 5) Output IGNORE alignment for non-scoring regions +# 7) Match hyp words to reference by audiofile channels, not speakers +# 6) Det plot output of confidence scores? +# 12) Output #Snt, #S.Err ? Std. Dev. / Median of speaker statistics? +# 13) Worry about GB, EUC and non-UTF8 unicode encoded stm/ctm files +# (Use perl's Encode support w/ command line option?) +# 16) Join together close reference utterances for alignment step +# 4) intersegment gaps? => fine as is. +# 17) Check that start_times are increasing across the hyp lattice +# from left to right. +# 18) Check that tags have * as start time and duration -- +# otherwise they might be words :). + +# Some constants. Many of these should eventually become command line options. + +my $fake_region_id = "-9999"; +my $insert_cost = 3; +my $delete_cost = 3; +my $sub_cost = 4; +my $opt_del_cost = 2; # It's better to substitute versus (%HEST) + # than to insert and then get it correct. + +my $log_zero = -1000; # a fair approximation to -infinity, + # given that confidence values are typically + # only precise to at most 3 digits. +my $max_cost = 99999999; + +my $allow_multiple_reference_paths = 0; +my $allow_word_fragment_matching = 1; + +my $verbosity = 1; +my $debug = 0; + +# Globals +my @cost; +my @traceback; + +# Load reference +print "Loading $ref_format reference file $ref ...\n" if ( $verbosity > 0 ); +my @load_ref_data; +if ($ref_format eq "stm") { + @load_ref_data = load_stm( $ref ); +} +elsif (($ref_format eq "trn") or ($ref_format eq "kaldi")) { + @load_ref_data = load_snor( $ref, $ref_format ); +} +else { + die "Internal error: invalid format $ref_format"; +} +my ( $refreg, $label_names, $category_names, $reforder, $get_side_from_speaker ) = @load_ref_data; + +# Merge reference regions into scoring regions +# my $score_regions = merge_ref_regions( $ref_regions ); + +# Load hypothesis +print "Loading $hyp_format hypothesis file $hyp ...\n" if ( $verbosity > 0 ); +my @load_hyp_data; +if ($hyp_format eq "ctm") { + @load_hyp_data = load_ctm( $hyp, $get_side_from_speaker ); +} +elsif (($hyp_format eq "trn") or ($hyp_format eq "kaldi")) { + my ( $hypreg ) = load_snor( $hyp, $hyp_format ); + @load_hyp_data = ( $hypreg, 5 ); # indicates no confidence field + # Make word networks for all utterances + map { MakeNetworkFromText( $_ ) } map { values %$_ } values %$hypreg; +} +else { + die "Internal error: invalid format $hyp_format"; +} +my ( $hypreg, $num_ctm_fields ) = @load_hyp_data; + +# Assign hypothesis words to scoring regions (currently done in load_ctm) +# my $assigned_hyps = assign_hyp_words_to_regions( $hyp_words, $score_regions ); + +# Do the alignment +# my $stats = align( $score_regions, $assigned_hyps, $out_pre ); +my $stats = align(); + +# Print stats +print_stats( $stats, $out_pre ); + +print STDOUT "Output files written to ${out_pre}.pra (alignments), ${out_pre}.sys (statistics), ${out_pre}.dtl.*\n"; + +exit(0); + +################################################################################# +# Subroutines +################################################################################# + +sub align { + + my $stats = {}; + + my %sub_count = (); + my %ins_count = (); + my %del_count = (); + my %ref_correct_count = (); + my %hyp_correct_count = (); + my %ref_count = (); + my %hyp_count = (); + my %ref_sub_count = (); + my %hyp_sub_count = (); + + open ( F, ">" . $out_pre . ".pra" ) + or die "Couldn't open ${out_pre}.pra for writing alignments\n"; + open ( SF, ">" . $out_pre . ".sgml" ) + or die "Couldn't open ${out_pre}.sgml for writing sgml alignments\n"; + + my $date = `date`; + chomp $date; + print SF '' . "\n"; + + foreach my $label ( sort keys %{ $label_names } ) { + print SF '\n"; + } + + foreach my $category ( sort { $a <=> $b } keys %{ $category_names } ) { + print SF '' . "\n"; + print SF "\n"; + } + + foreach my $spkr ( @$reforder ) { + if (not $require_all_in_ref) { + next unless exists($hypreg->{$spkr}); + } + + print "Aligning $spkr ...\n" if ( $verbosity > 1 ); + print SF '' . "\n"; + + my $cnt = 1; + ALIGN_UTT: foreach my $st ( sort { $a <=> $b } keys %{ $refreg->{$spkr} } ) { + + # Align + my $correct = 0; + my $insertions = 0; + my $deletions = 0; + my $substitutions = 0; + my $log_prob = 0; + + next ALIGN_UTT if ( $refreg->{$spkr}{$st}->{words} =~ /^IGNORE_TIME_SEGMENT_IN_SCORING$/i ); + + next ALIGN_UTT if ( ( $st eq $fake_region_id ) && !defined($hypreg->{$spkr}{$st}) ); + + print "Aligning ${spkr}-${st}\n" if ( $debug ); + + # Make the reference lattice from word string + my $reflat = { arcs => [], nodes => [] }; + push @{ $reflat->{nodes} }, { in_arcs => [], out_arcs => [ 0 ] }; + foreach my $refword ( split( ' ', $refreg->{$spkr}{$st}->{words} ) ) { + my $last_node_id = $#{ $reflat->{nodes} }; + push @{ $reflat->{arcs} }, { src => $last_node_id, + dst => $last_node_id + 1, + word => $refword }; + $reflat->{nodes}->[$last_node_id]->{out_arcs} = [ $#{ $reflat->{arcs} } ]; + push @{ $reflat->{nodes} }, { in_arcs => [ $#{$reflat->{arcs} } ], + out_arcs => [] }; + } + + my $hyplat = defined( $hypreg->{$spkr}{$st} ) ? + $hypreg->{$spkr}{$st} : + { nodes => [ { in_arcs => [], out_arcs => [] } ], + arcs => [] }; + + if ( $debug ) { + print "Reference lattice =\n"; + print_lattice( $reflat ); + print "Hypothesis lattice =\n"; + print_lattice( $hyplat ); + } + + @cost = (); + @traceback = (); + $cost[0][0] = 0; + $traceback[0][0] = {}; + + # Assign lowest costs to every ( ref_lat_node, hyp_lat_node ) pair + + for ( my $i = 0; $i <= $#{ $reflat->{nodes} }; $i++ ) { + HYP_NODES: for ( my $j = 0; $j <= $#{ $hyplat->{nodes} }; $j++ ) { + + next HYP_NODES if ( ( $i == 0 ) && ( $j == 0 ) ); + + $cost[$i][$j] = $max_cost; + print "Aligning $i,$j\n" if ( $debug ); + + foreach my $ref_arc ( @{ $reflat->{nodes}->[$i]->{in_arcs} } ) { + my $ref_arc_hash = $reflat->{arcs}->[$ref_arc]; + my $ref_word = $ref_arc_hash->{word}; + + foreach my $hyp_arc ( @{ $hyplat->{nodes}->[$j]->{in_arcs} } ) { + my $hyp_arc_hash = $hyplat->{arcs}->[$hyp_arc]; + + my $base_cost = $cost[ $ref_arc_hash->{src} ][ $hyp_arc_hash->{src} ]; + my $hyp_word = $hyp_arc_hash->{word}; + my $move_cost; + my $tb_str; + + print "Comparing ref $ref_word vs. hyp $hyp_word\n" if ( $debug ); + + if ( $ref_word eq $hyp_word ) { + $move_cost = $base_cost; + $tb_str = "CORRECT: $ref_word"; + } elsif ( ( $ref_word eq "(" . $hyp_word . ")" ) || + ( $allow_word_fragment_matching && + ( ( ( $ref_word =~ /^\((.*)\-\)$/ ) && # (X-) can match XY + ( $hyp_word =~ /^$1/ ) ) || + ( ( $ref_word =~ /^\(\-(.*)\)$/ ) && # (-X) can match YX + ( $hyp_word =~ /$1$/ ) ) ) ) ){ + $move_cost = $base_cost; + $tb_str = "CORRECT: hyp $hyp_word for ref $ref_word"; + } else { + $move_cost = $base_cost + $sub_cost; + $tb_str = "SUBSTITUTION: hyp $hyp_word for ref $ref_word"; + } + + update_cost( $i, $j, $ref_arc, $hyp_arc, $move_cost, $tb_str ); + } + + # Deletions + my $base_cost = $cost[ $ref_arc_hash->{src} ][ $j ]; + my $move_cost; + my $tb_str; + if ( $ref_word =~ /^\(.*\)$/ ) { + $move_cost = $base_cost + $opt_del_cost; + $tb_str = "CORRECT (Opt. Del.): $ref_word"; + } else { + $move_cost = $base_cost + $delete_cost; + $tb_str = "DELETION: $ref_word"; + } + + update_cost( $i, $j, $ref_arc, undef, $move_cost, $tb_str ); + + } + + # Insertions + foreach my $hyp_arc ( @{ $hyplat->{nodes}->[$j]->{in_arcs} } ) { + my $hyp_arc_hash = $hyplat->{arcs}->[$hyp_arc]; + my $base_cost = $cost[$i][ $hyp_arc_hash->{src} ]; + my $hyp_word = $hyp_arc_hash->{word}; + + my $move_cost = $base_cost + $insert_cost; + my $tb_str = "INSERTION: $hyp_word"; + update_cost( $i, $j, undef, $hyp_arc, $move_cost, $tb_str ); + } + + } # for $j + } # for $i + + # Traceback + my $i = $#{ $reflat->{nodes} }; + my $j = $#{ $hyplat->{nodes} }; + my $aligned_ref = ""; + my $aligned_hyp = ""; + my $align_str = ""; + my $sgml_str = ""; + + while ( ( $i > 0 ) || ( $j > 0 ) ) { + my $tb = $traceback[$i][$j]; + +# print "Traceback for $i,$j is $tb" if ( $debug ); + + my $tb_str = $tb->{str}; + die "Undefined traceback string for speaker $spkr start time $st (i=$i,j=$j)" unless defined( $tb_str ); + + $align_str = $tb_str . "\n" . $align_str; + + my $ref_arc_hash = defined($tb->{ref_arc}) ? $reflat->{arcs}->[$tb->{ref_arc}] : {}; + my $hyp_arc_hash = defined($tb->{hyp_arc}) ? $hyplat->{arcs}->[$tb->{hyp_arc}] : {}; + my $ref_word = defined( $ref_arc_hash->{word} ) ? $ref_arc_hash->{word} : ""; + my $hyp_word = defined( $hyp_arc_hash->{word} ) ? $hyp_arc_hash->{word} : ""; + my $hyp_word_conf = defined( $hyp_arc_hash->{conf} ) ? $hyp_arc_hash->{conf} : ""; + my $hyp_start_time = defined( $hyp_arc_hash->{start_time} ) ? $hyp_arc_hash->{start_time} : ""; + my $hyp_end_time = defined( $hyp_arc_hash->{end_time} ) ? $hyp_arc_hash->{end_time} : ""; + + if ( $ref_word ) { + $aligned_ref = $ref_word . " " . $aligned_ref; + $ref_count{$ref_word} = 0 unless defined( $ref_count{$ref_word} ); + $ref_count{$ref_word} += 1; + } + + if ( $hyp_word ) { + $aligned_hyp = $hyp_word . " " . $aligned_hyp; + $hyp_count{$hyp_word} = 0 unless defined( $ref_count{$hyp_word} ); + $hyp_count{$hyp_word} += 1; + } + + my $next_i = defined($ref_arc_hash->{src}) ? $ref_arc_hash->{src} : $i; + my $next_j = defined($hyp_arc_hash->{src}) ? $hyp_arc_hash->{src} : $j; + + if ( $tb_str =~ /^C/ ) { + $correct += 1; + $ref_correct_count{$ref_word} = 0 + unless defined( $ref_correct_count{$ref_word} ); + $ref_correct_count{$ref_word} += 1; + if ( $tb_str !~ /^CORRECT \(Opt/ ) { + $log_prob += mylog( $hyp_word_conf ); + $sgml_str = 'C,"' . $ref_word . '","' . $hyp_word . '",' . $hyp_start_time . '+' . $hyp_end_time . ',' . $hyp_word_conf . ':' . $sgml_str; + $hyp_correct_count{$hyp_word} = 0 + unless defined( $hyp_correct_count{$hyp_word} ); + $hyp_correct_count{$hyp_word} += 1; + } else { + $sgml_str = 'C,"' . $ref_word . '","",0.000+0.000,0.000000:' . $sgml_str; + } + } elsif ( $tb_str =~ /^S/ ) { + $substitutions +=1; + $log_prob += mylog( 1.0 - $hyp_word_conf ); + $sgml_str = 'S,"' . $ref_word . '","' . $hyp_word . '",' + . $hyp_start_time . '+' . $hyp_end_time . ',' + . $hyp_word_conf . ':' . $sgml_str; + $sub_count{"$ref_word $hyp_word"} = 0 + unless defined( $sub_count{"$ref_word $hyp_word"} ); + $sub_count{"$ref_word $hyp_word"} += 1; + $ref_sub_count{$ref_word} = 0 + unless defined( $ref_sub_count{$ref_word} ); + $ref_sub_count{$ref_word} += 1; + $hyp_sub_count{$hyp_word} = 0 + unless defined( $hyp_sub_count{$hyp_word} ); + $hyp_sub_count{$hyp_word} += 1; + } elsif ( $tb_str =~ /^I/ ) { + $insertions += 1; + $log_prob += mylog( 1.0 - $hyp_word_conf ); + $sgml_str = 'I,,"' . $hyp_word . '",' . $hyp_start_time . '+' . + $hyp_end_time . ',' . $hyp_word_conf + . ':' . $sgml_str; + $ins_count{$hyp_word} = 0 unless defined( $ins_count{$hyp_word} ); + $ins_count{$hyp_word} += 1; + } elsif ( $tb_str =~ /^D/ ) { + $deletions += 1; + $sgml_str = 'D,"' . $ref_word . '",,,:' . $sgml_str; + $del_count{$ref_word} = 0 unless defined( $del_count{$ref_word} ); + $del_count{$ref_word} += 1; + } else { + die "INTERNAL ERROR: Unknown traceback string $tb_str while aligning speaker $spkr reference starting at $st\n"; + } + + $i = $next_i; + $j = $next_j; + } # end while + + my $et = $refreg->{$spkr}{$st}->{end_time}; + if ( $st eq $fake_region_id ) { + print F "Speaker $spkr Hypothesis words outside of reference regions\n"; + } + else { + if (($ref_format eq 'trn') or ($ref_format eq 'kaldi')) { + print F "id: ${spkr}${SPEAKER_SIDE_DELIM}$st\n"; + } + else { + print F "Speaker $spkr Start time $st End time $et\n"; + } + } + print F "Ref: $aligned_ref\n"; + print F "Hyp: $aligned_hyp\n"; + + print F "Scores: ( #C #S #D #I ) = ( $correct $substitutions $deletions $insertions )\n"; + print F $align_str; + print F "\n"; + + my $nreference = $correct + $substitutions + $deletions; + my $nhypothesis = $correct + $substitutions + $insertions; + + print SF '' . "\n"; + chop $sgml_str; + print SF $sgml_str . "\n"; + print SF "\n"; + + # Accumulate statistics + foreach my $t ( split( ',', $refreg->{$spkr}{$st}->{tags} ), + $refreg->{$spkr}{$st}->{speaker}, + "ALL" ) { + my $s = { nref => $nreference, + nhyp => $nhypothesis, + cor => $correct, + sub => $substitutions, + ins => $insertions, + del => $deletions, + logprob => $log_prob }; + foreach my $k ( keys %{ $s } ) { + $stats->{$t}->{$k} += $s->{$k}; + } + } + + } # $foreach $st + + print SF "\n"; + } # foreach $spkr + + close( F ); + + print SF "\n"; + close( SF ); + + # dtl files + + open ( DTL, ">" . $out_pre . ".dtl.sub" ) + or die "Couldn't open ${out_pre}.dtl.sub for writing substitution counts\n"; + print DTL "Substitutions\n\n"; + print DTL "Count Ref_word Hyp_word\n"; + print DTL "---------------------------------------\n"; + foreach my $k ( sort { $sub_count{$b} <=> $sub_count{$a} } keys %sub_count ) { + printf DTL "%5d %-70s\n", $sub_count{$k}, $k; + } + close( DTL ); + + open ( DTL, ">" . $out_pre . ".dtl.ins" ) + or die "Couldn't open ${out_pre}.dtl.ins for writing insertion counts\n"; + print DTL "Insertions\n\n"; + print DTL "Count Hyp_word\n"; + print DTL "---------------------------------------\n"; + foreach my $k ( sort { $ins_count{$b} <=> $ins_count{$a} } keys %ins_count ) { + printf DTL "%5d %-70s\n", $ins_count{$k}, $k; + } + close( DTL ); + + open ( DTL, ">" . $out_pre . ".dtl.del" ) + or die "Couldn't open ${out_pre}.dtl.del for writing deletion counts\n"; + print DTL "Deletions\n\n"; + print DTL "Count Ref_word\n"; + print DTL "---------------------------------------\n"; + foreach my $k ( sort { $del_count{$b} <=> $del_count{$a} } keys %del_count ) { + printf DTL "%5d %-70s\n", $del_count{$k}, $k; + } + close( DTL ); + + open ( DTL, ">" . $out_pre . ".dtl.ref_words" ) + or die "Couldn't open ${out_pre}.dtl.ref_words for writing reference word statistics\n"; + print DTL "Statistics by reference word\n\n"; + printf DTL "%-25s %6s %4s %4s %4s %4s\n", + "Word", "Count", "%Cor", "%Err", "%Sub", "%Del"; + print DTL "---------------------------------------------------------\n"; + foreach my $k ( sort { $ref_count{$b} <=> $ref_count{$a} } keys %ref_count ) { + $ref_correct_count{$k} = 0 unless defined( $ref_correct_count{$k} ); + $ref_sub_count{$k} = 0 unless defined( $ref_sub_count{$k} ); + $del_count{$k} = 0 unless defined( $del_count{$k} ); + printf DTL "%-25s %6s %4d %4d %4d %4d\n", + $k, $ref_count{$k}, + 100 * ( $ref_correct_count{$k} / $ref_count{$k} ), + 100 * ( $ref_sub_count{$k} + $del_count{$k} ) / $ref_count{$k}, + 100 * ( $ref_sub_count{$k} / $ref_count{$k} ), + 100 * ( $del_count{$k} / $ref_count{$k} ); + } + close( DTL ); + + open ( DTL, ">" . $out_pre . ".dtl.hyp_words" ) + or die "Couldn't open ${out_pre}.dtl.hyp_words for writing reference word statistics\n"; + print DTL "Statistics by hypothesis word\n\n"; + printf DTL "%-25s %6s %4s %4s %4s %4s\n", + "Word", "Count", "%Cor", "%Err", "%Sub", "%Ins"; + print DTL "---------------------------------------------------------\n"; + foreach my $k ( sort { $hyp_count{$b} <=> $hyp_count{$a} } keys %hyp_count ) { + $hyp_correct_count{$k} = 0 unless defined( $hyp_correct_count{$k} ); + $hyp_sub_count{$k} = 0 unless defined( $hyp_sub_count{$k} ); + $ins_count{$k} = 0 unless defined( $ins_count{$k} ); + printf DTL "%-25s %6s %4d %4d %4d %4d\n", + $k, $hyp_count{$k}, + 100 * ( $hyp_correct_count{$k} / $hyp_count{$k} ), + 100 * ( $hyp_sub_count{$k} + $ins_count{$k} ) / $hyp_count{$k}, + 100 * ( $hyp_sub_count{$k} / $hyp_count{$k} ), + 100 * ( $ins_count{$k} / $hyp_count{$k} ); + } + close( DTL ); + + + return $stats; + +} # end of align + + +sub print_stats { + + my ( $stats, $out_pre ) = @_; + + my $sysf = $out_pre . ".sys"; + my $rawf = $out_pre . ".raw"; + + if ( $verbosity > 1 ) { + open( F, "| tee $sysf" ) or + die "Couldn't open | tee $sysf for writing\n"; + } + else { + open( F, ">" . $sysf ) or + die "Couldn't open $sysf for writing\n"; + } + + open( RAW, ">" . $rawf ) or + die "Couldn't open $rawf for writing\n"; + + my $format = "%15s %6s %6s %6s %5s %5s %5s %5s %7s\n"; + my $dash_line = ("-" x 79) . "\n"; + + printf F $format, "Label", "#Ref", "#Hyp", "WER", "%Cor", "%Sub", "%Del", "%Ins", "NCE"; + print F $dash_line; + + printf RAW $format, "Label", "#Ref", "#Hyp", "#Err", "#Cor", "#Sub", "#Del", "#Ins", "NCE"; + print RAW $dash_line; + + foreach my $t ( sort keys %{ $stats } ) { + my $label = $t; + $label = $label_names->{$t}->{short} if defined( $label_names->{$t}->{short} ); + + my $st = $stats->{$t}; + + # Prevent divide by zero + $st->{nref} = 1 if ($st->{nref} == 0); + + my $wer = ( $st->{sub} + $st->{del} + $st->{ins} ) / $st->{nref} * 100.0; + my $p_c = ( $st->{nhyp} == 0 ) ? 0.5 : $st->{cor} / $st->{nhyp}; + my $h_max = -$st->{cor} * mylog($p_c) - ($st->{nhyp} - $st->{cor}) * mylog( 1 - $p_c ); + my $NCE = ( $h_max == 0.00 ) ? "XXX" : sprintf( "%7.3f", 1.0 + $st->{logprob} / $h_max ); + $NCE = "n/a" if ( $num_ctm_fields == 5 ); + + print F $dash_line if ( $t eq "ALL" ); + print RAW $dash_line if ( $t eq "ALL" ); + + printf F $format, $label, $st->{nref}, $st->{nhyp}, sprintf( "%6.2f", $wer ), + sprintf( "%5.1f", $st->{cor} / $st->{nref} * 100.0 ), + sprintf( "%5.1f", $st->{sub} / $st->{nref} * 100.0 ), + sprintf( "%5.1f", $st->{del} / $st->{nref} * 100.0 ), + sprintf( "%5.1f", $st->{ins} / $st->{nref} * 100.0 ), + $NCE; + + printf RAW $format, $label, $st->{nref}, $st->{nhyp}, + ( $st->{sub} + $st->{del} + $st->{ins} ), + $st->{cor}, $st->{sub}, $st->{del}, $st->{ins}, $NCE; + + print F $dash_line if ( $t eq "ALL" ); + print RAW $dash_line if ( $t eq "ALL" ); + + } + print F $dash_line; + printf F $format, "Label", "#Ref", "#Hyp", "WER", "%Cor", "%Sub", "%Del", "%Ins", "NCE"; + + print RAW $dash_line; + printf RAW $format, "TAG", "#Ref", "#Hyp", "#Err", "#Cor", "#Sub", "#Del", "#Ins", "NCE"; + + close( F ); + close( RAW ); + + return; +} + +sub mylog +{ + my $x = shift; + + return ( $x > 0 ) ? log( $x) : $log_zero; +} + +sub load_stm +{ + my $ref = shift; + + my $refreg = {}; + my $label_names = {}; + my $category_names = {}; + my $reforder = []; + my $side_from_speaker = 0; + + open( R, $ref ) or die "Can't open stm file $ref for reading\n"; + REF: while( ) { + + if ( /^;;\s*LABEL\s*\"([^\"]*)\"\s*\"([^\"]*)\"\s*\"([^\"]*)\"/ ) { + my $label = $1; + warn "Previously defined label $label is redefined on STM file $ref line $_\n" + if ( defined( $label_names->{$label} ) + && ( $label_names->{$label}->{short} ne $2 ) ); + die "Label '$label' may not contain spaces on STM file $ref line $_\n" + if ( $label =~ / / ); + $label_names->{$label} = { short => $2, long => $3 }; + } + + if ( /^;;\s*CATEGORY\s*\"([^\"]*)\"\s*\"([^\"]*)\"\s*\"([^\"]*)\"/ ) { + my $category = $1; + warn "Previously defined category $category is redefined on STM file $ref line $_\n" + if ( defined( $category_names->{$category} ) + && ( $category_names->{$category}->{short} ne $2 ) ); + die "Category '$category' may not contain spaces on STM file $ref line $_\n" if ( $category =~ / / ); + $category_names->{$category} = { short => $2, long => $3 }; + } + + next REF if (/^;/ or /^\s*$/); + + my @f = split; + (@f >= 6) or die "Stm file $ref line\n$_ doesn't have enough fields. (It must have wavefile channel speaker start_time end_time tag)\n"; + my ($wavefile) = fileparse($f[0], qr/\.[^.]*$/); + my $channel = $f[1]; + my $speaker = $f[2]; + my $start_time = $f[3]; + my $end_time = $f[4]; + my $tag = $f[5]; + $tag =~ /^\<(.*)\>$/ or die "Couldn't parse tag field $tag of stm file $ref line $_ ; tag field must start with < and end with >\n"; + $tag = $1; + my $words = join( ' ', @f[6 .. $#f] ); + + if ( $end_time < $start_time + 0.0001 ) { + print "WARNING: For stm file $ref line $_ the end time $end_time isn't after the start time $start_time plus 0.0001 .\n"; + $end_time = $start_time + 0.0001; + } + + # For first cut, stm file utterances = scorable regions + my $side; + if ($channel =~ /^[A-Z]$/) { # usually will be just 'A' or 'B', but some intermediate scripts can have a channel of 'X' + $side = $wavefile . "_" . $channel; + } + else { + $side = $channel; + $side_from_speaker = 1; + } + + push(@$reforder, $side) unless $refreg->{$side}; + + # Check for overlapping reference regions + foreach my $st ( keys %{ $refreg->{$side} } ) { + my $et = $refreg->{$side}{$st}->{end_time}; + if ( ( ( $start_time > $st ) && ( $start_time < $et ) ) + || ( $end_time > $st ) && ( $end_time < $et ) ) { + warn "STM line $_ overlaps with STM utterance starting at $st and ending at $et\n\n"; + } + } + + $refreg->{$side}{$start_time} = { end_time => $end_time, + tags => $tag, + words => $words, + speaker => $speaker, + wavefile => $wavefile, + channel => $channel }; + $refreg->{$side}{$fake_region_id} = { end_time => $fake_region_id, + tags => $tag, + words => "", + speaker => $speaker, + wavefile => $wavefile, + channel => $channel } + unless defined( $refreg->{$side}{$fake_region_id} ); + + } + close( R ); + + return ( $refreg, $label_names, $category_names, $reforder, $side_from_speaker ); +} + +sub load_snor +{ + my $filename = shift; + my $fmt = shift; + + my $reg = {}; + my $label_names = {}; + my $category_names = {}; + my $order = []; + + open( R, $filename ) or die "Can't open trn file $filename for reading\n"; + RECORD: while( ) { + chomp; + next RECORD if (/^\s*$/); + + my @f = split; + next RECORD unless @f; + + my $snorIdField; + if ($fmt eq "trn") { + $snorIdField = pop(@f); + $snorIdField = StripParens($snorIdField); + } elsif ($fmt eq "kaldi") { + $snorIdField = shift(@f); + } else { + die "load_snor(): unknown format \"$fmt\"! "; + } + my $side; + my $uttIndex; + + if ($snorIdField =~ m/^(\S+)([_-])(\d+)$/) { + $side = $1; + $SPEAKER_SIDE_DELIM = $2; + $uttIndex = $3; + } else { + $side = $snorIdField; + $uttIndex = "1"; + } + + unless (defined($side) and defined($uttIndex)) { + die "Transcript (SNOR) file $filename bad SNOR id $snorIdField on line $_\n"; + } + + my $words = join( ' ', @f); + + push(@$order, $side) unless $reg->{$side}; + + # Check for repeated index within speaker + if ( $reg->{$side}{$uttIndex} ) { + warn "Transcript (SNOR) line side $side index $uttIndex repeated\n\n"; + } +# XXXX later need to decide exactly what field values will be + $reg->{$side}{$uttIndex} = { end_time => "XXX", + tags => "", + words => $words, + speaker => $side, + wavefile => "XXX", + channel => "XXX" }; +# XXXX do we need fake_region_id? +# XXXX MD disabling + # $reg->{$side}{$fake_region_id} = { end_time => $fake_region_id, + # tags => "", + # words => "", + # speaker => $side, + # wavefile => "XXX", + # channel => "XXX", } + # unless defined( $reg->{$side}{$fake_region_id} ); + + } + close( R ); + +# Note: label_names, category_names always empty for SNOR file + return ( $reg, $label_names, $category_names, $order, 0 ); +} + +sub StripParens { + my ($str) = @_; + # Return contents of one-level of matched, enclosing parentheses + # Strips the parens and adjoining space + # (More general than needs to be: allows internal blanks + # in the contained string, e.g., + # " ( Hi there )" --> "Hi there" + # if ( $str =~ /^\s*\(\s*(\S+(\s+\S+)*)\s*\)\s*$/ ) { + # Changed my mind, just keep it simple, allow no spaces: + if ( $str =~ /^\((\S+)\)$/ ) { + return $1; + } + return undef; +} + +sub ParseSNORID { + my ($snorId) = @_; + # This must stupidly assume that ids are in form side-uttindex, eg, sw2001-A-0001. + if ( $snorId && $snorId =~ /^(\S+)-(\d+)$/ ) { + return ($1, $2); + } + return (undef, undef); +} + +sub MakeNetworkFromText { + my ($lat) = @_; + + @$lat{"arcs", "nodes"} = ( [], [] ); + push @{ $lat->{nodes} }, { in_arcs => [], out_arcs => [ 0 ] }; + my $words = $lat->{words}; + return unless $words; + + foreach my $word ( split( ' ', $words) ) { + my $last_node_id = $#{ $lat->{nodes} }; + push @{ $lat->{arcs} }, { src => $last_node_id, + dst => $last_node_id + 1, + word => $word, + conf => 0.5 }; + $lat->{nodes}->[$last_node_id]->{out_arcs} = [ $#{ $lat->{arcs} } ]; + push @{ $lat->{nodes} }, { in_arcs => [ $#{$lat->{arcs} } ], + out_arcs => [] }; + } +} + +sub check_conf +{ + my ( $hyp, $line, $conf, $num_ctm_fields ) = @_; + + if ( defined( $conf ) ) { + # check that conf value is valid, if it isn't verify that the decoding type is being called correctly, especially DecodeFastFWBW + die "On ctm file $hyp line $line confidence value $conf isn't valid numeric value." unless ( $conf =~ /^\s*[-+]?[0-9]*(?:[0-9]|\.[0-9]*)?(?:[eE][-+]?[0-9]+)?\s*$/); + + if ( $num_ctm_fields == 5 ) { + warn "CTM file $hyp started out having five fields, but line $line has six!\n"; + } else { + $num_ctm_fields = 6; + } + } else { + $conf = 0.5; + if ( $num_ctm_fields == 6 ) { + warn "CTM file $hyp started out having six fields, but line $line has only five!\n"; + } else { + $num_ctm_fields = 5; + } + } + + die "On ctm file $hyp line $line confidence value $conf isn't between 0 and 1\n" + if ( ( $conf > 1.0 ) || ( $conf < 0.0 ) ); + + return ( $conf, $num_ctm_fields ); +} + +sub load_ctm +{ + my $hyp = shift; + my $side_from_speaker = shift; + + my $hypreg = {}; + my $num_ctm_fields = 0; + my $curr_spkr = undef; + + open( H, $hyp ) or die "Can't open ctm file $hyp for reading\n"; + HYP: while( ) { + #next HYP if ( /^[;#]/ or /^\s*$/ ); + if ( /^[;#]/ ) { + # Save the speaker ID from the comment that starts a new utterance. If the STM was + # indexed by speaker ID, we will use this to look up the matching reference transcription. + if ( /spkr (\S+)/ ) { + $curr_spkr = $1; + } + next HYP; + } + next HYP if ( /^\s*$/ ); + + my ($wavefile, $channel, $start_time, $duration, $word, $conf, @foo) = split; + (defined($word) && !(@foo)) or die "Ctm file $hyp line $_ doesn't have five or six fields\n(It must have wavefile channel start_time end_time word [confidence].)\n"; + + # Extract id from file + $wavefile = fileparse($wavefile, qr/\.[^.]*$/); + + # Assign it a scorable region + my $side; + if ($side_from_speaker && defined($curr_spkr)) { + $side = $curr_spkr; + } + else { + $side = $wavefile . "_" . $channel; + } + + if ( $word eq "" ) { + my $orig_wavefile = $wavefile; + my $orig_channel = $channel; + my @alt_hyps = ( [] ); + my $i = 0; + + my $region_start = 99999999; + my $region_end = -99999; + + ALT_LINE: while ( ) { + my ($wavefile2, $channel2, $start_time2, $duration2, $word2, $conf2, @foo2) = split; + $wavefile2 = fileparse($wavefile2, qr/\.[^.]*$/); + die "Wavefile switched from $orig_wavefile to $wavefile2 inside block at ctm file $hyp line $_" unless ( $orig_wavefile eq $wavefile2); + die "Channel switched from $orig_channel to $channel2 inside block at ctm file $hyp line $_" unless ( $orig_channel eq $channel2 ); + + (defined($word2) && !(@foo2)) or die "Ctm file $hyp line $_ doesn't have five or six fields\n(It must have wavefile channel start_time end_time word [confidence].)\n"; + + if ( $word2 eq "" ) { + $i++; + $alt_hyps[$i] = []; + next ALT_LINE; + } + if ( $word2 eq "" ) { + last ALT_LINE; + } + + ($conf2,$num_ctm_fields) = check_conf( $hyp, $_, $conf2, $num_ctm_fields); + + push @{ $alt_hyps[$i] }, [$start_time2, $start_time2 + $duration2, $word2, $conf2]; + + $region_start = min( $region_start, $start_time2 ); + $region_end = max( $start_time2 + $duration2, $region_end ); + } + + # Put the block into a hypreg + my $best_region = find_best_region( $side, $region_start, $region_end ); + + $hypreg->{$side}{$best_region} = { arcs => [], + nodes => [ { in_arcs => [], + out_arcs => [] } ] } + unless defined( $hypreg->{$side}{$best_region} ); + + my $alt_start_node_id = $#{ $hypreg->{$side}{$best_region}->{nodes} }; + my @arc_ids_to_fix = (); + + foreach my $i ( 0 .. $#alt_hyps ) { + my $start_node_id = $alt_start_node_id; + foreach my $j ( 0 .. $#{ $alt_hyps[$i] } ) { + + push @{ $hypreg->{$side}{$best_region}->{arcs} }, + { word => $alt_hyps[$i]->[$j]->[2], + conf => $alt_hyps[$i]->[$j]->[3], + start_time => $alt_hyps[$i]->[$j]->[0], + end_time => $alt_hyps[$i]->[$j]->[1], + src => $start_node_id, + dst => $#{ $hypreg->{$side}{$best_region}->{nodes} } + 1, + }; + + push @{ $hypreg->{$side}{$best_region}->{nodes}->[$start_node_id]->{out_arcs} }, $#{ $hypreg->{$side}{$best_region}->{arcs} }; + + if ( $j != $#{ $alt_hyps[$i] } ) { + push @{ $hypreg->{$side}{$best_region}->{nodes} }, + { in_arcs => [ $#{ $hypreg->{$side}{$best_region}->{arcs} } ], + out_arcs => [] }; + $start_node_id = $#{ $hypreg->{$side}{$best_region}->{nodes} }; + } else { + push @arc_ids_to_fix, + $#{ $hypreg->{$side}{$best_region}->{arcs} }; + } + + } # foreach $j + } # foreach $i + + push @{ $hypreg->{$side}{$best_region}->{nodes} }, + { in_arcs => [ @arc_ids_to_fix ], out_arcs => [] }; + foreach my $arc_id ( @arc_ids_to_fix ) { + $hypreg->{$side}{$best_region}->{arcs}->[$arc_id]->{dst} = + $#{ $hypreg->{$side}{$best_region}->{nodes} }; + } + + next HYP; + } # if $word eq "" + + my $end_time = $start_time + $duration; + + ($conf,$num_ctm_fields) = check_conf( $hyp, $_, $conf, $num_ctm_fields); + + my $best_region = find_best_region( $side, $start_time, $end_time ); + + $hypreg->{$side}{$best_region} = { arcs => [], + nodes => [ { in_arcs => [], + out_arcs => [] } ] } + unless defined( $hypreg->{$side}{$best_region} ); + + my $last_node_id = $#{ $hypreg->{$side}{$best_region}->{nodes} }; + push @{ $hypreg->{$side}{$best_region}->{arcs} }, + { src => $last_node_id, dst => $last_node_id + 1, word => $word, + conf => $conf, start_time => $start_time, end_time => $end_time }; + $hypreg->{$side}{$best_region}->{nodes}->[$last_node_id]->{out_arcs} + = [ $#{ $hypreg->{$side}{$best_region}->{arcs} } ]; + push @{ $hypreg->{$side}{$best_region}->{nodes} }, + { in_arcs => [ $#{ $hypreg->{$side}{$best_region}->{arcs} } ], + out_arcs => [] }; + + } # while + close( H ); + + return( $hypreg, $num_ctm_fields ); +} + + +sub min { + my ( $a, $b ) = @_; + return ( $a < $b ) ? $a : $b; +} + +sub max { + my ( $a, $b ) = @_; + return ( $a > $b ) ? $a : $b; +} + +sub find_best_region { + my ( $side, $start_time, $end_time ) = @_; + + if ( !defined( $refreg->{$side} ) ) { + die "Wavefile+channel $side from ctm file $hyp line $_ wasn't seen in the stm file reference.\n"; + } + + my $best_region = $fake_region_id; + my $dist_to_best_region = 9999999999; + + ST: foreach my $st ( keys %{ $refreg->{$side} } ) { + next ST if ( $st eq $fake_region_id ); + my $et = $refreg->{$side}{$st}->{end_time}; + my $dist = 0; + if ( $start_time < $st ) { + $dist = $st - $start_time; + } + if ( $end_time > $et ) { + $dist += $end_time - $et; + } + if ( $dist < $dist_to_best_region ) { + $best_region = $st; + $dist_to_best_region = $dist; + } + } + + return $best_region; +} + +sub update_cost { + my ( $i, $j, $ref_arc, $hyp_arc, $move_cost, $str ) = @_; + + die "Disconnected lattice at $i, $j $str" if ( $move_cost > $max_cost ); + + if ( $move_cost < $cost[$i][$j] ) { + $cost[$i][$j] = $move_cost; + $traceback[$i][$j] = { ref_arc => $ref_arc, + hyp_arc => $hyp_arc, + str => $str }; + print "New lowest cost $move_cost for $i,$j with $str\n" if ( $debug ); + } + +} + +sub print_lattice { + my ( $lat ) = @_; + + return unless defined( $lat->{nodes} ); + + print "Nodes:\n"; + for ( my $n = 0 ; $n <= $#{ $lat->{nodes} }; $n++ ) { + print " $n in_arcs = ", join( ' ', @{ $lat->{nodes}->[$n]->{in_arcs} } ), + " out_arcs = ", join( ' ', @{ $lat->{nodes}->[$n]->{out_arcs} } ), + "\n"; + } + print "Arcs:\n"; + for ( my $a = 0 ; $a <= $#{ $lat->{arcs} }; $a++ ) { + print " $a word = ", $lat->{arcs}->[$a]->{word}, " src = ", + $lat->{arcs}->[$a]->{src}, " dst = ", $lat->{arcs}->[$a]->{dst}, "\n"; + } + +} +__END__ + +=head1 NAME + +scorer.pl - Score speech recognition system output + +=head1 SYNOPSIS + +scorer.pl STM-reference-file CTM-hypothesis-file [output-filename-prefix] + +=head1 DESCRIPTION + +scorer.pl aligns the words in the CTM-hypothesis-file against the STM-reference-file +and then prints out various statistics of the alignment, including the +word error rate (WER), to standard out and to files beginning with +output-filename-prefix (or CTM-hypothesis-file if output-filename-prefix +is not given). It is intended as a replacement for sclite(1). + +=head1 OPTIONS + +Currently, scorer.pl takes no options. + +=head1 ALIGNMENT + +The alignment process consists of two steps. In the first step, each word from +the CTM hypothesis file is assigned to an utterance from the STM-reference-file. +In the second step, the reference words in each utterance are aligned with +the hypothesis words assigned to that utterance so as to minimize a +Levenshtein edit-distance function with correct words, insertions, deletions +and substitutions given costs of 0, 3, 3, and 4 respectively. (Inserting an +optionally deletable word counts as correct, but is given a cost of 2 for +alignment purposes.) + +=head1 STM FILE FORMAT + +STM (Segment Time Mark) files are text files, any line of which can be either +a blank line, a comment line, a label declaration line, or a regular line. Blank +and comment lines are ignored. + +Comment lines begin with a semicolon character and may then consist of any +number of non-new line characters. [Note: sclite requires STM comment lines +to begin with two semicolon characters.] + +Label declaration lines begin with two semicolons followed by optional whitespace +followed by the word "LABEL". Next comes three strings, each of which is delimmited +on both ends by double quotes ("), with optional whitespace between the strings. +The first string is the label tag used to mark utterances. It may not contain spaces. +The second string is the short label description, and is used when presenting +summary statistics for the utterances belonging to the label. The third string +is a long label description, and is currently unused. + +Here are some example label declarations: + + ;; LABEL "F" "Female" "Female Speakers" + ;; LABEL "FISH" "Fisher" "Fisher Speakers" + ;; LABEL "CH-M" "Callhome Male" "Male Callhome Speakers" + +Label declarations may be grouped in category sections, which are declared with +lines that look like + + ;; CATEGORY "0" "" "" + +scorer.pl currently does not use category information in any way. + +Regular STM file lines give the transcription and time information for reference +utterances, and consist of at least six whitespace separated fields. The meaning +of the fields is as follows: + +=over 4 + +=item Field 1: + +Audio file identifier. Typically this is the basename of the audio +file, without any path information or file type suffixes (like ".sph" or ".wav"). + +=item Field 2: + +Channel identifier. Typically "A" for channel 1, and "B" for +channel 2. + +=item Field 3: + +Speaker identifier. Typically this is the audio file identifier +followed by an underscore (_) followed by the channel identifier. + +=item Field 4: + +Utterance begin time in seconds, as counted from the beginning of +the audio file. Typically specified to 1/100ths of a second. + +=item Field 5: + +Utterance end time in seconds. + +=item Field 6: + +Label tags for this utterance. The label tags should be separated by commas (,) and enclosed by < and >. For example: . If there are no label tags +for the utterance, the string <> is expected. [Note: unlike in sclite, this field is +mandatory.] + +=item Fields 7+ (Optional): + +The words for this utterance. Any words enclosed with +parenthesis are considered to be "optionally deletable": if no hypothesis word +aligns to the optionally deletable word, then it is counted as correct. For example, +if "(%HESITATION)" is the sole reference word for an utterance, then either +%HESITATION or no hypothesis for the utterance will assigned 1 correct word and +0 errors for the utterance. If a single non-%HESITATION word is hypothesized, +it will be counted as a substitution. If a optionally deletable word ends with +a dash, then it is considered to be a word fragment and any hypothesis word that +matches the word upto dash will be considered correct. For example, the hypothesis +MOLD would be correct if aligned to (MOL-), (MO-), or (M-). + +The reference words can be in any encoding scheme in which +the bytes for whitespace, new lines, parenthesis, and dash (ascii 9, 10, 13, 32, 40, +41, and 45) always represent themselves. This is true for UTF-8 and (I believe) +EUC-JP, but not (I believe) for UTF-16 or GB18030-2000. In addition, if the +encoding scheme contains multiple byte sequences that code for the same character, +then the reference and hypothesis words should both be normalized into an +encoding subset for which every character has an unique byte sequence. + +=back + +=head1 CTM FILE FORMAT + +CTM (Conversation Time Marked) files are text files, any line of which may a +blank line, a comment line, or a regular line. As with STM files, comment lines +begin with a semicolon (;) character, and blank and comment lines are ignored. + +Regular CTM file lines give the information for a single hypothesis word, and +consist of either five or six whitespace separated fields: + +=over 4 + +=item Field 1: + +Audio file identifier. As in the STM file. + +=item Field 2: + +Channel identifier. As in the STM file. + +=item Field 3: + +Word start time in seconds, as counted from the beginning of +the audio file. Typically specified to 1/100ths of a second. + +=item Field 4: + +Word duration in seconds. + +=item Field 5: + +The hypothesis word. + +=item Field 6 (Optional): + +A confidence score for the hypothesis word. The +score must be between 0 and 1 inclusive. + +=back + +A CTM file may also contain alternate hypothesis paths. These are typically +the result of filtering an initial CTM file with a GLM mapping file, and +are intended to deal with hypothesis words that have multiple valid transcriptions. +Alternate hypothesis pathes are described by a format that looks like + + fsh_109487 A 90.500 0.210 WHAT 0.645777 + fsh_109487 A * * + fsh_109487 A 90.710 0.230 THAT'S 0.347474 + fsh_109487 A * * + fsh_109487 A 90.710 0.115 THAT 0.347474 + fsh_109487 A 90.825 0.115 IS 0.347474 + fsh_109487 A * * + fsh_109487 A 90.710 0.115 THAT 0.347474 + fsh_109487 A 90.825 0.115 HAS 0.347474 + fsh_109487 A * * + fsh_109487 A 94.240 0.320 JUST 0.884898 + +Specifically, the alternate paths should be surrounded by the tokens + and in the CTM file word field, and the alternate +paths should be separated by s, also in the word field. In all of +these cases, fields 3 and 4 should contain only single asterisks. + +For a particular audio file/channel combination, the words in a CTM file must +appear in order of increasing start time. The UNIX command +"sort +0 -1 +1 -2 +2nb -3" will accomplish this while also sorting the +conversations into an order sclite likes, but only if the CTM file does not +contain regions. + +=head1 OUTPUT + +scorer.pl outputs four files: the .sys, .raw, .sgml and .pra files. These are +written to output-filename-prefix plus the suffix; if output-filename-prefix +is not given, ctm-hypothesis-file is used as the output filename prefix +instead. The .sys file is additionally written to standard output. + +=head2 The .sys File + +The .sys file contains the following statistics for every label, every speaker, +and for the entire test set ("ALL"): + +=over 4 + +=item #Ref = number of words in the reference STM file + +=item #Hyp = number of words in the hypothesis CTM file + +=item WER = the word error rate = ( #_substitutions + #_deletions + #_insertions ) / #_reference_words + +=item %Cor = percentage correct = #_correct / #_reference_words + +=item %Sub = percentage substitutions = #_substitutions / #_reference_words + +=item %Del = percentage deletions = #_deletions / #_reference_words + +=item %Ins = percentage insertions = #_insertions / #_reference_words + +=item NCE = Normalized Cross Entropy, a measure of the goodness of the confidence +values in the CTM file. It is calculated using the following formula: + + NCE = 1 - LL / ( #Cor log(p_c) + (#Hyp - #Cor) log(1-p_c) ) + LL = Log Likelihood of Confidence Values + = sum_{w correct} log( conf(w) ) + sum_{w incorrect} log( 1 - conf(w) ) + p_c = (ML Estimate of) Probability of Correctness + = #Cor / #Hyp + +In all of the above formulas, log( 0 ) is replaced with -1000 whenever it +occurs. + +=back + +=head2 The .raw File + +The .raw file contains the following statistics for every label, every speaker, +and for the entire test set ("ALL"): + +=over 4 + +=item #Ref = number of words in the reference STM file + +=item #Hyp = number of words in the hypothesis CTM file + +=item #Err = number of errors = #_substitutions + #_deletions + #_insertions + +=item #Cor = number of correct hypothesis words + +=item #Sub = number of substitutions + +=item #Del = number of reference words deleted + +=item #Ins = number of hypothesis words inserted + +=item NCE = Normalized Cross Entropy, see above description under L<"The .sys File"> + +=back + +=head2 The .pra File + +The .pra file contains alignment information for each STM file reference +utterance. Here is an example: + + Speaker fsh_110103_A Start time 123.44 End time 127.61 + Ref: (%HESITATION) TOPIC IS NEEDED OR OR WHERE THEY HAVE + Hyp: OUR TOPIC IS NEEDED OR WHAT THEY HAVE + Scores: ( #C #S #D #I ) = ( 7 1 1 1 ) + INSERTION: OUR + CORRECT (Opt. Del.): (%HESITATION) + CORRECT: TOPIC + CORRECT: IS + CORRECT: NEEDED + DELETION: OR + CORRECT: OR + SUBSTITION: hypothesis WHAT for reference WHERE + CORRECT: THEY + CORRECT: HAVE + +The utterance start and end times are given in seconds. #C, #S, #D, and #I stand for number of correct, substitution, deletion and insertion words, respectively. "Opt. Del." stands for optionally deletable (see L<"STM FILE FORMAT"> above). + +Note that scorer.pl's .pra file output format is rather different than sclite's. + +=head2 The .sgml File + +The .sgml file also contains alignment information, but in a slightly +more computer parseable format. Here is an example: + + + + + I,,"YEAH",86.18+86.62,0.911505:C,"YEAH","YEAH",87.31+87.74,0.943606 + + + S,"YUP","YEAH",227.04+227.42,0.940359 + + + + +The alignment information for each reference utterance is described by a +colon delimmited list. Each alignment step is described by either + + C,"ref_word","hyp_word",start_time+end_time,confidence [CORRECT] + S,"ref_word","hyp_word",start_time+end_time,confidence [SUBSTITUTION] + I,,"hyp_word",start_time+end_time,confidence [INSERTION] + D,"ref_word",,, [DELETION] + +scorer.pl's .sgml file output is intended to be 100% compatible with sclite's, +with the one exception of PATH id names: scorer.pl's are + wavefile_channel-starttime-endtime +while sclite's are + wavefile_channel-number + +=head1 ADVANTAGES OVER SCLITE + +Better error messages. + +Less finicky about input: conversations do not need to appear in any particular +order, and it's okay if there are no hypothesis words for a speaker. + +Totally case sensitive: scorer.pl never upper or lower cases anything. + +Fewer "special" characters: words can now contain semicolons (;) and less than +signs (<), for example. + +Small, easy to maintain implementation. + +Summary statistics per label are put in the .sys file, rather than hidden +in separate .lur files. + +=head1 CAVEATS + +Some of sclite's output files aren't supported: .det and .hist plots and +.lur files. No sentence/utterance statistics or median statistics are output. + +"IGNORE_TIME_SEGMENT_IN_SCORING" segments are properly ignored for scoring, but +they produce no alignment information. (The effected hypothesis words should +be given "IGNORED" alignment tags in the .pra and .sgml files, but aren't.) + +If there isn't at least a 2-3 second gap between two reference utterances, +they should be joined together for the purposes of aligning the hypothesis +words (and then separated again when outputing the alignment statistics). + +Arguably, %Ins would make more sense as #_insertions / #_hypothesis_words, +but tradition and consistency define it as #_insertions / #_reference_words. + +Nested regions in CTM files and multiple reference paths in STM files +are not supported. + +There should be command line options to fiddle with various things +(insertion/deletion/substitution costs, whether to match word fragments, +which output files to produce, etc.). + +=head1 AUTHOR + +Thomas Colthurst, thomasc@bbn.com. Z + +=head1 COPYRIGHT + +Copyright 2005 by BBN Technologies. + diff --git a/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/char_trans_utf8_to_uxxxx.py b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/char_trans_utf8_to_uxxxx.py new file mode 100644 index 00000000000..81e38f1b109 --- /dev/null +++ b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/char_trans_utf8_to_uxxxx.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 + +""" This script convertss the text from character format to + hexadecimal format (uxxxx). + Eg. char_trans_utf8_to_uxxxx.py +""" + +import sys +from snor import SnorIter + +if len(sys.argv) != 3: + print("Usage: char_trans_utf8_to_uxxxx.py ") + sys.exit(1) + +input_file = sys.argv[1] +output_file = sys.argv[2] + + +def main(): + + with open(input_file, 'r', encoding='utf-8') as fh, open(output_file, 'w', encoding='utf-8') as fh_out: + for utt, uttid in SnorIter(fh): + for char in utt.split(): + if char == "": + fh_out.write("u0020 ") + else: + fh_out.write(utf8_char_to_uxxxx(char)) + fh_out.write(" ") + + # Finally write out uttid and newline + fh_out.write("(%s)\n" % uttid) + + +def utf8_char_to_uxxxx(char): + raw_hex = hex(ord(char))[2:].zfill(4).lower() + uxxxx_char = "u%s" % raw_hex + return uxxxx_char + +if __name__ == "__main__": + main() diff --git a/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/filter_ids.py b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/filter_ids.py new file mode 100644 index 00000000000..ee55a900fbf --- /dev/null +++ b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/filter_ids.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 + +""" This script is used for partial scoring, it will remove the given + utterance ids and will score with the remailing utterance ids. + Eg. filter_ids.py +""" + +import unicodedata +import sys +from snor import SnorIter + +if len(sys.argv) != 4: + print("Usage: filter_ids.py ") + sys.exit(1) + +input_ids_file = sys.argv[1] +input_trans = sys.argv[2] +output_trans = sys.argv[3] + +def main(): + + # First load ids to filter out of transcript + ids_to_filter = set() + with open(input_ids_file, 'r') as fh: + for line in fh: + ids_to_filter.add(line.strip()) + + # Now load input transcript and filter out the ids + with open(input_trans, 'r', encoding='utf-8') as fh, open(output_trans, 'w', encoding='utf-8') as fh_out: + for utt, uttid in SnorIter(fh): + if uttid in ids_to_filter: + continue + + fh_out.write("%s (%s)\n" % (utt, uttid)) + + + +if __name__ == "__main__": + main() diff --git a/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/find_missing_hyp_ids.py b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/find_missing_hyp_ids.py new file mode 100644 index 00000000000..e578ea6cc54 --- /dev/null +++ b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/find_missing_hyp_ids.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 + +""" This script finds and prints the hypothesis utterance ids which + are not present in the reference utterance ids. + Eg. find_missing_hyp_ids.py +""" + +import sys +from snor import SnorIter + +if len(sys.argv) != 3: + print("Usage: find_missing_hyp_ids.py ") + sys.exit(1) + +hyp_file = sys.argv[1] +ref_file = sys.argv[2] + +def main(): + + with open(hyp_file, 'r', encoding='utf-8') as hyp_fh, open(ref_file, 'r', encoding='utf-8') as ref_fh: + ref_ids = set() + for utt, uttid in SnorIter(ref_fh): + ref_ids.add(uttid) + + for utt, uttid in SnorIter(hyp_fh): + if uttid not in ref_ids: + print(uttid) + +if __name__ == "__main__": + main() diff --git a/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/insert_empty_hyp.py b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/insert_empty_hyp.py new file mode 100644 index 00000000000..fa9e51e38fc --- /dev/null +++ b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/insert_empty_hyp.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 + +""" This script adds ids with empty utterance. It is used during scoring + in cases where some of the reference ids are missing in the hypothesis. + Eg. insert_empty_hyp.py +""" + +import sys +from snor import SnorIter + +if len(sys.argv) != 4: + print("Usage: insert_empty_hyp.py ") + sys.exit(1) + +ids_file = sys.argv[1] +hyp_in_file = sys.argv[2] +hyp_out_file = sys.argv[3] + +def main(): + + with open(hyp_in_file, 'r', encoding='utf-8') as hyp_in_fh, open(hyp_out_file, 'w', encoding='utf-8') as hyp_out_fh, open(ids_file, 'r') as ids_fh: + # First just copy input hyp file over + for line in hyp_in_fh: + hyp_out_fh.write(line) + + # Now add missing ids + + for line in ids_fh: + uttid = line.strip() + hyp_out_fh.write("(%s)\n" % uttid) + +if __name__ == "__main__": + main() diff --git a/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/normalize_common.py b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/normalize_common.py new file mode 100644 index 00000000000..5bab669d175 --- /dev/null +++ b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/normalize_common.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 + +""" This script normalizes a text file. It performs following normalizations: + dots/filled-circles to periods, # variuos dashes to regular hyphen, full + width left/right-paren to regular left/right paren. + Eg. normalize_common.py +""" +import sys +from snor import SnorIter + +if len(sys.argv) != 3: + print("Usage: normalize_common.py ") + sys.exit(1) + +input_file = sys.argv[1] +output_file = sys.argv[2] + +def main(): + + with open(input_file, 'r', encoding='utf-8') as fh, open(output_file, 'w', encoding='utf-8') as fh_out: + for utt, uttid in SnorIter(fh): + for char in utt: + if char == "\u25cf" or char == "\u2022" or char == "\u2219": + # Convert "dots"/"filled-circles" to periods + fh_out.write("\u002e") + elif char == "\u2010" or char == "\u2011" or char == "\u2012" or char == "\u2013" or char == "\u2014" or char == "\u2015": + # Change variuos Unicode dashes to Reular hyphen + fh_out.write("\u002d") + elif char == "\uff09": + # Change Full width right-paren to regular paren + fh_out.write("\u0029") + elif char == "\uff08": + # Change Full width left-paren to regular paren + fh_out.write("\u0028") + else: + # Otherwise just apapend char w/o modification + fh_out.write(char) + + # Finally, print out uttid and newline + fh_out.write(" (%s)\n" % uttid) + + +if __name__ == "__main__": + main() diff --git a/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/normalize_farsi.py b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/normalize_farsi.py new file mode 100644 index 00000000000..aa6205fee51 --- /dev/null +++ b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/normalize_farsi.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 + +""" This script normalizes a text file. It performs following normalizations: + remove tatweel, vowels and hamza, RTL and LTR marks, + convert arabic keheh to arabic kaf, Farsi Yeh to Arabic Yeh, + Extended (farsi) arabic-indic digit to regular arabic-indic digit, + arabic comma to regular comma. + Eg. normalize_farsi.py +""" + +import sys +from snor import SnorIter + +if len(sys.argv) != 3: + print("Usage: normalize_farsi.py ") + sys.exit(1) + +input_file = sys.argv[1] +output_file = sys.argv[2] + +def main(): + + with open(input_file, 'r', encoding='utf-8') as fh, open(output_file, 'w', encoding='utf-8') as fh_out: + for utt, uttid in SnorIter(fh): + for char in utt: + # First, convert from presentation form to base form + if char in PRESENTATION_TO_BASE: + char = PRESENTATION_TO_BASE[char] + + # Next, handle character-level transformations + if char == "\u0640": + # remove tatweel + continue + elif char == "\u064b" or char == "\u064c" or char == "\u064d" or char == "\u064e" or char == "\u064f" or char == "\u0650" or char == "\u0651" or char == "\u0652" or char == "\u0653" or char == "\u0654" or char == "\u0655": + # remove vowels and hamza + continue + elif char == "\u200f" or char == "\u200e": + # remove RTL and LTR marks + continue + elif char == "\u06a9": + # u06a9 (arabic keheh) -> u0643 (arabic kaf) + fh_out.write("\u0643") + elif char == "\u06cc": + # u06cc (Farsi Yeh) -> u064a (Arabic Yeh) + fh_out.write("\u064a") + elif char == "\ufdfc": + # Transform ligature for RIAL sign -> seq of chars for rial sign + fh_out.write("\u0631\u06cc\u0627\u0644") + elif char == "\u06f0": + # Extended (farsi) arabic-indic digit -> regular arabic-indic digit + fh_out.write("\u0660") + elif char == "\u06f1": + # Extended (farsi) arabic-indic digit -> regular arabic-indic digit + fh_out.write("\u0661") + elif char == "\u06f2": + # Extended (farsi) arabic-indic digit -> regular arabic-indic digit + fh_out.write("\u0662") + elif char == "\u06f3": + # Extended (farsi) arabic-indic digit -> regular arabic-indic digit + fh_out.write("\u0663") + elif char == "\u06f4": + # Extended (farsi) arabic-indic digit -> regular arabic-indic digit + fh_out.write("\u0664") + elif char == "\u06f5": + # Extended (farsi) arabic-indic digit -> regular arabic-indic digit + fh_out.write("\u0665") + elif char == "\u06f6": + # Extended (farsi) arabic-indic digit -> regular arabic-indic digit + fh_out.write("\u0666") + elif char == "\u06f7": + # Extended (farsi) arabic-indic digit -> regular arabic-indic digit + fh_out.write("\u0667") + elif char == "\u06f8": + # Extended (farsi) arabic-indic digit -> regular arabic-indic digit + fh_out.write("\u0668") + elif char == "\u06f9": + # Extended (farsi) arabic-indic digit -> regular arabic-indic digit + fh_out.write("\u0669") + elif char == "\u060c": + # Change Arabic comma to Reular Comma + fh_out.write("\u002c") + else: + # Otherwise just apapend char w/o modification + fh_out.write(char) + + # Finally, print out uttid and newline + fh_out.write(" (%s)\n" % uttid) + + +BASE_TO_PRESENTATION = { + # ARABIC LETTER HAMZA + '\u0621': ('\uFE80', '', '', ''), + # ARABIC LETTER ALEF WITH MADDA ABOVE + '\u0622': ('\uFE81', '', '', '\uFE82'), + # ARABIC LETTER ALEF WITH HAMZA ABOVE + '\u0623': ('\uFE83', '', '', '\uFE84'), + # ARABIC LETTER WAW WITH HAMZA ABOVE + '\u0624': ('\uFE85', '', '', '\uFE86'), + # ARABIC LETTER ALEF WITH HAMZA BELOW + '\u0625': ('\uFE87', '', '', '\uFE88'), + # ARABIC LETTER YEH WITH HAMZA ABOVE + '\u0626': ('\uFE89', '\uFE8B', '\uFE8C', '\uFE8A'), + # ARABIC LETTER ALEF + '\u0627': ('\uFE8D', '', '', '\uFE8E'), + # ARABIC LETTER BEH + '\u0628': ('\uFE8F', '\uFE91', '\uFE92', '\uFE90'), + # ARABIC LETTER TEH MARBUTA + '\u0629': ('\uFE93', '', '', '\uFE94'), + # ARABIC LETTER TEH + '\u062A': ('\uFE95', '\uFE97', '\uFE98', '\uFE96'), + # ARABIC LETTER THEH + '\u062B': ('\uFE99', '\uFE9B', '\uFE9C', '\uFE9A'), + # ARABIC LETTER JEEM + '\u062C': ('\uFE9D', '\uFE9F', '\uFEA0', '\uFE9E'), + # ARABIC LETTER HAH + '\u062D': ('\uFEA1', '\uFEA3', '\uFEA4', '\uFEA2'), + # ARABIC LETTER KHAH + '\u062E': ('\uFEA5', '\uFEA7', '\uFEA8', '\uFEA6'), + # ARABIC LETTER DAL + '\u062F': ('\uFEA9', '', '', '\uFEAA'), + # ARABIC LETTER THAL + '\u0630': ('\uFEAB', '', '', '\uFEAC'), + # ARABIC LETTER REH + '\u0631': ('\uFEAD', '', '', '\uFEAE'), + # ARABIC LETTER ZAIN + '\u0632': ('\uFEAF', '', '', '\uFEB0'), + # ARABIC LETTER SEEN + '\u0633': ('\uFEB1', '\uFEB3', '\uFEB4', '\uFEB2'), + # ARABIC LETTER SHEEN + '\u0634': ('\uFEB5', '\uFEB7', '\uFEB8', '\uFEB6'), + # ARABIC LETTER SAD + '\u0635': ('\uFEB9', '\uFEBB', '\uFEBC', '\uFEBA'), + # ARABIC LETTER DAD + '\u0636': ('\uFEBD', '\uFEBF', '\uFEC0', '\uFEBE'), + # ARABIC LETTER TAH + '\u0637': ('\uFEC1', '\uFEC3', '\uFEC4', '\uFEC2'), + # ARABIC LETTER ZAH + '\u0638': ('\uFEC5', '\uFEC7', '\uFEC8', '\uFEC6'), + # ARABIC LETTER AIN + '\u0639': ('\uFEC9', '\uFECB', '\uFECC', '\uFECA'), + # ARABIC LETTER GHAIN + '\u063A': ('\uFECD', '\uFECF', '\uFED0', '\uFECE'), + # ARABIC LETTER FEH + '\u0641': ('\uFED1', '\uFED3', '\uFED4', '\uFED2'), + # ARABIC LETTER QAF + '\u0642': ('\uFED5', '\uFED7', '\uFED8', '\uFED6'), + # ARABIC LETTER KAF + '\u0643': ('\uFED9', '\uFEDB', '\uFEDC', '\uFEDA'), + # ARABIC LETTER LAM + '\u0644': ('\uFEDD', '\uFEDF', '\uFEE0', '\uFEDE'), + # ARABIC LETTER MEEM + '\u0645': ('\uFEE1', '\uFEE3', '\uFEE4', '\uFEE2'), + # ARABIC LETTER NOON + '\u0646': ('\uFEE5', '\uFEE7', '\uFEE8', '\uFEE6'), + # ARABIC LETTER HEH + '\u0647': ('\uFEE9', '\uFEEB', '\uFEEC', '\uFEEA'), + # ARABIC LETTER WAW + '\u0648': ('\uFEED', '', '', '\uFEEE'), + # ARABIC LETTER (UIGHUR KAZAKH KIRGHIZ)? ALEF MAKSURA + '\u0649': ('\uFEEF', '\uFBE8', '\uFBE9', '\uFEF0'), + # ARABIC LETTER YEH + '\u064A': ('\uFEF1', '\uFEF3', '\uFEF4', '\uFEF2'), + # ARABIC LETTER ALEF WASLA + '\u0671': ('\uFB50', '', '', '\uFB51'), + # ARABIC LETTER U WITH HAMZA ABOVE + '\u0677': ('\uFBDD', '', '', ''), + # ARABIC LETTER TTEH + '\u0679': ('\uFB66', '\uFB68', '\uFB69', '\uFB67'), + # ARABIC LETTER TTEHEH + '\u067A': ('\uFB5E', '\uFB60', '\uFB61', '\uFB5F'), + # ARABIC LETTER BEEH + '\u067B': ('\uFB52', '\uFB54', '\uFB55', '\uFB53'), + # ARABIC LETTER PEH + '\u067E': ('\uFB56', '\uFB58', '\uFB59', '\uFB57'), + # ARABIC LETTER TEHEH + '\u067F': ('\uFB62', '\uFB64', '\uFB65', '\uFB63'), + # ARABIC LETTER BEHEH + '\u0680': ('\uFB5A', '\uFB5C', '\uFB5D', '\uFB5B'), + # ARABIC LETTER NYEH + '\u0683': ('\uFB76', '\uFB78', '\uFB79', '\uFB77'), + # ARABIC LETTER DYEH + '\u0684': ('\uFB72', '\uFB74', '\uFB75', '\uFB73'), + # ARABIC LETTER TCHEH + '\u0686': ('\uFB7A', '\uFB7C', '\uFB7D', '\uFB7B'), + # ARABIC LETTER TCHEHEH + '\u0687': ('\uFB7E', '\uFB80', '\uFB81', '\uFB7F'), + # ARABIC LETTER DDAL + '\u0688': ('\uFB88', '', '', '\uFB89'), + # ARABIC LETTER DAHAL + '\u068C': ('\uFB84', '', '', '\uFB85'), + # ARABIC LETTER DDAHAL + '\u068D': ('\uFB82', '', '', '\uFB83'), + # ARABIC LETTER DUL + '\u068E': ('\uFB86', '', '', '\uFB87'), + # ARABIC LETTER RREH + '\u0691': ('\uFB8C', '', '', '\uFB8D'), + # ARABIC LETTER JEH + '\u0698': ('\uFB8A', '', '', '\uFB8B'), + # ARABIC LETTER VEH + '\u06A4': ('\uFB6A', '\uFB6C', '\uFB6D', '\uFB6B'), + # ARABIC LETTER PEHEH + '\u06A6': ('\uFB6E', '\uFB70', '\uFB71', '\uFB6F'), + # ARABIC LETTER KEHEH + '\u06A9': ('\uFB8E', '\uFB90', '\uFB91', '\uFB8F'), + # ARABIC LETTER NG + '\u06AD': ('\uFBD3', '\uFBD5', '\uFBD6', '\uFBD4'), + # ARABIC LETTER GAF + '\u06AF': ('\uFB92', '\uFB94', '\uFB95', '\uFB93'), + # ARABIC LETTER NGOEH + '\u06B1': ('\uFB9A', '\uFB9C', '\uFB9D', '\uFB9B'), + # ARABIC LETTER GUEH + '\u06B3': ('\uFB96', '\uFB98', '\uFB99', '\uFB97'), + # ARABIC LETTER NOON GHUNNA + '\u06BA': ('\uFB9E', '', '', '\uFB9F'), + # ARABIC LETTER RNOON + '\u06BB': ('\uFBA0', '\uFBA2', '\uFBA3', '\uFBA1'), + # ARABIC LETTER HEH DOACHASHMEE + '\u06BE': ('\uFBAA', '\uFBAC', '\uFBAD', '\uFBAB'), + # ARABIC LETTER HEH WITH YEH ABOVE + '\u06C0': ('\uFBA4', '', '', '\uFBA5'), + # ARABIC LETTER HEH GOAL + '\u06C1': ('\uFBA6', '\uFBA8', '\uFBA9', '\uFBA7'), + # ARABIC LETTER KIRGHIZ OE + '\u06C5': ('\uFBE0', '', '', '\uFBE1'), + # ARABIC LETTER OE + '\u06C6': ('\uFBD9', '', '', '\uFBDA'), + # ARABIC LETTER U + '\u06C7': ('\uFBD7', '', '', '\uFBD8'), + # ARABIC LETTER YU + '\u06C8': ('\uFBDB', '', '', '\uFBDC'), + # ARABIC LETTER KIRGHIZ YU + '\u06C9': ('\uFBE2', '', '', '\uFBE3'), + # ARABIC LETTER VE + '\u06CB': ('\uFBDE', '', '', '\uFBDF'), + # ARABIC LETTER FARSI YEH + '\u06CC': ('\uFBFC', '\uFBFE', '\uFBFF', '\uFBFD'), + # ARABIC LETTER E + '\u06D0': ('\uFBE4', '\uFBE6', '\uFBE7', '\uFBE5'), + # ARABIC LETTER YEH BARREE + '\u06D2': ('\uFBAE', '', '', '\uFBAF'), + # ARABIC LETTER YEH BARREE WITH HAMZA ABOVE + '\u06D3': ('\uFBB0', '', '', '\uFBB1'), +} + +PRESENTATION_TO_BASE = dict() +for base in BASE_TO_PRESENTATION: + for presentation in BASE_TO_PRESENTATION[base]: + if presentation == '': + continue + PRESENTATION_TO_BASE[presentation.lower()] = base.lower() + + + +if __name__ == "__main__": + main() diff --git a/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/normalize_spaces.py b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/normalize_spaces.py new file mode 100644 index 00000000000..a64ad74e440 --- /dev/null +++ b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/normalize_spaces.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 + +""" This script normalizes a text file. It performs following normalizations: + multiple continuous spaces to single space, removes spaces at the begining + and end of the word. + Eg. normalize_spaces.py +""" +import sys +from snor import SnorIter + +if len(sys.argv) != 3: + print("Usage: normalize_spaces.py ") + sys.exit(1) + +input_file = sys.argv[1] +output_file = sys.argv[2] + +def main(): + + with open(input_file, 'r', encoding='utf-8') as fh, open(output_file, 'w', encoding='utf-8') as fh_out: + for utt, uttid in SnorIter(fh): + # Only output one space at a time + space_chars = set([" ", "\t", "\u00a0"]) + + last_char_was_space = False + + # Strip spaces at beginning and end of utterance + utt = utt.strip(' ') + for char in utt: + if char in space_chars: + if not last_char_was_space: + fh_out.write(" ") + last_char_was_space = True + else: + fh_out.write(char) + last_char_was_space = False + + # Finally, print out uttid and newline + fh_out.write(" (%s)\n" % uttid) + +if __name__ == "__main__": + main() diff --git a/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/snor.py b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/snor.py new file mode 100644 index 00000000000..29aa22e97ab --- /dev/null +++ b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/snor.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 + +""" This script defines an iterator over SNOR-formatted files. + The iterator iterates over lines, returning tuples of form (utt, utt-id). + snor-format: + some text goes here (id-of-utterance) + some other text here (id-of-next-utterance) +""" + +def SnorIter(fh): + for line in fh: + lparen_location = line.rfind("(") + rparen_location = line.rfind(")") + + if lparen_location > 0 and line[lparen_location-1] == " ": + lparen_location_modifier = -1 + else: + lparen_location_modifier = 0 + utt = line[ :lparen_location + lparen_location_modifier ] + uttid = line[ lparen_location+1 : rparen_location ] + + yield utt, uttid diff --git a/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/trans_to_chars.py b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/trans_to_chars.py new file mode 100644 index 00000000000..1a01d8cb618 --- /dev/null +++ b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/trans_to_chars.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 + +""" This script converts the words into sequence of space separated characters. + It also converts space between words into " " + Eg. trans_to_chars.py +""" + +import unicodedata +import sys +from snor import SnorIter + +if len(sys.argv) != 3: + print("Usage: trans_to_chars.py ") + sys.exit(1) + +input_file = sys.argv[1] +output_file = sys.argv[2] + +def main(): + with open(input_file, 'r', encoding='utf-8') as fh, open(output_file, 'w', encoding='utf-8') as fh_out: + for utt, uttid in SnorIter(fh): + for char in utt: + if char == " ": + fh_out.write(" ") + else: + fh_out.write(char) + fh_out.write(" ") + # Finally write out uttid and newline + fh_out.write("(%s)\n" % uttid) + + +if __name__ == "__main__": + main() diff --git a/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/trans_to_tokenized_words.py b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/trans_to_tokenized_words.py new file mode 100644 index 00000000000..7af13f7fd9e --- /dev/null +++ b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/trans_to_tokenized_words.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 + +""" This script splits punctuations, digits and currency symbols + from the word. + Eg. "They have come!" he said reverently, gripping his + " They have come ! " he said reverently , gripping his + Eg. trans_to_tokenized_words.py +""" + +import unicodedata +import sys +from snor import SnorIter + +if len(sys.argv) != 3: + print("Usage: trans_to_tokenized_words.py ") + sys.exit(1) + +input_file = sys.argv[1] +output_file = sys.argv[2] + + +punc = set(chr(i) for i in range(sys.maxunicode) + if unicodedata.category(chr(i)).startswith('P')) +currency_symbols = set(chr(i) for i in range(sys.maxunicode) + if unicodedata.category(chr(i)) == "Sc") +digits = set(chr(i) for i in range(sys.maxunicode) + if unicodedata.category(chr(i)) == "Nd") + +split_punc = True +split_digits = True +def main(): + + with open(input_file, 'r', encoding='utf-8') as fh, open(output_file, 'w', encoding='utf-8') as fh_out: + for utt, uttid in SnorIter(fh): + for char in utt: + if (split_punc and char in punc) or (split_punc and char in currency_symbols) or (split_digits and char in digits): + fh_out.write(" ") + fh_out.write(char) + fh_out.write(" ") + else: + fh_out.write(char) + + # Finally write out uttid and newline + fh_out.write(" (%s)\n" % uttid) + + + +if __name__ == "__main__": + main() diff --git a/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/word_trans_utf8_to_uxxxx.py b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/word_trans_utf8_to_uxxxx.py new file mode 100644 index 00000000000..143667b1e8c --- /dev/null +++ b/egs/yomdle_tamil/v1/local/yomdle/normalized_scoring/utils/word_trans_utf8_to_uxxxx.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 + +""" This script converts characters from utf-8 format to hexadecimal format. + Eg. word_trans_utf8_to_uxxxx.py +""" + +import sys +from snor import SnorIter + +if len(sys.argv) != 3: + print("Usage: word_trans_utf8_to_uxxxx.py ") + sys.exit(1) + +input_file = sys.argv[1] +output_file = sys.argv[2] + +def main(): + with open(input_file, 'r', encoding='utf-8') as fh, open(output_file, 'w', encoding='utf-8') as fh_out: + for utt, uttid in SnorIter(fh): + for word in utt.split(): + fh_out.write(utf8_char_to_uxxxx(word[0])) + for char in word[1:]: + fh_out.write("_") + fh_out.write(utf8_char_to_uxxxx(char)) + fh_out.write(" ") + # Finally write out uttid and newline + fh_out.write("(%s)\n" % uttid) + + +def utf8_char_to_uxxxx(char): + raw_hex = hex(ord(char))[2:].zfill(4).lower() + uxxxx_char = "u%s" % raw_hex + return uxxxx_char + +if __name__ == "__main__": + main() diff --git a/egs/yomdle_tamil/v1/local/yomdle/yomdle2csv.py b/egs/yomdle_tamil/v1/local/yomdle/yomdle2csv.py new file mode 100755 index 00000000000..d75b8bcbe8b --- /dev/null +++ b/egs/yomdle_tamil/v1/local/yomdle/yomdle2csv.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python3 +''' + +GEDI2CSV + +Convert GEDI-type bounding boxes to CSV format + +''' + +import logging +import os +import sys +import time +import glob +import csv +import imghdr +from PIL import Image +import argparse +import pdb +import cv2 +import numpy as np +import xml.etree.ElementTree as ET + +sin = np.sin +cos = np.cos +pi = np.pi + +def Rotate2D(pts, cnt, ang=90): + M = np.array([[cos(ang),-sin(ang)],[sin(ang),cos(ang)]]) + res = np.dot(pts-cnt,M)+cnt + return M, res + +def npbox2string(npar): + if np.shape(npar)[0] != 1: + print('Error during CSV conversion\n') + c1,r1 = npar[0][0],npar[0][1] + c2,r2 = npar[0][2],npar[0][3] + c3,r3 = npar[0][4],npar[0][5] + c4,r4 = npar[0][6],npar[0][7] + + return c1,r1,c2,r2,c3,r3,c4,r4 + +# cv2.minAreaRect() returns a Box2D structure which contains following detals - ( center (x,y), (width, height), angle of rotation ) +# Get 4 corners of the rectangle using cv2.boxPoints() + +class GEDI2CSV(object): + + ''' Initialize the extractor''' + def __init__(self, logger, args): + self._logger = logger + self._args = args + + ''' + Segment image with GEDI bounding box information + ''' + def csvfile(self, coords, polys, baseName, pgrot): + + ''' for writing the files ''' + writePath = self._args.outputDir + if os.path.isdir(writePath) != True: + os.makedirs(writePath) + + rotlist = [] + + header=['ID','name','col1','row1','col2','row2','col3','row3','col4','row4','confidence','truth','pgrot','bbrot','qual','script','lang'] + conf=100 + pgrot = 0 + bbrot = 0 + qual = 0 + script = '' + + write_ctr = 0 + if len(coords) == 0 and len(polys) == 0: + self._logger.info('Found %s with no text content',(baseName)) + print('...Found %s with no text content' % (baseName)) + return + + strPos = writePath + baseName + + for j in polys: + try: + arr = [] + [id,poly_val,text,qual,lang] = j + script=None + #print(j) + for i in poly_val: + if len(i.strip()) > 0: + #print(i) + arr.append(eval(i)) + + contour = np.asarray(arr) + #print(contour) + convex = cv2.convexHull(contour) + rect = cv2.minAreaRect(convex) + box = cv2.boxPoints(rect) + box = np.int0(box) + box = np.reshape(box,(-1,1)).T + c1,r1,c2,r2,c3,r3,c4,r4 = npbox2string(box) + + bbrot = 0.0 + + rotlist.append([id,baseName + '_' + id + '.png',c1,r1,c2,r2,c3,r3,c4,r4,conf,text,pgrot,bbrot,qual,script,lang]) + + except: + print('...polygon error %s, %s' % (j, baseName)) + continue + + # then write out all of list to file + with open(strPos + ".csv", "w", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(header) + for row in rotlist: + writer.writerow(row) + write_ctr += 1 + + return write_ctr + + +def main(args): + + startTime = time.clock() + + writePath = args.outputDir + print('write to %s' % (writePath)) + if os.path.isdir(writePath) != True: + os.makedirs(writePath) + + ''' Setup logging ''' + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + if args.log: + handler = logging.FileHandler(args.log) + handler.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + + gtconverter = GEDI2CSV(logger, args) + namespaces = {"gedi" : "http://lamp.cfar.umd.edu/media/projects/GEDI/"} + keyCnt=0 + + fileCnt = 0 + line_write_ctr = 0 + line_error_ctr = 0 + file_error_ctr = 0 + ''' + Get all XML files in the directory and sub folders + ''' + print('reading %s' % (args.inputDir)) + for root, dirnames, filenames in os.walk(args.inputDir, followlinks=True): + for file in filenames: + if file.lower().endswith('.xml'): + fullName = os.path.join(root,file) + baseName = os.path.splitext(fullName) + + fileCnt += 1 + + try: + ''' read the XML file ''' + tree = ET.parse(fullName) + except: + print('...ERROR parsing %s' % (fullName)) + file_error_ctr += 1 + continue + + gedi_root = tree.getroot() + child = gedi_root.findall('gedi:DL_DOCUMENT',namespaces)[0] + totalpages = int(child.attrib['NrOfPages']) + coordinates=[] + polygons = [] + + ''' and for each page ''' + for i, pgs in enumerate(child.iterfind('gedi:DL_PAGE',namespaces)): + + if 'GEDI_orientation' not in pgs.attrib: + pageRot=0 + else: + pageRot = int(pgs.attrib['GEDI_orientation']) + logger.info(' PAGE ROTATION %s, %s' % (fullName, str(pageRot))) + + ''' find children for each page ''' + for zone in pgs.findall('gedi:DL_ZONE',namespaces): + + if zone.attrib['gedi_type']=='Text' : + if zone.get('polygon'): + keyCnt+=1 + polygons.append([zone.attrib['id'],zone.get('polygon').split(';'), + zone.get('Text_Content'),zone.get('Illegible'),zone.get('Language')]) + else: + print('...Not polygon') + + + if len(coordinates) > 0 or len(polygons) > 0: + line_write_ctr += gtconverter.csvfile(coordinates, polygons, os.path.splitext(file)[0], pageRot) + else: + print('...%s has no text content' % (baseName[0])) + + + print('complete...total files %d, lines written %d, img errors %d, line error %d' % (fileCnt, line_write_ctr, file_error_ctr, line_error_ctr)) + + +''' Args and defaults ''' +def parse_arguments(argv): + parser = argparse.ArgumentParser() + + parser.add_argument('--inputDir', type=str, help='Input directory', default='/data/YOMDLE/final_arabic/xml') + parser.add_argument('--outputDir', type=str, help='Output directory', default='/exp/YOMDL/final_arabic/csv_truth/') + parser.add_argument('--log', type=str, help='Log directory', default='/exp/logs.txt') + + return parser.parse_args(argv) + + +''' Run ''' +if __name__ == '__main__': + main(parse_arguments(sys.argv[1:])) diff --git a/egs/yomdle_tamil/v1/path.sh b/egs/yomdle_tamil/v1/path.sh new file mode 100755 index 00000000000..2d17b17a84a --- /dev/null +++ b/egs/yomdle_tamil/v1/path.sh @@ -0,0 +1,6 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C diff --git a/egs/yomdle_tamil/v1/run_end2end.sh b/egs/yomdle_tamil/v1/run_end2end.sh new file mode 100755 index 00000000000..e6a8e0a4432 --- /dev/null +++ b/egs/yomdle_tamil/v1/run_end2end.sh @@ -0,0 +1,165 @@ +#!/bin/bash + +# Copyright 2018 Hossein Hadian +# Ashish Arora +# Jonathan Chang +# Apache 2.0 + +set -e +stage=0 +nj=30 + +language_main=Tamil +slam_dir=/export/corpora5/slam/SLAM/ +yomdle_dir=/export/corpora5/slam/YOMDLE/ +corpus_dir=/export/corpora5/handwriting_ocr/corpus_data/ta/ + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +./local/check_tools.sh +# Start from stage=-2 for data preparation. This stage stores line images, +# csv files and splits{train,test,train_unsup} data/download/truth_line_image, +# data/download/truth_csv and data/local/splits respectively. +if [ $stage -le -2 ]; then + echo "$(date): preparing data, obtaining line images and csv files..." + local/yomdle/create_download_dir.sh --language_main $language_main \ + --slam_dir $slam_dir --yomdle_dir $yomdle_dir +fi + +if [ $stage -le -1 ]; then + echo "$(date): getting corpus text for language modelling..." + mkdir -p data/local/text/cleaned + cat $corpus_dir/* > data/local/text/ta.txt + head -20000 data/local/text/ta.txt > data/local/text/val.txt + tail -n +20000 data/local/text/ta.txt > data/local/text/corpus.txt +fi + +mkdir -p data/{train,test}/data +if [ $stage -le 0 ]; then + echo "$(date) stage 0: Processing train and test data." + echo " creating text, images.scp, utt2spk and spk2utt" + # removing empty transcription line images from train and test set. + # It can cause error while applying BPE. + for set in train test; do + local/process_data.py data/download/ \ + data/local/splits/${set}.txt data/${set} + image/fix_data_dir.sh data/${set} + done +fi + +if [ $stage -le 1 ]; then + echo "$(date) stage 1: getting allowed image widths for e2e training..." + image/get_image2num_frames.py --feat-dim 40 data/train + image/get_allowed_lengths.py --frame-subsampling-factor 4 10 data/train + for set in train test; do + echo "$(date) Extracting features, creating feats.scp file" + local/extract_features.sh --nj $nj --cmd "$cmd" --feat-dim 40 data/${set} + steps/compute_cmvn_stats.sh data/${set} || exit 1; + done + image/fix_data_dir.sh data/train +fi + +if [ $stage -le 2 ]; then + for set in train; do + echo "$(date) stage 2: Performing augmentation, it will double training data" + local/augment_data.sh --nj $nj --cmd "$cmd" --feat-dim 40 data/${set} data/${set}_aug data + steps/compute_cmvn_stats.sh data/${set}_aug || exit 1; + done +fi + +if [ $stage -le 3 ]; then + echo "$(date) stage 3: BPE preparation" + # getting non-silence phones. + cut -d' ' -f2- data/train/text | \ +python3 <( +cat << "END" +import os, sys, io; +infile = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8'); +output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8'); +phone_dict = dict(); +for line in infile: + line_vect = line.strip().split(); + for word in line_vect: + for phone in word: + phone_dict[phone] = phone; + +for phone in phone_dict.keys(): + output.write(phone+ '\n'); +END + ) > data/local/text/cleaned/phones.txt + + cut -d' ' -f2- data/train/text > data/local/text/cleaned/train.txt + + echo "Processing corpus text..." + # we are removing the lines from the corpus which which have + # phones other than the phones in data/local/text/cleaned/phones.txt. + cat data/local/text/corpus.txt | \ + local/process_corpus.py > data/local/text/cleaned/corpus.txt + cat data/local/text/val.txt | \ + local/process_corpus.py > data/local/text/cleaned/val.txt + + echo "learning BPE..." + # it is currently learned with only training text but we can also use all corpus text + # to learn BPE. phones are added so that one isolated occurance of every phone exists. + cat data/local/text/cleaned/phones.txt data/local/text/cleaned/train.txt | \ + utils/lang/bpe/prepend_words.py | utils/lang/bpe/learn_bpe.py -s 700 > data/local/bpe.txt || exit 1; +fi + +if [ $stage -le 4 ]; then + echo "$(date) stage 4: applying BPE..." + echo "applying BPE on train, test text..." + for set in test train train_aug; do + cut -d' ' -f1 data/$set/text > data/$set/ids + cut -d' ' -f2- data/$set/text | utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt | \ + sed 's/@@//g' > data/$set/bpe_text + mv data/$set/text data/$set/text.old + paste -d' ' data/$set/ids data/$set/bpe_text > data/$set/text + rm -f data/$set/bpe_text data/$set/ids + done + + echo "applying BPE to corpus text..." + cat data/local/text/cleaned/corpus.txt | utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt | \ + sed 's/@@//g' > data/local/text/cleaned/bpe_corpus.txt + cat data/local/text/cleaned/val.txt | utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt | \ + sed 's/@@//g' > data/local/text/cleaned/bpe_val.txt +fi + +if [ $stage -le 5 ]; then + echo "$(date) stage 5: Preparing dictionary and lang..." + local/prepare_dict.sh --dir data/local/dict + utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 8 --sil-prob 0.0 --position-dependent-phones false \ + data/local/dict "" data/lang/temp data/lang + utils/lang/bpe/add_final_optional_silence.sh --final-sil-prob 0.5 data/lang +fi + +if [ $stage -le 6 ]; then + echo "$(date) stage 6: Estimating a language model for decoding..." + local/train_lm.sh + utils/format_lm.sh data/lang data/local/local_lm/data/arpa/6gram_small.arpa.gz \ + data/local/dict/lexicon.txt data/lang + utils/build_const_arpa_lm.sh data/local/local_lm/data/arpa/6gram_unpruned.arpa.gz \ + data/lang data/lang_rescore_6g +fi + +if [ $stage -le 7 ]; then + echo "$(date) stage 7: Calling the flat-start chain recipe..." + local/chain/run_e2e_cnn.sh --train_set train_aug +fi + +if [ $stage -le 8 ]; then + echo "$(date) stage 8: Aligning the training data using the e2e chain model..." + steps/nnet3/align.sh --nj $nj --cmd "$cmd" \ + --use-gpu false \ + --scale-opts '--transition-scale=1.0 --acoustic-scale=1.0 --self-loop-scale=1.0' \ + data/train_aug data/lang exp/chain/e2e_cnn_1a exp/chain/e2e_ali_train +fi + +if [ $stage -le 9 ]; then + echo "$(date) stage 9: Building a tree and training a regular chain model using the e2e alignments..." + local/chain/run_cnn_e2eali.sh --train_set train_aug +fi diff --git a/egs/yomdle_tamil/v1/steps b/egs/yomdle_tamil/v1/steps new file mode 120000 index 00000000000..1b186770dd1 --- /dev/null +++ b/egs/yomdle_tamil/v1/steps @@ -0,0 +1 @@ +../../wsj/s5/steps/ \ No newline at end of file diff --git a/egs/yomdle_tamil/v1/utils b/egs/yomdle_tamil/v1/utils new file mode 120000 index 00000000000..a3279dc8679 --- /dev/null +++ b/egs/yomdle_tamil/v1/utils @@ -0,0 +1 @@ +../../wsj/s5/utils/ \ No newline at end of file diff --git a/egs/yomdle_zh/README.txt b/egs/yomdle_zh/README.txt new file mode 100644 index 00000000000..39d2348ca10 --- /dev/null +++ b/egs/yomdle_zh/README.txt @@ -0,0 +1,3 @@ +This directory contains example scripts for OCR on the Yomdle and Slam datasets. +Training is done on the Yomdle dataset and testing is done on Slam. +LM rescoring is also done with extra corpus data obtained from various sources (e.g. Hamshahri) diff --git a/egs/yomdle_zh/v1/cmd.sh b/egs/yomdle_zh/v1/cmd.sh new file mode 100755 index 00000000000..3c8eb9f93a5 --- /dev/null +++ b/egs/yomdle_zh/v1/cmd.sh @@ -0,0 +1,13 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +export cmd="queue.pl" diff --git a/egs/yomdle_zh/v1/image b/egs/yomdle_zh/v1/image new file mode 120000 index 00000000000..1668ee99922 --- /dev/null +++ b/egs/yomdle_zh/v1/image @@ -0,0 +1 @@ +../../cifar/v1/image/ \ No newline at end of file diff --git a/egs/yomdle_zh/v1/local/augment_data.sh b/egs/yomdle_zh/v1/local/augment_data.sh new file mode 100755 index 00000000000..1f13ed15ded --- /dev/null +++ b/egs/yomdle_zh/v1/local/augment_data.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright 2018 Hossein Hadian +# 2018 Ashish Arora + +# Apache 2.0 +# This script performs data augmentation. + +nj=4 +cmd=run.pl +feat_dim=40 +fliplr=false +verticle_shift=0 +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +srcdir=$1 +outdir=$2 +datadir=$3 + +mkdir -p $datadir/augmentations +echo "copying $srcdir to $datadir/augmentations/aug1, allowed length, creating feats.scp" + +for set in aug1; do + image/copy_data_dir.sh --spk-prefix $set- --utt-prefix $set- \ + $srcdir $datadir/augmentations/$set + cat $srcdir/allowed_lengths.txt > $datadir/augmentations/$set/allowed_lengths.txt + local/extract_features.sh --nj $nj --cmd "$cmd" --feat-dim $feat_dim \ + --vertical-shift $verticle_shift \ + --fliplr $fliplr --augment 'random_scale' $datadir/augmentations/$set +done + +echo " combine original data and data from different augmentations" +utils/combine_data.sh --extra-files images.scp $outdir $srcdir $datadir/augmentations/aug1 +cat $srcdir/allowed_lengths.txt > $outdir/allowed_lengths.txt diff --git a/egs/yomdle_zh/v1/local/bidi.py b/egs/yomdle_zh/v1/local/bidi.py new file mode 100755 index 00000000000..447313a5d02 --- /dev/null +++ b/egs/yomdle_zh/v1/local/bidi.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright 2018 Chun-Chieh Chang + +# This script is largely written by Stephen Rawls +# and uses the python package https://pypi.org/project/PyICU_BiDi/ +# The code leaves right to left text alone and reverses left to right text. + +import icu_bidi +import io +import sys +import unicodedata +# R=strong right-to-left; AL=strong arabic right-to-left +rtl_set = set(chr(i) for i in range(sys.maxunicode) + if unicodedata.bidirectional(chr(i)) in ['R','AL']) +def determine_text_direction(text): + # Easy case first + for char in text: + if char in rtl_set: + return icu_bidi.UBiDiLevel.UBIDI_RTL + # If we made it here we did not encounter any strongly rtl char + return icu_bidi.UBiDiLevel.UBIDI_LTR + +def utf8_visual_to_logical(text): + text_dir = determine_text_direction(text) + + bidi = icu_bidi.Bidi() + bidi.inverse = True + bidi.reordering_mode = icu_bidi.UBiDiReorderingMode.UBIDI_REORDER_INVERSE_LIKE_DIRECT + bidi.reordering_options = icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_DEFAULT # icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_INSERT_MARKS + + bidi.set_para(text, text_dir, None) + + res = bidi.get_reordered(0 | icu_bidi.UBidiWriteReorderedOpt.UBIDI_DO_MIRRORING | icu_bidi.UBidiWriteReorderedOpt.UBIDI_KEEP_BASE_COMBINING) + + return res + +def utf8_logical_to_visual(text): + text_dir = determine_text_direction(text) + + bidi = icu_bidi.Bidi() + + bidi.reordering_mode = icu_bidi.UBiDiReorderingMode.UBIDI_REORDER_DEFAULT + bidi.reordering_options = icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_DEFAULT #icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_INSERT_MARKS + + bidi.set_para(text, text_dir, None) + + res = bidi.get_reordered(0 | icu_bidi.UBidiWriteReorderedOpt.UBIDI_DO_MIRRORING | icu_bidi.UBidiWriteReorderedOpt.UBIDI_KEEP_BASE_COMBINING) + + return res + + +##main## +sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding="utf8") +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf8") +for line in sys.stdin: + line = line.strip() + line = utf8_logical_to_visual(line)[::-1] + sys.stdout.write(line + '\n') diff --git a/egs/yomdle_zh/v1/local/chain/compare_wer.sh b/egs/yomdle_zh/v1/local/chain/compare_wer.sh new file mode 100755 index 00000000000..ab880c1adb5 --- /dev/null +++ b/egs/yomdle_zh/v1/local/chain/compare_wer.sh @@ -0,0 +1,67 @@ +#!/bin/bash + +# this script is used for comparing decoding results between systems. +# e.g. local/chain/compare_wer.sh exp/chain/cnn{1a,1b} + +# Copyright 2017 Chun Chieh Chang +# 2017 Ashish Arora + +if [ $# == 0 ]; then + echo "Usage: $0: [ ... ]" + echo "e.g.: $0 exp/chain/cnn{1a,1b}" + exit 1 +fi + +echo "# $0 $*" +used_epochs=false + +echo -n "# System " +for x in $*; do printf "% 10s" " $(basename $x)"; done +echo + +echo -n "# WER " +for x in $*; do + wer=$(cat $x/decode_test/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# CER " +for x in $*; do + cer=$(cat $x/decode_test/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + + +if $used_epochs; then + exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. +fi + +echo -n "# Final train prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final train prob (xent) " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob (xent) " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo diff --git a/egs/yomdle_zh/v1/local/chain/run_cnn_e2eali_1b.sh b/egs/yomdle_zh/v1/local/chain/run_cnn_e2eali_1b.sh new file mode 100755 index 00000000000..0a4e00d7aed --- /dev/null +++ b/egs/yomdle_zh/v1/local/chain/run_cnn_e2eali_1b.sh @@ -0,0 +1,245 @@ +#!/bin/bash + +# e2eali_1b is the same as chainali_1a but uses the e2e chain model to get the +# lattice alignments and to build a tree + +# ./local/chain/compare_wer.sh exp_yomdle_chinese/chain/e2e_cnn_1a exp_yomdle_chinese/chain/cnn_e2eali_1b +# System e2e_cnn_1a cnn_e2eali_1b +# CER 15.44 13.57 +# Final train prob 0.0616 -0.0512 +# Final valid prob 0.0390 -0.0718 +# Final train prob (xent) -0.6199 +# Final valid prob (xent) -0.7448 + +set -e -o pipefail + +data_dir=data +exp_dir=exp + +stage=0 + +nj=30 +train_set=train +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. +affix=_1b #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +common_egs_dir= +reporting_email= + +# chain options +train_stage=-10 +xent_regularize=0.1 +frame_subsampling_factor=4 +# training chunk-options +chunk_width=340,300,200,100 +num_leaves=1000 +# we don't need extra left/right context for TDNN systems. +chunk_left_context=0 +chunk_right_context=0 +tdnn_dim=450 +# training options +srand=0 +remove_egs=true +lang_test=lang_test +# End configuration section. +echo "$0 $@" # Print the command line for logging + + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 2 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/nnet3/align_lats.sh --nj $nj --cmd "$cmd" \ + --acoustic-scale 1.0 \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0' \ + ${train_data_dir} $data_dir/lang $e2echain_model_dir $lat_dir + echo "" >$lat_dir/splice_opts + +fi + +if [ $stage -le 3 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor $frame_subsampling_factor \ + --alignment-subsampling-factor 1 \ + --context-opts "--context-width=3 --central-position=1" \ + --cmd "$cmd" $num_leaves ${train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 4 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + cnn_opts="l2-regularize=0.075" + tdnn_opts="l2-regularize=0.075" + output_opts="l2-regularize=0.1" + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=32" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=128" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=512" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=180 name=input + + conv-relu-batchnorm-layer name=cnn1 height-in=60 height-out=60 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=60 height-out=60 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn3 height-in=60 height-out=30 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn4 height-in=30 height-out=30 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=30 height-out=30 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn6 height-in=30 height-out=15 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn7 height-in=15 height-out=15 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn8 height-in=15 height-out=15 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn9 height-in=15 height-out=15 time-offsets=-4,0,4 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-8,-4,0,4,8) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' mod?els... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 5 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/iam-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=500" \ + --chain.frame-subsampling-factor=$frame_subsampling_factor \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=6 \ + --trainer.frames-per-iter=1000000 \ + --trainer.optimization.num-jobs-initial=4 \ + --trainer.optimization.num-jobs-final=8 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=16,8 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.chunk-left-context=$chunk_left_context \ + --egs.chunk-right-context=$chunk_right_context \ + --egs.chunk-left-context-initial=0 \ + --egs.chunk-right-context-final=0 \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 6 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $data_dir/$lang_test \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 7 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --extra-left-context $chunk_left_context \ + --extra-right-context $chunk_right_context \ + --extra-left-context-initial 0 \ + --extra-right-context-final 0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph $data_dir/test $dir/decode_test || exit 1; +fi diff --git a/egs/yomdle_zh/v1/local/chain/run_flatstart_cnn1a.sh b/egs/yomdle_zh/v1/local/chain/run_flatstart_cnn1a.sh new file mode 100755 index 00000000000..88bbd32790c --- /dev/null +++ b/egs/yomdle_zh/v1/local/chain/run_flatstart_cnn1a.sh @@ -0,0 +1,169 @@ +#!/bin/bash +# Copyright 2017 Hossein Hadian + +# This script does end2end chain training (i.e. from scratch) + +# ./local/chain/compare_wer.sh exp_yomdle_chinese/chain/e2e_cnn_1a exp_yomdle_chinese/chain/cnn_e2eali_1b +# System e2e_cnn_1a cnn_e2eali_1b +# CER 15.44 13.57 +# Final train prob 0.0616 -0.0512 +# Final valid prob 0.0390 -0.0718 +# Final train prob (xent) -0.6199 +# Final valid prob (xent) -0.7448 + +set -e + +data_dir=data +exp_dir=exp + +# configs for 'chain' +stage=0 +nj=30 +train_stage=-10 +get_egs_stage=-10 +affix=1a + +# training options +tdnn_dim=450 +num_epochs=4 +num_jobs_initial=4 +num_jobs_final=8 +minibatch_size=150=64,32/300=32,16/600=16,8/1200=8,4 +common_egs_dir= +l2_regularize=0.00005 +frames_per_iter=1000000 +cmvn_opts="--norm-means=false --norm-vars=false" +train_set=train +lang_test=lang_test + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 1 ]; then + steps/nnet3/chain/e2e/prepare_e2e.sh --nj $nj --cmd "$cmd" \ + --shared-phones true \ + --type mono \ + $data_dir/$train_set $lang $treedir + $cmd $treedir/log/make_phone_lm.log \ + cat $data_dir/$train_set/text \| \ + steps/nnet3/chain/e2e/text_to_phones.py $data_dir/lang \| \ + utils/sym2int.pl -f 2- $data_dir/lang/phones.txt \| \ + chain-est-phone-lm --num-extra-lm-states=500 \ + ark:- $treedir/phone_lm.fst +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $treedir/tree | grep num-pdfs | awk '{print $2}') + + cnn_opts="l2-regularize=0.075" + tdnn_opts="l2-regularize=0.075" + output_opts="l2-regularize=0.1" + + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=32" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=128" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=512" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=180 name=input + conv-relu-batchnorm-layer name=cnn1 height-in=60 height-out=60 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=60 height-out=30 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=30 height-out=30 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=30 height-out=30 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=30 height-out=15 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=15 height-out=15 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=15 height-out=15 time-offsets=-4,0,4 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $output_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs +fi + +if [ $stage -le 3 ]; then + # no need to store the egs in a shared storage because we always + # remove them. Anyway, it takes only 5 minutes to generate them. + + steps/nnet3/chain/e2e/train_e2e.py --stage $train_stage \ + --cmd "$cmd" \ + --feat.cmvn-opts "$cmvn_opts" \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize $l2_regularize \ + --chain.apply-deriv-weights false \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--num_egs_diagnostic 100 --num_utts_subset 400" \ + --chain.frame-subsampling-factor 4 \ + --chain.alignment-subsampling-factor 4 \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter $frames_per_iter \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.momentum 0 \ + --trainer.optimization.num-jobs-initial $num_jobs_initial \ + --trainer.optimization.num-jobs-final $num_jobs_final \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.optimization.shrink-value 1.0 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs true \ + --feat-dir $data_dir/${train_set} \ + --tree-dir $treedir \ + --dir $dir || exit 1; +fi + +if [ $stage -le 4 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $data_dir/$lang_test \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 5 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $nj --cmd "$cmd" \ + $dir/graph $data_dir/test $dir/decode_test || exit 1; +fi + +echo "Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/yomdle_zh/v1/local/create_download.sh b/egs/yomdle_zh/v1/local/create_download.sh new file mode 100755 index 00000000000..1daad354473 --- /dev/null +++ b/egs/yomdle_zh/v1/local/create_download.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Copyright 2018 Chun-Chieh Chang + +# The original format of the dataset given is GEDI and page images. +# This script is written to create line images from page images. +# It also creates csv files from the GEDI files. + +database_slam=/export/corpora5/slam/SLAM/Farsi/transcribed +database_yomdle=/export/corpora5/slam/YOMDLE/final_farsi +cangjie_url=https://raw.githubusercontent.com/wanleung/libcangjie/master/tables/cj5-cc.txt +download_dir=download +slam_dir=$download_dir/slam_farsi +yomdle_dir=$download_dir/yomdle_farsi + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +echo "$0: Processing SLAM ${language}" +echo "Date: $(date)." +mkdir -p ${slam_dir}/{truth_csv,truth_csv_raw,truth_line_image} +local/gedi2csv.py \ + --inputDir ${database_slam} \ + --outputDir ${slam_dir}/truth_csv_raw \ + --log ${slam_dir}/GEDI2CSV_enriched.log +local/create_line_image_from_page_image.py \ + ${database_slam} \ + ${slam_dir}/truth_csv_raw \ + ${slam_dir} + +echo "$0: Processing YOMDLE ${language}" +echo "Date: $(date)." +mkdir -p ${yomdle_dir}/{truth_csv,truth_csv_raw,truth_line_image} +local/yomdle2csv.py \ + --inputDir ${database_yomdle} \ + --outputDir ${yomdle_dir}/truth_csv_raw/ \ + --log ${yomdle_dir}/YOMDLE2CSV.log +local/create_line_image_from_page_image.py \ + --im-format "jpg" \ + ${database_yomdle}/images \ + ${yomdle_dir}/truth_csv_raw \ + ${yomdle_dir} + +echo "Downloading table for CangJie." +wget -P $download_dir/ $cangjie_url || exit 1; +perl -n -i -e 'print if $. > 8' $download_dir/cj5-cc.txt diff --git a/egs/yomdle_zh/v1/local/create_line_image_from_page_image.py b/egs/yomdle_zh/v1/local/create_line_image_from_page_image.py new file mode 100755 index 00000000000..7135bb1b242 --- /dev/null +++ b/egs/yomdle_zh/v1/local/create_line_image_from_page_image.py @@ -0,0 +1,458 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora +# Apache 2.0 +# minimum bounding box part in this script is originally from +#https://github.com/BebeSparkelSparkel/MinimumBoundingBox +#https://startupnextdoor.com/computing-convex-hull-in-python/ +""" This module will be used for extracting line images from page image. + Given the word segmentation (bounding box around a word) for every word, it will + extract line segmentation. To extract line segmentation, it will take word bounding + boxes of a line as input, will create a minimum area bounding box that will contain + all corner points of word bounding boxes. The obtained bounding box (will not necessarily + be vertically or horizontally aligned). Hence to extract line image from line bounding box, + page image is rotated and line image is cropped and saved. +""" + +import argparse +import csv +import itertools +import sys +import os +import numpy as np +from math import atan2, cos, sin, pi, degrees, sqrt +from collections import namedtuple + +from scipy.spatial import ConvexHull +from PIL import Image +from scipy.misc import toimage + +parser = argparse.ArgumentParser(description="Creates line images from page image") +parser.add_argument('image_dir', type=str, help='Path to full page images') +parser.add_argument('csv_dir', type=str, help='Path to csv files') +parser.add_argument('out_dir', type=str, help='Path to output directory') +parser.add_argument('--im-format', type=str, default='png', help='What file format are the images') +parser.add_argument('--padding', type=int, default=100, help='Padding so BBox does not exceed image area') +parser.add_argument('--head', type=int, default=-1, help='Number of csv files to process') +args = parser.parse_args() + +""" +bounding_box is a named tuple which contains: + area (float): area of the rectangle + length_parallel (float): length of the side that is parallel to unit_vector + length_orthogonal (float): length of the side that is orthogonal to unit_vector + rectangle_center(int, int): coordinates of the rectangle center + (use rectangle_corners to get the corner points of the rectangle) + unit_vector (float, float): direction of the length_parallel side. + (it's orthogonal vector can be found with the orthogonal_vector function + unit_vector_angle (float): angle of the unit vector to be in radians. + corner_points [(float, float)]: set that contains the corners of the rectangle +""" + +bounding_box_tuple = namedtuple('bounding_box_tuple', 'area ' + 'length_parallel ' + 'length_orthogonal ' + 'rectangle_center ' + 'unit_vector ' + 'unit_vector_angle ' + 'corner_points' + ) + + +def unit_vector(pt0, pt1): + """ Given two points pt0 and pt1, return a unit vector that + points in the direction of pt0 to pt1. + Returns + ------- + (float, float): unit vector + """ + dis_0_to_1 = sqrt((pt0[0] - pt1[0])**2 + (pt0[1] - pt1[1])**2) + return (pt1[0] - pt0[0]) / dis_0_to_1, \ + (pt1[1] - pt0[1]) / dis_0_to_1 + + +def orthogonal_vector(vector): + """ Given a vector, returns a orthogonal/perpendicular vector of equal length. + Returns + ------ + (float, float): A vector that points in the direction orthogonal to vector. + """ + return -1 * vector[1], vector[0] + + +def bounding_area(index, hull): + """ Given index location in an array and convex hull, it gets two points + hull[index] and hull[index+1]. From these two points, it returns a named + tuple that mainly contains area of the box that bounds the hull. This + bounding box orintation is same as the orientation of the lines formed + by the point hull[index] and hull[index+1]. + Returns + ------- + a named tuple that contains: + area: area of the rectangle + length_parallel: length of the side that is parallel to unit_vector + length_orthogonal: length of the side that is orthogonal to unit_vector + rectangle_center: coordinates of the rectangle center + unit_vector: direction of the length_parallel side. + (it's orthogonal vector can be found with the orthogonal_vector function) + """ + unit_vector_p = unit_vector(hull[index], hull[index+1]) + unit_vector_o = orthogonal_vector(unit_vector_p) + + dis_p = tuple(np.dot(unit_vector_p, pt) for pt in hull) + dis_o = tuple(np.dot(unit_vector_o, pt) for pt in hull) + + min_p = min(dis_p) + min_o = min(dis_o) + len_p = max(dis_p) - min_p + len_o = max(dis_o) - min_o + + return {'area': len_p * len_o, + 'length_parallel': len_p, + 'length_orthogonal': len_o, + 'rectangle_center': (min_p + float(len_p) / 2, min_o + float(len_o) / 2), + 'unit_vector': unit_vector_p, + } + + +def to_xy_coordinates(unit_vector_angle, point): + """ Given angle from horizontal axis and a point from origin, + returns converted unit vector coordinates in x, y coordinates. + angle of unit vector should be in radians. + Returns + ------ + (float, float): converted x,y coordinate of the unit vector. + """ + angle_orthogonal = unit_vector_angle + pi / 2 + return point[0] * cos(unit_vector_angle) + point[1] * cos(angle_orthogonal), \ + point[0] * sin(unit_vector_angle) + point[1] * sin(angle_orthogonal) + + +def rotate_points(center_of_rotation, angle, points): + """ Rotates a point cloud around the center_of_rotation point by angle + input + ----- + center_of_rotation (float, float): angle of unit vector to be in radians. + angle (float): angle of rotation to be in radians. + points [(float, float)]: Points to be a list or tuple of points. Points to be rotated. + Returns + ------ + [(float, float)]: Rotated points around center of rotation by angle + """ + rot_points = [] + ang = [] + for pt in points: + diff = tuple([pt[d] - center_of_rotation[d] for d in range(2)]) + diff_angle = atan2(diff[1], diff[0]) + angle + ang.append(diff_angle) + diff_length = sqrt(sum([d**2 for d in diff])) + rot_points.append((center_of_rotation[0] + diff_length * cos(diff_angle), + center_of_rotation[1] + diff_length * sin(diff_angle))) + + return rot_points + + +def rectangle_corners(rectangle): + """ Given rectangle center and its inclination, returns the corner + locations of the rectangle. + Returns + ------ + [(float, float)]: 4 corner points of rectangle. + """ + corner_points = [] + for i1 in (.5, -.5): + for i2 in (i1, -1 * i1): + corner_points.append((rectangle['rectangle_center'][0] + i1 * rectangle['length_parallel'], + rectangle['rectangle_center'][1] + i2 * rectangle['length_orthogonal'])) + + return rotate_points(rectangle['rectangle_center'], rectangle['unit_vector_angle'], corner_points) + + +def get_orientation(origin, p1, p2): + """ + Given origin and two points, return the orientation of the Point p1 with + regards to Point p2 using origin. + Returns + ------- + integer: Negative if p1 is clockwise of p2. + """ + difference = ( + ((p2[0] - origin[0]) * (p1[1] - origin[1])) + - ((p1[0] - origin[0]) * (p2[1] - origin[1])) + ) + return difference + + +def compute_hull(points): + """ + Given input list of points, return a list of points that + made up the convex hull. + Returns + ------- + [(float, float)]: convexhull points + """ + hull_points = [] + start = points[0] + min_x = start[0] + for p in points[1:]: + if p[0] < min_x: + min_x = p[0] + start = p + + point = start + hull_points.append(start) + + far_point = None + while far_point is not start: + p1 = None + for p in points: + if p is point: + continue + else: + p1 = p + break + + far_point = p1 + + for p2 in points: + if p2 is point or p2 is p1: + continue + else: + direction = get_orientation(point, far_point, p2) + if direction > 0: + far_point = p2 + + hull_points.append(far_point) + point = far_point + return hull_points + + +def minimum_bounding_box(points): + """ Given a list of 2D points, it returns the minimum area rectangle bounding all + the points in the point cloud. + Returns + ------ + returns a namedtuple that contains: + area: area of the rectangle + length_parallel: length of the side that is parallel to unit_vector + length_orthogonal: length of the side that is orthogonal to unit_vector + rectangle_center: coordinates of the rectangle center + unit_vector: direction of the length_parallel side. RADIANS + unit_vector_angle: angle of the unit vector + corner_points: set that contains the corners of the rectangle + """ + + if len(points) <= 2: raise ValueError('More than two points required.') + + hull_ordered = [points[index] for index in ConvexHull(points).vertices] + hull_ordered.append(hull_ordered[0]) + #hull_ordered = compute_hull(points) + hull_ordered = tuple(hull_ordered) + + min_rectangle = bounding_area(0, hull_ordered) + for i in range(1, len(hull_ordered)-1): + rectangle = bounding_area(i, hull_ordered) + if rectangle['area'] < min_rectangle['area']: + min_rectangle = rectangle + + min_rectangle['unit_vector_angle'] = atan2(min_rectangle['unit_vector'][1], min_rectangle['unit_vector'][0]) + min_rectangle['rectangle_center'] = to_xy_coordinates(min_rectangle['unit_vector_angle'], min_rectangle['rectangle_center']) + + return bounding_box_tuple( + area = min_rectangle['area'], + length_parallel = min_rectangle['length_parallel'], + length_orthogonal = min_rectangle['length_orthogonal'], + rectangle_center = min_rectangle['rectangle_center'], + unit_vector = min_rectangle['unit_vector'], + unit_vector_angle = min_rectangle['unit_vector_angle'], + corner_points = set(rectangle_corners(min_rectangle)) + ) + + +def get_center(im): + """ Given image, returns the location of center pixel + Returns + ------- + (int, int): center of the image + """ + center_x = float(im.size[0]) / 2 + center_y = float(im.size[1]) / 2 + return int(center_x), int(center_y) + + +def get_horizontal_angle(unit_vector_angle): + """ Given an angle in radians, returns angle of the unit vector in + first or fourth quadrant. + Returns + ------ + (float): updated angle of the unit vector to be in radians. + It is only in first or fourth quadrant. + """ + if unit_vector_angle > pi / 2 and unit_vector_angle <= pi: + unit_vector_angle = unit_vector_angle - pi + elif unit_vector_angle > -pi and unit_vector_angle < -pi / 2: + unit_vector_angle = unit_vector_angle + pi + + return unit_vector_angle + + +def get_smaller_angle(bounding_box): + """ Given a rectangle, returns its smallest absolute angle from horizontal axis. + Returns + ------ + (float): smallest angle of the rectangle to be in radians. + """ + unit_vector = bounding_box.unit_vector + unit_vector_angle = bounding_box.unit_vector_angle + ortho_vector = orthogonal_vector(unit_vector) + ortho_vector_angle = atan2(ortho_vector[1], ortho_vector[0]) + + unit_vector_angle_updated = get_horizontal_angle(unit_vector_angle) + ortho_vector_angle_updated = get_horizontal_angle(ortho_vector_angle) + + if abs(unit_vector_angle_updated) < abs(ortho_vector_angle_updated): + return unit_vector_angle_updated + else: + return ortho_vector_angle_updated + + +def rotated_points(bounding_box, center): + """ Given the rectangle, returns corner points of rotated rectangle. + It rotates the rectangle around the center by its smallest angle. + Returns + ------- + [(int, int)]: 4 corner points of rectangle. + """ + p1, p2, p3, p4 = bounding_box.corner_points + x1, y1 = p1 + x2, y2 = p2 + x3, y3 = p3 + x4, y4 = p4 + center_x, center_y = center + rotation_angle_in_rad = -get_smaller_angle(bounding_box) + x_dash_1 = (x1 - center_x) * cos(rotation_angle_in_rad) - (y1 - center_y) * sin(rotation_angle_in_rad) + center_x + x_dash_2 = (x2 - center_x) * cos(rotation_angle_in_rad) - (y2 - center_y) * sin(rotation_angle_in_rad) + center_x + x_dash_3 = (x3 - center_x) * cos(rotation_angle_in_rad) - (y3 - center_y) * sin(rotation_angle_in_rad) + center_x + x_dash_4 = (x4 - center_x) * cos(rotation_angle_in_rad) - (y4 - center_y) * sin(rotation_angle_in_rad) + center_x + + y_dash_1 = (y1 - center_y) * cos(rotation_angle_in_rad) + (x1 - center_x) * sin(rotation_angle_in_rad) + center_y + y_dash_2 = (y2 - center_y) * cos(rotation_angle_in_rad) + (x2 - center_x) * sin(rotation_angle_in_rad) + center_y + y_dash_3 = (y3 - center_y) * cos(rotation_angle_in_rad) + (x3 - center_x) * sin(rotation_angle_in_rad) + center_y + y_dash_4 = (y4 - center_y) * cos(rotation_angle_in_rad) + (x4 - center_x) * sin(rotation_angle_in_rad) + center_y + return x_dash_1, y_dash_1, x_dash_2, y_dash_2, x_dash_3, y_dash_3, x_dash_4, y_dash_4 + + +def pad_image(image): + """ Given an image, returns a padded image around the border. + This routine save the code from crashing if bounding boxes that are + slightly outside the page boundary. + Returns + ------- + image: page image + """ + offset = int(args.padding // 2) + padded_image = Image.new('RGB', (image.size[0] + int(args.padding), image.size[1] + int(args.padding)), "white") + padded_image.paste(im = image, box = (offset, offset)) + return padded_image + +def update_minimum_bounding_box_input(bounding_box_input): + """ Given list of 2D points, returns list of 2D points shifted by an offset. + Returns + ------ + points [(float, float)]: points, a list or tuple of 2D coordinates + """ + updated_minimum_bounding_box_input = [] + offset = int(args.padding // 2) + for point in bounding_box_input: + x, y = point + new_x = x + offset + new_y = y + offset + word_coordinate = (new_x, new_y) + updated_minimum_bounding_box_input.append(word_coordinate) + + return updated_minimum_bounding_box_input + + +### main ### +csv_count = 0 +for filename in sorted(os.listdir(args.csv_dir)): + if filename.endswith('.csv') and (csv_count < args.head or args.head < 0): + csv_count = csv_count + 1 + with open(os.path.join(args.csv_dir, filename), 'r', encoding='utf-8') as f: + image_file = os.path.join(args.image_dir, os.path.splitext(filename)[0] + '.' + args.im_format) + if not os.path.isfile(image_file): + continue + csv_out_file = os.path.join(args.out_dir, 'truth_csv', filename) + csv_out_fh = open(csv_out_file, 'w', encoding='utf-8') + csv_out_writer = csv.writer(csv_out_fh) + im = Image.open(image_file) + im = pad_image(im) + count = 1 + for row in itertools.islice(csv.reader(f), 0, None): + if count == 1: + count = 0 + continue + + points = [] + points.append((int(row[2]), int(row[3]))) + points.append((int(row[4]), int(row[5]))) + points.append((int(row[6]), int(row[7]))) + points.append((int(row[8]), int(row[9]))) + + x = [int(row[2]), int(row[4]), int(row[6]), int(row[8])] + y = [int(row[3]), int(row[5]), int(row[7]), int(row[9])] + min_x, min_y = min(x), min(y) + max_x, max_y = max(x), max(y) + if min_x == max_x or min_y == max_y: + continue + + try: + updated_mbb_input = update_minimum_bounding_box_input(points) + bounding_box = minimum_bounding_box(updated_mbb_input) + except Exception as e: + print("Error: Skipping Image " + row[1]) + continue + + p1, p2, p3, p4 = bounding_box.corner_points + x1, y1 = p1 + x2, y2 = p2 + x3, y3 = p3 + x4, y4 = p4 + min_x = int(min(x1, x2, x3, x4)) + min_y = int(min(y1, y2, y3, y4)) + max_x = int(max(x1, x2, x3, x4)) + max_y = int(max(y1, y2, y3, y4)) + box = (min_x, min_y, max_x, max_y) + region_initial = im.crop(box) + rot_points = [] + p1_new = (x1 - min_x, y1 - min_y) + p2_new = (x2 - min_x, y2 - min_y) + p3_new = (x3 - min_x, y3 - min_y) + p4_new = (x4 - min_x, y4 - min_y) + rot_points.append(p1_new) + rot_points.append(p2_new) + rot_points.append(p3_new) + rot_points.append(p4_new) + + cropped_bounding_box = bounding_box_tuple(bounding_box.area, + bounding_box.length_parallel, + bounding_box.length_orthogonal, + bounding_box.length_orthogonal, + bounding_box.unit_vector, + bounding_box.unit_vector_angle, + set(rot_points)) + + rotation_angle_in_rad = get_smaller_angle(cropped_bounding_box) + img2 = region_initial.rotate(degrees(rotation_angle_in_rad), resample = Image.BICUBIC) + x_dash_1, y_dash_1, x_dash_2, y_dash_2, x_dash_3, y_dash_3, x_dash_4, y_dash_4 = rotated_points( + cropped_bounding_box, get_center(region_initial)) + + min_x = int(min(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) + min_y = int(min(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) + max_x = int(max(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) + max_y = int(max(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) + box = (min_x, min_y, max_x, max_y) + region_final = img2.crop(box) + csv_out_writer.writerow(row) + image_out_file = os.path.join(args.out_dir, 'truth_line_image', row[1]) + region_final.save(image_out_file) diff --git a/egs/yomdle_zh/v1/local/extract_features.sh b/egs/yomdle_zh/v1/local/extract_features.sh new file mode 100755 index 00000000000..f75837ae5b3 --- /dev/null +++ b/egs/yomdle_zh/v1/local/extract_features.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Copyright 2017 Yiwen Shao +# 2018 Ashish Arora + +nj=4 +cmd=run.pl +feat_dim=40 +fliplr=false +augment='no_aug' +num_channels=3 +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +data=$1 +featdir=$data/data +scp=$data/images.scp +logdir=$data/log + +mkdir -p $logdir +mkdir -p $featdir + +# make $featdir an absolute pathname +featdir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $featdir ${PWD}` + +for n in $(seq $nj); do + split_scps="$split_scps $logdir/images.$n.scp" +done + +# split images.scp +utils/split_scp.pl $scp $split_scps || exit 1; + +$cmd JOB=1:$nj $logdir/extract_features.JOB.log \ + image/ocr/make_features.py $logdir/images.JOB.scp \ + --allowed_len_file_path $data/allowed_lengths.txt \ + --feat-dim $feat_dim --num-channels $num_channels --fliplr $fliplr --augment_type $augment \| \ + copy-feats --compress=true --compression-method=7 \ + ark:- ark,scp:$featdir/images.JOB.ark,$featdir/images.JOB.scp + +## aggregates the output scp's to get feats.scp +for n in $(seq $nj); do + cat $featdir/images.$n.scp || exit 1; +done > $data/feats.scp || exit 1 diff --git a/egs/yomdle_zh/v1/local/gedi2csv.py b/egs/yomdle_zh/v1/local/gedi2csv.py new file mode 100755 index 00000000000..0b80c2e80bb --- /dev/null +++ b/egs/yomdle_zh/v1/local/gedi2csv.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python3 + +""" +GEDI2CSV +Convert GEDI-type bounding boxes to CSV format + +GEDI Format Example: + + + + + + + + + +CSV Format Example +ID,name,col1,row1,col2,row2,col3,row3,col4,row4,confidence,truth,pgrot,bbrot,qual,script,lang +0,chinese_scanned_books_0001_0.png,99,41,99,14,754,14,754,41,100,凡我的邻人说是好的,有一大部分在我灵魂中却,0,0.0,0,,zh-cn +""" + +import logging +import os +import sys +import time +import glob +import csv +import imghdr +from PIL import Image +import argparse +import pdb +import cv2 +import numpy as np +import xml.etree.ElementTree as ET + +sin = np.sin +cos = np.cos +pi = np.pi + +def Rotate2D(pts, cnt, ang=90): + M = np.array([[cos(ang),-sin(ang)],[sin(ang),cos(ang)]]) + res = np.dot(pts-cnt,M)+cnt + return M, res + +def npbox2string(npar): + if np.shape(npar)[0] != 1: + print('Error during CSV conversion\n') + c1,r1 = npar[0][0],npar[0][1] + c2,r2 = npar[0][2],npar[0][3] + c3,r3 = npar[0][4],npar[0][5] + c4,r4 = npar[0][6],npar[0][7] + + return c1,r1,c2,r2,c3,r3,c4,r4 + +# cv2.minAreaRect() returns a Box2D structure which contains following detals - ( center (x,y), (width, height), angle of rotation ) +# Get 4 corners of the rectangle using cv2.boxPoints() + +class GEDI2CSV(object): + + """ Initialize the extractor""" + def __init__(self, logger, args): + self._logger = logger + self._args = args + + """ + Segment image with GEDI bounding box information + """ + def csvfile(self, coords, polys, baseName, pgrot): + + """ for writing the files """ + writePath = self._args.outputDir + writePath = os.path.join(writePath,'') + if os.path.isdir(writePath) != True: + os.makedirs(writePath) + + rotlist = [] + + header=['ID','name','col1','row1','col2','row2','col3','row3','col4','row4','confidence','truth','pgrot','bbrot','qual','script','text_type'] + conf=100 + write_ctr = 0 + if len(coords) == 0 and len(polys) == 0: + self._logger.info('Found %s with no text content',(baseName)) + print('...Found %s with no text content' % (baseName)) + return + + strPos = writePath + baseName + + """ for each group of coordinates """ + for i in coords: + + [id,x,y,w,h,degrees,text,qual,script,text_type] = i + + contour = np.array([(x,y),(x+w,y),(x+w,y+h),(x,y+h)]) + + """ + First rotate around upper left corner based on orientationD keyword + """ + M, rot = Rotate2D(contour, np.array([x,y]), degrees*pi/180) + rot = np.int0(rot) + + # rot is the 8 points rotated by degrees + # pgrot is the rotation after extraction, so save + + # save rotated points to list or array + rot = np.reshape(rot,(-1,1)).T + c1,r1,c2,r2,c3,r3,c4,r4 = npbox2string(rot) + + text = text.replace(u'\ufeff','') + + bbrot = degrees + rotlist.append([id,baseName + '_' + id + '.png',c1,r1,c2,r2,c3,r3,c4,r4,conf,text,pgrot,bbrot,qual,script,text_type]) + + # if there are polygons, first save the text + for j in polys: + arr = [] + [id,poly_val,text,qual,script,text_type] = j + for i in poly_val: + arr.append(eval(i)) + + contour = np.asarray(arr) + convex = cv2.convexHull(contour) + rect = cv2.minAreaRect(convex) + box = cv2.boxPoints(rect) + box = np.int0(box) + box = np.reshape(box,(-1,1)).T + c1,r1,c2,r2,c3,r3,c4,r4 = npbox2string(box) + + bbrot = 0.0 + + rotlist.append([id,baseName + '_' + id + '.png',c1,r1,c2,r2,c3,r3,c4,r4,conf,text,pgrot,bbrot,qual,script,text_type]) + + # then write out all of list to file + with open(strPos + ".csv", "w", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(header) + for row in rotlist: + writer.writerow(row) + write_ctr += 1 + + return write_ctr + + +def main(args): + + startTime = time.clock() + + writePath = args.outputDir + if os.path.isdir(writePath) != True: + os.makedirs(writePath) + + """ Setup logging """ + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + if args.log: + handler = logging.FileHandler(args.log) + handler.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + + gtconverter = GEDI2CSV(logger, args) + namespaces = {"gedi" : "http://lamp.cfar.umd.edu/media/projects/GEDI/"} + keyCnt=0 + + fileCnt = 0 + line_write_ctr = 0 + line_error_ctr = 0 + + """ + Get all XML files in the directory and sub folders + """ + for root, dirnames, filenames in os.walk(args.inputDir, followlinks=True): + for file in filenames: + if file.lower().endswith('.xml'): + fullName = os.path.join(root,file) + baseName = os.path.splitext(fullName) + + fileCnt += 1 + + """ read the XML file """ + tree = ET.parse(fullName) + gedi_root = tree.getroot() + child = gedi_root.findall('gedi:DL_DOCUMENT',namespaces)[0] + totalpages = int(child.attrib['NrOfPages']) + coordinates=[] + polygons = [] + if args.ftype == 'boxed': + fileTypeStr = 'col' + elif args.ftype == 'transcribed': + fileTypeStr = 'Text_Content' + else: + print('Filetype must be either boxed or transcribed!') + logger.info('Filetype must be either boxed or transcribed!') + sys.exit(-1) + + if args.quality == 'both': + qualset = {'Regular','Low-Quality'} + elif args.quality == 'low': + qualset = {'Low-Quality'} + elif args.quality == 'regular': + qualset = {'Regular'} + else: + print('Quality must be both, low or regular!') + logger.info('Quality must be both, low or regular!') + sys.exit(-1) + + + + """ and for each page """ + for i, pgs in enumerate(child.iterfind('gedi:DL_PAGE',namespaces)): + + if 'GEDI_orientation' not in pgs.attrib: + pageRot=0 + else: + pageRot = int(pgs.attrib['GEDI_orientation']) + logger.info(' PAGE ROTATION %s, %s' % (fullName, str(pageRot))) + + """ find children for each page """ + for zone in pgs.findall('gedi:DL_ZONE',namespaces): + + if zone.attrib['gedi_type']=='Text' and zone.attrib['Type'] in \ + ('Machine_Print','Confusable_Allograph','Handwriting') and zone.attrib['Quality'] in qualset: + if zone.get('polygon'): + keyCnt+=1 + polygons.append([zone.attrib['id'],zone.get('polygon').split(';'), + zone.get('Text_Content'),zone.get('Quality'),zone.get('Script'),zone.get('Type')]) + elif zone.get(fileTypeStr) != None: + keyCnt+=1 + coord = [zone.attrib['id'],int(zone.attrib['col']),int(zone.attrib['row']), + int(zone.attrib['width']), int(zone.attrib['height']), + float(zone.get('orientationD',0.0)), + zone.get('Text_Content'),zone.get('Quality'),zone.get('Script'),zone.get('Type')] + coordinates.append(coord) + + if len(coordinates) > 0 or len(polygons) > 0: + line_write_ctr += gtconverter.csvfile(coordinates, polygons, os.path.splitext(file)[0], pageRot) + else: + print('...%s has no applicable content' % (baseName[0])) + + print('complete...total files %d, lines written %d' % (fileCnt, line_write_ctr)) + + +def parse_arguments(argv): + """ Args and defaults """ + parser = argparse.ArgumentParser() + + parser.add_argument('--inputDir', type=str, help='Input directory', required=True) + parser.add_argument('--outputDir', type=str, help='Output directory', required=True) + parser.add_argument('--ftype', type=str, help='GEDI file type (either "boxed" or "transcribed")', default='transcribed') + parser.add_argument('--quality', type=str, help='GEDI file quality (either "both" or "low" or "regular")', default='regular') + parser.add_argument('--log', type=str, help='Log directory', default='./GEDI2CSV_enriched.log') + + return parser.parse_args(argv) + +if __name__ == '__main__': + """ Run """ + main(parse_arguments(sys.argv[1:])) + + + + + + diff --git a/egs/yomdle_zh/v1/local/prepare_dict.sh b/egs/yomdle_zh/v1/local/prepare_dict.sh new file mode 100755 index 00000000000..65b2e7aa901 --- /dev/null +++ b/egs/yomdle_zh/v1/local/prepare_dict.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash + +# Copyright 2017 Hossein Hadian +# 2017 Chun Chieh Chang +# 2017 Ashish Arora + +# This script prepares the dictionary. + +set -e +dir=data/local/dict +data_dir=data + +. ./utils/parse_options.sh || exit 1; + +base_dir=$(echo "$DIRECTORY" | cut -d "/" -f2) + +mkdir -p $dir + +local/prepare_lexicon.py --data-dir $data_dir $dir + +cut -d' ' -f2- $dir/lexicon.txt | sed 's/SIL//g' | tr ' ' '\n' | sort -u | sed '/^$/d' >$dir/nonsilence_phones.txt || exit 1; + +echo ' SIL' >> $dir/lexicon.txt + +echo SIL > $dir/silence_phones.txt + +echo SIL >$dir/optional_silence.txt + +echo -n "" >$dir/extra_questions.txt diff --git a/egs/yomdle_zh/v1/local/prepare_lexicon.py b/egs/yomdle_zh/v1/local/prepare_lexicon.py new file mode 100755 index 00000000000..3ebb52e38f4 --- /dev/null +++ b/egs/yomdle_zh/v1/local/prepare_lexicon.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora +# Chun-Chieh Chang + +import argparse +import os + +parser = argparse.ArgumentParser(description="""Creates the list of characters and words in lexicon""") +parser.add_argument('dir', type=str, help='output path') +parser.add_argument('--data-dir', type=str, default='data', help='Path to text file') +args = parser.parse_args() + +### main ### +lex = {} +text_path = os.path.join(args.data_dir, 'train', 'text') +text_fh = open(text_path, 'r', encoding='utf-8') + +# Used specially for Chinese. +# Uses the ChangJie keyboard input method to create subword units for Chinese. +cj5_table = {} +with open('download/cj5-cc.txt', 'r', encoding='utf-8') as f: + for line in f: + line_vect = line.strip().split() + if not line_vect[0].startswith('yyy') and not line_vect[0].startswith('z'): + cj5_table[line_vect[1]] = "cj5_" + " cj5_".join(list(line_vect[0])) + +with open(text_path, 'r', encoding='utf-8') as f: + for line in f: + line_vect = line.strip().split() + for i in range(1, len(line_vect)): + characters = list(line_vect[i]) + # Put SIL instead of "|". Because every "|" in the beginning of the words is for initial-space of that word + characters = " ".join([ 'SIL' if char == '|' else cj5_table[char] if char in cj5_table else char for char in characters]) + characters = characters.replace('#','') + lex[line_vect[i]] = characters + +with open(os.path.join(args.dir, 'lexicon.txt'), 'w', encoding='utf-8') as fp: + for key in sorted(lex): + fp.write(key + " " + lex[key] + "\n") diff --git a/egs/yomdle_zh/v1/local/process_data.py b/egs/yomdle_zh/v1/local/process_data.py new file mode 100755 index 00000000000..8964af8890a --- /dev/null +++ b/egs/yomdle_zh/v1/local/process_data.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora +# 2018 Chun Chieh Chang + +""" This script reads the extracted Farsi OCR (yomdle and slam) database files + and creates the following files (for the data subset selected via --dataset): + text, utt2spk, images.scp. + Eg. local/process_data.py data/download/ data/local/splits/train.txt data/train + Eg. text file: english_phone_books_0001_1 To sum up, then, it would appear that + utt2spk file: english_phone_books_0001_0 english_phone_books_0001 + images.scp file: english_phone_books_0001_0 \ + data/download/truth_line_image/english_phone_books_0001_0.png +""" + +import argparse +import os +import sys +import csv +import itertools +import unicodedata + +parser = argparse.ArgumentParser(description="Creates text, utt2spk, and images.scp files") +parser.add_argument('database_path', type=str, help='Path to data') +parser.add_argument('out_dir', type=str, help='directory to output files') +parser.add_argument('--head', type=int, default=-1, help='limit on number of synth data') +args = parser.parse_args() + +### main ### +print("Processing '{}' data...".format(args.out_dir)) + +text_file = os.path.join(args.out_dir, 'text') +text_fh = open(text_file, 'w', encoding='utf-8') +utt2spk_file = os.path.join(args.out_dir, 'utt2spk') +utt2spk_fh = open(utt2spk_file, 'w', encoding='utf-8') +image_file = os.path.join(args.out_dir, 'images.scp') +image_fh = open(image_file, 'w', encoding='utf-8') + +count = 0 +for filename in sorted(os.listdir(os.path.join(args.database_path, 'truth_csv'))): + if filename.endswith('.csv') and (count < args.head or args.head < 0): + count = count + 1 + csv_filepath = os.path.join(args.database_path, 'truth_csv', filename) + csv_file = open(csv_filepath, 'r', encoding='utf-8') + row_count = 0 + for row in csv.reader(csv_file): + if row_count == 0: + row_count = 1 + continue + image_id = os.path.splitext(row[1])[0] + image_filepath = os.path.join(args.database_path, 'truth_line_image', row[1]) + text = unicodedata.normalize('NFC', row[11]).replace('\n', '') + if os.path.isfile(image_filepath) and os.stat(image_filepath).st_size != 0 and text: + text_fh.write(image_id + ' ' + text + '\n') + utt2spk_fh.write(image_id + ' ' + '_'.join(image_id.split('_')[:-1]) + '\n') + image_fh.write(image_id + ' ' + image_filepath + ' ' + row[13] + '\n') diff --git a/egs/yomdle_zh/v1/local/score.sh b/egs/yomdle_zh/v1/local/score.sh new file mode 100755 index 00000000000..f2405205f02 --- /dev/null +++ b/egs/yomdle_zh/v1/local/score.sh @@ -0,0 +1,5 @@ +#!/bin/bash + + +steps/scoring/score_kaldi_wer.sh --max-lmwt 10 "$@" +steps/scoring/score_kaldi_cer.sh --max-lmwt 10 --stage 2 "$@" diff --git a/egs/yomdle_zh/v1/local/train_lm.sh b/egs/yomdle_zh/v1/local/train_lm.sh new file mode 100755 index 00000000000..bc738f217da --- /dev/null +++ b/egs/yomdle_zh/v1/local/train_lm.sh @@ -0,0 +1,110 @@ +#!/bin/bash + +# Copyright 2016 Vincent Nguyen +# 2016 Johns Hopkins University (author: Daniel Povey) +# 2017 Ashish Arora +# 2017 Hossein Hadian +# Apache 2.0 +# +# This script trains a LM on the YOMDLE training transcriptions. +# It is based on the example scripts distributed with PocoLM + +# It will check if pocolm is installed and if not will proceed with installation + +set -e +stage=0 +dir=data/local/local_lm +data_dir=data + +echo "$0 $@" # Print the command line for logging +. ./utils/parse_options.sh || exit 1; + +lm_dir=${dir}/data + + +mkdir -p $dir +. ./path.sh || exit 1; # for KALDI_ROOT +export PATH=$KALDI_ROOT/tools/pocolm/scripts:$PATH +( # First make sure the pocolm toolkit is installed. + cd $KALDI_ROOT/tools || exit 1; + if [ -d pocolm ]; then + echo Not installing the pocolm toolkit since it is already there. + else + echo "$0: Please install the PocoLM toolkit with: " + echo " cd ../../../tools; extras/install_pocolm.sh; cd -" + exit 1; + fi +) || exit 1; + +bypass_metaparam_optim_opt= +# If you want to bypass the metaparameter optimization steps with specific metaparameters +# un-comment the following line, and change the numbers to some appropriate values. +# You can find the values from output log of train_lm.py. +# These example numbers of metaparameters is for 4-gram model (with min-counts) +# running with train_lm.py. +# The dev perplexity should be close to the non-bypassed model. +#bypass_metaparam_optim_opt= +# Note: to use these example parameters, you may need to remove the .done files +# to make sure the make_lm_dir.py be called and tain only 3-gram model +#for order in 3; do +#rm -f ${lm_dir}/${num_word}_${order}.pocolm/.done + +if [ $stage -le 0 ]; then + mkdir -p ${dir}/data + mkdir -p ${dir}/data/text + + echo "$0: Getting the Data sources" + + rm ${dir}/data/text/* 2>/dev/null || true + + # Note: the name 'dev' is treated specially by pocolm, it automatically + # becomes the dev set. + nr=`cat $data_dir/train/text | wc -l` + nr_dev=$(($nr / 10 )) + nr_train=$(( $nr - $nr_dev )) + + # use the training data as an additional data source. + # we can later fold the dev data into this. + head -n $nr_train $data_dir/train/text | cut -d " " -f 2- > ${dir}/data/text/train.txt + tail -n $nr_dev $data_dir/train/text | cut -d " " -f 2- > ${dir}/data/text/dev.txt + + # for reporting perplexities, we'll use the "real" dev set. + # (the validation data is used as ${dir}/data/text/dev.txt to work + # out interpolation weights.) + # note, we can't put it in ${dir}/data/text/, because then pocolm would use + # it as one of the data sources. + cut -d " " -f 2- < $data_dir/test/text > ${dir}/data/real_dev_set.txt + + # get the wordlist from MADCAT text + cat ${dir}/data/text/train.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + cat ${dir}/data/word_count | awk '{print $2}' > ${dir}/data/wordlist +fi + +order=3 + +if [ $stage -le 1 ]; then + # decide on the vocabulary. + # Note: you'd use --wordlist if you had a previously determined word-list + # that you wanted to use. + # Note: if you have more than one order, use a certain amount of words as the + # vocab and want to restrict max memory for 'sort', + echo "$0: training the unpruned LM" + min_counts='train=1' + wordlist=${dir}/data/wordlist + + lm_name="`basename ${wordlist}`_${order}" + if [ -n "${min_counts}" ]; then + lm_name+="_`echo ${min_counts} | tr -s "[:blank:]" "_" | tr "=" "-"`" + fi + unpruned_lm_dir=${lm_dir}/${lm_name}.pocolm + train_lm.py --wordlist=${wordlist} --num-splits=5 --warm-start-ratio=1 \ + --min-counts="$min_counts" \ + --limit-unk-history=true \ + ${bypass_metaparam_optim_opt} \ + ${dir}/data/text ${order} ${lm_dir}/work ${unpruned_lm_dir} + + get_data_prob.py ${dir}/data/real_dev_set.txt ${unpruned_lm_dir} 2>&1 | grep -F '[perplexity' + + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${unpruned_lm_dir} | gzip -c > ${dir}/data/arpa/${order}gram_unpruned.arpa.gz +fi diff --git a/egs/yomdle_zh/v1/local/train_lm_lr.sh b/egs/yomdle_zh/v1/local/train_lm_lr.sh new file mode 100755 index 00000000000..b95e6474b18 --- /dev/null +++ b/egs/yomdle_zh/v1/local/train_lm_lr.sh @@ -0,0 +1,113 @@ +#!/bin/bash + +# Copyright 2016 Vincent Nguyen +# 2016 Johns Hopkins University (author: Daniel Povey) +# 2017 Ashish Arora +# 2017 Hossein Hadian +# Apache 2.0 +# +# This script trains a LM on the YOMDLE+Extra training transcriptions. +# It is based on the example scripts distributed with PocoLM + +# It will check if pocolm is installed and if not will proceed with installation + +set -e +stage=0 +dir=data/local/local_lm +data_dir=data +extra_lm=download/extra_lm.txt +order=3 + +echo "$0 $@" # Print the command line for logging +. ./utils/parse_options.sh || exit 1; + +lm_dir=${dir}/data + + +mkdir -p $dir +. ./path.sh || exit 1; # for KALDI_ROOT +export PATH=$KALDI_ROOT/tools/pocolm/scripts:$PATH +( # First make sure the pocolm toolkit is installed. + cd $KALDI_ROOT/tools || exit 1; + if [ -d pocolm ]; then + echo Not installing the pocolm toolkit since it is already there. + else + echo "$0: Please install the PocoLM toolkit with: " + echo " cd ../../../tools; extras/install_pocolm.sh; cd -" + exit 1; + fi +) || exit 1; + +bypass_metaparam_optim_opt= +# If you want to bypass the metaparameter optimization steps with specific metaparameters +# un-comment the following line, and change the numbers to some appropriate values. +# You can find the values from output log of train_lm.py. +# These example numbers of metaparameters is for 4-gram model (with min-counts) +# running with train_lm.py. +# The dev perplexity should be close to the non-bypassed model. +#bypass_metaparam_optim_opt= +# Note: to use these example parameters, you may need to remove the .done files +# to make sure the make_lm_dir.py be called and tain only 3-gram model +#for order in 3; do +#rm -f ${lm_dir}/${num_word}_${order}.pocolm/.done + +if [ $stage -le 0 ]; then + mkdir -p ${dir}/data + mkdir -p ${dir}/data/text + + echo "$0: Getting the Data sources" + + rm ${dir}/data/text/* 2>/dev/null || true + + cat ${extra_lm} | utils/lang/bpe/prepend_words.py | python3 utils/lang/bpe/apply_bpe.py -c $data_dir/train/bpe.out | sed 's/@@//g' > ${dir}/data/text/extra_lm.txt + + # Note: the name 'dev' is treated specially by pocolm, it automatically + # becomes the dev set. + nr=`cat $data_dir/train/text | wc -l` + nr_dev=$(($nr / 10 )) + nr_train=$(( $nr - $nr_dev )) + + # use the training data as an additional data source. + # we can later fold the dev data into this. + head -n $nr_train $data_dir/train/text | cut -d " " -f 2- > ${dir}/data/text/train.txt + tail -n $nr_dev $data_dir/train/text | cut -d " " -f 2- > ${dir}/data/text/dev.txt + + # for reporting perplexities, we'll use the "real" dev set. + # (the validation data is used as ${dir}/data/text/dev.txt to work + # out interpolation weights.) + # note, we can't put it in ${dir}/data/text/, because then pocolm would use + # it as one of the data sources. + cut -d " " -f 2- < $data_dir/test/text > ${dir}/data/real_dev_set.txt + + # get the wordlist from MADCAT text + cat ${dir}/data/text/{train,extra_lm}.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + #cat ${dir}/data/text/extra_fa.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + cat ${dir}/data/word_count | awk '{print $2}' > ${dir}/data/wordlist +fi + +if [ $stage -le 1 ]; then + # decide on the vocabulary. + # Note: you'd use --wordlist if you had a previously determined word-list + # that you wanted to use. + # Note: if you have more than one order, use a certain amount of words as the + # vocab and want to restrict max memory for 'sort', + echo "$0: training the unpruned LM" + min_counts='extra_lm=10 train=1' + wordlist=${dir}/data/wordlist + + lm_name="`basename ${wordlist}`_${order}" + if [ -n "${min_counts}" ]; then + lm_name+="_`echo ${min_counts} | tr -s "[:blank:]" "_" | tr "=" "-"`" + fi + unpruned_lm_dir=${lm_dir}/${lm_name}.pocolm + train_lm.py --wordlist=${wordlist} --num-splits=30 --warm-start-ratio=1 \ + --min-counts="$min_counts" \ + --limit-unk-history=true \ + ${bypass_metaparam_optim_opt} \ + ${dir}/data/text ${order} ${lm_dir}/work ${unpruned_lm_dir} + + get_data_prob.py ${dir}/data/real_dev_set.txt ${unpruned_lm_dir} 2>&1 | grep -F '[perplexity' + + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${unpruned_lm_dir} | gzip -c > ${dir}/data/arpa/${order}gram_unpruned.arpa.gz +fi diff --git a/egs/yomdle_zh/v1/local/wer_output_filter b/egs/yomdle_zh/v1/local/wer_output_filter new file mode 100755 index 00000000000..08d5563bca4 --- /dev/null +++ b/egs/yomdle_zh/v1/local/wer_output_filter @@ -0,0 +1,151 @@ +#!/usr/bin/env perl +# Copyright 2012-2014 Johns Hopkins University (Author: Yenda Trmal) +# Apache 2.0 + +use utf8; + +use open qw(:encoding(utf8)); +binmode STDIN, ":utf8"; +binmode STDOUT, ":utf8"; +binmode STDERR, ":utf8"; + +# Arabic-specific normalization +while (<>) { + @F = split " "; + print "$F[0] "; + foreach $s (@F[1..$#F]) { + # Normalize tabs, spaces, and no-break spaces + $s =~ s/[\x{0009}\x{0020}\x{00A0}]+/ /g; + # Normalize "dots"/"filled-circles" to periods + $s =~ s/[\x{25CF}\x{u2022}\x{2219}]+/\x{002E}/g; + # Normalize dashes to regular hyphen + $s =~ s/[\x{2010}\x{2011}\x{2012}\x{2013}\x{2014}\x{2015}]+/\x{002D}/g; + # Normalize various parenthesis to regular parenthesis + $s =~ s/\x{UFF09}/\x{0029}/g; + $s =~ s/\x{UFF08}/\x{0028}/g; + + # Convert various presentation forms to base form + $s =~ s/[\x{FED1}\x{FED3}\x{FED4}\x{FED2}]+/\x{0641}/g; + $s =~ s/[\x{FBB0}\x{FBB1}]+/\x{06D3}/g; + $s =~ s/[\x{FECD}\x{FECF}\x{FED0}\x{FECE}]+/\x{063A}/g; + $s =~ s/[\x{FBDD}]+/\x{0677}/g; + $s =~ s/[\x{FBA6}\x{FBA8}\x{FBA9}\x{FBA7}]+/\x{06C1}/g; + $s =~ s/[\x{FEC1}\x{FEC3}\x{FEC4}\x{FEC2}]+/\x{0637}/g; + $s =~ s/[\x{FE85}\x{FE86}]+/\x{0624}/g; + $s =~ s/[\x{FEA5}\x{FEA7}\x{FEA8}\x{FEA6}]+/\x{062E}/g; + $s =~ s/[\x{FBD9}\x{FBDA}]+/\x{06C6}/g; + $s =~ s/[\x{FE8F}\x{FE91}\x{FE92}\x{FE90}]+/\x{0628}/g; + $s =~ s/[\x{FEED}\x{FEEE}]+/\x{0648}/g; + $s =~ s/[\x{FE99}\x{FE9B}\x{FE9C}\x{FE9A}]+/\x{062B}/g; + $s =~ s/[\x{FEBD}\x{FEBF}\x{FEC0}\x{FEBE}]+/\x{0636}/g; + $s =~ s/[\x{FEE5}\x{FEE7}\x{FEE8}\x{FEE6}]+/\x{0646}/g; + $s =~ s/[\x{FBFC}\x{FBFE}\x{FBFF}\x{FBFD}]+/\x{06CC}/g; + $s =~ s/[\x{FBA4}\x{FBA5}]+/\x{06C0}/g; + $s =~ s/[\x{FB72}\x{FB74}\x{FB75}\x{FB73}]+/\x{0684}/g; + $s =~ s/[\x{FBD3}\x{FBD5}\x{FBD6}\x{FBD4}]+/\x{06AD}/g; + $s =~ s/[\x{FB6A}\x{FB6C}\x{FB6D}\x{FB6B}]+/\x{06A4}/g; + $s =~ s/[\x{FB66}\x{FB68}\x{FB69}\x{FB67}]+/\x{0679}/g; + $s =~ s/[\x{FB5E}\x{FB60}\x{FB61}\x{FB5F}]+/\x{067A}/g; + $s =~ s/[\x{FB88}\x{FB89}]+/\x{0688}/g; + $s =~ s/[\x{FB7E}\x{FB80}\x{FB81}\x{FB7F}]+/\x{0687}/g; + $s =~ s/[\x{FB8E}\x{FB90}\x{FB91}\x{FB8F}]+/\x{06A9}/g; + $s =~ s/[\x{FB86}\x{FB87}]+/\x{068E}/g; + $s =~ s/[\x{FE83}\x{FE84}]+/\x{0623}/g; + $s =~ s/[\x{FB8A}\x{FB8B}]+/\x{0698}/g; + $s =~ s/[\x{FED5}\x{FED7}\x{FED8}\x{FED6}]+/\x{0642}/g; + $s =~ s/[\x{FED9}\x{FEDB}\x{FEDC}\x{FEDA}]+/\x{0643}/g; + $s =~ s/[\x{FBE0}\x{FBE1}]+/\x{06C5}/g; + $s =~ s/[\x{FEB9}\x{FEBB}\x{FEBC}\x{FEBA}]+/\x{0635}/g; + $s =~ s/[\x{FEC5}\x{FEC7}\x{FEC8}\x{FEC6}]+/\x{0638}/g; + $s =~ s/[\x{FE8D}\x{FE8E}]+/\x{0627}/g; + $s =~ s/[\x{FB9A}\x{FB9C}\x{FB9D}\x{FB9B}]+/\x{06B1}/g; + $s =~ s/[\x{FEAD}\x{FEAE}]+/\x{0631}/g; + $s =~ s/[\x{FEF1}\x{FEF3}\x{FEF4}\x{FEF2}]+/\x{064A}/g; + $s =~ s/[\x{FE93}\x{FE94}]+/\x{0629}/g; + $s =~ s/[\x{FBE4}\x{FBE6}\x{FBE7}\x{FBE5}]+/\x{06D0}/g; + $s =~ s/[\x{FE89}\x{FE8B}\x{FE8C}\x{FE8A}]+/\x{0626}/g; + $s =~ s/[\x{FB84}\x{FB85}]+/\x{068C}/g; + $s =~ s/[\x{FE9D}\x{FE9F}\x{FEA0}\x{FE9E}]+/\x{062C}/g; + $s =~ s/[\x{FB82}\x{FB83}]+/\x{068D}/g; + $s =~ s/[\x{FEA1}\x{FEA3}\x{FEA4}\x{FEA2}]+/\x{062D}/g; + $s =~ s/[\x{FB52}\x{FB54}\x{FB55}\x{FB53}]+/\x{067B}/g; + $s =~ s/[\x{FB92}\x{FB94}\x{FB95}\x{FB93}]+/\x{06AF}/g; + $s =~ s/[\x{FB7A}\x{FB7C}\x{FB7D}\x{FB7B}]+/\x{0686}/g; + $s =~ s/[\x{FBDB}\x{FBDC}]+/\x{06C8}/g; + $s =~ s/[\x{FB56}\x{FB58}\x{FB59}\x{FB57}]+/\x{067E}/g; + $s =~ s/[\x{FEB5}\x{FEB7}\x{FEB8}\x{FEB6}]+/\x{0634}/g; + $s =~ s/[\x{FBE2}\x{FBE3}]+/\x{06C9}/g; + $s =~ s/[\x{FB96}\x{FB98}\x{FB99}\x{FB97}]+/\x{06B3}/g; + $s =~ s/[\x{FE80}]+/\x{0621}/g; + $s =~ s/[\x{FBAE}\x{FBAF}]+/\x{06D2}/g; + $s =~ s/[\x{FB62}\x{FB64}\x{FB65}\x{FB63}]+/\x{067F}/g; + $s =~ s/[\x{FEE9}\x{FEEB}\x{FEEC}\x{FEEA}]+/\x{0647}/g; + $s =~ s/[\x{FE81}\x{FE82}]+/\x{0622}/g; + $s =~ s/[\x{FBDE}\x{FBDF}]+/\x{06CB}/g; + $s =~ s/[\x{FE87}\x{FE88}]+/\x{0625}/g; + $s =~ s/[\x{FB6E}\x{FB70}\x{FB71}\x{FB6F}]+/\x{06A6}/g; + $s =~ s/[\x{FBA0}\x{FBA2}\x{FBA3}\x{FBA1}]+/\x{06BB}/g; + $s =~ s/[\x{FBAA}\x{FBAC}\x{FBAD}\x{FBAB}]+/\x{06BE}/g; + $s =~ s/[\x{FEA9}\x{FEAA}]+/\x{062F}/g; + $s =~ s/[\x{FEE1}\x{FEE3}\x{FEE4}\x{FEE2}]+/\x{0645}/g; + $s =~ s/[\x{FEEF}\x{FBE8}\x{FBE9}\x{FEF0}]+/\x{0649}/g; + $s =~ s/[\x{FB8C}\x{FB8D}]+/\x{0691}/g; + $s =~ s/[\x{FB76}\x{FB78}\x{FB79}\x{FB77}]+/\x{0683}/g; + $s =~ s/[\x{FB5A}\x{FB5C}\x{FB5D}\x{FB5B}]+/\x{0680}/g; + $s =~ s/[\x{FB9E}\x{FB9F}]+/\x{06BA}/g; + $s =~ s/[\x{FEC9}\x{FECB}\x{FECC}\x{FECA}]+/\x{0639}/g; + $s =~ s/[\x{FEDD}\x{FEDF}\x{FEE0}\x{FEDE}]+/\x{0644}/g; + $s =~ s/[\x{FB50}\x{FB51}]+/\x{0671}/g; + $s =~ s/[\x{FEB1}\x{FEB3}\x{FEB4}\x{FEB2}]+/\x{0633}/g; + $s =~ s/[\x{FE95}\x{FE97}\x{FE98}\x{FE96}]+/\x{062A}/g; + $s =~ s/[\x{FBD7}\x{FBD8}]+/\x{06C7}/g; + $s =~ s/[\x{FEAF}\x{FEB0}]+/\x{0632}/g; + $s =~ s/[\x{FEAB}\x{FEAC}]+/\x{0630}/g; + + # Remove tatweel + $s =~ s/\x{0640}//g; + # Remove vowels and hamza + $s =~ s/[\x{064B}-\x{0655}]+//g; + # Remove right-to-left and left-to-right + $s =~ s/[\x{200F}\x{200E}]+//g; + # Arabic Keheh to Arabic Kaf + $s =~ s/\x{06A9}/\x{0643}/g; + # Arabic Yeh to Farsi Yeh + $s =~ s/\x{064A}/\x{06CC}/g; + # Decompose RIAL + $s =~ s/\x{FDFC}/\x{0631}\x{06CC}\x{0627}\x{0644}/g; + # Farsi arabic-indic digits to arabic-indic digits + $s =~ s/\x{06F0}/\x{0660}/g; + $s =~ s/\x{06F1}/\x{0661}/g; + $s =~ s/\x{06F2}/\x{0662}/g; + $s =~ s/\x{06F3}/\x{0663}/g; + $s =~ s/\x{06F4}/\x{0664}/g; + $s =~ s/\x{06F5}/\x{0665}/g; + $s =~ s/\x{06F6}/\x{0666}/g; + $s =~ s/\x{06F7}/\x{0667}/g; + $s =~ s/\x{06F8}/\x{0668}/g; + $s =~ s/\x{06F9}/\x{0669}/g; + # Arabic-indic digits to digits + $s =~ s/\x{0660}/0/g; + $s =~ s/\x{0661}/1/g; + $s =~ s/\x{0662}/2/g; + $s =~ s/\x{0663}/3/g; + $s =~ s/\x{0664}/4/g; + $s =~ s/\x{0665}/5/g; + $s =~ s/\x{0666}/6/g; + $s =~ s/\x{0667}/7/g; + $s =~ s/\x{0668}/8/g; + $s =~ s/\x{0669}/9/g; + # Arabic comma to comma + $s =~ s/\x{060C}/\x{002C}/g; + + $s =~ s/\|/ /g; + if ($s ne "") { + print "$s"; + } else { + print ""; + } + } + print "\n"; +} + diff --git a/egs/yomdle_zh/v1/local/yomdle2csv.py b/egs/yomdle_zh/v1/local/yomdle2csv.py new file mode 100755 index 00000000000..8f208e2d968 --- /dev/null +++ b/egs/yomdle_zh/v1/local/yomdle2csv.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 + +""" +GEDI2CSV +Convert GEDI-type bounding boxes to CSV format + +GEDI Format Example: + + + + + + + + + +CSV Format Example +ID,name,col1,row1,col2,row2,col3,row3,col4,row4,confidence,truth,pgrot,bbrot,qual,script,lang +0,chinese_scanned_books_0001_0.png,99,41,99,14,754,14,754,41,100,凡我的邻人说是好的,有一大部分在我灵魂中却,0,0.0,0,,zh-cn +""" + +import logging +import os +import sys +import time +import glob +import csv +import imghdr +from PIL import Image +import argparse +import pdb +import cv2 +import numpy as np +import xml.etree.ElementTree as ET + +sin = np.sin +cos = np.cos +pi = np.pi + +def Rotate2D(pts, cnt, ang=90): + M = np.array([[cos(ang),-sin(ang)],[sin(ang),cos(ang)]]) + res = np.dot(pts-cnt,M)+cnt + return M, res + +def npbox2string(npar): + if np.shape(npar)[0] != 1: + print('Error during CSV conversion\n') + c1,r1 = npar[0][0],npar[0][1] + c2,r2 = npar[0][2],npar[0][3] + c3,r3 = npar[0][4],npar[0][5] + c4,r4 = npar[0][6],npar[0][7] + + return c1,r1,c2,r2,c3,r3,c4,r4 + +# cv2.minAreaRect() returns a Box2D structure which contains following detals - ( center (x,y), (width, height), angle of rotation ) +# Get 4 corners of the rectangle using cv2.boxPoints() + +class GEDI2CSV(object): + + """ Initialize the extractor""" + def __init__(self, logger, args): + self._logger = logger + self._args = args + + """ + Segment image with GEDI bounding box information + """ + def csvfile(self, coords, polys, baseName, pgrot): + + """ for writing the files """ + writePath = self._args.outputDir + if os.path.isdir(writePath) != True: + os.makedirs(writePath) + + rotlist = [] + + header=['ID','name','col1','row1','col2','row2','col3','row3','col4','row4','confidence','truth','pgrot','bbrot','qual','script','lang'] + conf=100 + pgrot = 0 + bbrot = 0 + qual = 0 + script = '' + + write_ctr = 0 + if len(coords) == 0 and len(polys) == 0: + self._logger.info('Found %s with no text content',(baseName)) + print('...Found %s with no text content' % (baseName)) + return + + strPos = writePath + baseName + + for j in polys: + try: + arr = [] + [id,poly_val,text,qual,lang] = j + script=None + #print(j) + for i in poly_val: + if len(i.strip()) > 0: + #print(i) + arr.append(eval(i)) + + contour = np.asarray(arr) + #print(contour) + convex = cv2.convexHull(contour) + rect = cv2.minAreaRect(convex) + box = cv2.boxPoints(rect) + box = np.int0(box) + box = np.reshape(box,(-1,1)).T + c1,r1,c2,r2,c3,r3,c4,r4 = npbox2string(box) + + bbrot = 0.0 + + rotlist.append([id,baseName + '_' + id + '.png',c1,r1,c2,r2,c3,r3,c4,r4,conf,text,pgrot,bbrot,qual,script,lang]) + + except: + print('...polygon error %s, %s' % (j, baseName)) + continue + + # then write out all of list to file + with open(strPos + ".csv", "w", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(header) + for row in rotlist: + writer.writerow(row) + write_ctr += 1 + + return write_ctr + + +def main(args): + + startTime = time.clock() + + writePath = args.outputDir + print('write to %s' % (writePath)) + if os.path.isdir(writePath) != True: + os.makedirs(writePath) + + """ Setup logging """ + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + if args.log: + handler = logging.FileHandler(args.log) + handler.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + + gtconverter = GEDI2CSV(logger, args) + namespaces = {"gedi" : "http://lamp.cfar.umd.edu/media/projects/GEDI/"} + keyCnt=0 + + fileCnt = 0 + line_write_ctr = 0 + line_error_ctr = 0 + file_error_ctr = 0 + """ + Get all XML files in the directory and sub folders + """ + print('reading %s' % (args.inputDir)) + for root, dirnames, filenames in os.walk(args.inputDir, followlinks=True): + for file in filenames: + if file.lower().endswith('.xml'): + fullName = os.path.join(root,file) + baseName = os.path.splitext(fullName) + + fileCnt += 1 + + try: + """ read the XML file """ + tree = ET.parse(fullName) + except: + print('...ERROR parsing %s' % (fullName)) + file_error_ctr += 1 + continue + + gedi_root = tree.getroot() + child = gedi_root.findall('gedi:DL_DOCUMENT',namespaces)[0] + totalpages = int(child.attrib['NrOfPages']) + coordinates=[] + polygons = [] + + """ and for each page """ + for i, pgs in enumerate(child.iterfind('gedi:DL_PAGE',namespaces)): + + if 'GEDI_orientation' not in pgs.attrib: + pageRot=0 + else: + pageRot = int(pgs.attrib['GEDI_orientation']) + logger.info(' PAGE ROTATION %s, %s' % (fullName, str(pageRot))) + + """ find children for each page """ + for zone in pgs.findall('gedi:DL_ZONE',namespaces): + + if zone.attrib['gedi_type']=='Text' : + if zone.get('polygon'): + keyCnt+=1 + polygons.append([zone.attrib['id'],zone.get('polygon').split(';'), + zone.get('Text_Content'),zone.get('Illegible'),zone.get('Language')]) + else: + print('...Not polygon') + + + if len(coordinates) > 0 or len(polygons) > 0: + line_write_ctr += gtconverter.csvfile(coordinates, polygons, os.path.splitext(file)[0], pageRot) + else: + print('...%s has no text content' % (baseName[0])) + + + print('complete...total files %d, lines written %d, img errors %d, line error %d' % (fileCnt, line_write_ctr, file_error_ctr, line_error_ctr)) + + +def parse_arguments(argv): + """ Args and defaults """ + parser = argparse.ArgumentParser() + + parser.add_argument('--inputDir', type=str, help='Input directory', default='/data/YOMDLE/final_arabic/xml') + parser.add_argument('--outputDir', type=str, help='Output directory', default='/exp/YOMDLE/final_arabic/csv_truth/') + parser.add_argument('--log', type=str, help='Log directory', default='/exp/logs.txt') + + return parser.parse_args(argv) + + +if __name__ == '__main__': + """ Run """ + main(parse_arguments(sys.argv[1:])) diff --git a/egs/yomdle_zh/v1/path.sh b/egs/yomdle_zh/v1/path.sh new file mode 100644 index 00000000000..2d17b17a84a --- /dev/null +++ b/egs/yomdle_zh/v1/path.sh @@ -0,0 +1,6 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C diff --git a/egs/yomdle_zh/v1/run.sh b/egs/yomdle_zh/v1/run.sh new file mode 100755 index 00000000000..128f15694cc --- /dev/null +++ b/egs/yomdle_zh/v1/run.sh @@ -0,0 +1,120 @@ +#!/bin/bash + +set -e +stage=0 +nj=60 + +database_slam=/export/corpora5/slam/SLAM/Chinese/transcribed +database_yomdle=/export/corpora5/slam/YOMDLE/final_chinese +download_dir=data_yomdle_chinese/download/ +extra_lm=download/extra_lm.txt +data_dir=data_yomdle_chinese +exp_dir=exp_yomdle_chinese + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + local/create_download.sh --database-slam $database_slam \ + --database-yomdle $database_yomdle \ + --slam-dir download/slam_chinese \ + --yomdle-dir download/yomdle_chinese +fi + +if [ $stage -le 0 ]; then + mkdir -p data_slam_chinese/slam + mkdir -p data_yomdle_chinese/yomdle + local/process_data.py download/slam_chinese data_slam_chinese/slam + local/process_data.py download/yomdle_chinese data_yomdle_chinese/yomdle + ln -s ../data_slam_chinese/slam ${data_dir}/test + ln -s ../data_yomdle_chinese/yomdle ${data_dir}/train + image/fix_data_dir.sh ${data_dir}/test + image/fix_data_dir.sh ${data_dir}/train +fi + +mkdir -p $data_dir/{train,test}/data +if [ $stage -le 1 ]; then + echo "$0: Obtaining image groups. calling get_image2num_frames" + echo "Date: $(date)." + image/get_image2num_frames.py --feat-dim 60 $data_dir/train + image/get_allowed_lengths.py --frame-subsampling-factor 4 10 $data_dir/train + + for datasplit in train test; do + echo "$0: Extracting features and calling compute_cmvn_stats for dataset: $datasplit. " + echo "Date: $(date)." + local/extract_features.sh --nj $nj --cmd "$cmd" \ + --feat-dim 60 --num-channels 3 \ + $data_dir/${datasplit} + steps/compute_cmvn_stats.sh $data_dir/${datasplit} || exit 1; + done + + echo "$0: Fixing data directory for train dataset" + echo "Date: $(date)." + utils/fix_data_dir.sh $data_dir/train +fi + +if [ $stage -le 2 ]; then + for datasplit in train; do + echo "$(date) stage 2: Performing augmentation, it will double training data" + local/augment_data.sh --nj $nj --cmd "$cmd" --feat-dim 60 $data_dir/${datasplit} $data_dir/${datasplit}_aug $data_dir + steps/compute_cmvn_stats.sh $data_dir/${datasplit}_aug || exit 1; + done +fi + +if [ $stage -le 3 ]; then + echo "$0: Preparing dictionary and lang..." + if [ ! -f $data_dir/train/bpe.out ]; then + cut -d' ' -f2- $data_dir/train/text | utils/lang/bpe/prepend_words.py | python3 utils/lang/bpe/learn_bpe.py -s 700 > $data_dir/train/bpe.out + for datasplit in test train train_aug; do + cut -d' ' -f1 $data_dir/$datasplit/text > $data_dir/$datasplit/ids + cut -d' ' -f2- $data_dir/$datasplit/text | utils/lang/bpe/prepend_words.py | python3 utils/lang/bpe/apply_bpe.py -c $data_dir/train/bpe.out | sed 's/@@//g' > $data_dir/$datasplit/bpe_text + mv $data_dir/$datasplit/text $data_dir/$datasplit/text.old + paste -d' ' $data_dir/$datasplit/ids $data_dir/$datasplit/bpe_text > $data_dir/$datasplit/text + done + fi + + local/prepare_dict.sh --data-dir $data_dir --dir $data_dir/local/dict + # This recipe uses byte-pair encoding, the silences are part of the words' pronunciations. + # So we set --sil-prob to 0.0 + utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 8 --sil-prob 0.0 --position-dependent-phones false \ + $data_dir/local/dict "" $data_dir/lang/temp $data_dir/lang + utils/lang/bpe/add_final_optional_silence.sh --final-sil-prob 0.5 $data_dir/lang +fi + +if [ $stage -le 4 ]; then + echo "$0: Estimating a language model for decoding..." + local/train_lm.sh --data-dir $data_dir --dir $data_dir/local/local_lm + utils/format_lm.sh $data_dir/lang $data_dir/local/local_lm/data/arpa/3gram_unpruned.arpa.gz \ + $data_dir/local/dict/lexicon.txt $data_dir/lang_test +fi + +if [ $stage -le 5 ]; then + echo "$0: Calling the flat-start chain recipe..." + echo "Date: $(date)." + local/chain/run_flatstart_cnn1a.sh --nj $nj --train-set train_aug --data-dir $data_dir --exp-dir $exp_dir +fi + +if [ $stage -le 6 ]; then + echo "$0: Aligning the training data using the e2e chain model..." + echo "Date: $(date)." + steps/nnet3/align.sh --nj $nj --cmd "$cmd" --use-gpu false \ + --scale-opts '--transition-scale=1.0 --acoustic-scale=1.0 --self-loop-scale=1.0' \ + $data_dir/train_aug $data_dir/lang $exp_dir/chain/e2e_cnn_1a $exp_dir/chain/e2e_ali_train +fi + +if [ $stage -le 7 ]; then + echo "$0: Building a tree and training a regular chain model using the e2e alignments..." + echo "Date: $(date)." + local/chain/run_cnn_e2eali_1b.sh --nj $nj --train-set train_aug --data-dir $data_dir --exp-dir $exp_dir +fi + +if [ $stage -le 8 ]; then + echo "$0: Estimating a language model for lattice rescoring...$(date)" + local/train_lm_lr.sh --data-dir $data_dir --dir $data_dir/local/local_lm_lr --extra-lm $extra_lm --order 6 + + utils/build_const_arpa_lm.sh $data_dir/local/local_lm_lr/data/arpa/6gram_unpruned.arpa.gz \ + $data_dir/lang_test $data_dir/lang_test_lr + steps/lmrescore_const_arpa.sh $data_dir/lang_test $data_dir/lang_test_lr \ + $data_dir/test $exp_dir/chain/cnn_e2eali_1b/decode_test $exp_dir/chain/cnn_e2eali_1b/decode_test_lr +fi diff --git a/egs/yomdle_zh/v1/steps b/egs/yomdle_zh/v1/steps new file mode 120000 index 00000000000..1b186770dd1 --- /dev/null +++ b/egs/yomdle_zh/v1/steps @@ -0,0 +1 @@ +../../wsj/s5/steps/ \ No newline at end of file diff --git a/egs/yomdle_zh/v1/utils b/egs/yomdle_zh/v1/utils new file mode 120000 index 00000000000..a3279dc8679 --- /dev/null +++ b/egs/yomdle_zh/v1/utils @@ -0,0 +1 @@ +../../wsj/s5/utils/ \ No newline at end of file diff --git a/egs/zeroth_korean/s5/README.txt b/egs/zeroth_korean/s5/README.txt new file mode 100644 index 00000000000..daa007362d8 --- /dev/null +++ b/egs/zeroth_korean/s5/README.txt @@ -0,0 +1,13 @@ +Zeroth-Korean kaldi example is from Zeroth Project. Zeroth project introduces free Korean speech corpus and aims to make Korean speech recognition more broadly accessible to everyone. This project was developed in collaboration between Lucas Jo(@Atlas Guide Inc.) and Wonkyum Lee(@Gridspace Inc.). + +In this example, we are using 51.6 hours transcribed Korean audio for training data (22,263 utterances, 105 people, 3000 sentences) and 1.2 hours transcribed Korean audio for testing data (457 utterances, 10 people). Besides audio and transcription, we provide pre-trained/designed language model, lexicon and morpheme-based segmenter(morfessor) + +The database can be also downloaded from openslr: +http://www.openslr.org/40 + +The database is licensed under Attribution 4.0 International (CC BY 4.0) + +This folder contains a speech recognition recipe which is based on WSJ/Librispeech example. + +For more details about Zeroth project, please visit: +https://github.com/goodatlas/zeroth diff --git a/egs/zeroth_korean/s5/RESULTS b/egs/zeroth_korean/s5/RESULTS new file mode 100644 index 00000000000..9255ec17673 --- /dev/null +++ b/egs/zeroth_korean/s5/RESULTS @@ -0,0 +1,63 @@ +#!/bin/bash + +# this RESULTS file was obtained by Wonkyum Lee in July 2018. + +for dir in exp/*; do + steps/info/gmm_dir_info.pl $dir + for x in $dir/decode*test*; do [ -d $x ] && [[ $x =~ "$1" ]] && grep WER $x/wer_* | utils/best_wer.sh; done +done +exit 0 + +# monophone, trained on the 2k shortest utterances +exp/mono: nj=16 align prob=-99.85 over 2.66h [retry=0.8%, fail=0.3%] states=130 gauss=1004 +%WER 70.24 [ 6499 / 9253, 295 ins, 1399 del, 4805 sub ] exp/mono/decode_nosp_fglarge_test_clean/wer_8_0.5 +%WER 71.28 [ 6596 / 9253, 185 ins, 1721 del, 4690 sub ] exp/mono/decode_nosp_tglarge_test_clean/wer_9_1.0 +%WER 78.83 [ 7294 / 9253, 218 ins, 1752 del, 5324 sub ] exp/mono/decode_nosp_tgsmall_test_clean/wer_10_0.0 + +# first triphone build, trained on 5k utterances +exp/tri1: nj=16 align prob=-98.34 over 11.55h [retry=1.6%, fail=0.6%] states=1568 gauss=10030 tree-impr=4.07 +%WER 37.44 [ 3464 / 9253, 258 ins, 725 del, 2481 sub ] exp/tri1/decode_nosp_fglarge_test_clean/wer_15_0.5 +%WER 38.85 [ 3595 / 9253, 347 ins, 633 del, 2615 sub ] exp/tri1/decode_nosp_tglarge_test_clean/wer_15_0.0 +%WER 53.23 [ 4925 / 9253, 296 ins, 1060 del, 3569 sub ] exp/tri1/decode_nosp_tgsmall_test_clean/wer_15_0.0 + +# tri2 is an LDA+MLLT systemm, trained on 10k utterances +exp/tri2: nj=16 align prob=-49.63 over 23.00h [retry=1.7%, fail=0.8%] states=2000 gauss=15039 tree-impr=4.70 lda-sum=18.11 mllt:impr,logdet=0.99,1.39 +%WER 33.50 [ 3100 / 9253, 248 ins, 626 del, 2226 sub ] exp/tri2/decode_nosp_fglarge_test_clean/wer_16_0.5 +%WER 34.55 [ 3197 / 9253, 315 ins, 537 del, 2345 sub ] exp/tri2/decode_nosp_tglarge_test_clean/wer_16_0.0 +%WER 48.98 [ 4532 / 9253, 303 ins, 903 del, 3326 sub ] exp/tri2/decode_nosp_tgsmall_test_clean/wer_14_0.0 + +# tri3 is an LDA+MLLT+SAT system, trained on entire clean training set +exp/tri3: nj=16 align prob=-48.95 over 51.22h [retry=1.6%, fail=0.7%] states=3336 gauss=40065 fmllr-impr=2.72 over 19.18h tree-impr=7.23 +%WER 23.89 [ 2211 / 9253, 233 ins, 404 del, 1574 sub ] exp/tri3/decode_nosp_fglarge_test_clean/wer_15_0.0 +%WER 24.47 [ 2264 / 9253, 252 ins, 385 del, 1627 sub ] exp/tri3/decode_nosp_tglarge_test_clean/wer_13_0.0 +%WER 37.81 [ 3499 / 9253, 274 ins, 671 del, 2554 sub ] exp/tri3/decode_nosp_tgsmall_test_clean/wer_13_0.0 +%WER 49.00 [ 4534 / 9253, 302 ins, 874 del, 3358 sub ] exp/tri3/decode_nosp_tgsmall_test_clean.si/wer_14_0.0 +%WER 21.68 [ 2006 / 9253, 226 ins, 346 del, 1434 sub ] exp/tri3/decode_fglarge_test_clean/wer_15_0.0 +%WER 22.59 [ 2090 / 9253, 231 ins, 372 del, 1487 sub ] exp/tri3/decode_tglarge_test_clean/wer_15_0.0 +%WER 34.83 [ 3223 / 9253, 294 ins, 605 del, 2324 sub ] exp/tri3/decode_tgsmall_test_clean/wer_12_0.0 +%WER 45.28 [ 4190 / 9253, 270 ins, 880 del, 3040 sub ] exp/tri3/decode_tgsmall_test_clean.si/wer_15_0.0 + +# tri4 is an LDA+MLLT+SAT system after estimating pronunciation probabilities +# and word-and-pronunciation-dependent silence probabilities. +exp/tri4: nj=16 align prob=-48.70 over 51.22h [retry=1.5%, fail=0.7%] states=3368 gauss=40039 fmllr-impr=0.23 over 42.91h tree-impr=7.87 +%WER 21.61 [ 2000 / 9253, 210 ins, 379 del, 1411 sub ] exp/tri4/decode_fglarge_test_clean/wer_14_0.5 +%WER 22.59 [ 2090 / 9253, 237 ins, 371 del, 1482 sub ] exp/tri4/decode_tglarge_test_clean/wer_15_0.0 +%WER 34.57 [ 3199 / 9253, 285 ins, 595 del, 2319 sub ] exp/tri4/decode_tgsmall_test_clean/wer_12_0.0 +%WER 45.82 [ 4240 / 9253, 270 ins, 833 del, 3137 sub ] exp/tri4/decode_tgsmall_test_clean.si/wer_13_0.0 + +for dir in exp/chain/tdnn*_sp; do + steps/info/chain_dir_info.pl $dir + for x in ${dir}_online/decode*test*; do [ -d $x ] && [[ $x =~ "$1" ]] && grep WER $x/wer_* | utils/best_wer.sh; done +done +exit 0 + +# tdnn_1a is a kind of factorized TDNN, with skip connections. +exp/chain/tdnn1b_sp: num-iters=174 nj=2..8 num-params=12.9M dim=40+100->3040 combine=-0.041->-0.041 (over 2) xent:train/valid[115,173,final]=(-1.14,-0.759,-0.751/-1.14,-0.788,-0.777) logprob:train/valid[115,173,final]=(-0.084,-0.047,-0.046/-0.080,-0.050,-0.048) +%WER 10.55 [ 976 / 9253, 122 ins, 166 del, 688 sub ] exp/chain/tdnn1b_sp_online/decode_fglarge_test_clean/wer_13_1.0 +%WER 17.65 [ 1633 / 9253, 208 ins, 233 del, 1192 sub ] exp/chain/tdnn1b_sp_online/decode_tgsmall_test_clean/wer_10_0.0 + +# This chain system has TDNN+Norm-OPGRU architecture. +exp/chain/tdnn_opgru1a_sp: num-iters=99 nj=2..12 num-params=38.0M dim=40+100->3040 combine=-0.045->-0.045 (over 1) xent:train/valid[65,98,final]=(-1.18,-0.663,-0.651/-1.21,-0.698,-0.684) logprob:train/valid[65,98,final]=(-0.079,-0.038,-0.037/-0.076,-0.040,-0.039) +%WER 9.45 [ 874 / 9253, 109 ins, 159 del, 606 sub ] exp/chain/tdnn_opgru1a_sp_online/decode_fglarge_test_clean/wer_10_1.0 +%WER 15.22 [ 1408 / 9253, 175 ins, 196 del, 1037 sub ] exp/chain/tdnn_opgru1a_sp_online/decode_tgsmall_test_clean/wer_8_0.0 + diff --git a/egs/zeroth_korean/s5/cmd.sh b/egs/zeroth_korean/s5/cmd.sh new file mode 100644 index 00000000000..34031439792 --- /dev/null +++ b/egs/zeroth_korean/s5/cmd.sh @@ -0,0 +1,17 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +export train_cmd="queue.pl --mem 2G" +export decode_cmd="queue.pl --mem 4G" +export mkgraph_cmd="queue.pl --mem 8G" +export normalize_cmd="queue.pl --mem 4G" + diff --git a/egs/zeroth_korean/s5/conf/decode.config b/egs/zeroth_korean/s5/conf/decode.config new file mode 100644 index 00000000000..7ba966f2b83 --- /dev/null +++ b/egs/zeroth_korean/s5/conf/decode.config @@ -0,0 +1 @@ +# empty config, just use the defaults. diff --git a/egs/zeroth_korean/s5/conf/mfcc.conf b/egs/zeroth_korean/s5/conf/mfcc.conf new file mode 100644 index 00000000000..7361509099f --- /dev/null +++ b/egs/zeroth_korean/s5/conf/mfcc.conf @@ -0,0 +1 @@ +--use-energy=false # only non-default option. diff --git a/egs/zeroth_korean/s5/conf/mfcc_hires.conf b/egs/zeroth_korean/s5/conf/mfcc_hires.conf new file mode 100644 index 00000000000..434834a6725 --- /dev/null +++ b/egs/zeroth_korean/s5/conf/mfcc_hires.conf @@ -0,0 +1,10 @@ +# config for high-resolution MFCC features, intended for neural network training +# Note: we keep all cepstra, so it has the same info as filterbank features, +# but MFCC is more easily compressible (because less correlated) which is why +# we prefer this method. +--use-energy=false # use average of log energy, not energy. +--num-mel-bins=40 # similar to Google's setup. +--num-ceps=40 # there is no dimensionality reduction. +--low-freq=20 # low cutoff frequency for mel bins... this is high-bandwidth data, so + # there might be some information at the low end. +--high-freq=-400 # high cutoff frequently, relative to Nyquist of 8000 (=7600) diff --git a/egs/zeroth_korean/s5/conf/online_cmvn.conf b/egs/zeroth_korean/s5/conf/online_cmvn.conf new file mode 100644 index 00000000000..7748a4a4dd3 --- /dev/null +++ b/egs/zeroth_korean/s5/conf/online_cmvn.conf @@ -0,0 +1 @@ +# configuration file for apply-cmvn-online, used in the script ../local/run_online_decoding.sh diff --git a/egs/zeroth_korean/s5/local/chain/compare_wer.sh b/egs/zeroth_korean/s5/local/chain/compare_wer.sh new file mode 100755 index 00000000000..e8366bfb358 --- /dev/null +++ b/egs/zeroth_korean/s5/local/chain/compare_wer.sh @@ -0,0 +1,107 @@ +#!/bin/bash + +# this script is used for comparing decoding results between systems. +# e.g. local/chain/compare_wer.sh exp/chain/tdnn_{c,d}_sp +# For use with discriminatively trained systems you specify the epochs after a colon: +# for instance, +# local/chain/compare_wer.sh exp/chain/tdnn_c_sp exp/chain/tdnn_c_sp_smbr:{1,2,3} + + +if [ $# == 0 ]; then + echo "Usage: $0: [ ... ]" + echo "e.g.: $0 exp/chain/tdnn_{b,c}_sp" + echo "or (with epoch numbers for discriminative training):" + echo "$0 exp/chain/tdnn_b_sp_disc:{1,2,3}" + exit 1 +fi + +echo "# $0 $*" + +used_epochs=false + +# this function set_names is used to separate the epoch-related parts of the name +# [for discriminative training] and the regular parts of the name. +# If called with a colon-free directory name, like: +# set_names exp/chain/tdnn_lstm1e_sp_bi_smbr +# it will set dir=exp/chain/tdnn_lstm1e_sp_bi_smbr and epoch_infix="" +# If called with something like: +# set_names exp/chain/tdnn_d_sp_smbr:3 +# it will set dir=exp/chain/tdnn_d_sp_smbr and epoch_infix="_epoch3" + + +set_names() { + if [ $# != 1 ]; then + echo "compare_wer_general.sh: internal error" + exit 1 # exit the program + fi + dirname=$(echo $1 | cut -d: -f1) + epoch=$(echo $1 | cut -s -d: -f2) + if [ -z $epoch ]; then + epoch_infix="" + else + used_epochs=true + epoch_infix=_epoch${epoch} + fi +} + + + +echo -n "# System " +for x in $*; do printf "% 10s" " $(basename $x)"; done +echo + +strings=( + "#WER test_clean (tgsmall) " + "#WER test_clean (fglarge) ") + +for n in 0 1 ; do + echo -n "${strings[$n]}" + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + decode_names=(tgsmall_test_clean fglarge_test_clean) + + wer=$(grep WER ${dirname}_online/decode_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo +done + + +if $used_epochs; then + exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. +fi + + +echo -n "# Final train prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final train prob (xent)" +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob (xent)" +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Num-params " +for x in $*; do + printf "% 10s" $(grep num-parameters $x/log/progress.1.log | awk '{print $2}') +done +echo diff --git a/egs/zeroth_korean/s5/local/chain/run_tdnn.sh b/egs/zeroth_korean/s5/local/chain/run_tdnn.sh new file mode 120000 index 00000000000..34499362831 --- /dev/null +++ b/egs/zeroth_korean/s5/local/chain/run_tdnn.sh @@ -0,0 +1 @@ +tuning/run_tdnn_1a.sh \ No newline at end of file diff --git a/egs/zeroth_korean/s5/local/chain/run_tdnn_opgru.sh b/egs/zeroth_korean/s5/local/chain/run_tdnn_opgru.sh new file mode 120000 index 00000000000..aedd4c8b4ac --- /dev/null +++ b/egs/zeroth_korean/s5/local/chain/run_tdnn_opgru.sh @@ -0,0 +1 @@ +tuning/run_tdnn_opgru_1a.sh \ No newline at end of file diff --git a/egs/zeroth_korean/s5/local/chain/tuning/run_tdnn_1a.sh b/egs/zeroth_korean/s5/local/chain/tuning/run_tdnn_1a.sh new file mode 100755 index 00000000000..14b9a8d6c8e --- /dev/null +++ b/egs/zeroth_korean/s5/local/chain/tuning/run_tdnn_1a.sh @@ -0,0 +1,290 @@ +#!/bin/bash + +set -e -o pipefail + +# This recipe trains TDNN-F AM +# The training recipe is from WSJ example(egs/wsj/s5/local/chain/tuning/run_tdnn_1g.sh) + +# steps/info/chain_dir_info.pl exp/chain/tdnn1a_sp +# exp/chain/tdnn1b_sp: num-iters=174 nj=2..8 num-params=12.9M dim=40+100->3040 combine=-0.041->-0.041 (over 2) xent:train/valid[115,173,final]=(-1.14,-0.759,-0.751/-1.14,-0.788,-0.777) logprob:train/valid[115,173,final]=(-0.084,-0.047,-0.046/-0.080,-0.050,-0.048) + +# ./local/chain/compare_wer.sh exp/chain/tdnn1a_sp +# System tdnn1b_sp +#WER test_clean (tgsmall) 17.65 +#WER test_clean (fglarge) 10.55 +# Final train prob -0.0460 +# Final valid prob -0.0480 +# Final train prob (xent) -0.7512 +# Final valid prob (xent) -0.7769 +# Num-params 12922560 + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +nj=30 +train_set=train_clean +speed_perturb=true +test_sets="test_clean" +gmm=tri4 # this is the source gmm-dir that we'll use for alignments; it + # should have alignments for the specified training data. +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. + +# Options which are not passed through to run_ivector_common.sh +affix=1a #affix for TDNN directory e.g. "1a" or "1b", in case we change the configuration. +common_egs_dir= + +# LSTM/chain options +train_stage=-10 +xent_regularize=0.1 +dropout_schedule='0,0@0.20,0.5@0.50,0' + +# training chunk-options +chunk_width=140,100,160 +# we don't need extra left/right context for TDNN systems. +chunk_left_context=0 +chunk_right_context=0 + +# training options +srand=0 +remove_egs=true + +#decode options +test_online_decoding=true # if true, it will run the last decoding stage. + +# End configuration section. +echo "$0 $@" # Print the command line for logging + + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 9 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj $nj --cmd "$train_cmd" ${lores_train_data_dir} \ + data/lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 10 ]; then + # Build a tree using our new topology. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 3500 ${lores_train_data_dir} \ + $lang $ali_dir $tree_dir +fi + +if [ $stage -le 11 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + tdnn_opts="l2-regularize=0.01 dropout-proportion=0.0 dropout-per-dim-continuous=true" + tdnnf_opts="l2-regularize=0.01 dropout-proportion=0.0 bypass-scale=0.66" + linear_opts="l2-regularize=0.01 orthonormal-constraint=-1.0" + prefinal_opts="l2-regularize=0.01" + output_opts="l2-regularize=0.005" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-dropout-layer name=tdnn1 $tdnn_opts dim=1280 + tdnnf-layer name=tdnnf2 $tdnnf_opts dim=1280 bottleneck-dim=160 time-stride=1 + tdnnf-layer name=tdnnf3 $tdnnf_opts dim=1280 bottleneck-dim=160 time-stride=1 + tdnnf-layer name=tdnnf4 $tdnnf_opts dim=1280 bottleneck-dim=160 time-stride=1 + tdnnf-layer name=tdnnf5 $tdnnf_opts dim=1280 bottleneck-dim=160 time-stride=0 + tdnnf-layer name=tdnnf6 $tdnnf_opts dim=1280 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf7 $tdnnf_opts dim=1280 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf8 $tdnnf_opts dim=1280 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf9 $tdnnf_opts dim=1280 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf10 $tdnnf_opts dim=1280 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf11 $tdnnf_opts dim=1280 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf12 $tdnnf_opts dim=1280 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf13 $tdnnf_opts dim=1280 bottleneck-dim=160 time-stride=3 + linear-component name=prefinal-l dim=256 $linear_opts + + + prefinal-layer name=prefinal-chain input=prefinal-l $prefinal_opts big-dim=1280 small-dim=256 + output-layer name=output include-log-softmax=false dim=$num_targets $output_opts + + prefinal-layer name=prefinal-xent input=prefinal-l $prefinal_opts big-dim=1280 small-dim=256 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor $output_opts + +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 12 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/wsj-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$decode_cmd" \ + --feat.online-ivector-dir=$train_ivector_dir \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.0 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=10 \ + --trainer.frames-per-iter=2000000 \ + --trainer.optimization.num-jobs-initial=2 \ + --trainer.optimization.num-jobs-final=8 \ + --trainer.optimization.initial-effective-lrate=0.0005 \ + --trainer.optimization.final-effective-lrate=0.00005 \ + --trainer.num-chunk-per-minibatch=128,64 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.chunk-left-context=0 \ + --egs.chunk-right-context=0 \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; + +fi + +if [ $stage -le 13 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/lang/check_phones_compatible.sh \ + data/lang_test_tgsmall/phones.txt $lang/phones.txt + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/lang_test_tgsmall \ + $tree_dir $tree_dir/graph_tgsmall || exit 1; +fi + +if $test_online_decoding && [ $stage -le 14 ]; then + # note: if the features change (e.g. you add pitch features), you will have to + # change the options of the following command line. + steps/online/nnet3/prepare_online_decoding.sh \ + --mfcc-config conf/mfcc_hires.conf \ + $lang exp/nnet3/extractor ${dir} ${dir}_online + + rm $dir/.error 2>/dev/null || true + + for data in $test_sets; do + ( + data_affix=$(echo $data | sed s/test_//) + nspk=$(wc -l 3040 combine=-0.045->-0.045 (over 1) xent:train/valid[65,98,final]=(-1.19,-0.661,-0.647/-1.21,-0.696,-0.680) logprob:train/valid[65,98,final]=(-0.080,-0.039,-0.038/-0.076,-0.039,-0.038) + +# ./local/chain/compare_wer.sh exp/chain/tdnn_opgru1a_sp +# System tdnn_opgru1a_sp +#WER test_clean (tgsmall) 15.22 +#WER test_clean (fglarge) 9.45 +# Final train prob -0.0373 +# Final valid prob -0.0386 +# Final train prob (xent) -0.6506 +# Final valid prob (xent) -0.6837 +# Num-params 37970368 + + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +nj=30 +train_set=train_clean +speed_perturb=true +test_sets="test_clean" +gmm=tri4 # this is the source gmm-dir that we'll use for alignments; it + # should have alignments for the specified training data. +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. + +# Options which are not passed through to run_ivector_common.sh +affix=1a #affix for TDNN directory e.g. "1a" or "1b", in case we change the configuration. +common_egs_dir= + +# OPGRU/chain options +train_stage=-10 +get_egs_stage=-10 + +xent_regularize=0.1 +dropout_schedule='0,0@0.20,0.2@0.50,0' + +chunk_width=140,100,160 +label_delay=5 + +remove_egs=true + + +#decode options +test_online_decoding=true # if true, it will run the last decoding stage. + +# decode options +extra_left_context=50 +frames_per_chunk= + +# End configuration section. +echo "$0 $@" # Print the command line for logging + + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 9 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj $nj --cmd "$train_cmd" ${lores_train_data_dir} \ + data/lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 10 ]; then + # Build a tree using our new topology. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 3500 ${lores_train_data_dir} \ + $lang $ali_dir $tree_dir +fi + +if [ $stage -le 11 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + gru_opts="dropout-per-frame=true dropout-proportion=0.0" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=1024 + relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1) dim=1024 + relu-batchnorm-layer name=tdnn3 input=Append(-1,0,1) dim=1024 + + # check steps/libs/nnet3/xconfig/gru.py for the other options and defaults + norm-opgru-layer name=opgru1 cell-dim=1024 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-3 $gru_opts + relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=1024 + relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=1024 + relu-batchnorm-layer name=tdnn6 input=Append(-3,0,3) dim=1024 + norm-opgru-layer name=opgru2 cell-dim=1024 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-3 $gru_opts + relu-batchnorm-layer name=tdnn7 input=Append(-3,0,3) dim=1024 + relu-batchnorm-layer name=tdnn8 input=Append(-3,0,3) dim=1024 + relu-batchnorm-layer name=tdnn9 input=Append(-3,0,3) dim=1024 + norm-opgru-layer name=opgru3 cell-dim=1024 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-3 $gru_opts + + ## adding the layers for chain branch + output-layer name=output input=opgru3 output-delay=$label_delay include-log-softmax=false dim=$num_targets max-change=1.5 + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + output-layer name=output-xent input=opgru3 output-delay=$label_delay dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 + +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ + +fi + + +if [ $stage -le 12 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/wsj-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir $train_ivector_dir \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --egs.dir "$common_egs_dir" \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width $chunk_width \ + --egs.chunk-left-context 40 \ + --egs.chunk-right-context 0 \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.optimization.backstitch-training-scale 0.3 \ + --trainer.optimization.backstitch-training-interval 1 \ + --egs.chunk-left-context-initial 0 \ + --egs.chunk-right-context-final 0 \ + --trainer.num-chunk-per-minibatch 64,32 \ + --trainer.frames-per-iter 2000000 \ + --trainer.num-epochs=8 \ + --trainer.optimization.shrink-value 0.99 \ + --trainer.optimization.num-jobs-initial 2 \ + --trainer.optimization.num-jobs-final 12 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.max-param-change 2.0 \ + --trainer.deriv-truncate-margin 8 \ + --cleanup.remove-egs true \ + --feat-dir $train_data_dir \ + --tree-dir $tree_dir \ + --lat-dir $lat_dir \ + --dir $dir + +fi + +if [ $stage -le 13 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/lang/check_phones_compatible.sh \ + data/lang_test_tgsmall/phones.txt $lang/phones.txt + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/lang_test_tgsmall \ + $tree_dir $tree_dir/graph_tgsmall || exit 1; +fi + +if $test_online_decoding && [ $stage -le 14 ]; then + # note: if the features change (e.g. you add pitch features), you will have to + # change the options of the following command line. + steps/online/nnet3/prepare_online_decoding.sh \ + --mfcc-config conf/mfcc_hires.conf \ + $lang exp/nnet3/extractor ${dir} ${dir}_online + + rm $dir/.error 2>/dev/null || true + + for data in $test_sets; do + ( + data_affix=$(echo $data | sed s/test_//) + nspk=$(wc -l " + echo "e.g.: $0 ./db/train_data_01 data/train_data_01" + exit 1 +fi + +db_dir=$1 +data_part=$2 + +src=${db_dir}/${data_part} +dst=data/${data_part} + +# all utterances are FLAC compressed +if ! which flac >&/dev/null; then + echo "Please install 'flac' on ALL worker nodes!" + exit 1 +fi + +spk_file=${db_dir}/AUDIO_INFO + +mkdir -p $dst || exit 1; + +[ ! -d $src ] && echo "$0: no such directory $src" && exit 1; +[ ! -f $spk_file ] && echo "$0: expected file $spk_file to exist" && exit 1; + +wav_scp=$dst/wav.scp; [[ -f "$wav_scp" ]] && rm $wav_scp +trans=$dst/text; [[ -f "$trans" ]] && rm $trans +utt2spk=$dst/utt2spk; [[ -f "$utt2spk" ]] && rm $utt2spk +spk2gender=$dst/spk2gender; [[ -f $spk2gender ]] && rm $spk2gender +utt2dur=$dst/utt2dur; [[ -f "$utt2dur" ]] && rm $utt2dur + +for scriptid_dir in $(find -L $src -mindepth 1 -maxdepth 1 -type d | sort); do + scriptid=$(basename $scriptid_dir) + if ! [ $scriptid -eq $scriptid ]; then # not integer. + echo "$0: unexpected subdirectory name $scriptid" + exit 1; + fi + + for reader_dir in $(find -L $scriptid_dir/ -mindepth 1 -maxdepth 1 -type d | sort); do + reader=$(basename $reader_dir) + if ! [ "$reader" -eq "$reader" ]; then + echo "$0: unexpected reader-subdirectory name $reader" + exit 1; + fi + + reader_gender=$(egrep "^$reader\|" $spk_file | awk -F'|' '{gsub(/[ ]+/, ""); print tolower($3)}') + if [ "$reader_gender" != 'm' ] && [ "$reader_gender" != 'f' ]; then + echo "Unexpected gender: '$reader_gender'" + exit 1; + fi + + echo " "$scriptid $reader $reader_gender + + find -L $reader_dir/ -iname "*.flac" | sort | xargs -I% basename % .flac | \ + awk -v "dir=$reader_dir" '{printf "%s flac -c -d -s %s/%s.flac |\n", $0, dir, $0}' >>$wav_scp|| exit 1 + + reader_trans=$reader_dir/${reader}_${scriptid}.trans.txt + [ ! -f $reader_trans ] && echo "$0: expected file $reader_trans to exist" && exit 1 + cat $reader_trans >>$trans + + # NOTE: Each chapter is dedicated to each speaker. + awk -v "reader=$reader" -v "scriptid=$scriptid" '{printf "%s %s_%s\n", $1, reader, scriptid}' \ + <$reader_trans >>$utt2spk || exit 1 + + # reader -> gender map (again using per-chapter granularity) + echo "${reader}_${scriptid} $reader_gender" >>$spk2gender + + done +done + +# sort +cat $wav_scp | sort > tmp +cp tmp $wav_scp +cat $trans | sort > tmp +cp tmp $trans +cat $utt2spk | sort > tmp +cp tmp $utt2spk +cat $spk2gender | sort > tmp +cp tmp $spk2gender +rm tmp + + +spk2utt=$dst/spk2utt +utils/utt2spk_to_spk2utt.pl <$utt2spk >$spk2utt || exit 1 + +ntrans=$(wc -l <$trans) +nutt2spk=$(wc -l <$utt2spk) +! [ "$ntrans" -eq "$nutt2spk" ] && \ + echo "Inconsistent #transcripts($ntrans) and #utt2spk($nutt2spk)" && exit 1; + +utils/data/get_utt2dur.sh $dst 1>&2 || exit 1 + +utils/validate_data_dir.sh --no-feats $dst || exit 1; + +echo "$0: successfully prepared data in $dst" + +exit 0 diff --git a/egs/zeroth_korean/s5/local/download_and_untar.sh b/egs/zeroth_korean/s5/local/download_and_untar.sh new file mode 100755 index 00000000000..2e62a3273d4 --- /dev/null +++ b/egs/zeroth_korean/s5/local/download_and_untar.sh @@ -0,0 +1,61 @@ +#!/bin/bash + +# Copyright 2018 Lucas Jo (Atlas Guide) +# 2018 Wonkyum Lee (Gridspace) +# Apache 2.0 + +if [ $# -ne "1" ]; then + echo "Usage: $0 " + echo "e.g.: $0 ./db" + exit 1 +fi + +exists(){ + command -v "$1" >/dev/null 2>&1 +} + + +dir=$1 +local_lm_dir=data/local/lm + +AUDIOINFO='AUDIO_INFO' +AUDIOLIST='train_data_01 test_data_01' + +echo "Now download corpus ----------------------------------------------------" +if [ ! -f $dir/db.tar.gz ]; then + if [ ! -d $dir ]; then + mkdir -p $dir + fi + wget -O $dir/db.tar.gz http://www.openslr.org/resources/40/zeroth_korean.tar.gz +else + echo " $dir/db.tar.gz already exist" +fi + +echo "Now extract corpus ----------------------------------------------------" +if [ ! -f $dir/$AUDIOINFO ]; then + tar -zxvf $dir/db.tar.gz -C $dir + else + echo " corpus already extracted" +fi + +if [ ! -d $local_lm_dir ]; then + mkdir -p $local_lm_dir +fi +echo "Check LMs files" +LMList="\ + zeroth.lm.fg.arpa.gz \ + zeroth.lm.tg.arpa.gz \ + zeroth.lm.tgmed.arpa.gz \ + zeroth.lm.tgsmall.arpa.gz \ + zeroth_lexicon \ + zeroth_morfessor.seg" + +for file in $LMList; do + if [ -f $local_lm_dir/$file ]; then + echo $file already exist + else + echo "Linking "$file + ln -s $PWD/$dir/$file $local_lm_dir/$file + fi +done +echo "all the files (lexicon, LM, segment model) are ready" diff --git a/egs/zeroth_korean/s5/local/format_lms.sh b/egs/zeroth_korean/s5/local/format_lms.sh new file mode 100755 index 00000000000..a9111e80eeb --- /dev/null +++ b/egs/zeroth_korean/s5/local/format_lms.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +# Copyright 2014 Vassil Panayotov +# Apache 2.0 + +# Prepares the test time language model(G) transducers +# (adapted from wsj/s5/local/wsj_format_data.sh) + +# Modified by Lucas Jo 2017 (Altas Guide) + +. ./path.sh || exit 1; + +# begin configuration section +src_dir=data/lang +# end configuration section + +. utils/parse_options.sh || exit 1; + +set -e + +if [ $# -ne 1 ]; then + echo "Usage: $0 " + echo "e.g.: $0 /export/a15/vpanayotov/data/lm" + echo ", where:" + echo " is the directory in which the language model is stored/downloaded" + echo "Options:" + echo " --src-dir

# source lang directory, default data/lang" + exit 1 +fi + +lm_dir=$1 + +if [ ! -d $lm_dir ]; then + echo "$0: expected source LM directory $lm_dir to exist" + exit 1; +fi +if [ ! -f $src_dir/words.txt ]; then + echo "$0: expected $src_dir/words.txt to exist." + exit 1; +fi + + +tmpdir=data/local/lm_tmp.$$ +trap "rm -r $tmpdir" EXIT + +mkdir -p $tmpdir + +#lm_sets="tgsmall tgmed" +lm_sets="tgsmall" +for lm_suffix in ${lm_sets}; do + # tglarge is prepared by a separate command, called from run.sh; we don't + # want to compile G.fst for tglarge, as it takes a while. + test=${src_dir}_test_${lm_suffix} + mkdir -p $test + cp -r ${src_dir}/* $test + gunzip -c $lm_dir/zeroth.lm.${lm_suffix}.arpa.gz | \ + arpa2fst --disambig-symbol=#0 \ + --read-symbol-table=$test/words.txt - $test/G.fst + + utils/validate_lang.pl --skip-determinization-check $test || exit 1; +done + +echo "Succeeded in formatting data." + +exit 0 diff --git a/egs/zeroth_korean/s5/local/nnet3/run_ivector_common.sh b/egs/zeroth_korean/s5/local/nnet3/run_ivector_common.sh new file mode 100755 index 00000000000..70be96310e1 --- /dev/null +++ b/egs/zeroth_korean/s5/local/nnet3/run_ivector_common.sh @@ -0,0 +1,108 @@ +#!/bin/bash + +# this script contains some common (shared) parts of the run_nnet*.sh scripts. +. cmd.sh + + +stage=0 +gmmdir=exp/tri4 +speed_perturb=false +trainset=train_clean + +set -e +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if [ "$speed_perturb" == "true" ]; then + if [ $stage -le 1 ]; then + echo "$0: preparing directory for speed-perturbed data" + #Although the nnet will be trained by high resolution data, we still have to perturbe the normal data to get the alignment + # _sp stands for speed-perturbed + for datadir in ${trainset} ; do + utils/data/perturb_data_dir_speed_3way.sh data/${datadir} data/${datadir}_sp + + mfccdir=mfcc_perturbed + steps/make_mfcc.sh --cmd "$train_cmd" --nj 40 \ + data/${datadir}_sp exp/make_mfcc/${datadir}_sp $mfccdir || exit 1; + steps/compute_cmvn_stats.sh data/${datadir}_sp exp/make_mfcc/${datadir}_sp $mfccdir || exit 1; + utils/fix_data_dir.sh data/${datadir}_sp + done + fi + + if [ $stage -le 2 ]; then + echo "$0: aligning with the perturbed low-resolution data" + #obtain the alignment of the perturbed data + steps/align_fmllr.sh --nj 100 --cmd "$train_cmd" \ + data/${trainset}_sp data/lang_nosp ${gmmdir} ${gmmdir}_ali_${trainset}_sp || exit 1 + fi + trainset=${trainset}_sp +fi + +if [ $stage -le 3 ]; then + # Create high-resolution MFCC features (with 40 cepstra instead of 13). + # this shows how you can split across multiple file-systems. we'll split the + # MFCC dir across multiple locations. You might want to be careful here, if you + # have multiple copies of Kaldi checked out and run the same recipe, not to let + # them overwrite each other. + + echo "$0: creating high-resolution MFCC features" + for datadir in ${trainset} ; do + utils/copy_data_dir.sh data/$datadir data/${datadir}_hires + steps/make_mfcc.sh --nj 40 --mfcc-config conf/mfcc_hires.conf \ + --cmd "$train_cmd" data/${datadir}_hires || exit 1; + steps/compute_cmvn_stats.sh data/${datadir}_hires || exit 1; + done + + # We need to build a small system just because we need PCA transform + # to train the diag-UBM on top of. + utils/subset_data_dir.sh data/${trainset}_hires 30000 data/train_30k_hires +fi + + +if [ $stage -le 4 ]; then + # Train a small system just for its PCA transform. + echo "$0: computing a PCA transform from the hires data." + mkdir exp -p exp/nnet3 + steps/online/nnet2/get_pca_transform.sh --cmd "$train_cmd" \ + --splice-opts "--left-context=3 --right-context=3" \ + --max-utts 30000 --subsample 2 \ + data/train_30k_hires exp/nnet3/pca_transform +fi + +if [ $stage -le 5 ]; then + # To train a diagonal UBM we don't need very much data, so use a small subset + echo "$0: training the diagonal UBM." + steps/online/nnet2/train_diag_ubm.sh --cmd "$train_cmd" --nj 30 --num-frames 700000 \ + data/train_30k_hires 512 exp/nnet3/pca_transform exp/nnet3/diag_ubm +fi + +if [ $stage -le 6 ]; then + # Train the iVector extractor. Use all of the speed-perturbed data since iVector extractors + # can be sensitive to the amount of data. The script defaults to an iVector dimension of 100 + echo "$0: training the iVector extractor" + steps/online/nnet2/train_ivector_extractor.sh --cmd "$train_cmd" --nj 10 \ + data/${trainset}_hires exp/nnet3/diag_ubm exp/nnet3/extractor || exit 1; +fi + +if [ $stage -le 7 ]; then + ivectordir=exp/nnet3/ivectors_${trainset}_hires + + # We extract iVectors on all the train data, which will be what we train the + # system on. With --utts-per-spk-max 2, the script. pairs the utterances + # into twos, and treats each of these pairs as one speaker. Note that these + # are extracted 'online'. + + # having a larger number of speakers is helpful for generalization, and to + # handle per-utterance decoding well (iVector starts at zero). + echo "$0: extracing iVector using trained iVector extractor" + utils/data/modify_speaker_info.sh --utts-per-spk-max 2 \ + data/${trainset}_hires data/${trainset}_hires_max2 + + steps/online/nnet2/extract_ivectors_online.sh --cmd "$train_cmd" --nj 60 \ + data/${trainset}_hires_max2 exp/nnet3/extractor $ivectordir || exit 1; +fi + + +exit 0; diff --git a/egs/zeroth_korean/s5/local/prepare_dict.sh b/egs/zeroth_korean/s5/local/prepare_dict.sh new file mode 100755 index 00000000000..76c6821e11e --- /dev/null +++ b/egs/zeroth_korean/s5/local/prepare_dict.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +# Copyright 2014 Vassil Panayotov +# Apache 2.0 + +# Modified by Lucas Jo 2017 (Altas Guide) +# Prepare dictionary + +if [ $# -ne 2 ]; then + echo "Usage: $0 " + echo "e.g.: /data/local/lm data/local/dict_nosp" + exit 1 +fi +lm_dir=$1 +dst_dir=$2 + +mkdir -p $dst_dir || exit 1; + +# this file is a copy of the lexicon we obtained from download_lm.sh process +lexicon_raw_nosil=$dst_dir/lexicon_raw_nosil.txt + +if [[ ! -s "$lexicon_raw_nosil" ]]; then + cp $lm_dir/zeroth_lexicon $lexicon_raw_nosil || exit 1 +fi + +silence_phones=$dst_dir/silence_phones.txt +optional_silence=$dst_dir/optional_silence.txt +nonsil_phones=$dst_dir/nonsilence_phones.txt +extra_questions=$dst_dir/extra_questions.txt + +echo "Preparing phone lists and clustering questions" +(echo SIL; echo SPN;) > $silence_phones +#( echo SIL; echo BRH; echo CGH; echo NSN ; echo SMK; echo UM; echo UHH ) > $silence_phones +echo SIL > $optional_silence +# nonsilence phones; on each line is a list of phones that correspond +# really to the same base phone. +awk '{for (i=2; i<=NF; ++i) { print $i; gsub(/[0-9]/, "", $i); print $i}}' $lexicon_raw_nosil |\ + sort -u |\ + perl -e 'while(<>){ + chop; m:^([^\d]+)(\d*)$: || die "Bad phone $_"; + $phones_of{$1} .= "$_ "; } + foreach $list (values %phones_of) {print $list . "\n"; } ' \ + > $nonsil_phones || exit 1; +# A few extra questions that will be added to those obtained by +# automatically clustering +# the "real" phones. These ask about stress; there's also one for +# silence. +cat $silence_phones| awk '{printf("%s ", $1);} END{printf "\n";}' > $extra_questions || exit 1; +cat $nonsil_phones | perl -e 'while(<>){ foreach $p (split(" ", $_)){ +$p =~ m:^([^\d]+)(\d*)$: || die "Bad phone $_"; $q{$2} .= "$p "; } } foreach $l (values %q) {print "$l\n";}' \ + >> $extra_questions || exit 1; + +echo "$(wc -l <$silence_phones) silence phones saved to: $silence_phones" +echo "$(wc -l <$optional_silence) optional silence saved to: $optional_silence" +echo "$(wc -l <$nonsil_phones) non-silence phones saved to: $nonsil_phones" +echo "$(wc -l <$extra_questions) extra triphone clustering-related questions saved to: $extra_questions" + +#(echo '!SIL SIL'; echo '[BREATH] BRH'; echo '[NOISE] NSN'; echo '[COUGH] CGH'; +# echo '[SMACK] SMK'; echo '[UM] UM'; echo '[UH] UHH' +# echo ' NSN' ) | \ +(echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) |\ + cat - $lexicon_raw_nosil | sort | uniq >$dst_dir/lexicon.txt +echo "Lexicon text file saved as: $dst_dir/lexicon.txt" +exit 0 + diff --git a/egs/zeroth_korean/s5/local/score.sh b/egs/zeroth_korean/s5/local/score.sh new file mode 100755 index 00000000000..c812199fc98 --- /dev/null +++ b/egs/zeroth_korean/s5/local/score.sh @@ -0,0 +1,63 @@ +#!/bin/bash +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey) +# 2014 Guoguo Chen +# Apache 2.0 + +[ -f ./path.sh ] && . ./path.sh + +# begin configuration section. +cmd=run.pl +stage=0 +decode_mbr=true +word_ins_penalty=0.0,0.5,1.0 +min_lmwt=7 +max_lmwt=17 +iter=final +#end configuration section. + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: local/score.sh [--cmd (run.pl|queue.pl...)] " + echo " Options:" + echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." + echo " --stage (0|1|2) # start scoring script from part-way through." + echo " --decode_mbr (true/false) # maximum bayes risk decoding (confusion network)." + echo " --min_lmwt # minumum LM-weight for lattice rescoring " + echo " --max_lmwt # maximum LM-weight for lattice rescoring " + exit 1; +fi + +data=$1 +lang_or_graph=$2 +dir=$3 + +symtab=$lang_or_graph/words.txt + +for f in $symtab $dir/lat.1.gz $data/text; do + [ ! -f $f ] && echo "score.sh: no such file $f" && exit 1; +done + +mkdir -p $dir/scoring/log + +cat $data/text | sed 's:::g' | sed 's:::g' > $dir/scoring/test_filt.txt + +for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do + $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/best_path.LMWT.$wip.log \ + lattice-scale --inv-acoustic-scale=LMWT "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ + lattice-add-penalty --word-ins-penalty=$wip ark:- ark:- \| \ + lattice-best-path --word-symbol-table=$symtab \ + ark:- ark,t:$dir/scoring/LMWT.$wip.tra || exit 1; +done + +# Note: the double level of quoting for the sed command +for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do + $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/score.LMWT.$wip.log \ + cat $dir/scoring/LMWT.$wip.tra \| \ + utils/int2sym.pl -f 2- $symtab \| sed 's:\::g' \| \ + compute-wer --text --mode=present \ + ark:$dir/scoring/test_filt.txt ark,p:- ">&" $dir/wer_LMWT_$wip || exit 1; +done + +exit 0; diff --git a/egs/zeroth_korean/s5/local/update_segmentation.sh b/egs/zeroth_korean/s5/local/update_segmentation.sh new file mode 100755 index 00000000000..e1eea821645 --- /dev/null +++ b/egs/zeroth_korean/s5/local/update_segmentation.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +# Copyright 2017 Lucas Jo (Atlas Guide) +# Apache 2.0 + +# do this when the segmentation rule is changed +dataDir=$1 +lmDir=$2 + +exists(){ + command -v "$1" >/dev/null 2>&1 +} + +# check morfessor installation +if ! exists morfessor; then + echo "You appear to not have Morfessor installed, either on your path." + echo "See tools/extras/install_morfessor.sh installation instructions." + exit 1 +fi + +trans=$dataDir/text +echo "Re-segment transcripts: $trans --------------------------------------------" +if [ ! -f $trans ]; then + echo "transcription file is not found in "$dataDir + exit 1 +fi +cp $trans $trans".old" +awk '{print $1}' $trans".old" > $trans"_tmp_index" +cut -d' ' -f2- $trans".old" |\ + sed -E 's/\s+/ /g; s/^\s//g; s/\s$//g' |\ + morfessor -e 'utf-8' -l $lmDir/zeroth_morfessor.seg -T - -o - \ + --output-format '{analysis} ' --output-newlines \ + --nosplit-re '[0-9\[\]\(\){}a-zA-Z&.,\-]+' \ + | paste -d" " $trans"_tmp_index" - > $trans +rm -f $trans"_tmp_index" + diff --git a/egs/zeroth_korean/s5/path.sh b/egs/zeroth_korean/s5/path.sh new file mode 100755 index 00000000000..2d17b17a84a --- /dev/null +++ b/egs/zeroth_korean/s5/path.sh @@ -0,0 +1,6 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C diff --git a/egs/zeroth_korean/s5/run.sh b/egs/zeroth_korean/s5/run.sh new file mode 100755 index 00000000000..c5c7506980b --- /dev/null +++ b/egs/zeroth_korean/s5/run.sh @@ -0,0 +1,264 @@ +#!/bin/bash +# +# Based mostly on the WSJ/Librispeech recipe. +# The training/testing database is described in http://www.openslr.org/40/ +# This corpus consists of 51hrs korean speech with cleaned automatic transcripts: +# +# Copyright 2018 Atlas Guide (Author : Lucas Jo) +# 2018 Gridspace Inc. (Author: Wonkyum Lee) +# +# Apache 2.0 +# + +# Check list before start +# 1. required software: Morfessor-2.0.1 (see tools/extras/install_morfessor.sh) + +stage=0 +db_dir=./db +nj=16 + +chain_train=true +decode=true # set false if you don't want to decode each GMM model +decode_rescoring=true # set false if you don't want to rescore with large language model +test_set="test_clean" + +. ./cmd.sh +. ./path.sh +. utils/parse_options.sh # e.g. this parses the --stage option if supplied. + +# you might not want to do this for interactive shells. +set -e + +if [ $stage -le 0 ]; then + # download the data. + local/download_and_untar.sh $db_dir +fi + +if [ $stage -le 1 ]; then + # format the data as Kaldi data directories + for part in train_data_01 test_data_01; do + # use underscore-separated names in data directories. + local/data_prep.sh $db_dir $part + done +fi + +if [ $stage -le 2 ]; then + # update segmentation of transcripts + for part in train_data_01 test_data_01; do + local/update_segmentation.sh data/$part data/local/lm + done +fi + +if [ $stage -le 3 ]; then + # prepare dictionary and language model + local/prepare_dict.sh data/local/lm data/local/dict_nosp + + utils/prepare_lang.sh data/local/dict_nosp \ + "" data/local/lang_tmp_nosp data/lang_nosp +fi + +if [ $stage -le 4 ]; then + # build testing language model + local/format_lms.sh --src-dir data/lang_nosp data/local/lm + + # re-scoring language model + if $decode_rescoring ; then + utils/build_const_arpa_lm.sh data/local/lm/zeroth.lm.tg.arpa.gz \ + data/lang_nosp data/lang_nosp_test_tglarge + utils/build_const_arpa_lm.sh data/local/lm/zeroth.lm.fg.arpa.gz \ + data/lang_nosp data/lang_nosp_test_fglarge + fi +fi + + +if [ $stage -le 5 ]; then + # Feature extraction (MFCC) + mfccdir=mfcc + for part in train_data_01 test_data_01; do + steps/make_mfcc.sh --cmd "$train_cmd" --nj $nj data/$part exp/make_mfcc/$part $mfccdir + steps/compute_cmvn_stats.sh data/$part exp/make_mfcc/$part $mfccdir + done + + # ... and then combine data sets into one (for later extension) + utils/combine_data.sh \ + data/train_clean data/train_data_01 + + utils/combine_data.sh \ + data/test_clean data/test_data_01 + + # Make some small data subsets for early system-build stages. + utils/subset_data_dir.sh --shortest data/train_clean 2000 data/train_2kshort + utils/subset_data_dir.sh data/train_clean 5000 data/train_5k + utils/subset_data_dir.sh data/train_clean 10000 data/train_10k +fi + +if [ $stage -le 6 ]; then + echo "$0: #### Monophone Training ###########" + # train a monophone system with 2k short utts + steps/train_mono.sh --boost-silence 1.25 --nj $nj --cmd "$train_cmd" \ + data/train_2kshort data/lang_nosp exp/mono + if $decode; then + utils/mkgraph.sh data/lang_nosp_test_tgsmall exp/mono exp/mono/graph_nosp_tgsmall + nspk=$(wc -l " data/local/lang_tmp data/lang + + local/format_lms.sh --src-dir data/lang data/local/lm + + utils/build_const_arpa_lm.sh \ + data/local/lm/zeroth.lm.tg.arpa.gz data/lang data/lang_test_tglarge + utils/build_const_arpa_lm.sh \ + data/local/lm/zeroth.lm.fg.arpa.gz data/lang data/lang_test_fglarge + + if $decode; then + utils/mkgraph.sh data/lang_test_tgsmall exp/tri3 exp/tri3/graph_tgsmall + nspk=$(wc -l trigger: - error_level = int(math.log(self.lines_in_function / base_trigger, 2)) + error_level = int(math.log(float(self.lines_in_function) / base_trigger, 2)) # 50 => 0, 100 => 1, 200 => 2, 400 => 3, 800 => 4, 1600 => 5, ... if error_level > 5: error_level = 5 @@ -676,7 +677,7 @@ class _IncludeError(Exception): pass -class FileInfo: +class FileInfo(object): """Provides utility functions for filenames. FileInfo provides easy access to the components of a file's path @@ -1012,7 +1013,7 @@ def CheckForCopyright(filename, lines, error): # We'll say it should occur by line 10. Don't forget there's a # dummy line at the front. - for line in xrange(1, min(len(lines), 11)): + for line in range(1, min(len(lines), 11)): if re.search(r'Copyright', lines[line], re.I): break else: # means no copyright line was found error(filename, 0, 'legal/copyright', 5, @@ -1604,7 +1605,7 @@ def CheckForFunctionLengths(filename, clean_lines, linenum, if starting_func: body_found = False - for start_linenum in xrange(linenum, clean_lines.NumLines()): + for start_linenum in range(linenum, clean_lines.NumLines()): start_line = lines[start_linenum] joined_line += ' ' + start_line.lstrip() if Search(r'(;|})', start_line): # Declarations and trivial functions @@ -2073,7 +2074,7 @@ def GetLineWidth(line): The width of the line in column positions, accounting for Unicode combining characters and wide characters. """ - if isinstance(line, unicode): + if isinstance(line, str): width = 0 for c in unicodedata.normalize('NFC', line): if unicodedata.east_asian_width(c) in ('W', 'F'): @@ -2861,7 +2862,7 @@ def CheckForIncludeWhatYouUse(filename, clean_lines, include_state, error, required = {} # A map of header name to linenumber and the template entity. # Example of required: { '': (1219, 'less<>') } - for linenum in xrange(clean_lines.NumLines()): + for linenum in range(clean_lines.NumLines()): line = clean_lines.elided[linenum] if not line or line[0] == '#': continue @@ -2994,7 +2995,7 @@ def ProcessFileData(filename, file_extension, lines, error): RemoveMultiLineComments(filename, lines, error) clean_lines = CleansedLines(lines) - for line in xrange(clean_lines.NumLines()): + for line in range(clean_lines.NumLines()): ProcessLine(filename, file_extension, clean_lines, line, include_state, function_state, class_state, error) class_state.CheckFinished(filename, error) diff --git a/scripts/rnnlm/choose_features.py b/scripts/rnnlm/choose_features.py index 0686c8f88c6..595c1d85bc1 100755 --- a/scripts/rnnlm/choose_features.py +++ b/scripts/rnnlm/choose_features.py @@ -8,10 +8,10 @@ import sys import math from collections import defaultdict -sys.stdout = open(1, 'w', encoding='latin-1', closefd=False) +sys.stdout = open(1, 'w', encoding='utf-8', closefd=False) import re -tab_or_space = re.compile('[ \t]') + parser = argparse.ArgumentParser(description="This script chooses the sparse feature representation of words. " "To be more specific, it chooses the set of features-- you compute " @@ -86,9 +86,9 @@ # and 'wordlist' is a list indexed by integer id, that returns the string-valued word. def read_vocab(vocab_file): vocab = {} - with open(vocab_file, 'r', encoding="latin-1") as f: + with open(vocab_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) == 2 if fields[0] in vocab: sys.exit(sys.argv[0] + ": duplicated word({0}) in vocab: {1}" @@ -115,9 +115,9 @@ def read_vocab(vocab_file): # id of the word, which evaluates to the unigram prob of the word. def read_unigram_probs(unigram_probs_file): unigram_probs = [] - with open(unigram_probs_file, 'r', encoding="latin-1") as f: + with open(unigram_probs_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) == 2 idx = int(fields[0]) if idx >= len(unigram_probs): diff --git a/scripts/rnnlm/compute_sentence_scores_back.sh b/scripts/rnnlm/compute_sentence_scores_back.sh new file mode 100755 index 00000000000..3024d43439e --- /dev/null +++ b/scripts/rnnlm/compute_sentence_scores_back.sh @@ -0,0 +1,69 @@ +#!/bin/bash + +# Copyright 2017 Hainan Xu +# 2017 Szu-Jui Chen + +# This script is very similar to rnnlm/compute_sentence_scores.sh, where it do the +# same procedure for reversed data. And it computes log-likelihoods from a +# Kaldi-RNNLM model instead of that of Mikolov's RNNLM. Because Kaldi-RNNLM uses +# letter-features which does not need an symbol, we don't need the "unk.probs" +# file any more to add as a penalty term in sentence likelihoods. + +ensure_normalized_probs=false # If true then the probabilities computed by the + # RNNLM will be correctly normalized. Note it is + # OK to set it to false because Kaldi-RNNLM is + # trained in a way that ensures the sum of probabilities + # is close to 1. + +. ./path.sh || exit 1; +. utils/parse_options.sh + +if [ $# != 4 ]; then + echo "Usage: $0 " + exit 1; +fi + +dir=$1 +tempdir=$2 +text_in=$3 +scores_out=$4 + +if [ -f $dir/word_embedding.final.mat ]; then + word_embedding=$dir/word_embedding.final.mat +else + [ ! -f $dir/feat_embedding.final.mat ] && + echo "$0: expect file $dir/feat_embedding.final.mat to exit" + word_embedding="rnnlm-get-word-embedding $dir/word_feats.txt $dir/feat_embedding.final.mat -|" +fi + +for x in final.raw config/words.txt; do + if [ ! -f $dir/$x ]; then + echo "$0: expected file $dir/$x to exist." + exit 1; + fi +done + +mkdir -p $tempdir +cat $text_in | sym2int.pl -f 2- $dir/config/words.txt | \ + awk '{printf("%s ",$1);for(i=NF;i>1;i--) printf("%s ",$i); print""}' > $tempdir/text.int + +special_symbol_opts=$(cat ${dir}/special_symbol_opts.txt) + +rnnlm-sentence-probs --normalize-probs=$ensure_normalized_probs \ + $special_symbol_opts $dir/final.raw "$word_embedding" $tempdir/text.int > $tempdir/loglikes.rnn +# Now $tempdir/loglikes.rnn has the following structure +# utt-id log P(word1 | ) log P(word2 | word1) ... log P( | all word histories) +# for example, +# +# en_4156-A_058697-058813-2 -3.57205 -2.70411 -4.29876 -3.63707 -6.00299 -2.11093 -2.03955 +# en_4156-A_058697-058813-3 -6.6074 -1.21244 -3.89991 -3.23747 -5.35102 -1.90448 -1.77809 +# en_4156-A_058697-058813-4 -5.09022 -1.24148 -4.76337 -4.75594 -5.77118 -2.08555 -2.18403 +# en_4156-A_058697-058813-5 -4.54489 -2.97485 -3.93646 -3.28041 -5.18779 -2.83356 -1.72601 +# en_4156-A_058697-058813-6 -2.31464 -3.74738 -4.03309 -3.22942 -5.66818 -2.0396 -1.64734 +# en_4156-A_058697-058813-7 -5.0728 -2.96303 -4.6539 -3.20266 -5.40682 -2.10625 -1.90956 + +[ $(cat $tempdir/loglikes.rnn | wc -l) -ne $(cat $tempdir/text.int | wc -l) ] && \ + echo "$0: rnnlm rescoring failed" && exit 1; + +# We need the negative log-probabilities +cat $tempdir/loglikes.rnn | awk '{sum=0;for(i=2;i<=NF;i++)sum-=$i; print $1,sum}' >$scores_out diff --git a/scripts/rnnlm/get_best_model.py b/scripts/rnnlm/get_best_model.py index e8c6bd8a2f4..ed266346e06 100755 --- a/scripts/rnnlm/get_best_model.py +++ b/scripts/rnnlm/get_best_model.py @@ -3,14 +3,14 @@ # Copyright 2017 Johns Hopkins University (author: Daniel Povey) # License: Apache 2.0. -import os import argparse -import sys +import glob import re +import sys parser = argparse.ArgumentParser(description="Works out the best iteration of RNNLM training " - "based on dev-set perplexity, and prints the number corresponding " - "to that iteration", + "based on dev-set perplexity, and prints the number corresponding " + "to that iteration", epilog="E.g. " + sys.argv[0] + " exp/rnnlm_a", formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -19,10 +19,9 @@ args = parser.parse_args() - -num_iters=None +num_iters = None try: - with open(args.rnnlm_dir + "/info.txt", encoding="latin-1") as f: + with open(args.rnnlm_dir + "/info.txt", encoding="utf-8") as f: for line in f: a = line.split("=") if a[0] == "num_iters": @@ -36,15 +35,15 @@ sys.exit(sys.argv[0] + ": could not get num_iters from {0}/info.txt".format( args.rnnlm_dir)) -best_objf=-2000 -best_iter=-1 +best_objf = -2000 +best_iter = -1 for i in range(1, num_iters): this_logfile = "{0}/log/compute_prob.{1}.log".format(args.rnnlm_dir, i) try: - f = open(this_logfile, 'r', encoding='latin-1') + f = open(this_logfile, 'r', encoding='utf-8') except: sys.exit(sys.argv[0] + ": could not open log-file {0}".format(this_logfile)) - this_objf=-1000 + this_objf = -1000 for line in f: m = re.search('Overall objf .* (\S+)$', str(line)) if m is not None: @@ -53,6 +52,10 @@ except Exception as e: sys.exit(sys.argv[0] + ": line in file {0} could not be parsed: {1}, error is: {2}".format( this_logfile, line, str(e))) + # verify this iteration still has model files present + if len(glob.glob("{0}/{1}.raw".format(args.rnnlm_dir, i))) == 0: + # this iteration has log files, but model files have been cleaned up, skip it + continue if this_objf == -1000: print(sys.argv[0] + ": warning: could not parse objective function from {0}".format( this_logfile), file=sys.stderr) @@ -63,5 +66,4 @@ if best_iter == -1: sys.exit(sys.argv[0] + ": error: could not get best iteration.") - print(str(best_iter)) diff --git a/scripts/rnnlm/get_embedding_dim.py b/scripts/rnnlm/get_embedding_dim.py index a5ddb8c25f3..1d516e0edf5 100755 --- a/scripts/rnnlm/get_embedding_dim.py +++ b/scripts/rnnlm/get_embedding_dim.py @@ -45,7 +45,7 @@ left_context=0 right_context=0 for line in out_lines: - line = line.decode('latin-1') + line = line.decode('utf-8') m = re.search(r'input-node name=input dim=(\d+)', line) if m is not None: try: @@ -101,4 +101,4 @@ "nnet '{0}': {1} != {2}".format( args.nnet, input_dim, output_dim)) -print(str(input_dim)) +print('{}'.format(input_dim)) diff --git a/scripts/rnnlm/get_special_symbol_opts.py b/scripts/rnnlm/get_special_symbol_opts.py index 83f7d708a49..7ee0ca54c9a 100755 --- a/scripts/rnnlm/get_special_symbol_opts.py +++ b/scripts/rnnlm/get_special_symbol_opts.py @@ -9,7 +9,7 @@ import sys import re -tab_or_space = re.compile('[ \t]') + parser = argparse.ArgumentParser(description="This script checks whether the special symbols " "appear in words.txt with expected values, if not, it will " @@ -28,9 +28,9 @@ lower_ids = {} upper_ids = {} -input_stream = io.TextIOWrapper(sys.stdin.buffer, encoding='latin-1') +input_stream = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') for line in input_stream: - fields = re.split(tab_or_space, line) + fields = line.split() assert(len(fields) == 2) sym = fields[0] if sym in special_symbols: diff --git a/scripts/rnnlm/get_unigram_probs.py b/scripts/rnnlm/get_unigram_probs.py index abb8515f330..e3189b26a92 100755 --- a/scripts/rnnlm/get_unigram_probs.py +++ b/scripts/rnnlm/get_unigram_probs.py @@ -8,7 +8,7 @@ import sys import re -tab_or_space = re.compile('[ \t]') + parser = argparse.ArgumentParser(description="This script gets the unigram probabilities of words.", epilog="E.g. " + sys.argv[0] + " --vocab-file=data/rnnlm/vocab/words.txt " @@ -77,10 +77,10 @@ def get_all_data_sources_except_dev(text_dir): # value is a tuple (repeated_times_per_epoch, weight) def read_data_weights(weights_file, data_sources): data_weights = {} - with open(weights_file, 'r', encoding="latin-1") as f: + with open(weights_file, 'r', encoding="utf-8") as f: for line in f: try: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) == 3 if fields[0] in data_weights: raise Exception("duplicated data source({0}) specified in " @@ -102,9 +102,9 @@ def read_data_weights(weights_file, data_sources): # return the vocab, which is a dict mapping the word to a integer id. def read_vocab(vocab_file): vocab = {} - with open(vocab_file, 'r', encoding="latin-1") as f: + with open(vocab_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) == 2 if fields[0] in vocab: sys.exit(sys.argv[0] + ": duplicated word({0}) in vocab: {1}" @@ -131,9 +131,9 @@ def get_counts(data_sources, data_weights, vocab): if weight == 0.0: continue - with open(counts_file, 'r', encoding="latin-1") as f: + with open(counts_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() if len(fields) != 2: print("Warning, should be 2 cols:", fields, line, file=sys.stderr); assert(len(fields) == 2) word = fields[0] diff --git a/scripts/rnnlm/get_vocab.py b/scripts/rnnlm/get_vocab.py index e30ce4a94c9..baafcb3a131 100755 --- a/scripts/rnnlm/get_vocab.py +++ b/scripts/rnnlm/get_vocab.py @@ -6,10 +6,10 @@ import os import argparse import sys -sys.stdout = open(1, 'w', encoding='latin-1', closefd=False) +sys.stdout = open(1, 'w', encoding='utf-8', closefd=False) import re -tab_or_space = re.compile('[ \t]') + parser = argparse.ArgumentParser(description="This script get a vocab from unigram counts " "of words produced by get_unigram_counts.sh", @@ -28,10 +28,10 @@ # Add the count for every word in counts_file # the result is written into word_counts def add_counts(word_counts, counts_file): - with open(counts_file, 'r', encoding="latin-1") as f: + with open(counts_file, 'r', encoding="utf-8") as f: for line in f: - line = line.strip() - word_and_count = re.split(tab_or_space, line) + line = line.strip(" \t\r\n") + word_and_count = line.split() assert len(word_and_count) == 2 if word_and_count[0] in word_counts: word_counts[word_and_count[0]] += int(word_and_count[1]) diff --git a/scripts/rnnlm/get_word_features.py b/scripts/rnnlm/get_word_features.py index 54d84077060..cdcc0a77734 100755 --- a/scripts/rnnlm/get_word_features.py +++ b/scripts/rnnlm/get_word_features.py @@ -10,7 +10,7 @@ from collections import defaultdict import re -tab_or_space = re.compile('[ \t]') + parser = argparse.ArgumentParser(description="This script turns the words into the sparse feature representation, " "using features from rnnlm/choose_features.py.", @@ -41,9 +41,9 @@ # return the vocab, which is a dict mapping the word to a integer id. def read_vocab(vocab_file): vocab = {} - with open(vocab_file, 'r', encoding="latin-1") as f: + with open(vocab_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) == 2 if fields[0] in vocab: sys.exit(sys.argv[0] + ": duplicated word({0}) in vocab: {1}" @@ -62,9 +62,9 @@ def read_vocab(vocab_file): # return a list of unigram_probs, indexed by word id def read_unigram_probs(unigram_probs_file): unigram_probs = [] - with open(unigram_probs_file, 'r', encoding="latin-1") as f: + with open(unigram_probs_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) == 2 idx = int(fields[0]) if idx >= len(unigram_probs): @@ -103,9 +103,9 @@ def read_features(features_file): feats['min_ngram_order'] = 10000 feats['max_ngram_order'] = -1 - with open(features_file, 'r', encoding="latin-1") as f: + with open(features_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert(len(fields) in [3, 4, 5]) feat_id = int(fields[0]) diff --git a/scripts/rnnlm/lmrescore_nbest.sh b/scripts/rnnlm/lmrescore_nbest.sh index 6f28c960dd9..58b19b9fa79 100755 --- a/scripts/rnnlm/lmrescore_nbest.sh +++ b/scripts/rnnlm/lmrescore_nbest.sh @@ -29,7 +29,7 @@ if [ $# != 6 ]; then echo "This version applies an RNNLM and mixes it with the LM scores" echo "previously in the lattices., controlled by the first parameter (rnnlm-weight)" echo "" - echo "Usage: utils/rnnlmrescore.sh " + echo "Usage: $0 [options] " echo "Main options:" echo " --inv-acwt # default 12. e.g. --inv-acwt 17. Equivalent to LM scale to use." echo " # for N-best list generation... note, we'll score at different acwt's" diff --git a/scripts/rnnlm/lmrescore_nbest_back.sh b/scripts/rnnlm/lmrescore_nbest_back.sh new file mode 100755 index 00000000000..7531d99b0a4 --- /dev/null +++ b/scripts/rnnlm/lmrescore_nbest_back.sh @@ -0,0 +1,142 @@ +#!/bin/bash + +# Copyright 2017 Hainan Xu +# 2017 Szu-Jui Chen + +# This script is very similar to scripts/rnnlm/lmrescore_nbest.sh, and it takes the results +# from forward model then performs n-best LM rescoring based on backward model with Kaldi-RNNLM. + +# Begin configuration section. +N=10 +inv_acwt=10 +cmd=run.pl +use_phi=false # This is kind of an obscure option. If true, we'll remove the old + # LM weights (times 1-RNN_scale) using a phi (failure) matcher, which is + # appropriate if the old LM weights were added in this way, e.g. by + # lmrescore.sh. Otherwise we'll use normal composition, which is appropriate + # if the lattices came directly from decoding. This won't actually make much + # difference (if any) to WER, it's more so we know we are doing the right thing. +test=false # Activate a testing option. +stage=1 # Stage of this script, for partial reruns. +skip_scoring=false +keep_ali=true +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +[ -f ./path.sh ] && . ./path.sh +. utils/parse_options.sh + +if [ $# != 6 ]; then + echo "Do language model rescoring of lattices (partially remove old LM, add new LM)" + echo "This version applies an RNNLM and mixes it with the LM scores" + echo "previously in the lattices., controlled by the first parameter (rnnlm-weight)" + echo "" + echo "Usage: $0 [options] " + echo "Main options:" + echo " --inv-acwt # default 12. e.g. --inv-acwt 17. Equivalent to LM scale to use." + echo " # for N-best list generation... note, we'll score at different acwt's" + echo " --cmd # how to run jobs." + echo " --phi (true|false) # Should be set to true if the source lattices were created" + echo " # by lmrescore.sh, false if they came from decoding." + echo " --N # Value of N in N-best rescoring (default: 10)" + exit 1; +fi + +rnnweight=$1 +oldlang=$2 +rnndir=$3 +data=$4 +indir=$5 +dir=$6 + +acwt=`perl -e "print (1.0/$inv_acwt);"` + +# Figures out if the old LM is G.fst or G.carpa +oldlm=$oldlang/G.fst +if [ -f $oldlang/G.carpa ]; then + oldlm=$oldlang/G.carpa +elif [ ! -f $oldlm ]; then + echo "$0: expecting either $oldlang/G.fst or $oldlang/G.carpa to exist" &&\ + exit 1; +fi + +for f in $rnndir/final.raw $data/feats.scp $indir/lat.1.gz; do + [ ! -f $f ] && echo "$0: expected file $f to exist." && exit 1; +done + +nj=`cat $indir/num_jobs` || exit 1; +mkdir -p $dir; +cp $indir/num_jobs $dir/num_jobs + +adir=$dir/archives + +phi=`grep -w '#0' $oldlang/words.txt | awk '{print $2}'` + +rm $dir/.error 2>/dev/null +mkdir -p $dir/log + +# First convert lattice to N-best. Be careful because this +# will be quite sensitive to the acoustic scale; this should be close +# to the one we'll finally get the best WERs with. +# Note: the lattice-rmali part here is just because we don't +# need the alignments for what we're doing. +if [ $stage -le 5 ]; then + echo "$0: Copying needed information from $indir/archives to $adir" + # Do some small tasks; for these we don't use the queue, it will only slow us down. + for n in `seq $nj`; do + mkdir -p $adir.$n + cp $indir/archives.$n/ali $adir.$n/ + cp $indir/archives.$n/words $adir.$n/ + cp $indir/archives.$n/words_text $adir.$n/ + cp $indir/archives.$n/lmwt.nolm $adir.$n/ + cp $indir/archives.$n/acwt $adir.$n/ + cp $indir/archives.$n/lmwt.withlm $adir.$n/ + + mkdir -p $adir.$n/temp + paste $adir.$n/lmwt.nolm $adir.$n/lmwt.withlm | awk '{print $1, ($4-$2);}' > \ + $adir.$n/lmwt.lmonly || exit 1; + done +fi +if [ $stage -le 6 ]; then + echo "$0: invoking rnnlm/compute_sentence_scores_back.sh which calls rnnlm to get RNN LM scores." + $cmd JOB=1:$nj $dir/log/rnnlm_compute_scores.JOB.log \ + rnnlm/compute_sentence_scores_back.sh $rnndir $adir.JOB/temp \ + $adir.JOB/words_text $adir.JOB/lmwt.rnn +fi + +if [ $stage -le 7 ]; then + echo "$0: doing average on forward and backward scores." + for n in `seq $nj`; do + paste $indir/archives.$n/lmwt.rnn $adir.$n/lmwt.rnn | awk -F' ' '{print $1,$2 * 0.5 + $4 * 0.5}' \ + > $adir.$n/lmwt.rnn_bi + done +fi + +if [ $stage -le 8 ]; then + echo "$0: reconstructing total LM+graph scores including interpolation of RNNLM and old LM scores." + for n in `seq $nj`; do + paste $adir.$n/lmwt.nolm $adir.$n/lmwt.lmonly $adir.$n/lmwt.rnn_bi | awk -v rnnweight=$rnnweight \ + '{ key=$1; graphscore=$2; lmscore=$4; rnnscore=$6; + score = graphscore+(rnnweight*rnnscore)+((1-rnnweight)*lmscore); + print $1,score; } ' > $adir.$n/lmwt.interp.$rnnweight || exit 1; + done +fi + +if [ $stage -le 9 ]; then + echo "$0: reconstructing archives back into lattices." + $cmd JOB=1:$nj $dir/log/reconstruct_lattice.JOB.log \ + linear-to-nbest "ark:$adir.JOB/ali" "ark:$adir.JOB/words" \ + "ark:$adir.JOB/lmwt.interp.$rnnweight" "ark:$adir.JOB/acwt" ark:- \| \ + nbest-to-lattice ark:- "ark:|gzip -c >$dir/lat.JOB.gz" || exit 1; +fi + +if ! $skip_scoring ; then + [ ! -x local/score.sh ] && \ + echo "Not scoring because local/score.sh does not exist or not executable." && exit 1; + local/score.sh --cmd "$cmd" $data $oldlang $dir || + { echo "$0: Scoring failed. (ignore by '--skip-scoring true')"; exit 1; } +fi + +exit 0; + diff --git a/scripts/rnnlm/prepare_rnnlm_dir.sh b/scripts/rnnlm/prepare_rnnlm_dir.sh index 1de91bb7232..e101822d983 100755 --- a/scripts/rnnlm/prepare_rnnlm_dir.sh +++ b/scripts/rnnlm/prepare_rnnlm_dir.sh @@ -23,7 +23,7 @@ if [ $# != 3 ]; then echo "Usage: $0 [options] " echo "Sets up the directory for RNNLM training as done by" echo "rnnlm/train_rnnlm.sh, and initializes the model." - echo " is as validated by rnnlm/validate_data_dir.py" + echo " is as validated by rnnlm/validate_text_dir.py" echo " is as validated by rnnlm/validate_config_dir.sh." exit 1 fi @@ -53,9 +53,13 @@ if [ $stage -le 1 ]; then echo "$0: copying config directory" mkdir -p $dir/config # copy expected things from $config_dir to $dir/config. - for f in words.txt features.txt data_weights.txt oov.txt xconfig; do + for f in words.txt data_weights.txt oov.txt xconfig; do cp $config_dir/$f $dir/config done + # features.txt is optional, check separately + if [ -f $config_dir/features.txt ]; then + cp $config_dir/features.txt $dir/config + fi fi rnnlm/get_special_symbol_opts.py < $dir/config/words.txt > $dir/special_symbol_opts.txt diff --git a/scripts/rnnlm/prepare_split_data.py b/scripts/rnnlm/prepare_split_data.py index e39f4504f37..427f043df98 100755 --- a/scripts/rnnlm/prepare_split_data.py +++ b/scripts/rnnlm/prepare_split_data.py @@ -9,7 +9,7 @@ import sys import re -tab_or_space = re.compile('[ \t]') + parser = argparse.ArgumentParser(description="This script prepares files containing integerized text, " "for consumption by nnet3-get-egs.", @@ -66,10 +66,10 @@ def get_all_data_sources_except_dev(text_dir): # value is a tuple (repeated_times_per_epoch, weight) def read_data_weights(weights_file, data_sources): data_weights = {} - with open(weights_file, 'r', encoding="latin-1") as f: + with open(weights_file, 'r', encoding="utf-8") as f: for line in f: try: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) == 3 if fields[0] in data_weights: raise Exception("duplicated data source({0}) specified in " @@ -97,7 +97,7 @@ def distribute_to_outputs(source_filename, weight, output_filehandles): num_outputs = len(output_filehandles) n = 0 try: - f = open(source_filename, 'r', encoding="latin-1") + f = open(source_filename, 'r', encoding="utf-8") except Exception as e: sys.exit(sys.argv[0] + ": failed to open file {0} for reading: {1} ".format( source_filename, str(e))) @@ -124,7 +124,7 @@ def distribute_to_outputs(source_filename, weight, output_filehandles): os.makedirs(args.split_dir + "/info") # set up the 'num_splits' file, which contains an integer. -with open("{0}/info/num_splits".format(args.split_dir), 'w', encoding="latin-1") as f: +with open("{0}/info/num_splits".format(args.split_dir), 'w', encoding="utf-8") as f: print(args.num_splits, file=f) # e.g. set temp_files = [ 'foo/1.tmp', 'foo/2.tmp', ..., 'foo/5.tmp' ] @@ -136,7 +136,7 @@ def distribute_to_outputs(source_filename, weight, output_filehandles): temp_filehandles = [] for fname in temp_files: try: - temp_filehandles.append(open(fname, 'w', encoding="latin-1")) + temp_filehandles.append(open(fname, 'w', encoding="utf-8")) except Exception as e: sys.exit(sys.argv[0] + ": failed to open file: " + str(e) + ".. if this is a max-open-filehandles limitation, you may " diff --git a/scripts/rnnlm/rnnlm_cleanup.py b/scripts/rnnlm/rnnlm_cleanup.py new file mode 100644 index 00000000000..6a304f7f4cb --- /dev/null +++ b/scripts/rnnlm/rnnlm_cleanup.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Tilde +# License: Apache 2.0 + +import sys + +import argparse +import os +import re +import glob + +script_name = sys.argv[0] + +parser = argparse.ArgumentParser(description="Removes models from past training iterations of " + "RNNLM. Can use either 'keep_latest' (default) or " + "'keep_best' cleanup strategy, where former keeps " + "the models that are freshest, while latter keeps " + "the models with best training objective score on " + "dev set.", + epilog="E.g. " + script_name + " exp/rnnlm_a --keep_best", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + +parser.add_argument("rnnlm_dir", + help="Directory where the RNNLM has been trained") +parser.add_argument("--iters_to_keep", + help="Max number of iterations to keep", + type=int, + default=3) +parser.add_argument("--keep_latest", + help="Keeps the training iterations that are latest by age", + action="store_const", + const=True, + default=False) +parser.add_argument("--keep_best", + help="Keeps the training iterations that have the best objf", + action="store_const", + const=True, + default=False) + +args = parser.parse_args() + +# validate arguments +if args.keep_latest and args.keep_best: + sys.exit(script_name + ": can only use one of 'keep_latest' or 'keep_best', but not both") +elif not args.keep_latest and not args.keep_best: + sys.exit(script_name + ": no cleanup strategy specified: use 'keep_latest' or 'keep_best'") + + +class IterationInfo: + def __init__(self, model_files, objf, compute_prob_done): + self.model_files = model_files + self.objf = objf + self.compute_prob_done = compute_prob_done + + def __str__(self): + return "{model_files: %s, compute_prob: %s, objf: %2.3f}" % (self.model_files, + self.compute_prob_done, + self.objf) + + def __repr__(self): + return self.__str__() + + +def get_compute_prob_info(log_file): + # we want to know 3 things: iteration number, objf and whether compute prob is done + iteration = int(log_file.split(".")[-2]) + objf = -2000 + compute_prob_done = False + # roughly based on code in get_best_model.py + try: + f = open(log_file, "r", encoding="utf-8") + except: + print(script_name + ": warning: compute_prob log not found for iteration " + + str(iter) + ". Skipping", + file=sys.stderr) + return iteration, objf, compute_prob_done + for line in f: + objf_m = re.search('Overall objf .* (\S+)$', str(line)) + if objf_m is not None: + try: + objf = float(objf_m.group(1)) + except Exception as e: + sys.exit(script_name + ": line in file {0} could not be parsed: {1}, error is: {2}".format( + log_file, line, str(e))) + if "# Ended" in line: + compute_prob_done = True + if objf == -2000: + print(script_name + ": warning: could not parse objective function from " + log_file, file=sys.stderr) + return iteration, objf, compute_prob_done + + +def get_iteration_files(exp_dir): + iterations = dict() + compute_prob_logs = glob.glob(exp_dir + "/log/compute_prob.[0-9]*.log") + for log in compute_prob_logs: + iteration, objf, compute_prob_done = get_compute_prob_info(log) + if iteration == 0: + # iteration 0 is special, never consider it for cleanup + continue + if compute_prob_done: + # this iteration can be safely considered for cleanup + # gather all model files belonging to it + model_files = [] + # when there are multiple jobs per iteration, there can be several model files + # we need to potentially clean them all up without mixing them up + model_files.extend(glob.glob("{0}/word_embedding.{1}.mat".format(exp_dir, iteration))) + model_files.extend(glob.glob("{0}/word_embedding.{1}.[0-9]*.mat".format(exp_dir, iteration))) + model_files.extend(glob.glob("{0}/feat_embedding.{1}.mat".format(exp_dir, iteration))) + model_files.extend(glob.glob("{0}/feat_embedding.{1}.[0-9]*.mat".format(exp_dir, iteration))) + model_files.extend(glob.glob("{0}/{1}.raw".format(exp_dir, iteration))) + model_files.extend(glob.glob("{0}/{1}.[0-9]*.raw".format(exp_dir, iteration))) + # compute_prob logs outlive model files, only consider iterations that do still have model files + if len(model_files) > 0: + iterations[iteration] = IterationInfo(model_files, objf, compute_prob_done) + return iterations + + +def remove_model_files_for_iter(iter_info): + for f in iter_info.model_files: + os.remove(f) + + +def keep_latest(iteration_dict): + max_to_keep = args.iters_to_keep + kept = 0 + iterations_in_reverse_order = reversed(sorted(iteration_dict)) + for iter in iterations_in_reverse_order: + if kept < max_to_keep: + kept += 1 + else: + remove_model_files_for_iter(iteration_dict[iter]) + + +def keep_best(iteration_dict): + iters_to_keep = args.iters_to_keep + best = [] + for iter, iter_info in iteration_dict.items(): + objf = iter_info.objf + if objf == -2000: + print(script_name + ": warning: objf unavailable for iter " + str(iter), file=sys.stderr) + continue + # add potential best, sort by objf, trim to iters_to_keep size + best.append((iter, objf)) + best = sorted(best, key=lambda x: -x[1]) + if len(best) > iters_to_keep: + throwaway = best[iters_to_keep:] + best = best[:iters_to_keep] + # remove iters that we know are not the best + for (iter, _) in throwaway: + remove_model_files_for_iter(iteration_dict[iter]) + + +# grab all the iterations mapped to their model files, objf score and compute_prob status +iterations = get_iteration_files(args.rnnlm_dir) +# apply chosen cleanup strategy +if args.keep_latest: + keep_latest(iterations) +else: + keep_best(iterations) diff --git a/scripts/rnnlm/show_word_features.py b/scripts/rnnlm/show_word_features.py index 5fe049cb8ce..4335caed5d8 100755 --- a/scripts/rnnlm/show_word_features.py +++ b/scripts/rnnlm/show_word_features.py @@ -6,10 +6,11 @@ import os import argparse import sys -sys.stdout = open(1, 'w', encoding='latin-1', closefd=False) + +sys.stdout = open(1, 'w', encoding='utf-8', closefd=False) import re -tab_or_space = re.compile('[ \t]') + parser = argparse.ArgumentParser(description="This script turns the word features to a human readable format.", epilog="E.g. " + sys.argv[0] + "exp/rnnlm/word_feats.txt exp/rnnlm/features.txt " @@ -30,9 +31,9 @@ def read_feature_type_and_key(features_file): feat_types = {} - with open(features_file, 'r', encoding="latin-1") as f: + with open(features_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert(len(fields) in [2, 3, 4]) feat_id = int(fields[0]) @@ -47,9 +48,9 @@ def read_feature_type_and_key(features_file): feat_type_and_key = read_feature_type_and_key(args.features_file) num_word_feats = 0 -with open(args.word_features_file, 'r', encoding="latin-1") as f: +with open(args.word_features_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) % 2 == 1 print(int(fields[0]), end='\t') diff --git a/scripts/rnnlm/train_rnnlm.sh b/scripts/rnnlm/train_rnnlm.sh index aedfc470ac9..013e9a56c2f 100755 --- a/scripts/rnnlm/train_rnnlm.sh +++ b/scripts/rnnlm/train_rnnlm.sh @@ -38,6 +38,11 @@ num_egs_threads=10 # number of threads used for sampling, if we're using use_gpu=true # use GPU for training use_gpu_for_diagnostics=false # set true to use GPU for compute_prob_*.log +# optional cleanup options +cleanup=false # add option --cleanup true to enable automatic cleanup of old models +cleanup_strategy="keep_latest" # determines cleanup strategy, use either "keep_latest" or "keep_best" +cleanup_keep_iters=3 # number of iterations that will have their models retained + trap 'for pid in $(jobs -pr); do kill -KILL $pid; done' INT QUIT TERM . utils/parse_options.sh @@ -208,7 +213,7 @@ while [ $x -lt $num_iters ]; do --read-rnnlm="$src_rnnlm" --write-rnnlm=$dir/$dest_number.raw \ --read-embedding=$dir/${embedding_type}_embedding.$x.mat \ --write-embedding=$dir/${embedding_type}_embedding.$dest_number.mat \ - "ark,bg:cat $repeated_data | rnnlm-get-egs --srand=$num_splits_processed $train_egs_args - ark:- |" || touch $dir/.train_error & + "ark,bg:cat $repeated_data | rnnlm-get-egs --chunk-length=$chunk_length --srand=$num_splits_processed $train_egs_args - ark:- |" || touch $dir/.train_error & done wait # wait for just the training jobs. [ -f $dir/.train_error ] && \ @@ -222,12 +227,16 @@ while [ $x -lt $num_iters ]; do nnet3-average $src_models $dir/$[x+1].raw '&&' \ matrix-sum --average=true $src_matrices $dir/${embedding_type}_embedding.$[x+1].mat fi + # optionally, perform cleanup after training + if [ "$cleanup" = true ] ; then + python3 rnnlm/rnnlm_cleanup.py $dir --$cleanup_strategy --iters_to_keep $cleanup_keep_iters + fi ) - # the error message below is not that informative, but $cmd will # have printed a more specific one. [ -f $dir/.error ] && echo "$0: error with diagnostics on iteration $x of training" && exit 1; fi + x=$[x+1] num_splits_processed=$[num_splits_processed+this_num_jobs] done diff --git a/scripts/rnnlm/validate_features.py b/scripts/rnnlm/validate_features.py index 010ceb72615..e67f03207bb 100755 --- a/scripts/rnnlm/validate_features.py +++ b/scripts/rnnlm/validate_features.py @@ -8,7 +8,7 @@ import sys import re -tab_or_space = re.compile('[ \t]') + parser = argparse.ArgumentParser(description="Validates features file, produced by rnnlm/choose_features.py.", epilog="E.g. " + sys.argv[0] + " exp/rnnlm/features.txt", @@ -24,7 +24,7 @@ if not os.path.isfile(args.features_file): sys.exit(sys.argv[0] + ": Expected file {0} to exist".format(args.features_file)) -with open(args.features_file, 'r', encoding="latin-1") as f: +with open(args.features_file, 'r', encoding="utf-8") as f: has_unigram = False has_length = False idx = 0 @@ -33,7 +33,7 @@ final_feats = {} word_feats = {} for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert(len(fields) in [3, 4, 5]) assert idx == int(fields[0]) diff --git a/scripts/rnnlm/validate_text_dir.py b/scripts/rnnlm/validate_text_dir.py index 4b311a8abbd..1f250d4c2f8 100755 --- a/scripts/rnnlm/validate_text_dir.py +++ b/scripts/rnnlm/validate_text_dir.py @@ -8,7 +8,7 @@ import sys import re -tab_or_space = re.compile('[ \t]') + parser = argparse.ArgumentParser(description="Validates data directory containing text " "files from one or more data sources, including dev.txt.", @@ -40,7 +40,7 @@ def check_text_file(text_file): - with open(text_file, 'r', encoding="latin-1") as f: + with open(text_file, 'r', encoding="utf-8") as f: found_nonempty_line = False lineno = 0 if args.allow_internal_eos == 'true': @@ -54,7 +54,7 @@ def check_text_file(text_file): lineno += 1 if args.spot_check == 'true' and lineno > 10: break - words = re.split(tab_or_space, line) + words = line.split() if len(words) != 0: found_nonempty_line = True for word in words: @@ -76,9 +76,9 @@ def check_text_file(text_file): # with some kind of utterance-id first_field_set = set() other_fields_set = set() - with open(text_file, 'r', encoding="latin-1") as f: + with open(text_file, 'r', encoding="utf-8") as f: for line in f: - array = re.split(tab_or_space, line) + array = line.split() if len(array) > 0: first_word = array[0] if first_word in first_field_set or first_word in other_fields_set: diff --git a/scripts/rnnlm/validate_word_features.py b/scripts/rnnlm/validate_word_features.py index f8eb5858d95..372286d8d12 100755 --- a/scripts/rnnlm/validate_word_features.py +++ b/scripts/rnnlm/validate_word_features.py @@ -8,7 +8,7 @@ import sys import re -tab_or_space = re.compile('[ \t]') + parser = argparse.ArgumentParser(description="Validates word features file, produced by rnnlm/get_word_features.py.", epilog="E.g. " + sys.argv[0] + " --features-file=exp/rnnlm/features.txt " @@ -28,9 +28,9 @@ unigram_feat_id = -1 length_feat_id = -1 max_feat_id = -1 -with open(args.features_file, 'r', encoding="latin-1") as f: +with open(args.features_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert(len(fields) in [3, 4, 5]) feat_id = int(fields[0]) @@ -52,9 +52,9 @@ if feat_id > max_feat_id: max_feat_id = feat_id -with open(args.word_features_file, 'r', encoding="latin-1") as f: +with open(args.word_features_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) > 0 and len(fields) % 2 == 1 word_id = int(fields[0]) diff --git a/src/.version b/src/.version index 37c2d9960ec..9ad974f6109 100644 --- a/src/.version +++ b/src/.version @@ -1 +1 @@ -5.4 +5.5 diff --git a/src/Makefile b/src/Makefile index 6dfd146e3d5..a49c912c6ed 100644 --- a/src/Makefile +++ b/src/Makefile @@ -31,15 +31,9 @@ include kaldi.mk # Reset the default goal, so that the all target will become default .DEFAULT_GOAL := -all: - $(MAKE) checkversion - $(MAKE) kaldi.mk - $(MAKE) mklibdir - $(MAKE) subdirs +all: $(SUBDIRS) matrix/test -echo Done -subdirs: $(SUBDIRS) - mklibdir: test -d $(KALDILIBDIR) || mkdir $(KALDILIBDIR) @@ -138,11 +132,11 @@ ext_depend: check_portaudio .PHONY: $(SUBDIRS) -$(SUBDIRS) : mklibdir +$(SUBDIRS) : checkversion kaldi.mk mklibdir $(MAKE) -C $@ .PHONY: $(EXT_SUBDIRS) -$(EXT_SUBDIRS) : mklibdir ext_depend +$(EXT_SUBDIRS) : checkversion kaldi.mk mklibdir ext_depend $(MAKE) -C $@ diff --git a/src/base/Makefile b/src/base/Makefile index 583c6badcf2..49af4f87ff4 100644 --- a/src/base/Makefile +++ b/src/base/Makefile @@ -18,7 +18,7 @@ include ../kaldi.mk TESTFILES = kaldi-math-test io-funcs-test kaldi-error-test timer-test -OBJFILES = kaldi-math.o kaldi-error.o io-funcs.o kaldi-utils.o +OBJFILES = kaldi-math.o kaldi-error.o io-funcs.o kaldi-utils.o timer.o LIBNAME = kaldi-base diff --git a/src/base/io-funcs-inl.h b/src/base/io-funcs-inl.h index 6b87f4c1a24..b703ef5addc 100644 --- a/src/base/io-funcs-inl.h +++ b/src/base/io-funcs-inl.h @@ -47,7 +47,7 @@ template void WriteBasicType(std::ostream &os, os << t << " "; } if (os.fail()) { - throw std::runtime_error("Write failure in WriteBasicType."); + KALDI_ERR << "Write failure in WriteBasicType."; } } @@ -122,7 +122,7 @@ inline void WriteIntegerPairVector(std::ostream &os, bool binary, os << "]\n"; } if (os.fail()) { - throw std::runtime_error("Write failure in WriteIntegerPairVector."); + KALDI_ERR << "Write failure in WriteIntegerPairVector."; } } @@ -224,7 +224,7 @@ template inline void WriteIntegerVector(std::ostream &os, bool binary, os << "]\n"; } if (os.fail()) { - throw std::runtime_error("Write failure in WriteIntegerVector."); + KALDI_ERR << "Write failure in WriteIntegerVector."; } } diff --git a/src/base/io-funcs.cc b/src/base/io-funcs.cc index 90988faf3ea..150f74099be 100644 --- a/src/base/io-funcs.cc +++ b/src/base/io-funcs.cc @@ -138,7 +138,7 @@ void WriteToken(std::ostream &os, bool binary, const char *token) { CheckToken(token); // make sure it's valid (can be read back) os << token << " "; if (os.fail()) { - throw std::runtime_error("Write failure in WriteToken."); + KALDI_ERR << "Write failure in WriteToken."; } } @@ -179,11 +179,8 @@ int PeekToken(std::istream &is, bool binary) { int ans = is.peek(); if (read_bracket) { if (!is.unget()) { - KALDI_WARN << "Error ungetting '<' in PeekToken"; - // Clear the bad bit. It seems to be possible for this code to be - // reached, and the C++ standard is very vague on whether even a single - // call to unget() should succeed; see - // http://www.cplusplus.com/reference/istream/istream/unget/ + // Clear the bad bit. This code can be (and is in fact) reached, since the + // C++ standard does not guarantee that a call to unget() must succeed. is.clear(); } } diff --git a/src/base/io-funcs.h b/src/base/io-funcs.h index ca476033950..895f661ecee 100644 --- a/src/base/io-funcs.h +++ b/src/base/io-funcs.h @@ -31,7 +31,9 @@ #include #include #include + #include "base/kaldi-common.h" +#include "base/io-funcs-inl.h" namespace kaldi { @@ -44,7 +46,7 @@ namespace kaldi { We also want to have control over whitespace in text mode without affecting the meaning of the file, for pretty-printing purposes. - Errors are handled by throwing an exception (std::runtime_error). + Errors are handled by throwing a KaldiFatalError exception. For integer and floating-point types (and boolean values): @@ -106,7 +108,7 @@ namespace kaldi { it doesn't throw. It's useful if a class can have various forms based on typedefs and virtual classes, and wants to know which version to read. - ReadToken allow the caller to obtain the next token. PeekToken works just + ReadToken allows the caller to obtain the next token. PeekToken works just like ReadToken, but seeks back to the beginning of the token. A subsequent call to ReadToken will read the same token again. This is useful when different object types are written to the same file; using PeekToken one can @@ -201,13 +203,18 @@ void WriteToken(std::ostream &os, bool binary, const std::string & token); /// value of the stream. int Peek(std::istream &is, bool binary); -/// ReadToken gets the next token and puts it in str (exception on failure). +/// ReadToken gets the next token and puts it in str (exception on failure). If +/// PeekToken() had been previously called, it is possible that the stream had +/// failed to unget the starting '<' character. In this case ReadToken() returns +/// the token string without the leading '<'. You must be prepared to handle +/// this case. ExpectToken() handles this internally, and is not affected. void ReadToken(std::istream &is, bool binary, std::string *token); /// PeekToken will return the first character of the next token, or -1 if end of /// file. It's the same as Peek(), except if the first character is '<' it will -/// skip over it and will return the next character. It will unget the '<' so -/// the stream is where it was before you did PeekToken(). +/// skip over it and will return the next character. It will attempt to unget +/// the '<' so the stream is where it was before you did PeekToken(), however, +/// this is not guaranteed (see ReadToken()). int PeekToken(std::istream &is, bool binary); /// ExpectToken tries to read in the given token, and throws an exception @@ -235,7 +242,4 @@ inline void InitKaldiOutputStream(std::ostream &os, bool binary); inline bool InitKaldiInputStream(std::istream &is, bool *binary); } // end namespace kaldi. - -#include "base/io-funcs-inl.h" - #endif // KALDI_BASE_IO_FUNCS_H_ diff --git a/src/base/kaldi-common.h b/src/base/kaldi-common.h index e0002d91bb7..264565d1812 100644 --- a/src/base/kaldi-common.h +++ b/src/base/kaldi-common.h @@ -36,5 +36,6 @@ #include "base/kaldi-types.h" #include "base/io-funcs.h" #include "base/kaldi-math.h" +#include "base/timer.h" #endif // KALDI_BASE_KALDI_COMMON_H_ diff --git a/src/base/kaldi-error-test.cc b/src/base/kaldi-error-test.cc index 527de852cac..462ad956907 100644 --- a/src/base/kaldi-error-test.cc +++ b/src/base/kaldi-error-test.cc @@ -42,13 +42,12 @@ void UnitTestError() { } // end namespace kaldi. int main() { - kaldi::g_program_name = "/foo/bar/kaldi-error-test"; + kaldi::SetProgramName("/foo/bar/kaldi-error-test"); try { kaldi::UnitTestError(); KALDI_ASSERT(0); // should not happen. exit(1); - } catch(std::runtime_error &r) { - std::cout << "UnitTestError: the error we generated was: " << r.what(); + } catch(kaldi::KaldiFatalError &e) { + std::cout << "The error we generated was: '" << e.KaldiMessage() << "'\n"; } } - diff --git a/src/base/kaldi-error.cc b/src/base/kaldi-error.cc index f2ce1edf37d..9705936466c 100644 --- a/src/base/kaldi-error.cc +++ b/src/base/kaldi-error.cc @@ -1,5 +1,6 @@ // base/kaldi-error.cc +// Copyright 2019 SmartAction LLC (kkm) // Copyright 2016 Brno University of Technology (author: Karel Vesely) // Copyright 2009-2011 Microsoft Corporation; Lukas Burget; Ondrej Glembek @@ -35,88 +36,90 @@ namespace kaldi { + /***** GLOBAL VARIABLES FOR LOGGING *****/ int32 g_kaldi_verbose_level = 0; -const char *g_program_name = NULL; -static LogHandler g_log_handler = NULL; - -// If the program name was set (g_program_name != ""), GetProgramName -// returns the program name (without the path), e.g. "gmm-align". -// Otherwise it returns the empty string "". -const char *GetProgramName() { - return g_program_name == NULL ? "" : g_program_name; +static std::string program_name; +static LogHandler log_handler = NULL; + +void SetProgramName(const char *basename) { + // Using the 'static std::string' for the program name is mostly harmless, + // because (a) Kaldi logging is undefined before main(), and (b) no stdc++ + // string implementation has been found in the wild that would not be just + // an empty string when zero-initialized but not yet constructed. + program_name = basename; } + /***** HELPER FUNCTIONS *****/ -// Given a filename like "/a/b/c/d/e/f.cc", GetShortFileName -// returns "e/f.cc". Does not currently work if backslash is -// the filename separator. -static const char *GetShortFileName(const char *filename) { - const char *last_slash = strrchr(filename, '/'); - if (!last_slash) { - return filename; - } else { - while (last_slash > filename && last_slash[-1] != '/') - last_slash--; - return last_slash; +// Trim filename to at most 1 trailing directory long. Given a filename like +// "/a/b/c/d/e/f.cc", return "e/f.cc". Support both '/' and '\' as the path +// separator. +static const char *GetShortFileName(const char *path) { + if (path == nullptr) + return ""; + + const char *prev = path, *last = path; + while ((path = std::strpbrk(path, "\\/")) != nullptr) { + ++path; + prev = last; + last = path; } + return prev; } -/***** STACKTRACE *****/ +/***** STACK TRACE *****/ +#ifdef HAVE_EXECINFO_H static std::string Demangle(std::string trace_name) { -#if defined(HAVE_CXXABI_H) && defined(HAVE_EXECINFO_H) - // at input the string looks like: +#ifdef HAVE_CXXABI_H + // At input the string looks like: // ./kaldi-error-test(_ZN5kaldi13UnitTestErrorEv+0xb) [0x804965d] - // We want to extract the name e.g. '_ZN5kaldi13UnitTestErrorEv", - // demangle it and return it. + // We want to extract the name e.g. '_ZN5kaldi13UnitTestErrorEv" + // and demangle it. - // try to locate '(' and '+', take the string in between, + // Try to locate '(' and '+', take the string in between. size_t begin(trace_name.find("(")), end(trace_name.rfind("+")); if (begin != std::string::npos && end != std::string::npos && begin < end) { - trace_name = trace_name.substr(begin+1,end-(begin+1)); + trace_name = trace_name.substr(begin + 1, end - (begin + 1)); } - // demangle, + // Try to demangle function name. int status; char *demangled_name = abi::__cxa_demangle(trace_name.c_str(), 0, 0, &status); - std::string ans; - if (status == 0) { - ans = demangled_name; + if (status == 0 && demangled_name != NULL) { + trace_name = demangled_name; free(demangled_name); - } else { - ans = trace_name; } - // return, - return ans; -#else +#endif // HAVE_CXXABI_H return trace_name; -#endif } - +#endif // HAVE_EXECINFO_H static std::string KaldiGetStackTrace() { std::string ans; #ifdef HAVE_EXECINFO_H -#define KALDI_MAX_TRACE_SIZE 50 -#define KALDI_MAX_TRACE_PRINT 20 // must be even. - // buffer for the trace, + const size_t KALDI_MAX_TRACE_SIZE = 50; + const size_t KALDI_MAX_TRACE_PRINT = 20; // Must be even. + // Buffer for the trace. void *trace[KALDI_MAX_TRACE_SIZE]; - // get the trace, + // Get the trace. size_t size = backtrace(trace, KALDI_MAX_TRACE_SIZE); - // get the trace symbols, + // Get the trace symbols. char **trace_symbol = backtrace_symbols(trace, size); + if (trace_symbol == NULL) + return ans; - // Compose the 'string', + // Compose a human-readable backtrace string. ans += "[ Stack-Trace: ]\n"; if (size <= KALDI_MAX_TRACE_PRINT) { for (size_t i = 0; i < size; i++) { ans += Demangle(trace_symbol[i]) + "\n"; } - } else { // print out first+last (e.g.) 5. + } else { // Print out first+last (e.g.) 5. for (size_t i = 0; i < KALDI_MAX_TRACE_PRINT/2; i++) { ans += Demangle(trace_symbol[i]) + "\n"; } @@ -125,11 +128,12 @@ static std::string KaldiGetStackTrace() { ans += Demangle(trace_symbol[i]) + "\n"; } if (size == KALDI_MAX_TRACE_SIZE) - ans += ".\n.\n.\n"; // stack was too long, probably a bug. + ans += ".\n.\n.\n"; // Stack was too long, probably a bug. } - // cleanup, - free(trace_symbol); // it's okay, just the pointers, not the strings. + // We must free the array of pointers allocated by backtrace_symbols(), + // but not the strings themselves. + free(trace_symbol); #endif // HAVE_EXECINFO_H return ans; } @@ -142,86 +146,55 @@ MessageLogger::MessageLogger(LogMessageEnvelope::Severity severity, // Obviously, we assume the strings survive the destruction of this object. envelope_.severity = severity; envelope_.func = func; - envelope_.file = GetShortFileName(file); // Pointer inside 'file'. + envelope_.file = GetShortFileName(file); // Points inside 'file'. envelope_.line = line; } +void MessageLogger::LogMessage() const { + // Send to the logging handler if provided. + if (log_handler != NULL) { + log_handler(envelope_, GetMessage().c_str()); + return; + } -MessageLogger::~MessageLogger() KALDI_NOEXCEPT(false) { - // remove trailing '\n', - std::string str = ss_.str(); - while (!str.empty() && str[str.length() - 1] == '\n') - str.resize(str.length() - 1); - - // print the mesage (or send to logging handler), - MessageLogger::HandleMessage(envelope_, str.c_str()); -} - - -void MessageLogger::HandleMessage(const LogMessageEnvelope &envelope, - const char *message) { - // Send to a logging handler if provided. - if (g_log_handler != NULL) { - g_log_handler(envelope, message); + // Otherwise, use the default Kaldi logging. + // Build the log-message header. + std::stringstream full_message; + if (envelope_.severity > LogMessageEnvelope::kInfo) { + full_message << "VLOG[" << envelope_.severity << "] ("; } else { - // Otherwise, we use the default Kaldi logging. - // Build the log-message 'header', - std::stringstream header; - if (envelope.severity > LogMessageEnvelope::kInfo) { - header << "VLOG[" << envelope.severity << "] ("; - } else { - switch (envelope.severity) { - case LogMessageEnvelope::kInfo : - header << "LOG ("; - break; - case LogMessageEnvelope::kWarning : - header << "WARNING ("; - break; - case LogMessageEnvelope::kError : - header << "ERROR ("; - break; - case LogMessageEnvelope::kAssertFailed : - header << "ASSERTION_FAILED ("; - break; - default: - abort(); // coding error (unknown 'severity'), - } + switch (envelope_.severity) { + case LogMessageEnvelope::kInfo : + full_message << "LOG ("; + break; + case LogMessageEnvelope::kWarning : + full_message << "WARNING ("; + break; + case LogMessageEnvelope::kAssertFailed : + full_message << "ASSERTION_FAILED ("; + break; + case LogMessageEnvelope::kError : + default: // If not the ERROR, it still an error! + full_message << "ERROR ("; + break; } - // fill the other info from the envelope, - header << GetProgramName() << "[" KALDI_VERSION "]" << ':' - << envelope.func << "():" << envelope.file << ':' << envelope.line - << ")"; - - // Printing the message, - if (envelope.severity >= LogMessageEnvelope::kWarning) { - // VLOG, LOG, WARNING: - fprintf(stderr, "%s %s\n", header.str().c_str(), message); - } else { - // ERROR, ASSERT_FAILED (print with stack-trace): - fprintf(stderr, "%s %s\n\n%s\n", header.str().c_str(), message, - KaldiGetStackTrace().c_str()); + } + // Add other info from the envelope and the message text. + full_message << program_name.c_str() << "[" KALDI_VERSION "]" << ':' + << envelope_.func << "():" << envelope_.file << ':' + << envelope_.line << ") " << GetMessage().c_str(); + + // Add stack trace for errors and assertion failures, if available. + if (envelope_.severity < LogMessageEnvelope::kWarning) { + const std::string& stack_trace = KaldiGetStackTrace(); + if (!stack_trace.empty()) { + full_message << "\n\n" << stack_trace; } } - // Should we throw exception, or abort? - switch (envelope.severity) { - case LogMessageEnvelope::kAssertFailed: - abort(); // ASSERT_FAILED, - break; - case LogMessageEnvelope::kError: - if (!std::uncaught_exception()) { - // throw exception with empty message, - throw std::runtime_error(""); // KALDI_ERR, - } else { - // If we got here, this thread has already thrown exception, - // and this exception has not yet arrived to its 'catch' clause... - // Throwing a new exception would be unsafe! - // (can happen during 'stack unwinding', if we have 'KALDI_ERR << msg' - // in a destructor of some local object). - abort(); - } - break; - } + // Print the complete message to stderr. + full_message << "\n"; + std::cerr << full_message.str(); } @@ -229,17 +202,20 @@ void MessageLogger::HandleMessage(const LogMessageEnvelope &envelope, void KaldiAssertFailure_(const char *func, const char *file, int32 line, const char *cond_str) { - MessageLogger ml(LogMessageEnvelope::kAssertFailed, func, file, line); - ml.stream() << ": '" << cond_str << "' "; + MessageLogger::Log() = + MessageLogger (LogMessageEnvelope::kAssertFailed, func, file, line) + << "Assertion failed: (" << cond_str << ")"; + fflush(NULL); // Flush all pending buffers, abort() may not flush stderr. + std::abort(); } /***** THIRD-PARTY LOG-HANDLER *****/ -LogHandler SetLogHandler(LogHandler new_handler) { - LogHandler old_handler = g_log_handler; - g_log_handler = new_handler; +LogHandler SetLogHandler(LogHandler handler) { + LogHandler old_handler = log_handler; + log_handler = handler; return old_handler; } -} // end namespace kaldi +} // namespace kaldi diff --git a/src/base/kaldi-error.h b/src/base/kaldi-error.h index 172ea675312..c90a18b15f1 100644 --- a/src/base/kaldi-error.h +++ b/src/base/kaldi-error.h @@ -1,5 +1,6 @@ // base/kaldi-error.h +// Copyright 2019 SmartAction LLC (kkm) // Copyright 2016 Brno University of Technology (author: Karel Vesely) // Copyright 2009-2011 Microsoft Corporation; Ondrej Glembek; Lukas Burget; // Saarland University @@ -33,18 +34,6 @@ #include "base/kaldi-utils.h" /* Important that this file does not depend on any other kaldi headers. */ -// By adding 'KALDI_NOEXCEPT(bool)' immediately after function declaration, -// we can tell the compiler that the function must-not produce -// exceptions (true), or may produce exceptions (false): -#if _MSC_VER >= 1900 || (!defined(_MSC_VER) && __cplusplus >= 201103L) -#define KALDI_NOEXCEPT(Predicate) noexcept((Predicate)) -#elif defined(__GXX_EXPERIMENTAL_CXX0X__) && \ - (__GNUC__ >= 4 && __GNUC_MINOR__ >= 6) -#define KALDI_NOEXCEPT(Predicate) noexcept((Predicate)) -#else -#define KALDI_NOEXCEPT(Predicate) -#endif - #ifdef _MSC_VER #define __func__ __FUNCTION__ #endif @@ -54,22 +43,23 @@ namespace kaldi { /// \addtogroup error_group /// @{ -/***** VERBOSITY LEVEL *****/ +/***** PROGRAM NAME AND VERBOSITY LEVEL *****/ -/// This is set by util/parse-options.{h, cc} if you set --verbose=? option. -extern int32 g_kaldi_verbose_level; +/// Called by ParseOptions to set base name (no directory) of the executing +/// program. The name is printed in logging code along with every message, +/// because in our scripts, we often mix together the stderr of many programs. +/// This function is very thread-unsafe. +void SetProgramName(const char *basename); -/// This is set by util/parse-options.{h, cc} (from argv[0]) and used (if set) -/// in error reporting code to display the name of the program (this is because -/// in our scripts, we often mix together the stderr of many programs). it is -/// the base-name of the program (no directory), followed by ':' We don't use -/// std::string, due to the static initialization order fiasco. -extern const char *g_program_name; +/// This is set by util/parse-options.{h,cc} if you set --verbose=? option. +/// Do not use directly, prefer {Get,Set}VerboseLevel(). +extern int32 g_kaldi_verbose_level; +/// Get verbosity level, usually set via command line '--verbose=' switch. inline int32 GetVerboseLevel() { return g_kaldi_verbose_level; } -/// This should be rarely used; command-line programs set the verbose level -/// automatically from ParseOptions. +/// This should be rarely used, except by programs using Kaldi as library; +/// command-line programs set the verbose level automatically from ParseOptions. inline void SetVerboseLevel(int32 i) { g_kaldi_verbose_level = i; } @@ -77,83 +67,115 @@ inline void SetVerboseLevel(int32 i) { g_kaldi_verbose_level = i; } /// Log message severity and source location info. struct LogMessageEnvelope { + /// Message severity. In addition to these levels, positive values (1 to 6) + /// specify verbose logging level. Verbose messages are produced only when + /// SetVerboseLevel() has been called to set logging level to at least the + /// corresponding value. enum Severity { - kAssertFailed = -3, - kError = -2, - kWarning = -1, - kInfo = 0, + kAssertFailed = -3, //!< Assertion failure. abort() will be called. + kError = -2, //!< Fatal error. KaldiFatalError will be thrown. + kWarning = -1, //!< Indicates a recoverable but abnormal condition. + kInfo = 0, //!< Informational message. }; - // An 'enum Severity' value, or a positive number indicating verbosity level. - int severity; - const char *func; - const char *file; - int32 line; + int severity; //!< A Severity value, or positive verbosity level. + const char *func; //!< Name of the function invoking the logging. + const char *file; //!< Source file name with up to 1 leading directory. + int32 line; // + MessageLogger &operator<<(const T &val) { + ss_ << val; + return *this; + } + + // When assigned a MessageLogger, log its contents. + struct Log final { + void operator=(const MessageLogger& logger) { + logger.LogMessage(); + } + }; - /// The hook for the 'insertion operator', e.g. - /// 'KALDI_LOG << "Message,"', - inline std::ostream &stream() { return ss_; } + // When assigned a MessageLogger, log its contents and then throw + // a KaldiFatalError. + struct LogAndThrow final { + [[ noreturn ]] void operator=(const MessageLogger& logger) { + logger.LogMessage(); + throw KaldiFatalError(logger.GetMessage()); + } + }; private: - /// The logging function, - static void HandleMessage(const LogMessageEnvelope &env, const char *msg); + std::string GetMessage() const { return ss_.str(); } + void LogMessage() const; -private: LogMessageEnvelope envelope_; std::ostringstream ss_; }; -// The definition of the logging macros, +// Logging macros. #define KALDI_ERR \ - ::kaldi::MessageLogger(::kaldi::LogMessageEnvelope::kError, \ - __func__, __FILE__, __LINE__).stream() + ::kaldi::MessageLogger::LogAndThrow() = \ + ::kaldi::MessageLogger(::kaldi::LogMessageEnvelope::kError, \ + __func__, __FILE__, __LINE__) #define KALDI_WARN \ - ::kaldi::MessageLogger(::kaldi::LogMessageEnvelope::kWarning, \ - __func__, __FILE__, __LINE__).stream() + ::kaldi::MessageLogger::Log() = \ + ::kaldi::MessageLogger(::kaldi::LogMessageEnvelope::kWarning, \ + __func__, __FILE__, __LINE__) #define KALDI_LOG \ - ::kaldi::MessageLogger(::kaldi::LogMessageEnvelope::kInfo, \ - __func__, __FILE__, __LINE__).stream() -#define KALDI_VLOG(v) if ((v) <= ::kaldi::g_kaldi_verbose_level) \ - ::kaldi::MessageLogger((::kaldi::LogMessageEnvelope::Severity)(v), \ - __func__, __FILE__, __LINE__).stream() + ::kaldi::MessageLogger::Log() = \ + ::kaldi::MessageLogger(::kaldi::LogMessageEnvelope::kInfo, \ + __func__, __FILE__, __LINE__) +#define KALDI_VLOG(v) \ + if ((v) <= ::kaldi::GetVerboseLevel()) \ + ::kaldi::MessageLogger::Log() = \ + ::kaldi::MessageLogger((::kaldi::LogMessageEnvelope::Severity)(v), \ + __func__, __FILE__, __LINE__) /***** KALDI ASSERTS *****/ -void KaldiAssertFailure_(const char *func, const char *file, - int32 line, const char *cond_str); +[[ noreturn ]] void KaldiAssertFailure_(const char *func, const char *file, + int32 line, const char *cond_str); -// Note on KALDI_ASSERT and KALDI_PARANOID_ASSERT -// The original (simple) version of the code was this +// Note on KALDI_ASSERT and KALDI_PARANOID_ASSERT: // -// #define KALDI_ASSERT(cond) if (!(cond)) -// kaldi::KaldiAssertFailure_(__func__, __FILE__, __LINE__, #cond); -// -// That worked well, but we were concerned that it -// could potentially cause a performance issue due to failed branch -// prediction (best practice is to have the if branch be the commonly -// taken one). -// Therefore, we decided to move the call into the else{} branch. // A single block {} around if /else does not work, because it causes // syntax error (unmatched else block) in the following code: // @@ -162,19 +184,21 @@ void KaldiAssertFailure_(const char *func, const char *file, // else // SomethingElse(); // -// do {} while(0) -- note there is no semicolon at the end! --- works nicely +// do {} while(0) -- note there is no semicolon at the end! -- works nicely, // and compilers will be able to optimize the loop away (as the condition // is always false). +// +// Also see KALDI_COMPILE_TIME_ASSERT, defined in base/kaldi-utils.h, and +// KALDI_ASSERT_IS_INTEGER_TYPE and KALDI_ASSERT_IS_FLOATING_TYPE, also defined +// there. #ifndef NDEBUG #define KALDI_ASSERT(cond) do { if (cond) (void)0; else \ ::kaldi::KaldiAssertFailure_(__func__, __FILE__, __LINE__, #cond); } while(0) #else #define KALDI_ASSERT(cond) (void)0 #endif -// Also see KALDI_COMPILE_TIME_ASSERT, defined in base/kaldi-utils.h, -// and KALDI_ASSERT_IS_INTEGER_TYPE and KALDI_ASSERT_IS_FLOATING_TYPE, -// also defined there. -// some more expensive asserts only checked if this defined + +// Some more expensive asserts only checked if this defined. #ifdef KALDI_PARANOID #define KALDI_PARANOID_ASSERT(cond) do { if (cond) (void)0; else \ ::kaldi::KaldiAssertFailure_(__func__, __FILE__, __LINE__, #cond); } while(0) @@ -185,14 +209,15 @@ void KaldiAssertFailure_(const char *func, const char *file, /***** THIRD-PARTY LOG-HANDLER *****/ -/// Type of third-party logging function, +/// Type of third-party logging function. typedef void (*LogHandler)(const LogMessageEnvelope &envelope, const char *message); /// Set logging handler. If called with a non-NULL function pointer, the -/// function pointed by it is called to send messages to a caller-provided -/// log. If called with NULL pointer, restores default Kaldi error logging to -/// stderr. SetLogHandler is obviously not thread safe. +/// function pointed by it is called to send messages to a caller-provided log. +/// If called with a NULL pointer, restores default Kaldi error logging to +/// stderr. This function is obviously not thread safe; the log handler must be. +/// Returns a previously set logging handler pointer, or NULL. LogHandler SetLogHandler(LogHandler); /// @} end "addtogroup error_group" diff --git a/src/base/kaldi-math.cc b/src/base/kaldi-math.cc index 991e46a590c..17271f3c46f 100644 --- a/src/base/kaldi-math.cc +++ b/src/base/kaldi-math.cc @@ -21,6 +21,7 @@ #include "base/kaldi-math.h" #ifndef _MSC_VER #include +#include #endif #include #include @@ -42,7 +43,7 @@ int32 RoundUpToNearestPowerOfTwo(int32 n) { static std::mutex _RandMutex; int Rand(struct RandomState* state) { -#if defined(_MSC_VER) || defined(__CYGWIN__) +#if !defined(_POSIX_THREAD_SAFE_FUNCTIONS) // On Windows and Cygwin, just call Rand() return rand(); #else @@ -109,10 +110,8 @@ int32 RandInt(int32 min_val, int32 max_val, struct RandomState* state) { return min_val + ( (unsigned int)( (Rand(state)+RAND_MAX*Rand(state))) % (unsigned int)(max_val+1-min_val)); } else { - throw std::runtime_error(std::string() - +"rand_int failed because we do not support " - +"such large random numbers. " - +"(Extend this function)."); + KALDI_ERR << "rand_int failed because we do not support such large " + "random numbers. (Extend this function)."; } } #else diff --git a/src/base/timer.cc b/src/base/timer.cc new file mode 100644 index 00000000000..ce4ef292783 --- /dev/null +++ b/src/base/timer.cc @@ -0,0 +1,85 @@ +// base/timer.cc + +// Copyright 2018 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/timer.h" +#include "base/kaldi-error.h" +#include +#include +#include +#include + +namespace kaldi { + +class ProfileStats { + public: + void AccStats(const char *function_name, double elapsed) { + std::unordered_map::iterator + iter = map_.find(function_name); + if (iter == map_.end()) { + map_[function_name] = ProfileStatsEntry(function_name); + map_[function_name].total_time = elapsed; + } else { + iter->second.total_time += elapsed; + } + } + ~ProfileStats() { + // This map makes sure we agglomerate the time if there were any duplicate + // addresses of strings. + std::unordered_map total_time; + for (auto iter = map_.begin(); iter != map_.end(); iter++) + total_time[iter->second.name] += iter->second.total_time; + + ReverseSecondComparator comp; + std::vector > pairs(total_time.begin(), + total_time.end()); + std::sort(pairs.begin(), pairs.end(), comp); + for (size_t i = 0; i < pairs.size(); i++) { + KALDI_LOG << "Time taken in " << pairs[i].first << " is " + << std::fixed << std::setprecision(2) << pairs[i].second << "s."; + } + } + private: + + struct ProfileStatsEntry { + std::string name; + double total_time; + ProfileStatsEntry() { } + ProfileStatsEntry(const char *name): name(name) { } + }; + + struct ReverseSecondComparator { + bool operator () (const std::pair &a, + const std::pair &b) { + return a.second > b.second; + } + }; + + // Note: this map is keyed on the address of the string, there is no proper + // hash function. The assumption is that the strings are compile-time + // constants. + std::unordered_map map_; +}; + +ProfileStats g_profile_stats; + +Profiler::~Profiler() { + g_profile_stats.AccStats(name_, tim_.Elapsed()); +} + +} // namespace kaldi diff --git a/src/base/timer.h b/src/base/timer.h index 7889c4a258b..0e033766362 100644 --- a/src/base/timer.h +++ b/src/base/timer.h @@ -20,7 +20,7 @@ #define KALDI_BASE_TIMER_H_ #include "base/kaldi-utils.h" -// Note: Sleep(float secs) is included in base/kaldi-utils.h. +#include "base/kaldi-error.h" #if defined(_MSC_VER) || defined(MINGW) @@ -53,7 +53,7 @@ class Timer { private: LARGE_INTEGER time_start_; }; -} + #else #include @@ -87,9 +87,29 @@ class Timer { struct timeval time_start_; struct timezone time_zone_; }; -} #endif +class Profiler { + public: + // Caution: the 'const char' should always be a string constant; for speed, + // internally the profiling code uses the address of it as a lookup key. + Profiler(const char *function_name): name_(function_name) { } + ~Profiler(); + private: + Timer tim_; + const char *name_; +}; + +// To add timing info for a function, you just put +// KALDI_PROFILE; +// at the beginning of the function. Caution: this doesn't +// include the class name. +#define KALDI_PROFILE Profiler _profiler(__func__) + + + +} // namespace kaldi + #endif // KALDI_BASE_TIMER_H_ diff --git a/src/bin/Makefile b/src/bin/Makefile index 627c4f8a131..7cb01b50120 100644 --- a/src/bin/Makefile +++ b/src/bin/Makefile @@ -21,7 +21,8 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \ post-to-pdf-post logprob-to-post prob-to-post copy-post \ matrix-sum build-pfile-from-ali get-post-on-ali tree-info am-info \ vector-sum matrix-sum-rows est-pca sum-lda-accs sum-mllt-accs \ - transform-vec align-text matrix-dim post-to-smat + transform-vec align-text matrix-dim post-to-smat compile-graph \ + compare-int-vector OBJFILES = @@ -29,8 +30,8 @@ OBJFILES = ADDLIBS = ../decoder/kaldi-decoder.a ../lat/kaldi-lat.a ../lm/kaldi-lm.a \ ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a TESTFILES = diff --git a/src/bin/acc-lda.cc b/src/bin/acc-lda.cc index 92cd192b9a6..b664135bdc7 100644 --- a/src/bin/acc-lda.cc +++ b/src/bin/acc-lda.cc @@ -37,7 +37,7 @@ int main(int argc, char *argv[]) { "Accumulate LDA statistics based on pdf-ids.\n" "Usage: acc-lda [options] \n" "Typical usage:\n" - " ali-to-post ark:1.ali ark:- | lda-acc 1.mdl \"ark:splice-feats scp:train.scp|\" ark:- ldaacc.1\n"; + " ali-to-post ark:1.ali ark:- | acc-lda 1.mdl \"ark:splice-feats scp:train.scp|\" ark:- ldaacc.1\n"; bool binary = true; BaseFloat rand_prune = 0.0; @@ -126,5 +126,3 @@ int main(int argc, char *argv[]) { return -1; } } - - diff --git a/src/bin/ali-to-phones.cc b/src/bin/ali-to-phones.cc index 2a76000cfae..602e32e9768 100644 --- a/src/bin/ali-to-phones.cc +++ b/src/bin/ali-to-phones.cc @@ -38,7 +38,7 @@ int main(int argc, char *argv[]) { " ali-to-phones 1.mdl ark:1.ali ark:-\n" "or:\n" " ali-to-phones --ctm-output 1.mdl ark:1.ali 1.ctm\n" - "See also: show-alignments lattice-align-phones\n"; + "See also: show-alignments lattice-align-phones, compare-int-vector\n"; ParseOptions po(usage); bool per_frame = false; bool write_lengths = false; @@ -137,5 +137,3 @@ int main(int argc, char *argv[]) { return -1; } } - - diff --git a/src/bin/align-text.cc b/src/bin/align-text.cc index 616dac858d7..1c695675274 100644 --- a/src/bin/align-text.cc +++ b/src/bin/align-text.cc @@ -86,28 +86,34 @@ int main(int argc, char *argv[]) { if (!text2_reader.HasKey(key)) { KALDI_WARN << "Key " << key << " is in " << text1_rspecifier - << ", but not in " << text2_rspecifier; + << ", but not in " << text2_rspecifier; n_fail++; continue; } const std::vector &text1 = text1_reader.Value(); const std::vector &text2 = text2_reader.Value(key); - // Checks if the special symbol is in the string. - KALDI_ASSERT(std::find(text1.begin(), - text1.end(), special_symbol) == text1.end()); - KALDI_ASSERT(std::find(text2.begin(), - text2.end(), special_symbol) == text2.end()); - if (std::find_if(text1.begin(), text1.end(), IsNotToken) != text1.end()) { - KALDI_ERR << "In text1, the utterance " << key << " contains unprintable characters." \ - << "That means there is a problem with the text (such as incorrect encoding)." << std::endl; - return -1; + KALDI_ERR << "In text1, the utterance " << key + << " contains unprintable characters. That means there is" + << " a problem with the text (such as incorrect encoding)."; } if (std::find_if(text2.begin(), text2.end(), IsNotToken) != text2.end()) { - KALDI_ERR << "In text2, the utterance " << key << " contains unprintable characters." \ - << "That means there is a problem with the text (such as incorrect encoding)." << std::endl; - return -1; + KALDI_ERR << "In text2, the utterance " << key + << " contains unprintable characters. That means there is" + << " a problem with the text (such as incorrect encoding)."; + } + + // Verify that the special symbol is not in the string. + if (std::find(text1.begin(), text1.end(), special_symbol) != text1.end()){ + KALDI_ERR << "In text1, the utterance " << key + << " contains the special symbol '" << special_symbol + << "'. This is not allowed."; + } + if (std::find(text2.begin(), text2.end(), special_symbol) != text2.end()){ + KALDI_ERR << "In text2, the utterance " << key + << " contains the special symbol '" << special_symbol + << "'. This is not allowed."; } std::vector > aligned; diff --git a/src/bin/compare-int-vector.cc b/src/bin/compare-int-vector.cc new file mode 100644 index 00000000000..5f80ff5ee6c --- /dev/null +++ b/src/bin/compare-int-vector.cc @@ -0,0 +1,184 @@ +// bin/compare-int-vector.cc + +// Copyright 2018 Johns Hopkins University (Author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-vector.h" +#include "transform/transform-common.h" +#include + + +namespace kaldi { +void AddToCount(int32 location_to_add, + double value_to_add, + std::vector *counts) { + if (location_to_add < 0) + KALDI_ERR << "Contents of vectors cannot be " + "negative if --write-tot-counts or --write-diff-counts " + "options are provided."; + if (counts->size() <= static_cast(location_to_add)) + counts->resize(location_to_add + 1, 0.0); + (*counts)[location_to_add] += value_to_add; +} + +void AddToConfusionMatrix(int32 phone1, int32 phone2, + Matrix *counts) { + if (phone1 < 0 || phone2 < 0) + KALDI_ERR << "Contents of vectors cannot be " + "negative if --write-confusion-matrix option is " + "provided."; + int32 max_size = std::max(phone1, phone2) + 1; + if (counts->NumRows() < max_size) + counts->Resize(max_size, max_size, kCopyData); + (*counts)(phone1, phone2) += 1.0; +} + + +void WriteAsKaldiVector(const std::vector &counts, + std::string &wxfilename, + bool binary) { + Vector counts_vec(counts.size()); + for (size_t i = 0; i < counts.size(); i++) + counts_vec(i) = counts[i]; + WriteKaldiObject(counts_vec, wxfilename, binary); +} + +} // namespace kaldi + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Compare vectors of integers (e.g. phone alignments)\n" + "Prints to stdout fields of the form:\n" + " \n" + "\n" + "e.g.:\n" + " SWB1_A_31410_32892 420 36\n" + "\n" + "Usage:\n" + "compare-int-vector [options] \n" + "\n" + "e.g. compare-int-vector scp:foo.scp scp:bar.scp > comparison\n" + "E.g. the inputs might come from ali-to-phones.\n" + "Warnings are printed if the vector lengths differ for a given utterance-id,\n" + "and in those cases, the number of frames printed will be the smaller of the\n" + "\n" + "See also: ali-to-phones, copy-int-vector\n"; + + + ParseOptions po(usage); + + std::string tot_wxfilename, + diff_wxfilename, + confusion_matrix_wxfilename; + bool binary = true; + + po.Register("binary", &binary, "If true, write in binary mode (only applies " + "if --write-tot-counts or --write-diff-counts options are supplied)."); + po.Register("write-tot-counts", &tot_wxfilename, "Filename to write " + "vector of total counts. These may be summed with 'vector-sum'."); + po.Register("write-diff-counts", &diff_wxfilename, "Filename to write " + "vector of counts of phones (or whatever is in the inputs) " + "that differ from one vector to the other. Each time a pair differs, " + "0.5 will be added to each one's location."); + po.Register("write-confusion-matrix", &confusion_matrix_wxfilename, + "Filename to write confusion matrix, indexed by [phone1][phone2]." + "These may be summed by 'matrix-sum'."); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string vector1_rspecifier = po.GetArg(1), + vector2_rspecifier = po.GetArg(2); + + int64 num_done = 0, + num_not_found = 0, + num_mismatched_lengths = 0, + tot_frames = 0, tot_difference = 0; + + std::vector diff_counts; + std::vector tot_counts; + Matrix confusion_matrix; + + SequentialInt32VectorReader reader1(vector1_rspecifier); + RandomAccessInt32VectorReader reader2(vector2_rspecifier); + + for (; !reader1.Done(); reader1.Next(), num_done++) { + const std::string &key = reader1.Key(); + if (!reader2.HasKey(key)) { + KALDI_WARN << "No key " << key << " found in second input."; + num_not_found++; + continue; + } + const std::vector &value1 = reader1.Value(), + &value2 = reader2.Value(key); + size_t len1 = value1.size(), len2 = value2.size(); + if (len1 != len2) { + KALDI_WARN << "For utterance " << key << ", lengths differ " + << len1 << " vs. " << len2; + num_mismatched_lengths++; + } + size_t len = std::min(len1, len2), + difference = 0; + for (size_t i = 0; i < len; i++) { + int32 phone1 = value1[i], phone2 = value2[i]; + if (phone1 != phone2) { + difference++; + if (!diff_wxfilename.empty()) { + AddToCount(phone1, 0.5, &diff_counts); + AddToCount(phone2, 0.5, &diff_counts); + } + } + if (!tot_wxfilename.empty()) + AddToCount(phone1, 1.0, &tot_counts); + if (!confusion_matrix_wxfilename.empty()) + AddToConfusionMatrix(phone1, phone2, &confusion_matrix); + } + num_done++; + std::cout << key << " " << len << " " << difference << "\n"; + tot_frames += len; + tot_difference += difference; + } + + BaseFloat difference_percent = tot_difference * 100.0 / tot_frames; + KALDI_LOG << "Computed difference for " << num_done << " utterances, of which " + << num_mismatched_lengths << " had mismatched lengths; corresponding " + "utterance not found for " << num_not_found; + KALDI_LOG << "Average p(different) is " << std::setprecision(4) << difference_percent + << "%, over " << tot_frames << " frames."; + + if (!tot_wxfilename.empty()) + WriteAsKaldiVector(tot_counts, tot_wxfilename, binary); + if (!diff_wxfilename.empty()) + WriteAsKaldiVector(diff_counts, diff_wxfilename, binary); + if (!confusion_matrix_wxfilename.empty()) + WriteKaldiObject(confusion_matrix, confusion_matrix_wxfilename, binary); + + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/bin/compile-graph.cc b/src/bin/compile-graph.cc new file mode 100644 index 00000000000..7174fdf8113 --- /dev/null +++ b/src/bin/compile-graph.cc @@ -0,0 +1,200 @@ +// bin/compile-graph.cc + +// Copyright 2018 Johns Hopkins University (Author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "tree/context-dep.h" +#include "hmm/transition-model.h" +#include "hmm/hmm-utils.h" +#include "fstext/fstext-lib.h" +#include "fstext/push-special.h" +#include "fstext/grammar-context-fst.h" +#include "decoder/grammar-fst.h" + + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + using fst::SymbolTable; + using fst::VectorFst; + using fst::StdArc; + + + const char *usage = + "Creates HCLG decoding graph. Similar to mkgraph.sh but done in code.\n" + "\n" + "Usage: compile-graph [options] " + " \n" + "e.g.: \n" + " compile-train-graphs-fsts tree 1.mdl L_disambig.fst G.fst HCLG.fst\n"; + ParseOptions po(usage); + + + BaseFloat transition_scale = 1.0; + BaseFloat self_loop_scale = 1.0; // Caution: the script default is 0.1. + int32 nonterm_phones_offset = -1; + std::string disambig_rxfilename; + + + po.Register("read-disambig-syms", &disambig_rxfilename, "File containing " + "list of disambiguation symbols in phone symbol table"); + po.Register("transition-scale", &transition_scale, "Scale of transition " + "probabilities (excluding self-loops)."); + po.Register("self-loop-scale", &self_loop_scale, "Scale of self-loop vs. " + "non-self-loop probability mass. Caution: the default of " + "mkgraph.sh is 0.1, but this defaults to 1.0."); + po.Register("nonterm-phones-offset", &nonterm_phones_offset, "Integer " + "value of symbol #nonterm_bos in phones.txt, if present. " + "(Only relevant for grammar decoding)."); + + po.Read(argc, argv); + + if (po.NumArgs() != 5) { + po.PrintUsage(); + exit(1); + } + + std::string tree_rxfilename = po.GetArg(1), + model_rxfilename = po.GetArg(2), + lex_rxfilename = po.GetArg(3), + grammar_rxfilename = po.GetArg(4), + hclg_wxfilename = po.GetArg(5); + + ContextDependency ctx_dep; // the tree. + ReadKaldiObject(tree_rxfilename, &ctx_dep); + + TransitionModel trans_model; + ReadKaldiObject(model_rxfilename, &trans_model); + + VectorFst *lex_fst = fst::ReadFstKaldi(lex_rxfilename), + *grammar_fst = fst::ReadFstKaldi(grammar_rxfilename); + + std::vector disambig_syms; + if (disambig_rxfilename != "") + if (!ReadIntegerVectorSimple(disambig_rxfilename, &disambig_syms)) + KALDI_ERR << "Could not read disambiguation symbols from " + << disambig_rxfilename; + if (disambig_syms.empty()) + KALDI_WARN << "You supplied no disambiguation symbols; note, these are " + << "typically necessary when compiling graphs from FSTs (i.e. " + << "supply L_disambig.fst and the list of disambig syms with\n" + << "--read-disambig-syms)"; + + const std::vector &phone_syms = trans_model.GetPhones(); + SortAndUniq(&disambig_syms); + for (int32 i = 0; i < disambig_syms.size(); i++) + if (std::binary_search(phone_syms.begin(), phone_syms.end(), + disambig_syms[i])) + KALDI_ERR << "Disambiguation symbol " << disambig_syms[i] + << " is also a phone."; + + VectorFst lg_fst; + TableCompose(*lex_fst, *grammar_fst, &lg_fst); + + DeterminizeStarInLog(&lg_fst, fst::kDelta); + + MinimizeEncoded(&lg_fst, fst::kDelta); + + fst::PushSpecial(&lg_fst, fst::kDelta); + + delete grammar_fst; + delete lex_fst; + + VectorFst clg_fst; + + std::vector > ilabels; + + int32 context_width = ctx_dep.ContextWidth(), + central_position = ctx_dep.CentralPosition(); + + if (nonterm_phones_offset < 0) { + // The normal case. + ComposeContext(disambig_syms, context_width, central_position, + &lg_fst, &clg_fst, &ilabels); + } else { + // The grammar-FST case. See ../doc/grammar.dox for an intro. + if (context_width != 2 || central_position != 1) { + KALDI_ERR << "Grammar-fst graph creation only supports models with left-" + "biphone context. (--nonterm-phones-offset option was supplied)."; + } + ComposeContextLeftBiphone(nonterm_phones_offset, disambig_syms, + lg_fst, &clg_fst, &ilabels); + } + lg_fst.DeleteStates(); + + HTransducerConfig h_cfg; + h_cfg.transition_scale = transition_scale; + h_cfg.nonterm_phones_offset = nonterm_phones_offset; + std::vector disambig_syms_h; // disambiguation symbols on + // input side of H. + VectorFst *h_fst = GetHTransducer(ilabels, + ctx_dep, + trans_model, + h_cfg, + &disambig_syms_h); + + VectorFst hclg_fst; // transition-id to word. + TableCompose(*h_fst, clg_fst, &hclg_fst); + clg_fst.DeleteStates(); + delete h_fst; + + KALDI_ASSERT(hclg_fst.Start() != fst::kNoStateId); + + // Epsilon-removal and determinization combined. This will fail if not determinizable. + DeterminizeStarInLog(&hclg_fst); + + if (!disambig_syms_h.empty()) { + RemoveSomeInputSymbols(disambig_syms_h, &hclg_fst); + RemoveEpsLocal(&hclg_fst); + } + + // Encoded minimization. + MinimizeEncoded(&hclg_fst); + + std::vector disambig; + bool check_no_self_loops = true, + reorder = true; + AddSelfLoops(trans_model, + disambig, + self_loop_scale, + reorder, + check_no_self_loops, + &hclg_fst); + + if (nonterm_phones_offset >= 0) + PrepareForGrammarFst(nonterm_phones_offset, &hclg_fst); + + { // convert 'hclg' to ConstFst and write. + fst::ConstFst const_hclg(hclg_fst); + bool binary = true, write_binary_header = false; // suppress the ^@B + Output ko(hclg_wxfilename, binary, write_binary_header); + fst::FstWriteOptions wopts(PrintableWxfilename(hclg_wxfilename)); + const_hclg.Write(ko.Stream(), wopts); + } + + KALDI_LOG << "Wrote graph with " << hclg_fst.NumStates() + << " states to " << hclg_wxfilename; + return 0; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/bin/compute-wer-bootci.cc b/src/bin/compute-wer-bootci.cc index c6dcd051749..ba2a4ce739c 100644 --- a/src/bin/compute-wer-bootci.cc +++ b/src/bin/compute-wer-bootci.cc @@ -162,10 +162,10 @@ int main(int argc, char *argv[]) { try { const char *usage = - "Compute a bootstrapping of WER to extract the 95\% confidence interval.\n" + "Compute a bootstrapping of WER to extract the 95% confidence interval.\n" "Take a reference and a transcription file, in integer or text format,\n" "and outputs overall WER statistics to standard output along with its\n" - "confidence interval using the bootstrap methos of Bisani and Ney.\n" + "confidence interval using the bootstrap method of Bisani and Ney.\n" "If a second transcription file corresponding to the same reference is\n" "provided, a bootstrap comparison of the two transcription is performed\n" "to estimate the probability of improvement.\n" @@ -234,12 +234,12 @@ int main(int argc, char *argv[]) { std::cout.precision(2); std::cerr.precision(2); std::cout << "Set1: %WER " << std::fixed << 100*mean_wer << - " 95\% Conf Interval [ " << 100*mean_wer-100*interval << + " 95% Conf Interval [ " << 100*mean_wer-100*interval << ", " << 100*mean_wer+100*interval << " ]" << '\n'; if(!hyp2_rspecifier.empty()) { std::cout << "Set2: %WER " << std::fixed << 100*mean_wer2 << - " 95\% Conf Interval [ " << 100*mean_wer2-100*interval2 << + " 95% Conf Interval [ " << 100*mean_wer2-100*interval2 << ", " << 100*mean_wer2+100*interval2 << " ]" << '\n'; std::cout << "Probability of Set2 improving Set1: " << std::fixed << diff --git a/src/bin/copy-post.cc b/src/bin/copy-post.cc index 6d0d351a594..d5ca3f42980 100644 --- a/src/bin/copy-post.cc +++ b/src/bin/copy-post.cc @@ -26,13 +26,13 @@ int main(int argc, char *argv[]) { try { using namespace kaldi; - typedef kaldi::int32 int32; + typedef kaldi::int32 int32; const char *usage = "Copy archives of posteriors, with optional scaling\n" - "(Also see rand-prune-post and sum-post)\n" "\n" - "Usage: copy-post \n"; + "Usage: copy-post \n" + "See also: post-to-weights, scale-post, sum-post, weight-post ...\n"; BaseFloat scale = 1.0; ParseOptions po(usage); @@ -43,15 +43,15 @@ int main(int argc, char *argv[]) { po.PrintUsage(); exit(1); } - + std::string post_rspecifier = po.GetArg(1), post_wspecifier = po.GetArg(2); kaldi::SequentialPosteriorReader posterior_reader(post_rspecifier); - kaldi::PosteriorWriter posterior_writer(post_wspecifier); + kaldi::PosteriorWriter posterior_writer(post_wspecifier); int32 num_done = 0; - + for (; !posterior_reader.Done(); posterior_reader.Next()) { std::string key = posterior_reader.Key(); @@ -71,4 +71,3 @@ int main(int argc, char *argv[]) { return -1; } } - diff --git a/src/bin/draw-tree.cc b/src/bin/draw-tree.cc index ad1dd41a53f..d107ab1cfac 100644 --- a/src/bin/draw-tree.cc +++ b/src/bin/draw-tree.cc @@ -18,6 +18,7 @@ // limitations under the License. #include "tree/tree-renderer.h" +#include "tree/context-dep.h" void MakeEvent(std::string &qry, fst::SymbolTable *phone_syms, kaldi::EventType **query) @@ -33,25 +34,23 @@ void MakeEvent(std::string &qry, fst::SymbolTable *phone_syms, if (key == kPdfClass) { value = static_cast(atoi(valstr.c_str())); if (value < 0) { // not valid pdf-class - KALDI_ERR << "Bad query: invalid pdf-class (" - << valstr << ')' << std::endl << std::endl; + KALDI_ERR << "Bad query: invalid pdf-class (" << valstr << ')'; } } else { value = static_cast(phone_syms->Find(valstr.c_str())); if (value == -1) { // fst::kNoSymbol - KALDI_ERR << "Bad query: invalid symbol (" - << valstr << ')' << std::endl << std::endl; + KALDI_ERR << "Bad query: invalid symbol (" << valstr << ')'; } } query_event->push_back(std::make_pair(key++, value)); old_found = found + 1; } std::string valstr = qry.substr(old_found); - EventValueType value = static_cast(phone_syms->Find(valstr.c_str())); + EventValueType value = + static_cast(phone_syms->Find(valstr.c_str())); if (value == -1) { // fst::kNoSymbol - KALDI_ERR << "Bad query: invalid symbol (" - << valstr << ')' << std::endl << std::endl; + KALDI_ERR << "Bad query: invalid symbol (" << valstr << ')'; } query_event->push_back(std::make_pair(key, value)); diff --git a/src/bin/matrix-sum.cc b/src/bin/matrix-sum.cc index 8a7b5a39e00..3c93dfd0d39 100644 --- a/src/bin/matrix-sum.cc +++ b/src/bin/matrix-sum.cc @@ -238,17 +238,20 @@ int32 TypeThreeUsage(const ParseOptions &po, << "tables, the intermediate arguments must not be tables."; } - bool add = true; - Matrix mat; + Matrix sum; for (int32 i = 1; i < po.NumArgs(); i++) { - bool binary_in; - Input ki(po.GetArg(i), &binary_in); - // this Read function will throw if there is a size mismatch. - mat.Read(ki.Stream(), binary_in, add); + Matrix this_mat; + ReadKaldiObject(po.GetArg(i), &this_mat); + if (sum.NumRows() < this_mat.NumRows() || + sum.NumCols() < this_mat.NumCols()) + sum.Resize(std::max(sum.NumRows(), this_mat.NumRows()), + std::max(sum.NumCols(), this_mat.NumCols()), + kCopyData); + sum.AddMat(1.0, this_mat); } if (average) - mat.Scale(1.0 / (po.NumArgs() - 1)); - WriteKaldiObject(mat, po.GetArg(po.NumArgs()), binary); + sum.Scale(1.0 / (po.NumArgs() - 1)); + WriteKaldiObject(sum, po.GetArg(po.NumArgs()), binary); KALDI_LOG << "Summed " << (po.NumArgs() - 1) << " matrices; " << "wrote sum to " << PrintableWxfilename(po.GetArg(po.NumArgs())); return 0; @@ -335,4 +338,3 @@ int main(int argc, char *argv[]) { return -1; } } - diff --git a/src/bin/post-to-smat.cc b/src/bin/post-to-smat.cc index 8cd8df41647..2d043000866 100644 --- a/src/bin/post-to-smat.cc +++ b/src/bin/post-to-smat.cc @@ -48,15 +48,16 @@ int main(int argc, char *argv[]) { po.Read(argc, argv); - if (dim <= 0) { - KALDI_ERR << "The --dim option must be specified."; - } - if (po.NumArgs() != 2) { po.PrintUsage(); exit(1); } + if (dim <= 0) { + KALDI_ERR << "The --dim option must be specified."; + } + + std::string posteriors_rspecifier = po.GetArg(1), sparse_matrix_wspecifier = po.GetArg(2); diff --git a/src/bin/vector-sum.cc b/src/bin/vector-sum.cc index 42404e38384..3e622cafdc7 100644 --- a/src/bin/vector-sum.cc +++ b/src/bin/vector-sum.cc @@ -1,7 +1,7 @@ // bin/vector-sum.cc -// Copyright 2014 Vimal Manohar -// 2014 Johns Hopkins University (author: Daniel Povey) +// Copyright 2014 Vimal Manohar +// 2014-2018 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // @@ -43,7 +43,7 @@ int32 TypeOneUsage(const ParseOptions &po) { // Input vectors SequentialBaseFloatVectorReader vector_reader1(vector_in_fn1); - std::vector vector_readers(num_args-2, + std::vector vector_readers(num_args-2, static_cast(NULL)); std::vector vector_in_fns(num_args-2); for (int32 i = 2; i < num_args; ++i) { @@ -51,7 +51,7 @@ int32 TypeOneUsage(const ParseOptions &po) { vector_in_fns[i-2] = po.GetArg(i); } - int32 n_utts = 0, n_total_vectors = 0, + int32 n_utts = 0, n_total_vectors = 0, n_success = 0, n_missing = 0, n_other_errors = 0; for (; !vector_reader1.Done(); vector_reader1.Next()) { @@ -70,10 +70,10 @@ int32 TypeOneUsage(const ParseOptions &po) { if (vector2.Dim() == vector_out.Dim()) { vector_out.AddVec(1.0, vector2); } else { - KALDI_WARN << "Dimension mismatch for utterance " << key + KALDI_WARN << "Dimension mismatch for utterance " << key << " : " << vector2.Dim() << " for " << "system " << (i + 2) << ", rspecifier: " - << vector_in_fns[i] << " vs " << vector_out.Dim() + << vector_in_fns[i] << " vs " << vector_out.Dim() << " primary vector, rspecifier:" << vector_in_fn1; n_other_errors++; } @@ -94,9 +94,9 @@ int32 TypeOneUsage(const ParseOptions &po) { << " different systems"; KALDI_LOG << "Produced output for " << n_success << " utterances; " << n_missing << " total missing vectors"; - + DeletePointers(&vector_readers); - + return (n_success != 0 && n_missing < (n_success - n_missing)) ? 0 : 1; } @@ -108,13 +108,13 @@ int32 TypeTwoUsage(const ParseOptions &po, "vector-sum: first argument must be an rspecifier"); // if next assert fails it would be bug in the code as otherwise we shouldn't // be called. - KALDI_ASSERT(ClassifyWspecifier(po.GetArg(2), NULL, NULL, NULL) == + KALDI_ASSERT(ClassifyWspecifier(po.GetArg(2), NULL, NULL, NULL) == kNoWspecifier); SequentialBaseFloatVectorReader vec_reader(po.GetArg(1)); Vector sum; - + int32 num_done = 0, num_err = 0; for (; !vec_reader.Done(); vec_reader.Next()) { @@ -134,7 +134,7 @@ int32 TypeTwoUsage(const ParseOptions &po, } } } - + if (num_done > 0 && average) sum.Scale(1.0 / num_done); Vector sum_float(sum); @@ -157,21 +157,21 @@ int32 TypeThreeUsage(const ParseOptions &po, << "tables, the intermediate arguments must not be tables."; } } - if (ClassifyWspecifier(po.GetArg(po.NumArgs()), NULL, NULL, NULL) != + if (ClassifyWspecifier(po.GetArg(po.NumArgs()), NULL, NULL, NULL) != kNoWspecifier) { KALDI_ERR << "Wrong usage (type 3): if first and last arguments are not " << "tables, the intermediate arguments must not be tables."; } - bool add = true; - Vector vec; + Vector sum; for (int32 i = 1; i < po.NumArgs(); i++) { - bool binary_in; - Input ki(po.GetArg(i), &binary_in); - // this Read function will throw if there is a size mismatch. - vec.Read(ki.Stream(), binary_in, add); + Vector this_vec; + ReadKaldiObject(po.GetArg(i), &this_vec); + if (sum.Dim() < this_vec.Dim()) + sum.Resize(this_vec.Dim(), kCopyData);; + sum.AddVec(1.0, this_vec); } - WriteKaldiObject(vec, po.GetArg(po.NumArgs()), binary); + WriteKaldiObject(sum, po.GetArg(po.NumArgs()), binary); KALDI_LOG << "Summed " << (po.NumArgs() - 1) << " vectors; " << "wrote sum to " << PrintableWxfilename(po.GetArg(po.NumArgs())); return 0; @@ -201,15 +201,15 @@ int main(int argc, char *argv[]) { " \n" " e.g.: vector-sum --binary=false 1.vec 2.vec 3.vec sum.vec\n" "See also: copy-vector, dot-weights\n"; - + bool binary, average = false; - + ParseOptions po(usage); po.Register("binary", &binary, "If true, write output as binary (only " "relevant for usage types two or three"); po.Register("average", &average, "Do average instead of sum"); - + po.Read(argc, argv); int32 N = po.NumArgs(), exit_status; @@ -226,11 +226,11 @@ int main(int argc, char *argv[]) { exit_status = TypeTwoUsage(po, binary, average); } else if (po.NumArgs() >= 2 && ClassifyRspecifier(po.GetArg(1), NULL, NULL) == kNoRspecifier && - ClassifyWspecifier(po.GetArg(N), NULL, NULL, NULL) == + ClassifyWspecifier(po.GetArg(N), NULL, NULL, NULL) == kNoWspecifier) { // summing flat files. exit_status = TypeThreeUsage(po, binary); - } else { + } else { po.PrintUsage(); exit(1); } diff --git a/src/chain/Makefile b/src/chain/Makefile index 2a735c2ca2d..fbad28f7de6 100644 --- a/src/chain/Makefile +++ b/src/chain/Makefile @@ -18,8 +18,7 @@ LIBNAME = kaldi-chain ADDLIBS = ../cudamatrix/kaldi-cudamatrix.a ../lat/kaldi-lat.a \ ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a ../tree/kaldi-tree.a \ - ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a # Make sure we have CUDA_ARCH from kaldi.mk, ifeq ($(CUDA), true) diff --git a/src/chain/chain-den-graph.cc b/src/chain/chain-den-graph.cc index 62d2f3aaa56..11c851091bd 100644 --- a/src/chain/chain-den-graph.cc +++ b/src/chain/chain-den-graph.cc @@ -1,6 +1,6 @@ // chain/chain-den-graph.cc -// Copyright 2015 Johns Hopkins University (author: Daniel Povey) +// Copyright 2015-2018 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // @@ -315,11 +315,18 @@ void CreateDenominatorFst(const ContextDependency &ctx_dep, fst::Project(&phone_lm, fst::PROJECT_INPUT); } std::vector disambig_syms; // empty list of diambiguation symbols. - fst::ContextFst cfst(subsequential_symbol, trans_model.GetPhones(), - disambig_syms, ctx_dep.ContextWidth(), - ctx_dep.CentralPosition()); - StdVectorFst context_dep_lm; - fst::ComposeContextFst(cfst, phone_lm, &context_dep_lm); + + // inv_cfst will be expanded on the fly, as needed. + fst::InverseContextFst inv_cfst(subsequential_symbol, + trans_model.GetPhones(), + disambig_syms, + ctx_dep.ContextWidth(), + ctx_dep.CentralPosition()); + + fst::StdVectorFst context_dep_lm; + fst::ComposeDeterministicOnDemandInverse(phone_lm, &inv_cfst, + &context_dep_lm); + // at this point, context_dep_lm will have indexes into 'ilabels' as its // input symbol (representing context-dependent phones), and phones on its // output. We don't need the phones, so we'll project. @@ -335,7 +342,7 @@ void CreateDenominatorFst(const ContextDependency &ctx_dep, // we'll use the same value in test time. Consistency is the key here. h_config.transition_scale = 1.0; - StdVectorFst *h_fst = GetHTransducer(cfst.ILabelInfo(), + StdVectorFst *h_fst = GetHTransducer(inv_cfst.IlabelInfo(), ctx_dep, trans_model, h_config, @@ -355,7 +362,7 @@ void CreateDenominatorFst(const ContextDependency &ctx_dep, AddSelfLoops(trans_model, disambig_syms_h, self_loop_scale, reorder, check_no_self_loops, &transition_id_fst); // at this point transition_id_fst will have transition-ids as its ilabels and - // context-dependent phones (indexes into ILabelInfo()) as its olabels. + // context-dependent phones (indexes into IlabelInfo()) as its olabels. // Discard the context-dependent phones by projecting on the input, keeping // only the transition-ids. fst::Project(&transition_id_fst, fst::PROJECT_INPUT); diff --git a/src/chain/chain-denominator.cc b/src/chain/chain-denominator.cc index e41e942e266..b644e429b67 100644 --- a/src/chain/chain-denominator.cc +++ b/src/chain/chain-denominator.cc @@ -61,10 +61,11 @@ DenominatorComputation::DenominatorComputation( num_sequences_).SetZero(); KALDI_ASSERT(nnet_output.NumRows() % num_sequences == 0); - // the kStrideEqualNumCols argument means we'll allocate a contiguous block of - // memory for this; it is added to ensure that the same block of memory - // (cached in the allocator) can be used for xent_output_deriv when allocated - // from chain-training.cc. + // the kStrideEqualNumCols argument is so that we can share the same + // memory block with xent_output_deriv (see chain-training.cc, search for + // kStrideEqualNumCols). This depends on how the allocator works, and + // actually might not happen, but anyway, the impact on speed would + // likely be un-measurably small. exp_nnet_output_transposed_.Resize(nnet_output.NumCols(), nnet_output.NumRows(), kUndefined, kStrideEqualNumCols); diff --git a/src/chain/chain-supervision.cc b/src/chain/chain-supervision.cc index 8f95034c437..f8a2c1d11cc 100644 --- a/src/chain/chain-supervision.cc +++ b/src/chain/chain-supervision.cc @@ -299,6 +299,7 @@ bool ProtoSupervisionToSupervision( using fst::VectorFst; using fst::StdArc; VectorFst phone_fst(proto_supervision.fst); + std::vector disambig_syms; // empty list of diambiguation symbols. int32 subsequential_symbol = trans_model.GetPhones().back() + 1; if (ctx_dep.CentralPosition() != ctx_dep.ContextWidth() - 1) { // note: this function only adds the subseq symbol to the input of what was @@ -307,19 +308,28 @@ bool ProtoSupervisionToSupervision( AddSubsequentialLoop(subsequential_symbol, &phone_fst); fst::Project(&phone_fst, fst::PROJECT_INPUT); } - std::vector disambig_syms; // empty list of diambiguation symbols. - fst::ContextFst cfst(subsequential_symbol, trans_model.GetPhones(), - disambig_syms, ctx_dep.ContextWidth(), - ctx_dep.CentralPosition()); + + // inv_cfst will be expanded on the fly, as needed. + fst::InverseContextFst inv_cfst(subsequential_symbol, + trans_model.GetPhones(), + disambig_syms, + ctx_dep.ContextWidth(), + ctx_dep.CentralPosition()); + + VectorFst context_dep_fst; - fst::ComposeContextFst(cfst, phone_fst, &context_dep_fst); - // at this point, context_dep_fst will have indexes into 'ilabels' as its - // input symbol (representing context-dependent phones), and phones on its - // output. We don't need the phones, so we'll project. + ComposeDeterministicOnDemandInverse(phone_fst, &inv_cfst, &context_dep_fst); + + + // at this point, context_dep_fst will have indexes into + // 'inv_cfst.IlabelInfo()' as its input symbol (representing context-dependent + // phones), and phones on its output. We don't need the phones, so we'll + // project. fst::Project(&context_dep_fst, fst::PROJECT_INPUT); - std::vector disambig_syms_h; // disambiguation symbols on input side - // of H -- will be empty. + std::vector disambig_syms_h; // disambiguation symbols on input side of + // H -- will be empty, as there were no + // disambiguation symbols on the output. HTransducerConfig h_cfg; @@ -327,7 +337,7 @@ bool ProtoSupervisionToSupervision( // when we compose with the denominator graph. h_cfg.transition_scale = 0.0; - VectorFst *h_fst = GetHTransducer(cfst.ILabelInfo(), + VectorFst *h_fst = GetHTransducer(inv_cfst.IlabelInfo(), ctx_dep, trans_model, h_cfg, diff --git a/src/chain/language-model.cc b/src/chain/language-model.cc index 41e06116ea8..dd69340a6b8 100644 --- a/src/chain/language-model.cc +++ b/src/chain/language-model.cc @@ -129,7 +129,6 @@ int32 LanguageModelEstimator::FindOrCreateLmStateIndexForHistory( int32 backoff_lm_state = FindOrCreateLmStateIndexForHistory( backoff_hist); lm_states_[ans].backoff_lmstate_index = backoff_lm_state; - hist_to_lmstate_index_[backoff_hist] = backoff_lm_state; } return ans; } @@ -298,7 +297,7 @@ int32 LanguageModelEstimator::AssignFstStates() { void LanguageModelEstimator::Estimate(fst::StdVectorFst *fst) { KALDI_LOG << "Estimating language model with --no-prune-ngram-order=" << opts_.no_prune_ngram_order << ", --ngram-order=" - << opts_.ngram_order << ", --num-extra-lm-state=" + << opts_.ngram_order << ", --num-extra-lm-states=" << opts_.num_extra_lm_states; SetParentCounts(); num_basic_lm_states_ = CheckActiveStates(); @@ -408,5 +407,3 @@ void LanguageModelEstimator::OutputToFst( } // namespace chain } // namespace kaldi - - diff --git a/src/chainbin/Makefile b/src/chainbin/Makefile index 61f653f174f..41ac7342d17 100644 --- a/src/chainbin/Makefile +++ b/src/chainbin/Makefile @@ -25,7 +25,7 @@ ADDLIBS = ../nnet3/kaldi-nnet3.a ../chain/kaldi-chain.a \ ../cudamatrix/kaldi-cudamatrix.a ../decoder/kaldi-decoder.a \ ../lat/kaldi-lat.a ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/chainbin/chain-get-supervision.cc b/src/chainbin/chain-get-supervision.cc index 6090d9f0058..1ac89d4630b 100644 --- a/src/chainbin/chain-get-supervision.cc +++ b/src/chainbin/chain-get-supervision.cc @@ -22,6 +22,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "chain/chain-supervision.h" +#include "tree/context-dep.h" namespace kaldi { namespace chain { diff --git a/src/chainbin/nnet3-chain-train.cc b/src/chainbin/nnet3-chain-train.cc index 9ea7ba1b06f..536669a17d3 100644 --- a/src/chainbin/nnet3-chain-train.cc +++ b/src/chainbin/nnet3-chain-train.cc @@ -20,6 +20,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "nnet3/nnet-chain-training.h" +#include "cudamatrix/cu-allocator.h" int main(int argc, char *argv[]) { @@ -52,6 +53,7 @@ int main(int argc, char *argv[]) { "yes|no|optional|wait, only has effect if compiled with CUDA"); opts.Register(&po); + RegisterCuAllocatorOptions(&po); po.Read(argc, argv); diff --git a/src/configure b/src/configure index a954583d3fb..b21cc48f7ee 100755 --- a/src/configure +++ b/src/configure @@ -22,6 +22,8 @@ # ./configure --atlas-root=../tools/ATLAS/build # ./configure --use-cuda=no # disable CUDA detection (will build cpu-only # # version of kaldi even on CUDA-enabled machine +# ./configure --use-cuda --cudatk-dir=/usr/local/cuda/ --cuda-arch=-arch=sm_70 +# # Use cuda in /usr/local/cuda and set the arch to sm_70 # ./configure --static --fst-root=/opt/cross/armv8hf \ # --atlas-root=/opt/cross/armv8hf --host=armv8-rpi3-linux-gnueabihf # # Cross compile for armv8hf, this assumes that you have openfst built @@ -42,7 +44,7 @@ # This should be incremented after any significant change to the configure # script, i.e. any change affecting kaldi.mk or the build system as a whole. -CONFIGURE_VERSION=7 +CONFIGURE_VERSION=10 if ! [ -x "$PWD/configure" ]; then echo 'You must run "configure" from the src/ directory.' @@ -65,6 +67,7 @@ Configuration options: --shared Build and link against shared libraries [default=no] --use-cuda Build with CUDA [default=yes] --cudatk-dir=DIR CUDA toolkit directory + --cuda-arch=FLAGS Override the default CUDA_ARCH flags. See https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#nvcc-examples. --double-precision Build with BaseFloat set to double if yes [default=no], mostly useful for testing purposes. --static-fst Build with static OpenFst libraries [default=no] @@ -114,8 +117,13 @@ function rel2abs { fi } +function read_value { + local val=`expr "X$1" : '[^=]*=\(.*\)'`; + echo $val +} + function read_dirname { - local dir_name=`expr "X$1" : '[^=]*=\(.*\)'`; + local dir_name=`read_value $1` local retval=`rel2abs $dir_name` [ -z $retval ] && echo "Bad option '$1': no such directory" && exit 1; echo $retval @@ -367,7 +375,7 @@ function linux_configure_mkl_threading { function configure_cuda { # Check for CUDA toolkit in the system if [ ! -d "$CUDATKDIR" ]; then - for base in /Developer/NVIDIA/CUDA-6.0 /usr/local/share/cuda /usr/local/cuda /pkgs_local/cuda-3.2/ /opt/nvidia_cuda/cuda-6.0/ /usr/; do + for base in /usr/local/share/cuda /usr/local/cuda /usr/; do if [ -f $base/bin/nvcc ]; then CUDATKDIR=$base fi @@ -395,14 +403,6 @@ function configure_cuda { GCC_VER=$($COMPILER -dumpversion) GCC_VER_NUM=$(echo $GCC_VER | sed 's/\./ /g' | xargs printf "%d%02d%02d") case $CUDA_VERSION in - 5_5) - MIN_UNSUPPORTED_GCC_VER="5.0" - MIN_UNSUPPORTED_GCC_VER_NUM=50000; - ;; - 6_*) - MIN_UNSUPPORTED_GCC_VER="5.0" - MIN_UNSUPPORTED_GCC_VER_NUM=50000; - ;; 7_*) MIN_UNSUPPORTED_GCC_VER="5.0" MIN_UNSUPPORTED_GCC_VER_NUM=50000; @@ -415,7 +415,7 @@ function configure_cuda { MIN_UNSUPPORTED_GCC_VER="7.0" MIN_UNSUPPORTED_GCC_VER_NUM=70000; ;; - 9_2 | 9_*) + 9_2 | 9_* | 10_*) MIN_UNSUPPORTED_GCC_VER="8.0" MIN_UNSUPPORTED_GCC_VER_NUM=80000; ;; @@ -429,14 +429,17 @@ function configure_cuda { fi fi - case $CUDA_VERSION in - 5_5) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35" ;; - 6_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50" ;; - 7_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_53,code=sm_53" ;; - 8_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62" ;; - 9_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70" ;; - *) echo "Unsupported CUDA_VERSION (CUDA_VERSION=$CUDA_VERSION), please report it to Kaldi mailing list, together with 'nvcc -h' or 'ptxas -h' which lists allowed -gencode values..."; exit 1 ;; - esac + if [ -z "$CUDA_ARCH" ]; then + case $CUDA_VERSION in + 5_5) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35" ;; + 6_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50" ;; + 7_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_53,code=sm_53" ;; + 8_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62" ;; + 9_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70" ;; + 10_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_72,code=sm_72 -gencode arch=compute_75,code=sm_75" ;; + *) echo "Unsupported CUDA_VERSION (CUDA_VERSION=$CUDA_VERSION), please report it to Kaldi mailing list, together with 'nvcc -h' or 'ptxas -h' which lists allowed -gencode values..."; exit 1 ;; + esac + fi echo "Using CUDA toolkit $CUDATKDIR (nvcc compiler and runtime libraries)" echo >> kaldi.mk @@ -459,7 +462,8 @@ function configure_cuda { elif [ "`uname -m`" == "ppc64le" ]; then cat makefiles/cuda_64bit.mk >> kaldi.mk else - cat makefiles/cuda_32bit.mk >> kaldi.mk + echo "CUDA will not be used! CUDA is not supported with 32-bit builds." + exit 1; fi else @@ -555,66 +559,23 @@ function linux_check_static { fi } -function linux_configure_debian_ubuntu { - m=$1 - ATLASLIBS="/usr/lib$m/atlas-base/libatlas.so.3gf /usr/lib$m/atlas-base/libf77blas.so.3gf /usr/lib$m/atlas-base/libcblas.so.3gf /usr/lib$m/atlas-base/liblapack_atlas.so.3gf" - for f in $ATLASLIBS; do - [ ! -f $f ] && return 1; - done - lapacklib=$(echo $ATLASLIBS | awk '{print $NF}') - if ! nm --dynamic $lapacklib | grep ATL_cgetrf >/dev/null; then - exit 1; - fi - echo ATLASINC = $ATLASROOT/include >> kaldi.mk - echo ATLASLIBS = $ATLASLIBS >> kaldi.mk - echo >> kaldi.mk - if [[ "$TARGET_ARCH" == arm* ]]; then - cat makefiles/linux_atlas_arm.mk >> kaldi.mk - elif [[ "$TARGET_ARCH" == ppc64le ]]; then - cat makefiles/linux_atlas_ppc64le.mk >> kaldi.mk - else - cat makefiles/linux_atlas.mk >> kaldi.mk - fi - echo "Successfully configured for Debian/Ubuntu Linux [dynamic libraries] with ATLASLIBS =$ATLASLIBS" - $use_cuda && configure_cuda - linux_configure_speex -} - -function linux_configure_debian_ubuntu3 { - ATLASLIBS="/usr/lib/libatlas.so.3 /usr/lib/libf77blas.so.3 /usr/lib/libcblas.so.3 /usr/lib/liblapack_atlas.so.3" - for f in $ATLASLIBS; do - [ ! -f $f ] && return 1; - done - lapacklib=$(echo $ATLASLIBS | awk '{print $NF}') - if ! nm --dynamic $lapacklib | grep ATL_cgetrf >/dev/null; then - exit 1; - fi - echo ATLASINC = $ATLASROOT/include >> kaldi.mk - echo ATLASLIBS = $ATLASLIBS >> kaldi.mk - echo >> kaldi.mk - if [[ "$TARGET_ARCH" == arm* ]]; then - cat makefiles/linux_atlas_arm.mk >> kaldi.mk - elif [[ "$TARGET_ARCH" == ppc64le ]]; then - cat makefiles/linux_atlas_ppc64le.mk >> kaldi.mk - else - cat makefiles/linux_atlas.mk >> kaldi.mk - fi - echo "Successfully configured for Debian/Ubuntu Linux [dynamic libraries] with ATLASLIBS =$ATLASLIBS" - $use_cuda && configure_cuda - linux_configure_speex -} - -function linux_configure_debian7 { - ATLASLIBS="/usr/lib/atlas-base/libatlas.so.3.0 /usr/lib/atlas-base/libf77blas.so.3.0 /usr/lib/atlas-base/libcblas.so.3 /usr/lib/atlas-base/liblapack_atlas.so.3" +function linux_configure_atlas_generic { + # You pass in a directory (e.g. /usr/lib/atlas-base) and a suffix (e.g. so.3.0) + # and it tries to find ATLAS libraries with that dir and suffix. On success it + # returns 0; on failure, it returns 1. + dir=$1 + suffix=$2 + ATLASLIBS="$dir/libatlas.$suffix $dir/libf77blas.$suffix $dir/libcblas.$suffix $dir/liblapack_atlas.$suffix" for f in $ATLASLIBS; do [ ! -f $f ] && return 1; done lapacklib=$(echo $ATLASLIBS | awk '{print $NF}') if ! nm --dynamic $lapacklib | grep ATL_cgetrf >/dev/null; then + echo "configure: failed to find symbol ATL_cgetrf in library $lapacklib" exit 1; fi libdir=$(dirname $(echo $ATLASLIBS | awk '{print $1}')) - [ -z "$libdir" ] && echo "Error getting libdir in linux_configure_debian7" && exit 1; + [ -z "$libdir" ] && echo "Error getting libdir in linux_configure_atlas_generic: dir=$dir,suffix=$suffix" && exit 1; echo ATLASINC = $ATLASROOT/include >> kaldi.mk echo ATLASLIBS = $ATLASLIBS -Wl,-rpath=$libdir >> kaldi.mk echo >> kaldi.mk @@ -625,33 +586,11 @@ function linux_configure_debian7 { else cat makefiles/linux_atlas.mk >> kaldi.mk fi - echo "Successfully configured for Debian 7 [dynamic libraries] with ATLASLIBS =$ATLASLIBS" + echo "Successfully configured ATLAS with ATLASLIBS=$ATLASLIBS" $use_cuda && configure_cuda linux_configure_speex } -function linux_configure_redhat { - m=$1 # 64 or empty. - ATLASLIBS="/usr/lib$m/atlas/libatlas.so.3 /usr/lib$m/atlas/libf77blas.so.3 /usr/lib$m/atlas/libcblas.so.3 /usr/lib$m/atlas/libclapack.so.3" - for f in $ATLASLIBS; do - [ ! -f $f ] && return 1; - done - libdir=$(dirname $(echo $ATLASLIBS | awk '{print $1}')) - [ -z "$libdir" ] && echo "Error getting libdir in linux_configure_redhat" && exit 1; - echo ATLASINC = $ATLASROOT/include >> kaldi.mk - echo ATLASLIBS = $ATLASLIBS -Wl,-rpath=$libdir >> kaldi.mk - echo >> kaldi.mk - if [[ "$TARGET_ARCH" == arm* ]]; then - cat makefiles/linux_atlas_arm.mk >> kaldi.mk - elif [[ "$TARGET_ARCH" == ppc64le ]]; then - cat makefiles/linux_atlas_ppc64le.mk >> kaldi.mk - else - cat makefiles/linux_atlas.mk >> kaldi.mk - fi - echo "Successfully configured for red hat [dynamic libraries] with ATLASLIBS =$ATLASLIBS" - $use_cuda && configure_cuda -} - function linux_configure_redhat_fat { # This is for when only two so-called 'fat' ATLAS libs are provided: # libsatlas.so.3 and libtatlas.so.3. @@ -673,11 +612,11 @@ function linux_configure_redhat_fat { else cat makefiles/linux_atlas.mk >> kaldi.mk fi - echo "Successfully configured for red hat [dynamic libraries, fat] with ATLASLIBS =$ATLASLIBS" $use_cuda && configure_cuda + echo "Successfully configured for red hat [dynamic libraries, fat] with ATLASLIBS =$ATLASLIBS" } -function linux_configure_static { +function linux_configure_atlas_static { if $threaded_atlas; then pt=pt; else pt=""; fi if [ -z $ATLASLIBDIR ]; then # Note: it'll pick up the last one below. @@ -696,11 +635,11 @@ function linux_configure_static { echo "Validating presence of ATLAS libs in $ATLASLIBDIR" ATLASLIBS= # The Lapack part of ATLAS seems to appear under various different names.. but it - # should always have symbols like ATL_cgetrf defined, so we test for this, - # for all the names we have encountered. + # should always have symbols like ATL_cgetrf and clapack_cgetrf defined, so we test for this. for libname in liblapack liblapack_atlas libclapack; do if [ -f $ATLASLIBDIR/${libname}.a -a "$ATLASLIBS" == "" ]; then - if nm $ATLASLIBDIR/${libname}.a | grep ATL_cgetrf >/dev/null; then + if nm $ATLASLIBDIR/${libname}.a | grep ATL_cgetrf >/dev/null && \ + nm $ATLASLIBDIR/${libname}.a | grep clapack_cgetrf >/dev/null; then ATLASLIBS=$ATLASLIBDIR/${libname}.a echo "Using library $ATLASLIBS as ATLAS's CLAPACK library." fi @@ -779,11 +718,11 @@ function linux_configure_dynamic { echo "Validating presence of ATLAS libs in $ATLASLIBDIR" ATLASLIBS= # The Lapack part of ATLAS seems to appear under various different names.. but it - # should always have symbols like ATL_cgetrf defined, so we test for this, - # for all the names we have encountered. + # should always have symbols like clapack_cgetrf and ATL_cgetrf defined, so we test for this. for libname in lapack lapack_atlas clapack; do if [ -f $ATLASLIBDIR/lib${libname}.so -a "$ATLASLIBS" == "" ]; then - if nm --dynamic $ATLASLIBDIR/lib${libname}.so | grep ATL_cgetrf >/dev/null; then + if nm --dynamic $ATLASLIBDIR/lib${libname}.so | grep clapack_cgetrf >/dev/null && \ + nm --dynamic $ATLASLIBDIR/lib${libname}.so | grep ATL_cgetrf >/dev/null; then ATLASLIBS="$ATLASLIBDIR/lib${libname}.so" echo "Using library $ATLASLIBS as ATLAS's CLAPACK library." fi @@ -860,6 +799,7 @@ android=false MATHLIB='ATLAS' ATLASROOT=`rel2abs ../tools/ATLAS_headers/` FSTROOT=`rel2abs ../tools/openfst` +CUBROOT=`rel2abs ../tools/cub` # Save the command line to include in kaldi.mk cmd_line="$0 $@" @@ -946,12 +886,15 @@ do mkl_threading=sequential; shift ;; --mkl-threading=*) - mkl_threading=`expr "X$1" : '[^=]*=\(.*\)'`; + mkl_threading=`read_value $1`; threaded_atlas=true; shift ;; --fst-root=*) FSTROOT=`read_dirname $1`; shift ;; + --cub-root=*) + CUBROOT=`read_dirname $1`; + shift ;; --clapack-root=*) CLAPACKROOT=`read_dirname $1`; shift ;; @@ -977,19 +920,22 @@ do OMPLIBDIR=`read_dirname $1`; shift ;; --mathlib=*) - MATHLIB=`expr "X$1" : '[^=]*=\(.*\)'`; + MATHLIB=`read_value $1`; shift ;; --cudatk-dir=*) CUDATKDIR=`read_dirname $1`; shift ;; #CUDA is used in src/cudamatrix and src/nnet{,bin} only + --cuda-arch=*) + CUDA_ARCH=`read_value $1`; + shift;; --fst-version=*) - OPENFST_VER=`expr "X$1" : '[^=]*=\(.*\)'`; + OPENFST_VER=`read_value $1`; shift;; --host=*) # The type of system where built programs and libraries will run. # It should be in the format cpu-vendor-os. If specified, this script # will infer the target architecture from the specified host triple. - HOST=`expr "X$1" : '[^=]*=\(.*\)'`; + HOST=`read_value $1`; shift ;; --android-incdir=*) android=true; @@ -1118,6 +1064,16 @@ echo "OPENFSTLIBS = $OPENFSTLIBS" >> kaldi.mk echo "OPENFSTLDFLAGS = $OPENFSTLDFLAGS" >> kaldi.mk echo >> kaldi.mk +$use_cuda && echo "Checking cub library in $CUBROOT ..." +if [[ "$use_cuda" = true && ! -f $CUBROOT/cub/cub.cuh ]]; then + failure "Could not find file $CUBROOT/cub/cub.cuh: + you may not have installed cub. Go to ../tools/ and type + e.g. 'make cub'; cub is a new requirement." +else + echo "CUBROOT = $CUBROOT" >> kaldi.mk +fi + + # OS-specific steps given below append to kaldi.mk echo "Doing OS specific configurations ..." @@ -1209,33 +1165,18 @@ elif [ "`uname`" == "Linux" ]; then # containing {liblapack.a,libblas.a}, and linking against just these two # libraries worked. - if $static_math; then - # Prefer static to dynamic math. - linux_configure_static || \ - linux_configure_debian_ubuntu3 || \ - linux_configure_dynamic || \ - linux_configure_debian_ubuntu 64 || \ - linux_configure_debian_ubuntu || \ - linux_configure_debian7 || \ - linux_configure_redhat 64 || \ - linux_configure_redhat || \ - linux_configure_redhat_fat 64 || \ - linux_configure_redhat_fat || \ - linux_atlas_failure "Failed to configure ATLAS libraries"; - else - # Prefer dynamic to static math. - linux_configure_debian_ubuntu3 || \ - linux_configure_dynamic || \ - linux_configure_static || \ - linux_configure_debian_ubuntu 64 || \ - linux_configure_debian_ubuntu || \ - linux_configure_debian7 || \ - linux_configure_redhat 64 || \ - linux_configure_redhat || \ - linux_configure_redhat_fat 64 || \ - linux_configure_redhat_fat || \ - linux_atlas_failure "Failed to configure ATLAS libraries"; - fi + ( $static_math && linux_configure_atlas_static ) || \ + linux_configure_atlas_generic /usr/lib "so.3" || \ + linux_configure_atlas_generic /usr/lib/atlas-base "so.3gf" || \ + linux_configure_atlas_generic /usr/lib64/atlas-base "so.3gf" \ + linux_configure_atlas_generic /usr/lib/atlas "so.3" || \ + linux_configure_atlas_generic /usr/lib64/atlas "so.3" || \ + linux_configure_atlas_generic /usr/lib/x86_64-linux-gnu/ "so.3" || \ + linux_configure_atlas_generic /usr/lib/x86_64-linux-gnu/ "so" || \ + linux_configure_redhat_fat 64 || \ + linux_configure_redhat_fat || \ + linux_configure_atlas_static || \ + linux_atlas_failure "Failed to configure ATLAS libraries"; elif [ "$MATHLIB" == "MKL" ]; then if [ "$TARGET_ARCH" != "x86_64" ]; then @@ -1317,22 +1258,42 @@ elif [ "`uname`" == "Linux" ]; then if [ -z "$OPENBLASROOT" ]; then failure "Must specify the location of OPENBLAS with --openblas-root option (and it must exist)" fi - if [ ! -f $OPENBLASROOT/lib/libopenblas.so ]; then + if [ -f $OPENBLASROOT/lib/libopenblas.so ]; then + OPENBLASLIBDIR=$OPENBLASROOT/lib + elif [ -f $OPENBLASROOT/lib64/libopenblas.so ]; then + # in REDHAT/CentOS package installs, the library is located here + OPENBLASLIBDIR=$OPENBLASROOT/lib64 + else failure "Expected to find the file $OPENBLASROOT/lib/libopenblas.so" fi + if [ -f $OPENBLASROOT/include/cblas.h ] ; then + OPENBLASINCDIR=$OPENBLASROOT/include + elif [ -f $OPENBLASROOT/include/openblas/cblas.h ] ; then + # in REDHAT/CentOS/Ubuntu package installs, the includes are located here + OPENBLASINCDIR=$OPENBLASROOT/include/openblas + else + echo "$0: ***** Using OpenBlas from $OPENBLASROOT but cblas.h is not found. " + echo " ****** Assuming openblas is aleady in a default include path, but" + echo " ***** if you get compilation messages about not finding files like cblas.h," + echo " ***** you should look into this (e.g. make sure to install the 'openblas-dev' package," + echo " ***** if it is a package-based install)." + OPENBLASINCDIR="/usr/include" + fi echo "Your math library seems to be OpenBLAS from $OPENBLASROOT. Configuring appropriately." if $static_math; then echo "Configuring static OpenBlas since --static-math=yes" - OPENBLASLIBS="$OPENBLASROOT/lib/libopenblas.a -lgfortran" + OPENBLASLIBS="$OPENBLASLIBDIR/libopenblas.a -lgfortran" else echo "Configuring dynamically loaded OpenBlas since --static-math=no (the default)" - OPENBLASLIBS="-L$OPENBLASROOT/lib -lopenblas -lgfortran -Wl,-rpath=$OPENBLASROOT/lib" + OPENBLASLIBS="-L$OPENBLASLIBDIR -lopenblas -lgfortran -Wl,-rpath=$OPENBLASLIBDIR" fi - echo "OPENBLASINC = $OPENBLASROOT/include" >> kaldi.mk + echo "OPENBLASINC = $OPENBLASINCDIR" >> kaldi.mk echo "OPENBLASLIBS = $OPENBLASLIBS" >> kaldi.mk echo >> kaldi.mk if [[ "$TARGET_ARCH" == arm* ]]; then cat makefiles/linux_openblas_arm.mk >> kaldi.mk + elif [[ "$TARGET_ARCH" == aarch64* ]]; then + cat makefiles/linux_openblas_aarch64.mk >> kaldi.mk elif [[ "$TARGET_ARCH" == ppc64le ]]; then cat makefiles/linux_openblas_ppc64le.mk >> kaldi.mk else diff --git a/src/cudamatrix/Makefile b/src/cudamatrix/Makefile index ca831390ea9..45c2ba44fd7 100644 --- a/src/cudamatrix/Makefile +++ b/src/cudamatrix/Makefile @@ -18,8 +18,7 @@ endif LIBNAME = kaldi-cudamatrix -ADDLIBS = ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a +ADDLIBS = ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a # Make sure we have CUDA_ARCH from kaldi.mk, ifeq ($(CUDA), true) diff --git a/src/cudamatrix/cu-allocator.cc b/src/cudamatrix/cu-allocator.cc index fec75b01a3f..80380bdb92c 100644 --- a/src/cudamatrix/cu-allocator.cc +++ b/src/cudamatrix/cu-allocator.cc @@ -1,6 +1,6 @@ // cudamatrix/cu-allocator.cc -// Copyright 2015 Johns Hopkins University (author: Daniel Povey) +// Copyright 2015-2018 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // @@ -19,6 +19,8 @@ +#include "cudamatrix/cu-allocator.h" + #if HAVE_CUDA == 1 #include @@ -28,6 +30,10 @@ #include #include #include +#ifndef _MSC_VER +#include +#endif + #include "cudamatrix/cu-common.h" #include "cudamatrix/cu-device.h" #include "cudamatrix/cu-matrix.h" @@ -39,10 +45,223 @@ namespace kaldi { void* CuMemoryAllocator::Malloc(size_t size) { - // For now just call MallocPitch and throw away the pitch, to avoid - // duplicating code here. Apparently the time difference is quite small. - size_t pitch; - return MallocPitch(size, 1, &pitch); + Timer tim; + if (!opts_.cache_memory) { + void *ans; + CU_SAFE_CALL(cudaMalloc(&ans, size)); + double elapsed = tim.Elapsed(); + tot_time_taken_ += elapsed; + malloc_time_taken_ += elapsed; + t_++; + return ans; + } + + // We could perhaps change this to KALDI_PARANOID_ASSERT to save time. + KALDI_ASSERT(size != 0); + + // Round up 'size' to a multiple of 256; this ensures the right kind of + // memory alignment. + size = (size + 255) & ~((size_t)255); + void *ans = MallocInternal(size); + tot_time_taken_ += tim.Elapsed(); + return ans; +} + + +CuMemoryAllocator::MemoryBlock *CuMemoryAllocator::SplitBlock( + MemoryBlock *block, size_t size) { + SubRegion *subregion = block->subregion; + // new_block will become the right-most part of 'block', and 'block' will + // be the left-most part. + MemoryBlock *new_block = new MemoryBlock; + bool return_new_block; + char *new_begin; + + // We now decide whether to make the left part of 'block' be of size ('size') + // and return it (the 'if' branch of the if-else block below), or the right + // part (the 'else' branch). We decide this based on heuristics. Basically, + // we want to allocate the sub-block that's either next to the edge of the + // MemoryRegion, or next to something that was allocated long ago (and which, + // we assume won't be deallocated for a relatively long time). That is: we + // want to leave the un-allocated memory next to a memory block that was + // recently allocated (and thus is likely to be freed sooner), so that when + // that block is freed we can merge it with the still-unallocated piece into a + // larger block; this will reduce fragmentation. But if this block spans + // multiple sub-regions we don't want to do that, as that would be against our + // heuristic of, where possible, allocating memory from lower-numbered + // sub-regions. + // + // Bear in mind that we can assume block->next and block->prev, if they are + // non-NULL, are both currently allocated, since 'block' is un-allocated and + // we would have merged any adjacent un-allocated sub-regions. + if (block->next != NULL && block->prev != NULL && + block->prev->t < block->next->t && + block->next->subregion == subregion) { + // We'll allocate the right part of the block, since the left side is next + // to a relatively recently-allocated block. + return_new_block = true; + new_begin = block->end - size; + } else { + // We'll allocate the left part of the block. + return_new_block = false; + new_begin = block->begin + size; + } + + // The following code makes sure the SubRegion for 'new_block' is correct, + // i.e. its 'begin' is >= the 'begin' of the subregion and < the 'end' of the + // subregion. If the following loop segfaults, it indicates a bug somewhere + // else. + while (new_begin >= subregion->end) + subregion = subregion->next; + MemoryBlock *next_block = block->next; + new_block->begin = new_begin; + new_block->end = block->end; + new_block->subregion = subregion; + new_block->allocated = false; + new_block->thread_id = block->thread_id; + new_block->t = block->t; + new_block->next = next_block; + new_block->prev = block; + if (next_block) + next_block->prev = new_block; + block->next = new_block; + block->end = new_begin; + + // Add the split-up piece that we won't be allocating, to the + // 'free_blocks' member of its subregion. + if (return_new_block) { + AddToFreeBlocks(block); + return new_block; + } else { + AddToFreeBlocks(new_block); + return block; + } +} + + +void CuMemoryAllocator::RemoveFromFreeBlocks(MemoryBlock *block) { + SubRegion *subregion = block->subregion; + size_t block_size = block->end - block->begin; + std::pair p(block_size, block); + size_t num_removed = subregion->free_blocks.erase(p); + KALDI_ASSERT(num_removed != 0); + // Update largest_free_block_, if needed. + size_t subregion_index = subregion->subregion_index; + if (block_size == largest_free_block_[subregion_index]) { + if (subregion->free_blocks.empty()) + largest_free_block_[subregion_index] = 0; + else + largest_free_block_[subregion_index] = + subregion->free_blocks.begin()->first; + } +} + +void CuMemoryAllocator::AddToFreeBlocks(MemoryBlock *block) { + SubRegion *subregion = block->subregion; + KALDI_PARANOID_ASSERT(block->begin >= subregion->begin && + block->begin < subregion->end); + size_t block_size = block->end - block->begin, + subregion_index = subregion->subregion_index; + // Update largest_free_block_, if needed. + if (block_size > largest_free_block_[subregion_index]) { + largest_free_block_[subregion_index] = block_size; + } + subregion->free_blocks.insert(std::pair(block_size, block)); +} + + +void* CuMemoryAllocator::MallocFromSubregion(SubRegion *subregion, + size_t size) { + // NULL is implementation defined and doesn't have to be zero so we can't + // guarantee that NULL will be <= a valid pointer-- so we cast to a pointer + // from zero instead of using NULL. + std::pair p(size, (MemoryBlock*)0); + + std::set >::iterator iter = + subregion->free_blocks.lower_bound(p); + // so now 'iter' is the first member of free_blocks whose size_t value is >= + // size. If 'iter' was equal to the end() of that multi_map, it would be a + // bug because the calling code checked that the largest free block in this + // region was sufficiently large. We don't check this; if it segfaults, we'll + // debug. + + // search for a block that we don't have to synchronize on + int max_iters = 20; + auto search_iter = iter; + for (int32 i = 0; + search_iter != subregion->free_blocks.end() && i < max_iters; + ++i, ++search_iter) { + if (search_iter->second->thread_id == std::this_thread::get_id() || + search_iter->second->t <= synchronize_gpu_t_) { + iter = search_iter; + break; + } + } + + MemoryBlock *block = iter->second; + // Erase 'block' from its subregion's free blocks list... the next lines are + // similar to RemoveFromFreeBlocks(), but we code it directly as we have the + // iterator here, and it would be wasteful to do another lookup. + subregion->free_blocks.erase(iter); + // Update largest_free_block_, if needed. The following few lines of code also appear + // in RemoveFromFreeBlocks(). + size_t block_size = block->end - block->begin, + subregion_index = subregion->subregion_index; + if (block_size == largest_free_block_[subregion_index]) { + if (subregion->free_blocks.empty()) + largest_free_block_[subregion_index] = 0; + else + largest_free_block_[subregion_index] = + subregion->free_blocks.begin()->first; + } + + KALDI_PARANOID_ASSERT(block_size >= size && block->allocated == false); + + // the most memory we allow to be 'wasted' by failing to split a block, is the + // smaller of: 1/16 of the size we're allocating, or half a megabyte. + size_t allowed_extra_size = std::min(size >> 4, 524288); + if (block_size > size + allowed_extra_size) { + // If the requested block is substantially larger than what was requested, + // split it so we don't waste memory. + block = SplitBlock(block, size); + } + + if (std::this_thread::get_id() != block->thread_id && + block->t > synchronize_gpu_t_) { + // see NOTE ON SYNCHRONIZATION in the header. + SynchronizeGpu(); + synchronize_gpu_t_ = t_; + num_synchronizations_++; + } + block->allocated = true; + block->t = t_; + allocated_block_map_[block->begin] = block; + allocated_memory_ += (block->end - block->begin); + if (allocated_memory_ > max_allocated_memory_) + max_allocated_memory_ = allocated_memory_; + return block->begin; +} + +// By the time MallocInternal is called, we will have ensured that 'size' is +// a nonzero multiple of 256 (for memory aligment reasons). +// inline +void* CuMemoryAllocator::MallocInternal(size_t size) { +start: + std::vector::const_iterator iter = largest_free_block_.begin(), + end = largest_free_block_.end(); + size_t subregion_index = 0; + for (; iter != end; ++iter, ++subregion_index) { + if (*iter > size) { + return MallocFromSubregion(subregions_[subregion_index], size); + } + } + // We dropped off the loop without finding a subregion with enough memory + // to satisfy the request -> allocate a new region. + AllocateNewRegion(size); + // An infinite loop shouldn't be possible because after calling + // AllocateNewRegion(size), there should always be a SubRegion + // with that size available. + goto start; } // Returns max(0, floor(log_2(i))). Not tested independently. @@ -63,311 +282,363 @@ static inline size_t IntegerLog2(size_t i) { return ans; } -//inline -CuMemoryAllocator::MruCache& CuMemoryAllocator::GetCacheForSize( - size_t num_bytes) { - size_t bucket_index = IntegerLog2(num_bytes); - KALDI_ASSERT(num_bytes > 0 && bucket_index < caches_.size()); - return caches_[bucket_index]; -} - -//inline -void* CuMemoryAllocator::MallocPitchInternal(size_t row_bytes, - size_t num_rows, - size_t *pitch) { - num_system_allocations_++; - void *ans; - cudaError_t e; - for (int32 i = 0; i <= 2; i++) { - if (num_rows != 1) { - CuTimer tim; - e = cudaMallocPitch(&ans, pitch, row_bytes, num_rows); - tot_time_taken_in_cuda_malloc_pitch_ += tim.Elapsed(); +std::string GetFreeGpuMemory(int64* free, int64* total) { +#ifdef _MSC_VER + size_t mem_free, mem_total; + cuMemGetInfo_v2(&mem_free, &mem_total); +#else + // define the function signature type + size_t mem_free, mem_total; + { + // we will load cuMemGetInfo_v2 dynamically from libcuda.so + // pre-fill ``safe'' values that will not cause problems + mem_free = 1; mem_total = 1; + // open libcuda.so + void* libcuda = dlopen("libcuda.so", RTLD_LAZY); + if (NULL == libcuda) { + KALDI_WARN << "cannot open libcuda.so"; } else { - CuTimer tim; - // we might save a little time this way. - e = cudaMalloc(&ans, row_bytes); - tot_time_taken_in_cuda_malloc_ += tim.Elapsed(); - *pitch = row_bytes; - } - if (e != cudaSuccess) { - PrintMemoryUsage(); - // On the first 2 out of the 3 iters, try freeing memory. - if (i <= 1) { - KALDI_WARN << "Allocation of " << row_bytes << " x " - << num_rows << " region failed: freeing some memory and " - << "trying again. "; - BaseFloat new_memory_factor = 1.1; - if (opts_.memory_factor > new_memory_factor) { - KALDI_LOG << "To avoid future problems like this, changing " - << "memory_factor from " << opts_.memory_factor << " to " - << new_memory_factor; - opts_.memory_factor = new_memory_factor; - } - size_t memory_cached = MemoryCached(), - memory_requested = row_bytes * num_rows, - memory_to_free = std::max(memory_cached / 2, - std::min(memory_cached, - memory_requested)); - FreeSomeCachedMemory(memory_to_free); + // define the function signature type + // and get the symbol + typedef CUresult (*cu_fun_ptr)(size_t*, size_t*); + cu_fun_ptr dl_cuMemGetInfo = (cu_fun_ptr)dlsym(libcuda,"cuMemGetInfo_v2"); + if (NULL == dl_cuMemGetInfo) { + KALDI_WARN << "cannot load cuMemGetInfo from libcuda.so"; } else { - KALDI_ERR << "Cannot allocate the requested memory (" - << row_bytes << " x " << num_rows << " = " - << row_bytes * num_rows << " bytes)"; + // call the function + dl_cuMemGetInfo(&mem_free, &mem_total); } - cudaGetLastError(); // Clear the error state. - } else { - break; + // close the library + dlclose(libcuda); } } - return ans; +#endif + // copy the output values outside + if (NULL != free) *free = mem_free; + if (NULL != total) *total = mem_total; + // prepare the text output + std::ostringstream os; + os << "free:" << mem_free/(1024*1024) << "M, " + << "used:" << (mem_total-mem_free)/(1024*1024) << "M, " + << "total:" << mem_total/(1024*1024) << "M, " + << "free/total:" << mem_free/(float)mem_total; + return os.str(); } void CuMemoryAllocator::PrintMemoryUsage() const { - KALDI_LOG << "Memory usage: " << cur_bytes_allocated_ - << " bytes currently allocated (max: " - << max_bytes_allocated_ << "); " << cur_bytes_used_ - << " currently in use by user (max: " << max_bytes_used_ << ")" - << "; " << num_system_allocations_ << '/' - << num_user_allocations_ << " calls to Malloc* resulted in " - << "CUDA calls."; - if (GetVerboseLevel() >= 1) { - // CuTimer only accumulates stats at verbose level 1 or above. - KALDI_LOG << "Time taken in cudaMallocPitch=" << tot_time_taken_in_cuda_malloc_pitch_ - << ", in cudaMalloc=" << tot_time_taken_in_cuda_malloc_ - << ", in cudaFree=" << tot_time_taken_in_cuda_free_ - << ", in this->MallocPitch()=" << tot_time_taken_in_malloc_pitch_; + if (!opts_.cache_memory) { + KALDI_LOG << "Not caching allocations; time taken in " + << "malloc/free is " << malloc_time_taken_ + << "/" << (tot_time_taken_ - malloc_time_taken_) + << ", num operations is " << t_ + << "; device memory info: " + << GetFreeGpuMemory(NULL, NULL); + return; + } + + size_t num_blocks_allocated = 0, num_blocks_free = 0, + memory_allocated = 0, memory_held = 0, + largest_free_block = 0, largest_allocated_block = 0; + + for (size_t i = 0; i < memory_regions_.size(); i++) { + MemoryBlock *m = memory_regions_[i].block_begin; + KALDI_ASSERT(m->begin == memory_regions_[i].begin); + for (; m != NULL; m = m->next) { + size_t size = m->end - m->begin; + if (m->allocated) { + num_blocks_allocated++; + memory_allocated += size; + if (size > largest_allocated_block) + largest_allocated_block = size; + } else { + num_blocks_free++; + if (size > largest_free_block) + largest_free_block = size; + } + memory_held += size; + // The following is just some sanity checks; this code is rarely called so + // it's a reasonable place to put them. + if (m->next) { + KALDI_ASSERT(m->next->prev == m && m->end == m->next->begin); + } else { + KALDI_ASSERT(m->end == memory_regions_[m->subregion->memory_region].end); + } + } } + KALDI_LOG << "Memory usage: " << memory_allocated << "/" + << memory_held << " bytes currently allocated/total-held; " + << num_blocks_allocated << "/" << num_blocks_free + << " blocks currently allocated/free; largest " + << "free/allocated block sizes are " + << largest_allocated_block << "/" << largest_free_block + << "; time taken total/cudaMalloc is " + << tot_time_taken_ << "/" << malloc_time_taken_ + << ", synchronized the GPU " << num_synchronizations_ + << " times out of " << (t_/2) << " frees; " + << "device memory info: " << GetFreeGpuMemory(NULL, NULL) + << "maximum allocated: " << max_allocated_memory_ + << "current allocated: " << allocated_memory_; +} + +// Note: we just initialize with the default options, but we can change it later +// (as long as it's before we first use the class) by calling SetOptions(). +CuMemoryAllocator::CuMemoryAllocator(): + opts_(CuAllocatorOptions()), + t_(0), + synchronize_gpu_t_(0), + num_synchronizations_(0), + tot_time_taken_(0.0), + malloc_time_taken_(0.0), + max_allocated_memory_(0), + allocated_memory_(0) { + // Note: we don't allocate any memory regions at the start; we wait for the user + // to call Malloc() or MallocPitch(), and then allocate one when needed. } -CuMemoryAllocator::CuMemoryAllocator(CuAllocatorOptions opts): - opts_(opts), - caches_(40), - cur_bytes_allocated_(0), - max_bytes_allocated_(0), - cur_bytes_used_(0), - max_bytes_used_(0), - t_(1), - num_user_allocations_(0), - num_system_allocations_(0), - tot_time_taken_in_cuda_malloc_(0.0), - tot_time_taken_in_cuda_malloc_pitch_(0.0), - tot_time_taken_in_cuda_free_(0.0), - tot_time_taken_in_malloc_pitch_(0.0) { } void* CuMemoryAllocator::MallocPitch(size_t row_bytes, size_t num_rows, size_t *pitch) { - CuTimer tim; - t_++; - num_user_allocations_++; - size_t requested_bytes = row_bytes * num_rows; - if (cur_bytes_used_ + requested_bytes > max_bytes_used_) - max_bytes_used_ = cur_bytes_used_ + requested_bytes; - MruCache &cache = GetCacheForSize(requested_bytes); - MemoryRequest request(row_bytes, num_rows); - CachedMemoryElement output; - if (cache.Lookup(request, &output)) { - // we have cached memory with this value. - void *ans = output.pointer; - *pitch = output.pitch; - used_map_[ans] = UsedMemoryElement(row_bytes, num_rows, output.pitch); - cur_bytes_used_ += requested_bytes; - tot_time_taken_in_malloc_pitch_ += tim.Elapsed(); - return ans; - } else { - // note: it's important that we already updated max_bytes_used_. - size_t next_bytes_allocated = cur_bytes_allocated_ + requested_bytes, - max_bytes_to_allocate = - static_cast(opts_.memory_factor * max_bytes_used_); - ssize_t bytes_overflow = next_bytes_allocated - max_bytes_to_allocate; - if (bytes_overflow > 0) { - // The amount we would have allocated, after fulfilling this request, - // would exceed our limits (we don't allow ourselves to allocate more than - // memory_factor times the maximum amount of memory the user ever owns - // during the lifetime of the program). So free some memory. - KALDI_ASSERT(bytes_overflow <= MemoryCached()); // sanity check. - FreeSomeCachedMemory(static_cast(bytes_overflow)); - KALDI_ASSERT(cur_bytes_allocated_ + requested_bytes <= - max_bytes_to_allocate); - } - void *ans = MallocPitchInternal(row_bytes, num_rows, pitch); - cur_bytes_allocated_ += requested_bytes; - if (cur_bytes_allocated_ > max_bytes_allocated_) - max_bytes_allocated_ = cur_bytes_allocated_; - used_map_[ans] = UsedMemoryElement(row_bytes, num_rows, *pitch); - cur_bytes_used_ += requested_bytes; - tot_time_taken_in_malloc_pitch_ += tim.Elapsed(); + Timer tim; + if (!opts_.cache_memory) { + void *ans; + CU_SAFE_CALL(cudaMallocPitch(&ans, pitch, row_bytes, num_rows)); + double elapsed = tim.Elapsed(); + tot_time_taken_ += elapsed; + malloc_time_taken_ += elapsed; return ans; } -} -void CuMemoryAllocator::FreeSomeCachedMemory(size_t bytes_to_free_in) { - CuTimer tim; - // the next few lines are responsible for increasing the amount of memory we - // are going to free, in case the user requested an amount that's very tiny - // compared with the total amount of memory ever used. This helps us - // to amortize the cost of visiting all of the buckets inside this code. - // (there are only 40 buckets so it's not so big, but we're being careful. - size_t bytes_cached = cur_bytes_allocated_ - cur_bytes_used_, - min_to_free = static_cast(max_bytes_used_ * opts_.delete_factor); - size_t bytes_to_free = std::min(bytes_cached, - std::max(bytes_to_free_in, min_to_free)), - bytes_freed = 0; - - size_t num_caches = caches_.size(), - t = t_; - // size_factor contains the approximate (power-of-two) size of the pointers - // that each cache's pointers contain. The 'cost' of keeping any given pointer, - // we declare to be the time since we last used it multiplied by the size - // of the memory in the pointer. - std::vector size_factor(num_caches); - for (size_t i = 0, j=1; i < num_caches; i++, j *= 2) - size_factor[i] = j; - - std::priority_queue > queue; - // Set up the queue. - for (int32 i = 0; i < num_caches; i++) { - const MruCache &cache = caches_[i]; - size_t cache_t = cache.LeastRecentTime(); - if (cache_t > 0) { // t == 0 means the cache is empty. - size_t interval = t - cache_t; - BaseFloat cost = size_factor[i] * interval; - KALDI_ASSERT(interval > 0); - queue.push(std::pair(cost, i)); - } - } - while (bytes_freed < bytes_to_free) { - // If the following fails it means I made some kind of bookkeeping error, - // and most likely we are trying to free more memory than we really have - // cached. - KALDI_ASSERT(!queue.empty() && "Code error."); - std::pair p = queue.top(); - int32 cache_index = p.second; - MruCache &cache = caches_[cache_index]; - queue.pop(); - if (queue.empty()) { - while (bytes_freed < bytes_to_free) { - bytes_freed += cache.RemoveLeastRecentlyUsed(); - } - } else { - BaseFloat next_worst_cost = queue.top().first; - while (1) { - bytes_freed += cache.RemoveLeastRecentlyUsed(); - if (bytes_freed >= bytes_to_free) - break; - size_t least_recent_time = cache.LeastRecentTime(); - if (least_recent_time == 0) // this cache is now empty - break; - size_t interval = t - least_recent_time; - KALDI_ASSERT(interval > 0); - BaseFloat cost = size_factor[cache_index] * interval; - if (cost < next_worst_cost) { - // There is another bucket that has worse cost than this, - // so stop processing this bucket-- but first put it - // back in the queue. - queue.push(std::pair(cost, cache_index)); - break; - } - } - } - } - KALDI_ASSERT(bytes_freed <= cur_bytes_allocated_); - cur_bytes_allocated_ -= bytes_freed; - tot_time_taken_in_cuda_free_ += tim.Elapsed(); + // Round up row_bytes to a multiple of 256. + row_bytes = (row_bytes + 255) & ~((size_t)255); + *pitch = row_bytes; + void *ans = MallocInternal(row_bytes * num_rows); + tot_time_taken_ += tim.Elapsed(); + return ans; } void CuMemoryAllocator::Free(void *ptr) { + Timer tim; + if (!opts_.cache_memory) { + CU_SAFE_CALL(cudaFree(ptr)); + tot_time_taken_ += tim.Elapsed(); + t_++; + return; + } t_++; - unordered_map::iterator iter = - used_map_.find(ptr); - if (iter == used_map_.end()) { + unordered_map::iterator iter = + allocated_block_map_.find(ptr); + if (iter == allocated_block_map_.end()) { KALDI_ERR << "Attempt to free CUDA memory pointer that was not allocated: " << ptr; } - const UsedMemoryElement &elem = iter->second; - size_t num_bytes = elem.row_bytes * elem.num_rows; - - cur_bytes_used_ -= num_bytes; - MruCache &cache = GetCacheForSize(num_bytes); + MemoryBlock *block = iter->second; + allocated_memory_ -= (block->end - block->begin); + allocated_block_map_.erase(iter); + block->t = t_; + block->thread_id = std::this_thread::get_id(); + block->allocated = false; + + // If this is not the first block of the memory region and the previous block + // is not allocated, merge this block into the previous block. + MemoryBlock *prev_block = block->prev; + if (prev_block != NULL && !prev_block->allocated) { + RemoveFromFreeBlocks(prev_block); + prev_block->end = block->end; + if (prev_block->thread_id != block->thread_id) { + // the two blocks we're merging were freed by different threads, so we + // give the 'nonexistent thread' as their thread, which means that + // whichever thread requests that block, we force synchronization. We can + // assume that prev_block was previously allocated (prev_block->t > 0) + // because we always start from the left when allocating blocks, and we + // know that this block was previously allocated. + prev_block->thread_id = std::thread::id(); + } + prev_block->t = t_; + prev_block->next = block->next; + if (block->next) + block->next->prev = prev_block; + delete block; + block = prev_block; + } - cache.Insert(MemoryRequest(elem.row_bytes, elem.num_rows), - CachedMemoryElement(ptr, t_, elem.pitch)); - used_map_.erase(iter); + // If this is not the last block of the memory region and the next block is + // not allocated, merge the next block into this block. + MemoryBlock *next_block = block->next; + if (next_block != NULL && !next_block->allocated) { + // merge next_block into 'block', deleting 'next_block'. Note: at this + // point, if we merged with the previous block, the variable 'block' may now + // be pointing to that previous block, so it would be a 3-way merge. + RemoveFromFreeBlocks(next_block); + block->end = next_block->end; + if (next_block->thread_id != block->thread_id && next_block->t > 0) { + // the two blocks we're merging were freed by different threads, so we + // give the 'nonexistent thread' as their thread, which means that + // whichever thread requests that block, we force synchronization. there + // is no need to do this if next_block->t == 0, which would mean it had + // never been allocated. + block->thread_id = std::thread::id(); + } + // We don't need to inspect the 't' value of next_block; it can't be + // larger than t_ because t_ is now. + block->next = next_block->next; + if (block->next) + block->next->prev = block; + delete next_block; + } + AddToFreeBlocks(block); + tot_time_taken_ += tim.Elapsed(); } -size_t CuMemoryAllocator::MruCache::LeastRecentTime() const { - if (list_.empty()) { - KALDI_PARANOID_ASSERT(map_.empty()); - return 0; - } else { - const MemoryRequest &mr = list_.front(); - MapType::const_iterator iter = map_.find(mr); - KALDI_ASSERT(iter != map_.end()); - const MapValueType &queue = iter->second; - KALDI_ASSERT(!queue.empty()); - return queue.front().first.t; +void CuMemoryAllocator::AllocateNewRegion(size_t size) { + int64 free_memory, total_memory; + std::string mem_info = GetFreeGpuMemory(&free_memory, &total_memory); + opts_.Check(); + size_t region_size = static_cast(free_memory * opts_.memory_proportion); + if (region_size < size) + region_size = size; + // Round up region_size to an exact multiple of 1M (note: we expect it will + // be much larger than that). 1048575 is 2^20 - 1. + region_size = (region_size + 1048575) & ~((size_t)1048575); + + if (!memory_regions_.empty()) { + // If this is not the first region allocated, print some information. + KALDI_LOG << "About to allocate new memory region of " << region_size + << " bytes; current memory info is: " << mem_info; + } + void *memory_region; + cudaError_t e; + { + Timer tim; + e = cudaMalloc(&memory_region, region_size); + malloc_time_taken_ += tim.Elapsed(); + } + if (e != cudaSuccess) { + PrintMemoryUsage(); + if (!CuDevice::Instantiate().IsComputeExclusive()) { + KALDI_ERR << "Failed to allocate a memory region of " << region_size + << " bytes. Possibly this is due to sharing the GPU. Try " + << "switching the GPUs to exclusive mode (nvidia-smi -c 3) and using " + << "the option --use-gpu=wait to scripts like " + << "steps/nnet3/chain/train.py. Memory info: " + << mem_info; + } else { + KALDI_ERR << "Failed to allocate a memory region of " << region_size + << " bytes. Possibly smaller minibatch size would help. " + << "Memory info: " << mem_info; + } } + // this_num_subregions would be approximately 'opts_.num_subregions' if + // 'region_size' was all the device's memory. (We add one to round up). + // We're aiming to get a number of sub-regions approximately equal to + // opts_.num_subregions by the time we allocate all the device's memory. + size_t this_num_subregions = 1 + + (region_size * opts_.num_subregions) / total_memory; + + size_t memory_region_index = memory_regions_.size(); + memory_regions_.resize(memory_region_index + 1); + MemoryRegion &this_region = memory_regions_.back(); + + this_region.begin = static_cast(memory_region); + this_region.end = this_region.begin + region_size; + // subregion_size will be hundreds of megabytes. + size_t subregion_size = region_size / this_num_subregions; + + std::vector new_subregions; + char* subregion_begin = static_cast(memory_region); + for (size_t i = 0; i < this_num_subregions; i++) { + SubRegion *subregion = new SubRegion(); + subregion->memory_region = memory_region_index; + subregion->begin = subregion_begin; + if (i + 1 == this_num_subregions) { + subregion->end = this_region.end; + KALDI_ASSERT(subregion->end > subregion->begin); + } else { + subregion->end = subregion_begin + subregion_size; + subregion_begin = subregion->end; + } + subregion->next = NULL; + if (i > 0) { + new_subregions.back()->next = subregion; + } + new_subregions.push_back(subregion); + } + // Initially the memory is in a single block, owned by + // the first subregion. It will be split up gradually. + MemoryBlock *block = new MemoryBlock(); + block->begin = this_region.begin; + block->end = this_region.end; + block->subregion = new_subregions.front(); + block->allocated = false; + block->t = 0; // was never allocated. + block->next = NULL; + block->prev = NULL; + for (size_t i = 0; i < this_num_subregions; i++) + subregions_.push_back(new_subregions[i]); + SortSubregions(); + this_region.block_begin = block; + + AddToFreeBlocks(block); } -bool CuMemoryAllocator::MruCache::Lookup(const MemoryRequest &request, - CachedMemoryElement *output) { - MapType::iterator iter = map_.find(request); - if (iter == map_.end()) - return false; - MapValueType &q = iter->second; - KALDI_ASSERT(!q.empty()); - // use q.back() as we want to return the most recently used one if there - // is a choice. We believe this will give better caching behavior. - *output = q.back().first; - list_.erase(q.back().second); - q.pop_back(); - if (q.empty()) - map_.erase(request); - return true; +// We sort the sub-regions according to the distance between the start of the +// MemoryRegion of which they are a part, and the start of the SubRegion. This +// will generally mean that the highest-numbered SubRegion-- the one we keep +// free at all costs-- will be the end of the first block which we allocated +// (which under most situations will be the largest block). +void CuMemoryAllocator::SortSubregions() { + largest_free_block_.resize(subregions_.size()); + + std::vector > pairs; + for (size_t i = 0; i < subregions_.size(); i++) { + SubRegion *subregion = subregions_[i]; + MemoryRegion &memory_region = memory_regions_[subregion->memory_region]; + size_t distance = subregion->begin - memory_region.begin; + pairs.push_back(std::pair(distance, subregion)); + } + std::sort(pairs.begin(), pairs.end()); + for (size_t i = 0; i < subregions_.size(); i++) { + subregions_[i] = pairs[i].second; + subregions_[i]->subregion_index = i; + if (subregions_[i]->free_blocks.empty()) + largest_free_block_[i] = 0; + else + largest_free_block_[i] = subregions_[i]->free_blocks.begin()->first; + } } -void CuMemoryAllocator::MruCache::Insert(const MemoryRequest &request, - const CachedMemoryElement &element) { - list_.push_back(request); - map_[request].push_back(std::pair( - element, - --list_.end())); +CuMemoryAllocator::~CuMemoryAllocator() { + // We mainly free these blocks of memory so that cuda-memcheck doesn't report + // spurious errors. + for (size_t i = 0; i < memory_regions_.size(); i++) { + // No need to check the return status here-- the program is exiting anyway. + cudaFree(memory_regions_[i].begin); + } + for (size_t i = 0; i < subregions_.size(); i++) { + SubRegion *subregion = subregions_[i]; + for (auto iter = subregion->free_blocks.begin(); + iter != subregion->free_blocks.end(); ++iter) + delete iter->second; + delete subregion; + } } -size_t CuMemoryAllocator::MruCache::RemoveLeastRecentlyUsed() { - // Remove least-recently-used element from cache. - KALDI_ASSERT(!list_.empty()); - MemoryRequest request = list_.front(); - MapType::iterator iter = map_.find(request); - KALDI_ASSERT(iter != map_.end()); - MapValueType &queue = iter->second; - KALDI_ASSERT(!queue.empty()); - // least recently used elements are at the front of the queue. - std::pair &p = queue.front(); - KALDI_ASSERT(p.second == list_.begin()); - CU_SAFE_CALL(cudaFree(p.first.pointer)); - queue.pop_front(); - if (queue.empty()) - map_.erase(request); - list_.pop_front(); - return request.first * request.second; -} -CuMemoryAllocator::MruCache& CuMemoryAllocator::MruCache::operator = ( - const CuMemoryAllocator::MruCache &other) { - KALDI_ASSERT(other.list_.empty()); - return *this; -} -CuMemoryAllocator::MruCache::MruCache( - const CuMemoryAllocator::MruCache &other) { - KALDI_ASSERT(other.list_.empty()); -} +CuMemoryAllocator g_cuda_allocator; +} // namespace kaldi -} +#endif // HAVE_CUDA -#endif // HAVE_CUDA +namespace kaldi { + +// Define/initialize this global variable. It was declared in cu-allocator.h. +// This has to be done outside of the ifdef, because we register the options +// whether or not CUDA is compiled in (so that the binaries accept the same +// options). +CuAllocatorOptions g_allocator_options; + +} diff --git a/src/cudamatrix/cu-allocator.h b/src/cudamatrix/cu-allocator.h index 0f96315e848..d7d65da806a 100644 --- a/src/cudamatrix/cu-allocator.h +++ b/src/cudamatrix/cu-allocator.h @@ -23,54 +23,137 @@ #define KALDI_CUDAMATRIX_CU_ALLOCATOR_H_ #if HAVE_CUDA == 1 - #include +#include +#include +#endif + #include +#include #include #include #include +#include #include -#include -#include #include "base/kaldi-common.h" #include "util/stl-utils.h" +#include "itf/options-itf.h" namespace kaldi { // For now we don't give the user a way to modify these from the command line. +// or the code, it just documents what the default options are. To change +// the options, you have to do it in the code. struct CuAllocatorOptions { - // memory_factor is the total amount of (allocated + cached) memory that we - // allow to be held, relative to the max amount of memory the program has ever - // allocated. It will increase the amount of memory the program will - // potentially consume, by this factor. - BaseFloat memory_factor; - - // This is the minimum amount of memory that we will delete when we are forced - // to delete stuff, relative to the max amount of memory the program has ever - // allocated. This should be less than memory_factor - 1.0 and > 0. It - // shouldn't be too critical. The reason it exists is to avoid calling the - // cleanup code and only releasing very small amounts of memory, because there - // is a constant overhead proportional to the number of buckets. - BaseFloat delete_factor; - - CuAllocatorOptions(): memory_factor(1.3), - delete_factor(0.001) { } + // True if we are going to actually cache memory allocations on this device. + // You'd normally set it to false only if you wanted to debug a possible + // memory problem using cuda-memcheck or cuda-gdb. It will be slower, but + // using CUDA's native allocator allows those tools to detect out-of-region + // memory accesses. + bool cache_memory; + + // The proportion of the device's memory that the CuAllocator allocates to + // start with; by default this is 0.5, although if you want to share the + // device (not recommended!) you should set this lower. + BaseFloat memory_proportion; + + // The target number of subregions of the entire CUDA device memory (we'll + // start with a smaller number of memory_proportion is << 1). Kind of + // a tuning knob.. more regions will make it more aggressively consolidate + // memory low addresses. + int32 num_subregions; + + CuAllocatorOptions(): + cache_memory(true), memory_proportion(0.5), num_subregions(20) { } + + void Register(OptionsItf *po) { + po->Register("cuda-cache-memory", &cache_memory, "True if you want " + "to use the caching allocator. Set this to false only if you " + "want to use cuda-memcheck or cuda-gdb; it will be slower."); + po->Register("cuda-memory-proportion", &memory_proportion, + "Proportion of the GPU device memory that the allocator " + "should allocate at the start"); + } void Check() { - KALDI_ASSERT(delete_factor < memory_factor - 1.0 && delete_factor > 0.0); + // don't let it get too close to 1; + KALDI_ASSERT(memory_proportion >= 0.05 && memory_proportion < 0.99); } }; +extern CuAllocatorOptions g_allocator_options; + +inline void RegisterCuAllocatorOptions(OptionsItf *po) { + g_allocator_options.Register(po); +} +} // namespace kaldi + + +#if HAVE_CUDA == 1 +namespace kaldi { + +/** + This class allocates large regions of memory from the GPU and allocates + sub-blocks of it for the user. This is needed because the CUDA malloc and + free routines are very slow. + + The user doesn't access this class directly, it is accessed via the CuDevice + object. The CuDevice class allocates memory using this class's Malloc() and + MallocPitch() functions, and frees them with its Free() function, and this + class caches the memory blocks to avoid calling the CUDA library's + malloc/free functions too often. If the application is using multiple + threads, it's necessary to lock this class before using it, and in that case + the CuDevice class calls the MallocLocking() and MallocPitchLocking() + versions of the allocation functions (but the user should call + CuDevice::AllowMultithreading() if the application plans to use GPU + functionality from multiple CPU threads). + + NOTE ON SYNCHRONIZATION: if multiple CUDA streams are used there is a + potential problem with any caching allocator which shares its pool across + CUDA streams. That is: if a memory block is freed by stream 1 and allocated to + stream 2, an operation might start in stream 2 before stream 1 has finished + working with that memory location. We solve this here using a rather low-tech + solution, relying on calling SynchronizeGpu() which submits a no-op kernel + into the legacy default stream. Each + time CuMemoryAllocator()::Free() is called and we cache the memory block + in this class, we record the thread-id of the CPU thread from which it was + freed, as well as a timestamp (the t_ member of CuMemoryAllocator, which + we increment every time the class is used). When we allocate memory + that was cached, we try to allocate it from a block that was relased by the + same CPU thread; and if that is not possible and we haven't called + SynchronizeGpu() since the block was freed, then we call + SynchronizeGpu(). The hope is that this will happen quite rarely. + Note that this is based on the assumption that the user is using the + per-thread default stream (indeed this is how we compile). If the + user were to make explicit use of CUDA streams, this mechanism would + not necessarily be sufficient to prevent data-race conditions and the + user might have to take further precautions. + + NOTE ON FRAGMENTATION: Memory fragmentation is one of the main problems that + you'll run into with allocators like this. This allocator will allocate a + small number of large regions of memory, and allocate smaller pieces of + memory that it splits off from the regions as needed. It will always merge + adjacent blocks as much as it can when the user frees memory. The main + heuristic to avoid memory fragmenting too much is that it always allocates, + where possible, from memory that's as close as possible to the start of a + memory region. This will tend to keep all the small allocations together at + the beginning of the memory region, and hopefully keep large blocks availale + at the end. The mechanism to always allocate from as close as possible to + the start of the memory region, is that we split up the memory regions into + a small number of sub-regions and, when handling a request for allocation, + allocate it from the lowest-numbered sub-region that can meet a request for + that size. (Note: we can allocate blocks that span sub-regions, so this + approach does not limit the block size we can allocate). + +*/ -// Class that caches memory for us (the CUDA -// malloc and free routines are very slow). -// This is a member of the CuDevice class. class CuMemoryAllocator { public: - /// Allocates memory on the CUDA device, of size 'size'. + /// Allocates memory on the CUDA device, of size 'size'. size == 0 is not + /// allowed and is an error. void* Malloc(size_t size); /// Allocation function for matrix-like things. @@ -95,156 +178,187 @@ class CuMemoryAllocator { Free(ptr); } + void PrintMemoryUsage() const; + + // returns the current memory allocated within the cache + size_t GetAllocatedMemory() { return allocated_memory_; } - // the maximum amount of memory that was ever allocated in the lifetime of the - // program, in bytes. - size_t MaxMemoryAllocated() const { return max_bytes_allocated_; } + // returns the maximum memory used within the cache during current execution + size_t GetMaxAllocatedMemory() { return max_allocated_memory_; } - // memory held in the cache currently, in bytes. - size_t MemoryCached() const { return cur_bytes_allocated_ - cur_bytes_used_; } + CuMemoryAllocator(); - // memory that's cached plus memory that's allocated, in bytes. - size_t MemoryAllocated() const { return cur_bytes_allocated_; } + // Allows you to set options: must be called before any Malloc function is + // called on this class. It's done this way so the options can be changed + // by the user (c.f. RegisterCuAllocatorOptions()) before the options are read. + void SetOptions(const CuAllocatorOptions &opts) { opts_ = opts; } - void PrintMemoryUsage() const; + ~CuMemoryAllocator(); - CuMemoryAllocator(CuAllocatorOptions opts); private: - void FreeSomeCachedMemory(size_t bytes_to_free); + struct SubRegion; + + struct MemoryBlock { + char *begin; // The beginning of the block (in CUDA memory) + char *end; // the end of the block (in CUDA memory) + SubRegion *subregion; // Pointer to the SubRegion to which this memory + // block belongs. + bool allocated; // True if this MemoryBlock has currently been given to the + // user; false if not. + + size_t t; // Zero if this memory block was never given to the user; + // otherwise, the time value (t_ in the CuAllocator class) + // when it was most recently either allocated to the user + // or freed by the user. + + std::thread::id thread_id; // If allocated == false and t > 0 (i.e. this + // memory block was released by the user), the + // thread-id of the user thread that freed this + // block, or the invalid thread-id as created by + // the constructor of std::thread::id if this + // block was created by merging blocks from + // different threads. Required for + // synchronization; and note that we assume + // there is one CUDA stream per CPU thread. + + MemoryBlock *next; // The next MemoryBlock within this MemoryRegion (or + // NULL if this is the last one); its 'begin' would be + // the same as the 'end' of this block. + MemoryBlock *prev; // The previous MemoryBlock within this MemoryRegion (or + // NULL if this is the first one); its 'end' would be the + // same as the 'begin' of this block. - // This calls CudaMallocPitch, checks for errors (dies if it has to), and - // returns the result. It's up to the caller to do all the bookkeeping though. - inline void* MallocPitchInternal(size_t row_bytes, size_t num_rows, size_t *pitch); + }; - typedef std::pair MemoryRequest; // (row_bytes, num_rows). - struct CachedMemoryElement { - void *pointer; // the CUDA memory location that we own - size_t t; // time value when we put this in the cache. - size_t pitch; // pitch of this memory region (c.f. cudaMallocPitch()). - CachedMemoryElement() { } - CachedMemoryElement(void *pointer, size_t t, size_t pitch): - pointer(pointer), t(t), pitch(pitch) { } + // a MemoryRegion is a large piece of memory that we allocated via CudaMalloc. + // there normally won't be more than about 3 or 4 of these. + // We'll identify MemoryRegions by a size_t (e.g 0, 1, 2, 3... ) which is an + // index into the memory_regions_ vector. + struct MemoryRegion { + char *begin; // 'begin' is the start of the memory region. + char *end; // 'end' is the end of the memory region. + SubRegion *subregion_begin; // The first SubRegion that belongs to this + // MemoryRegion. + MemoryBlock *block_begin; // The first MemoryBlock that belongs to this + // MemoryRegion. }; - // This class caches a map from MemoryRequest to a list of CachedMemoryElements, - // and gives us access to the least-recently-used element for efficient. - // removal. - // We will have an instance of this class for each power-of-2 of size in - // bytes. This makes it easier to, when we need to delete something, find - // the item for which the (time-since-used * size-in-bytes) is approximately - // greatest. - class MruCache { - public: - size_t LeastRecentTime() const; // t value of least recent CachedMemoryElement (0 - // if empty). - - size_t RemoveLeastRecentlyUsed(); // Remove least-recently-used element - // from cache. Return size in bytes of - // that removed memory region. Crash if - // this was empty. - - // Attempts lookup of the most recently cached element corresponding to - // 'request'. If available, removes it from the cache and puts it to - // 'output', and returns true. Otherwise returns false. - bool Lookup(const MemoryRequest &request, - CachedMemoryElement *output); - - // Inserts this CachedMemoryElement to the list of CachedMemoryElements for this - // MemoryRequest. The time in the CachedMemoryElement is expected to be greater - // than times in previously supplied CachedMemoryElements. - void Insert(const MemoryRequest &request, - const CachedMemoryElement &element); - - struct MemoryRequestHasher { - // input is interpreted as (row_bytes, num_rows). row_bytes will always - // be a multiple of 4, and num_rows will frequently be a multiple of - // powers of 2 also. We need to shift right and add so that there will be - // some action in the lower-order bits. - size_t operator () (const std::pair &p) const noexcept { - size_t temp = p.first + 1867 * p.second; - return temp + (temp >> 2) + (temp >> 8); - } - }; - - MruCache() { } - // Define these to make inclusion in std::vector possible, but make them - // fail if called on anything but empty cache objects-- we never resize - // the vector of caches after initializing it. - MruCache &operator = (const MruCache &other); - MruCache(const MruCache &other); - private: - typedef std::list ListType; - typedef std::list::iterator ListIterType; - typedef std::deque > MapValueType; - typedef unordered_map MapType; - // 'list_' contains MemoryRequests with the most recent on the back (where they are added), - // and least recent on the front (where they are removed by RemoveLeastRecentlyUsed, although - // they are also removed from random parts of the list by Lookup(). - // There will in general be duplicates of MemoryRequests in the list, as - // many as there are entries in the MapValueType. - ListType list_; - // 'map_' maps from a MemoryRequest to a queue of (memory-element, - // iterator), with the most-recently-added things at the back; we remove - // things from the front of these queues (oldest) inside - // RemoveLeastRecentlyUsed(), and from the back (newest) in Lookup. - MapType map_; + // a SubRegion is a smaller zone of memory within a MemoryRegion. For + // example, we divide the first MemoryRegion we allocate into 10 blocks, and + // if we allocate blocks of memory later on, we'll sub-divide them into blocks + // of about the same size. A SubRegion is just a largish bin into which we + // put any blocks of memory that happen to start within that SubRegion; + // actually, memory blocks may cross over the boundaries of SubRegions. The + // motivation for dividing up MemoryRegions into SubRegions is that it allos + // us an efficient mechanism to segregate smaller memory blocks into higher + // memory and larger ones into lower memory: for each allocation, we allocate + // it from the highest-numbered SubRegion that is able to allocate something of + // that size. Over time, this will lead to smaller memory blocks being + // concentrated in higher-numbered SubRegions. + struct SubRegion { + size_t memory_region; // This is an index into the memory_regions_ vector + // which identifies which MemoryRegion this SubRegion + // is a part of. + size_t subregion_index; // The index of this SubRegion within the + // subregions_ vector; this can change when we + // allocate more MemoryRegions. + char *begin; // 'begin' is the start of the memory in this SubRegion. + char *end; // 'end' is the end of the memory in this SubRegion. + + // Contains the free MemoryBlocks starting within this SubRegion. + std::set > free_blocks; + + // Pointer to the next SubRegion within this MemoryRegion (i.e. the SubRegion + // whose begin equals this one's end), or NULL if this is the last one. + SubRegion *next; }; + // Tries to allocate CUDA memory of the given size; will crash if it was not + // able to. + inline void* MallocInternal(size_t size); - inline MruCache &GetCacheForSize(size_t num_bytes); + // Allocates from a given SubRegion, after we have determined that it + // can satisfy this request. Broken out of MallocInternal for clarity. + inline void* MallocFromSubregion(SubRegion *subregion, size_t size); - CuAllocatorOptions opts_; - // indexed by log_2 (amount of memory requested), the caches. - std::vector caches_; + // Splits the given MemoryBlock so that one piece is of size 'size', and + // returns the piece which is of size 'size'. The caller guarantees that + // 'size' is less than the current size of the memory block, that 'block' is + // not currently allocated (i.e. block->allocated == false). This function + // assumes that, at entry, 'block' is not present in its subregion's + // 'free_blocks' (because the caller has removed it), and it takes + // responsibility for entering the 'unused' part (the part we're not + // returning) into its subregion's 'free_blocks' by calling AddToFreeBlocks(). + inline MemoryBlock *SplitBlock(MemoryBlock *block, size_t size); - size_t cur_bytes_allocated_; // number of bytes currently owned by callers or - // cached. - size_t max_bytes_allocated_; // the max over all time, of cur_bytes_allocated_. - size_t cur_bytes_used_; // number of bytes currently owned by callers. - size_t max_bytes_used_; // the max over all time, of cur_bytes_used_. - size_t t_; // time counter, incremented with each call. - size_t num_user_allocations_; // number of times user calls Malloc* - size_t num_system_allocations_; // number of times we call cudaMalloc*. - double tot_time_taken_in_cuda_malloc_; // time in cudaMalloc - double tot_time_taken_in_cuda_malloc_pitch_; // time in cudaMallocPitch - double tot_time_taken_in_cuda_free_; // time in cudaFree - double tot_time_taken_in_malloc_pitch_; // time in this->MallocPitch() - - - // a memory element is 'used' when it is currently possessed by the caller - // (and is not in our cache). - struct UsedMemoryElement { - size_t row_bytes; - size_t num_rows; - size_t pitch; - UsedMemoryElement() { } - UsedMemoryElement(size_t row_bytes, size_t num_rows, size_t pitch): - row_bytes(row_bytes), num_rows(num_rows), pitch(pitch) { } - }; + // Removes this block from the 'free_blocks' set of the SubRegion to which + // it belongs. This is called when allocating a block, and from other places. + void RemoveFromFreeBlocks(MemoryBlock *block); + + // Adds this block to the 'free_blocks' set of the SubRegion to which it + // belongs. This is called when freeing a block, and from other places. + void AddToFreeBlocks(MemoryBlock *block); + + // This function is called when an allocation failed and we need to try to + // allocate more memory from the evice. The 'size' is the size of the + // requested memory block whose allocation failed-- it's provided so that + // we can be sure to allocate a new region of at least this size. + void AllocateNewRegion(size_t size); + + // Called from AllocateNewRegion(), this ensures that the subregions are + // sorted as we want (which is a kind of heuristic that will be discussed in + // the code), and it also recomputes the largest_free_block_ array. + void SortSubregions(); - struct PointerHasher { - size_t operator() (const void *arg) const noexcept { - // the last few bits tend to be very predictable, for alignment reasons (CUDA - // allocation may align on 256 byte or 512 byte boundaries or something similar). - size_t temp = reinterpret_cast(arg); - return (temp >> 4) + (temp >> 9); - } - }; - // This is a map from memory locations owned by the user, so we can recover - // the information when people call Free() and we add it back into the cache. - unordered_map used_map_; - // this is only locked by the '*Locking' versions of the functions. + CuAllocatorOptions opts_; + + std::vector memory_regions_; + + std::vector subregions_; + + // For each SubRegion in sub_regions_, this vector gives us the size of the + // largest free block present in that SubRegion, which is equal to + // sub_regions_[i]->free_blocks.begin()->first. It allows us to fairly + // efficiently find the lowest-numbered SubRegion which can handle a + // particular request for memory. + std::vector largest_free_block_; + + size_t t_; // time counter, incremented with each call. + size_t synchronize_gpu_t_; // value of t_ at the last time we called + // SynchronizeGpu(). + size_t num_synchronizations_; // number of times we called SynchronizeGpu() + double tot_time_taken_; // Total time taken in calls to this object. + double malloc_time_taken_; // Total time we spent calling cudaMalloc(). + + // This is a map from memory locations currently owned by the user, to the + // MemoryBlock which stores the information about that location. + std::unordered_map allocated_block_map_; + + // this is only locked by the '*Locking' versions of the functions (necessary only + // in multi-threaded applications). std::mutex mutex_; + // Keep track of the memory usage from the cache to track the maximum memory used by + // the application + size_t max_allocated_memory_; + size_t allocated_memory_; }; -} // namespace +// This function returns some printable information about the memory used +// as a string: an example showing the format is: +// "free: 10M, used: 490M, total: 500M: free/total: 0.02" +// In addition, if the pointers 'free' and 'total' are non-NULL, it will +// output to them the free memory and the total memory of the device. +std::string GetFreeGpuMemory(int64* free, int64* total); + +extern CuMemoryAllocator g_cuda_allocator; + +} // namespace kaldi #endif // HAVE_CUDA diff --git a/src/cudamatrix/cu-array-inl.h b/src/cudamatrix/cu-array-inl.h index ddae19b9a4e..567cc0f6d18 100644 --- a/src/cudamatrix/cu-array-inl.h +++ b/src/cudamatrix/cu-array-inl.h @@ -105,8 +105,9 @@ void CuArrayBase::CopyFromVec(const std::vector &src) { if (CuDevice::Instantiate().Enabled()) { CuTimer tim; CU_SAFE_CALL( - cudaMemcpy(data_, &src.front(), src.size() * sizeof(T), - cudaMemcpyHostToDevice)); + cudaMemcpyAsync(data_, &src.front(), src.size() * sizeof(T), + cudaMemcpyHostToDevice, cudaStreamPerThread)); + CU_SAFE_CALL(cudaStreamSynchronize(cudaStreamPerThread)); CuDevice::Instantiate().AccuProfile(__func__, tim); } else #endif @@ -122,7 +123,9 @@ void CuArray::CopyFromVec(const std::vector &src) { #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { CuTimer tim; - CU_SAFE_CALL(cudaMemcpy(this->data_, &src.front(), src.size()*sizeof(T), cudaMemcpyHostToDevice)); + CU_SAFE_CALL(cudaMemcpyAsync(this->data_, &src.front(), + src.size()*sizeof(T), cudaMemcpyHostToDevice, cudaStreamPerThread)); + CU_SAFE_CALL(cudaStreamSynchronize(cudaStreamPerThread)); CuDevice::Instantiate().AccuProfile(__func__, tim); } else #endif @@ -139,8 +142,9 @@ void CuArray::CopyFromArray(const CuArrayBase &src) { #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { CuTimer tim; - CU_SAFE_CALL(cudaMemcpy(this->data_, src.data_, this->dim_ * sizeof(T), - cudaMemcpyDeviceToDevice)); + CU_SAFE_CALL(cudaMemcpyAsync(this->data_, src.data_, this->dim_ * sizeof(T), + cudaMemcpyDeviceToDevice, + cudaStreamPerThread)); CuDevice::Instantiate().AccuProfile(__func__, tim); } else #endif @@ -158,8 +162,8 @@ void CuArrayBase::CopyFromArray(const CuArrayBase &src) { if (CuDevice::Instantiate().Enabled()) { CuTimer tim; CU_SAFE_CALL( - cudaMemcpy(this->data_, src.data_, dim_ * sizeof(T), - cudaMemcpyDeviceToDevice)); + cudaMemcpyAsync(this->data_, src.data_, dim_ * sizeof(T), + cudaMemcpyDeviceToDevice, cudaStreamPerThread)); CuDevice::Instantiate().AccuProfile(__func__, tim); } else #endif @@ -178,7 +182,9 @@ void CuArrayBase::CopyToVec(std::vector *dst) const { #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { CuTimer tim; - CU_SAFE_CALL(cudaMemcpy(&dst->front(), Data(), this->dim_ * sizeof(T), cudaMemcpyDeviceToHost)); + CU_SAFE_CALL(cudaMemcpyAsync(&dst->front(), Data(), this->dim_ * sizeof(T), + cudaMemcpyDeviceToHost, cudaStreamPerThread)); + CU_SAFE_CALL(cudaStreamSynchronize(cudaStreamPerThread)); CuDevice::Instantiate().AccuProfile("CuArray::CopyToVecD2H", tim); } else #endif @@ -195,7 +201,9 @@ void CuArrayBase::CopyToHost(T *dst) const { #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { CuTimer tim; - CU_SAFE_CALL(cudaMemcpy(dst, Data(), this->dim_ * sizeof(T), cudaMemcpyDeviceToHost)); + CU_SAFE_CALL(cudaMemcpyAsync(dst, Data(), this->dim_ * sizeof(T), + cudaMemcpyDeviceToHost, cudaStreamPerThread)); + CU_SAFE_CALL(cudaStreamSynchronize(cudaStreamPerThread)); CuDevice::Instantiate().AccuProfile("CuArray::CopyToVecD2H", tim); } else #endif @@ -211,7 +219,9 @@ void CuArrayBase::SetZero() { #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { CuTimer tim; - CU_SAFE_CALL(cudaMemset(this->data_, 0, this->dim_ * sizeof(T))); + CU_SAFE_CALL(cudaMemsetAsync(this->data_, 0, this->dim_ * sizeof(T), + cudaStreamPerThread)); + CU_SAFE_CALL(cudaStreamSynchronize(cudaStreamPerThread)); CuDevice::Instantiate().AccuProfile("CuArray::SetZero", tim); } else #endif diff --git a/src/cudamatrix/cu-device.cc b/src/cudamatrix/cu-device.cc index c5114ed8b22..85c2492c074 100644 --- a/src/cudamatrix/cu-device.cc +++ b/src/cudamatrix/cu-device.cc @@ -23,7 +23,6 @@ #if HAVE_CUDA == 1 - #include #include #include @@ -42,23 +41,15 @@ #include "base/kaldi-utils.h" #include "util/common-utils.h" #include "util/kaldi-io.h" +// the following is for cuda_legacy_noop(). +#include "cudamatrix/cu-kernels-ansi.h" namespace kaldi { -/** - This function was added by Dan in July 2015 after upgrading on the CLSP - cluster to the CUDA 7.0 toolkit; the old mechanism of just calling - cudaThreadSynchronize() [==cudaDeviceSynchronize()] and having it - automagically select a GPU (when exclusive mode is on) doesn't seem to work - any more, in situations where GPU 0 is already being used. This works. It's - not 100% clear if the fact that the old code wasn't working was a bug, or a - changed feature (the NVidia docs were never super-clear regarding device - initialization). But regardless, changing to this new mechanism should be - harmless even if the problem was specific to the CLSP grid. -*/ - +/// This function attempts to get a CUDA device context on some available device +/// by doing 'cudaFree(0)'. If it succeeds it returns true; if it fails, it +/// outputs some debugging information into 'debug_str' and returns false. static bool GetCudaContext(int32 num_gpus, std::string *debug_str) { - // Our first attempt to get a device context is: we do cudaFree(0) and see if // that returns no error code. If it succeeds then we have a device // context. Apparently this is the canonical way to get a context. @@ -88,53 +79,79 @@ static bool GetCudaContext(int32 num_gpus, std::string *debug_str) { return false; } -/** - * SelectGpuId(use_gpu) - * - * There are 3 'use_gpu' modes for GPU selection: - * "yes" -- Select GPU automatically (or get one by exclusive mode) - * and die if this fails. - * "optional" -- Do as above, but if it fails, back off to CPU. - * "no" -- Run on CPU. - * - * In case of Compute exclusive mode, the GPU is selected by OS. - * - * Otherwise GPU selection is based on largest proportion of free memory. - * This can eventually lead to multiple processes computing on single GPU, - * which is slow. More practical is to use "compute exclusive mode". - * - * This method is to be called at the very beginning of the program - * (before first allocation in cudamatrix), or not at all (default to CPU). - * - */ + +void CuDevice::Initialize() { + // This function may be called in the following two situations: + // + // (1) in the main thread, only when a GPU is not currently being used, either + // within a call like CuDevice()::Instantiate().SelectGpuId(..) + // (where the Instantiate() call will call Initialize() before SelectGpuId() + // is called, just because of how Instantiate() works), or in a call + // to 'CuDevice::Instantiate().Enabled()'. In this case it will just + // set initialized_ to true and notice that device_id_ == 1, and do nothing. + // + // (2) in threads created by the user, as soon as someone calls something that + // might potentially use the GPU, via CuDevice()::Instantiate(). + // If device_id_ is >= 0, this will create the cuBLAS and cuSparse handles. + KALDI_ASSERT(!initialized_); + initialized_ = true; + if (device_id_ == -1) { + // There is nothing to do; we are not using a GPU. + return; + } else { + if (!multi_threaded_) { + multi_threaded_ = true; + KALDI_WARN << "For multi-threaded code that might use GPU, you should call " + "CuDevice::Instantiate().AllowMultithreading() at the start of " + "the program."; + } + device_id_copy_ = device_id_; + cudaSetDevice(device_id_); + // Initialize CUBLAS. + CUBLAS_SAFE_CALL(cublasCreate(&cublas_handle_)); + CUBLAS_SAFE_CALL(cublasSetStream(cublas_handle_, cudaStreamPerThread)); + + #if CUDA_VERSION >= 9000 + if (device_options_.use_tensor_cores) { + // Enable tensor cores in CUBLAS + // Note if the device does not support tensor cores this will fall back to normal math mode + CUBLAS_SAFE_CALL(cublasSetMathMode(cublas_handle_, + CUBLAS_TENSOR_OP_MATH)); + } + #endif + + // Initialize the cuSPARSE library + CUSPARSE_SAFE_CALL(cusparseCreate(&cusparse_handle_)); + CUSPARSE_SAFE_CALL(cusparseSetStream(cusparse_handle_, cudaStreamPerThread)); + } +} + void CuDevice::SelectGpuId(std::string use_gpu) { - // Possible modes + if (device_id_ != -1) { + KALDI_ERR << "You cannot call SelectGpuId twice if, on the first time, " + "you requested a GPU."; + } if (use_gpu != "yes" && use_gpu != "no" && use_gpu != "optional" && use_gpu != "wait") { KALDI_ERR << "Please choose : --use-gpu=yes|no|optional|wait, passed '" << use_gpu << "'"; } - - // Make sure this function is not called twice! - if (Enabled()) { - KALDI_ERR << "There is already an active GPU " << active_gpu_id_ - << ", cannot change it on the fly!"; - } - // Allow the GPU to stay disabled - if (!Enabled() && use_gpu == "no") { + if (use_gpu == "no") { KALDI_LOG << "Manually selected to compute on CPU."; return; } - // Check that we have a gpu available int32 num_gpus = 0; cudaError_t e = cudaGetDeviceCount(&num_gpus); + // Make sure the global allocator object has the up-to-date options. + g_cuda_allocator.SetOptions(g_allocator_options); + if (num_gpus == 0) { if (use_gpu == "yes" || use_gpu == "wait") { KALDI_CUDA_ERR(e, "No CUDA GPU detected!"); } if (use_gpu == "optional") { - KALDI_WARN << "Running on CPU!!! No CUDA GPU detected..."; + KALDI_WARN << "No CUDA GPU detected; running on CPU since --use-gpu=optional specified."; return; } } @@ -183,8 +200,8 @@ void CuDevice::SelectGpuId(std::string use_gpu) { << " seconds before creating CUDA context"; } - // Re-assure we have the context - KALDI_ASSERT(cudaSuccess == cudaThreadSynchronize()); + // Double check that we have the context + KALDI_ASSERT(cudaSuccess == cudaDeviceSynchronize()); // Check if the machine use compute exclusive mode if (IsComputeExclusive()) { @@ -196,7 +213,7 @@ void CuDevice::SelectGpuId(std::string use_gpu) { KALDI_WARN << "Not in compute-exclusive mode. Suggestion: use " "'nvidia-smi -c 3' to set compute exclusive mode"; // We want to choose the device more carefully, so release the CUDA context. - e = cudaThreadExit(); // deprecated, but for legacy reason not cudaDeviceReset + e = cudaDeviceReset(); if (e != cudaSuccess) { KALDI_CUDA_ERR(e, "Failed to release CUDA context on a GPU"); } @@ -206,8 +223,8 @@ void CuDevice::SelectGpuId(std::string use_gpu) { FinalizeActiveGpu(); return; } else { - // Could not get GPU, after prevously having the CUDA context? - // Strange but not impossible... + // We could not get a GPU the second time, after prevously having the CUDA + // context. Strange but not impossible. if (use_gpu == "yes") { KALDI_ERR << "Error acquiring GPU."; } @@ -221,37 +238,40 @@ void CuDevice::SelectGpuId(std::string use_gpu) { void CuDevice::FinalizeActiveGpu() { - // The device at this point should have active GPU, so we can query its name - // and memory stats and notify user which GPU is finally used. + // The device at this point should have an active GPU, so we can query its + // name and memory stats and notify user which GPU is being used. - // Get the device-id of active device: + // Get the device-id of the active device. { - int32 act_gpu_id; - cudaError_t e = cudaGetDevice(&act_gpu_id); + int device_id; + cudaError_t e = cudaGetDevice(&device_id); if (e != cudaSuccess) { KALDI_CUDA_ERR(e, "Failed to get device-id of active device."); } - // Remember the id of active GPU - active_gpu_id_ = act_gpu_id; // CuDevice::Enabled() is true from now on + device_id_ = device_id; + device_id_copy_ = device_id; + initialized_ = true; // Prevent Initialize() from being called on this, + // the main thread. // Initialize CUBLAS. - CUBLAS_SAFE_CALL(cublasCreate(&handle_)); + CUBLAS_SAFE_CALL(cublasCreate(&cublas_handle_)); + CUBLAS_SAFE_CALL(cublasSetStream(cublas_handle_, cudaStreamPerThread)); // Initialize the cuSPARSE library CUSPARSE_SAFE_CALL(cusparseCreate(&cusparse_handle_)); + CUSPARSE_SAFE_CALL(cusparseSetStream(cusparse_handle_, cudaStreamPerThread)); - // Notify user which GPU is finally used + // Notify the user which GPU is being userd. char name[128]; - DeviceGetName(name,128,act_gpu_id); + DeviceGetName(name,128, device_id); - CU_SAFE_CALL(cudaGetDeviceProperties(&properties_, act_gpu_id)); + CU_SAFE_CALL(cudaGetDeviceProperties(&properties_, device_id)); - KALDI_LOG << "The active GPU is [" << act_gpu_id << "]: " << name << "\t" - << GetFreeMemory(&free_memory_at_startup_, NULL) << " version " + KALDI_LOG << "The active GPU is [" << device_id << "]: " << name << "\t" + << GetFreeGpuMemory(&free_memory_at_startup_, NULL) << " version " << properties_.major << "." << properties_.minor; } return; } - bool CuDevice::DoublePrecisionSupported() { if (!Enabled()) return true; return properties_.major > 1 || (properties_.major == 1 && properties_.minor >= 3); @@ -261,10 +281,10 @@ bool CuDevice::DoublePrecisionSupported() { bool CuDevice::IsComputeExclusive() { // assume we already have an CUDA context created - KALDI_ASSERT(cudaSuccess == cudaThreadSynchronize()); + KALDI_ASSERT(cudaSuccess == cudaDeviceSynchronize()); // get the device-id and its device-properties - int32 gpu_id = -1; + int gpu_id = -1; cudaError_t e = cudaGetDevice(&gpu_id); if (e != cudaSuccess) { KALDI_CUDA_ERR(e, "Failed to get current device"); @@ -279,11 +299,9 @@ bool CuDevice::IsComputeExclusive() { case cudaComputeModeExclusive : return true; break; -#if (CUDA_VERSION >= 4000) case cudaComputeModeExclusiveProcess : return true; break; -#endif default : // in this case we release the GPU context... return false; @@ -318,37 +336,35 @@ bool CuDevice::SelectGpuIdAuto() { switch(ret) { case cudaSuccess : { // create the CUDA context for the thread - cudaThreadSynchronize(); // deprecated, but for legacy not cudaDeviceSynchronize + cudaDeviceSynchronize(); // get GPU name char name[128]; DeviceGetName(name,128,n); // get GPU memory stats int64 free, total; std::string mem_stats; - mem_stats = GetFreeMemory(&free, &total); + mem_stats = GetFreeGpuMemory(&free, &total); // log KALDI_LOG << "cudaSetDevice(" << n << "): " << name << "\t" << mem_stats; - // We have seen that in some cases GetFreeMemory returns zero + // We have seen that in some cases GetFreeGpuMemory returns zero // That will produce nan after division, which might confuse // the sorting routine. Or maybe not, but let's keep it clean if (total <= 0) { - KALDI_LOG << "Total memory reported for device " << n << " is zero (or less)."; + KALDI_LOG << "Total memory reported for device " << n + << " is zero (or less)."; } float mem_ratio = total > 0 ? free/(float)total : 0; free_mem_ratio[n] = std::make_pair(n, mem_ratio); // destroy the CUDA context for the thread - cudaThreadExit(); // deprecated, but for legacy reason not cudaDeviceReset + cudaDeviceReset(); } break; - -#if (CUDA_VERSION > 3020) case cudaErrorDeviceAlreadyInUse : KALDI_LOG << "cudaSetDevice(" << n << "): " << "Device cannot be accessed, used EXCLUSIVE-THREAD mode..."; break; -#endif case cudaErrorInvalidDevice : KALDI_LOG << "cudaSetDevice(" << n << "): " << "Device cannot be accessed, not a VALID CUDA device!"; @@ -366,7 +382,7 @@ bool CuDevice::SelectGpuIdAuto() { // the free_mem_ratio should be bigger than zero KALDI_ASSERT(free_mem_ratio[max_id].second > 0.0); - float dev_id; + int dev_id; float mem_ratio; do { // try to select the GPU in the best to worst order @@ -382,7 +398,7 @@ bool CuDevice::SelectGpuIdAuto() { KALDI_WARN << "Cannot select this device: return code " << e << ", Error message: \"" << cudaGetErrorString(e) << "\""; } else { - e = cudaThreadSynchronize(); // deprecated, but for legacy not cudaDeviceSynchronize + e = cudaDeviceSynchronize(); if (e != cudaSuccess) { KALDI_WARN << "Cannot select this device: return code " << e << ", Error message: \"" << cudaGetErrorString(e) << "\""; @@ -403,10 +419,16 @@ bool CuDevice::SelectGpuIdAuto() { void CuDevice::AccuProfile(const char *function_name, const CuTimer &timer) { if (GetVerboseLevel() >= 1) { + std::unique_lock lock(profile_mutex_, std::defer_lock_t()); + if (multi_threaded_) + lock.lock(); std::string key(function_name); - cudaDeviceSynchronize(); + // by passing 0 as the stream to cudaStreamSynchronize, we are using the + // per-thread default stream. Since we compile with + // -DCUDA_API_PER_THREAD_DEFAULT_STREAM, this equates to a per-thread + // stream. + cudaStreamSynchronize(0); double elapsed = timer.Elapsed(); - if (profile_map_.find(key) == profile_map_.end()) profile_map_[key] = elapsed; else @@ -415,13 +437,8 @@ void CuDevice::AccuProfile(const char *function_name, } void CuDevice::PrintMemoryUsage() const { - if (Enabled()) { - allocator_.PrintMemoryUsage(); - int64 free_memory_now; - GetFreeMemory(&free_memory_now, NULL); - KALDI_LOG << "Memory used (according to the device): " - << (free_memory_at_startup_ - free_memory_now) << " bytes."; - } + if (Enabled()) + g_cuda_allocator.PrintMemoryUsage(); } void CuDevice::PrintProfile() { @@ -452,60 +469,6 @@ void CuDevice::PrintProfile() { } -std::string CuDevice::GetFreeMemory(int64* free, int64* total) const { - // WARNING! the CUDA API is inconsistent accross versions! -#ifdef _MSC_VER - size_t mem_free, mem_total; - cuMemGetInfo_v2(&mem_free, &mem_total); -#else -#if (CUDA_VERSION >= 3020) - // define the function signature type - size_t mem_free, mem_total; -#else - unsigned int mem_free, mem_total; -#endif - { - // we will load cuMemGetInfo_v2 dynamically from libcuda.so - // pre-fill ``safe'' values that will not cause problems - mem_free = 1; mem_total = 1; - // open libcuda.so - void* libcuda = dlopen("libcuda.so",RTLD_LAZY); - if (NULL == libcuda) { - KALDI_WARN << "cannot open libcuda.so"; - } else { - // define the function signature type - // and get the symbol -#if (CUDA_VERSION >= 3020) - typedef CUresult (*cu_fun_ptr)(size_t*, size_t*); - cu_fun_ptr dl_cuMemGetInfo = (cu_fun_ptr)dlsym(libcuda,"cuMemGetInfo_v2"); -#else - typedef CUresult (*cu_fun_ptr)(int*, int*); - cu_fun_ptr dl_cuMemGetInfo = (cu_fun_ptr)dlsym(libcuda,"cuMemGetInfo"); -#endif - if (NULL == dl_cuMemGetInfo) { - KALDI_WARN << "cannot load cuMemGetInfo from libcuda.so"; - } else { - // call the function - dl_cuMemGetInfo(&mem_free, &mem_total); - } - // close the library - dlclose(libcuda); - } - } -#endif - // copy the output values outside - if (NULL != free) *free = mem_free; - if (NULL != total) *total = mem_total; - // prepare the text output - std::ostringstream os; - os << "free:" << mem_free/(1024*1024) << "M, " - << "used:" << (mem_total-mem_free)/(1024*1024) << "M, " - << "total:" << mem_total/(1024*1024) << "M, " - << "free/total:" << mem_free/(float)mem_total; - return os.str(); -} - - void CuDevice::DeviceGetName(char* name, int32 len, int32 dev) { // prefill with something reasonable strncpy(name,"Unknown GPU",len); @@ -554,15 +517,49 @@ void CuDevice::CheckGpuHealth() { AccuProfile(__func__, t); } -CuDevice::CuDevice() : - active_gpu_id_(-1), debug_stride_mode_(false), - num_debug_stride_allocations_(0), allocator_(CuAllocatorOptions()), - multi_threaded_(false) { } +CuDevice::CuDevice(): + initialized_(false), + device_id_copy_(-1), + cublas_handle_(NULL), + cusparse_handle_(NULL) { +} + +CuDevice::~CuDevice() { + if (cublas_handle_) + CUBLAS_SAFE_CALL(cublasDestroy(cublas_handle_)); + if (cusparse_handle_) + CUSPARSE_SAFE_CALL(cusparseDestroy(cusparse_handle_)); +} + +// Each thread has its own copy of the CuDevice object. +// Note: this was declared "static". +thread_local CuDevice CuDevice::this_thread_device_; + +CuDevice::CuDeviceOptions CuDevice::device_options_; -// The instance of the static singleton -CuDevice CuDevice::global_device_; +// define and initialize the static members of the CuDevice object. +int32 CuDevice::device_id_ = -1; +bool CuDevice::multi_threaded_ = false; +unordered_map CuDevice::profile_map_; +std::mutex CuDevice::profile_mutex_; +int64 CuDevice::free_memory_at_startup_; +cudaDeviceProp CuDevice::properties_; +bool CuDevice::debug_stride_mode_ = false; + + +void SynchronizeGpu() { + cuda_legacy_noop(); + CU_SAFE_CALL(cudaGetLastError()); } +} // namespace kaldi + +#else // #if HAVE_CUDA == 1 + +namespace kaldi { +// SynchronizeGpu() does nothing if we didn't compile for GPU. +void SynchronizeGpu() { } +} -#endif // HAVE_CUDA +#endif // #if HAVE_CUDA == 1 diff --git a/src/cudamatrix/cu-device.h b/src/cudamatrix/cu-device.h index 99105355a8f..8816f9d223b 100644 --- a/src/cudamatrix/cu-device.h +++ b/src/cudamatrix/cu-device.h @@ -24,7 +24,6 @@ #define KALDI_CUDAMATRIX_CU_DEVICE_H_ #if HAVE_CUDA == 1 - #include #include #include @@ -41,61 +40,95 @@ namespace kaldi { class CuTimer; /** - * Singleton object which represents the CUDA device - * responsible for CUBLAS initilalisation, collects profiling info + This class contains code for selecting the CUDA device, initializing the + cuBLAS and cuSparse handles, and providing an interface for memory allocation + (which supports caching, to avoid the slowness of the CUDA memory allocator). + + There is a separate instance of the CuDevice object for each thread of the + program, but many of its variables are static (hence, shared between all + instances). + + We only (currently) support using a single GPU device; however, we support + multiple CUDA streams. The expected programming model here is that you will + have multiple CPU threads, and each CPU thread automatically gets its own + CUDA stream because we compile with -DCUDA_API_PER_THREAD_DEFAULT_STREAM. + + In terms of synchronizing the activities of multiple threads: The CuDevice + object (with help from the underlying CuAllocator object) ensures that the + memory caching code won't itself be a cause of synchronization problems, + i.e. you don't have to worry that when you allocate with CuDevice::Malloc(), + the memory will still be in use by another thread on the GPU. However, it + may sometimes still be necessary to synchronize the activities of multiple + streams by calling the function SynchronizeGpu()-- probably right before a + thread increments a semaphore, right after it waits on a semaphore, or + right after it acquires a mutex, or something like that. + */ class CuDevice { - // Singleton object (there should only be one instantiated per program) public: - static inline CuDevice& Instantiate() { return global_device_; } - inline cublasHandle_t GetHandle() { return handle_; } + // You obtain the CuDevice for the current thread by calling + // CuDevice::Instantiate() + // At the beginning of the program, if you want to use a GPU, you + // should call CuDevice::Instantiate().SelectGpuId(..). + static inline CuDevice& Instantiate() { + CuDevice &ans = this_thread_device_; + if (!ans.initialized_) + ans.Initialize(); + return ans; + } + + inline cublasHandle_t GetCublasHandle() { return cublas_handle_; } inline cusparseHandle_t GetCusparseHandle() { return cusparse_handle_; } - // We provide functions Malloc, MallocPitch and Free which replace cudaMalloc, - // cudaMallocPitch and cudaFree. Their function is to cache the results of - // previous allocations to avoid the very large overhead that CUDA's - // allocation seems to give for some setups. + // We provide functions Malloc(), MallocPitch() and Free() which replace + // cudaMalloc(), cudaMallocPitch() and cudaFree(). Their function is to cache + // the results of previous allocations to avoid the very large overhead that + // CUDA's allocation seems to give for some setups. inline void* Malloc(size_t size) { - return multi_threaded_ ? allocator_.MallocLocking(size) : - allocator_.Malloc(size); + return multi_threaded_ ? g_cuda_allocator.MallocLocking(size) : + g_cuda_allocator.Malloc(size); } inline void* MallocPitch(size_t row_bytes, size_t num_rows, size_t *pitch) { if (multi_threaded_) { - return allocator_.MallocPitchLocking(row_bytes, num_rows, pitch); + return g_cuda_allocator.MallocPitchLocking(row_bytes, num_rows, pitch); } else if (debug_stride_mode_) { // The pitch bucket size is hardware dependent. // It is 512 on K40c with CUDA 7.5 // "% 8" ensures that any 8 adjacent allocations have different pitches // if their original pitches are same in the normal mode. - return allocator_.MallocPitch( - row_bytes + 512 * ((num_debug_stride_allocations_++) % 8), num_rows, + return g_cuda_allocator.MallocPitch( + row_bytes + 512 * RandInt(0, 4), num_rows, pitch); } else { - return allocator_.MallocPitch(row_bytes, num_rows, pitch); + return g_cuda_allocator.MallocPitch(row_bytes, num_rows, pitch); } } + inline void Free(void *ptr) { - if (multi_threaded_) allocator_.FreeLocking(ptr); - else allocator_.Free(ptr); + if (multi_threaded_) g_cuda_allocator.FreeLocking(ptr); + else g_cuda_allocator.Free(ptr); } - /// Select a GPU for computation, the 'use_gpu' modes are: - /// "yes" -- Select GPU automatically and die if this fails. + /// Select a GPU for computation. You are supposed to call this function just + /// once, at the beginning of the program (from the main thread), or not at + /// all. + /// The 'use_gpu' modes are: + /// "yes" -- Select GPU automatically and die if this fails. If you have set + /// the GPUs to exclusive mode it will select one + /// pseudo-randomly; otherwise it will choose whichever one has + /// the most free memory (but we recommend to set GPUs to + /// exclusive mode, or controlling which GPU to use by setting + /// the variable CUDA_VISIBLE_DEVICES to the id of the GPU you + /// want the program to use. /// "optional" -- Do as above, but if it fails, back off to CPU. /// "no" -- Run on CPU. - /// (more comments in cu-device.cc) void SelectGpuId(std::string use_gpu); /// Check if the CUDA GPU is selected for use bool Enabled() const { - return (active_gpu_id_ > -1); - } - - /// Get the active GPU id - int32 ActiveGpuId() { - return active_gpu_id_; + return (device_id_ > -1); } /// Returns true if either we have no GPU, or we have a GPU @@ -106,21 +139,19 @@ class CuDevice { /// are printed out when you call PrintProfile(). However, /// it only does something if VerboseLevel() >= 1. void AccuProfile(const char *function_name, const CuTimer &timer); + + /// Print some profiling information using KALDI_LOG. void PrintProfile(); + /// Print some memory-usage information using KALDI_LOG. void PrintMemoryUsage() const; /// The user should call this if the program plans to access the GPU (e.g. via /// using class CuMatrix) from more than one thread. If you fail to call this - /// for a multi-threaded program, it will occasionally segfault. + /// for a multi-threaded program, it may occasionally segfault (and also + /// the code will detect that you failed to call it, and will print a warning). inline void AllowMultithreading() { multi_threaded_ = true; } - void ResetProfile() { - profile_map_.clear(); - } - - /// Get the actual GPU memory use stats - std::string GetFreeMemory(int64* free = NULL, int64* total = NULL) const; /// Get the name of the GPU void DeviceGetName(char* name, int32 len, int32 dev); @@ -153,22 +184,56 @@ class CuDevice { /// (i.e. from outside the class), call this only if Enabled() returns true. bool IsComputeExclusive(); + // Register command line options for CUDA device. + // This must be done before calling CuDevice::Initialize() + // Example: + // CuDevice::RegisterDeviceOptions(&po); + // po.Read(argc, argv); + // CuDevice::Initialize(); + static void RegisterDeviceOptions(OptionsItf *po) { + CuDevice::device_options_.Register(po); + } + ~CuDevice(); private: + + struct CuDeviceOptions { + bool use_tensor_cores; // Enable tensor cores + CuDeviceOptions () : use_tensor_cores(false) {}; + void Register(OptionsItf *po) { + po->Register("cuda-use-tensor-cores", &use_tensor_cores, + "Enable FP16 tensor math. " + "This is higher performance but less accuracy. " + "This is only recommended for inference."); + } + }; + + static CuDeviceOptions device_options_; + + // Default constructor used to initialize this_thread_device_ CuDevice(); CuDevice(CuDevice&); // Disallow. CuDevice &operator=(CuDevice&); // Disallow. - static CuDevice global_device_; - cublasHandle_t handle_; - cusparseHandle_t cusparse_handle_; + /// The Initialize() function exists to do the following, in threads other + /// than the main thread, and only if we are using a GPU: call + /// cudaSetDevice(), and set up cublas_handle_ and cusparse_handle_. It does + /// get called in the main thread (see documentation by its definition), but + /// does nothing interesting there. + void Initialize(); - /// Automatically select GPU and get CUDA context. Returns true on success. + /// Automatically select GPU and get CUDA context (this is only called, from + /// SelectGpuId(), if the GPUs are in non-exclusive mode). Returns true on + /// success. bool SelectGpuIdAuto(); - /// Try to get CUDA context on manually selected GPU. Return true on success. - bool SelectGpuIdManual(int32 gpu_id); - + /// This function, called from SelectGpuId(), is to be called when a + /// GPU context corresponding to the GPU we want to use exists; it + /// works out the device-id, creates the cuBLAS and cuSparse handles, + /// and prints out some information that's useful for debugging. + /// It also sets initialized_ to true, to suppress Initialize() from + /// being called on this, the main thread, in future, since + /// that would try to create the handles again. void FinalizeActiveGpu(); /// Should only be called if Enabled() == true. @@ -177,29 +242,58 @@ class CuDevice { /// Should only be called if Enabled() == true. int32 MinorDeviceVersion(); - unordered_map profile_map_; - /// active_gpu_id_ values: - /// -3 default (default, the SelectGpuId was not called, we did not want to use GPU) - /// -2 SelectGpuId was called, but no GPU was present - /// -1 SelectGpuId was called, but the GPU was manually disabled - /// 0..N Normal GPU IDs - int32 active_gpu_id_; + // Each thread has its own CuDevice object, which contains the cublas and + // cusparse handles. These are unique to the thread (which is what is + // recommended by NVidia). + static thread_local CuDevice this_thread_device_; - int64 free_memory_at_startup_; + // The GPU device-id that we are using. This will be initialized to -1, and will + // be set when the user calls + // CuDevice::Instantiate::SelectGpuId(...) + // from the main thread. Background threads will, when spawned and when + // CuDevice::Instantiate() is called from them the first time, will + // call cudaSetDevice(device_id)) + static int32 device_id_; - cudaDeviceProp properties_; + // This will automatically be set to true if the application has multiple + // threads that access the GPU device. It is used to know whether to + // use locks when accessing the allocator and the profiling-related code. + static bool multi_threaded_; - // there used to be a 'bool verbose_' here. I'm leaving a placeholder here - // instead of removing it because it causes particularly hard-to-debug errors - // if compilation is not done right (e.g. make depend was not done), and this - // class's members move about. - bool unused_; - bool debug_stride_mode_; - uint32 num_debug_stride_allocations_; + // The variable profile_map_ will only be used if the verbose level is >= 1; + // it will accumulate some function-level timing information that is printed + // out at program end. This makes things a bit slower as we have to call + // cudaDeviceSynchronize() to make the timing information meaningful. + static unordered_map profile_map_; + // profile_mutex_ guards profile_map_ in case multi_threaded_ is true. + static std::mutex profile_mutex_; + + // free_memory_at_startup_ is just used in printing the memory used according + // to the device. + static int64 free_memory_at_startup_; + static cudaDeviceProp properties_; + + // If set to true by SetDebugStrideMode(), code will be activated to use + // pseudo-random stride values when allocating data (to detect errors which + // otherwise would be rare). + static bool debug_stride_mode_; + + + // The following member variable is initialized to false; if the user calls + // Instantiate() in a thread where it is still false, Initialize() will be + // called, in order to -- if a GPU is being used-- call cudaSetDevice() and + // set up the cublas and cusparse handles. + bool initialized_; + + // This variable is just a copy of the static variable device_id_. It's used + // to detect when this code is called in the wrong way. + int32 device_id_copy_; + + cublasHandle_t cublas_handle_; + + cusparseHandle_t cusparse_handle_; - CuMemoryAllocator allocator_; - bool multi_threaded_; // true if user called AllowMultithreading(). }; // class CuDevice @@ -214,13 +308,38 @@ class CuTimer: public Timer { // This function is declared as a more convenient way to get the CUDA device handle for use // in the CUBLAS v2 API, since we so frequently need to access it. -inline cublasHandle_t GetCublasHandle() { return CuDevice::Instantiate().GetHandle(); } +inline cublasHandle_t GetCublasHandle() { return CuDevice::Instantiate().GetCublasHandle(); } // A more convenient way to get the handle to use cuSPARSE APIs. inline cusparseHandle_t GetCusparseHandle() { return CuDevice::Instantiate().GetCusparseHandle(); } -} // namespace + +} // namespace kaldi #endif // HAVE_CUDA -#endif +namespace kaldi { + +/** + The function SynchronizeGpu(), which for convenience is defined whether or + not we have compiled for CUDA, is intended to be called in places where threads + need to be synchronized. + + It just launches a no-op kernel into the legacy default stream. This will + have the effect that it will run after any kernels previously launched from + any stream(*), and before kernels that will later be launched from any stream(*). + (*) does not apply to non-blocking streams. + + Note: at the time of writing we never call SynchronizeGpu() from binary-level + code because it hasn't become necessary yet; the only program that might have + multiple threads actually using the GPU is rnnlm-train (if the user were to + invoke it with the ,bg option for loading training examples); but the only + CUDA invocation the RnnlmExample::Read() function uses (via + CuMatrix::Read()), is cudaMemcpy, which is synchronous already. + +*/ +void SynchronizeGpu(); + +} // namespace kaldi + +#endif // KALDI_CUDAMATRIX_CU_DEVICE_H_ diff --git a/src/cudamatrix/cu-kernels-ansi.h b/src/cudamatrix/cu-kernels-ansi.h index ebbcb9da5ff..a61bb601e8e 100644 --- a/src/cudamatrix/cu-kernels-ansi.h +++ b/src/cudamatrix/cu-kernels-ansi.h @@ -790,6 +790,10 @@ void cuda_uncompress_uint8(dim3 Gr, dim3 Bl, BaseFloat *dest, MatrixDim dim, const uint8_t *src, int src_stride, float scale); +// Launches a kernel that does nothing, explicitly using the legacy default stream; +// this will synchronize all CUDA streams (except for non-blocking streams) on the +// device. +void cuda_legacy_noop(); } // extern "C" diff --git a/src/cudamatrix/cu-kernels.cu b/src/cudamatrix/cu-kernels.cu index 4101d5ba52f..515412ca398 100644 --- a/src/cudamatrix/cu-kernels.cu +++ b/src/cudamatrix/cu-kernels.cu @@ -28,7 +28,7 @@ #include #include #include "cudamatrix/cu-kernels-ansi.h" - +#include /*********************************************************************** @@ -958,6 +958,7 @@ static void _trace_mat_mat(const Real* A, const Real* B, MatrixDim dA, Real trans[TileDim][TileDim + 1]; Real sum[CU1DBLOCK]; } smem; + // linear thread id; const int32_cuda tid = threadIdx.y * blockDim.x + threadIdx.x; const int32_cuda grid_height = gridDim.y * TileDim; @@ -1021,6 +1022,7 @@ static void _trace_mat_mat(const Real* A, const Real* B, MatrixDim dA, if (tid == 0) { value[blockIdx.y * gridDim.x + blockIdx.x] = smem.sum[0]; } + } // _trace_mat_mat_trans reduce the partial sum to @@ -1030,6 +1032,7 @@ __global__ static void _trace_mat_mat_trans(const Real* A, const Real* B, MatrixDim dA, int B_stride, Real* value) { __shared__ Real ssum[CU1DBLOCK]; + // linear thread id; const int32_cuda tid = threadIdx.y * blockDim.x + threadIdx.x; const int32_cuda j = blockIdx.x * blockDim.x + threadIdx.x; @@ -1046,7 +1049,7 @@ static void _trace_mat_mat_trans(const Real* A, const Real* B, MatrixDim dA, } ssum[tid] = tsum; __syncthreads(); - + // Block reduce # pragma unroll for (int shift = CU1DBLOCK / 2; shift > warpSize; shift >>= 1) { @@ -2485,6 +2488,8 @@ template __global__ static void _softmax_reduce(Real*y, const Real*x, MatrixDim d, int src_stride) { __shared__ Real smem[CU1DBLOCK]; + typedef cub::BlockReduce BlockReduceT; + __shared__ typename BlockReduceT::TempStorage temp_storage; const int i = blockIdx.x; const int x_start = i * src_stride; const int y_start = i * d.stride; @@ -2496,24 +2501,9 @@ static void _softmax_reduce(Real*y, const Real*x, MatrixDim d, int src_stride) { for (int j = tid; j < d.cols; j += CU1DBLOCK) { tmax = fmax(tmax, x[x_start + j]); } - smem[tid] = tmax; - __syncthreads(); - - // reduce to 2x warpSize elements per row -# pragma unroll - for (int shift = CU1DBLOCK / 2; shift > warpSize; shift >>= 1) { - if (tid < shift) { - smem[tid] = fmax(smem[tid], smem[tid + shift]); - } - __syncthreads(); - } - - // reduce to 1 element per row - if (tid < warpSize) { -# pragma unroll - for (int shift = warpSize; shift > 0; shift >>= 1) { - smem[tid] = fmax(smem[tid], smem[tid + shift]); - } + tmax = BlockReduceT(temp_storage).Reduce(tmax, cub::Max()); + if (tid == 0) { + smem[0] = tmax; } // broadcast max to all threads @@ -2526,24 +2516,9 @@ static void _softmax_reduce(Real*y, const Real*x, MatrixDim d, int src_stride) { for (int j = tid; j < d.cols; j += CU1DBLOCK) { tsum += exp(x[x_start + j] - max); } - smem[tid] = tsum; - __syncthreads(); - - // reduce to 2x warpSize elements per row -# pragma unroll - for (int shift = CU1DBLOCK / 2; shift > warpSize; shift >>= 1) { - if (tid < shift) { - smem[tid] += smem[tid + shift]; - } - __syncthreads(); - } - - // reduce to 1 element per row - if (tid < warpSize) { -# pragma unroll - for (int shift = warpSize; shift > 0; shift >>= 1) { - smem[tid] += smem[tid + shift]; - } + tsum = BlockReduceT(temp_storage).Sum(tsum); + if (tid == 0) { + smem[0] = tsum; } // broadcast sum to all threads @@ -2577,43 +2552,28 @@ static void _normalize_per_row(Real *y, int y_stride, const Real *x, const int i = blockIdx.x; const int tid = threadIdx.x; const Real* x_row = x + i * x_d.stride; - __shared__ Real ssum[CU1DBLOCK]; + + typedef cub::BlockReduce BlockReduceT; + __shared__ typename BlockReduceT::TempStorage temp_storage; + + __shared__ Real stddev_div_target_rms; + __shared__ Real scale; // Reduce x_j^2 to CU1DBLOCK elements per row Real tsum = Real(0); for (int j = tid; j < x_d.cols; j += CU1DBLOCK) { tsum += x_row[j] * x_row[j]; } - ssum[tid] = tsum; + tsum = BlockReduceT(temp_storage).Sum(tsum); __syncthreads(); - // Tree reduce to 2x warpSize elements per row -# pragma unroll - for (int shift = CU1DBLOCK / 2; shift > warpSize; shift >>= 1) { - if (tid < shift) - ssum[tid] += ssum[tid + shift]; - __syncthreads(); - } - - // Reduce last warp to 1 element per row. - // Threads implicitly synchronized within a warp. - if (tid < warpSize) { -# pragma unroll - for (int shift = warpSize; shift > 0; shift >>= 1) { - ssum[tid] += ssum[tid + shift]; - } - } - - const Real kSquaredNormFloor = 1.3552527156068805425e-20; // 2^-66 if (tid == 0) { - ssum[0] = sqrt( - fmax(ssum[0] / (target_rms * target_rms * x_d.cols), kSquaredNormFloor)); + const Real kSquaredNormFloor = 1.3552527156068805425e-20; // 2^-66 + stddev_div_target_rms = sqrt( + fmax(tsum / (target_rms * target_rms * x_d.cols), kSquaredNormFloor)); + scale = Real(1) / stddev_div_target_rms; } - - // Broadcast floored stddev to all threads. __syncthreads(); - const Real stddev_div_target_rms = ssum[0]; - const Real scale = Real(1) / stddev_div_target_rms; // Store normalized input to output Real* y_row = y + i * y_stride; @@ -2626,7 +2586,6 @@ static void _normalize_per_row(Real *y, int y_stride, const Real *x, } } - template __global__ static void _diff_normalize_per_row(Real *id, int id_stride, const Real *iv, @@ -2722,6 +2681,8 @@ __global__ static void _log_softmax_reduce(Real* y, const Real* x, MatrixDim y_dim, int x_stride) { __shared__ Real smem[CU1DBLOCK]; + typedef cub::BlockReduce BlockReduceT; + __shared__ typename BlockReduceT::TempStorage temp_storage; const int i = blockIdx.x; const int x_start = i * x_stride; const int y_start = i * y_dim.stride; @@ -2733,23 +2694,9 @@ static void _log_softmax_reduce(Real* y, const Real* x, MatrixDim y_dim, for (int j = tid; j < y_dim.cols; j += CU1DBLOCK) { tmax = fmax(tmax, x[x_start + j]); } - smem[tid] = tmax; - __syncthreads(); - - // reduce to 2x warpSize elements per row -# pragma unroll - for (int shift = CU1DBLOCK / 2; shift > warpSize; shift >>= 1) { - if (tid < shift) { - smem[tid] = fmax(smem[tid], smem[tid + shift]); - } - __syncthreads(); - } - - // reduce to 1 element per row - if (tid < warpSize) { - for (int shift = warpSize; shift > 0; shift >>= 1) { - smem[tid] = fmax(smem[tid], smem[tid + shift]); - } + tmax = BlockReduceT(temp_storage).Reduce(tmax, cub::Max()); + if (tid == 0) { + smem[0] = tmax; } // broadcast max to all threads @@ -2762,23 +2709,9 @@ static void _log_softmax_reduce(Real* y, const Real* x, MatrixDim y_dim, for (int j = tid; j < y_dim.cols; j += CU1DBLOCK) { tsum += exp(x[x_start + j] - max); } - smem[tid] = tsum; - __syncthreads(); - - // reduce to 2x warpSize elements per row -# pragma unroll - for (int shift = CU1DBLOCK / 2; shift > warpSize; shift >>= 1) { - if (tid < shift) { - smem[tid] += smem[tid + shift]; - } - __syncthreads(); - } - - // reduce to 1 element per row - if (tid < warpSize) { - for (int shift = warpSize; shift > 0; shift >>= 1) { - smem[tid] += smem[tid + shift]; - } + tsum = BlockReduceT(temp_storage).Sum(tsum); + if (tid == 0) { + smem[0] = tsum; } // broadcast sum to all threads @@ -3024,6 +2957,9 @@ static void _diff_softmax(Real* x, const MatrixDim dim, const Real* value, const int value_stride, const Real* diff, const int diff_stride) { __shared__ Real ssum[CU1DBLOCK]; + typedef cub::BlockReduce BlockReduceT; + __shared__ typename BlockReduceT::TempStorage temp_storage; + const int tid = threadIdx.x; const int i = blockIdx.x; const int value_start = i * value_stride; @@ -3035,24 +2971,9 @@ static void _diff_softmax(Real* x, const MatrixDim dim, const Real* value, for (int j = tid; j < dim.cols; j += CU1DBLOCK) { tsum += value[value_start + j] * diff[diff_start + j]; } - ssum[tid] = tsum; - __syncthreads(); - - // Tree reduce to 2x warpSize elements. -# pragma unroll - for (int shift = CU1DBLOCK / 2; shift > warpSize; shift >>= 1) { - if (tid < shift) { - ssum[tid] += ssum[tid + shift]; - } - __syncthreads(); - } - - // Warp reduce to 1 element. Threads implicitly synchronized within a warp. - if (tid < warpSize) { -# pragma unroll - for (int shift = warpSize; shift > 0; shift >>= 1) { - ssum[tid] += ssum[tid + shift]; - } + tsum = BlockReduceT(temp_storage).Sum(tsum); + if (tid == 0) { + ssum[0] = tsum; } // Broadcast result to all threads @@ -3078,6 +2999,8 @@ static void _diff_log_softmax(const MatrixDim in_deriv_dim, Real* in_deriv) { __shared__ Real ssum[CU1DBLOCK]; + typedef cub::BlockReduce BlockReduceT; + __shared__ typename BlockReduceT::TempStorage temp_storage; const int tid = threadIdx.x; const int i = blockIdx.x; const int out_value_start = i * out_value_stride; @@ -3089,24 +3012,9 @@ static void _diff_log_softmax(const MatrixDim in_deriv_dim, for (int j = tid; j < in_deriv_dim.cols; j += CU1DBLOCK) { tsum += out_deriv[out_deriv_start + j]; } - ssum[tid] = tsum; - __syncthreads(); - - // Tree reduce to 2x warpSize elements. -# pragma unroll - for (int shift = CU1DBLOCK / 2; shift > warpSize; shift >>= 1) { - if (tid < shift) { - ssum[tid] += ssum[tid + shift]; - } - __syncthreads(); - } - - // Warp reduce to 1 element. Threads implicitly synchronized within a warp. - if (tid < warpSize) { -# pragma unroll - for (int shift = warpSize; shift > 0; shift >>= 1) { - ssum[tid] += ssum[tid + shift]; - } + tsum = BlockReduceT(temp_storage).Sum(tsum); + if (tid == 0) { + ssum[0] = tsum; } // Broadcast result to all threads @@ -3699,7 +3607,9 @@ static void _cuda_uncompress(BaseFloat *dest, MatrixDim dim, } } - +__global__ +static void _noop_kernel() { +} /*********************************************************************** * ANSI-C wrappers of CUDA kernels @@ -5459,3 +5369,10 @@ void cuda_uncompress_int16(dim3 Gr, dim3 Bl, BaseFloat *dest, int src_stride, float scale) { _cuda_uncompress<<>>(dest, dim, src, src_stride, scale); } + + +// Launches a kernel that does nothing, explicitly using the legacy default stream; +// this will synchronize all threads without blocking. +void cuda_legacy_noop() { + _noop_kernel<<<1, 1, 0, cudaStreamLegacy>>>(); +} diff --git a/src/cudamatrix/cu-math-test.cc b/src/cudamatrix/cu-math-test.cc index 09255c9587b..022742ed29f 100644 --- a/src/cudamatrix/cu-math-test.cc +++ b/src/cudamatrix/cu-math-test.cc @@ -545,6 +545,50 @@ static void UnitTestCuMathNormalizePerRow() { } } + +template +static void UnitTestCuMathNormalizePerRow_v2() { + + int row = 128; + int col = 1024; + + Matrix Hi(row,col); + Matrix Ho(row,col); + Hi.SetRandn(); + Hi.Scale(5.0); + Hi.ApplyFloor(0.0); // like ReLU, + + CuMatrix Di(row, col); + CuMatrix Do(row, col); + Di.CopyFromMat(Hi); + + Real target_rms = 0.3456; + bool add_log_stddev = false; + const Real kSquaredNormFloor = 1.35525271560688e-20; // 2^-66 + + //gpu + cu::NormalizePerRow(Di, target_rms, add_log_stddev, &Do); + + //cpu + { + MatrixBase& in(Hi); + MatrixBase& out(Ho); + Real target_rms=0.3456; + Vector in_norm(in.NumRows()); + Real d_scaled = in.NumCols() * target_rms * target_rms; + in_norm.AddDiagMat2(1.0 / d_scaled, in, kNoTrans, 0.0); + in_norm.ApplyFloor(kSquaredNormFloor); + in_norm.ApplyPow(-0.5); + out.CopyFromMat(in); + out.MulRowsVec(in_norm); + } + + Matrix Ho2(Do); + // here the BUG was detected (by processing big-enough matrix), + AssertEqual(Ho,Ho2,0.00001); +} + + template static void UnitTestCuDiffNormalizePerRow() { for (int32 i = 0; i < 2; i++) { @@ -660,6 +704,7 @@ template void CudaMathUnitTest() { UnitTestEnsureNonzero(); UnitTestBackpropLstmNonlinearity(); UnitTestCuMathNormalizePerRow(); + UnitTestCuMathNormalizePerRow_v2(); UnitTestCuDiffNormalizePerRow(); } @@ -673,9 +718,9 @@ int main() { for (; loop < 2; loop++) { CuDevice::Instantiate().SetDebugStrideMode(true); if (loop == 0) - CuDevice::Instantiate().SelectGpuId("no"); // -1 means no GPU + CuDevice::Instantiate().SelectGpuId("no"); // 0 means no GPU else - CuDevice::Instantiate().SelectGpuId("yes"); // -2 .. automatic selection + CuDevice::Instantiate().SelectGpuId("yes"); // 1 .. automatic selection #endif srand(time(NULL)); kaldi::CudaMathUnitTest(); diff --git a/src/cudamatrix/cu-matrix-inl.h b/src/cudamatrix/cu-matrix-inl.h index 9b7a707d2e5..0e182d4e72a 100644 --- a/src/cudamatrix/cu-matrix-inl.h +++ b/src/cudamatrix/cu-matrix-inl.h @@ -36,6 +36,7 @@ inline CuSubMatrix::CuSubMatrix(const CuMatrixBase &mat, // initializer, so nothing to do. } else { KALDI_ASSERT(row_offset >= 0 && col_offset >= 0 && + num_rows >= 0 && num_cols >= 0 && row_offset + num_rows <= mat.num_rows_ && col_offset + num_cols <= mat.num_cols_); this->data_ = mat.data_ + static_cast(col_offset) + @@ -68,5 +69,3 @@ inline CuSubMatrix::CuSubMatrix(const Real *data, } // namespace kaldi #endif - - diff --git a/src/cudamatrix/cu-matrix-test.cc b/src/cudamatrix/cu-matrix-test.cc index 01030bb8353..46bc6ea0cb2 100644 --- a/src/cudamatrix/cu-matrix-test.cc +++ b/src/cudamatrix/cu-matrix-test.cc @@ -1327,6 +1327,8 @@ template static void UnitTestCuMatrixSetMatMatDivMat() { B.SetRandn(); C.SetRandn(); + C.ApplyFloor(0.01); // make sure there are no zeros. + M.SetMatMatDivMat(A,B,C); ref.AddMatMatElements(1.0, A, B, 0.0); ref.DivElements(C); @@ -2620,7 +2622,11 @@ static int32 DoubleFactorial(int32 i) { template static void UnitTestCuMatrixSetRandn() { - { // First test consistency when called twice. + + if (false) { + // This block tests consistency when called twice. + // It has been disabled since we added multi-threaded testing, + // since consistency wouldn't be expected if other threads were running. int32 dimM = 100 + Rand() % 200, dimN = 100 + Rand() % 200; Matrix M(dimM, dimN), N(dimM, dimN); srand(104); @@ -3040,16 +3046,38 @@ template void CudaMatrixUnitTest() { int main() { SetVerboseLevel(1); int32 loop = 0; + bool test_threads = true; + // num_threads only matters if test_threads == true. Don't make it + // to large, because it will affect CPU usage if you are using CPU. + int32 num_threads = 4; + + #if HAVE_CUDA == 1 for (loop = 0; loop < 2; loop++) { CuDevice::Instantiate().SetDebugStrideMode(true); + if (test_threads) + CuDevice::Instantiate().AllowMultithreading(); if (loop == 0) CuDevice::Instantiate().SelectGpuId("no"); else CuDevice::Instantiate().SelectGpuId("yes"); #endif - kaldi::CudaMatrixUnitTest(); + if (test_threads) { + KALDI_LOG << "Doing matrix unit test with " + << num_threads << " threads."; + std::vector threads; + for (int32 i = 0; i < num_threads - 1; i++) + threads.push_back(new std::thread(kaldi::CudaMatrixUnitTest)); + // the last thread running is the main thread. + kaldi::CudaMatrixUnitTest(); + for (size_t i = 0; i < threads.size(); i++) { + threads[i]->join(); + delete threads[i]; + } + } else { + kaldi::CudaMatrixUnitTest(); + } #if HAVE_CUDA == 1 if (CuDevice::Instantiate().DoublePrecisionSupported()) { diff --git a/src/cudamatrix/cu-matrix.cc b/src/cudamatrix/cu-matrix.cc index beccd9dc4a5..1f09ff278ce 100644 --- a/src/cudamatrix/cu-matrix.cc +++ b/src/cudamatrix/cu-matrix.cc @@ -229,8 +229,10 @@ void CuMatrixBase::CopyFromMat(const CuMatrixBase &M, MatrixIndexT dst_pitch = stride_ * sizeof(Real); MatrixIndexT src_pitch = M.Stride() * sizeof(Real); MatrixIndexT width = M.NumCols() * sizeof(Real); - CU_SAFE_CALL(cudaMemcpy2D(data_, dst_pitch, M.data_, src_pitch, - width, M.num_rows_, cudaMemcpyDeviceToDevice)); + CU_SAFE_CALL( + cudaMemcpy2DAsync(data_, dst_pitch, M.data_, src_pitch, + width, M.num_rows_, cudaMemcpyDeviceToDevice, + cudaStreamPerThread)); } else { if (trans == kNoTrans) { dim3 dimGrid, dimBlock; @@ -319,8 +321,10 @@ void CuMatrixBase::CopyFromMat(const MatrixBase &src, MatrixIndexT dst_pitch = stride_*sizeof(Real); MatrixIndexT src_pitch = src.Stride()*sizeof(Real); MatrixIndexT width = src.NumCols()*sizeof(Real); - CU_SAFE_CALL(cudaMemcpy2D(data_, dst_pitch, src.Data(), src_pitch, - width, src.NumRows(), cudaMemcpyHostToDevice)); + CU_SAFE_CALL(cudaMemcpy2DAsync(data_, dst_pitch, src.Data(), src_pitch, + width, src.NumRows(), cudaMemcpyHostToDevice, + cudaStreamPerThread)); + cudaStreamSynchronize(cudaStreamPerThread); CuDevice::Instantiate().AccuProfile("CuMatrixBase::CopyFromMat(from CPU)", tim); } else { @@ -2286,14 +2290,15 @@ void CuMatrixBase::CopyRowsFromVec(const CuVectorBase &v) { if (v.Dim() == num_rows_*num_cols_) { if (stride_ == num_cols_) { const Real* v_data = v.Data(); - CU_SAFE_CALL(cudaMemcpy(data_, v_data, - sizeof(Real)*num_rows_*num_cols_, - cudaMemcpyDeviceToDevice)); + CU_SAFE_CALL( + cudaMemcpyAsync(data_, v_data, sizeof(Real)*num_rows_*num_cols_, + cudaMemcpyDeviceToDevice, cudaStreamPerThread)); } else { - CU_SAFE_CALL(cudaMemcpy2D(data_, stride_ * sizeof(Real), v.Data(), - num_cols_*sizeof(Real), num_cols_*sizeof(Real), - num_rows_, - cudaMemcpyDeviceToDevice)); + CU_SAFE_CALL( + cudaMemcpy2DAsync(data_, stride_ * sizeof(Real), v.Data(), + num_cols_*sizeof(Real), num_cols_*sizeof(Real), + num_rows_, cudaMemcpyDeviceToDevice, + cudaStreamPerThread)); } } else if (v.Dim() == num_cols_) { dim3 dimGrid, dimBlock; diff --git a/src/cudamatrix/cu-matrixdim.h b/src/cudamatrix/cu-matrixdim.h index dab7bd40eb2..74912dad6e3 100644 --- a/src/cudamatrix/cu-matrixdim.h +++ b/src/cudamatrix/cu-matrixdim.h @@ -26,16 +26,10 @@ /* * Typedefs needed for ANSI-C interface of CUDA wrappers */ -#ifdef _MSC_VER - typedef unsigned __int32 uint32_cuda; - typedef __int32 int32_cuda; - typedef __int32 MatrixIndexT_cuda; // you'd have to change this if you changed MatrixIndexT from int32. -#else - #include - typedef uint32_t uint32_cuda; - typedef int32_t int32_cuda; - typedef int32_t MatrixIndexT_cuda; // you'd have to change this if you changed MatrixIndexT from int32. -#endif +#include +typedef uint32_t uint32_cuda; +typedef int32_t int32_cuda; +typedef int32_t MatrixIndexT_cuda; // you'd have to change this if you changed MatrixIndexT from int32. template struct MatrixElement { diff --git a/src/cudamatrix/cu-packed-matrix.cc b/src/cudamatrix/cu-packed-matrix.cc index 64f8afe0616..7581b043ae0 100644 --- a/src/cudamatrix/cu-packed-matrix.cc +++ b/src/cudamatrix/cu-packed-matrix.cc @@ -143,8 +143,9 @@ void CuPackedMatrix::CopyFromPacked(const CuPackedMatrix &src) { size_t nr = static_cast(num_rows_), num_bytes = ((nr * (nr+1)) / 2) * sizeof(Real); - CU_SAFE_CALL(cudaMemcpy(data_, src.data_, num_bytes, - cudaMemcpyDeviceToDevice)); + CU_SAFE_CALL( + cudaMemcpyAsync(data_, src.data_, num_bytes, cudaMemcpyDeviceToDevice, + cudaStreamPerThread)); CuDevice::Instantiate().AccuProfile("CuPackedMatrix::CopyFromPacked1", tim); } else @@ -247,7 +248,9 @@ void CuPackedMatrix::SetZero() { size_t nr = static_cast(num_rows_), num_bytes = ((nr * (nr+1)) / 2) * sizeof(Real); - CU_SAFE_CALL(cudaMemset(reinterpret_cast(this->data_), 0, num_bytes)); + CU_SAFE_CALL(cudaMemsetAsync(reinterpret_cast(this->data_), 0, + num_bytes, cudaStreamPerThread)); + CU_SAFE_CALL(cudaStreamSynchronize(cudaStreamPerThread)); CuDevice::Instantiate().AccuProfile("CuPackedMatrix::SetZero", tim); } else #endif diff --git a/src/cudamatrix/cu-value.h b/src/cudamatrix/cu-value.h index 2245ff01200..cab0a3235d7 100644 --- a/src/cudamatrix/cu-value.h +++ b/src/cudamatrix/cu-value.h @@ -22,7 +22,7 @@ #ifndef KALDI_CUDAMATRIX_CU_VALUE_H_ #define KALDI_CUDAMATRIX_CU_VALUE_H_ -#include +#include "cudamatrix/cu-device.h" namespace kaldi { @@ -39,7 +39,9 @@ class CuValue { inline CuValue operator = (const CuValue &other) { #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { - CU_SAFE_CALL(cudaMemcpy(data_, other.data_, sizeof(Real), cudaMemcpyDeviceToDevice)); + CU_SAFE_CALL( + cudaMemcpyAsync(data_, other.data_, sizeof(Real), + cudaMemcpyDeviceToDevice, cudaStreamPerThread)); return *this; } else #endif diff --git a/src/cudamatrix/cu-vector.cc b/src/cudamatrix/cu-vector.cc index 5ea3a236b0a..7c968c6550d 100644 --- a/src/cudamatrix/cu-vector.cc +++ b/src/cudamatrix/cu-vector.cc @@ -83,11 +83,11 @@ Real VecMatVec(const CuVectorBase &v1, const CuMatrixBase &M, const CuVectorBase &v2) { KALDI_ASSERT(v1.Dim() == M.NumRows() && M.NumCols() == v2.Dim()); if (v1.Dim() > v2.Dim()) { // do v2*M first - CuVector v2M(v1.Dim(), kUndefined); + CuVector v2M(v1.Dim()); v2M.AddMatVec(1.0, M, kNoTrans, v2, 0.0); return VecVec(v2M, v1); } else { // do v1*M first - CuVector v1M(v2.Dim(), kUndefined); + CuVector v1M(v2.Dim()); v1M.AddMatVec(1.0, M, kTrans, v1, 0.0); return VecVec(v1M, v2); } @@ -167,14 +167,16 @@ void CuVectorBase::CopyRowsFromMat(const CuMatrixBase &mat) { if (dim_ == 0) return; CuTimer tim; if (mat.Stride() == mat.NumCols() && mat.NumRows() != 0) { - CU_SAFE_CALL(cudaMemcpy(data_, mat.Data(), sizeof(Real)*dim_, - cudaMemcpyDeviceToDevice)); + CU_SAFE_CALL( + cudaMemcpyAsync(data_, mat.Data(), sizeof(Real)*dim_, + cudaMemcpyDeviceToDevice, cudaStreamPerThread)); } else { Real* vec_data = data_; for (MatrixIndexT r = 0; r < mat.NumRows(); r++) { - CU_SAFE_CALL(cudaMemcpy(vec_data, mat.RowData(r), - sizeof(Real) * mat.NumCols(), - cudaMemcpyDeviceToDevice)); + CU_SAFE_CALL(cudaMemcpyAsync(vec_data, mat.RowData(r), + sizeof(Real) * mat.NumCols(), + cudaMemcpyDeviceToDevice, + cudaStreamPerThread)); vec_data += mat.NumCols(); } } @@ -1049,7 +1051,9 @@ void CuVectorBase::CopyFromVec(const CuVectorBase &src) { if (CuDevice::Instantiate().Enabled()) { if (dim_ == 0) return; CuTimer tim; - CU_SAFE_CALL(cudaMemcpy(data_, src.data_, src.dim_ * sizeof(Real), cudaMemcpyDeviceToDevice)); + CU_SAFE_CALL( + cudaMemcpyAsync(data_, src.data_, src.dim_ * sizeof(Real), + cudaMemcpyDeviceToDevice, cudaStreamPerThread)); CuDevice::Instantiate().AccuProfile(__func__, tim); } else #endif @@ -1068,7 +1072,9 @@ void CuVectorBase::SetZero() { KALDI_ASSERT(dim_>=0); KALDI_ASSERT(data_!=NULL); CuTimer tim; - CU_SAFE_CALL(cudaMemset(data_, 0, dim_*sizeof(Real))); + CU_SAFE_CALL(cudaMemsetAsync(data_, 0, dim_*sizeof(Real), + cudaStreamPerThread)); + CU_SAFE_CALL(cudaStreamSynchronize(cudaStreamPerThread)); CuDevice::Instantiate().AccuProfile("CuVector::SetZero", tim); } else #endif diff --git a/src/decoder/Makefile b/src/decoder/Makefile index e997cf3c3c4..020fe358fe9 100644 --- a/src/decoder/Makefile +++ b/src/decoder/Makefile @@ -7,13 +7,13 @@ TESTFILES = OBJFILES = training-graph-compiler.o lattice-simple-decoder.o lattice-faster-decoder.o \ lattice-faster-online-decoder.o simple-decoder.o faster-decoder.o \ - decoder-wrappers.o + decoder-wrappers.o grammar-fst.o decodable-matrix.o LIBNAME = kaldi-decoder -ADDLIBS = ../lat/kaldi-lat.a ../hmm/kaldi-hmm.a \ +ADDLIBS = ../lat/kaldi-lat.a ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/decoder/biglm-faster-decoder.h b/src/decoder/biglm-faster-decoder.h index ad8157f8a0f..a6b99fba95e 100644 --- a/src/decoder/biglm-faster-decoder.h +++ b/src/decoder/biglm-faster-decoder.h @@ -122,8 +122,9 @@ class BiglmFasterDecoder { // will be nonempty). fst_out->DeleteStates(); Token *best_tok = NULL; - Weight best_final; // only set if is_final == true. The final-prob corresponding - // to the best final token (i.e. the one with best weight best_weight, below). + Weight best_final = Weight::Zero(); // set only if is_final == true. The + // final-prob corresponding to the best final token (i.e. the one with best + // weight best_weight, below). bool is_final = ReachedFinal(); if (!is_final) { for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) diff --git a/src/decoder/decodable-matrix.cc b/src/decoder/decodable-matrix.cc new file mode 100644 index 00000000000..3cc7b87f2d7 --- /dev/null +++ b/src/decoder/decodable-matrix.cc @@ -0,0 +1,107 @@ +// decoder/decodable-matrix.cc + +// Copyright 2018 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/decodable-matrix.h" + +namespace kaldi { + +DecodableMatrixMapped::DecodableMatrixMapped( + const TransitionModel &tm, + const MatrixBase &likes, + int32 frame_offset): + trans_model_(tm), likes_(&likes), likes_to_delete_(NULL), + frame_offset_(frame_offset) { + stride_ = likes.Stride(); + raw_data_ = likes.Data() - (stride_ * frame_offset); + + if (likes.NumCols() != tm.NumPdfs()) + KALDI_ERR << "Mismatch, matrix has " + << likes.NumCols() << " rows but transition-model has " + << tm.NumPdfs() << " pdf-ids."; +} + +DecodableMatrixMapped::DecodableMatrixMapped( + const TransitionModel &tm, const Matrix *likes, + int32 frame_offset): + trans_model_(tm), likes_(likes), likes_to_delete_(likes), + frame_offset_(frame_offset) { + stride_ = likes->Stride(); + raw_data_ = likes->Data() - (stride_ * frame_offset_); + if (likes->NumCols() != tm.NumPdfs()) + KALDI_ERR << "Mismatch, matrix has " + << likes->NumCols() << " rows but transition-model has " + << tm.NumPdfs() << " pdf-ids."; +} + + +BaseFloat DecodableMatrixMapped::LogLikelihood(int32 frame, int32 tid) { + int32 pdf_id = trans_model_.TransitionIdToPdfFast(tid); +#ifdef KALDI_PARANOID + return (*likes_)(frame - frame_offset_, pdf_id); +#else + return raw_data_[frame * stride_ + pdf_id]; +#endif +} + +int32 DecodableMatrixMapped::NumFramesReady() const { + return frame_offset_ + likes_->NumRows(); +} + +bool DecodableMatrixMapped::IsLastFrame(int32 frame) const { + KALDI_ASSERT(frame < NumFramesReady()); + return (frame == NumFramesReady() - 1); +} + +// Indices are one-based! This is for compatibility with OpenFst. +int32 DecodableMatrixMapped::NumIndices() const { + return trans_model_.NumTransitionIds(); +} + +DecodableMatrixMapped::~DecodableMatrixMapped() { + delete likes_to_delete_; +} + + +void DecodableMatrixMappedOffset::AcceptLoglikes( + Matrix *loglikes, int32 frames_to_discard) { + if (loglikes->NumRows() == 0) return; + KALDI_ASSERT(loglikes->NumCols() == trans_model_.NumPdfs()); + KALDI_ASSERT(frames_to_discard <= loglikes_.NumRows() && + frames_to_discard >= 0); + if (frames_to_discard == loglikes_.NumRows()) { + loglikes_.Swap(loglikes); + loglikes->Resize(0, 0); + } else { + int32 old_rows_kept = loglikes_.NumRows() - frames_to_discard, + new_num_rows = old_rows_kept + loglikes->NumRows(); + Matrix new_loglikes(new_num_rows, loglikes->NumCols()); + new_loglikes.RowRange(0, old_rows_kept).CopyFromMat( + loglikes_.RowRange(frames_to_discard, old_rows_kept)); + new_loglikes.RowRange(old_rows_kept, loglikes->NumRows()).CopyFromMat( + *loglikes); + loglikes_.Swap(&new_loglikes); + } + frame_offset_ += frames_to_discard; + stride_ = loglikes_.Stride(); + raw_data_ = loglikes_.Data() - (frame_offset_ * stride_); +} + + + +} // end namespace kaldi. diff --git a/src/decoder/decodable-matrix.h b/src/decoder/decodable-matrix.h index de70ea82753..475638a35af 100644 --- a/src/decoder/decodable-matrix.h +++ b/src/decoder/decodable-matrix.h @@ -26,14 +26,14 @@ #include "base/kaldi-common.h" #include "hmm/transition-model.h" #include "itf/decodable-itf.h" +#include "matrix/kaldi-matrix.h" namespace kaldi { class DecodableMatrixScaledMapped: public DecodableInterface { public: - // This constructor creates an object that will not delete "likes" - // when done. + // This constructor creates an object that will not delete "likes" when done. DecodableMatrixScaledMapped(const TransitionModel &tm, const Matrix &likes, BaseFloat scale): trans_model_(tm), likes_(&likes), @@ -55,7 +55,7 @@ class DecodableMatrixScaledMapped: public DecodableInterface { KALDI_ERR << "DecodableMatrixScaledMapped: mismatch, matrix has " << likes->NumCols() << " rows but transition-model has " << tm.NumPdfs() << " pdf-ids."; - } + } virtual int32 NumFramesReady() const { return likes_->NumRows(); } @@ -66,7 +66,7 @@ class DecodableMatrixScaledMapped: public DecodableInterface { // Note, frames are numbered from zero. virtual BaseFloat LogLikelihood(int32 frame, int32 tid) { - return scale_ * (*likes_)(frame, trans_model_.TransitionIdToPdf(tid)); + return scale_ * (*likes_)(frame, trans_model_.TransitionIdToPdfFast(tid)); } // Indices are one-based! This is for compatibility with OpenFst. @@ -83,6 +83,59 @@ class DecodableMatrixScaledMapped: public DecodableInterface { KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableMatrixScaledMapped); }; +/** + This is like DecodableMatrixScaledMapped, but it doesn't support an acoustic + scale, and it does support a frame offset, whereby you can state that the + first row of 'likes' is actually the n'th row of the matrix of available + log-likelihoods. It's useful if the neural net output comes in chunks for + different frame ranges. + + Note: DecodableMatrixMappedOffset solves the same problem in a slightly + different way, where you use the same decodable object. This one, unlike + DecodableMatrixMappedOffset, is compatible with when the loglikes are in a + SubMatrix. + */ +class DecodableMatrixMapped: public DecodableInterface { + public: + // This constructor creates an object that will not delete "likes" when done. + // the frame_offset is the frame the row 0 of 'likes' corresponds to, would be + // greater than one if this is not the first chunk of likelihoods. + DecodableMatrixMapped(const TransitionModel &tm, + const MatrixBase &likes, + int32 frame_offset = 0); + + // This constructor creates an object that will delete "likes" + // when done. + DecodableMatrixMapped(const TransitionModel &tm, + const Matrix *likes, + int32 frame_offset = 0); + + virtual int32 NumFramesReady() const; + + virtual bool IsLastFrame(int32 frame) const; + + virtual BaseFloat LogLikelihood(int32 frame, int32 tid); + + // Note: these indices are 1-based. + virtual int32 NumIndices() const; + + virtual ~DecodableMatrixMapped(); + + private: + const TransitionModel &trans_model_; // for tid to pdf mapping + const MatrixBase *likes_; + const Matrix *likes_to_delete_; + int32 frame_offset_; + + // raw_data_ and stride_ are a kind of fast look-aside for 'likes_', to be + // used when KALDI_PARANOID is false. + const BaseFloat *raw_data_; + int32 stride_; + + KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableMatrixMapped); +}; + + /** This decodable class returns log-likes stored in a matrix; it supports repeatedly writing to the matrix and setting a time-offset representing the @@ -91,68 +144,51 @@ class DecodableMatrixScaledMapped: public DecodableInterface { code will call SetLoglikes() each time more log-likelihods are available. If you try to access a log-likelihood that's no longer available because the frame index is less than the current offset, it is of course an error. + + See also DecodableMatrixMapped, which supports the same type of thing but + with a different interface where you are expected to re-construct the + object each time you want to decode. */ class DecodableMatrixMappedOffset: public DecodableInterface { public: DecodableMatrixMappedOffset(const TransitionModel &tm): - trans_model_(tm), frame_offset_(0), input_is_finished_(false) { } - - + trans_model_(tm), frame_offset_(0), input_is_finished_(false) { } virtual int32 NumFramesReady() { return frame_offset_ + loglikes_.NumRows(); } // this is not part of the generic Decodable interface. int32 FirstAvailableFrame() { return frame_offset_; } - + + // Logically, this function appends 'loglikes' (interpreted as newly available + // frames) to the log-likelihoods stored in the class. + // // This function is destructive of the input "loglikes" because it may // under some circumstances do a shallow copy using Swap(). This function // appends loglikes to any existing likelihoods you've previously supplied. - // frames_to_discard, if nonzero, will discard that number of previously - // available frames, from the left, advancing FirstAvailableFrame() by - // a number equal to frames_to_discard. You should only set frames_to_discard - // to nonzero if you know your decoder won't want to access the loglikes - // for older frames. void AcceptLoglikes(Matrix *loglikes, - int32 frames_to_discard) { - if (loglikes->NumRows() == 0) return; - KALDI_ASSERT(loglikes->NumCols() == trans_model_.NumPdfs()); - KALDI_ASSERT(frames_to_discard <= loglikes_.NumRows() && - frames_to_discard >= 0); - if (frames_to_discard == loglikes_.NumRows()) { - loglikes_.Swap(loglikes); - loglikes->Resize(0, 0); - } else { - int32 old_rows_kept = loglikes_.NumRows() - frames_to_discard, - new_num_rows = old_rows_kept + loglikes->NumRows(); - Matrix new_loglikes(new_num_rows, loglikes->NumCols()); - new_loglikes.RowRange(0, old_rows_kept).CopyFromMat( - loglikes_.RowRange(frames_to_discard, old_rows_kept)); - new_loglikes.RowRange(old_rows_kept, loglikes->NumRows()).CopyFromMat( - *loglikes); - loglikes_.Swap(&new_loglikes); - } - frame_offset_ += frames_to_discard; - } + int32 frames_to_discard); void InputIsFinished() { input_is_finished_ = true; } - + virtual int32 NumFramesReady() const { return loglikes_.NumRows() + frame_offset_; } - + virtual bool IsLastFrame(int32 frame) const { KALDI_ASSERT(frame < NumFramesReady()); return (frame == NumFramesReady() - 1 && input_is_finished_); } virtual BaseFloat LogLikelihood(int32 frame, int32 tid) { - int32 index = frame - frame_offset_; - KALDI_ASSERT(index >= 0 && index < loglikes_.NumRows()); - return loglikes_(index, trans_model_.TransitionIdToPdf(tid)); + int32 pdf_id = trans_model_.TransitionIdToPdfFast(tid); +#ifdef KALDI_PARANOID + return loglikes_(frame - frame_offset_, pdf_id); +#else + // This does no checking, so will be faster. + return raw_data_[frame * stride_ + pdf_id]; +#endif } - - virtual int32 NumIndices() const { return trans_model_.NumTransitionIds(); } // nothing special to do in destructor. @@ -162,6 +198,15 @@ class DecodableMatrixMappedOffset: public DecodableInterface { Matrix loglikes_; int32 frame_offset_; bool input_is_finished_; + + // 'raw_data_' and 'stride_' are intended as a fast look-aside which is an + // alternative to accessing data_. raw_data_ is a faked version of + // data_->Data() as if it started from frame zero rather than frame_offset_. + // This simplifies the code of LogLikelihood(), in cases where KALDI_PARANOID + // is not defined. + BaseFloat *raw_data_; + int32 stride_; + KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableMatrixMappedOffset); }; @@ -171,20 +216,20 @@ class DecodableMatrixScaled: public DecodableInterface { DecodableMatrixScaled(const Matrix &likes, BaseFloat scale): likes_(likes), scale_(scale) { } - + virtual int32 NumFramesReady() const { return likes_.NumRows(); } - + virtual bool IsLastFrame(int32 frame) const { KALDI_ASSERT(frame < NumFramesReady()); return (frame == NumFramesReady() - 1); } - + // Note, frames are numbered from zero. virtual BaseFloat LogLikelihood(int32 frame, int32 index) { - if (index > likes_.NumCols() || index <= 0 || + if (index > likes_.NumCols() || index <= 0 || frame < 0 || frame >= likes_.NumRows()) - KALDI_ERR << "Invalid (frame, index - 1) = (" - << frame << ", " << index - 1 << ") for matrix of size " + KALDI_ERR << "Invalid (frame, index - 1) = (" + << frame << ", " << index - 1 << ") for matrix of size " << likes_.NumRows() << " x " << likes_.NumCols(); return scale_ * likes_(frame, index - 1); } @@ -197,8 +242,6 @@ class DecodableMatrixScaled: public DecodableInterface { BaseFloat scale_; KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableMatrixScaled); }; - - } // namespace kaldi #endif // KALDI_DECODER_DECODABLE_MATRIX_H_ diff --git a/src/decoder/decoder-wrappers.cc b/src/decoder/decoder-wrappers.cc index 150d9e513a8..ff573c74d15 100644 --- a/src/decoder/decoder-wrappers.cc +++ b/src/decoder/decoder-wrappers.cc @@ -19,6 +19,8 @@ #include "decoder/decoder-wrappers.h" #include "decoder/faster-decoder.h" +#include "decoder/lattice-faster-decoder.h" +#include "decoder/grammar-fst.h" #include "lat/lattice-functions.h" namespace kaldi { @@ -195,8 +197,9 @@ DecodeUtteranceLatticeFasterClass::~DecodeUtteranceLatticeFasterClass() { // Takes care of output. Returns true on success. +template bool DecodeUtteranceLatticeFaster( - LatticeFasterDecoder &decoder, // not const but is really an input. + LatticeFasterDecoderTpl &decoder, // not const but is really an input. DecodableInterface &decodable, // not const but is really an input. const TransitionModel &trans_model, const fst::SymbolTable *word_syms, @@ -292,6 +295,38 @@ bool DecodeUtteranceLatticeFaster( return true; } +// Instantiate the template above for the two required FST types. +template bool DecodeUtteranceLatticeFaster( + LatticeFasterDecoderTpl > &decoder, + DecodableInterface &decodable, + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr); + +template bool DecodeUtteranceLatticeFaster( + LatticeFasterDecoderTpl &decoder, + DecodableInterface &decodable, + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr); + + // Takes care of output. Returns true on success. bool DecodeUtteranceLatticeSimple( LatticeSimpleDecoder &decoder, // not const but is really an input. @@ -347,7 +382,7 @@ bool DecodeUtteranceLatticeSimple( for (size_t i = 0; i < words.size(); i++) { std::string s = word_syms->Find(words[i]); if (s == "") - KALDI_ERR << "Word-id " << words[i] <<" not in symbol table."; + KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; std::cerr << s << ' '; } std::cerr << '\n'; diff --git a/src/decoder/decoder-wrappers.h b/src/decoder/decoder-wrappers.h index 38ddcb40f50..fc81137f356 100644 --- a/src/decoder/decoder-wrappers.h +++ b/src/decoder/decoder-wrappers.h @@ -95,8 +95,13 @@ void ModifyGraphForCarefulAlignment( /// other obvious place to put it. If determinize == false, it writes to /// lattice_writer, else to compact_lattice_writer. The writers for /// alignments and words will only be written to if they are open. +/// +/// Caution: this will only link correctly if FST is either fst::Fst, +/// or fst::GrammarFst, as the template function is defined in the .cc file and +/// only instantiated for those two types. +template bool DecodeUtteranceLatticeFaster( - LatticeFasterDecoder &decoder, // not const but is really an input. + LatticeFasterDecoderTpl &decoder, // not const but is really an input. DecodableInterface &decodable, // not const but is really an input. const TransitionModel &trans_model, const fst::SymbolTable *word_syms, @@ -110,6 +115,7 @@ bool DecodeUtteranceLatticeFaster( LatticeWriter *lattice_writer, double *like_ptr); // puts utterance's likelihood in like_ptr on success. + /// This class basically does the same job as the function /// DecodeUtteranceLatticeFaster, but in a way that allows us /// to build a multi-threaded command line program more easily. diff --git a/src/decoder/grammar-fst.cc b/src/decoder/grammar-fst.cc new file mode 100644 index 00000000000..1b79e7b5521 --- /dev/null +++ b/src/decoder/grammar-fst.cc @@ -0,0 +1,1031 @@ +// decoder/grammar-fst.cc + +// Copyright 2018 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/grammar-fst.h" +#include "fstext/grammar-context-fst.h" + +namespace fst { + + +GrammarFst::GrammarFst( + int32 nonterm_phones_offset, + std::shared_ptr > top_fst, + const std::vector > > > &ifsts): + nonterm_phones_offset_(nonterm_phones_offset), + top_fst_(top_fst), + ifsts_(ifsts) { + Init(); +} + +void GrammarFst::Init() { + KALDI_ASSERT(nonterm_phones_offset_ > 1); + InitNonterminalMap(); + entry_arcs_.resize(ifsts_.size()); + if (!ifsts_.empty()) { + // We call this mostly so that if something is wrong with the input FSTs, the + // problem will be detected sooner rather than later. + // There would be no problem if we were to call InitEntryArcs(i) + // for all 0 <= i < ifsts_size(), but we choose to call it + // lazily on demand, to save startup time if the number of nonterminals + // is large. + InitEntryArcs(0); + } + InitInstances(); +} + +GrammarFst::~GrammarFst() { + Destroy(); +} + +void GrammarFst::Destroy() { + for (size_t i = 0; i < instances_.size(); i++) { + FstInstance &instance = instances_[i]; + std::unordered_map::const_iterator + iter = instance.expanded_states.begin(), + end = instance.expanded_states.end(); + for (; iter != end; ++iter) { + ExpandedState *e = iter->second; + delete e; + } + } + top_fst_ = NULL; + ifsts_.clear(); + nonterminal_map_.clear(); + entry_arcs_.clear(); + instances_.clear(); +} + + +void GrammarFst::DecodeSymbol(Label label, + int32 *nonterminal_symbol, + int32 *left_context_phone) { + // encoding_multiple will normally equal 1000 (but may be a multiple of 1000 + // if there are a lot of phones); kNontermBigNumber is 10000000. + int32 big_number = static_cast(kNontermBigNumber), + nonterm_phones_offset = nonterm_phones_offset_, + encoding_multiple = GetEncodingMultiple(nonterm_phones_offset); + // The following assertion should be optimized out as the condition is + // statically known. + KALDI_ASSERT(big_number % static_cast(kNontermMediumNumber) == 0); + + *nonterminal_symbol = (label - big_number) / encoding_multiple; + *left_context_phone = label % encoding_multiple; + if (*nonterminal_symbol <= nonterm_phones_offset || + *left_context_phone == 0 || *left_context_phone > + nonterm_phones_offset + static_cast(kNontermBos)) + KALDI_ERR << "Decoding invalid label " << label + << ": code error or invalid --nonterm-phones-offset?"; + +} + +void GrammarFst::InitNonterminalMap() { + nonterminal_map_.clear(); + for (size_t i = 0; i < ifsts_.size(); i++) { + int32 nonterminal = ifsts_[i].first; + if (nonterminal_map_.count(nonterminal)) + KALDI_ERR << "Nonterminal symbol " << nonterminal + << " is paired with two FSTs."; + if (nonterminal < GetPhoneSymbolFor(kNontermUserDefined)) + KALDI_ERR << "Nonterminal symbol " << nonterminal + << " in input pairs, was expected to be >= " + << GetPhoneSymbolFor(kNontermUserDefined); + nonterminal_map_[nonterminal] = static_cast(i); + } +} + + +void GrammarFst::InitEntryArcs(int32 i) { + KALDI_ASSERT(static_cast(i) < ifsts_.size()); + const ConstFst &fst = *(ifsts_[i].second); + InitEntryOrReentryArcs(fst, fst.Start(), + GetPhoneSymbolFor(kNontermBegin), + &(entry_arcs_[i])); +} + +void GrammarFst::InitInstances() { + KALDI_ASSERT(instances_.empty()); + instances_.resize(1); + instances_[0].ifst_index = -1; + instances_[0].fst = top_fst_.get(); + instances_[0].parent_instance = -1; + instances_[0].parent_state = -1; +} + +void GrammarFst::InitEntryOrReentryArcs( + const ConstFst &fst, + int32 entry_state, + int32 expected_nonterminal_symbol, + std::unordered_map *phone_to_arc) { + phone_to_arc->clear(); + ArcIterator > aiter(fst, entry_state); + int32 arc_index = 0; + for (; !aiter.Done(); aiter.Next(), ++arc_index) { + const StdArc &arc = aiter.Value(); + int32 nonterminal, left_context_phone; + if (arc.ilabel <= (int32)kNontermBigNumber) { + if (entry_state == fst.Start()) { + KALDI_ERR << "There is something wrong with the graph; did you forget to " + "add #nonterm_begin and #nonterm_end to the non-top-level FSTs " + "before compiling?"; + } else { + KALDI_ERR << "There is something wrong with the graph; re-entry state is " + "not as anticipated."; + } + } + DecodeSymbol(arc.ilabel, &nonterminal, &left_context_phone); + if (nonterminal != expected_nonterminal_symbol) { + KALDI_ERR << "Expected arcs from this state to have nonterminal-symbol " + << expected_nonterminal_symbol << ", but got " + << nonterminal; + } + std::pair p(left_context_phone, arc_index); + if (!phone_to_arc->insert(p).second) { + // If it was not successfully inserted in the phone_to_arc map, it means + // there were two arcs with the same left-context phone, which does not + // make sense; that's an error, likely a code error (or an error when the + // input FSTs were generated). + KALDI_ERR << "Two arcs had the same left-context phone."; + } + } +} + +GrammarFst::ExpandedState *GrammarFst::ExpandState( + int32 instance_id, BaseStateId state_id) { + int32 big_number = kNontermBigNumber; + const ConstFst &fst = *(instances_[instance_id].fst); + ArcIterator > aiter(fst, state_id); + KALDI_ASSERT(!aiter.Done() && aiter.Value().ilabel > big_number && + "Something is not right; did you call PrepareForGrammarFst()?"); + + const StdArc &arc = aiter.Value(); + int32 encoding_multiple = GetEncodingMultiple(nonterm_phones_offset_), + nonterminal = (arc.ilabel - big_number) / encoding_multiple; + if (nonterminal == GetPhoneSymbolFor(kNontermBegin) || + nonterminal == GetPhoneSymbolFor(kNontermReenter)) { + KALDI_ERR << "Encountered unexpected type of nonterminal while " + "expanding state."; + } else if (nonterminal == GetPhoneSymbolFor(kNontermEnd)) { + return ExpandStateEnd(instance_id, state_id); + } else if (nonterminal >= GetPhoneSymbolFor(kNontermUserDefined)) { + return ExpandStateUserDefined(instance_id, state_id); + } else { + KALDI_ERR << "Encountered unexpected type of nonterminal " + << nonterminal << " while expanding state."; + } + return NULL; // Suppress compiler warning +} + + +// static inline +void GrammarFst::CombineArcs(const StdArc &leaving_arc, + const StdArc &arriving_arc, + float cost_correction, + StdArc *arc) { + // The following assertion shouldn't fail; we ensured this in + // PrepareForGrammarFst(), search for 'olabel_problem'. + KALDI_ASSERT(leaving_arc.olabel == 0); + // 'leaving_arc' leaves one fst, and 'arriving_arcs', conceptually arrives in + // another. This code merges the information of the two arcs to make a + // cross-FST arc. The ilabel information is discarded as it was only intended + // for the consumption of the GrammarFST code. + arc->ilabel = 0; + arc->olabel = arriving_arc.olabel; + // conceptually, arc->weight = + // Times(Times(leaving_arc.weight, arriving_arc.weight), Weight(cost_correction)). + // The below might be a bit faster, I hope-- avoiding checking. + arc->weight = Weight(cost_correction + leaving_arc.weight.Value() + + arriving_arc.weight.Value()); + arc->nextstate = arriving_arc.nextstate; +} + +GrammarFst::ExpandedState *GrammarFst::ExpandStateEnd( + int32 instance_id, BaseStateId state_id) { + if (instance_id == 0) + KALDI_ERR << "Did not expect #nonterm_end symbol in FST-instance 0."; + const FstInstance &instance = instances_[instance_id]; + int32 parent_instance_id = instance.parent_instance; + const ConstFst &fst = *(instance.fst); + const FstInstance &parent_instance = instances_[parent_instance_id]; + const ConstFst &parent_fst = *(parent_instance.fst); + + ExpandedState *ans = new ExpandedState; + ans->dest_fst_instance = parent_instance_id; + + // parent_aiter is the arc-iterator in the state we return to. We'll Seek() + // to a different position 'parent_aiter' for each arc leaving this state. + // (actually we expect just one arc to leave this state). + ArcIterator > parent_aiter(parent_fst, + instance.parent_state); + + // for explanation of cost_correction, see documentation for CombineArcs(). + float num_reentry_arcs = instances_[instance_id].parent_reentry_arcs.size(), + cost_correction = -log(num_reentry_arcs); + + ArcIterator > aiter(fst, state_id); + + for (; !aiter.Done(); aiter.Next()) { + const StdArc &leaving_arc = aiter.Value(); + int32 this_nonterminal, left_context_phone; + DecodeSymbol(leaving_arc.ilabel, &this_nonterminal, + &left_context_phone); + KALDI_ASSERT(this_nonterminal == GetPhoneSymbolFor(kNontermEnd) && + ">1 nonterminals from a state; did you use " + "PrepareForGrammarFst()?"); + std::unordered_map::const_iterator reentry_iter = + instances_[instance_id].parent_reentry_arcs.find(left_context_phone), + reentry_end = instances_[instance_id].parent_reentry_arcs.end(); + if (reentry_iter == reentry_end) { + KALDI_ERR << "FST with index " << instance.ifst_index + << " ends with left-context-phone " << left_context_phone + << " but parent FST does not support that left-context " + "at the return point."; + } + size_t parent_arc_index = static_cast(reentry_iter->second); + parent_aiter.Seek(parent_arc_index); + const StdArc &arriving_arc = parent_aiter.Value(); + // 'arc' will combine the information on 'leaving_arc' and 'arriving_arc', + // except that the ilabel will be set to zero. + if (leaving_arc.olabel != 0) { + // If the following fails it would maybe indicate you hadn't called + // PrepareForGrammarFst(), or there was an error in that, because + // we made sure the leaving arc does not have an olabel. Search + // in that code for 'olabel_problem' for more details. + KALDI_ERR << "Leaving arc has zero olabel."; + } + StdArc arc; + CombineArcs(leaving_arc, arriving_arc, cost_correction, &arc); + ans->arcs.push_back(arc); + } + return ans; +} + +int32 GrammarFst::GetChildInstanceId(int32 instance_id, int32 nonterminal, + int32 state) { + int64 encoded_pair = (static_cast(nonterminal) << 32) + state; + // 'new_instance_id' is the instance-id we'd assign if we had to create a new one. + // We try to add it at once, to avoid having to do an extra map lookup in case + // it wasn't there and we did need to add it. + int32 child_instance_id = instances_.size(); + { + std::pair p(encoded_pair, child_instance_id); + std::pair::const_iterator, bool> ans = + instances_[instance_id].child_instances.insert(p); + if (!ans.second) { + // The pair was not inserted, which means the key 'encoded_pair' did exist in the + // map. Return the value in the map. + child_instance_id = ans.first->second; + return child_instance_id; + } + } + // If we reached this point, we did successfully insert 'child_instance_id' into + // the map, because the key didn't exist. That means we have to actually create + // the instance. + instances_.resize(child_instance_id + 1); + const FstInstance &parent_instance = instances_[instance_id]; + FstInstance &child_instance = instances_[child_instance_id]; + + // Work out the ifst_index for this nonterminal. + std::unordered_map::const_iterator iter = + nonterminal_map_.find(nonterminal); + if (iter == nonterminal_map_.end()) { + KALDI_ERR << "Nonterminal " << nonterminal << " was requested, but " + "there is no FST for it."; + } + int32 ifst_index = iter->second; + child_instance.ifst_index = ifst_index; + child_instance.fst = ifsts_[ifst_index].second.get(); + child_instance.parent_instance = instance_id; + child_instance.parent_state = state; + InitEntryOrReentryArcs(*(parent_instance.fst), state, + GetPhoneSymbolFor(kNontermReenter), + &(child_instance.parent_reentry_arcs)); + return child_instance_id; +} + +GrammarFst::ExpandedState *GrammarFst::ExpandStateUserDefined( + int32 instance_id, BaseStateId state_id) { + const ConstFst &fst = *(instances_[instance_id].fst); + ArcIterator > aiter(fst, state_id); + + ExpandedState *ans = new ExpandedState; + int32 dest_fst_instance = -1; // We'll set it in the loop. + // and->dest_fst_instance will be set to this. + + for (; !aiter.Done(); aiter.Next()) { + const StdArc &leaving_arc = aiter.Value(); + int32 nonterminal, left_context_phone; + DecodeSymbol(leaving_arc.ilabel, &nonterminal, + &left_context_phone); + int32 child_instance_id = GetChildInstanceId(instance_id, + nonterminal, + leaving_arc.nextstate); + if (dest_fst_instance < 0) { + dest_fst_instance = child_instance_id; + } else if (dest_fst_instance != child_instance_id) { + KALDI_ERR << "Same state leaves to different FST instances " + "(Did you use PrepareForGrammarFst()?)"; + } + const FstInstance &child_instance = instances_[child_instance_id]; + const ConstFst &child_fst = *(child_instance.fst); + int32 child_ifst_index = child_instance.ifst_index; + std::unordered_map &entry_arcs = entry_arcs_[child_ifst_index]; + if (entry_arcs.empty()) + InitEntryArcs(child_ifst_index); + // for explanation of cost_correction, see documentation for CombineArcs(). + float num_entry_arcs = entry_arcs.size(), + cost_correction = -log(num_entry_arcs); + + // Get the arc-index for the arc leaving the start-state of child FST that + // corresponds to this phonetic context. + std::unordered_map::const_iterator entry_iter = + entry_arcs.find(left_context_phone); + if (entry_iter == entry_arcs.end()) { + KALDI_ERR << "FST for nonterminal " << nonterminal + << " does not have an entry point for left-context-phone " + << left_context_phone; + } + int32 arc_index = entry_iter->second; + ArcIterator > child_aiter(child_fst, child_fst.Start()); + child_aiter.Seek(arc_index); + const StdArc &arriving_arc = child_aiter.Value(); + StdArc arc; + CombineArcs(leaving_arc, arriving_arc, cost_correction, &arc); + ans->arcs.push_back(arc); + } + ans->dest_fst_instance = dest_fst_instance; + return ans; +} + + +void GrammarFst::Write(std::ostream &os, bool binary) const { + using namespace kaldi; + if (!binary) + KALDI_ERR << "GrammarFst::Write only supports binary mode."; + int32 format = 1, + num_ifsts = ifsts_.size(); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, format); + WriteBasicType(os, binary, num_ifsts); + WriteBasicType(os, binary, nonterm_phones_offset_); + + std::string stream_name("unknown"); + FstWriteOptions wopts(stream_name); + top_fst_->Write(os, wopts); + + for (int32 i = 0; i < num_ifsts; i++) { + int32 nonterminal = ifsts_[i].first; + WriteBasicType(os, binary, nonterminal); + ifsts_[i].second->Write(os, wopts); + } + WriteToken(os, binary, ""); +} + +static ConstFst *ReadConstFstFromStream(std::istream &is) { + fst::FstHeader hdr; + std::string stream_name("unknown"); + if (!hdr.Read(is, stream_name)) + KALDI_ERR << "Reading FST: error reading FST header"; + FstReadOptions ropts("", &hdr); + ConstFst *ans = ConstFst::Read(is, ropts); + if (!ans) + KALDI_ERR << "Could not read ConstFst from stream."; + return ans; +} + + + +void GrammarFst::Read(std::istream &is, bool binary) { + using namespace kaldi; + if (!binary) + KALDI_ERR << "GrammarFst::Read only supports binary mode."; + if (top_fst_ != NULL) + Destroy(); + int32 format = 1, num_ifsts; + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &format); + if (format != 1) + KALDI_ERR << "This version of the code cannot read this GrammarFst, " + "update your code."; + ReadBasicType(is, binary, &num_ifsts); + ReadBasicType(is, binary, &nonterm_phones_offset_); + top_fst_ = std::shared_ptr >(ReadConstFstFromStream(is)); + for (int32 i = 0; i < num_ifsts; i++) { + int32 nonterminal; + ReadBasicType(is, binary, &nonterminal); + std::shared_ptr > + this_fst(ReadConstFstFromStream(is)); + ifsts_.push_back(std::pair > >( + nonterminal, this_fst)); + } + Init(); +} + + +/** + This utility function input-determinizes a specified state s of the FST + 'fst'. (This input-determinizes while treating epsilon as a real symbol, + although for the application we expect to use it, there won't be epsilons). + + What this function does is: for any symbol i that appears as the ilabel of + more than one arc leaving state s of FST 'fst', it creates an additional + state, it creates a new state t with epsilon-input transitions leaving it for + each of those multiple arcs leaving state s; it deletes the original arcs + leaving state s; and it creates a single arc leaving state s to the newly + created state with the ilabel i on it. It sets the weights as necessary to + preserve equivalence and also to ensure that if, prior to this modification, + the FST was stochastic when cast to the log semiring (see + IsStochasticInLog()), it still will be. I.e. when interpreted as + negative logprobs, the weight from state s to t would be the sum of + the weights on the original arcs leaving state s. + + This is used as a very cheap solution when preparing FSTs for the grammar + decoder, to ensure that there is only one entry-state to the sub-FST for each + phonetic left-context; this keeps the grammar-FST code (i.e. the code that + stitches them together) simple. Of course it will tend to introduce + unnecessary epsilons, and if we were careful we might be able to remove + some of those, but this wouldn't have a substantial impact on overall + decoder performance so we don't bother. + */ +static void InputDeterminizeSingleState(StdArc::StateId s, + VectorFst *fst) { + bool was_input_deterministic = true; + typedef StdArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Label Label; + typedef Arc::Weight Weight; + + struct InfoForIlabel { + std::vector arc_indexes; // indexes of all arcs with this ilabel + float tot_cost; // total cost of all arcs leaving state s for this + // ilabel, summed as if they were negative log-probs. + StateId new_state; // state-id of new state, if any, that we have created + // to remove duplicate symbols with this ilabel. + InfoForIlabel(): new_state(-1) { } + }; + + std::unordered_map label_map; + + size_t arc_index = 0; + for (ArcIterator > aiter(*fst, s); + !aiter.Done(); aiter.Next(), ++arc_index) { + const Arc &arc = aiter.Value(); + InfoForIlabel &info = label_map[arc.ilabel]; + if (info.arc_indexes.empty()) { + info.tot_cost = arc.weight.Value(); + } else { + info.tot_cost = -kaldi::LogAdd(-info.tot_cost, -arc.weight.Value()); + was_input_deterministic = false; + } + info.arc_indexes.push_back(arc_index); + } + + if (was_input_deterministic) + return; // Nothing to do. + + // 'new_arcs' will contain the modified list of arcs + // leaving state s + std::vector new_arcs; + new_arcs.reserve(arc_index); + arc_index = 0; + for (ArcIterator > aiter(*fst, s); + !aiter.Done(); aiter.Next(), ++arc_index) { + const Arc &arc = aiter.Value(); + Label ilabel = arc.ilabel; + InfoForIlabel &info = label_map[ilabel]; + if (info.arc_indexes.size() == 1) { + new_arcs.push_back(arc); // no changes needed + } else { + if (info.new_state < 0) { + info.new_state = fst->AddState(); + // add arc from state 's' to newly created state. + new_arcs.push_back(Arc(ilabel, 0, Weight(info.tot_cost), + info.new_state)); + } + // add arc from new state to original destination of this arc. + fst->AddArc(info.new_state, Arc(0, arc.olabel, + Weight(arc.weight.Value() - info.tot_cost), + arc.nextstate)); + } + } + fst->DeleteArcs(s); + for (size_t i = 0; i < new_arcs.size(); i++) + fst->AddArc(s, new_arcs[i]); +} + + +// This class contains the implementation of the function +// PrepareForGrammarFst(), which is declared in grammar-fst.h. +class GrammarFstPreparer { + public: + using FST = VectorFst; + using Arc = StdArc; + using StateId = Arc::StateId; + using Label = Arc::Label; + using Weight = Arc::Weight; + + GrammarFstPreparer(int32 nonterm_phones_offset, + VectorFst *fst): + nonterm_phones_offset_(nonterm_phones_offset), + fst_(fst), orig_num_states_(fst->NumStates()), + simple_final_state_(kNoStateId) { } + + void Prepare() { + if (fst_->Start() == kNoStateId) { + KALDI_ERR << "FST has no states."; + } + for (StateId s = 0; s < fst_->NumStates(); s++) { + if (IsSpecialState(s)) { + if (NeedEpsilons(s)) { + InsertEpsilonsForState(s); + // This state won't be treated as a 'special' state any more; + // all 'special' arcs (arcs with ilabels >= kNontermBigNumber) + // have been moved and now leave from newly created states that + // this state transitions to via epsilons arcs. + } else { + // OK, state s is a special state. + FixArcsToFinalStates(s); + MaybeAddFinalProbToState(s); + // The following ensures that the start-state of sub-FSTs only has + // a single arc per left-context phone (the graph-building recipe can + // end up creating more than one if there were disambiguation symbols, + // e.g. for langauge model backoff). + if (s == fst_->Start() && IsEntryState(s)) + InputDeterminizeSingleState(s, fst_); + } + } + } + StateId num_new_states = fst_->NumStates() - orig_num_states_; + KALDI_LOG << "Added " << num_new_states << " new states while " + "preparing for grammar FST."; + } + + private: + + // Returns true if state 's' has at least one arc coming out of it with a + // special nonterminal-related ilabel on it (i.e. an ilabel >= + // kNontermBigNumber), and false otherwise. + bool IsSpecialState(StateId s) const; + + // This function verifies that state s does not currently have any + // final-prob (crashes if that fails); then, if the arcs leaving s have + // nonterminal symbols kNontermEnd or user-defined nonterminals (>= + // kNontermUserDefined), it adds a final-prob with cost given by + // KALDI_GRAMMAR_FST_SPECIAL_WEIGHT to the state. + // + // State s is required to be a 'special state', i.e. have special symbols on + // arcs leaving it, and the function assumes (since it will already + // have been checked) that the arcs leaving s, if there are more than + // one, all correspond to the same nonterminal symbol. + void MaybeAddFinalProbToState(StateId s); + + + // This function does some checking for 'special states', that they have + // certain expected properties, and also detects certain problematic + // conditions that we need to fix. It returns true if we need to + // modify this state (by adding input-epsilon arcs), and false otherwise. + bool NeedEpsilons(StateId s) const; + + // Returns true if state s (which is expected to be the start state, although we + // don't check this) has arcs with nonterminal symbols #nonterm_begin. + bool IsEntryState(StateId s) const; + + // Fixes any final-prob-related problems with this state. The problem we aim + // to fix is that there may be arcs with nonterminal symbol #nonterm_end which + // transition from this state to a state with non-unit final prob. This + // function assimilates that final-prob into the arc leaving from this state, + // by making the arc transition to a new state with unit final-prob, and + // incorporating the original final-prob into the arc's weight. + // + // The purpose of this is to keep the GrammarFst code simple. + // + // It would have been more efficient to do this in CheckProperties(), but + // doing it this way is clearer; and the extra time taken here will be tiny. + void FixArcsToFinalStates(StateId s); + + + // This struct represents a category of arcs that are allowed to leave from + // the same 'special state'. If a special state has arcs leaving it that + // are in more than one category, it will need to be split up into + // multiple states connected by epsilons. + // + // The 'nonterminal' and 'nextstate' have to do with ensuring that all + // arcs leaving a particular FST state transition to the same FST instance + // (which, in turn, helps to keep the ArcIterator code efficient). + // + // The 'olabel' has to do with ensuring that arcs with user-defined + // nonterminals or kNontermEnd have no olabels on them. This is a requirement + // of the CombineArcs() function of GrammarFst, because it needs to combine + // two olabels into one so we need to know that at least one of the olabels is + // always epsilon. + struct ArcCategory { + int32 nonterminal; // The nonterminal symbol #nontermXXX encoded into the ilabel, + // or 0 if the ilabel was other.nonterminal) return false; + if (nextstate < other.nextstate) return true; + else if (nextstate > other.nextstate) return false; + return olabel < other.olabel; + } + }; + + // This function, which is used in CheckProperties() and + // InsertEpsilonsForState(), works out the categrory of the arc; see + // documentation of struct ArcCategory for more details. + void GetCategoryOfArc(const Arc &arc, + ArcCategory *arc_category) const; + + + // This will be called for 'special states' that need to be split up. + // Non-special arcs leaving this state will stay here. For each + // category of special arcs (see ArcCategory for details), a new + // state will be created and those arcs will leave from that state + // instead; for each such state, an input-epsilon arc will leave this state + // for that state. For more details, see the code. + void InsertEpsilonsForState(StateId s); + + inline int32 GetPhoneSymbolFor(enum NonterminalValues n) const { + return nonterm_phones_offset_ + static_cast(n); + } + + int32 nonterm_phones_offset_; + VectorFst *fst_; + StateId orig_num_states_; + // If needed we may add a 'simple final state' to fst_, which has unit + // final-prob. This is used when we ensure that states with kNontermExit on + // them transition to a state with unit final-prob, so we don't need to + // look at the final-prob when expanding states. + StateId simple_final_state_; +}; + +bool GrammarFstPreparer::IsSpecialState(StateId s) const { + if (fst_->Final(s).Value() == KALDI_GRAMMAR_FST_SPECIAL_WEIGHT) { + // TODO: find a way to detect if it was a coincidence, or not make it an + // error, because in principle a user-defined grammar could contain this + // special cost. + KALDI_WARN << "It looks like you are calling PrepareForGrammarFst twice."; + } + for (ArcIterator aiter(*fst_, s ); !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel >= kNontermBigNumber) // 1 million + return true; + } + return false; +} + +bool GrammarFstPreparer::IsEntryState(StateId s) const { + int32 big_number = kNontermBigNumber, + encoding_multiple = GetEncodingMultiple(nonterm_phones_offset_); + + for (ArcIterator aiter(*fst_, s ); !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + int32 nonterminal = (arc.ilabel - big_number) / + encoding_multiple; + // we check that at least one has label with nonterminal equal to #nonterm_begin... + // in fact they will all have this value if at least one does, and this was checked + // in NeedEpsilons(). + if (nonterminal == GetPhoneSymbolFor(kNontermBegin)) + return true; + } + return false; +} + + +bool GrammarFstPreparer::NeedEpsilons(StateId s) const { + + // See the documentation for GetCategoryOfArc() for explanation of what these are. + std::set categories; + + if (fst_->Final(s) != Weight::Zero()) { + // A state having a final-prob is considered the same as it having + // a non-nonterminal arc out of it.. this would be like a transition + // within the same FST. + ArcCategory category; + category.nonterminal = 0; + category.nextstate = kNoStateId; + category.olabel = 0; + categories.insert(category); + } + + int32 big_number = kNontermBigNumber, + encoding_multiple = GetEncodingMultiple(nonterm_phones_offset_); + + for (ArcIterator aiter(*fst_, s ); !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + ArcCategory category; + GetCategoryOfArc(arc, &category); + categories.insert(category); + + // the rest of this block is just checking. + int32 nonterminal = category.nonterminal; + + if (nonterminal >= GetPhoneSymbolFor(kNontermUserDefined)) { + // Check that the destination state of this arc has arcs with + // kNontermReenter on them. We'll separately check that such states + // don't have other types of arcs leaving them (search for + // kNontermReenter below), so it's sufficient to check the first arc. + ArcIterator next_aiter(*fst_, arc.nextstate); + if (next_aiter.Done()) + KALDI_ERR << "Destination state of a user-defined nonterminal " + "has no arcs leaving it."; + const Arc &next_arc = next_aiter.Value(); + int32 next_nonterminal = (next_arc.ilabel - big_number) / + encoding_multiple; + if (next_nonterminal != GetPhoneSymbolFor(kNontermReenter)) { + KALDI_ERR << "Expected arcs with user-defined nonterminals to be " + "followed by arcs with kNontermReenter."; + } + } + if (nonterminal == GetPhoneSymbolFor(kNontermBegin) && + s != fst_->Start()) { + KALDI_ERR << "#nonterm_begin symbol is present but this is not the " + "first state. Did you do fstdeterminizestar while compiling?"; + } + if (nonterminal == GetPhoneSymbolFor(kNontermEnd)) { + if (fst_->NumArcs(arc.nextstate) != 0 || + fst_->Final(arc.nextstate) == Weight::Zero()) { + KALDI_ERR << "Arc with kNontermEnd is not the final arc."; + } + } + } + if (categories.size() > 1) { + // This state has arcs leading to multiple FST instances. + // Do some checking to see that there is nothing really unexpected in + // there. + for (std::set::const_iterator + iter = categories.begin(); + iter != categories.end(); ++iter) { + int32 nonterminal = iter->nonterminal; + if (nonterminal == nonterm_phones_offset_ + kNontermBegin || + nonterminal == nonterm_phones_offset_ + kNontermReenter) + // we don't expect any state which has symbols like (kNontermBegin:p1) + // on arcs coming out of it, to also have other types of symbol. The + // same goes for kNontermReenter. + KALDI_ERR << "We do not expect states with arcs of type " + "kNontermBegin/kNontermReenter coming out of them, to also have " + "other types of arc."; + } + } + // the first half of the || below relates to olabels on arcs with either + // user-defined nonterminals or #nonterm_end (which would become 'leaving_arc' + // in the CombineArcs() function of GrammarFst). That function does not allow + // nonzero olabels on 'leaving_arc', which would be a problem if the + // 'arriving' arc had nonzero olabels, so we solve this by introducing + // input-epsilon arcs and putting the olabels on them instead. + bool need_epsilons = (categories.size() == 1 && + categories.begin()->olabel != 0) || + categories.size() > 1; + return need_epsilons; +} + +void GrammarFstPreparer::FixArcsToFinalStates(StateId s) { + int32 encoding_multiple = GetEncodingMultiple(nonterm_phones_offset_), + big_number = kNontermBigNumber; + for (MutableArcIterator aiter(fst_, s ); !aiter.Done(); aiter.Next()) { + Arc arc = aiter.Value(); + if (arc.ilabel < big_number) + continue; + int32 nonterminal = (arc.ilabel - big_number) / encoding_multiple; + if (nonterminal == GetPhoneSymbolFor(kNontermEnd)) { + KALDI_ASSERT(fst_->NumArcs(arc.nextstate) == 0 && + fst_->Final(arc.nextstate) != Weight::Zero()); + if (fst_->Final(arc.nextstate) == Weight::One()) + continue; // There is no problem to fix. + if (simple_final_state_ == kNoStateId) { + simple_final_state_ = fst_->AddState(); + fst_->SetFinal(simple_final_state_, Weight::One()); + } + arc.weight = Times(arc.weight, fst_->Final(arc.nextstate)); + arc.nextstate = simple_final_state_; + aiter.SetValue(arc); + } + } +} + +void GrammarFstPreparer::MaybeAddFinalProbToState(StateId s) { + if (fst_->Final(s) != Weight::Zero()) { + // Something went wrong and it will require some debugging. In Prepare(), + // if we detected that the special state had a nonzero final-prob, we + // would have inserted epsilons to remove it, so there may be a bug in + // this class's code. + KALDI_ERR << "State already final-prob."; + } + ArcIterator aiter(*fst_, s ); + KALDI_ASSERT(!aiter.Done()); + const Arc &arc = aiter.Value(); + int32 encoding_multiple = GetEncodingMultiple(nonterm_phones_offset_), + big_number = kNontermBigNumber, + nonterminal = (arc.ilabel - big_number) / encoding_multiple; + KALDI_ASSERT(nonterminal >= GetPhoneSymbolFor(kNontermBegin)); + if (nonterminal == GetPhoneSymbolFor(kNontermEnd) || + nonterminal >= GetPhoneSymbolFor(kNontermUserDefined)) { + fst_->SetFinal(s, Weight(KALDI_GRAMMAR_FST_SPECIAL_WEIGHT)); + } +} + +void GrammarFstPreparer::GetCategoryOfArc( + const Arc &arc, ArcCategory *arc_category) const { + int32 encoding_multiple = GetEncodingMultiple(nonterm_phones_offset_), + big_number = kNontermBigNumber; + + int32 ilabel = arc.ilabel; + if (ilabel < big_number) { + arc_category->nonterminal = 0; + arc_category->nextstate = kNoStateId; + arc_category->olabel = 0; + } else { + int32 nonterminal = (ilabel - big_number) / encoding_multiple; + arc_category->nonterminal = nonterminal; + if (nonterminal <= nonterm_phones_offset_) { + KALDI_ERR << "Problem decoding nonterminal symbol " + "(wrong --nonterm-phones-offset option?), ilabel=" + << ilabel; + } + if (nonterminal >= GetPhoneSymbolFor(kNontermUserDefined)) { + // This is a user-defined symbol. + arc_category->nextstate = arc.nextstate; + arc_category->olabel = arc.olabel; + } else { + arc_category->nextstate = kNoStateId; + if (nonterminal == GetPhoneSymbolFor(kNontermEnd)) + arc_category->olabel = arc.olabel; + else + arc_category->olabel = 0; + } + } +} + + +void GrammarFstPreparer::InsertEpsilonsForState(StateId s) { + // Maps from category of arc, to a pair: + // the StateId is the state corresponding to that category. + // the float is the cost on the arc leading to that state; + // we compute the value that corresponds to the sum of the probabilities + // of the leaving arcs, bearing in mind that p = exp(-cost). + // We don't insert the arc-category whose 'nonterminal' is 0 here (i.e. the + // category for normal arcs); those arcs stay at this state. + std::map > category_to_state; + + // This loop sets up 'category_to_state'. + for (fst::ArcIterator aiter(*fst_, s); !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + ArcCategory category; + GetCategoryOfArc(arc, &category); + int32 nonterminal = category.nonterminal; + if (nonterminal == 0) + continue; + if (nonterminal == GetPhoneSymbolFor(kNontermBegin) || + nonterminal == GetPhoneSymbolFor(kNontermReenter)) { + KALDI_ERR << "Something went wrong; did not expect to insert epsilons " + "for this type of state."; + } + auto iter = category_to_state.find(category); + if (iter == category_to_state.end()) { + StateId new_state = fst_->AddState(); + float cost = arc.weight.Value(); + category_to_state[category] = std::pair(new_state, cost); + } else { + std::pair &p = iter->second; + p.second = -kaldi::LogAdd(-p.second, -arc.weight.Value()); + } + } + + KALDI_ASSERT(!category_to_state.empty()); // would be a code error. + + // 'arcs_from_this_state' is a place to put arcs that will put on this state + // after we delete all its existing arcs. + std::vector arcs_from_this_state; + arcs_from_this_state.reserve(fst_->NumArcs(s) + category_to_state.size()); + + // add arcs corresponding to transitions to the newly created states, to + // 'arcs_from_this_state' + for (std::map >::const_iterator + iter = category_to_state.begin(); iter != category_to_state.end(); + ++iter) { + const ArcCategory &category = iter->first; + StateId new_state = iter->second.first; + float cost = iter->second.second; + Arc arc; + arc.ilabel = 0; + arc.olabel = category.olabel; + arc.weight = Weight(cost); + arc.nextstate = new_state; + arcs_from_this_state.push_back(arc); + } + + // Now add to 'arcs_from_this_state', and to the newly created states, + // arcs corresponding to each of the arcs that were originally leaving + // this state. + for (fst::ArcIterator aiter(*fst_, s); !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + ArcCategory category; + GetCategoryOfArc(arc, &category); + int32 nonterminal = category.nonterminal; + if (nonterminal == 0) { // this arc remains unchanged; we'll put it back later. + arcs_from_this_state.push_back(arc); + continue; + } + auto iter = category_to_state.find(category); + KALDI_ASSERT(iter != category_to_state.end()); + Arc new_arc; + new_arc.ilabel = arc.ilabel; + if (arc.olabel == category.olabel) { + new_arc.olabel = 0; // the olabel went on the epsilon-input arc. + } else { + KALDI_ASSERT(category.olabel == 0); + new_arc.olabel = arc.olabel; + } + StateId new_state = iter->second.first; + float epsilon_arc_cost = iter->second.second; + new_arc.weight = Weight(arc.weight.Value() - epsilon_arc_cost); + new_arc.nextstate = arc.nextstate; + fst_->AddArc(new_state, new_arc); + } + + fst_->DeleteArcs(s); + for (size_t i = 0; i < arcs_from_this_state.size(); i++) { + fst_->AddArc(s, arcs_from_this_state[i]); + } + // leave the final-prob on this state as it was before. +} + + +void PrepareForGrammarFst(int32 nonterm_phones_offset, + VectorFst *fst) { + GrammarFstPreparer p(nonterm_phones_offset, fst); + p.Prepare(); +} + +void CopyToVectorFst(GrammarFst *grammar_fst, + VectorFst *vector_fst) { + typedef GrammarFstArc::StateId GrammarStateId; // int64 + typedef StdArc::StateId StdStateId; // int + typedef StdArc::Label Label; + typedef StdArc::Weight Weight; + + std::vector > queue; + std::unordered_map state_map; + + vector_fst->DeleteStates(); + state_map[grammar_fst->Start()] = vector_fst->AddState(); // state 0. + vector_fst->SetStart(0); + + queue.push_back( + std::pair(grammar_fst->Start(), 0)); + + while (!queue.empty()) { + std::pair p = queue.back(); + queue.pop_back(); + GrammarStateId grammar_state = p.first; + StdStateId std_state = p.second; + vector_fst->SetFinal(std_state, grammar_fst->Final(grammar_state)); + ArcIterator aiter(*grammar_fst, grammar_state); + for (; !aiter.Done(); aiter.Next()) { + const GrammarFstArc &grammar_arc = aiter.Value(); + StdArc std_arc; + std_arc.ilabel = grammar_arc.ilabel; + std_arc.olabel = grammar_arc.olabel; + std_arc.weight = grammar_arc.weight; + GrammarStateId next_grammar_state = grammar_arc.nextstate; + StdStateId next_std_state; + std::unordered_map::const_iterator + state_iter = state_map.find(next_grammar_state); + if (state_iter == state_map.end()) { + next_std_state = vector_fst->AddState(); + state_map[next_grammar_state] = next_std_state; + queue.push_back(std::pair( + next_grammar_state, next_std_state)); + } else { + next_std_state = state_iter->second; + } + std_arc.nextstate = next_std_state; + vector_fst->AddArc(std_state, std_arc); + } + } +} + + + +} // end namespace fst diff --git a/src/decoder/grammar-fst.h b/src/decoder/grammar-fst.h new file mode 100644 index 00000000000..cfbfcad4ec6 --- /dev/null +++ b/src/decoder/grammar-fst.h @@ -0,0 +1,642 @@ +// decoder/grammar-fst.h + +// Copyright 2018 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_GRAMMAR_FST_H_ +#define KALDI_DECODER_GRAMMAR_FST_H_ + +/** + For an extended explanation of the framework of which grammar-fsts are a + part, please see \ref grammar (i.e. ../doc/grammar.dox). + + This header implements a special FST type which we use in that framework; + it is a lightweight wrapper which stitches together several FSTs and makes + them look, to the decoder code, like a single FST. It is a bit like + OpenFst's Replace() function, but with support for left-biphone context. + */ + + + +#include "fst/fstlib.h" +#include "fstext/grammar-context-fst.h" + +namespace fst { + + +// GrammarFstArc is an FST Arc type which differs from the normal StdArc type by +// having the state-id be 64 bits, enough to store two indexes: the higher 32 +// bits for the FST-instance index, and the lower 32 bits for the state within +// that FST-instance. +// Obviously this leads to very high-numbered state indexes, which might be +// a problem in some circumstances, but the decoder code doesn't store arrays +// indexed by state, it uses hashes, so this isn't a problem. +struct GrammarFstArc { + typedef fst::TropicalWeight Weight; + typedef int Label; // OpenFst's StdArc uses int; this is for compatibility. + typedef int64 StateId; + + Label ilabel; + Label olabel; + Weight weight; + StateId nextstate; + + GrammarFstArc() {} + + GrammarFstArc(Label ilabel, Label olabel, Weight weight, StateId nextstate) + : ilabel(ilabel), + olabel(olabel), + weight(std::move(weight)), + nextstate(nextstate) {} +}; + +#define KALDI_GRAMMAR_FST_SPECIAL_WEIGHT 4096.0 + +class GrammarFst; + +// Declare that we'll be overriding class ArcIterator for class GrammarFst. +// This wouldn't work if we were fully using the OpenFst framework, +// e.g. if we had GrammarFst inherit from class Fst. +template<> class ArcIterator; + + +/** + GrammarFst is an FST that is 'stitched together' from multiple FSTs, that can + recursively incorporate each other. (This is limited to left-biphone + phonetic context). This class does not inherit from fst::Fst and does not + support its full interface-- only the parts that are necessary for the + decoder to work when templated on it. + + The basic interface is inspired by OpenFst's 'ReplaceFst' (see its + replace.h), except that this handles left-biphone phonetic context, which + requires, essentially, having multiple exit-points and entry-points for + sub-FSTs that represent nonterminals in the grammar; and multiple return + points whenever we invoke a nonterminal. For more information + see \ref grammar (i.e. ../doc/grammar.dox). + + THREAD SAFETY: you can't use this object from multiple threads; you should + create lightweight copies of this object using the copy constructor, + e.g. `new GrammarFst(this_grammar_fst)`, if you want to decode from multiple + threads using the same GrammarFst. +*/ +class GrammarFst { + public: + typedef GrammarFstArc Arc; + typedef TropicalWeight Weight; + + // StateId is actually int64. The high-order 32 bits are interpreted as an + // instance_id, i.e. and index into the instances_ vector; the low-order 32 + // bits are the state index in the FST instance. + typedef Arc::StateId StateId; + + // The StateId of the individual FST instances (int, currently). + typedef StdArc::StateId BaseStateId; + + typedef Arc::Label Label; + + + /** + Constructor. This constructor is very lightweight; the only immediate work + it does is to iterate over the arcs in the start states of the provided + FSTs in order to set up the appropriate entry points. + + For simplicity (to avoid templates), we limit the input FSTs to be of type + ConstFst; this limitation could be removed later if needed. You + can always construct a ConstFst if you have another StdArc-based + FST type. If the FST was read from disk, it may already be of type + ConstFst, and dynamic_cast might be sufficient to convert the type. + + @param [in] nonterm_phones_offset The integer id of the symbol + "#nonterm_bos" in phones.txt. + @param [in] top_fst The top-level FST of the grammar, which will + usually invoke the fsts in 'ifsts'. The fsts in 'ifsts' may + also invoke each other recursively. Even left-recursion is + allowed, although if it is with zero cost, it may blow up when you + decode. When an FST invokes another, the invocation point will + have sequences of two special symbols which would be decoded as: + (#nonterm:foo,p1) (#nonterm_reenter,p2) + where p1 and p2 (which may be real phones or #nonterm_bos) + represent the phonetic left-context that we enter, and leave, the + sub-graph with respectively. + @param [in] ifsts ifsts is a list of pairs (nonterminal-symbol, + the HCLG.fst corresponding to that symbol). The nonterminal + symbols must be among the user-specified nonterminals in + phones.txt, i.e. the things with names like "#nonterm:foo" and + "#nonterm:bar" in phones.txt. Also no nonterminal may appear more + than once in 'fsts'. ifsts may be empty, even though that doesn't + make much sense. + */ + GrammarFst( + int32 nonterm_phones_offset, + std::shared_ptr > top_fst, + const std::vector > > > &ifsts); + + /// Copy constructor. Useful because this object is not thread safe so cannot + /// be used by multiple parallel decoder threads, but it is lightweight and + /// can copy it without causing the stored FSTs to be copied. + GrammarFst(const GrammarFst &other) = default; + + /// This constructor should only be used prior to calling Read(). + GrammarFst() { } + + // This Write function allows you to dump a GrammarFst to disk as a single + // object. It only supports binary mode, but the option is allowed for + // compatibility with other Kaldi read/write functions (it will crash if + // binary == false). + void Write(std::ostream &os, bool binary) const; + + // Reads the format that Write() outputs. Will crash if binary == false. + void Read(std::istream &os, bool binary); + + StateId Start() const { + // the top 32 bits of the 64-bit state-id will be zero, because the + // top FST instance has instance-id = 0. + return static_cast(top_fst_->Start()); + } + + Weight Final(StateId s) const { + // If the fst-id (top 32 bits of s) is nonzero, this state is not final, + // because we need to return to the top-level FST before we can be final. + if (s != static_cast(static_cast(s))) { + return Weight::Zero(); + } else { + BaseStateId base_state = static_cast(s); + Weight ans = top_fst_->Final(base_state); + if (ans.Value() == KALDI_GRAMMAR_FST_SPECIAL_WEIGHT) { + return Weight::Zero(); + } else { + return ans; + } + } + } + + // This is called in LatticeFasterDecoder. As an implementation shortcut, if + // the state is an expanded state, we return 1, meaning 'yes, there are input + // epsilons'; the calling code doesn't actually care about the exact number. + inline size_t NumInputEpsilons(StateId s) const { + // Compare with the constructor of ArcIterator. + int32 instance_id = s >> 32; + BaseStateId base_state = static_cast(s); + const GrammarFst::FstInstance &instance = instances_[instance_id]; + const ConstFst *base_fst = instance.fst; + if (base_fst->Final(base_state).Value() != KALDI_GRAMMAR_FST_SPECIAL_WEIGHT) { + return base_fst->NumInputEpsilons(base_state); + } else { + return 1; + } + } + + inline std::string Type() const { return "grammar"; } + + ~GrammarFst(); + private: + + struct ExpandedState; + + friend class ArcIterator; + + // sets up nonterminal_map_. + void InitNonterminalMap(); + + // sets up entry_arcs_[i]. We do this only on demand, as each one is + // accessed, so that if there are a lot of nonterminals, this object doesn't + // too much work when it is initialized. Each call to this function only + // takes time O(number of left-context phones), which is quite small, but we'd + // like to avoid that if possible. + void InitEntryArcs(int32 i); + + // sets up instances_ with the top-level instance. + void InitInstances(); + + // Does the initialization tasks after nonterm_phones_offset_, + // top_fsts_ and ifsts_ have been set up + void Init(); + + // clears everything. + void Destroy(); + + /* + This utility function sets up a map from "left-context phone", meaning + either a phone index or the index of the symbol #nonterm_bos, to + an arc-index leaving a particular state in an FST (i.e. an index + that we could use to Seek() to the matching arc). + + @param [in] fst The FST that is being entered (or reentered) + @param [in] entry_state The state in 'fst' which is being entered + (or reentered); will be fst.Start() if it's being + entered. It must have arcs with ilabels decodable as + (nonterminal_symbol, left_context_phone). Will either be the + start state (if 'nonterminal_symbol' corresponds to + #nonterm_begin), or an internal state (if 'nonterminal_symbol' + corresponds to #nonterm_reenter). The arc-indexes of those + arcs will be the values we set in 'phone_to_arc' + @param [in] nonterminal_symbol The index in phones.txt of the + nonterminal symbol we expect to be encoded in the ilabels + of the arcs leaving 'entry_state'. Will either correspond + to #nonterm_begin or #nonterm_reenter. + @param [out] phone_to_arc We output the map from left_context_phone + to the arc-index (i.e. the index we'd have to Seek() to + in an arc-iterator set up for the state 'entry_state). + */ + void InitEntryOrReentryArcs( + const ConstFst &fst, + int32 entry_state, + int32 nonterminal_symbol, + std::unordered_map *phone_to_arc); + + + inline int32 GetPhoneSymbolFor(enum NonterminalValues n) { + return nonterm_phones_offset_ + static_cast(n); + } + /** + Decodes an ilabel into a pair (nonterminal, left_context_phone). Crashes + if something went wrong or ilabel did not represent that (e.g. was less + than kNontermBigNumber). + + @param [in] the ilabel to be decoded. Note: the type 'Label' will in practice be int. + @param [out] The nonterminal part of the ilabel after decoding. + Will be a value greater than nonterm_phones_offset_. + @param [out] The left-context-phone part of the ilabel after decoding. + Will either be a phone index, or the symbol corresponding + to #nonterm_bos (meaning no left-context as we are at + the beginning of the sequence). + */ + void DecodeSymbol(Label label, + int32 *nonterminal_symbol, + int32 *left_context_phone); + + + // This function creates and returns an ExpandedState corresponding to a + // particular state-id in the FstInstance for this instance_id. It is called + // when we have determined that an ExpandedState needs to be created and that + // it is not currently present. It creates and returns it; the calling code + // needs to add it to the expanded_states map for its FST instance. + ExpandedState *ExpandState(int32 instance_id, BaseStateId state_id); + + // Called from ExpandState() when the nonterminal type on the arcs is + // #nonterm_end, this implements ExpandState() for that case. + ExpandedState *ExpandStateEnd(int32 instance_id, BaseStateId state_id); + + // Called from ExpandState() when the nonterminal type on the arcs is a + // user-defined nonterminal, this implements ExpandState() for that case. + ExpandedState *ExpandStateUserDefined(int32 instance_id, BaseStateId state_id); + + // Called from ExpandStateUserDefined(), this function attempts to look up the + // pair (nonterminal, state) in the map + // instances_[instance_id].child_instances. If it exists (because this + // return-state has been expanded before), it returns the value it found; + // otherwise it creates the child-instance and returns its newly created + // instance-id. + inline int32 GetChildInstanceId(int32 instance_id, int32 nonterminal, + int32 state); + + /** + Called while expanding states, this function combines information from two + arcs: one leaving one sub-fst and one arriving in another sub-fst. + + @param [in] leaving_arc The arc leaving the first FST; must have + zero olabel. The ilabel will have a nonterminal symbol + like #nonterm:foo or #nonterm_end on it, encoded with a + phonetic context, but we ignore the ilabel. + @param [in] arriving_arc The arc arriving in the second FST. + It will have an ilabel consisted of either #nonterm_begin + or #nonterm_enter combined with a left-context phone, + but we ignore the ilabel. + @param [in] cost_correction A correction term that we add to the + cost of the arcs. This basically cancels out the + "1/num_options" part of the weight that we added in L.fst + when we put in all the phonetic-context options. We + did that to keep the FST stochastic, so that if we ever + pushed the weights, it wouldn't lead to weird effects. + This takes out that correction term... things will + still sum to one in the appropriate way, because in fact + when we cross these FST boundaries we only take one + specific phonetic context, rather than all possibilities. + @param [out] arc The arc that we output. Will have: + - weight equal to the product of the input arcs' weights, + times a weight constructed from 'cost_correction'. + - olabel equal to arriving_arc.olabel (leaving_arc's olabel + will be zero). + - ilabel equal to zero (we discard both ilabels, they are + not transition-ids but special symbols). + - nextstate equal to the nextstate of arriving_arc. + */ + static inline void CombineArcs(const StdArc &leaving_arc, + const StdArc &arriving_arc, + float cost_correction, + StdArc *arc); + + /** Called from the ArcIterator constructor when we encounter an FST state with + nonzero final-prob, this function first looks up this state_id in + 'expanded_states' member of the corresponding FstInstance, and returns it + if already present; otherwise it populates the 'expanded_states' map with + something for this state_id and returns the value. + */ + inline ExpandedState *GetExpandedState(int32 instance_id, + BaseStateId state_id) { + std::unordered_map &expanded_states = + instances_[instance_id].expanded_states; + + std::unordered_map::iterator iter = + expanded_states.find(state_id); + if (iter != expanded_states.end()) { + return iter->second; + } else { + ExpandedState *ans = ExpandState(instance_id, state_id); + // Don't use the reference 'expanded_states'; it could have been + // invalidated. + instances_[instance_id].expanded_states[state_id] = ans; + return ans; + } + } + + /** + Represents an expanded state in an FstInstance. We expand states whenever + we encounter states with a final-cost equal to + KALDI_GRAMMAR_FST_SPECIAL_WEIGHT (4096.0). The function + PrepareGrammarFst() makes sure to add this special final-cost on states + that have special arcs leaving them. */ + struct ExpandedState { + // The final-prob for expanded states is always zero; to avoid + // corner cases, we ensure this via adding epsilon arcs where + // needed. + + // fst-instance index of destination state (we will have ensured previously + // that this is the same for all outgoing arcs). + int32 dest_fst_instance; + + // List of arcs out of this state, where the 'nextstate' element will be the + // lower-order 32 bits of the destination state and the higher order bits + // will be given by 'dest_fst_instance'. We do it this way, instead of + // constructing a vector, in order to simplify the ArcIterator code and + // avoid unnecessary branches in loops over arcs. + // We guarantee that this 'arcs' array will always be nonempty; this + // is to avoid certain hassles on Windows with automated bounds-checking. + std::vector arcs; + }; + + + // An FstInstance is a copy of an FST. The instance numbered zero is for + // top_fst_, and (to state it approximately) whenever any FST instance invokes + // another FST a new instance will be generated on demand. + struct FstInstance { + // ifst_index is the index into the ifsts_ vector that corresponds to this + // FST instance, or -1 if this is the top-level instance. + int32 ifst_index; + + // Pointer to the FST corresponding to this instance: it will equal top_fst_ + // if ifst_index == -1, or ifsts_[ifst_index].second otherwise. + const ConstFst *fst; + + // 'expanded_states', which will be populated on demand as states in this + // FST instance are accessed, will only contain entries for states in this + // FST that the final-prob's value equal to + // KALDI_GRAMMAR_FST_SPECIAL_WEIGHT. (That final-prob value is used as a + // kind of signal to this code that the state needs expansion). + std::unordered_map expanded_states; + + // 'child_instances', which is populated on demand as states in this FST + // instance are accessed, is logically a map from pair (nonterminal_index, + // return_state) to instance_id. When we encounter an arc in our FST with a + // user-defined nonterminal indexed 'nonterminal_index' on its ilabel, and + // with 'return_state' as its nextstate, we look up that pair + // (nonterminal_index, return_state) in this map to see whether there already + // exists an FST instance for that. If it exists then the transition goes to + // that FST instance; if not, then we create a new one. The 'return_state' + // that's part of the key in this map would be the same as the 'parent_state' + // in that child FST instance, and of course the 'parent_instance' in + // that child FST instance would be the instance_id of this instance. + // + // In most cases each return_state would only have a single + // nonterminal_index, making the 'nonterminal_index' in the key *usually* + // redundant, but in principle it could happen that two user-defined + // nonterminals might share the same return-state. + std::unordered_map child_instances; + + // The instance-id of the FST we return to when we are done with this one + // (or -1 if this is the top-level FstInstance so there is nowhere to + // return). + int32 parent_instance; + + // The state in the FST of 'parent_instance' at which we expanded this FST + // instance, and to which we return (actually we return to the next-states + // of arcs out of 'parent_state'). + int32 parent_state; + + // 'parent_reentry_arcs' is a map from left-context-phone (i.e. either a + // phone index or #nonterm_bos), to an arc-index, which we could use to + // Seek() in an arc-iterator for state parent_state in the FST-instance + // 'parent_instance'. It's set up when we create this FST instance. (The + // arcs used to enter this instance are not located here, they can be + // located in entry_arcs_[instance_id]). We make use of reentry_arcs when + // we expand states in this FST that have #nonterm_end on their arcs, + // leading to final-states, which signal a return to the parent + // FST-instance. + std::unordered_map parent_reentry_arcs; + }; + + // The integer id of the symbol #nonterm_bos in phones.txt. + int32 nonterm_phones_offset_; + + // The top-level FST passed in by the user; contains the start state and + // final-states, and may invoke FSTs in 'ifsts_' (which can also invoke + // each other recursively). + std::shared_ptr > top_fst_; + + // A list of pairs (nonterm, fst), where 'nonterm' is a user-defined + // nonterminal symbol as numbered in phones.txt (e.g. #nonterm:foo), and + // 'fst' is the corresponding FST. + std::vector > > > ifsts_; + + // Maps from the user-defined nonterminals like #nonterm:foo as numbered + // in phones.txt, to the corresponding index into 'ifsts_', i.e. the ifst_index. + std::unordered_map nonterminal_map_; + + // entry_arcs_ will have the same dimension as ifsts_. Each entry_arcs_[i] + // is a map from left-context phone (i.e. either a phone-index or + // #nonterm_bos) to the corresponding arc-index leaving the start-state in + // the FST 'ifsts_[i].second'. + // We populate this only on demand as each one is needed (except for the + // first one, which we populate immediately as a kind of sanity check). + // Doing it on-demand prevents this object's initialization from being + // nontrivial in the case where there are a lot of nonterminals. + std::vector > entry_arcs_; + + // The FST instances. Initially it is a vector with just one element + // representing top_fst_, and it will be populated with more elements on + // demand. An instance_id refers to an index into this vector. + std::vector instances_; +}; + + +/** + This is the overridden template for class ArcIterator for GrammarFst. This + is only used in the decoder, and the GrammarFst is not a "real" FST (it just + has a very similar-looking interface), so we don't need to implement all the + functionality that the regular ArcIterator has. + */ +template <> +class ArcIterator { + public: + using Arc = typename GrammarFst::Arc; + using BaseArc = StdArc; + using StateId = typename Arc::StateId; // int64 + using BaseStateId = typename StdArc::StateId; // int + using ExpandedState = GrammarFst::ExpandedState; + + // Caution: uses const_cast to evade const rules on GrammarFst. This is for + // compatibility with how things work in OpenFst. + inline ArcIterator(const GrammarFst &fst_in, StateId s) { + GrammarFst &fst = const_cast(fst_in); + // 'instance_id' is the high order bits of the state. + int32 instance_id = s >> 32; + // 'base_state' is low order bits of the state. It's important to + // explicitly say int32 below, not BaseStateId == int, which might on some + // compilers be a 64-bit type. + BaseStateId base_state = static_cast(s); + const GrammarFst::FstInstance &instance = fst.instances_[instance_id]; + const ConstFst *base_fst = instance.fst; + if (base_fst->Final(base_state).Value() != KALDI_GRAMMAR_FST_SPECIAL_WEIGHT) { + // A normal state + dest_instance_ = instance_id; + base_fst->InitArcIterator(s, &data_); + i_ = 0; + } else { + // A special state + ExpandedState *expanded_state = fst.GetExpandedState(instance_id, + base_state); + dest_instance_ = expanded_state->dest_fst_instance; + // it's ok to leave the other members of data_ uninitialized, as they will + // never be interrogated. + data_.arcs = &(expanded_state->arcs[0]); + data_.narcs = expanded_state->arcs.size(); + i_ = 0; + } + // Ideally we want to call CopyArcToTemp() now, but we rely on the fact that + // the calling code needs to call Done() before accessing Value(); we call + // CopyArcToTemp() from Done(). Of course this is slightly against the + // semantics of Done(), but it's more efficient to have Done() call + // CopyArcToTemp() than this function or Next(), as Done() already has to + // test that the arc-iterator has not reached the end. + } + + inline bool Done() { + if (i_ < data_.narcs) { + CopyArcToTemp(); + return false; + } else { + return true; + } + } + + inline void Next() { + i_++; + // Note: logically, at this point we should do: + // if (i_ < data_.size) + // CopyArcToTemp(); + // Instead we move this CopyArcToTemp() invocation into Done(), which we + // know will always be called after Next() and before Value(), because the + // user has no other way of knowing whether the iterator is still valid. + // This is for efficiency. + } + + inline const Arc &Value() const { return arc_; } + + private: + + inline void CopyArcToTemp() { + const StdArc &src = data_.arcs[i_]; + arc_.ilabel = src.ilabel; + arc_.olabel = src.olabel; + arc_.weight = src.weight; + arc_.nextstate = (static_cast(dest_instance_) << 32) | + src.nextstate; + } + + // The members of 'data_' that we use are: + // const Arc *arcs; + // size_t narcs; + ArcIteratorData data_; + + + int32 dest_instance_; // The index of the FstInstance that we transition to from + // this state. + size_t i_; // i_ is the index into the 'arcs' pointer. + + Arc arc_; // 'Arc' is the current arc in the GrammarFst, that this iterator + // is pointing to. It will be a copy of data_.arcs[i], except with + // the 'nextstate' modified to encode dest_instance_ in the higher + // order bits. Making a copy is of course unnecessary for the most + // part, but Value() needs to return a reference; we rely on the + // compiler to optimize out any unnecessary moves of data. +}; + +/** + This function copies a GrammarFst to a VectorFst (intended mostly for testing + and comparison purposes). GrammarFst doesn't actually inherit from class + Fst, so we can't just construct an FST from the GrammarFst. + + grammar_fst gets expanded by this call, and although we could make it a const + reference (because the ArcIterator does actually use const_cast), we make it + a non-const pointer to emphasize that this call does change grammar_fst. + */ +void CopyToVectorFst(GrammarFst *grammar_fst, + VectorFst *vector_fst); + +/** + This function prepares 'ifst' for use in GrammarFst: it ensures that it has + the expected properties, changing it slightly as needed. 'ifst' is expected + to be a fully compiled HCLG graph that is intended to be used in GrammarFst. + The user will most likely want to copy it to the ConstFst type after calling + this function. + + The following describes what this function does, and the reasons why + it has to do these things: + + - To keep the ArcIterator code simple (to avoid branches in loops), even + for expanded states we store the destination fst-instance index + separately per state, not per arc. This requires that any transitions + across FST boundaries from a single FST must be to a single destination + FST (for a given source state). We fix this problem by introducing + epsilon arcs and new states whenever we find a state that would cause a + problem for the above. + - In order to signal to the GrammarFst code that a particular state has + cross-FST-boundary transitions, we set the final-prob to a nonzero value + on that state. Specifically, we use a weight with Value() == 4096.0. + When the GrammarFst code sees that value it knows that it was not a + 'real' final-prob. Prior to doing this we ensure, by adding epsilon + transitions as needed, that the state did not previously have a + final-prob. + - For arcs that are final arcs in an FST that represents a nonterminal + (these arcs would have #nonterm_exit on them), we ensure that the + states that they transition to have unit final-prob (i.e. final-prob + equal to One()), by incorporating any final-prob into the arc itself. + This avoids the GrammarFst code having to inspect those final-probs + when expanding states. + + @param [in] nonterm_phones_offset The integer id of + the symbols #nonterm_bos in the phones.txt file. + @param [in,out] fst The FST to be (slightly) modified. + */ +void PrepareForGrammarFst(int32 nonterm_phones_offset, + VectorFst *fst); + + +} // end namespace fst + + +#endif diff --git a/src/decoder/lattice-faster-decoder.cc b/src/decoder/lattice-faster-decoder.cc index b837d836a70..2bc8c7cdef4 100644 --- a/src/decoder/lattice-faster-decoder.cc +++ b/src/decoder/lattice-faster-decoder.cc @@ -1,7 +1,7 @@ // decoder/lattice-faster-decoder.cc // Copyright 2009-2012 Microsoft Corporation Mirko Hannemann -// 2013-2014 Johns Hopkins University (Author: Daniel Povey) +// 2013-2018 Johns Hopkins University (Author: Daniel Povey) // 2014 Guoguo Chen // 2018 Zhehuai Chen @@ -20,40 +20,40 @@ // See the Apache 2 License for the specific language governing permissions and // limitations under the License. -// Note on svn: this file is "upstream" from lattice-faster-online-decoder.cc, and -// changes in this file should be merged into lattice-faster-online-decoder.cc, -// after committing the changes to this file, using the command -// svn merge ^/sandbox/online/src/decoder/lattice-faster-decoder.cc lattice-faster-online-decoder.cc - #include "decoder/lattice-faster-decoder.h" #include "lat/lattice-functions.h" namespace kaldi { // instantiate this class once for each thing you have to decode. -LatticeFasterDecoder::LatticeFasterDecoder(const fst::Fst &fst, - const LatticeFasterDecoderConfig &config): - fst_(fst), delete_fst_(false), config_(config), num_toks_(0) { +template +LatticeFasterDecoderTpl::LatticeFasterDecoderTpl( + const FST &fst, + const LatticeFasterDecoderConfig &config): + fst_(&fst), delete_fst_(false), config_(config), num_toks_(0) { config.Check(); toks_.SetSize(1000); // just so on the first frame we do something reasonable. } -LatticeFasterDecoder::LatticeFasterDecoder(const LatticeFasterDecoderConfig &config, - fst::Fst *fst): - fst_(*fst), delete_fst_(true), config_(config), num_toks_(0) { +template +LatticeFasterDecoderTpl::LatticeFasterDecoderTpl( + const LatticeFasterDecoderConfig &config, FST *fst): + fst_(fst), delete_fst_(true), config_(config), num_toks_(0) { config.Check(); toks_.SetSize(1000); // just so on the first frame we do something reasonable. } -LatticeFasterDecoder::~LatticeFasterDecoder() { +template +LatticeFasterDecoderTpl::~LatticeFasterDecoderTpl() { DeleteElems(toks_.Clear()); ClearActiveTokens(); - if (delete_fst_) delete &(fst_); + if (delete_fst_) delete fst_; } -void LatticeFasterDecoder::InitDecoding() { +template +void LatticeFasterDecoderTpl::InitDecoding() { // clean up from last time: DeleteElems(toks_.Clear()); cost_offsets_.clear(); @@ -62,20 +62,21 @@ void LatticeFasterDecoder::InitDecoding() { num_toks_ = 0; decoding_finalized_ = false; final_costs_.clear(); - StateId start_state = fst_.Start(); + StateId start_state = fst_->Start(); KALDI_ASSERT(start_state != fst::kNoStateId); active_toks_.resize(1); - Token *start_tok = new Token(0.0, 0.0, NULL, NULL); + Token *start_tok = new Token(0.0, 0.0, NULL, NULL, NULL); active_toks_[0].toks = start_tok; toks_.Insert(start_state, start_tok); num_toks_++; - ProcessNonemittingWrapper(config_.beam); + ProcessNonemitting(config_.beam); } // Returns true if any kind of traceback is available (not necessarily from // a final state). It should only very rarely return false; this indicates // an unusual search error. -bool LatticeFasterDecoder::Decode(DecodableInterface *decodable) { +template +bool LatticeFasterDecoderTpl::Decode(DecodableInterface *decodable) { InitDecoding(); // We use 1-based indexing for frames in this decoder (if you view it in @@ -85,8 +86,8 @@ bool LatticeFasterDecoder::Decode(DecodableInterface *decodable) { while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { if (NumFramesDecoded() % config_.prune_interval == 0) PruneActiveTokens(config_.lattice_beam * config_.prune_scale); - BaseFloat cost_cutoff = ProcessEmittingWrapper(decodable); - ProcessNonemittingWrapper(cost_cutoff); + BaseFloat cost_cutoff = ProcessEmitting(decodable); + ProcessNonemitting(cost_cutoff); } FinalizeDecoding(); @@ -97,7 +98,8 @@ bool LatticeFasterDecoder::Decode(DecodableInterface *decodable) { // Outputs an FST corresponding to the single best path through the lattice. -bool LatticeFasterDecoder::GetBestPath(Lattice *olat, +template +bool LatticeFasterDecoderTpl::GetBestPath(Lattice *olat, bool use_final_probs) const { Lattice raw_lat; GetRawLattice(&raw_lat, use_final_probs); @@ -105,10 +107,12 @@ bool LatticeFasterDecoder::GetBestPath(Lattice *olat, return (olat->NumStates() != 0); } -// Outputs an FST corresponding to the raw, state-level -// tracebacks. -bool LatticeFasterDecoder::GetRawLattice(Lattice *ofst, - bool use_final_probs) const { + +// Outputs an FST corresponding to the raw, state-level lattice +template +bool LatticeFasterDecoderTpl::GetRawLattice( + Lattice *ofst, + bool use_final_probs) const { typedef LatticeArc Arc; typedef Arc::StateId StateId; typedef Arc::Weight Weight; @@ -159,11 +163,11 @@ bool LatticeFasterDecoder::GetRawLattice(Lattice *ofst, for (int32 f = 0; f <= num_frames; f++) { for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { StateId cur_state = tok_map[tok]; - for (ForwardLink *l = tok->links; + for (ForwardLinkT *l = tok->links; l != NULL; l = l->next) { - unordered_map::const_iterator iter = - tok_map.find(l->next_tok); + typename unordered_map::const_iterator + iter = tok_map.find(l->next_tok); StateId nextstate = iter->second; KALDI_ASSERT(iter != tok_map.end()); BaseFloat cost_offset = 0.0; @@ -178,8 +182,8 @@ bool LatticeFasterDecoder::GetRawLattice(Lattice *ofst, } if (f == num_frames) { if (use_final_probs && !final_costs.empty()) { - unordered_map::const_iterator iter = - final_costs.find(tok); + typename unordered_map::const_iterator + iter = final_costs.find(tok); if (iter != final_costs.end()) ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); } else { @@ -195,8 +199,9 @@ bool LatticeFasterDecoder::GetRawLattice(Lattice *ofst, // This function is now deprecated, since now we do determinization from outside // the LatticeFasterDecoder class. Outputs an FST corresponding to the // lattice-determinized lattice (one path per word sequence). -bool LatticeFasterDecoder::GetLattice(CompactLattice *ofst, - bool use_final_probs) const { +template +bool LatticeFasterDecoderTpl::GetLattice(CompactLattice *ofst, + bool use_final_probs) const { Lattice raw_fst; GetRawLattice(&raw_fst, use_final_probs); Invert(&raw_fst); // make it so word labels are on the input. @@ -217,7 +222,8 @@ bool LatticeFasterDecoder::GetLattice(CompactLattice *ofst, return (ofst->NumStates() != 0); } -void LatticeFasterDecoder::PossiblyResizeHash(size_t num_toks) { +template +void LatticeFasterDecoderTpl::PossiblyResizeHash(size_t num_toks) { size_t new_sz = static_cast(static_cast(num_toks) * config_.hash_ratio); if (new_sz > toks_.Size()) { @@ -256,8 +262,10 @@ void LatticeFasterDecoder::PossiblyResizeHash(size_t num_toks) { // for the current frame. [note: it's inserted if necessary into hash toks_ // and also into the singly linked list of tokens active on this frame // (whose head is at active_toks_[frame]). -inline LatticeFasterDecoder::Token *LatticeFasterDecoder::FindOrAddToken( - StateId state, int32 frame_plus_one, BaseFloat tot_cost, bool *changed) { +template +inline Token* LatticeFasterDecoderTpl::FindOrAddToken( + StateId state, int32 frame_plus_one, BaseFloat tot_cost, + Token *backpointer, bool *changed) { // Returns the Token pointer. Sets "changed" (if non-NULL) to true // if the token was newly created or the cost changed. KALDI_ASSERT(frame_plus_one < active_toks_.size()); @@ -268,7 +276,7 @@ inline LatticeFasterDecoder::Token *LatticeFasterDecoder::FindOrAddToken( // tokens on the currently final frame have zero extra_cost // as any of them could end up // on the winning path. - Token *new_tok = new Token (tot_cost, extra_cost, NULL, toks); + Token *new_tok = new Token (tot_cost, extra_cost, NULL, toks, backpointer); // NULL: no forward links yet toks = new_tok; num_toks_++; @@ -279,6 +287,9 @@ inline LatticeFasterDecoder::Token *LatticeFasterDecoder::FindOrAddToken( Token *tok = e_found->val; // There is an existing Token for this state. if (tok->tot_cost > tot_cost) { // replace old token tok->tot_cost = tot_cost; + // SetBackpointer() just does tok->backpointer = backpointer in + // the case where Token == BackpointerToken, else nothing. + tok->SetBackpointer(backpointer); // we don't allocate a new token, the old stays linked in active_toks_ // we only replace the tot_cost // in the current frame, there are no forward links (and no extra_cost) @@ -297,7 +308,8 @@ inline LatticeFasterDecoder::Token *LatticeFasterDecoder::FindOrAddToken( // prunes outgoing links for all tokens in active_toks_[frame] // it's called by PruneActiveTokens // all links, that have link_extra_cost > lattice_beam are pruned -void LatticeFasterDecoder::PruneForwardLinks( +template +void LatticeFasterDecoderTpl::PruneForwardLinks( int32 frame_plus_one, bool *extra_costs_changed, bool *links_pruned, BaseFloat delta) { // delta is the amount by which the extra_costs must change @@ -324,7 +336,7 @@ void LatticeFasterDecoder::PruneForwardLinks( changed = false; for (Token *tok = active_toks_[frame_plus_one].toks; tok != NULL; tok = tok->next) { - ForwardLink *link, *prev_link = NULL; + ForwardLinkT *link, *prev_link = NULL; // will recompute tok_extra_cost for tok. BaseFloat tok_extra_cost = std::numeric_limits::infinity(); // tok_extra_cost is the best (min) of link_extra_cost of outgoing links @@ -338,7 +350,7 @@ void LatticeFasterDecoder::PruneForwardLinks( // through link source state and through link destination state KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN if (link_extra_cost > config_.lattice_beam) { // excise link - ForwardLink *next_link = link->next; + ForwardLinkT *next_link = link->next; if (prev_link != NULL) prev_link->next = next_link; else tok->links = next_link; delete link; @@ -373,14 +385,15 @@ void LatticeFasterDecoder::PruneForwardLinks( // PruneForwardLinksFinal is a version of PruneForwardLinks that we call // on the final frame. If there are final tokens active, it uses // the final-probs for pruning, otherwise it treats all tokens as final. -void LatticeFasterDecoder::PruneForwardLinksFinal() { +template +void LatticeFasterDecoderTpl::PruneForwardLinksFinal() { KALDI_ASSERT(!active_toks_.empty()); int32 frame_plus_one = active_toks_.size() - 1; if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen. KALDI_WARN << "No tokens alive at end of file"; - typedef unordered_map::const_iterator IterType; + typedef typename unordered_map::const_iterator IterType; ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_); decoding_finalized_ = true; // We call DeleteElems() as a nicety, not because it's really necessary; @@ -399,7 +412,7 @@ void LatticeFasterDecoder::PruneForwardLinksFinal() { changed = false; for (Token *tok = active_toks_[frame_plus_one].toks; tok != NULL; tok = tok->next) { - ForwardLink *link, *prev_link = NULL; + ForwardLinkT *link, *prev_link = NULL; // will recompute tok_extra_cost. It has a term in it that corresponds // to the "final-prob", so instead of initializing tok_extra_cost to infinity // below we set it to the difference between the (score+final_prob) of this token, @@ -425,7 +438,7 @@ void LatticeFasterDecoder::PruneForwardLinksFinal() { ((tok->tot_cost + link->acoustic_cost + link->graph_cost) - next_tok->tot_cost); if (link_extra_cost > config_.lattice_beam) { // excise link - ForwardLink *next_link = link->next; + ForwardLinkT *next_link = link->next; if (prev_link != NULL) prev_link->next = next_link; else tok->links = next_link; delete link; @@ -457,7 +470,8 @@ void LatticeFasterDecoder::PruneForwardLinksFinal() { } // while changed } -BaseFloat LatticeFasterDecoder::FinalRelativeCost() const { +template +BaseFloat LatticeFasterDecoderTpl::FinalRelativeCost() const { if (!decoding_finalized_) { BaseFloat relative_cost; ComputeFinalCosts(NULL, &relative_cost, NULL); @@ -474,7 +488,8 @@ BaseFloat LatticeFasterDecoder::FinalRelativeCost() const { // [we don't do this in PruneForwardLinks because it would give us // a problem with dangling pointers]. // It's called by PruneActiveTokens if any forward links have been pruned -void LatticeFasterDecoder::PruneTokensForFrame(int32 frame_plus_one) { +template +void LatticeFasterDecoderTpl::PruneTokensForFrame(int32 frame_plus_one) { KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); Token *&toks = active_toks_[frame_plus_one].toks; if (toks == NULL) @@ -500,7 +515,8 @@ void LatticeFasterDecoder::PruneTokensForFrame(int32 frame_plus_one) { // that. We go backwards through the frames and stop when we reach a point // where the delta-costs are not changing (and the delta controls when we consider // a cost to have "not changed"). -void LatticeFasterDecoder::PruneActiveTokens(BaseFloat delta) { +template +void LatticeFasterDecoderTpl::PruneActiveTokens(BaseFloat delta) { int32 cur_frame_plus_one = NumFramesDecoded(); int32 num_toks_begin = num_toks_; // The index "f" below represents a "frame plus one", i.e. you'd have to subtract @@ -529,7 +545,8 @@ void LatticeFasterDecoder::PruneActiveTokens(BaseFloat delta) { << " to " << num_toks_; } -void LatticeFasterDecoder::ComputeFinalCosts( +template +void LatticeFasterDecoderTpl::ComputeFinalCosts( unordered_map *final_costs, BaseFloat *final_relative_cost, BaseFloat *final_best_cost) const { @@ -540,11 +557,12 @@ void LatticeFasterDecoder::ComputeFinalCosts( BaseFloat infinity = std::numeric_limits::infinity(); BaseFloat best_cost = infinity, best_cost_with_final = infinity; + while (final_toks != NULL) { StateId state = final_toks->key; Token *tok = final_toks->val; const Elem *next = final_toks->tail; - BaseFloat final_cost = fst_.Final(state).Value(); + BaseFloat final_cost = fst_->Final(state).Value(); BaseFloat cost = tok->tot_cost, cost_with_final = cost + final_cost; best_cost = std::min(cost, best_cost); @@ -571,8 +589,27 @@ void LatticeFasterDecoder::ComputeFinalCosts( } } -void LatticeFasterDecoder::AdvanceDecoding(DecodableInterface *decodable, - int32 max_num_frames) { +template +void LatticeFasterDecoderTpl::AdvanceDecoding(DecodableInterface *decodable, + int32 max_num_frames) { + if (std::is_same >::value) { + // if the type 'FST' is the FST base-class, then see if the FST type of fst_ + // is actually VectorFst or ConstFst. If so, call the AdvanceDecoding() + // function after casting *this to the more specific type. + if (fst_->Type() == "const") { + LatticeFasterDecoderTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } else if (fst_->Type() == "vector") { + LatticeFasterDecoderTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } + } + + KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ && "You must call InitDecoding() before AdvanceDecoding"); int32 num_frames_ready = decodable->NumFramesReady(); @@ -589,15 +626,16 @@ void LatticeFasterDecoder::AdvanceDecoding(DecodableInterface *decodable, if (NumFramesDecoded() % config_.prune_interval == 0) { PruneActiveTokens(config_.lattice_beam * config_.prune_scale); } - BaseFloat cost_cutoff = ProcessEmittingWrapper(decodable); - ProcessNonemittingWrapper(cost_cutoff); + BaseFloat cost_cutoff = ProcessEmitting(decodable); + ProcessNonemitting(cost_cutoff); } } // FinalizeDecoding() is a version of PruneActiveTokens that we call // (optionally) on the final frame. Takes into account the final-prob of // tokens. This function used to be called PruneActiveTokensFinal(). -void LatticeFasterDecoder::FinalizeDecoding() { +template +void LatticeFasterDecoderTpl::FinalizeDecoding() { int32 final_frame_plus_one = NumFramesDecoded(); int32 num_toks_begin = num_toks_; // PruneForwardLinksFinal() prunes final frame (with final-probs), and @@ -615,7 +653,8 @@ void LatticeFasterDecoder::FinalizeDecoding() { } /// Gets the weight cutoff. Also counts the active tokens. -BaseFloat LatticeFasterDecoder::GetCutoff(Elem *list_head, size_t *tok_count, +template +BaseFloat LatticeFasterDecoderTpl::GetCutoff(Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, Elem **best_elem) { BaseFloat best_weight = std::numeric_limits::infinity(); // positive == high cost == bad. @@ -684,8 +723,9 @@ BaseFloat LatticeFasterDecoder::GetCutoff(Elem *list_head, size_t *tok_count, } } -template -BaseFloat LatticeFasterDecoder::ProcessEmitting(DecodableInterface *decodable) { +template +BaseFloat LatticeFasterDecoderTpl::ProcessEmitting( + DecodableInterface *decodable) { KALDI_ASSERT(active_toks_.size() > 0); int32 frame = active_toks_.size() - 1; // frame is the frame-index // (zero-based) used to get likelihoods @@ -708,8 +748,8 @@ BaseFloat LatticeFasterDecoder::ProcessEmitting(DecodableInterface *decodable) { // pruning "online" before having seen all tokens BaseFloat cost_offset = 0.0; // Used to keep probabilities in a good - // dynamic range. - const FstType &fst = dynamic_cast(fst_); + // dynamic range. + // First process the best token to get a hopefully // reasonably tight bound on the next cutoff. The only @@ -718,12 +758,12 @@ BaseFloat LatticeFasterDecoder::ProcessEmitting(DecodableInterface *decodable) { StateId state = best_elem->key; Token *tok = best_elem->val; cost_offset = - tok->tot_cost; - for (fst::ArcIterator aiter(fst, state); + for (fst::ArcIterator aiter(*fst_, state); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); if (arc.ilabel != 0) { // propagate.. - BaseFloat new_weight = arc.weight.Value() + cost_offset - + BaseFloat new_weight = arc.weight.Value() + cost_offset - decodable->LogLikelihood(frame, arc.ilabel) + tok->tot_cost; if (new_weight + adaptive_beam < next_cutoff) next_cutoff = new_weight + adaptive_beam; @@ -745,7 +785,7 @@ BaseFloat LatticeFasterDecoder::ProcessEmitting(DecodableInterface *decodable) { StateId state = e->key; Token *tok = e->val; if (tok->tot_cost <= cur_cutoff) { - for (fst::ArcIterator aiter(fst, state); + for (fst::ArcIterator aiter(*fst_, state); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); @@ -761,12 +801,12 @@ BaseFloat LatticeFasterDecoder::ProcessEmitting(DecodableInterface *decodable) { // Note: the frame indexes into active_toks_ are one-based, // hence the + 1. Token *next_tok = FindOrAddToken(arc.nextstate, - frame + 1, tot_cost, NULL); + frame + 1, tot_cost, tok, NULL); // NULL: no change indicator needed // Add ForwardLink from tok to next_tok (put on head of list tok->links) - tok->links = new ForwardLink(next_tok, arc.ilabel, arc.olabel, - graph_cost, ac_cost, tok->links); + tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, + graph_cost, ac_cost, tok->links); } } // for all arcs } @@ -776,31 +816,26 @@ BaseFloat LatticeFasterDecoder::ProcessEmitting(DecodableInterface *decodable) { return next_cutoff; } -template BaseFloat LatticeFasterDecoder::ProcessEmitting>( - DecodableInterface *decodable); -template BaseFloat LatticeFasterDecoder::ProcessEmitting>( - DecodableInterface *decodable); -template BaseFloat LatticeFasterDecoder::ProcessEmitting>( - DecodableInterface *decodable); - -BaseFloat LatticeFasterDecoder::ProcessEmittingWrapper(DecodableInterface *decodable) { - if (fst_.Type() == "const") { - return LatticeFasterDecoder::ProcessEmitting>(decodable); - } else if (fst_.Type() == "vector") { - return LatticeFasterDecoder::ProcessEmitting>(decodable); - } else { - return LatticeFasterDecoder::ProcessEmitting>(decodable); +// static inline +template +void LatticeFasterDecoderTpl::DeleteForwardLinks(Token *tok) { + ForwardLinkT *l = tok->links, *m; + while (l != NULL) { + m = l->next; + delete l; + l = m; } + tok->links = NULL; } -template -void LatticeFasterDecoder::ProcessNonemitting(BaseFloat cutoff) { + +template +void LatticeFasterDecoderTpl::ProcessNonemitting(BaseFloat cutoff) { KALDI_ASSERT(!active_toks_.empty()); int32 frame = static_cast(active_toks_.size()) - 2; // Note: "frame" is the time-index we just processed, or -1 if // we are processing the nonemitting transitions before the // first frame (called from InitDecoding()). - const FstType &fst = dynamic_cast(fst_); // Processes nonemitting arcs for one frame. Propagates within toks_. // Note-- this queue structure is is not very optimal as @@ -809,15 +844,20 @@ void LatticeFasterDecoder::ProcessNonemitting(BaseFloat cutoff) { // problem did not improve overall speed. KALDI_ASSERT(queue_.empty()); - for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) - queue_.push_back(e->key); - if (queue_.empty()) { + + if (toks_.GetList() == NULL) { if (!warned_) { KALDI_WARN << "Error, no surviving tokens: frame is " << frame; warned_ = true; } } + for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) { + StateId state = e->key; + if (fst_->NumInputEpsilons(state) != 0) + queue_.push_back(state); + } + while (!queue_.empty()) { StateId state = queue_.back(); queue_.pop_back(); @@ -830,9 +870,9 @@ void LatticeFasterDecoder::ProcessNonemitting(BaseFloat cutoff) { // because we're about to regenerate them. This is a kind // of non-optimality (remember, this is the simple decoder), // but since most states are emitting it's not a huge issue. - tok->DeleteForwardLinks(); // necessary when re-visiting + DeleteForwardLinks(tok); // necessary when re-visiting tok->links = NULL; - for (fst::ArcIterator aiter(fst, state); + for (fst::ArcIterator aiter(*fst_, state); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); @@ -843,50 +883,37 @@ void LatticeFasterDecoder::ProcessNonemitting(BaseFloat cutoff) { bool changed; Token *new_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, - &changed); + tok, &changed); - tok->links = new ForwardLink(new_tok, 0, arc.olabel, - graph_cost, 0, tok->links); + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); // "changed" tells us whether the new token has a different // cost from before, or is new [if so, add into queue]. - if (changed) queue_.push_back(arc.nextstate); + if (changed && fst_->NumInputEpsilons(arc.nextstate) != 0) + queue_.push_back(arc.nextstate); } } } // for all arcs } // while queue not empty } -template void LatticeFasterDecoder::ProcessNonemitting>( - BaseFloat cutoff); -template void LatticeFasterDecoder::ProcessNonemitting>( - BaseFloat cutoff); -template void LatticeFasterDecoder::ProcessNonemitting>( - BaseFloat cutoff); - -void LatticeFasterDecoder::ProcessNonemittingWrapper(BaseFloat cost_cutoff) { - if (fst_.Type() == "const") { - return LatticeFasterDecoder::ProcessNonemitting>(cost_cutoff); - } else if (fst_.Type() == "vector") { - return LatticeFasterDecoder::ProcessNonemitting>(cost_cutoff); - } else { - return LatticeFasterDecoder::ProcessNonemitting>(cost_cutoff); - } -} -void LatticeFasterDecoder::DeleteElems(Elem *list) { +template +void LatticeFasterDecoderTpl::DeleteElems(Elem *list) { for (Elem *e = list, *e_tail; e != NULL; e = e_tail) { e_tail = e->tail; toks_.Delete(e); } } -void LatticeFasterDecoder::ClearActiveTokens() { // a cleanup routine, at utt end/begin +template +void LatticeFasterDecoderTpl::ClearActiveTokens() { // a cleanup routine, at utt end/begin for (size_t i = 0; i < active_toks_.size(); i++) { // Delete all tokens alive on this frame, and any forward // links they may have. for (Token *tok = active_toks_[i].toks; tok != NULL; ) { - tok->DeleteForwardLinks(); + DeleteForwardLinks(tok); Token *next_tok = tok->next; delete tok; num_toks_--; @@ -898,10 +925,11 @@ void LatticeFasterDecoder::ClearActiveTokens() { // a cleanup routine, at utt en } // static -void LatticeFasterDecoder::TopSortTokens(Token *tok_list, - std::vector *topsorted_list) { +template +void LatticeFasterDecoderTpl::TopSortTokens( + Token *tok_list, std::vector *topsorted_list) { unordered_map token2pos; - typedef unordered_map::iterator IterType; + typedef typename unordered_map::iterator IterType; int32 num_toks = 0; for (Token *tok = tok_list; tok != NULL; tok = tok->next) num_toks++; @@ -918,7 +946,7 @@ void LatticeFasterDecoder::TopSortTokens(Token *tok_list, for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) { Token *tok = iter->first; int32 pos = iter->second; - for (ForwardLink *link = tok->links; link != NULL; link = link->next) { + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { if (link->ilabel == 0) { // We only need to consider epsilon links, since non-epsilon links // transition between frames and this function only needs to sort a list @@ -943,16 +971,16 @@ void LatticeFasterDecoder::TopSortTokens(Token *tok_list, for (loop_count = 0; !reprocess.empty() && loop_count < max_loop; ++loop_count) { std::vector reprocess_vec; - for (unordered_set::iterator iter = reprocess.begin(); + for (typename unordered_set::iterator iter = reprocess.begin(); iter != reprocess.end(); ++iter) reprocess_vec.push_back(*iter); reprocess.clear(); - for (std::vector::iterator iter = reprocess_vec.begin(); + for (typename std::vector::iterator iter = reprocess_vec.begin(); iter != reprocess_vec.end(); ++iter) { Token *tok = *iter; int32 pos = token2pos[tok]; // Repeat the processing we did above (for comments, see above). - for (ForwardLink *link = tok->links; link != NULL; link = link->next) { + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { if (link->ilabel == 0) { IterType following_iter = token2pos.find(link->next_tok); if (following_iter != token2pos.end()) { @@ -975,4 +1003,17 @@ void LatticeFasterDecoder::TopSortTokens(Token *tok_list, (*topsorted_list)[iter->second] = iter->first; } +// Instantiate the template for the combination of token types and FST types +// that we'll need. +template class LatticeFasterDecoderTpl, decoder::StdToken>; +template class LatticeFasterDecoderTpl, decoder::StdToken >; +template class LatticeFasterDecoderTpl, decoder::StdToken >; +template class LatticeFasterDecoderTpl; + +template class LatticeFasterDecoderTpl , decoder::BackpointerToken>; +template class LatticeFasterDecoderTpl, decoder::BackpointerToken >; +template class LatticeFasterDecoderTpl, decoder::BackpointerToken >; +template class LatticeFasterDecoderTpl; + + } // end namespace kaldi. diff --git a/src/decoder/lattice-faster-decoder.h b/src/decoder/lattice-faster-decoder.h index 9c6ddd67acd..5f8c0778723 100644 --- a/src/decoder/lattice-faster-decoder.h +++ b/src/decoder/lattice-faster-decoder.h @@ -20,10 +20,6 @@ // See the Apache 2 License for the specific language governing permissions and // limitations under the License. -// Note: this file is "upstream" from lattice-faster-online-decoder.h, -// and changes in this file should be made to lattice-faster-online-decoder.h, -// if applicable. - #ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_ #define KALDI_DECODER_LATTICE_FASTER_DECODER_H_ @@ -35,6 +31,7 @@ #include "fstext/fstext-lib.h" #include "lat/determinize-lattice-pruned.h" #include "lat/kaldi-lattice.h" +#include "decoder/grammar-fst.h" namespace kaldi { @@ -86,32 +83,165 @@ struct LatticeFasterDecoderConfig { } void Check() const { KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 + && min_active <= max_active && prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0 && prune_scale > 0.0 && prune_scale < 1.0); } }; +namespace decoder { +// We will template the decoder on the token type as well as the FST type; this +// is a mechanism so that we can use the same underlying decoder code for +// versions of the decoder that support quickly getting the best path +// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also +// those that do not (LatticeFasterDecoder). + + +// ForwardLinks are the links from a token to a token on the next frame. +// or sometimes on the current frame (for input-epsilon links). +template +struct ForwardLink { + using Label = fst::StdArc::Label; + + Token *next_tok; // the next token [or NULL if represents final-state] + Label ilabel; // ilabel on arc + Label olabel; // olabel on arc + BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.) + BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc + ForwardLink *next; // next in singly-linked list of forward arcs (arcs + // in the state-level lattice) from a token. + inline ForwardLink(Token *next_tok, Label ilabel, Label olabel, + BaseFloat graph_cost, BaseFloat acoustic_cost, + ForwardLink *next): + next_tok(next_tok), ilabel(ilabel), olabel(olabel), + graph_cost(graph_cost), acoustic_cost(acoustic_cost), + next(next) { } +}; -/** A bit more optimized version of the lattice decoder. - See \ref lattices_generation \ref decoders_faster and \ref decoders_simple - for more information. - */ -class LatticeFasterDecoder { - public: - typedef fst::StdArc Arc; - typedef Arc::Label Label; - typedef Arc::StateId StateId; - typedef Arc::Weight Weight; - // instantiate this class once for each thing you have to decode. - LatticeFasterDecoder(const fst::Fst &fst, - const LatticeFasterDecoderConfig &config); +struct StdToken { + using ForwardLinkT = ForwardLink; + using Token = StdToken; + + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals the + // minimum difference between the cost of the best path that this link is a + // part of, and the cost of the absolute best path, under the assumption that + // any of the currently active states at the decoding front may eventually + // succeed (e.g. if you were to take the currently active states one by one + // and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + Token *next; + + // This function does nothing and should be optimized out; it's needed + // so we can share the regular LatticeFasterDecoderTpl code and the code + // for LatticeFasterOnlineDecoder that supports fast traceback. + inline void SetBackpointer (Token *backpointer) { } + + // This constructor just ignores the 'backpointer' argument. That argument is + // needed so that we can use the same decoder code for LatticeFasterDecoderTpl + // and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a + // fast way to obtain the best path). + inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next) { } +}; + +struct BackpointerToken { + using ForwardLinkT = ForwardLink; + using Token = BackpointerToken; + + // BackpointerToken is like Token but also + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + BackpointerToken *next; + + // Best preceding BackpointerToken (could be a on this frame, connected to + // this via an epsilon transition, or on a previous frame). This is only + // required for an efficient GetBestPath function in + // LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation + // (the "links" list is what stores the forward links, for that). + Token *backpointer; + + inline void SetBackpointer (Token *backpointer) { + this->backpointer = backpointer; + } + + inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next), + backpointer(backpointer) { } +}; + +} // namespace decoder + - // This version of the initializer "takes ownership" of the fst, - // and will delete it when this object is destroyed. - LatticeFasterDecoder(const LatticeFasterDecoderConfig &config, - fst::Fst *fst); +/** This is the "normal" lattice-generating decoder. + See \ref lattices_generation \ref decoders_faster and \ref decoders_simple + for more information. + The decoder is templated on the FST type and the token type. The token type + will normally be StdToken, but also may be BackpointerToken which is to support + quick lookup of the current best path (see lattice-faster-online-decoder.h) + + The FST you invoke this decoder with is expected to equal + Fst::Fst, a.k.a. StdFst, or GrammarFst. If you invoke it with + FST == StdFst and it notices that the actual FST type is + fst::VectorFst or fst::ConstFst, the decoder object + will internally cast itself to one that is templated on those more specific + types; this is an optimization for speed. + */ +template +class LatticeFasterDecoderTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using ForwardLinkT = decoder::ForwardLink; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeFasterDecoderTpl(const FST &fst, + const LatticeFasterDecoderConfig &config); + + // This version of the constructor takes ownership of the fst, and will delete + // it when this object is destroyed. + LatticeFasterDecoderTpl(const LatticeFasterDecoderConfig &config, + FST *fst); void SetOptions(const LatticeFasterDecoderConfig &config) { config_ = config; @@ -121,7 +251,7 @@ class LatticeFasterDecoder { return config_; } - ~LatticeFasterDecoder(); + ~LatticeFasterDecoderTpl(); /// Decodes until there are no more frames left in the "decodable" object.. /// note, this may block waiting for input if the "decodable" object blocks. @@ -151,8 +281,13 @@ class LatticeFasterDecoder { /// of the graph then it will include those as final-probs, else /// it will treat all final-probs as one. /// The raw lattice will be topologically sorted. - bool GetRawLattice(Lattice *ofst, - bool use_final_probs = true) const; + /// + /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, + /// which also supports a pruning beam, in case for some reason + /// you want it pruned tighter than the regular lattice beam. + /// We could put that here in future needed. + bool GetRawLattice(Lattice *ofst, bool use_final_probs = true) const; + /// [Deprecated, users should now use GetRawLattice and determinize it @@ -207,53 +342,13 @@ class LatticeFasterDecoder { // whenever we call ProcessEmitting(). inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } - private: - // ForwardLinks are the links from a token to a token on the next frame. - // or sometimes on the current frame (for input-epsilon links). - struct Token; - struct ForwardLink { - Token *next_tok; // the next token [or NULL if represents final-state] - Label ilabel; // ilabel on link. - Label olabel; // olabel on link. - BaseFloat graph_cost; // graph cost of traversing link (contains LM, etc.) - BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing link - ForwardLink *next; // next in singly-linked list of forward links from a - // token. - inline ForwardLink(Token *next_tok, Label ilabel, Label olabel, - BaseFloat graph_cost, BaseFloat acoustic_cost, - ForwardLink *next): - next_tok(next_tok), ilabel(ilabel), olabel(olabel), - graph_cost(graph_cost), acoustic_cost(acoustic_cost), - next(next) { } - }; + protected: + // we make things protected instead of private, as code in + // LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the + // internals. - // Token is what's resident in a particular state at a particular time. - // In this decoder a Token actually contains *forward* links. - // When first created, a Token just has the (total) cost. We add forward - // links from it when we process the next frame. - struct Token { - BaseFloat tot_cost; // would equal weight.Value()... cost up to this point. - BaseFloat extra_cost; // >= 0. This is used in pruning away tokens. - // there is a comment in lattice-faster-decoder.cc explaining this; - // search for "a note on the definition of extra_cost". - - ForwardLink *links; // Head of singly linked list of ForwardLinks - - Token *next; // Next in list of tokens for this frame. - - inline Token(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLink *links, - Token *next): - tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next) { } - inline void DeleteForwardLinks() { - ForwardLink *l = links, *m; - while (l != NULL) { - m = l->next; - delete l; - l = m; - } - links = NULL; - } - }; + // Deletes the elements of the singly linked list tok->links. + inline static void DeleteForwardLinks(Token *tok); // head of per-frame list of Tokens (list is in topological order), // and something saying whether we ever pruned it using PruneForwardLinks. @@ -265,7 +360,13 @@ class LatticeFasterDecoder { must_prune_tokens(true) { } }; - typedef HashList::Elem Elem; + using Elem = typename HashList::Elem; + // Equivalent to: + // struct Elem { + // StateId key; + // Token *val; + // Elem *tail; + // }; void PossiblyResizeHash(size_t num_toks); @@ -277,8 +378,11 @@ class LatticeFasterDecoder { // index plus one, which is used to index into the active_toks_ array. // Returns the Token pointer. Sets "changed" (if non-NULL) to true if the // token was newly created or the cost changed. + // If Token == StdToken, the 'backpointer' argument has no purpose (and will + // hopefully be optimized out). inline Token *FindOrAddToken(StateId state, int32 frame_plus_one, - BaseFloat tot_cost, bool *changed); + BaseFloat tot_cost, Token *backpointer, + bool *changed); // prunes outgoing links for all tokens in active_toks_[frame] // it's called by PruneActiveTokens @@ -338,20 +442,15 @@ class LatticeFasterDecoder { BaseFloat GetCutoff(Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, Elem **best_elem); - /// Processes emitting arcs for one frame. Propagates from prev_toks_ to cur_toks_. - /// Returns the cost cutoff for subsequent ProcessNonemitting() to use. - /// Templated on FST type for speed; called via ProcessEmittingWrapper(). - template BaseFloat ProcessEmitting(DecodableInterface *decodable); - - BaseFloat ProcessEmittingWrapper(DecodableInterface *decodable); + /// Processes emitting arcs for one frame. Propagates from prev_toks_ to + /// cur_toks_. Returns the cost cutoff for subsequent ProcessNonemitting() to + /// use. + BaseFloat ProcessEmitting(DecodableInterface *decodable); /// Processes nonemitting (epsilon) arcs for one frame. Called after /// ProcessEmitting() on each frame. The cost cutoff is computed by the /// preceding ProcessEmitting(). - /// the templated design is similar to ProcessEmitting() - template void ProcessNonemitting(BaseFloat cost_cutoff); - - void ProcessNonemittingWrapper(BaseFloat cost_cutoff); + void ProcessNonemitting(BaseFloat cost_cutoff); // HashList defined in ../util/hash-list.h. It actually allows us to maintain // more than one list (e.g. for current and previous frames), but only one of @@ -367,9 +466,13 @@ class LatticeFasterDecoder { // must_prune_tokens). std::vector queue_; // temp variable used in ProcessNonemitting, std::vector tmp_array_; // used in GetCutoff. - // make it class member to avoid internal new/delete. - const fst::Fst &fst_; + + // fst_ is a pointer to the FST we are decoding from. + const FST *fst_; + // delete_fst_ is true if the pointer fst_ needs to be deleted when this + // object is destroyed. bool delete_fst_; + std::vector cost_offsets_; // This contains, for each // frame, an offset that was added to the acoustic log-likelihoods on that // frame in order to keep everything in a nice dynamic range i.e. close to @@ -416,9 +519,11 @@ class LatticeFasterDecoder { void ClearActiveTokens(); - KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoder); + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderTpl); }; +typedef LatticeFasterDecoderTpl LatticeFasterDecoder; + } // end namespace kaldi. diff --git a/src/decoder/lattice-faster-online-decoder.cc b/src/decoder/lattice-faster-online-decoder.cc index 0a921438f94..ca0058155dd 100644 --- a/src/decoder/lattice-faster-online-decoder.cc +++ b/src/decoder/lattice-faster-online-decoder.cc @@ -29,85 +29,17 @@ namespace kaldi { -// instantiate this class once for each thing you have to decode. -LatticeFasterOnlineDecoder::LatticeFasterOnlineDecoder( - const fst::Fst &fst, - const LatticeFasterDecoderConfig &config): - fst_(fst), delete_fst_(false), config_(config), num_toks_(0) { - config.Check(); - toks_.SetSize(1000); // just so on the first frame we do something reasonable. -} - - -LatticeFasterOnlineDecoder::LatticeFasterOnlineDecoder(const LatticeFasterDecoderConfig &config, - fst::Fst *fst): - fst_(*fst), delete_fst_(true), config_(config), num_toks_(0) { - config.Check(); - toks_.SetSize(1000); // just so on the first frame we do something reasonable. -} - - -LatticeFasterOnlineDecoder::~LatticeFasterOnlineDecoder() { - DeleteElems(toks_.Clear()); - ClearActiveTokens(); - if (delete_fst_) delete &(fst_); -} - -void LatticeFasterOnlineDecoder::InitDecoding() { - // clean up from last time: - DeleteElems(toks_.Clear()); - cost_offsets_.clear(); - ClearActiveTokens(); - warned_ = false; - num_toks_ = 0; - decoding_finalized_ = false; - final_costs_.clear(); - StateId start_state = fst_.Start(); - KALDI_ASSERT(start_state != fst::kNoStateId); - active_toks_.resize(1); - Token *start_tok = new Token(0.0, 0.0, NULL, NULL, NULL); - active_toks_[0].toks = start_tok; - toks_.Insert(start_state, start_tok); - num_toks_++; - ProcessNonemittingWrapper(config_.beam); -} - -// Returns true if any kind of traceback is available (not necessarily from -// a final state). It should only very rarely return false; this indicates -// an unusual search error. -bool LatticeFasterOnlineDecoder::Decode(DecodableInterface *decodable) { - InitDecoding(); - - // We use 1-based indexing for frames in this decoder (if you view it in - // terms of features), but note that the decodable object uses zero-based - // numbering, which we have to correct for when we call it. - - while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { - if (NumFramesDecoded() % config_.prune_interval == 0) - PruneActiveTokens(config_.lattice_beam * config_.prune_scale); - BaseFloat cost_cutoff = ProcessEmittingWrapper(decodable); // Note: the value returned by - ProcessNonemittingWrapper(cost_cutoff); - } - FinalizeDecoding(); - - // Returns true if we have any kind of traceback available (not necessarily - // to the end state; query ReachedFinal() for that). - return !active_toks_.empty() && active_toks_.back().toks != NULL; -} - - - - - -bool LatticeFasterOnlineDecoder::TestGetBestPath(bool use_final_probs) const { +template +bool LatticeFasterOnlineDecoderTpl::TestGetBestPath( + bool use_final_probs) const { Lattice lat1; { Lattice raw_lat; - GetRawLattice(&raw_lat, use_final_probs); + this->GetRawLattice(&raw_lat, use_final_probs); ShortestPath(raw_lat, &lat1); } Lattice lat2; - GetBestPath(&lat2, use_final_probs); + GetBestPath(&lat2, use_final_probs); BaseFloat delta = 0.1; int32 num_paths = 1; if (!fst::RandEquivalent(lat1, lat2, num_paths, delta, rand())) { @@ -120,8 +52,9 @@ bool LatticeFasterOnlineDecoder::TestGetBestPath(bool use_final_probs) const { // Outputs an FST corresponding to the single best path through the lattice. -bool LatticeFasterOnlineDecoder::GetBestPath(Lattice *olat, - bool use_final_probs) const { +template +bool LatticeFasterOnlineDecoderTpl::GetBestPath(Lattice *olat, + bool use_final_probs) const { olat->DeleteStates(); BaseFloat final_graph_cost; BestPathIterator iter = BestPathEnd(use_final_probs, &final_graph_cost); @@ -141,94 +74,98 @@ bool LatticeFasterOnlineDecoder::GetBestPath(Lattice *olat, return true; } - -// Outputs an FST corresponding to the raw, state-level -// tracebacks. -bool LatticeFasterOnlineDecoder::GetRawLattice(Lattice *ofst, - bool use_final_probs) const { - typedef LatticeArc Arc; - typedef Arc::StateId StateId; - typedef Arc::Weight Weight; - typedef Arc::Label Label; - - // Note: you can't use the old interface (Decode()) if you want to - // get the lattice with use_final_probs = false. You'd have to do - // InitDecoding() and then AdvanceDecoding(). - if (decoding_finalized_ && !use_final_probs) +template +typename LatticeFasterOnlineDecoderTpl::BestPathIterator LatticeFasterOnlineDecoderTpl::BestPathEnd( + bool use_final_probs, + BaseFloat *final_cost_out) const { + if (this->decoding_finalized_ && !use_final_probs) KALDI_ERR << "You cannot call FinalizeDecoding() and then call " - << "GetRawLattice() with use_final_probs == false"; + << "BestPathEnd() with use_final_probs == false"; + KALDI_ASSERT(this->NumFramesDecoded() > 0 && + "You cannot call BestPathEnd if no frames were decoded."); unordered_map final_costs_local; const unordered_map &final_costs = - (decoding_finalized_ ? final_costs_ : final_costs_local); - if (!decoding_finalized_ && use_final_probs) - ComputeFinalCosts(&final_costs_local, NULL, NULL); + (this->decoding_finalized_ ? this->final_costs_ :final_costs_local); + if (!this->decoding_finalized_ && use_final_probs) + this->ComputeFinalCosts(&final_costs_local, NULL, NULL); - ofst->DeleteStates(); - // num-frames plus one (since frames are one-based, and we have - // an extra frame for the start-state). - int32 num_frames = active_toks_.size() - 1; - KALDI_ASSERT(num_frames > 0); - const int32 bucket_count = num_toks_/2 + 3; - unordered_map tok_map(bucket_count); - // First create all states. - std::vector token_list; - for (int32 f = 0; f <= num_frames; f++) { - if (active_toks_[f].toks == NULL) { - KALDI_WARN << "GetRawLattice: no tokens active on frame " << f - << ": not producing lattice.\n"; - return false; + // Singly linked list of tokens on last frame (access list through "next" + // pointer). + BaseFloat best_cost = std::numeric_limits::infinity(); + BaseFloat best_final_cost = 0; + Token *best_tok = NULL; + for (Token *tok = this->active_toks_.back().toks; + tok != NULL; tok = tok->next) { + BaseFloat cost = tok->tot_cost, final_cost = 0.0; + if (use_final_probs && !final_costs.empty()) { + // if we are instructed to use final-probs, and any final tokens were + // active on final frame, include the final-prob in the cost of the token. + typename unordered_map::const_iterator + iter = final_costs.find(tok); + if (iter != final_costs.end()) { + final_cost = iter->second; + cost += final_cost; + } else { + cost = std::numeric_limits::infinity(); + } + } + if (cost < best_cost) { + best_cost = cost; + best_tok = tok; + best_final_cost = final_cost; } - TopSortTokens(active_toks_[f].toks, &token_list); - for (size_t i = 0; i < token_list.size(); i++) - if (token_list[i] != NULL) - tok_map[token_list[i]] = ofst->AddState(); } - // The next statement sets the start state of the output FST. Because we - // topologically sorted the tokens, state zero must be the start-state. - ofst->SetStart(0); - - KALDI_VLOG(4) << "init:" << num_toks_/2 + 3 << " buckets:" - << tok_map.bucket_count() << " load:" << tok_map.load_factor() - << " max:" << tok_map.max_load_factor(); - // Now create all arcs. - for (int32 f = 0; f <= num_frames; f++) { - for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { - StateId cur_state = tok_map[tok]; - for (ForwardLink *l = tok->links; - l != NULL; - l = l->next) { - unordered_map::const_iterator iter = - tok_map.find(l->next_tok); - StateId nextstate = iter->second; - KALDI_ASSERT(iter != tok_map.end()); - BaseFloat cost_offset = 0.0; - if (l->ilabel != 0) { // emitting.. - KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); - cost_offset = cost_offsets_[f]; - } - Arc arc(l->ilabel, l->olabel, - Weight(l->graph_cost, l->acoustic_cost - cost_offset), - nextstate); - ofst->AddArc(cur_state, arc); - } - if (f == num_frames) { - if (use_final_probs && !final_costs.empty()) { - unordered_map::const_iterator iter = - final_costs.find(tok); - if (iter != final_costs.end()) - ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); - } else { - ofst->SetFinal(cur_state, LatticeWeight::One()); + if (best_tok == NULL) { // this should not happen, and is likely a code error or + // caused by infinities in likelihoods, but I'm not making + // it a fatal error for now. + KALDI_WARN << "No final token found."; + } + if (final_cost_out) + *final_cost_out = best_final_cost; + return BestPathIterator(best_tok, this->NumFramesDecoded() - 1); +} + + +template +typename LatticeFasterOnlineDecoderTpl::BestPathIterator LatticeFasterOnlineDecoderTpl::TraceBackBestPath( + BestPathIterator iter, LatticeArc *oarc) const { + KALDI_ASSERT(!iter.Done() && oarc != NULL); + Token *tok = static_cast(iter.tok); + int32 cur_t = iter.frame, ret_t = cur_t; + if (tok->backpointer != NULL) { + ForwardLinkT *link; + for (link = tok->backpointer->links; + link != NULL; link = link->next) { + if (link->next_tok == tok) { // this is the link to "tok" + oarc->ilabel = link->ilabel; + oarc->olabel = link->olabel; + BaseFloat graph_cost = link->graph_cost, + acoustic_cost = link->acoustic_cost; + if (link->ilabel != 0) { + KALDI_ASSERT(static_cast(cur_t) < this->cost_offsets_.size()); + acoustic_cost -= this->cost_offsets_[cur_t]; + ret_t--; } + oarc->weight = LatticeWeight(graph_cost, acoustic_cost); + break; } } + if (link == NULL) { // Did not find correct link. + KALDI_ERR << "Error tracing best-path back (likely " + << "bug in token-pruning algorithm)"; + } + } else { + oarc->ilabel = 0; + oarc->olabel = 0; + oarc->weight = LatticeWeight::One(); // zero costs. } - return (ofst->NumStates() > 0); + return BestPathIterator(tok->backpointer, ret_t); } -bool LatticeFasterOnlineDecoder::GetRawLatticePruned( +template +bool LatticeFasterOnlineDecoderTpl::GetRawLatticePruned( Lattice *ofst, bool use_final_probs, BaseFloat beam) const { @@ -240,57 +177,58 @@ bool LatticeFasterOnlineDecoder::GetRawLatticePruned( // Note: you can't use the old interface (Decode()) if you want to // get the lattice with use_final_probs = false. You'd have to do // InitDecoding() and then AdvanceDecoding(). - if (decoding_finalized_ && !use_final_probs) + if (this->decoding_finalized_ && !use_final_probs) KALDI_ERR << "You cannot call FinalizeDecoding() and then call " << "GetRawLattice() with use_final_probs == false"; unordered_map final_costs_local; const unordered_map &final_costs = - (decoding_finalized_ ? final_costs_ : final_costs_local); - if (!decoding_finalized_ && use_final_probs) - ComputeFinalCosts(&final_costs_local, NULL, NULL); + (this->decoding_finalized_ ? this->final_costs_ : final_costs_local); + if (!this->decoding_finalized_ && use_final_probs) + this->ComputeFinalCosts(&final_costs_local, NULL, NULL); ofst->DeleteStates(); // num-frames plus one (since frames are one-based, and we have // an extra frame for the start-state). - int32 num_frames = active_toks_.size() - 1; + int32 num_frames = this->active_toks_.size() - 1; KALDI_ASSERT(num_frames > 0); for (int32 f = 0; f <= num_frames; f++) { - if (active_toks_[f].toks == NULL) { - KALDI_WARN << "GetRawLattice: no tokens active on frame " << f + if (this->active_toks_[f].toks == NULL) { + KALDI_WARN << "No tokens active on frame " << f << ": not producing lattice.\n"; return false; } } - unordered_map tok_map; std::queue > tok_queue; // First initialize the queue and states. Put the initial state on the queue; // this is the last token in the list active_toks_[0].toks. - for (Token *tok = active_toks_[0].toks; tok != NULL; tok = tok->next) { + for (Token *tok = this->active_toks_[0].toks; + tok != NULL; tok = tok->next) { if (tok->next == NULL) { tok_map[tok] = ofst->AddState(); ofst->SetStart(tok_map[tok]); std::pair tok_pair(tok, 0); // #frame = 0 tok_queue.push(tok_pair); } - } - + } + // Next create states for "good" tokens while (!tok_queue.empty()) { std::pair cur_tok_pair = tok_queue.front(); tok_queue.pop(); Token *cur_tok = cur_tok_pair.first; int32 cur_frame = cur_tok_pair.second; - KALDI_ASSERT(cur_frame >= 0 && cur_frame <= cost_offsets_.size()); - - unordered_map::const_iterator iter = + KALDI_ASSERT(cur_frame >= 0 && + cur_frame <= this->cost_offsets_.size()); + + typename unordered_map::const_iterator iter = tok_map.find(cur_tok); KALDI_ASSERT(iter != tok_map.end()); StateId cur_state = iter->second; - for (ForwardLink *l = cur_tok->links; + for (ForwardLinkT *l = cur_tok->links; l != NULL; l = l->next) { Token *next_tok = l->next_tok; @@ -304,7 +242,8 @@ bool LatticeFasterOnlineDecoder::GetRawLatticePruned( } else { nextstate = tok_map[next_tok]; } - BaseFloat cost_offset = (l->ilabel != 0 ? cost_offsets_[cur_frame] : 0); + BaseFloat cost_offset = (l->ilabel != 0 ? + this->cost_offsets_[cur_frame] : 0); Arc arc(l->ilabel, l->olabel, Weight(l->graph_cost, l->acoustic_cost - cost_offset), nextstate); @@ -313,11 +252,11 @@ bool LatticeFasterOnlineDecoder::GetRawLatticePruned( } if (cur_frame == num_frames) { if (use_final_probs && !final_costs.empty()) { - unordered_map::const_iterator iter = + typename unordered_map::const_iterator iter = final_costs.find(cur_tok); if (iter != final_costs.end()) ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); - } else { + } else { ofst->SetFinal(cur_state, LatticeWeight::One()); } } @@ -326,841 +265,12 @@ bool LatticeFasterOnlineDecoder::GetRawLatticePruned( } -void LatticeFasterOnlineDecoder::PossiblyResizeHash(size_t num_toks) { - size_t new_sz = static_cast(static_cast(num_toks) - * config_.hash_ratio); - if (new_sz > toks_.Size()) { - toks_.SetSize(new_sz); - } -} - -// FindOrAddToken either locates a token in hash of toks_, -// or if necessary inserts a new, empty token (i.e. with no forward links) -// for the current frame. [note: it's inserted if necessary into hash toks_ -// and also into the singly linked list of tokens active on this frame -// (whose head is at active_toks_[frame]). -inline LatticeFasterOnlineDecoder::Token *LatticeFasterOnlineDecoder::FindOrAddToken( - StateId state, int32 frame_plus_one, BaseFloat tot_cost, - Token *backpointer, bool *changed) { - // Returns the Token pointer. Sets "changed" (if non-NULL) to true - // if the token was newly created or the cost changed. - KALDI_ASSERT(frame_plus_one < active_toks_.size()); - Token *&toks = active_toks_[frame_plus_one].toks; - Elem *e_found = toks_.Find(state); - if (e_found == NULL) { // no such token presently. - const BaseFloat extra_cost = 0.0; - // tokens on the currently final frame have zero extra_cost - // as any of them could end up - // on the winning path. - Token *new_tok = new Token (tot_cost, extra_cost, NULL, toks, backpointer); - // NULL: no forward links yet - toks = new_tok; - num_toks_++; - toks_.Insert(state, new_tok); - if (changed) *changed = true; - return new_tok; - } else { - Token *tok = e_found->val; // There is an existing Token for this state. - if (tok->tot_cost > tot_cost) { // replace old token - tok->tot_cost = tot_cost; - tok->backpointer = backpointer; - // we don't allocate a new token, the old stays linked in active_toks_ - // we only replace the tot_cost - // in the current frame, there are no forward links (and no extra_cost) - // only in ProcessNonemitting we have to delete forward links - // in case we visit a state for the second time - // those forward links, that lead to this replaced token before: - // they remain and will hopefully be pruned later (PruneForwardLinks...) - if (changed) *changed = true; - } else { - if (changed) *changed = false; - } - return tok; - } -} - -// prunes outgoing links for all tokens in active_toks_[frame] -// it's called by PruneActiveTokens -// all links, that have link_extra_cost > lattice_beam are pruned -void LatticeFasterOnlineDecoder::PruneForwardLinks( - int32 frame_plus_one, bool *extra_costs_changed, - bool *links_pruned, BaseFloat delta) { - // delta is the amount by which the extra_costs must change - // If delta is larger, we'll tend to go back less far - // toward the beginning of the file. - // extra_costs_changed is set to true if extra_cost was changed for any token - // links_pruned is set to true if any link in any token was pruned - - *extra_costs_changed = false; - *links_pruned = false; - KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); - if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen. - if (!warned_) { - KALDI_WARN << "No tokens alive [doing pruning].. warning first " - "time only for each utterance\n"; - warned_ = true; - } - } - - // We have to iterate until there is no more change, because the links - // are not guaranteed to be in topological order. - bool changed = true; // difference new minus old extra cost >= delta ? - while (changed) { - changed = false; - for (Token *tok = active_toks_[frame_plus_one].toks; - tok != NULL; tok = tok->next) { - ForwardLink *link, *prev_link = NULL; - // will recompute tok_extra_cost for tok. - BaseFloat tok_extra_cost = std::numeric_limits::infinity(); - // tok_extra_cost is the best (min) of link_extra_cost of outgoing links - for (link = tok->links; link != NULL; ) { - // See if we need to excise this link... - Token *next_tok = link->next_tok; - BaseFloat link_extra_cost = next_tok->extra_cost + - ((tok->tot_cost + link->acoustic_cost + link->graph_cost) - - next_tok->tot_cost); // difference in brackets is >= 0 - // link_exta_cost is the difference in score between the best paths - // through link source state and through link destination state - KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN - if (link_extra_cost > config_.lattice_beam) { // excise link - ForwardLink *next_link = link->next; - if (prev_link != NULL) prev_link->next = next_link; - else tok->links = next_link; - delete link; - link = next_link; // advance link but leave prev_link the same. - *links_pruned = true; - } else { // keep the link and update the tok_extra_cost if needed. - if (link_extra_cost < 0.0) { // this is just a precaution. - if (link_extra_cost < -0.01) - KALDI_WARN << "Negative extra_cost: " << link_extra_cost; - link_extra_cost = 0.0; - } - if (link_extra_cost < tok_extra_cost) - tok_extra_cost = link_extra_cost; - prev_link = link; // move to next link - link = link->next; - } - } // for all outgoing links - if (fabs(tok_extra_cost - tok->extra_cost) > delta) - changed = true; // difference new minus old is bigger than delta - tok->extra_cost = tok_extra_cost; - // will be +infinity or <= lattice_beam_. - // infinity indicates, that no forward link survived pruning - } // for all Token on active_toks_[frame] - if (changed) *extra_costs_changed = true; - - // Note: it's theoretically possible that aggressive compiler - // optimizations could cause an infinite loop here for small delta and - // high-dynamic-range scores. - } // while changed -} - -// PruneForwardLinksFinal is a version of PruneForwardLinks that we call -// on the final frame. If there are final tokens active, it uses -// the final-probs for pruning, otherwise it treats all tokens as final. -void LatticeFasterOnlineDecoder::PruneForwardLinksFinal() { - KALDI_ASSERT(!active_toks_.empty()); - int32 frame_plus_one = active_toks_.size() - 1; - - if (active_toks_[frame_plus_one].toks == NULL ) // empty list; should not happen. - KALDI_WARN << "No tokens alive at end of file\n"; - - typedef unordered_map::const_iterator IterType; - ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_); - decoding_finalized_ = true; - // We call DeleteElems() as a nicety, not because it's really necessary; - // otherwise there would be a time, after calling PruneTokensForFrame() on the - // final frame, when toks_.GetList() or toks_.Clear() would contain pointers - // to nonexistent tokens. - DeleteElems(toks_.Clear()); - - // Now go through tokens on this frame, pruning forward links... may have to - // iterate a few times until there is no more change, because the list is not - // in topological order. This is a modified version of the code in - // PruneForwardLinks, but here we also take account of the final-probs. - bool changed = true; - BaseFloat delta = 1.0e-05; - while (changed) { - changed = false; - for (Token *tok = active_toks_[frame_plus_one].toks; - tok != NULL; tok = tok->next) { - ForwardLink *link, *prev_link = NULL; - // will recompute tok_extra_cost. It has a term in it that corresponds - // to the "final-prob", so instead of initializing tok_extra_cost to infinity - // below we set it to the difference between the (score+final_prob) of this token, - // and the best such (score+final_prob). - BaseFloat final_cost; - if (final_costs_.empty()) { - final_cost = 0.0; - } else { - IterType iter = final_costs_.find(tok); - if (iter != final_costs_.end()) - final_cost = iter->second; - else - final_cost = std::numeric_limits::infinity(); - } - BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_; - // tok_extra_cost will be a "min" over either directly being final, or - // being indirectly final through other links, and the loop below may - // decrease its value: - for (link = tok->links; link != NULL; ) { - // See if we need to excise this link... - Token *next_tok = link->next_tok; - BaseFloat link_extra_cost = next_tok->extra_cost + - ((tok->tot_cost + link->acoustic_cost + link->graph_cost) - - next_tok->tot_cost); - if (link_extra_cost > config_.lattice_beam) { // excise link - ForwardLink *next_link = link->next; - if (prev_link != NULL) prev_link->next = next_link; - else tok->links = next_link; - delete link; - link = next_link; // advance link but leave prev_link the same. - } else { // keep the link and update the tok_extra_cost if needed. - if (link_extra_cost < 0.0) { // this is just a precaution. - if (link_extra_cost < -0.01) - KALDI_WARN << "Negative extra_cost: " << link_extra_cost; - link_extra_cost = 0.0; - } - if (link_extra_cost < tok_extra_cost) - tok_extra_cost = link_extra_cost; - prev_link = link; - link = link->next; - } - } - // prune away tokens worse than lattice_beam above best path. This step - // was not necessary in the non-final case because then, this case - // showed up as having no forward links. Here, the tok_extra_cost has - // an extra component relating to the final-prob. - if (tok_extra_cost > config_.lattice_beam) - tok_extra_cost = std::numeric_limits::infinity(); - // to be pruned in PruneTokensForFrame - - if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) - changed = true; - tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. - } - } // while changed - -} - -BaseFloat LatticeFasterOnlineDecoder::FinalRelativeCost() const { - if (!decoding_finalized_) { - BaseFloat relative_cost; - ComputeFinalCosts(NULL, &relative_cost, NULL); - return relative_cost; - } else { - // we're not allowed to call that function if FinalizeDecoding() has - // been called; return a cached value. - return final_relative_cost_; - } -} - - -// Prune away any tokens on this frame that have no forward links. -// [we don't do this in PruneForwardLinks because it would give us -// a problem with dangling pointers]. -// It's called by PruneActiveTokens if any forward links have been pruned -void LatticeFasterOnlineDecoder::PruneTokensForFrame(int32 frame_plus_one) { - KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); - Token *&toks = active_toks_[frame_plus_one].toks; - if (toks == NULL) - KALDI_WARN << "No tokens alive [doing pruning]\n"; - Token *tok, *next_tok, *prev_tok = NULL; - for (tok = toks; tok != NULL; tok = next_tok) { - next_tok = tok->next; - if (tok->extra_cost == std::numeric_limits::infinity()) { - // token is unreachable from end of graph; (no forward links survived) - // excise tok from list and delete tok. - if (prev_tok != NULL) prev_tok->next = tok->next; - else toks = tok->next; - delete tok; - num_toks_--; - } else { // fetch next Token - prev_tok = tok; - } - } -} - -// Go backwards through still-alive tokens, pruning them, starting not from -// the current frame (where we want to keep all tokens) but from the frame before -// that. We go backwards through the frames and stop when we reach a point -// where the delta-costs are not changing (and the delta controls when we consider -// a cost to have "not changed"). -void LatticeFasterOnlineDecoder::PruneActiveTokens(BaseFloat delta) { - int32 cur_frame_plus_one = NumFramesDecoded(); - int32 num_toks_begin = num_toks_; - // The index "f" below represents a "frame plus one", i.e. you'd have to subtract - // one to get the corresponding index for the decodable object. - for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) { - // Reason why we need to prune forward links in this situation: - // (1) we have never pruned them (new TokenList) - // (2) we have not yet pruned the forward links to the next f, - // after any of those tokens have changed their extra_cost. - if (active_toks_[f].must_prune_forward_links) { - bool extra_costs_changed = false, links_pruned = false; - PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta); - if (extra_costs_changed && f > 0) // any token has changed extra_cost - active_toks_[f-1].must_prune_forward_links = true; - if (links_pruned) // any link was pruned - active_toks_[f].must_prune_tokens = true; - active_toks_[f].must_prune_forward_links = false; // job done - } - if (f+1 < cur_frame_plus_one && // except for last f (no forward links) - active_toks_[f+1].must_prune_tokens) { - PruneTokensForFrame(f+1); - active_toks_[f+1].must_prune_tokens = false; - } - } - KALDI_VLOG(4) << "PruneActiveTokens: pruned tokens from " << num_toks_begin - << " to " << num_toks_; -} - -void LatticeFasterOnlineDecoder::ComputeFinalCosts( - unordered_map *final_costs, - BaseFloat *final_relative_cost, - BaseFloat *final_best_cost) const { - KALDI_ASSERT(!decoding_finalized_); - if (final_costs != NULL) - final_costs->clear(); - const Elem *final_toks = toks_.GetList(); - BaseFloat infinity = std::numeric_limits::infinity(); - BaseFloat best_cost = infinity, - best_cost_with_final = infinity; - while (final_toks != NULL) { - StateId state = final_toks->key; - Token *tok = final_toks->val; - const Elem *next = final_toks->tail; - BaseFloat final_cost = fst_.Final(state).Value(); - BaseFloat cost = tok->tot_cost, - cost_with_final = cost + final_cost; - best_cost = std::min(cost, best_cost); - best_cost_with_final = std::min(cost_with_final, best_cost_with_final); - if (final_costs != NULL && final_cost != infinity) - (*final_costs)[tok] = final_cost; - final_toks = next; - } - if (final_relative_cost != NULL) { - if (best_cost == infinity && best_cost_with_final == infinity) { - // Likely this will only happen if there are no tokens surviving. - // This seems the least bad way to handle it. - *final_relative_cost = infinity; - } else { - *final_relative_cost = best_cost_with_final - best_cost; - } - } - if (final_best_cost != NULL) { - if (best_cost_with_final != infinity) { // final-state exists. - *final_best_cost = best_cost_with_final; - } else { // no final-state exists. - *final_best_cost = best_cost; - } - } -} - - -LatticeFasterOnlineDecoder::BestPathIterator LatticeFasterOnlineDecoder::BestPathEnd( - bool use_final_probs, - BaseFloat *final_cost_out) const { - if (decoding_finalized_ && !use_final_probs) - KALDI_ERR << "You cannot call FinalizeDecoding() and then call " - << "BestPathEnd() with use_final_probs == false"; - KALDI_ASSERT(NumFramesDecoded() > 0 && - "You cannot call BestPathEnd if no frames were decoded."); - - unordered_map final_costs_local; - - const unordered_map &final_costs = - (decoding_finalized_ ? final_costs_ : final_costs_local); - if (!decoding_finalized_ && use_final_probs) - ComputeFinalCosts(&final_costs_local, NULL, NULL); - - // Singly linked list of tokens on last frame (access list through "next" - // pointer). - BaseFloat best_cost = std::numeric_limits::infinity(); - BaseFloat best_final_cost = 0; - Token *best_tok = NULL; - for (Token *tok = active_toks_.back().toks; tok != NULL; tok = tok->next) { - BaseFloat cost = tok->tot_cost, final_cost = 0.0; - if (use_final_probs && !final_costs.empty()) { - // if we are instructed to use final-probs, and any final tokens were - // active on final frame, include the final-prob in the cost of the token. - unordered_map::const_iterator iter = final_costs.find(tok); - if (iter != final_costs.end()) { - final_cost = iter->second; - cost += final_cost; - } else { - cost = std::numeric_limits::infinity(); - } - } - if (cost < best_cost) { - best_cost = cost; - best_tok = tok; - best_final_cost = final_cost; - } - } - if (best_tok == NULL) { // this should not happen, and is likely a code error or - // caused by infinities in likelihoods, but I'm not making - // it a fatal error for now. - KALDI_WARN << "No final token found."; - } - if (final_cost_out) - *final_cost_out = best_final_cost; - return BestPathIterator(best_tok, NumFramesDecoded() - 1); -} - - -LatticeFasterOnlineDecoder::BestPathIterator LatticeFasterOnlineDecoder::TraceBackBestPath( - BestPathIterator iter, LatticeArc *oarc) const { - KALDI_ASSERT(!iter.Done() && oarc != NULL); - Token *tok = static_cast(iter.tok); - int32 cur_t = iter.frame, ret_t = cur_t; - if (tok->backpointer != NULL) { - ForwardLink *link; - for (link = tok->backpointer->links; - link != NULL; link = link->next) { - if (link->next_tok == tok) { // this is the link to "tok" - oarc->ilabel = link->ilabel; - oarc->olabel = link->olabel; - BaseFloat graph_cost = link->graph_cost, - acoustic_cost = link->acoustic_cost; - if (link->ilabel != 0) { - KALDI_ASSERT(static_cast(cur_t) < cost_offsets_.size()); - acoustic_cost -= cost_offsets_[cur_t]; - ret_t--; - } - oarc->weight = LatticeWeight(graph_cost, acoustic_cost); - break; - } - } - if (link == NULL) { // Did not find correct link. - KALDI_ERR << "Error tracing best-path back (likely " - << "bug in token-pruning algorithm)"; - } - } else { - oarc->ilabel = 0; - oarc->olabel = 0; - oarc->weight = LatticeWeight::One(); // zero costs. - } - return BestPathIterator(tok->backpointer, ret_t); -} - - -void LatticeFasterOnlineDecoder::AdvanceDecoding(DecodableInterface *decodable, - int32 max_num_frames) { - KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ && - "You must call InitDecoding() before AdvanceDecoding"); - int32 num_frames_ready = decodable->NumFramesReady(); - // num_frames_ready must be >= num_frames_decoded, or else - // the number of frames ready must have decreased (which doesn't - // make sense) or the decodable object changed between calls - // (which isn't allowed). - KALDI_ASSERT(num_frames_ready >= NumFramesDecoded()); - int32 target_frames_decoded = num_frames_ready; - if (max_num_frames >= 0) - target_frames_decoded = std::min(target_frames_decoded, - NumFramesDecoded() + max_num_frames); - while (NumFramesDecoded() < target_frames_decoded) { - if (NumFramesDecoded() % config_.prune_interval == 0) { - PruneActiveTokens(config_.lattice_beam * config_.prune_scale); - } - // note: ProcessEmitting() increments NumFramesDecoded(). - BaseFloat cost_cutoff = ProcessEmittingWrapper(decodable); - ProcessNonemittingWrapper(cost_cutoff); - } -} - - -// FinalizeDecoding() is a version of PruneActiveTokens that we call -// (optionally) on the final frame. Takes into account the final-prob of -// tokens. This function used to be called PruneActiveTokensFinal(). -void LatticeFasterOnlineDecoder::FinalizeDecoding() { - int32 final_frame_plus_one = NumFramesDecoded(); - int32 num_toks_begin = num_toks_; - // PruneForwardLinksFinal() prunes final frame (with final-probs), and - // sets decoding_finalized_. - PruneForwardLinksFinal(); - for (int32 f = final_frame_plus_one - 1; f >= 0; f--) { - bool b1, b2; // values not used. - BaseFloat dontcare = 0.0; // delta of zero means we must always update - PruneForwardLinks(f, &b1, &b2, dontcare); - PruneTokensForFrame(f + 1); - } - PruneTokensForFrame(0); - KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin - << " to " << num_toks_; -} - -/// Gets the weight cutoff. Also counts the active tokens. -BaseFloat LatticeFasterOnlineDecoder::GetCutoff(Elem *list_head, size_t *tok_count, - BaseFloat *adaptive_beam, Elem **best_elem) { - BaseFloat best_weight = std::numeric_limits::infinity(); - // positive == high cost == bad. - size_t count = 0; - if (config_.max_active == std::numeric_limits::max() && - config_.min_active == 0) { - for (Elem *e = list_head; e != NULL; e = e->tail, count++) { - BaseFloat w = static_cast(e->val->tot_cost); - if (w < best_weight) { - best_weight = w; - if (best_elem) *best_elem = e; - } - } - if (tok_count != NULL) *tok_count = count; - if (adaptive_beam != NULL) *adaptive_beam = config_.beam; - return best_weight + config_.beam; - } else { - tmp_array_.clear(); - for (Elem *e = list_head; e != NULL; e = e->tail, count++) { - BaseFloat w = e->val->tot_cost; - tmp_array_.push_back(w); - if (w < best_weight) { - best_weight = w; - if (best_elem) *best_elem = e; - } - } - if (tok_count != NULL) *tok_count = count; - - BaseFloat beam_cutoff = best_weight + config_.beam, - min_active_cutoff = std::numeric_limits::infinity(), - max_active_cutoff = std::numeric_limits::infinity(); - - KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() - << " is " << tmp_array_.size(); - - if (tmp_array_.size() > static_cast(config_.max_active)) { - std::nth_element(tmp_array_.begin(), - tmp_array_.begin() + config_.max_active, - tmp_array_.end()); - max_active_cutoff = tmp_array_[config_.max_active]; - } - if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam. - if (adaptive_beam) - *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta; - return max_active_cutoff; - } - if (tmp_array_.size() > static_cast(config_.min_active)) { - if (config_.min_active == 0) min_active_cutoff = best_weight; - else { - std::nth_element(tmp_array_.begin(), - tmp_array_.begin() + config_.min_active, - tmp_array_.size() > static_cast(config_.max_active) ? - tmp_array_.begin() + config_.max_active : - tmp_array_.end()); - min_active_cutoff = tmp_array_[config_.min_active]; - } - } - - if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam. - if (adaptive_beam) - *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta; - return min_active_cutoff; - } else { - *adaptive_beam = config_.beam; - return beam_cutoff; - } - } -} - - -template -BaseFloat LatticeFasterOnlineDecoder::ProcessEmitting( - DecodableInterface *decodable) { - KALDI_ASSERT(active_toks_.size() > 0); - int32 frame = active_toks_.size() - 1; // frame is the frame-index - // (zero-based) used to get likelihoods - // from the decodable object. - active_toks_.resize(active_toks_.size() + 1); - - Elem *final_toks = toks_.Clear(); // analogous to swapping prev_toks_ / cur_toks_ - // in simple-decoder.h. Removes the Elems from - // being indexed in the hash in toks_. - Elem *best_elem = NULL; - BaseFloat adaptive_beam; - size_t tok_cnt; - BaseFloat cur_cutoff = GetCutoff(final_toks, &tok_cnt, &adaptive_beam, &best_elem); - PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough. - - BaseFloat next_cutoff = std::numeric_limits::infinity(); - // pruning "online" before having seen all tokens - - BaseFloat cost_offset = 0.0; // Used to keep probabilities in a good - // dynamic range. - const FstType &fst = dynamic_cast(fst_); - - // First process the best token to get a hopefully - // reasonably tight bound on the next cutoff. The only - // products of the next block are "next_cutoff" and "cost_offset". - if (best_elem) { - StateId state = best_elem->key; - Token *tok = best_elem->val; - cost_offset = - tok->tot_cost; - for (fst::ArcIterator aiter(fst, state); - !aiter.Done(); - aiter.Next()) { - const Arc &arc = aiter.Value(); - if (arc.ilabel != 0) { // propagate.. - BaseFloat new_weight = arc.weight.Value() + cost_offset - - decodable->LogLikelihood(frame, arc.ilabel) + tok->tot_cost; - if (new_weight + adaptive_beam < next_cutoff) - next_cutoff = new_weight + adaptive_beam; - } - } - } - - // Store the offset on the acoustic likelihoods that we're applying. - // Could just do cost_offsets_.push_back(cost_offset), but we - // do it this way as it's more robust to future code changes. - cost_offsets_.resize(frame + 1, 0.0); - cost_offsets_[frame] = cost_offset; - - // the tokens are now owned here, in final_toks, and the hash is empty. - // 'owned' is a complex thing here; the point is we need to call DeleteElem - // on each elem 'e' to let toks_ know we're done with them. - for (Elem *e = final_toks, *e_tail; e != NULL; e = e_tail) { - // loop this way because we delete "e" as we go. - StateId state = e->key; - Token *tok = e->val; - if (tok->tot_cost <= cur_cutoff) { - for (fst::ArcIterator aiter(fst, state); - !aiter.Done(); - aiter.Next()) { - const Arc &arc = aiter.Value(); - if (arc.ilabel != 0) { // propagate.. - BaseFloat ac_cost = cost_offset - - decodable->LogLikelihood(frame, arc.ilabel), - graph_cost = arc.weight.Value(), - cur_cost = tok->tot_cost, - tot_cost = cur_cost + ac_cost + graph_cost; - if (tot_cost > next_cutoff) continue; - else if (tot_cost + adaptive_beam < next_cutoff) - next_cutoff = tot_cost + adaptive_beam; // prune by best current token - // Note: the frame indexes into active_toks_ are one-based, - // hence the + 1. - Token *next_tok = FindOrAddToken(arc.nextstate, - frame + 1, tot_cost, tok, NULL); - // NULL: no change indicator needed - - // Add ForwardLink from tok to next_tok (put on head of list tok->links) - tok->links = new ForwardLink(next_tok, arc.ilabel, arc.olabel, - graph_cost, ac_cost, tok->links); - } - } // for all arcs - } - e_tail = e->tail; - toks_.Delete(e); // delete Elem - } - return next_cutoff; -} - -template BaseFloat LatticeFasterOnlineDecoder:: - ProcessEmitting>(DecodableInterface *decodable); -template BaseFloat LatticeFasterOnlineDecoder:: - ProcessEmitting>(DecodableInterface *decodable); -template BaseFloat LatticeFasterOnlineDecoder:: - ProcessEmitting>(DecodableInterface *decodable); - -BaseFloat LatticeFasterOnlineDecoder::ProcessEmittingWrapper( - DecodableInterface *decodable) { - if (fst_.Type() == "const") { - return LatticeFasterOnlineDecoder:: - ProcessEmitting>(decodable); - } else if (fst_.Type() == "vector") { - return LatticeFasterOnlineDecoder:: - ProcessEmitting>(decodable); - } else { - return LatticeFasterOnlineDecoder:: - ProcessEmitting>(decodable); - } -} - -template -void LatticeFasterOnlineDecoder::ProcessNonemitting(BaseFloat cutoff) { - KALDI_ASSERT(!active_toks_.empty()); - int32 frame = static_cast(active_toks_.size()) - 2; - // Note: "frame" is the time-index we just processed, or -1 if - // we are processing the nonemitting transitions before the - // first frame (called from InitDecoding()). - const FstType &fst = dynamic_cast(fst_); - - // Processes nonemitting arcs for one frame. Propagates within toks_. - // Note-- this queue structure is is not very optimal as - // it may cause us to process states unnecessarily (e.g. more than once), - // but in the baseline code, turning this vector into a set to fix this - // problem did not improve overall speed. - - KALDI_ASSERT(queue_.empty()); - for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) - queue_.push_back(e->key); - if (queue_.empty()) { - if (!warned_) { - KALDI_WARN << "Error, no surviving tokens: frame is " << frame; - warned_ = true; - } - } - - while (!queue_.empty()) { - StateId state = queue_.back(); - queue_.pop_back(); - - Token *tok = toks_.Find(state)->val; // would segfault if state not in toks_ but this can't happen. - BaseFloat cur_cost = tok->tot_cost; - if (cur_cost > cutoff) // Don't bother processing successors. - continue; - // If "tok" has any existing forward links, delete them, - // because we're about to regenerate them. This is a kind - // of non-optimality (remember, this is the simple decoder), - // but since most states are emitting it's not a huge issue. - tok->DeleteForwardLinks(); // necessary when re-visiting - tok->links = NULL; - for (fst::ArcIterator aiter(fst, state); - !aiter.Done(); - aiter.Next()) { - const Arc &arc = aiter.Value(); - if (arc.ilabel == 0) { // propagate nonemitting only... - BaseFloat graph_cost = arc.weight.Value(), - tot_cost = cur_cost + graph_cost; - if (tot_cost < cutoff) { - bool changed; - - Token *new_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, - tok, &changed); - - tok->links = new ForwardLink(new_tok, 0, arc.olabel, - graph_cost, 0, tok->links); - - // "changed" tells us whether the new token has a different - // cost from before, or is new [if so, add into queue]. - if (changed) queue_.push_back(arc.nextstate); - } - } - } // for all arcs - } // while queue not empty -} - -template void LatticeFasterOnlineDecoder:: - ProcessNonemitting>(BaseFloat cutoff); -template void LatticeFasterOnlineDecoder:: - ProcessNonemitting>(BaseFloat cutoff); -template void LatticeFasterOnlineDecoder:: - ProcessNonemitting>(BaseFloat cutoff); - -void LatticeFasterOnlineDecoder::ProcessNonemittingWrapper( - BaseFloat cost_cutoff) { - if (fst_.Type() == "const") { - return LatticeFasterOnlineDecoder:: - ProcessNonemitting>(cost_cutoff); - } else if (fst_.Type() == "vector") { - return LatticeFasterOnlineDecoder:: - ProcessNonemitting>(cost_cutoff); - } else { - return LatticeFasterOnlineDecoder:: - ProcessNonemitting>(cost_cutoff); - } -} - -void LatticeFasterOnlineDecoder::DeleteElems(Elem *list) { - for (Elem *e = list, *e_tail; e != NULL; e = e_tail) { - // Token::TokenDelete(e->val); - e_tail = e->tail; - toks_.Delete(e); - } -} - -void LatticeFasterOnlineDecoder::ClearActiveTokens() { // a cleanup routine, at utt end/begin - for (size_t i = 0; i < active_toks_.size(); i++) { - // Delete all tokens alive on this frame, and any forward - // links they may have. - for (Token *tok = active_toks_[i].toks; tok != NULL; ) { - tok->DeleteForwardLinks(); - Token *next_tok = tok->next; - delete tok; - num_toks_--; - tok = next_tok; - } - } - active_toks_.clear(); - KALDI_ASSERT(num_toks_ == 0); -} - -// static -void LatticeFasterOnlineDecoder::TopSortTokens(Token *tok_list, - std::vector *topsorted_list) { - unordered_map token2pos; - typedef unordered_map::iterator IterType; - int32 num_toks = 0; - for (Token *tok = tok_list; tok != NULL; tok = tok->next) - num_toks++; - int32 cur_pos = 0; - // We assign the tokens numbers num_toks - 1, ... , 2, 1, 0. - // This is likely to be in closer to topological order than - // if we had given them ascending order, because of the way - // new tokens are put at the front of the list. - for (Token *tok = tok_list; tok != NULL; tok = tok->next) - token2pos[tok] = num_toks - ++cur_pos; - - unordered_set reprocess; - - for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) { - Token *tok = iter->first; - int32 pos = iter->second; - for (ForwardLink *link = tok->links; link != NULL; link = link->next) { - if (link->ilabel == 0) { - // We only need to consider epsilon links, since non-epsilon links - // transition between frames and this function only needs to sort a list - // of tokens from a single frame. - IterType following_iter = token2pos.find(link->next_tok); - if (following_iter != token2pos.end()) { // another token on this frame, - // so must consider it. - int32 next_pos = following_iter->second; - if (next_pos < pos) { // reassign the position of the next Token. - following_iter->second = cur_pos++; - reprocess.insert(link->next_tok); - } - } - } - } - // In case we had previously assigned this token to be reprocessed, we can - // erase it from that set because it's "happy now" (we just processed it). - reprocess.erase(tok); - } - - size_t max_loop = 1000000, loop_count; // max_loop is to detect epsilon cycles. - for (loop_count = 0; - !reprocess.empty() && loop_count < max_loop; ++loop_count) { - std::vector reprocess_vec; - for (unordered_set::iterator iter = reprocess.begin(); - iter != reprocess.end(); ++iter) - reprocess_vec.push_back(*iter); - reprocess.clear(); - for (std::vector::iterator iter = reprocess_vec.begin(); - iter != reprocess_vec.end(); ++iter) { - Token *tok = *iter; - int32 pos = token2pos[tok]; - // Repeat the processing we did above (for comments, see above). - for (ForwardLink *link = tok->links; link != NULL; link = link->next) { - if (link->ilabel == 0) { - IterType following_iter = token2pos.find(link->next_tok); - if (following_iter != token2pos.end()) { - int32 next_pos = following_iter->second; - if (next_pos < pos) { - following_iter->second = cur_pos++; - reprocess.insert(link->next_tok); - } - } - } - } - } - } - KALDI_ASSERT(loop_count < max_loop && "Epsilon loops exist in your decoding " - "graph (this is not allowed!)"); - - topsorted_list->clear(); - topsorted_list->resize(cur_pos, NULL); // create a list with NULLs in between. - for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) - (*topsorted_list)[iter->second] = iter->first; -} - +// Instantiate the template for the FST types that we'll need. +template class LatticeFasterOnlineDecoderTpl >; +template class LatticeFasterOnlineDecoderTpl >; +template class LatticeFasterOnlineDecoderTpl >; +template class LatticeFasterOnlineDecoderTpl; } // end namespace kaldi. diff --git a/src/decoder/lattice-faster-online-decoder.h b/src/decoder/lattice-faster-online-decoder.h index 6cf0503d891..69bf8b6d98d 100644 --- a/src/decoder/lattice-faster-online-decoder.h +++ b/src/decoder/lattice-faster-online-decoder.h @@ -34,23 +34,46 @@ #include "fstext/fstext-lib.h" #include "lat/determinize-lattice-pruned.h" #include "lat/kaldi-lattice.h" -// Use the same configuration class as LatticeFasterDecoder. #include "decoder/lattice-faster-decoder.h" namespace kaldi { -/** LatticeFasterOnlineDecoder is as LatticeFasterDecoder but also supports an - efficient way to get the best path (see the function BestPathEnd()), which - is useful in endpointing. +/** LatticeFasterOnlineDecoderTpl is as LatticeFasterDecoderTpl but also + supports an efficient way to get the best path (see the function + BestPathEnd()), which is useful in endpointing and in situations where you + might want to frequently access the best path. + + This is only templated on the FST type, since the Token type is required to + be BackpointerToken. Actually it only makes sense to instantiate + LatticeFasterDecoderTpl with Token == BackpointerToken if you do so indirectly via + this child class. */ -class LatticeFasterOnlineDecoder { +template +class LatticeFasterOnlineDecoderTpl: + public LatticeFasterDecoderTpl { public: - typedef fst::StdArc Arc; - typedef Arc::Label Label; - typedef Arc::StateId StateId; - typedef Arc::Weight Weight; + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using Token = decoder::BackpointerToken; + using ForwardLinkT = decoder::ForwardLink; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeFasterOnlineDecoderTpl(const FST &fst, + const LatticeFasterDecoderConfig &config): + LatticeFasterDecoderTpl(fst, config) { } + + // This version of the initializer takes ownership of 'fst', and will delete + // it when this object is destroyed. + LatticeFasterOnlineDecoderTpl(const LatticeFasterDecoderConfig &config, + FST *fst): + LatticeFasterDecoderTpl(config, fst) { } + struct BestPathIterator { void *tok; @@ -64,42 +87,10 @@ class LatticeFasterOnlineDecoder { bool Done() { return tok == NULL; } }; - // instantiate this class once for each thing you have to decode. - LatticeFasterOnlineDecoder(const fst::Fst &fst, - const LatticeFasterDecoderConfig &config); - - // This version of the initializer "takes ownership" of the fst, - // and will delete it when this object is destroyed. - LatticeFasterOnlineDecoder(const LatticeFasterDecoderConfig &config, - fst::Fst *fst); - - - void SetOptions(const LatticeFasterDecoderConfig &config) { - config_ = config; - } - - const LatticeFasterDecoderConfig &GetOptions() const { - return config_; - } - - ~LatticeFasterOnlineDecoder(); - - /// Decodes until there are no more frames left in the "decodable" object.. - /// note, this may block waiting for input if the "decodable" object blocks. - /// Returns true if any kind of traceback is available (not necessarily from a - /// final state). - bool Decode(DecodableInterface *decodable); - - - /// says whether a final-state was active on the last frame. If it was not, the - /// lattice (or traceback) will end with states that are not final-states. - bool ReachedFinal() const { - return FinalRelativeCost() != std::numeric_limits::infinity(); - } /// Outputs an FST corresponding to the single best path through the lattice. /// This is quite efficient because it doesn't get the entire raw lattice and find - /// the best path through it; insterad, it uses the BestPathEnd and BestPathIterator + /// the best path through it; instead, it uses the BestPathEnd and BestPathIterator /// so it basically traces it back through the lattice. /// Returns true if result is nonempty (using the return status is deprecated, /// it will become void). If "use_final_probs" is true AND we reached the @@ -135,16 +126,8 @@ class LatticeFasterOnlineDecoder { BestPathIterator TraceBackBestPath( BestPathIterator iter, LatticeArc *arc) const; - /// Outputs an FST corresponding to the raw, state-level - /// tracebacks. Returns true if result is nonempty. - /// If "use_final_probs" is true AND we reached the final-state - /// of the graph then it will include those as final-probs, else - /// it will treat all final-probs as one. - /// The raw lattice will be topologically sorted. - bool GetRawLattice(Lattice *ofst, - bool use_final_probs = true) const; - /// Behaves the same like GetRawLattice but only processes tokens whose + /// Behaves the same as GetRawLattice but only processes tokens whose /// extra_cost is smaller than the best-cost plus the specified beam. /// It is only worthwhile to call this function if beam is less than /// the lattice_beam specified in the config; otherwise, it would @@ -153,271 +136,10 @@ class LatticeFasterOnlineDecoder { bool use_final_probs, BaseFloat beam) const; - - /// InitDecoding initializes the decoding, and should only be used if you - /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to - /// call this. You can also call InitDecoding if you have already decoded an - /// utterance and want to start with a new utterance. - void InitDecoding(); - - /// This will decode until there are no more frames ready in the decodable - /// object. You can keep calling it each time more frames become available. - /// If max_num_frames is specified, it specifies the maximum number of frames - /// the function will decode before returning. - void AdvanceDecoding(DecodableInterface *decodable, - int32 max_num_frames = -1); - - /// This function may be optionally called after AdvanceDecoding(), when you - /// do not plan to decode any further. It does an extra pruning step that - /// will help to prune the lattices output by GetRawLattice more accurately, - /// particularly toward the end of the utterance. It does this by using the - /// final-probs in pruning (if any final-state survived); it also does a final - /// pruning step that visits all states (the pruning that is done during - /// decoding may fail to prune states that are within kPruningScale = 0.1 - /// outside of the beam). If you call this, you cannot call AdvanceDecoding - /// again (it will fail), and you cannot call GetRawLattice() and related - /// functions with use_final_probs = false. Used to be called - /// PruneActiveTokensFinal(). - void FinalizeDecoding(); - - /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives - /// more information. It returns the difference between the best (final-cost - /// plus cost) of any token on the final frame, and the best cost of any token - /// on the final frame. If it is infinity it means no final-states were - /// present on the final frame. It will usually be nonnegative. If it not - /// too positive (e.g. < 5 is my first guess, but this is not tested) you can - /// take it as a good indication that we reached the final-state with - /// reasonable likelihood. - BaseFloat FinalRelativeCost() const; - - // Returns the number of frames decoded so far. The value returned changes - // whenever we call ProcessEmitting(). - inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } - - private: - // ForwardLinks are the links from a token to a token on the next frame. - // or sometimes on the current frame (for input-epsilon links). - struct Token; - struct ForwardLink { - Token *next_tok; // the next token [or NULL if represents final-state] - Label ilabel; // ilabel on link. - Label olabel; // olabel on link. - BaseFloat graph_cost; // graph cost of traversing link (contains LM, etc.) - BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing link - ForwardLink *next; // next in singly-linked list of forward links from a - // token. - inline ForwardLink(Token *next_tok, Label ilabel, Label olabel, - BaseFloat graph_cost, BaseFloat acoustic_cost, - ForwardLink *next): - next_tok(next_tok), ilabel(ilabel), olabel(olabel), - graph_cost(graph_cost), acoustic_cost(acoustic_cost), - next(next) { } - }; - - // Token is what's resident in a particular state at a particular time. - // In this decoder a Token actually contains *forward* links. - // When first created, a Token just has the (total) cost. We add forward - // links from it when we process the next frame. - struct Token { - BaseFloat tot_cost; // would equal weight.Value()... cost up to this point. - BaseFloat extra_cost; // >= 0. After calling PruneForwardLinks, this equals - // the minimum difference between the cost of the best path, and the cost of - // this is on, and the cost of the absolute best path, under the assumption - // that any of the currently active states at the decoding front may - // eventually succeed (e.g. if you were to take the currently active states - // one by one and compute this difference, and then take the minimum). - - ForwardLink *links; // Head of singly linked list of ForwardLinks - - Token *next; // Next in list of tokens for this frame. - - Token *backpointer; // best preceding Token (could be on this frame or a - // previous frame). This is only required for an - // efficient GetBestPath function, it plays no part in - // the lattice generation (the "links" list is what - // stores the forward links, for that). - - inline Token(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLink *links, - Token *next, Token *backpointer): - tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next), - backpointer(backpointer) { } - inline void DeleteForwardLinks() { - ForwardLink *l = links, *m; - while (l != NULL) { - m = l->next; - delete l; - l = m; - } - links = NULL; - } - }; - - // head of per-frame list of Tokens (list is in topological order), - // and something saying whether we ever pruned it using PruneForwardLinks. - struct TokenList { - Token *toks; - bool must_prune_forward_links; - bool must_prune_tokens; - TokenList(): toks(NULL), must_prune_forward_links(true), - must_prune_tokens(true) { } - }; - - typedef HashList::Elem Elem; - - void PossiblyResizeHash(size_t num_toks); - - // FindOrAddToken either locates a token in hash of toks_, or if necessary - // inserts a new, empty token (i.e. with no forward links) for the current - // frame. [note: it's inserted if necessary into hash toks_ and also into the - // singly linked list of tokens active on this frame (whose head is at - // active_toks_[frame]). The frame_plus_one argument is the acoustic frame - // index plus one, which is used to index into the active_toks_ array. - // Returns the Token pointer. Sets "changed" (if non-NULL) to true if the - // token was newly created or the cost changed. - inline Token *FindOrAddToken(StateId state, int32 frame_plus_one, - BaseFloat tot_cost, Token *backpointer, - bool *changed); - - // prunes outgoing links for all tokens in active_toks_[frame] - // it's called by PruneActiveTokens - // all links, that have link_extra_cost > lattice_beam are pruned - // delta is the amount by which the extra_costs must change - // before we set *extra_costs_changed = true. - // If delta is larger, we'll tend to go back less far - // toward the beginning of the file. - // extra_costs_changed is set to true if extra_cost was changed for any token - // links_pruned is set to true if any link in any token was pruned - void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed, - bool *links_pruned, - BaseFloat delta); - - // This function computes the final-costs for tokens active on the final - // frame. It outputs to final-costs, if non-NULL, a map from the Token* - // pointer to the final-prob of the corresponding state, for all Tokens - // that correspond to states that have final-probs. This map will be - // empty if there were no final-probs. It outputs to - // final_relative_cost, if non-NULL, the difference between the best - // forward-cost including the final-prob cost, and the best forward-cost - // without including the final-prob cost (this will usually be positive), or - // infinity if there were no final-probs. [c.f. FinalRelativeCost(), which - // outputs this quanitity]. It outputs to final_best_cost, if - // non-NULL, the lowest for any token t active on the final frame, of - // forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in - // the graph of the state corresponding to token t, or the best of - // forward-cost[t] if there were no final-probs active on the final frame. - // You cannot call this after FinalizeDecoding() has been called; in that - // case you should get the answer from class-member variables. - void ComputeFinalCosts(unordered_map *final_costs, - BaseFloat *final_relative_cost, - BaseFloat *final_best_cost) const; - - // PruneForwardLinksFinal is a version of PruneForwardLinks that we call - // on the final frame. If there are final tokens active, it uses - // the final-probs for pruning, otherwise it treats all tokens as final. - void PruneForwardLinksFinal(); - - // Prune away any tokens on this frame that have no forward links. - // [we don't do this in PruneForwardLinks because it would give us - // a problem with dangling pointers]. - // It's called by PruneActiveTokens if any forward links have been pruned - void PruneTokensForFrame(int32 frame_plus_one); - - - // Go backwards through still-alive tokens, pruning them if the - // forward+backward cost is more than lat_beam away from the best path. It's - // possible to prove that this is "correct" in the sense that we won't lose - // anything outside of lat_beam, regardless of what happens in the future. - // delta controls when it considers a cost to have changed enough to continue - // going backward and propagating the change. larger delta -> will recurse - // less far. - void PruneActiveTokens(BaseFloat delta); - - /// Gets the weight cutoff. Also counts the active tokens. - BaseFloat GetCutoff(Elem *list_head, size_t *tok_count, - BaseFloat *adaptive_beam, Elem **best_elem); - - /// Processes emitting arcs for one frame. Propagates from prev_toks_ to cur_toks_. - /// Returns the cost cutoff for subsequent ProcessNonemitting() to use. - /// Templated on FST type for speed; called via ProcessEmittingWrapper(). - template BaseFloat ProcessEmitting(DecodableInterface *decodable); - - BaseFloat ProcessEmittingWrapper(DecodableInterface *decodable); - - /// Processes nonemitting (epsilon) arcs for one frame. Called after - /// ProcessEmitting() on each frame. The cost cutoff is computed by the - /// preceding ProcessEmitting(). - /// the templated design is similar to ProcessEmitting() - template void ProcessNonemitting(BaseFloat cost_cutoff); - - void ProcessNonemittingWrapper(BaseFloat cost_cutoff); - - // HashList defined in ../util/hash-list.h. It actually allows us to maintain - // more than one list (e.g. for current and previous frames), but only one of - // them at a time can be indexed by StateId. It is indexed by frame-index - // plus one, where the frame-index is zero-based, as used in decodable object. - // That is, the emitting probs of frame t are accounted for in tokens at - // toks_[t+1]. The zeroth frame is for nonemitting transition at the start of - // the graph. - HashList toks_; - - std::vector active_toks_; // Lists of tokens, indexed by - // frame (members of TokenList are toks, must_prune_forward_links, - // must_prune_tokens). - std::vector queue_; // temp variable used in ProcessNonemitting, - std::vector tmp_array_; // used in GetCutoff. - // make it class member to avoid internal new/delete. - const fst::Fst &fst_; - bool delete_fst_; - std::vector cost_offsets_; // This contains, for each - // frame, an offset that was added to the acoustic log-likelihoods on that - // frame in order to keep everything in a nice dynamic range i.e. close to - // zero, to reduce roundoff errors. - LatticeFasterDecoderConfig config_; - int32 num_toks_; // current total #toks allocated... - bool warned_; - - /// decoding_finalized_ is true if someone called FinalizeDecoding(). [note, - /// calling this is optional]. If true, it's forbidden to decode more. Also, - /// if this is set, then the output of ComputeFinalCosts() is in the next - /// three variables. The reason we need to do this is that after - /// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some - /// of the tokens on the last frame are freed, so we free the list from toks_ - /// to avoid having dangling pointers hanging around. - bool decoding_finalized_; - /// For the meaning of the next 3 variables, see the comment for - /// decoding_finalized_ above., and ComputeFinalCosts(). - unordered_map final_costs_; - BaseFloat final_relative_cost_; - BaseFloat final_best_cost_; - - // There are various cleanup tasks... the the toks_ structure contains - // singly linked lists of Token pointers, where Elem is the list type. - // It also indexes them in a hash, indexed by state (this hash is only - // maintained for the most recent frame). toks_.Clear() - // deletes them from the hash and returns the list of Elems. The - // function DeleteElems calls toks_.Delete(elem) for each elem in - // the list, which returns ownership of the Elem to the toks_ structure - // for reuse, but does not delete the Token pointer. The Token pointers - // are reference-counted and are ultimately deleted in PruneTokensForFrame, - // but are also linked together on each frame by their own linked-list, - // using the "next" pointer. We delete them manually. - void DeleteElems(Elem *list); - - // This function takes a singly linked list of tokens for a single frame, and - // outputs a list of them in topological order (it will crash if no such order - // can be found, which will typically be due to decoding graphs with epsilon - // cycles, which are not allowed). Note: the output list may contain NULLs, - // which the caller should pass over; it just happens to be more efficient for - // the algorithm to output a list that contains NULLs. - static void TopSortTokens(Token *tok_list, - std::vector *topsorted_list); - - void ClearActiveTokens(); - - - KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterOnlineDecoder); + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterOnlineDecoderTpl); }; +typedef LatticeFasterOnlineDecoderTpl LatticeFasterOnlineDecoder; } // end namespace kaldi. diff --git a/src/decoder/lattice-simple-decoder.cc b/src/decoder/lattice-simple-decoder.cc index 79d6d5288be..f2b16782827 100644 --- a/src/decoder/lattice-simple-decoder.cc +++ b/src/decoder/lattice-simple-decoder.cc @@ -564,7 +564,9 @@ void LatticeSimpleDecoder::ProcessNonemitting() { for (unordered_map::iterator iter = cur_toks_.begin(); iter != cur_toks_.end(); ++iter) { - queue.push_back(iter->first); + StateId state = iter->first; + if (fst_.NumInputEpsilons(state) != 0) + queue.push_back(state); best_cost = std::min(best_cost, iter->second->tot_cost); } if (queue.empty()) { @@ -604,7 +606,7 @@ void LatticeSimpleDecoder::ProcessNonemitting() { // "changed" tells us whether the new token has a different // cost from before, or is new [if so, add into queue]. - if (changed) + if (changed && fst_.NumInputEpsilons(arc.nextstate) != 0) queue.push_back(arc.nextstate); } } diff --git a/src/decoder/training-graph-compiler.cc b/src/decoder/training-graph-compiler.cc index 8b28ad2d11f..191d02f1720 100644 --- a/src/decoder/training-graph-compiler.cc +++ b/src/decoder/training-graph-compiler.cc @@ -1,5 +1,7 @@ // decoder/training-graph-compiler.cc -// Copyright 2009-2011 Microsoft Corporation + +// Copyright 2009-2011 Microsoft Corporation +// 2018 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // @@ -40,15 +42,15 @@ TrainingGraphCompiler::TrainingGraphCompiler(const TransitionModel &trans_model, KALDI_ERR << "Disambiguation symbol " << disambig_syms_[i] << " is also a phone."; - int32 subseq_symbol = 1 + phone_syms.back(); - if (!disambig_syms_.empty() && subseq_symbol <= disambig_syms_.back()) - subseq_symbol = 1 + disambig_syms_.back(); + subsequential_symbol_ = 1 + phone_syms.back(); + if (!disambig_syms_.empty() && subsequential_symbol_ <= disambig_syms_.back()) + subsequential_symbol_ = 1 + disambig_syms_.back(); { int32 N = ctx_dep.ContextWidth(), P = ctx_dep.CentralPosition(); if (P != N-1) - AddSubsequentialLoop(subseq_symbol, lex_fst_); // This is needed for + AddSubsequentialLoop(subsequential_symbol_, lex_fst_); // This is needed for // systems with right-context or we will not successfully compose // with C. } @@ -80,25 +82,19 @@ bool TrainingGraphCompiler::CompileGraph(const fst::VectorFst &word KALDI_ASSERT(phone2word_fst.Start() != kNoStateId); - ContextFst *cfst = NULL; - { // make cfst [ it's expanded on the fly ] - const std::vector &phone_syms = trans_model_.GetPhones(); // needed to create context fst. - int32 subseq_symbol = phone_syms.back() + 1; - if (!disambig_syms_.empty() && subseq_symbol <= disambig_syms_.back()) - subseq_symbol = 1 + disambig_syms_.back(); - - cfst = new ContextFst(subseq_symbol, - phone_syms, - disambig_syms_, - ctx_dep_.ContextWidth(), - ctx_dep_.CentralPosition()); - } + const std::vector &phone_syms = trans_model_.GetPhones(); // needed to create context fst. + + // inv_cfst will be expanded on the fly, as needed. + InverseContextFst inv_cfst(subsequential_symbol_, + phone_syms, + disambig_syms_, + ctx_dep_.ContextWidth(), + ctx_dep_.CentralPosition()); - VectorFst ctx2word_fst; - ComposeContextFst(*cfst, phone2word_fst, &ctx2word_fst); - // ComposeContextFst is like Compose but faster for this particular Fst type. - // [and doesn't expand too many arcs in the ContextFst.] + VectorFst ctx2word_fst; + ComposeDeterministicOnDemandInverse(phone2word_fst, &inv_cfst, &ctx2word_fst); + // now ctx2word_fst is C * LG, assuming phone2word_fst is written as LG. KALDI_ASSERT(ctx2word_fst.Start() != kNoStateId); HTransducerConfig h_cfg; @@ -106,7 +102,7 @@ bool TrainingGraphCompiler::CompileGraph(const fst::VectorFst &word std::vector disambig_syms_h; // disambiguation symbols on // input side of H. - VectorFst *H = GetHTransducer(cfst->ILabelInfo(), + VectorFst *H = GetHTransducer(inv_cfst.IlabelInfo(), ctx_dep_, trans_model_, h_cfg, @@ -142,7 +138,6 @@ bool TrainingGraphCompiler::CompileGraph(const fst::VectorFst &word &trans2word_fst); delete H; - delete cfst; return true; } @@ -173,19 +168,14 @@ bool TrainingGraphCompiler::CompileGraphs( out_fsts->resize(word_fsts.size(), NULL); if (word_fsts.empty()) return true; - ContextFst *cfst = NULL; - { // make cfst [ it's expanded on the fly ] - const std::vector &phone_syms = trans_model_.GetPhones(); // needed to create context fst. - int32 subseq_symbol = phone_syms.back() + 1; - if (!disambig_syms_.empty() && subseq_symbol <= disambig_syms_.back()) - subseq_symbol = 1 + disambig_syms_.back(); - - cfst = new ContextFst(subseq_symbol, - phone_syms, - disambig_syms_, - ctx_dep_.ContextWidth(), - ctx_dep_.CentralPosition()); - } + const std::vector &phone_syms = trans_model_.GetPhones(); // needed to create context fst. + + // inv_cfst will be expanded on the fly, as needed. + InverseContextFst inv_cfst(subsequential_symbol_, + phone_syms, + disambig_syms_, + ctx_dep_.ContextWidth(), + ctx_dep_.CentralPosition()); for (size_t i = 0; i < word_fsts.size(); i++) { VectorFst phone2word_fst; @@ -196,10 +186,8 @@ bool TrainingGraphCompiler::CompileGraphs( "Perhaps you have words missing in your lexicon?"); VectorFst ctx2word_fst; - ComposeContextFst(*cfst, phone2word_fst, &ctx2word_fst); - // ComposeContextFst is like Compose but faster for this particular Fst type. - // [and doesn't expand too many arcs in the ContextFst.] - + ComposeDeterministicOnDemandInverse(phone2word_fst, &inv_cfst, &ctx2word_fst); + // now ctx2word_fst is C * LG, assuming phone2word_fst is written as LG. KALDI_ASSERT(ctx2word_fst.Start() != kNoStateId); (*out_fsts)[i] = ctx2word_fst.Copy(); // For now this contains the FST with symbols @@ -210,7 +198,7 @@ bool TrainingGraphCompiler::CompileGraphs( h_cfg.transition_scale = opts_.transition_scale; std::vector disambig_syms_h; - VectorFst *H = GetHTransducer(cfst->ILabelInfo(), + VectorFst *H = GetHTransducer(inv_cfst.IlabelInfo(), ctx_dep_, trans_model_, h_cfg, @@ -247,7 +235,6 @@ bool TrainingGraphCompiler::CompileGraphs( } delete H; - delete cfst; return true; } diff --git a/src/decoder/training-graph-compiler.h b/src/decoder/training-graph-compiler.h index 36bd62db4f7..ee56c6dfb3d 100644 --- a/src/decoder/training-graph-compiler.h +++ b/src/decoder/training-graph-compiler.h @@ -1,6 +1,7 @@ // decoder/training-graph-compiler.h -// Copyright 2009-2011 Microsoft Corporation +// Copyright 2009-2011 Microsoft Corporation +// 2018 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // @@ -23,6 +24,7 @@ #include "hmm/transition-model.h" #include "fst/fstlib.h" #include "fstext/fstext-lib.h" +#include "tree/context-dep.h" namespace kaldi { @@ -65,14 +67,14 @@ class TrainingGraphCompiler { const TrainingGraphCompilerOptions &opts); - /// CompileGraph compiles a single training graph its input is a + // CompileGraph compiles a single training graph its input is a // weighted acceptor (G) at the word level, its output is HCLG. - // Note: G could actually be an acceptor, it would also work. + // Note: G could actually be a transducer, it would also work. // This function is not const for technical reasons involving the cache. // if not for "table_compose" we could make it const. bool CompileGraph(const fst::VectorFst &word_grammar, fst::VectorFst *out_fst); - + // CompileGraphs allows you to compile a number of graphs at the same // time. This consumes more memory but is faster. bool CompileGraphs( @@ -87,8 +89,8 @@ class TrainingGraphCompiler { bool CompileGraphsFromText( const std::vector > &word_grammar, std::vector *> *out_fsts); - - + + ~TrainingGraphCompiler() { delete lex_fst_; } private: const TransitionModel &trans_model_; @@ -96,6 +98,7 @@ class TrainingGraphCompiler { fst::VectorFst *lex_fst_; // lexicon FST (an input; we take // ownership as we need to modify it). std::vector disambig_syms_; // disambig symbols (if any) in the phone + int32 subsequential_symbol_; // search in ../fstext/context-fst.h for more info. // symbol table. fst::TableComposeCache > lex_cache_; // stores matcher.. // this is one of Dan's extensions. diff --git a/src/doc/build_setup.dox b/src/doc/build_setup.dox index 86dca5cad69..47ff7e033a8 100644 --- a/src/doc/build_setup.dox +++ b/src/doc/build_setup.dox @@ -58,11 +58,9 @@ Changes that you might want to make to kaldi.mk after running "configure" are the following: - Changing the debug level: - - The default (which creates the easiest-to-debug binaries) is enabled by the options "-g -O0 -DKALDI_PARANOID". - - For faster, but still debuggable, binaries, you can change -O0 to -O1 - - If you won't need to debug the binaries, you can remove the "-g -O0 -DKALDI_PARANOID" options, which - will make it even faster. - - For maximum speed and no checking, you can replace the "-g -O0 -DKALDI_PARANOID" options with + - The default is "-O1" + - Easy to debug binaries can be enabled by uncommenting the options "-O0 -DKALDI_PARANOID". + - For maximum speed and no checking, you can replace the "-O0 -DKALDI_PARANOID" options with "-O2 -DNDEBUG" or "-O3 -DNDEBUG" - Changing the default precision - To test algorithms in double precision (e.g. if you suspect that roundoff is affecting diff --git a/src/doc/data_prep.dox b/src/doc/data_prep.dox index 89fd19ed8d4..e81032537cc 100644 --- a/src/doc/data_prep.dox +++ b/src/doc/data_prep.dox @@ -191,7 +191,7 @@ the speaker identities, you can just make the speaker-ids the same as the uttera so the format of the file would be just \ \. We have made the previous sentence bold because we have encountered people creating a "global" speaker-id. This is a bad idea because it makes cepstral mean normalization -ineffective in traning (since it's applied globally), and because it will create problems +ineffective in training (since it's applied globally), and because it will create problems when you use utils/split_data_dir.sh to split your data into pieces. There is another file that exists in some setups; it is used only occasionally and @@ -811,7 +811,7 @@ state transducers. (Note that language models would be represented as finite st acceptors, or FSAs, which can be considered as a special case of finite state transducers). The script utils/format_lm.sh deals with converting the ARPA-format language -models into an OpenFst format. Here is the usage messages of that script: +models into an OpenFst format. Here is the usage messages of that script: \verbatim Usage: utils/format_lm.sh E.g.: utils/format_lm.sh data/lang data/local/lm/foo.kn.gz data/local/dict/lexicon.txt data/lang_test @@ -838,4 +838,52 @@ E.g.: utils/format_lm_sri.sh data/lang data/local/lm/foo.kn.gz data/lang_test Converts ARPA-format language models to FSTs. Change the LM vocabulary using SRILM. \endverbatim + +\section data_prep_unknown Note on unknown words + +This is an explanation of how Kaldi deals with unknown words (words not in the +vocabulary); we are putting it on the "data preparation" page for lack of a more obvious +location. + +In many setups, \ or something similar will be present in the +LM as long as the data that you used to train the LM had words that were not +in the vocabulary you used to train the LM, +because language modeling toolkits tend to map those all to a +single special world, usually called \ or +\. You can look at the arpa file to figure out what it's called; it +will usually be one of those two. + + +During training, if there are words in the text file in your data +directory that are not in the words.txt in the lang directory that +you are using, Kaldi will map them to a special word that's specified in the +lang directory in the file data/lang/oov.txt; it will usually be +either \, \ or maybe +\. This word will have been chosen by the user +(i.e., you), and supplied to prepare_lang.sh as a command-line argument. +If this word has nonzero probability in the language model (which you can test +by looking at the arpa file), then it will be possible for Kaldi to recognize +this word in test time. This will often be the case if you call this word +\, because as we mentioned above, language modeling toolkits +will often use this spelling for ``unknown word'' (which is a special word that +all out-of-vocabulary words get mapped to). Decoding output will always be limited to the +intersection of the words in the language model with the words in the lexicon.txt (or whatever file format you supplied the +lexicon in, e.g. lexicop.txt); these words will all be present in the words.txt +in your lang directory. +So if Kaldi's "unknown word" doesn't match the LM's "unknown word", you will +simply never decode this word. In any +case, even when allowed to be decoded, this word typically won't be output very +often and in practice it doesn't tend to have much impact on WERs. + +Of course a single phone isn't a very good, or accurate, model of OOV words. In +some Kaldi setups we have example scripts with names +local/run_unk_model.sh: e.g., see the file +tedlium/s5_r2/local/run_unk_model.sh. These scripts replace the unk +phone with a phone-level LM on phones. They make it possible to get access to +the sequence of phones in a hypothesized unknown word. Note: unknown words +should be considered an "advanced topic" in speech recognition and we discourage +beginners from looking into this topic too closely. + + + */ diff --git a/src/doc/dependencies.dox b/src/doc/dependencies.dox index 63d2658b726..d8a5591955f 100644 --- a/src/doc/dependencies.dox +++ b/src/doc/dependencies.dox @@ -113,7 +113,7 @@ - CLAPACK, the linear algebra library (we download the headers). This is useful only on systems where you don't have ATLAS and are instead compiling with CLAPACK. - - OpenBLAS: this is an alernative to ATLAS or CLAPACK. The scripts don't + - OpenBLAS: this is an alternative to ATLAS or CLAPACK. The scripts don't use it by default but we provide installation scripts so you can install it if you want to compare it against ATLAS (it's more actively maintained than ATLAS). diff --git a/src/doc/dnn.dox b/src/doc/dnn.dox index 5b3d2b98261..bab4658e552 100644 --- a/src/doc/dnn.dox +++ b/src/doc/dnn.dox @@ -37,7 +37,7 @@ namespace kaldi { We currently have three separate codebases for deep neural nets in Kaldi. All are still active in the sense that the up-to-date recipes refer to all of them. The first one ("nnet1"( is located in code subdirectories nnet/ and - nnetbin/, and is primiarly maintained by Karel Vesely. The second is located + nnetbin/, and is primarily maintained by Karel Vesely. The second is located in code subdirectories nnet2/ and nnet2bin/, and is primarily maintained by Daniel Povey (this code was originally based on an earlier version of Karel's code, but it has been extensively rewritten). The third is located diff --git a/src/doc/dnn1.dox b/src/doc/dnn1.dox index 223b7665274..e8dcfd90d3f 100644 --- a/src/doc/dnn1.dox +++ b/src/doc/dnn1.dox @@ -35,13 +35,13 @@ show some \ref dnn1_advanced_features, and do a light introduction to the \ref d
\section dnn1_toplevel_scripts Top-level script -Let's have a look at the script egs/wsj/s5/local/nnet/run_dnn.sh. +Let's have a look at the script egs/wsj/s5/local/nnet/run_dnn.sh. This script assumes to use a single CUDA GPU, and that kaldi was compiled with CUDA (check for 'CUDA = true' in src/kaldi.mk). Also we assume that 'cuda_cmd' is set properly in egs/wsj/s5/cmd.sh either to a GPU cluster node using 'queue.pl' or to a local machine using 'run.pl'. And finally the script assumes we already have a SAT GMM system exp/tri4b and corresponding fMLLR transforms, as generated by egs/wsj/s5/run.sh. Note that for other databases the run_dnn.sh is typically in the same location s5/local/nnet/run_dnn.sh. -The script egs/wsj/s5/local/nnet/run_dnn.sh is split into several stages: +The script egs/wsj/s5/local/nnet/run_dnn.sh is split into several stages: 0. storing 40-dimensional fMLLR features to disk, steps/nnet/make_fmllr_feats.sh, this simplifies the training scripts, the 40-dimensional features are MFCC-LDA-MLLT-fMLLR with CMN @@ -100,7 +100,7 @@ Besides the DNN recipe, there are also other example scripts which can be handy:
\section dnn1_training_script_internals Training script internals -The main neural network training script steps/nnet/train.sh is invoked as: +The main neural network training script steps/nnet/train.sh is invoked as: \verbatim steps/nnet/train.sh @@ -111,11 +111,11 @@ The is used only in the special case when using LDA feature-transform The output (i.e. the trained networks and logfiles) goes into . Internally the script prepares the feature+target pipelines, generates a neural-network prototype and initialization, creates feature_transform and calls the scheduler script -steps/nnet/train_scheduler.sh, +steps/nnet/train_scheduler.sh, which runs the training epochs and controls the learning rate. -While looking inside steps/nnet/train.sh we see: +While looking inside steps/nnet/train.sh we see: 1. CUDA is required, the scripts exit if no GPU was detected or was CUDA not compiled in (one can still use '--skip-cuda-check true' to run on CPU, but it is 10-20x slower) @@ -165,12 +165,12 @@ $ cat exp/dnn5b_pretrain-dbn_dnn/nnet.proto 7. the network is initialized by : \ref nnet-initialize.cc , the DBN gets prepended in the next step using \ref nnet-concat.cc -8. finally the training gets called by running scheduler script steps/nnet/train_scheduler.sh +8. finally the training gets called by running scheduler script steps/nnet/train_scheduler.sh Note : both neural networks and feature transforms can be viewed by \ref nnet-info.cc, or shown in ascii by \ref nnet-copy.cc -While looking inside steps/nnet/train_scheduler.sh we see: +While looking inside steps/nnet/train_scheduler.sh we see: the initial cross-validation run and the main for-loop over $iter which runs the epochs and controls the learning rate. Typically, the train_scheduler.sh is called from train.sh. - the default learning-rate scheduling is based on the relative improvement of the objective function: @@ -310,7 +310,7 @@ AddMat 174.307s AddMatMat 1922.11s \endverbatim - Running steps/nnet/train_scheduler.sh directly: + Running steps/nnet/train_scheduler.sh directly: - The script train_scheduler.sh can be called outside train.sh, it allows to override the default NN-input and NN-target streams, which can be handy. - However the script assumes everything is set-up correctly, and there are almost no sanity checks, which makes it suitable for more advanced users only. - It is highly recommended to have a look at how train_scheduler.sh is usually called before trying to call it directly. diff --git a/src/doc/get_version_info.sh b/src/doc/get_version_info.sh index c11fb7f805e..3b9b8e1f2fe 100755 --- a/src/doc/get_version_info.sh +++ b/src/doc/get_version_info.sh @@ -42,8 +42,12 @@ fi # Note: when you add new tuples here you'll also want to add ndew # \htmlinclude directives in versions.dox. -for tuple in "5.0 5.0 c160a9883" "5.1 5.1 2145519961" "5.2 5.2 393ef73caa93" "5.3 5.3 131cdd4cb544" \ - "5.4 master be969d7baf04"; do +for tuple in "5.0 5.0 c160a9883" "5.1 5.1 2145519961" "5.2 5.2 393ef73caa93" "5.3 5.3 db28650346ba07" \ + "5.4 5.4 be969d7baf04" "5.5 master 7aab92b7c"; do + if [ $(echo $tuple | wc -w) != 3 ]; then + echo "$0: tuple should have 3 fields: '$tuple'" + exit 1 + fi major_minor_number=$(echo $tuple | awk '{print $1}') # e.g. 5.0 branch=$(echo $tuple | awk '{print $2}') # e.g. 'master', or '5.1' (it's a branch name) first_commit=$(echo $tuple | awk '{print $3}') diff --git a/src/doc/grammar.dox b/src/doc/grammar.dox new file mode 100644 index 00000000000..30396041d22 --- /dev/null +++ b/src/doc/grammar.dox @@ -0,0 +1,545 @@ +// doc/grammar.dox + + +// Copyright 2018 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +namespace kaldi { + +/** + + \page grammar Support for grammars and graphs with on-the-fly parts. + + + This page explains our support for dynamically created grammars and graphs with + extra parts that you want be able to compile quickly (like words you want to + add to the lexicon; contact lists; things like that). We have used the word + "grammar" as an easy searchable term for this framework, but this is not the + only way to implement grammars in Kaldi. If you have a smallish, fixed grammar + it would probably be much easier to create an FST (G.fst) directly from the + grammar (ensuring it is determinizable by means of disambiguation symbols if + necessary), and using the normal graph creation recipe. This framework is + specifically for where you have a compelling need to pre-compile the HCLG.fst + for various sub-parts and have them dynamically stitched together (typically to + avoid recompiling large graphs at runtime). + + This framework is limited to work only with left-biphone models. This is + without loss of performance, because our best models (chain models) already use + left-biphone context. + + + \section grammar_replace Relation to OpenFst's 'Replace()' operation + + The design of these tools is inspired by OpenFst's "Replace()"" operation, as implemented + by its command-line tool fstreplace. The basic idea is illustrated by its usage message: +\verbatim +Recursively replaces FST arcs with other FST(s). + + Usage: fstreplace root.fst rootlabel [rule1.fst label1 ...] [out.fst] +\endverbatim + Below is a very trivial example of using fstreplace; it just replaces the olabel 5 + in the top-level FST with 6. +\verbatim +# (echo 0 1 0 5; echo 1 0) | fstcompile > top.fst +# (echo 0 1 0 6; echo 1 0) | fstcompile > x.fst +# fstreplace top.fst 1000 x.fst 5 | fstprint +0 1 0 0 +1 2 0 6 +2 3 0 0 +3 +\endverbatim + The framework of these tools is similar, in that at the G.fst level there are + symbols that will end up getting replaced by other FSTs. Most of the + complexity has to do with the need to handle phonetic context-- and this is + the reason why we can't just use the existing Replace() operation or its + on-demand equivalent. + + A slight difference in interface of our tools versus fstreplace is that in our + tools, the top-level FST (corresponding to the 1st arg of fstreplace) does not + have a symbol assigned to it and thus cannot be "replaced into" any + FST. + + + \section grammar_overview Overview of the framework + + To explain how this works, we'll take the "contact list" scenario, where you want to + build a large language model with a nonterminal, say \#nonterm:contact_list in it, + and at recognition time you quickly build some kind of small LM representing + the contact list (possibly with previously unseen words), and compile that graph. + Both the "big graph" and the "small graph" are fully compiled down to the HCLG level. + The GrammarFst code "stitches them together" at decode time. The way this is + accomplished is by putting special ilabels in the two HCLGs that the GrammarFst + code knows how to interpret. That is: most ilabels in the HCLGs correspond to + transition-ids, but there are "special ilabels" with values over ten million, that + the GrammarFst code knows how to interpret, and it uses them to stitch together + the FSTs, in a way that's related to OpenFst's Replace() operation, but is a little + more complicated due to the need to get the phonetic context right. (It only supports + left-biphone context, to keep the complexity manageable). + + The GrammarFst has an interface very similar to OpenFst's "Fst" type-- + sufficiently similar that the decoder can use it as a drop-in replacement for a + normal FST-- but it does not actually inherit from any OpenFst type; this is to + simplify the implementation and give us more freedom in designing it. The + decoders that use GrammarFst are templated on the FST type, and we use + GrammarFst as the template argument when we want to decode with them. + + The StateId used in the GrammarFst code is a 64-bit StateId, which we interpret + as a pair of 32-bit integers. The high-order bits are the "fst instance" and the + low-order bits are the state in that "fst instance". In the contact-list example, + fst-instance zero would be the top-level graph, and there would potentially be + a new fst-instance, numbered 1, 2, ..., for each time the \#nonterm:contact_list nonterminal + appears in the big language model. However, these are only generated on demand + as those parts of the graph are actually accessed. The GrammarFst is a + lightweight object that does very little work at startup. It is designed to be + as fast as possible in the "normal case" when we are not crossing FST + boundaries, and are just traversing inside a single FST. The GrammarFst code + needs a fast-to-evaluate "signal" that it needs to do something special for a + particular FST state. We let the final-probabilities be that signal: that is, + each time we initialize an ArcIterator, the GrammarFst code tests + whether the final-prob has a special value or not. If it has that special value + (4096.0), then the GrammarFst code does a little bit of extra work to see whether + it needs to expand the state, and to look up a previously expanded + version of the state (or expand it if it wasn't already present). By "expand" + the state we mean compute the vector of arcs leaving it. + + + The FST compilation process-- i.e. the process of going from G.fst to HCLG.fst-- + is a little different when we intend to support grammars. That is, we need to + extend some of the tools used in compilation to work correctly with certain + special symbols that we introduce. The differences are explained below. + + + \subsection grammar_overview Where to find example script + + The top-level example scripts for this setup are in egs/mini_librispeech/s5; + see the scripts local/grammar/simple_demo.sh and local/grammar/extend_vocab_demo.sh. + There are also versions of these scripts that use silence probabilities, in + local/grammar/simple_demo_silprobs.sh and local/grammar/extend_vocab_demo_silprobs.sh. + (Actually the workflow is exactly the same in the silprob and no-silprob versions + of the scripts; we created those different versions for testing purposes, as those + demo scripts also help us test the correctness of the code). + + + \section grammar_symtabs Symbol tables and special symbols + + When using this framework, we to add certain extra symbols to the words.txt + and phones.txt symbol tables. These extra symbols represent certain special + symbols intrinsic to the framework, plus the user-defined nonterminal symbols. + In the following example the user-defined special symbols are \#nonterm:foo + and \#nonterm:bar. +\verbatim +tail words.txt +ZZZ 8431 +#0 8432 +#nonterm_begin 8434 +#nonterm_end 8435 +#nonterm:foo 8437 +#nonterm:bar 8438 +\endverbatim + The phones.txt contains a couple more symbols: +\verbatim +tail phones.txt +Z_S 243 +#0 244 +#1 245 +#2 246 +#nonterm_bos 247 +#nonterm_begin 248 +#nonterm_end 249 +#nonterm_reenter 250 +#nonterm:foo 251 +#nonterm:bar 252 +\endverbatim + The user should never need to explicitly add these symbols to the words.txt and + phones.txt files; they are automatically added by utils/prepare_lang.sh. All the user + has to do is to create the file 'nonterminals.txt' in the 'dict dir' (the directory + containing the dictionary, as validated by validate_dict_dir.pl). + + The C++ code never directly interacts with the nonterminal symbols in + words.txt; that is all done at the script level (e.g. creating L.fst), and the + C++ code only interacts with the nonterminal symbols in phones.txt. Therefore + there are no particularly strong constraints on the symbols in words.txt if you + are prepared to modify the scripts or create "LG.fst"-type graphs directly. + There are some constraints on the order of these symbols in phones.txt: in that case, + the inbuilt symbols (the ones without a colon) must be in the order shown, + the user-defined nonterminals must directly follow them, and there must be no + phones numbered higher than the nonterminal-related symbols (although higher-numbered + disambiguation symbols are allowed). + + Some binaries accept an option --nonterm-phones-offset, which tell them + where to find the nonterminal symbols. This should always be equal to the + integer id of the symbol \#nonterm_bos in phones.txt. In the above example + it would be --nonterm-phones-offset=247. + + \section grammar_special_g Special symbols in G.fst + + If you are using this framework you will be creating several graphs, so there + may be several copies of G.fst (and the intermediate and fully compiled + versions thereof). All of them are allowed to include sub-graphs via + nonterminals, and this can be done recursively; it is OK if the fully + compiled graph is infinite, because it is only expanded on demand. + + If you want to include a particular nonterminal (say the one for + \#nonterm:foo), you have to include that symbol \#nonterm:foo on the input + side of G.fst. As to what you include on the output side: that's up to you, as + the framework doesn't care, but bear in mind that symbols without + pronunciations may cause problems for lattice word alignment. Note to more + advanced users: the program lattice-align-words won't work if there are output + symbols in HCLG.fst that don't have any pronunciation, but the alternative + solution lattice-align-words-lexicon will still work, as long as you add + entries for those words with empty pronunciations, in align_lexicon.int; the + entries will be of the form 200007 200007, assuming 200007 is the integer id + of the word with the empty pronunciation. The script prepare_lang.sh adds + these entries for you. + + For graphs which are not top-level graphs, all ilabel sequences in + G.fst should begin with the special symbol \#nonterm_begin and end with + \#nonterm_end. This can be accomplished via fstconcat from the command + line, or by just adding them directly as you create the graph. These + symbols will later be involved in selecting the correct phonetic context when we + enter the compiled HCLG.fst. + + For some applications, such as the contact-list scenario where you are adding + new vocabulary items, it may be easier to skip creating G.fst and just create + LG.fst manually; this won't be hard to do once you know its expected structure. + The example script local/grammar/extend_vocab_demo.sh in egs/mini_librispeech/s5/ + may be a good reference for this, even if you don't plan to actually use those + scripts in production. + + + \section grammar_special_lg Special symbols in LG.fst + + Before we describe what L.fst does with the special symbols, + we will state what we expect LG.fst to contain after composition. All the + special symbols are on the ilabels of LG.fst. + + Let us define the set of "left-context phones" as the set of phones that can + end a word, plus the optional silence, plus the special symbol \#nonterm_bos. + This is the set of phones that can possibly appear as the left-context when we + are beginning a word, plus \#nonterm_bos as a stand-in for the beginning-of-sequence + context where no previous phone was seen. We will italicize the phrase + left-context phones when we use it, to emphasize that it has a special meaning. + + For non-top-level graphs only: + + - All ilabel sequences in the FST must begin with \#nonterm_begin followed by each possible + left-context phone, i.e. parallel arcs enumerating all possible phonetic + left-contexts that could precede this nonterminal. + + In non-word-position-dependent systems we can just let this set be all phones; + in word-position-dependent systems it can be all phones except word-internal + and word-begin phones, i.e. all phones except those that look like XX_B + and XX_I. If the set of possible left contexts is known to be smaller, it may + be more efficient to make this a smaller set. In addition to real phones, + we include \#nonterm_bos in this set, which represents the phonetic + context we encounter at the start of an utterance. + + - All ilabel sequences must end with \#nonterm_end. + + Whenever a nonterminal is invoked, whether from a top-level or non-top-level + graph, the ilabels in LG.fst will be, for example, \#nonterm:foo followed by + in parallel, all possible left-context phones. These left-context get added + by L.fst. + + \section grammar_special_l Special symbols in L.fst + + This section explains what sequences involving special symbols in L.fst we need to + add, in order to compile a LG.fst with the desired properties from G.fst. + The things we describe below are implemented by + utils/lang/make_lexicon_fst.py and utils/lang/make_lexicon_fst_silprob.py, + and is activated when you provide the --left-context-phones and --nonterminals + options. This is automatically called from prepare_lang.sh when it sees the + file nonterminals.txt in the input dictionary directory. + + Let the loop-state of L.fst be the state in L.fst with very high out-degree, + from which all the words leave (and return). + + + The lexicon needs to include, in addition to the normal things: + + - A sequence starting at the start state and ending at the loop-state, with + olabel \#nonterm_begin and ilabels consisting of, \#nonterm_begin + followed by all possible left-context phones (and \#nonterm_bos) in + parallel. + - An arc from the loop-state to a final state, with ilabel and olabel equal to \#nonterm_end. + - For each user-defined nonterminal (e.g. \#nonterm:foo) and for + \#nonterm_begin, a loop beginning and ending at the loop-state that starts with + the user-defined nontermal, e.g. \#nonterm:foo, on the ilabel and + olabel, and then has all left-context-phones on the ilabel only. + + In order to keep LG.fst as stochastic as possible (i.e. as "sum-to-one" as possible + in probabilistic terms), when we have states from which there leave arcs containing + all left-context phones we add a cost equal to the log of the number of + left-context phones. This will allow us to push the weights later + on in the graph-building procedure, without causing strange effects that would + be harmful to decoding speed and accuracy. When the graphs actually get spliced + together, all but one of the alternative paths for "all possible left-context + phone" will be disallowed; and that that point we will cancel out the cost of + log(number of left-context phones). This happens in the function + GrammarFst::CombineArcs(). + + Note that the above means that each sub-graph corresponding to + a user-defined nonterminal will allow optional silence after the nonterminal + but not before it. This is consistent with the way the nonterminal is invoked + from the higher-level graph, and generates exactly one optional silence between each pair of + "real" words, plus one at the beginning and end of the top-level graph. This equivalence + is something we test at the end of the example script + egs/mini_librispeech/s5/local/grammar/simple_demo.sh. + Users should bear all this in mind if they are going to construct these sub-graphs + manually at the LG.fst level rather than using the provided scripts. + + \subsection grammar_special_l Interaction with 'silprobs' + + In the versions of the lexicons that have word-specific silence probabilities +(see this paper for explanation) + there are actually two versions of the loop state, one for after silence + and one for after nonsilence . + When using 'silprobs', each word has a word-specific cost at its beginning and end that + is associated with the transition to/from nonsilence and silence respectively (where by + "silence" we specifically mean the optional silence added by the lexicon, not silence phones + in a more general sense). + + Please refer to utils/lang/make_lexicon_fst_silprob.py for the + details of how we handle nonterminal symbols in combination with these types of + graphs. We will just share the top-level idea here, which is this: when we + enter the HCLG.fst for the nonterminal, and when we return from it, we 'know' the + identity of immediately preceding phone. (That is how this framework works; read + further if you find this surprising). We use that information to implement + the 'silprob' idea without having to give the FST additional entry + points; basically, if the left-context phone was the optional-silence phone, we + go to the state in L.fst that would have been in after seeing optional silence. + This will do the right thing in the normal case. In the specific configuration + where you were not using word-position-dependent phones (c.f. the --position-dependent-phones + option of prepare_lang.sh) and where there are words in your lexicon that end with + the optional-silence phone (e.g. SIL), this will not quite do the right thing, + but we don't expect that this difference will be particularly significant in any real-world + use cases. + + \section grammar_special_clg Special symbols in CLG.fst + + First, some background: the symbols on the input of CLG.fst (i.e. the ilabels) have interpretation + given by a what we call the ilabel_info. This is explained more in \ref tree_ilabel. Programs + that consume CLG.fst always also consume the ilabel_info, which is a vector >. + For a particular ilabel, say 1536, ilabel_info[1536] = { 5, 21 } is a vector of integers representing + a phone-in-context. E.g. this would represent the phone 21 with a left-context of 5. + Disambiguation symbols also appear on the input of CLG.fst, and they are represented in the ilabel_info + a 1-dimensional vector like { -104 } containing the negative of the disambiguation symbol's + integer id. + + The special symbols we add to the input of CLG.fst to support the grammar-decoding framework + always correspond to pairs of symbols, + specifically pairs (\#nontermXXX, left-context phone), where \#nontermXXX is any + of the symbols \#nonterm_begin, \#nonterm_end, \#nonterm_reenter, or user-defined + nonterminals like \#nonterm:foo. The ilabel-info for these special symbols will be + pairs like {-104, 21} where the first element is the negative of the \#nontermXXX symbol + and the second is the left-context phone. The negation makes it easy to distinguish these + ilabel_info entries from regular phones-in-context. + + The special symbols in CLG.fst will be as follows. + + The following special symbols may appear in any CLG graph, top-level or not: + - When any graph invokes a sub-graph, there will be an arc with an ilabel + (\#nonterm:foo, left-context-phone) representing the + user-specified nonterminal and the actual left-context, which will be + followed by arcs with ilabels of the form (\#nonterm_reenter, + left-context-phone), for all left-context phones. + + For non-top-level CLG graphs only: + - These graphs will begin with ilabels representing pairs (\#nonterm_begin, left-context-phone), + representing all potential left-contexts. + - They will end with ilabels (\#nonterm_end, left-context-phone), representing + actual left-contexts. + + + \subsection grammar_special_c Special symbols in C.fst + + First, background. Since this framework only supports left-biphone + context, the states of C.fst correspond to the left context phone, and the + ilabels on the transitions correspond to biphones (plus self-loops for + disambiguation symbols). + + Next, what we are trying to accomplish. C.fst needs to do as follows + (describing how it needs to change sequences in LG.fst to sequences in CLG.fst): + + - It needs to change the sequence \#nonterm_begin p1 (where p1 is a left-context-phone) + to a single symbol representing the pair (\#nonterm_begin, p1). + - It needs to change the symbol \#nonterm_end to a single symbol representing + the pair (\#nonterm_end left-context-phone), where left-context-phone + represents the current phonetic left-context. + - For each user-defined nonterminal e.g. \#nonterm:foo, it needs to change + the sequence \#nonterm:foo p1 (where p1 is a left-context-phone) + to a sequence of two symbols representing the pairs (\#nonterm:foo, p0) and + (\#nonterm_renter p1) respectively. Here, p0 represents the phone that was + previous to the symbol \#nonterm:foo. + + In order to implement the above, we augment the state-space of C.fst by adding + three new states: + + - One which we transition to when the olabel is + \#nonterm_begin + - One which we transition to when we see any user-defined + symbol \#nonterm:foo. + - One which we transition to when the olabel is \#nonterm_end. + + In order to avoid changing the main context-fst code, we implement this in a + special class fst::InverseLeftBiphoneContextFst which implements these extensions + and which only supports the left-biphone case. See that code for more + details (search for "state space" in grammar-context-fst.h). + + + \section grammar_special_hclg Special symbols in HCLG.fst + + The special symbols in the HCLG.fst graphs will represent the same thing as + those in CLG.fst graphs, discussed above; but their representation in integer + form is different. + + Firstly, some background. At the input of CLG.fst the symbols are indexes + into an ilabel_info table. At the input of HCLG.fst the symbols, in general, + represent transition-ids-- and also disambiguation symbols, but those + are removed after determinization. The point is that HCLG.fst does not come with + a table like the ilabel_info that gives us the interpretation of symbols, + so we need to use an encoding that allows us to combine two integers into one. + + We choose a representation of the special symbols in HCLG.fst that avoids + clashing with the transition-ids and which makes it relatively painless to + decode the symbols to find what they represent. The representation of + a pair (\#nonterm:XXX, left-context-phone) is, + in the typical case: +\verbatim + hclg_ilabel = 1000000 + 1000 * nonterm_xxx + left_context_phone +\endverbatim + where of course nonterm_xxx and left_context_phone are the corresponding + symbol-ids in phones.txt. Actually, in place of + the "1000" above we use the smallest multiple of 1000 that is greater than the value passed to the + --nonterm-phones-offset option; this allows us to handle large phone sets while also being fairly + human-readable. + + + \subsection grammar_special_h Special symbols in H.fst + + Since H.fst only needs to change the integer represention of the special + symbols but otherwise leaves them unchanged, the changes to it are quite trivial. + H.fst has a high-out-degree state which we will refer to as the loop-state. + We just need to add a self-loop arc at the loop-state for each of the special + symbols referred to in the ilabel_info. The ilabel and olabel + are different since the integer encodings are different. + + + \section grammar_decoder The decoder + + + The current approach to decoding with grammars + is to wrap up the entire thing as an FST so that the same decoding code as + before can be used. That is, we just invoke the decoder with a different FST. + We use 64-bit state-ids, so that we can let the higher-order 32 bits encode the "fst instance" + and the lower-order bits encode the state within that instance. The fst instances + are created on the fly as states are visited. Instance 0 is always the "top-level" FST, + and we create new FST instances on the fly as needed, when we encounter arcs with + "special symbols" on. + + The actual decoder code is the same as the regular decoder; we just template it on + a different FST type: type fst::GrammarFst instead of fst::Fst. Class fst::GrammarFst does not + inherit from class fst::Fst or support its entire interface (this would have been very + complex to implement); it only supports the parts of the interface actually needed + by the decoder. + + + \subsection grammar_decoder_arc_iterator The ArcIterator of GrammarFst + + Probably the most critical part of the design is the ArcIterator + code, since the inner loop of the decoder is a loop over arcs. In order to avoid + having to copy the underlying FSTs, for "normal states" (those that don't have arcs + leaving them which enter or return from other FST instances), the ArcIterator code actually points into the + arcs of the underlying FSTs, which of course have a differently-typed 'nextstate', with 32 bits + not 64 bits. The ArcIterator also stores the higher 32 bits of the state-id, which + corresponds to the "fst instance" id, and every time you call its Next() function it + creates a new local copy of the 'current arc' it points to, which differs from the + underlying arc by having a 64-bit 'nextstate'. The overhead of copying the arc to a temporary will, + we hope, be mostly removed by compiler optimation. (In fact this does seem to be the + case: the overhead of GrammarFst decoding is about 15\% with -O0 and 5\% with -O2). + + Some states in the GrammarFst are 'special' states because they have arcs leaving them that + cross FST boundaries. For these 'special' states we have to construct the arcs separately, and + we store this information in a hash in class GrammarFst. + + To keep the decoder code fast and memory-efficient, we need to know quickly, + every time we visit a state, whether it is a "special" state or a normal state. + We don't want to do this with a big array indexed by state, because it would + take up too much memory per GrammarFst object. Instead we do it by giving a + special final-prob value to "special states" in the underlying FSTs that + GrammarFst stitches together. The ArcIterator code tests whether the + final-cost has this special value (4096.0) and if it does, it knows that it's a + "special state" and looks it up in a hash; if not, it just looks up the start + of the array of arcs for this state in the underlying FST. + + In order to avoid having any extra if-statements in the ArcIterator that would + have to be evaluated while we loop over arcs, we make sure that even "expanded + states" have vectors of arcs that use the underlying arc type (fst::StdArc) + with 32-bit state-ids. The "fst-instance" index of the destination FST is + stored separately in the ArcIterator, just as it is for normal states. This, + of course, requires that we must not have states with arcs leaving them + that transition to multiple FST instances. See the next section for how + we ensure this. + + + \section grammar_prepare Preparing FSTs for use in grammar decoding + + + The GrammarFst code has various requirements on the FSTs that it stitches together, + some of which were mentioned above. These requirements are designed to help + keep the GrammarFst code fast. The function fst::PrepareForGrammarFst (internally implemented + by class fst::GrammarFstPreparer) ensures that these preconditions are met. The user is required + to call this preparation code prior to instantiating the GrammarFst object, so the preparation is + considered part of the graph construction; this keeps the run-time code fast. + The standard graph-construction script utils/mkgraph.sh calls this automatically (via the binary + make-grammar-fst) if it detects that you are using this framework. + + The tasks of fst::PrepareForGrammarFst include setting a final-cost of 4096.0 for + FST states that will end up being "special" states, and also making various small + changes to the HCLG.fst that ensure it has the properties needed by class fst::GrammarFst + (e.g. ensuring no state will have transitions to multiple FST instances). These + changes are mostly accomplished by inserting epsilon arcs; for details, see the + documentation of class fst::GrammarFstPreparer. + + \section grammar_olabels Output labels in GrammarFsts + + In the example scripts we provided, because we only wanted "real words" to + appear on the output side of HCLG.fst, we ensured that no special symbols of + the form \#nontermXXX on the output side of G.fst. However, the + graph compilation framework does allow you to include those symbols if you + want. These might be useful in certain application scenarios, where you want + to know that a particular span of words was decoded as part of a sub-grammar. + The only thing you have to be careful of is that the program + lattice-align-words (and the code underlying it) will not work if you have + words that have an empty pronunciation. That can be an issue if you need to find the + exact time-alignment of words for some reason. In those cases you should use + the alternative program lattice-align-words-lexicon (which reads a file + lexicon.int giving the pronunciation of words in your lexicon), which should + work even in this case. The prepare_lang.sh script already puts empty + pronunciation entries for symbols of the form \#nontermXXX + in lexicon.int, so lattice-align-words-lexicon method of word alignment + should "just work" if you made the lang and graph directories using the + provided scripts. + + + + + +*/ + + +} diff --git a/src/doc/history.dox b/src/doc/history.dox index 40d46c7e32f..0813f2331cc 100644 --- a/src/doc/history.dox +++ b/src/doc/history.dox @@ -54,7 +54,8 @@ Sandeep Boda, Sandeep Reddy and Haihua Xu (who helped with coding, code cleanup and documentation); we were visited by Michael Riley (who helped us to understand OpenFst and gave some lectures on FSTs), and would like to acknowledge the help of - Honza Cernocky (for allowing us to have the workshop and helping to organize it), + Honza Cernocky (for negotiating the venue and some support for the workshop from + the Faculty of Information Technology of BUT and helping to organize it), Renata Kohlova (administration), and Tomas Kasparek (system administration). It is possible that this list of contributors contains oversights; any important omissions are unlikely to be intentional. @@ -62,13 +63,16 @@ A lot of code was written during the summer of 2010 but we still did not have a complete working system. Some of the participants of the 2010 workshop continued working to complete the toolkit and get a working set of training scripts. - The code was released on May 14th, 2011. + The code was released on May 14th, 2011, and presented to public at ICASSP 2011 + in Prague, + + see the recordings. Since the initial release, Kaldi has been maintained and developed to a large extent by Daniel Povey, working at Microsoft Research until early 2012 and since then at Johns Hopkins University; but also with major contributions by others: notably Karel Vesely, who developed the neural-net training framework, - and Arnab Ghoshal, who co-ordinated the acoustic modeling work early on; but + and Arnab Ghoshal, who coordinated the acoustic modeling work early on; but also other major contributors whom we do not name here because it is too hard to determine where to cut off the list; and a long tail of minor contributors; the total number of people who have contributed code or scripts or patches is diff --git a/src/doc/hmm.dox b/src/doc/hmm.dox index c410b1ba5a1..fb936bf2d25 100644 --- a/src/doc/hmm.dox +++ b/src/doc/hmm.dox @@ -98,7 +98,7 @@ numbered state of a "prototype HMM" has two variables "forward_pdf_class" and "self_loop_pdf_class". The "self_loop_pdf_class" is a kind of pdf-class that is associated with self-loop transition. It is by default identical to "forward_pdf_class", -but it can be used to define less-convectional HMM topologies +but it can be used to define less-conventional HMM topologies where the pdfs on the self-loop and forward transitions are different. The decision to allow the pdf-class on just the self-loop to be different, while not embracing a fully "arc-based" representation where the pdfs on diff --git a/src/doc/io.dox b/src/doc/io.dox index dc958f57a6f..8f3a3cc05b6 100644 --- a/src/doc/io.dox +++ b/src/doc/io.dox @@ -383,7 +383,7 @@ namespace kaldi { std::string rspecifier2 = "ark:-"; // archive read from stdin. // write to a gzipped text archive. std::string wspecifier1 = "ark,t:| gzip -c > /some/dir/foo.ark.gz"; - std::string wspecifier2 = "ark,scp:data/my.ark,data/my.ark"; + std::string wspecifier2 = "ark,scp:data/my.ark,data/my.scp"; \endcode Usually, an rspecifier or wspecifier consists of a comma-separated, unordered @@ -401,7 +401,7 @@ namespace kaldi { \endverbatim This will write an archive, and a script file with lines like "utt_id /somedir/foo.ark:1234" that specify offsets into the - archive for more efficient random access. You can then do what you like which + archive for more efficient random access. You can then do whatever you like with the script file, including breaking it up into segments, and it will behave like any other script file. Note that although the order of options before the colon doesn't generally matter, in this particular case the "ark" must come before diff --git a/src/doc/kaldi_for_dummies.dox b/src/doc/kaldi_for_dummies.dox index c04e0d0c3e9..b48d6dd8dac 100644 --- a/src/doc/kaldi_for_dummies.dox +++ b/src/doc/kaldi_for_dummies.dox @@ -71,7 +71,7 @@ and installation, - \c awk – programming language, used for searching and processing patterns in files and data streams, - \c bash – Unix shell and script programming language, - - \c grep – command-line utility for searching plain-text data sets for lines + - \c grep – command-line utility for searching plain-text datasets for lines matching a regular expression, - \c make – automatically builds executable programs and libraries from source code, @@ -87,16 +87,16 @@ If you do not have much idea about how to use GIT, please read about it: \ref tutorial_git. I installed Kaldi in this directory (called 'Kaldi root path'): -\c /home/{user}/kaldi-trunk +\c /home/{user}/kaldi \section kaldi_for_dummies_directories Kaldi directories structure Try to acknowledge where particular Kaldi components are placed. Also it would be nice if you read any \c README files you find. -\c kaldi-trunk - main Kaldi directory which contains: +\c kaldi - main Kaldi directory which contains: - \c egs – example scripts allowing you to quickly build ASR -systems for over 30 popular speech corporas (documentation is attached for each +systems for over 30 popular speech corpora (documentation is attached for each project), - \c misc – additional tools and supplies, not needed for proper Kaldi functionality, @@ -127,7 +127,7 @@ train it, test it and get some decoding results.

Your first task

Something to begin with - create a folder \c digits in -\c kaldi-trunk/egs/ directory. This is a place where you will put all +\c kaldi/egs/ directory. This is a place where you will put all the stuff related to your project. \section kaldi_for_dummies_data Data preparation @@ -136,34 +136,34 @@ the stuff related to your project. I assume that you want to set up an ASR system, basing on your own audio data. For example - let it be a set of 100 files. File format is WAV. Each file -contains 3 spoken digits recorded in english language, one by one. Each of +contains 3 spoken digits recorded in English language, one by one. Each of these audio files is named in a recognizable way (e.g. \c 1_5_6.wav, which in my pattern means that the spoken sentence is 'one, five, six') and placed in the recognizable folder representing particular speaker during a particular recording session (there may be a situation that you have recordings of the same person but in two different quality/noise environments - put these -in separate folders). So to sum up, my exemplary data set looks like this: +in separate folders). So to sum up, my exemplary dataset looks like this: - 10 different speakers (ASR systems must be trained and tested on different speakers, the more speakers you have the better), - each speaker says 10 sentences, - - 100 senteces/utterances (in 100 *.wav files placed in 10 folders related to + - 100 sentences/utterances (in 100 *.wav files placed in 10 folders related to particular speakers - 10 *.wav files in each folder), - 300 words (digits from zero to nine), - each sentence/utterance consist of 3 words. -Whatever your first data set is, adjust my example to your particular case. Be -careful with big data sets and complex grammars - start with something simple. +Whatever your first dataset is, adjust my example to your particular case. Be +careful with big datasets and complex grammars - start with something simple. Sentences that contain only digits are perfect in this case.

Task

-Go to \c kaldi-trunk/egs/digits directory and create -\c digits_audio folder. In \c kaldi-trunk/egs/digits/digits_audio +Go to \c kaldi/egs/digits directory and create +\c digits_audio folder. In \c kaldi/egs/digits/digits_audio create two folders: \c train and \c test. Select one speaker -of your choice to represent testing data set. Use this speaker's 'speakerID' as -a name for an another new folder in \c kaldi-trunk/egs/digits/digits_audio/test +of your choice to represent testing dataset. Use this speaker's 'speakerID' as +a name for an another new folder in \c kaldi/egs/digits/digits_audio/test directory. Then put there all the audio files related to that person. Put the rest (9 speakers) into \c train folder - this will be your training -data set. Also create subfolders for each speaker. +dataset. Also create subfolders for each speaker. \subsection kaldi_for_dummies_acoustic Acoustic data @@ -174,14 +174,14 @@ section as well) can be considered as a text file with some number of strings (each string in a new line). These strings need to be sorted. If you will encounter any sorting issues you can use Kaldi scripts for checking (\c utils/validate_data_dir.sh) and fixing (\c utils/fix_data_dir.sh) data order. -And for you information - \c utils directory will be attached to your project in +And for your information - \c utils directory will be attached to your project in \ref kaldi_for_dummies_tools "Tools attachment" section.

Task

-In \c kaldi-trunk/egs/digits directory, create a folder \c data. Then create +In \c kaldi/egs/digits directory, create a folder \c data. Then create \c test and \c train subfolders inside. Create in each subfolder following files (so you have files named in the same way in \c test and \c train subfolders -but they relate to two different data sets that you created before): +but they relate to two different datasets that you created before): a.) \c spk2gender
This file informs about speakers gender. As we assumed, 'speakerID' is a unique @@ -207,9 +207,9 @@ for examples below). Pattern: \verbatim -dad_4_4_2 /home/{user}/kaldi-trunk/egs/digits/digits_audio/train/dad/4_4_2.wav -july_1_2_5 /home/{user}/kaldi-trunk/egs/digits/digits_audio/train/july/1_2_5.wav -july_6_8_3 /home/{user}/kaldi-trunk/egs/digits/digits_audio/train/july/6_8_3.wav +dad_4_4_2 /home/{user}/kaldi/egs/digits/digits_audio/train/dad/4_4_2.wav +july_1_2_5 /home/{user}/kaldi/egs/digits/digits_audio/train/july/1_2_5.wav +july_6_8_3 /home/{user}/kaldi/egs/digits/digits_audio/train/july/6_8_3.wav # and so on... \endverbatim @@ -236,8 +236,8 @@ july_6_8_3 july \endverbatim e.) \c corpus.txt
-This file has a slightly different directory. In \c kaldi-trunk/egs/digits/data -create another folder \c local. In \c kaldi-trunk/egs/digits/data/local create a +This file has a slightly different directory. In \c kaldi/egs/digits/data +create another folder \c local. In \c kaldi/egs/digits/data/local create a file \c corpus.txt which should contain every single utterance transcription that can occur in your ASR system (in our case it will be 100 lines from 100 audio files). @@ -252,14 +252,14 @@ four four two \subsection kaldi_for_dummies_language Language data -This section relates to language modelling files that also need to be considered +This section relates to language modeling files that also need to be considered as 'must be done'. Look for the syntax details here: \ref data_prep (each file is precisely described). Also feel free to read some examples in other \c egs scripts. Now is the perfect time.

Task

-In \c kaldi-trunk/egs/digits/data/local directory, create a folder \c dict. In -\c kaldi-trunk/egs/digits/data/local/dict create following files: +In \c kaldi/egs/digits/data/local directory, create a folder \c dict. In +\c kaldi/egs/digits/data/local/dict create following files: a.) \c lexicon.txt
This file contains every word from your dictionary with its 'phone @@ -337,19 +337,19 @@ complete. You need to add necessary Kaldi tools that are widely used in exemplary scripts.

Task

-From \c kaldi-trunk/egs/wsj/s5 copy two folders (with the whole content) - +From \c kaldi/egs/wsj/s5 copy two folders (with the whole content) - \c utils and \c steps - and put them in your -\c kaldi-trunk/egs/digits directory. You can also create links to these +\c kaldi/egs/digits directory. You can also create links to these directories. You may find such links in, for example, -\c kaldi-trunk/egs/voxforge/s5. +\c kaldi/egs/voxforge/s5. \subsection kaldi_for_dummies_scoring Scoring script This script will help you to get decoding results.

Task

-From \c kaldi-trunk/egs/voxforge/s5/local copy the script \c score.sh into -similar location in your project (\c kaldi-trunk/egs/digits/local). +From \c kaldi/egs/voxforge/s5/local copy the script \c score.sh into +similar location in your project (\c kaldi/egs/digits/local). \subsection kaldi_for_dummies_srilm SRILM installation @@ -358,7 +358,7 @@ example - SRI Language Modeling Toolkit (SRILM).

Task

For detailed installation instructions go to -\c kaldi-trunk/tools/install_srilm.sh (read all comments inside). +\c kaldi/tools/install_srilm.sh (read all comments inside). \subsection kaldi_for_dummies_configuration Configuration files @@ -366,8 +366,8 @@ It is not necessary to create configuration files but it can be a good habit for future.

Task

-In \c kaldi-trunk/egs/digits create a folder \c conf. Inside -\c kaldi-trunk/egs/digits/conf create two files (for some configuration +In \c kaldi/egs/digits create a folder \c conf. Inside +\c kaldi/egs/digits/conf create two files (for some configuration modifications in decoding and mfcc feature extraction processes - taken from \c /egs/voxforge): @@ -395,10 +395,10 @@ decided to use two different training methods: - TRI1 - simple triphone training (first triphone pass). These two methods are enough to show noticable differences in decoding results -using only digits lexicon and small training data set. +using only digits lexicon and small training dataset.

Task

-In \c kaldi-trunk/egs/digits directory create 3 scripts: +In \c kaldi/egs/digits directory create 3 scripts: a.) \c cmd.sh
\code{.sh} @@ -416,7 +416,7 @@ export KALDI_ROOT=`pwd`/../.. export PATH=$PWD/utils/:$KALDI_ROOT/src/bin:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/src/fstbin/:$KALDI_ROOT/src/gmmbin/:$KALDI_ROOT/src/featbin/:$KALDI_ROOT/src/lmbin/:$KALDI_ROOT/src/sgmm2bin/:$KALDI_ROOT/src/fgmmbin/:$KALDI_ROOT/src/latbin/:$PWD:$PATH # Defining audio data directory (modify it for your installation directory!) -export DATA_ROOT="/home/{user}/kaldi-trunk/egs/digits/digits_audio" +export DATA_ROOT="/home/{user}/kaldi/egs/digits/digits_audio" # Enable SRILM . $KALDI_ROOT/tools/env.sh @@ -432,7 +432,7 @@ c.) \c run.sh . ./path.sh || exit 1 . ./cmd.sh || exit 1 -nj=1 # number of parallel jobs - 1 is perfect for such a small data set +nj=1 # number of parallel jobs - 1 is perfect for such a small dataset lm_order=1 # language model order (n-gram quantity) - 1 is enough for digits grammar # Safety mechanism (possible running this script with modified arguments) @@ -564,7 +564,7 @@ Now all you have to do is to run \c run.sh script. If I have made any mistakes in this tutorial, logs from the terminal should guide you how to deal with it. Besides the fact that you will notice some decoding results in the terminal -window, go to newly made \c kaldi-trunk/egs/digits/exp. You may notice there +window, go to newly made \c kaldi/egs/digits/exp. You may notice there folders with \c mono and \c tri1 results as well - directories structure are the same. Go to \c mono/decode directory. Here you may find result files (named in a wer_{number} way). Logs for decoding process may be found in \c log @@ -575,7 +575,7 @@ folder (same directory). This is just an example. The point of this short tutorial is to show you how to create 'anything' in Kaldi and to get a better understanding of how to think while using this toolkit. Personally I started with looking for tutorials made -by the Kaldi authors/developers. After succesful Kaldi installation I launched +by the Kaldi authors/developers. After successful Kaldi installation I launched some example scripts (Yesno, Voxforge, LibriSpeech - they are relatively easy and have free acoustic/language data to download - I used these three as a base for my own scripts). @@ -586,7 +586,7 @@ There are two very useful sections for beginners inside:
a.) \ref tutorial - almost 'step by step' tutorial on how to set up an ASR system; up to some point this can be done without RM dataset. It is good to read it,
-b.) \ref data_prep - very detailed explaination of how to use your own data +b.) \ref data_prep - very detailed explanation of how to use your own data in Kaldi. More useful links about Kaldi I found:
diff --git a/src/doc/mainpage.dox b/src/doc/mainpage.dox index c6a3468a5d0..88fefbd8e02 100644 --- a/src/doc/mainpage.dox +++ b/src/doc/mainpage.dox @@ -66,6 +66,7 @@ - \subpage graph - \subpage graph_recipe_test - \subpage graph_recipe_train + - \subpage grammar - \subpage fst_algo - \subpage decoders - \subpage lattices diff --git a/src/doc/online_decoding.dox b/src/doc/online_decoding.dox index 799bfb5895f..9bcc2575be1 100644 --- a/src/doc/online_decoding.dox +++ b/src/doc/online_decoding.dox @@ -438,6 +438,89 @@ and downloadable models that can be used with online nnet3 decoding, please see http://kaldi-asr.org/models.html (the first model there, the ASPIRE model, includes instructions in a README file). +\subsection online_decoding_nnet3_tcp TCP server for nnet3 online decoding + +The program to run the TCP sever is online2-tcp-nnet3-decode-faster located in the +~/src/online2bin folder. The usage is as follows: + +\verbatim +online2-tcp-nnet3-decode-faster +\endverbatim + +For example: + +\verbatim +online2-tcp-nnet3-decode-faster model/final.mdl graph/HCLG.fst graph/words.txt +\endverbatim + +The word symbol table is mandatory (unlike other nnet3 online decoding programs) because +the server outputs word strings. Endpointing is mandatory to make the operation of the +program reasonable. Other, non-standard options include: + - port-num - the port the server listens on (by default 5050) + - samp-freq - sampling frequency of audio (usually 8000 for telephony and 16000 for other uses) + - chunk-length - length of signal being processed by decoder at each step + - output-period - how often we check for changes in the decoding (ie. output refresh rate, default 1s) + - num-threads-startup - number of threads used when initializing iVector extractor + - read-timeout - it the program doesn't receive data during this timeout, the server terminates the connection. + Use -1 to disable this feature. + +The TCP protocol simply takes RAW signal on input (16-bit signed integer +encoding at chosen sampling frequency) and outputs simple text using the following +logic: + - each refresh period (output-freq argument) the current state of decoding is output + - each line is terminated by '\r' + - once an utterance boundary is detected due to endpointing a '\n' char is output + +Each output string (delimited by '\r') should be treated as uncertain and can change +entirely until the utterance delimiter ('\n') is sent. The delimiter chars are chosen +specifically in order to make the output look neat in the terminal. It is possible to +use it with other interfaces and a web demo (HTML/JS AudioAPI+WebSockets) exists. + +To run the program from the terminal you can use one of the following commands. First, +make sure the server is running and accepting connections. Using the Aspire models, the +command should look like this: +\verbatim +online2-tcp-nnet3-decode-faster --samp-freq=8000 --frames-per-chunk=20 --extra-left-context-initial=0 + --frame-subsampling-factor=3 --config=model/conf/online.conf --min-active=200 --max-active=7000 + --beam=15.0 --lattice-beam=6.0 --acoustic-scale=1.0 --port-num=5050 model/final.mdl graph/HCLG.fst graph/words.txt +\endverbatim + +Note in order to make the communication as simple as possible, the server has to accept +any data on input and cannot figure out when the stream is over. It will therefore not +be able to terminate the connection and it is the client's resposibility to disconnect +when it is ready to do so. As a fallback for certain situations, the read-timeout option +was added, which will automatically disconnect if a chosen amount of seconds has passed. +Keep in mind, that this is not an ideal solution and it's a better idea to design your +client to properly disconnect the connection when neccessary. + +For testing purposes, we will use the netcat program. We will also use sox to reeoncode the +files properly from any source. Netcat has an issue that, similarly to what was stated above +about the server, it cannot always interpret the data and usually it won't automatically +disconnect the TCP connection. To get around this, we will use the '-N' switch, which kills +the connection once streaming of the file is complete, but this can have a small sideffect of +not reading the whole output from the Kaldi server if the discconect comes too fast. Just +keep this in mind if you intend to implement any of these programs into a production environment. + +To send a WAV file into the server, it first needs to be decoded into raw audio, then it can be +sent to the socket: +\verbatim +sox audio.wav -t raw -c 1 -b 16 -r 8k -e signed-integer - | nc -N localhost 5050 +\endverbatim + +It is possible to play audio (almost) simultaneously as decoding. It may require installing the +'pv' program (used to throttle the signal into Kaldi at the same speed as the playback): + +\verbatim +sox audio.wav -t raw -c 1 -b 16 -r 8k -e signed-integer - | \ + tee >(play -t raw -r 8k -e signed-integer -b 16 -c 1 -q -) | \ + pv -L 16000 -q | nc -N localhost 5050 +\endverbatim + +Finally, it is possible to send audio from the microphone directly into the server: + +\verbatim +rec -r 8k -e signed-integer -c 1 -b 16 -t raw -q - | nc -N localhost 5050 +\endverbatim */ diff --git a/src/doc/tutorial_looking.dox b/src/doc/tutorial_looking.dox index 420abfc9bce..831d721c7eb 100644 --- a/src/doc/tutorial_looking.dox +++ b/src/doc/tutorial_looking.dox @@ -171,7 +171,7 @@ making sure have their normal values, begin with KALDI_. This is a precaution to avoid future conflicts with other codebases (since \#defines don't limit themselves to the kaldi namespace). Notice the style of the function names: LikeThis(). Our style is generally based on - this one , + this one , to conform with OpenFst, but there are some differences. To see other elements of the style, which will help you to understand Kaldi @@ -190,7 +190,7 @@ It prints out the usage, which should give you a generic idea of how Kaldi progr are called. Note that while there is a --config option that can be used to pass a configuration file, in general Kaldi is not as config-driven as HTK and these files are not widely used. You will see a --binary option. In general, Kaldi file -formats come in both binary and test forms, and the --binary option controls how +formats come in both binary and text forms, and the --binary option controls how they are written. However, this only controls how single objects (e.g. acoustic models) are written. For whole collections of objects (e.g. collections of feature files), there is a different mechanism that we will come to later. diff --git a/src/doc/tutorial_prereqs.dox b/src/doc/tutorial_prereqs.dox index 82079a281b9..72b1fcf8ad8 100644 --- a/src/doc/tutorial_prereqs.dox +++ b/src/doc/tutorial_prereqs.dox @@ -51,7 +51,7 @@ The most difficult part of the installation process relates to the math library ATLAS; if this is not already installed as a library on your system you will have to compile it, and this requires that CPU throttling be turned off, which - may require root priveleges. We provide scripts and detailed instructions for + may require root privileges. We provide scripts and detailed instructions for all installation steps. When scripts fail, read the output carefully because it tries to provide guidance as to how to fix problems. Please inform us if there are problems at any point, however minor; see \ref other. diff --git a/src/doc/tutorial_running.dox b/src/doc/tutorial_running.dox index f977348a3cb..d639cd4e664 100644 --- a/src/doc/tutorial_running.dox +++ b/src/doc/tutorial_running.dox @@ -115,14 +115,14 @@ Now go back to the data directory and change directory to /train. Then execute t \verbatim head text -head spk2gender.map +head spk2gender head spk2utt head utt2spk head wav.scp \endverbatim - text - This file contains mappings between utterances and utterance ids which will be used by Kaldi. This file will be turned into an integer format-- still a text file, but with the words replaced with integers. -- spk2gender.map - This file contains mappings between speakers and their gender. This also acts as a list of unique users involved in training. +- spk2gender - This file contains mappings between speakers and their gender. This also acts as a list of unique users involved in training. - spk2utt - This is a mapping between the speaker identifiers and all the utterance identifiers associated with the speaker. - utt2spk - This is a one-to-one mapping between utterance ids and the corresponding speaker identifiers. - wav.scp - This file is actually read directly by Kaldi programs when doing feature extraction. Look at the file again. It is parsed as a set of key-value pairs, where the key is the first string on each line. The value is a kind of "extended filename", and you can guess how it works. Since it is for reading we will refer to this type of string as an "rxfilename" (for writing we use the term wxfilename). See \ref io_sec_xfilename if you are curious. Note that although we use the extension .scp, this is not a script file in the HTK sense (i.e. it is not viewed as an extension to the command-line arguments). @@ -383,7 +383,7 @@ do copy-tree --binary=false exp/mono/tree - | less \endverbatim Note that this is a monophone "tree" so it is very trivial-- it -does not have any "splits". Although this tree format was not indended to be +does not have any "splits". Although this tree format was not intended to be very human-readable, we have received a number of queries about the tree format so we will explain it. The rest of this paragraph can be skipped over by the casual reader. After "ToPdf", the tree file contains an object of the @@ -442,7 +442,7 @@ Type \verbatim grep Overall exp/mono/log/acc.{?,??}.{?,??}.log \endverbatim -You can see the acoustic likelihods on each iteration. Next look at one of the files +You can see the acoustic likelihoods on each iteration. Next look at one of the files exp/mono/log/update.*.log to see what kind of information is in the update log. When the monophone training is finished, we can test the monophone decoding. Before decoding, we have to create the decode graph. Type: @@ -505,7 +505,7 @@ gmm-decode-faster \endverbatim to see the usage message, and match up the arguments with what you see in the log file. Recall that "rspecifier" is one of those strings that specifies how to read a table, -and "wspecifier" specifies how to write one. Look carefuly at these arguments and try +and "wspecifier" specifies how to write one. Look carefully at these arguments and try to figure out what they mean. Look at the rspecifier that corresponds to the features, and try to understand it (this one has spaces inside, so Kaldi prints it out with single quotes around it so that you could paste it into the shell and the program would run as intended). diff --git a/src/doc/tutorial_setup.dox b/src/doc/tutorial_setup.dox index 11d97a945f9..13f5e3e9c74 100644 --- a/src/doc/tutorial_setup.dox +++ b/src/doc/tutorial_setup.dox @@ -34,16 +34,16 @@ Assuming Git is installed, to get the latest code you can type \verbatim - git clone https://github.com/kaldi-asr/kaldi.git kaldi-trunk --origin golden + git clone https://github.com/kaldi-asr/kaldi.git \endverbatim - Then cd to kaldi-trunk. Look at the INSTALL file and follow the instructions + Then cd to kaldi. Look at the INSTALL file and follow the instructions (it points you to two subdirectories). Look carefully at the output of the installation scripts, as they try to guide you what to do. Some installation errors are non-fatal, and the installation scripts will tell you so (i.e. there are some things it installs which are nice to have but are not really needed). The "best-case" scenario is that you do: \verbatim - cd kaldi-trunk/tools/; make; cd ../src; ./configure; make + cd kaldi/tools/; make; cd ../src; ./configure; make \endverbatim and everything will just work; however, if this does not happen there are fallback plans (e.g. you may have to install some package on your machine, or run diff --git a/src/doc/versions.dox b/src/doc/versions.dox index d12b8621ccd..08e2c2bbda7 100644 --- a/src/doc/versions.dox +++ b/src/doc/versions.dox @@ -28,7 +28,7 @@ \section versions_scheme Versioning scheme - During its lifetime, Kaldi has has three different versioning methods. + During its lifetime, Kaldi has three different versioning methods. Originally Kaldi was a subversion (svn)-based project, and was hosted on Sourceforge. Then Kaldi was moved to github, and for some time the only version-number available was the git hash of the commit. @@ -61,7 +61,7 @@ This is the first major/minor version number after introducing the versioning scheme. The latest revision of version 5.0 is saved as branch "5.0" on github. - Below are patches corresponding to minor version numbers 5.0.x. + Below are commits corresponding to minor version numbers 5.0.x. \htmlinclude 5.0.html @@ -89,7 +89,7 @@ The latest revision of version 5.1 is saved as branch "5.1" on github. - Below are patches corresponding to minor version numbers 5.1.x. + Below are commits corresponding to minor version numbers 5.1.x. \htmlinclude 5.1.html @@ -110,7 +110,7 @@ The latest revision of version 5.2 is saved as branch "5.2" on github. - Below are patches corresponding to minor version numbers 5.1.x. + Below are commits corresponding to minor version numbers 5.1.x. \htmlinclude 5.2.html @@ -121,30 +121,52 @@ - Create a nnet3-based setup for RNN language models (i.e. recurrent and neural net based language models) - Some extentions to the core of the nnet3 framework to support constant values and - scalar multiplication without dedicated compoennts. + scalar multiplication without dedicated components. - Below are patches corresponding to minor version numbers 5.3.x. + Below are commits corresponding to minor version numbers 5.3.x. \htmlinclude 5.3.html \subsection versions_versions_54 Version 5.4 - Version 5.4 is the current master branch. The main changes that were made between + The main changes that were made between the end of 5.3.x and the start of the 5.4 branch include: - Some code changes in the nnet3 codebase, for speed and memory efficiency. - Various simplifications and code reorganizations in the nnet3 code. - - Support for a new kind of factorized TDNN which gives substantially better + - Support for a new kind of factorized TDNN (TDNN-F) which gives substantially better results than our old TDNN recipe, and is even better than our old TDNN+LSTM recipe. A good example of this is in egs/swbd/s5c/local/chain/tuning/run_tdnn_lstm_1n.sh. Some nnet3 code changes were needed for this as well (mostly: support for constraining a matrix to have orthonormal rows). - Below are patches corresponding to minor version numbers 5.4.x. + Some of the larger changes that were made while 5.4 was the major version number include: + - Improvements to handwriting recognition and OCR recipes, including BPE (word-piece) encoding. + - An updated version of the TDNN-F configuration, including ResNet-style bypass, + which is now the default in many recipes. (it's called tdnnf-layer in xconfigs). + - A rewrite of the CUDA memory allocator to be based on a small number of large regions + (since with newer drivers and hardware, allocation speed was becoming a bottleneck). + - A decoder speedup (make use of OpenFst's NumInputEpsilons() function). + + + Below are commits corresponding to minor version numbers 5.4.x. \htmlinclude 5.4.html + \subsection versions_versions_55 Version 5.5 + + + Version 5.5 is the current master branch. The change that was made between the end of + 5.4 and the start of 5.5 is support for \ref grammar grammar decoding; this allows support for things like + the "contact list scenario" where you want to use a dynamically changing contact list in + a larger, fixed decoding graph. + + Below are commits corresponding to minor version numbers 5.5.x. + + + \htmlinclude 5.5.html + */ diff --git a/src/feat/Makefile b/src/feat/Makefile index 2af9da2ec59..dcd029f7f94 100644 --- a/src/feat/Makefile +++ b/src/feat/Makefile @@ -16,7 +16,7 @@ OBJFILES = feature-functions.o feature-mfcc.o feature-plp.o feature-fbank.o \ LIBNAME = kaldi-feat ADDLIBS = ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/feat/feature-common-inl.h b/src/feat/feature-common-inl.h index 546f272e821..26127a4dc4d 100644 --- a/src/feat/feature-common-inl.h +++ b/src/feat/feature-common-inl.h @@ -33,44 +33,29 @@ void OfflineFeatureTpl::ComputeFeatures( Matrix *output) { KALDI_ASSERT(output != NULL); BaseFloat new_sample_freq = computer_.GetFrameOptions().samp_freq; - if (sample_freq == new_sample_freq) + if (sample_freq == new_sample_freq) { Compute(wave, vtln_warp, output); - else { - if (new_sample_freq < sample_freq) { - if (! computer_.GetFrameOptions().allow_downsample) + } else { + if (new_sample_freq < sample_freq && + ! computer_.GetFrameOptions().allow_downsample) KALDI_ERR << "Waveform and config sample Frequency mismatch: " << sample_freq << " .vs " << new_sample_freq - << " ( use --allow_downsample=true option to allow " + << " (use --allow-downsample=true to allow " << " downsampling the waveform)."; - - // Downsample the waveform. - Vector downsampled_wave(wave); - DownsampleWaveForm(sample_freq, wave, - new_sample_freq, &downsampled_wave); - Compute(downsampled_wave, vtln_warp, output); - } else - KALDI_ERR << "The waveform is allowed to get downsampled." - << "New sample Frequency " << new_sample_freq - << " is larger than waveform original sampling frequency " - << sample_freq; - + else if (new_sample_freq > sample_freq && + ! computer_.GetFrameOptions().allow_upsample) + KALDI_ERR << "Waveform and config sample Frequency mismatch: " + << sample_freq << " .vs " << new_sample_freq + << " (use --allow-upsample=true option to allow " + << " upsampling the waveform)."; + // Resample the waveform. + Vector resampled_wave(wave); + ResampleWaveform(sample_freq, wave, + new_sample_freq, &resampled_wave); + Compute(resampled_wave, vtln_warp, output); } } -template -void OfflineFeatureTpl::ComputeFeatures( - const VectorBase &wave, - BaseFloat sample_freq, - BaseFloat vtln_warp, - Matrix *output) const { - OfflineFeatureTpl temp(*this); - // This const version of ComputeFeatures() is a wrapper that - // calls the non-const ComputeFeatures() on a temporary object - // that is a copy of *this. It is not as efficient because of the - // overhead of copying *this. - temp.ComputeFeatures(wave, vtln_warp, output); -} - template void OfflineFeatureTpl::Compute( const VectorBase &wave, diff --git a/src/feat/feature-common.h b/src/feat/feature-common.h index 1c83aed8ea9..45911cef585 100644 --- a/src/feat/feature-common.h +++ b/src/feat/feature-common.h @@ -152,16 +152,6 @@ class OfflineFeatureTpl { BaseFloat sample_freq, BaseFloat vtln_warp, Matrix *output); - /** - This const version of ComputeFeatures() is a wrapper that - calls the non-const ComputeFeatures() on a temporary object - that is a copy of *this. It is not as efficient because of the - overhead of copying *this. - */ - void ComputeFeatures(const VectorBase &wave, - BaseFloat sample_freq, - BaseFloat vtln_warp, - Matrix *output) const; int32 Dim() const { return computer_.Dim(); } diff --git a/src/feat/feature-fbank.cc b/src/feat/feature-fbank.cc index c54069696b5..10f7e67d607 100644 --- a/src/feat/feature-fbank.cc +++ b/src/feat/feature-fbank.cc @@ -82,8 +82,8 @@ void FbankComputer::Compute(BaseFloat signal_log_energy, // Compute energy after window function (not the raw one). if (opts_.use_energy && !opts_.raw_energy) - signal_log_energy = Log(std::max(VecVec(*signal_frame, *signal_frame), - std::numeric_limits::min())); + signal_log_energy = Log(std::max(VecVec(*signal_frame, *signal_frame), + std::numeric_limits::min())); if (srfft_ != NULL) // Compute FFT using split-radix algorithm. srfft_->Compute(signal_frame->Data(), true); @@ -108,7 +108,7 @@ void FbankComputer::Compute(BaseFloat signal_log_energy, mel_banks.Compute(power_spectrum, &mel_energies); if (opts_.use_log_fbank) { // Avoid log of zero (which should be prevented anyway by dithering). - mel_energies.ApplyFloor(std::numeric_limits::epsilon()); + mel_energies.ApplyFloor(std::numeric_limits::epsilon()); mel_energies.ApplyLog(); // take the log. } diff --git a/src/feat/feature-fbank.h b/src/feat/feature-fbank.h index 41ef2eef50a..724d7d148dc 100644 --- a/src/feat/feature-fbank.h +++ b/src/feat/feature-fbank.h @@ -53,7 +53,7 @@ struct FbankOptions { // this seems to be common for 16khz-sampled data, // but for 8khz-sampled data, 15 may be better. use_energy(false), - energy_floor(0.0), // not in log scale: a small value e.g. 1.0e-10 + energy_floor(0.0), raw_energy(true), htk_compat(false), use_log_fbank(true), @@ -65,7 +65,9 @@ struct FbankOptions { opts->Register("use-energy", &use_energy, "Add an extra dimension with energy to the FBANK output."); opts->Register("energy-floor", &energy_floor, - "Floor on energy (absolute, not relative) in FBANK computation"); + "Floor on energy (absolute, not relative) in FBANK computation. " + "Only makes a difference if --use-energy=true; only necessary if " + "--dither=0.0. Suggested values: 0.1 or 1.0"); opts->Register("raw-energy", &raw_energy, "If true, compute energy before preemphasis and windowing"); opts->Register("htk-compat", &htk_compat, "If true, put energy last. " diff --git a/src/feat/feature-mfcc.cc b/src/feat/feature-mfcc.cc index 122ba1b100d..899988c2822 100644 --- a/src/feat/feature-mfcc.cc +++ b/src/feat/feature-mfcc.cc @@ -35,8 +35,8 @@ void MfccComputer::Compute(BaseFloat signal_log_energy, const MelBanks &mel_banks = *(GetMelBanks(vtln_warp)); if (opts_.use_energy && !opts_.raw_energy) - signal_log_energy = Log(std::max(VecVec(*signal_frame, *signal_frame), - std::numeric_limits::min())); + signal_log_energy = Log(std::max(VecVec(*signal_frame, *signal_frame), + std::numeric_limits::min())); if (srfft_ != NULL) // Compute FFT using the split-radix algorithm. srfft_->Compute(signal_frame->Data(), true); @@ -51,7 +51,7 @@ void MfccComputer::Compute(BaseFloat signal_log_energy, mel_banks.Compute(power_spectrum, &mel_energies_); // avoid log of zero (which should be prevented anyway by dithering). - mel_energies_.ApplyFloor(std::numeric_limits::epsilon()); + mel_energies_.ApplyFloor(std::numeric_limits::epsilon()); mel_energies_.ApplyLog(); // take the log. feature->SetZero(); // in case there were NaNs. diff --git a/src/feat/feature-mfcc.h b/src/feat/feature-mfcc.h index d1d2b8f9d09..66c52e89821 100644 --- a/src/feat/feature-mfcc.h +++ b/src/feat/feature-mfcc.h @@ -40,7 +40,8 @@ struct MfccOptions { MelBanksOptions mel_opts; int32 num_ceps; // e.g. 13: num cepstral coeffs, counting zero. bool use_energy; // use energy; else C0 - BaseFloat energy_floor; + BaseFloat energy_floor; // 0 by default; set to a value like 1.0 or 0.1 if + // you disable dithering. bool raw_energy; // If true, compute energy before preemphasis and windowing BaseFloat cepstral_lifter; // Scaling factor on cepstra for HTK compatibility. // if 0.0, no liftering is done. @@ -53,7 +54,7 @@ struct MfccOptions { // but for 8khz-sampled data, 15 may be better. num_ceps(13), use_energy(true), - energy_floor(0.0), // not in log scale: a small value e.g. 1.0e-10 + energy_floor(0.0), raw_energy(true), cepstral_lifter(22.0), htk_compat(false) {} @@ -66,7 +67,9 @@ struct MfccOptions { opts->Register("use-energy", &use_energy, "Use energy (not C0) in MFCC computation"); opts->Register("energy-floor", &energy_floor, - "Floor on energy (absolute, not relative) in MFCC computation"); + "Floor on energy (absolute, not relative) in MFCC computation. " + "Only makes a difference if --use-energy=true; only necessary if " + "--dither=0.0. Suggested values: 0.1 or 1.0"); opts->Register("raw-energy", &raw_energy, "If true, compute energy before preemphasis and windowing"); opts->Register("cepstral-lifter", &cepstral_lifter, diff --git a/src/feat/feature-plp.cc b/src/feat/feature-plp.cc index 719e55dd6da..8f4a7d66161 100644 --- a/src/feat/feature-plp.cc +++ b/src/feat/feature-plp.cc @@ -124,8 +124,8 @@ void PlpComputer::Compute(BaseFloat signal_log_energy, if (opts_.use_energy && !opts_.raw_energy) - signal_log_energy = Log(std::max(VecVec(*signal_frame, *signal_frame), - std::numeric_limits::min())); + signal_log_energy = Log(std::max(VecVec(*signal_frame, *signal_frame), + std::numeric_limits::min())); if (srfft_ != NULL) // Compute FFT using split-radix algorithm. srfft_->Compute(signal_frame->Data(), true); @@ -159,8 +159,8 @@ void PlpComputer::Compute(BaseFloat signal_log_energy, BaseFloat residual_log_energy = ComputeLpc(autocorr_coeffs_, &lpc_coeffs_); - residual_log_energy = std::max(residual_log_energy, - std::numeric_limits::min()); + residual_log_energy = std::max(residual_log_energy, + std::numeric_limits::min()); Lpc2Cepstrum(opts_.lpc_order, lpc_coeffs_.Data(), raw_cepstrum_.Data()); feature->Range(1, opts_.num_ceps - 1).CopyFromVec( diff --git a/src/feat/feature-plp.h b/src/feat/feature-plp.h index d7deab07ec1..958c5706e89 100644 --- a/src/feat/feature-plp.h +++ b/src/feat/feature-plp.h @@ -61,7 +61,7 @@ struct PlpOptions { lpc_order(12), num_ceps(13), use_energy(true), - energy_floor(0.0), // not in log scale: a small value e.g. 1.0e-10 + energy_floor(0.0), raw_energy(true), compress_factor(0.33333), cepstral_lifter(22), @@ -78,7 +78,9 @@ struct PlpOptions { opts->Register("use-energy", &use_energy, "Use energy (not C0) for zeroth PLP feature"); opts->Register("energy-floor", &energy_floor, - "Floor on energy (absolute, not relative) in PLP computation"); + "Floor on energy (absolute, not relative) in PLP computation. " + "Only makes a difference if --use-energy=true; only necessary if " + "--dither=0.0. Suggested values: 0.1 or 1.0"); opts->Register("raw-energy", &raw_energy, "If true, compute energy before preemphasis and windowing"); opts->Register("compress-factor", &compress_factor, diff --git a/src/feat/feature-spectrogram.cc b/src/feat/feature-spectrogram.cc index 953f38fc54f..d2daa7aa829 100644 --- a/src/feat/feature-spectrogram.cc +++ b/src/feat/feature-spectrogram.cc @@ -54,8 +54,8 @@ void SpectrogramComputer::Compute(BaseFloat signal_log_energy, // Compute energy after window function (not the raw one) if (!opts_.raw_energy) - signal_log_energy = Log(std::max(VecVec(*signal_frame, *signal_frame), - std::numeric_limits::epsilon())); + signal_log_energy = Log(std::max(VecVec(*signal_frame, *signal_frame), + std::numeric_limits::epsilon())); if (srfft_ != NULL) // Compute FFT using split-radix algorithm. srfft_->Compute(signal_frame->Data(), true); @@ -67,7 +67,7 @@ void SpectrogramComputer::Compute(BaseFloat signal_log_energy, SubVector power_spectrum(*signal_frame, 0, signal_frame->Dim() / 2 + 1); - power_spectrum.ApplyFloor(std::numeric_limits::epsilon()); + power_spectrum.ApplyFloor(std::numeric_limits::epsilon()); power_spectrum.ApplyLog(); feature->CopyFromVec(power_spectrum); diff --git a/src/feat/feature-spectrogram.h b/src/feat/feature-spectrogram.h index ec318556f24..9aeb68c8df8 100644 --- a/src/feat/feature-spectrogram.h +++ b/src/feat/feature-spectrogram.h @@ -41,13 +41,17 @@ struct SpectrogramOptions { bool raw_energy; // If true, compute energy before preemphasis and windowing SpectrogramOptions() : - energy_floor(0.0), // not in log scale: a small value e.g. 1.0e-10 + energy_floor(0.0), raw_energy(true) {} void Register(OptionsItf *opts) { frame_opts.Register(opts); opts->Register("energy-floor", &energy_floor, - "Floor on energy (absolute, not relative) in Spectrogram computation"); + "Floor on energy (absolute, not relative) in Spectrogram " + "computation. Caution: this floor is applied to the zeroth " + "component, representing the total signal energy. The " + "floor on the individual spectrogram elements is fixed at " + "std::numeric_limits::epsilon()."); opts->Register("raw-energy", &raw_energy, "If true, compute energy before preemphasis and windowing"); } diff --git a/src/feat/feature-window.cc b/src/feat/feature-window.cc index 98afe1849e9..c5d4cc29831 100644 --- a/src/feat/feature-window.cc +++ b/src/feat/feature-window.cc @@ -144,8 +144,8 @@ void ProcessWindow(const FrameExtractionOptions &opts, window->Add(-window->Sum() / frame_length); if (log_energy_pre_window != NULL) { - BaseFloat energy = std::max(VecVec(*window, *window), - std::numeric_limits::epsilon()); + BaseFloat energy = std::max(VecVec(*window, *window), + std::numeric_limits::epsilon()); *log_energy_pre_window = Log(energy); } @@ -219,20 +219,4 @@ void ExtractWindow(int64 sample_offset, ProcessWindow(opts, window_function, &frame, log_energy_pre_window); } -void ExtractWaveformRemainder(const VectorBase &wave, - const FrameExtractionOptions &opts, - Vector *wave_remainder) { - int32 frame_shift = opts.WindowShift(); - int32 num_frames = NumFrames(wave.Dim(), opts); - // offset is the amount at the start that has been extracted. - int32 offset = num_frames * frame_shift; - KALDI_ASSERT(wave_remainder != NULL); - int32 remaining_len = wave.Dim() - offset; - wave_remainder->Resize(remaining_len); - KALDI_ASSERT(remaining_len >= 0); - if (remaining_len > 0) - wave_remainder->CopyFromVec(SubVector(wave, offset, remaining_len)); -} - - } // namespace kaldi diff --git a/src/feat/feature-window.h b/src/feat/feature-window.h index a897c6fa4b0..c9172521d7c 100644 --- a/src/feat/feature-window.h +++ b/src/feat/feature-window.h @@ -40,14 +40,16 @@ struct FrameExtractionOptions { BaseFloat preemph_coeff; // Preemphasis coefficient. bool remove_dc_offset; // Subtract mean of wave before FFT. std::string window_type; // e.g. Hamming window - bool round_to_power_of_two; - BaseFloat blackman_coeff; - bool snip_edges; - bool allow_downsample; // May be "hamming", "rectangular", "povey", "hanning", "blackman" // "povey" is a window I made to be similar to Hamming but to go to zero at the // edges, it's pow((0.5 - 0.5*cos(n/N*2*pi)), 0.85) // I just don't think the Hamming window makes sense as a windowing function. + bool round_to_power_of_two; + BaseFloat blackman_coeff; + bool snip_edges; + bool allow_downsample; + bool allow_upsample; + int max_feature_vectors; FrameExtractionOptions(): samp_freq(16000), frame_shift_ms(10.0), @@ -59,7 +61,10 @@ struct FrameExtractionOptions { round_to_power_of_two(true), blackman_coeff(0.42), snip_edges(true), - allow_downsample(false) { } + allow_downsample(false), + allow_upsample(false), + max_feature_vectors(-1) + { } void Register(OptionsItf *opts) { opts->Register("sample-frequency", &samp_freq, @@ -71,7 +76,9 @@ struct FrameExtractionOptions { "Coefficient for use in signal preemphasis"); opts->Register("remove-dc-offset", &remove_dc_offset, "Subtract mean from waveform on each frame"); - opts->Register("dither", &dither, "Dithering constant (0.0 means no dither)"); + opts->Register("dither", &dither, "Dithering constant (0.0 means no dither). " + "If you turn this off, you should set the --energy-floor " + "option, e.g. to 1.0 or 0.1"); opts->Register("window-type", &window_type, "Type of window " "(\"hamming\"|\"hanning\"|\"povey\"|\"rectangular\"" "|\"blackmann\")"); @@ -88,6 +95,13 @@ struct FrameExtractionOptions { opts->Register("allow-downsample", &allow_downsample, "If true, allow the input waveform to have a higher frequency than " "the specified --sample-frequency (and we'll downsample)."); + opts->Register("max-feature-vectors", &max_feature_vectors, + "Memory optimization. If larger than 0, periodically remove feature " + "vectors so that only this number of the latest feature vectors is " + "retained."); + opts->Register("allow-upsample", &allow_upsample, + "If true, allow the input waveform to have a lower frequency than " + "the specified --sample-frequency (and we'll upsample)."); } int32 WindowShift() const { return static_cast(samp_freq * 0.001 * frame_shift_ms); @@ -202,15 +216,6 @@ void ExtractWindow(int64 sample_offset, BaseFloat *log_energy_pre_window = NULL); -// ExtractWaveformRemainder is useful if the waveform is coming in segments. -// It extracts the bit of the waveform at the end of this block that you -// would have to append the next bit of waveform to, if you wanted to have -// the same effect as everything being in one big block. -void ExtractWaveformRemainder(const VectorBase &wave, - const FrameExtractionOptions &opts, - Vector *wave_remainder); - - /// @} End of "addtogroup feat" } // namespace kaldi diff --git a/src/feat/mel-computations.h b/src/feat/mel-computations.h index 5df36c8cb90..7053da54f3a 100644 --- a/src/feat/mel-computations.h +++ b/src/feat/mel-computations.h @@ -63,7 +63,7 @@ struct MelBanksOptions { opts->Register("low-freq", &low_freq, "Low cutoff frequency for mel bins"); opts->Register("high-freq", &high_freq, - "High cutoff frequency for mel bins (if < 0, offset from Nyquist)"); + "High cutoff frequency for mel bins (if <= 0, offset from Nyquist)"); opts->Register("vtln-low", &vtln_low, "Low inflection point in piecewise linear VTLN warping function"); opts->Register("vtln-high", &vtln_high, diff --git a/src/feat/online-feature-test.cc b/src/feat/online-feature-test.cc index e3a1d5f99f3..7ba6c7c32be 100644 --- a/src/feat/online-feature-test.cc +++ b/src/feat/online-feature-test.cc @@ -375,6 +375,45 @@ void TestOnlineAppendFeature() { } } +void TestRecyclingVector() { + RecyclingVector full_vec; + RecyclingVector shrinking_vec(10); + for (int i = 0; i != 100; ++i) { + Vector data(1); + data.Set(i); + full_vec.PushBack(new Vector(data)); + shrinking_vec.PushBack(new Vector(data)); + } + KALDI_ASSERT(full_vec.Size() == 100); + KALDI_ASSERT(shrinking_vec.Size() == 100); + + // full_vec should contain everything + for (int i = 0; i != 100; ++i) { + Vector *data = full_vec.At(i); + KALDI_ASSERT(data != nullptr); + KALDI_ASSERT((*data)(0) == static_cast(i)); + } + + // shrinking_vec may throw an exception for the first 90 elements + int caught_exceptions = 0; + for (int i = 0; i != 90; ++i) { + try { + shrinking_vec.At(i); + } catch (const std::runtime_error &) { + ++caught_exceptions; + } + } + // it may actually store a bit more elements for performance efficiency considerations + KALDI_ASSERT(caught_exceptions >= 80); + + // shrinking_vec should contain the last 10 elements + for (int i = 90; i != 100; ++i) { + Vector *data = shrinking_vec.At(i); + KALDI_ASSERT(data != nullptr); + KALDI_ASSERT((*data)(0) == static_cast(i)); + } +} + } // end namespace kaldi int main() { @@ -387,6 +426,7 @@ int main() { TestOnlinePlp(); TestOnlineTransform(); TestOnlineAppendFeature(); + TestRecyclingVector(); } std::cout << "Test OK.\n"; } diff --git a/src/feat/online-feature.cc b/src/feat/online-feature.cc index 267a4724580..a60e7fb8d61 100644 --- a/src/feat/online-feature.cc +++ b/src/feat/online-feature.cc @@ -24,50 +24,136 @@ namespace kaldi { -template +RecyclingVector::RecyclingVector(int items_to_hold): + items_to_hold_(items_to_hold == 0 ? -1 : items_to_hold), + first_available_index_(0) { +} + +RecyclingVector::~RecyclingVector() { + for (auto *item : items_) { + delete item; + } +} + +Vector *RecyclingVector::At(int index) const { + if (index < first_available_index_) { + KALDI_ERR << "Attempted to retrieve feature vector that was " + "already removed by the RecyclingVector (index = " + << index << "; " + << "first_available_index = " << first_available_index_ << "; " + << "size = " << Size() << ")"; + } + // 'at' does size checking. + return items_.at(index - first_available_index_); +} + +void RecyclingVector::PushBack(Vector *item) { + if (items_.size() == items_to_hold_) { + delete items_.front(); + items_.pop_front(); + ++first_available_index_; + } + items_.push_back(item); +} + +int RecyclingVector::Size() const { + return first_available_index_ + items_.size(); +} + +template void OnlineGenericBaseFeature::GetFrame(int32 frame, VectorBase *feat) { - // 'at' does size checking. - feat->CopyFromVec(*(features_.at(frame))); + feat->CopyFromVec(*(features_.At(frame))); }; -template +template OnlineGenericBaseFeature::OnlineGenericBaseFeature( const typename C::Options &opts): computer_(opts), window_function_(computer_.GetFrameOptions()), + features_(opts.frame_opts.max_feature_vectors), input_finished_(false), waveform_offset_(0) { } -template -void OnlineGenericBaseFeature::AcceptWaveform(BaseFloat sampling_rate, - const VectorBase &waveform) { + +template +void OnlineGenericBaseFeature::MaybeCreateResampler( + BaseFloat sampling_rate) { BaseFloat expected_sampling_rate = computer_.GetFrameOptions().samp_freq; - if (sampling_rate != expected_sampling_rate) + + if (resampler_ != nullptr) { + KALDI_ASSERT(resampler_->GetInputSamplingRate() == sampling_rate); + KALDI_ASSERT(resampler_->GetOutputSamplingRate() == expected_sampling_rate); + } else if (((sampling_rate > expected_sampling_rate) && + !computer_.GetFrameOptions().allow_downsample) || + ((sampling_rate > expected_sampling_rate) && + !computer_.GetFrameOptions().allow_upsample)) { + resampler_.reset(new LinearResample( + sampling_rate, expected_sampling_rate, + std::min(sampling_rate / 2, expected_sampling_rate / 2), 6)); + } else if (sampling_rate != expected_sampling_rate) { KALDI_ERR << "Sampling frequency mismatch, expected " - << expected_sampling_rate << ", got " << sampling_rate; - if (waveform.Dim() == 0) + << expected_sampling_rate << ", got " << sampling_rate + << "\nPerhaps you want to use the options " + "--allow_{upsample,downsample}"; + } +} + +template +void OnlineGenericBaseFeature::InputFinished() { + if (resampler_ != nullptr) { + Vector appended_wave; + Vector resampled_wave; + resampler_->Resample(appended_wave, true, &resampled_wave); + + if (waveform_remainder_.Dim() != 0) + appended_wave.Range(0, waveform_remainder_.Dim()) + .CopyFromVec(waveform_remainder_); + appended_wave.Range(waveform_remainder_.Dim(), resampled_wave.Dim()) + .CopyFromVec(resampled_wave); + waveform_remainder_.Swap(&appended_wave); + } + input_finished_ = true; + ComputeFeatures(); +} + +template +void OnlineGenericBaseFeature::AcceptWaveform( + BaseFloat sampling_rate, const VectorBase &original_waveform) { + if (original_waveform.Dim() == 0) return; // Nothing to do. if (input_finished_) KALDI_ERR << "AcceptWaveform called after InputFinished() was called."; - // append 'waveform' to 'waveform_remainder_.' - Vector appended_wave(waveform_remainder_.Dim() + waveform.Dim()); + + Vector appended_wave; + Vector resampled_wave; + + const VectorBase *waveform; + + MaybeCreateResampler(sampling_rate); + if (resampler_ == nullptr) { + waveform = &original_waveform; + } else { + resampler_->Resample(original_waveform, false, &resampled_wave); + waveform = &resampled_wave; + } + + appended_wave.Resize(waveform_remainder_.Dim() + waveform->Dim()); if (waveform_remainder_.Dim() != 0) - appended_wave.Range(0, waveform_remainder_.Dim()).CopyFromVec( - waveform_remainder_); - appended_wave.Range(waveform_remainder_.Dim(), waveform.Dim()).CopyFromVec( - waveform); + appended_wave.Range(0, waveform_remainder_.Dim()) + .CopyFromVec(waveform_remainder_); + appended_wave.Range(waveform_remainder_.Dim(), waveform->Dim()) + .CopyFromVec(*waveform); waveform_remainder_.Swap(&appended_wave); ComputeFeatures(); } -template +template void OnlineGenericBaseFeature::ComputeFeatures() { const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions(); int64 num_samples_total = waveform_offset_ + waveform_remainder_.Dim(); - int32 num_frames_old = features_.size(), + int32 num_frames_old = features_.Size(), num_frames_new = NumFrames(num_samples_total, frame_opts, input_finished_); KALDI_ASSERT(num_frames_new >= num_frames_old); - features_.resize(num_frames_new, NULL); Vector window; bool need_raw_log_energy = computer_.NeedRawLogEnergy(); @@ -81,7 +167,7 @@ void OnlineGenericBaseFeature::ComputeFeatures() { // note: this online feature-extraction code does not support VTLN. BaseFloat vtln_warp = 1.0; computer_.Compute(raw_log_energy, vtln_warp, &window, this_feature); - features_[frame] = this_feature; + features_.PushBack(this_feature); } // OK, we will now discard any portion of the signal that will not be // necessary to compute frames in the future. @@ -110,7 +196,6 @@ template class OnlineGenericBaseFeature; template class OnlineGenericBaseFeature; template class OnlineGenericBaseFeature; - OnlineCmvnState::OnlineCmvnState(const OnlineCmvnState &other): speaker_cmvn_stats(other.speaker_cmvn_stats), global_cmvn_stats(other.global_cmvn_stats), @@ -138,12 +223,12 @@ void OnlineCmvnState::Read(std::istream &is, bool binary) { ExpectToken(is, binary, ""); } - - OnlineCmvn::OnlineCmvn(const OnlineCmvnOptions &opts, const OnlineCmvnState &cmvn_state, OnlineFeatureInterface *src): - opts_(opts), src_(src) { + opts_(opts), temp_stats_(2, src->Dim() + 1), + temp_feats_(src->Dim()), temp_feats_dbl_(src->Dim()), + src_(src) { SetState(cmvn_state); if (!SplitStringToIntegers(opts.skip_dims, ":", false, &skip_dims_)) KALDI_ERR << "Bad --skip-dims option (should be colon-separated list of " @@ -151,7 +236,10 @@ OnlineCmvn::OnlineCmvn(const OnlineCmvnOptions &opts, } OnlineCmvn::OnlineCmvn(const OnlineCmvnOptions &opts, - OnlineFeatureInterface *src): opts_(opts), src_(src) { + OnlineFeatureInterface *src): + opts_(opts), temp_stats_(2, src->Dim() + 1), + temp_feats_(src->Dim()), temp_feats_dbl_(src->Dim()), + src_(src) { if (!SplitStringToIntegers(opts.skip_dims, ":", false, &skip_dims_)) KALDI_ERR << "Bad --skip-dims option (should be colon-separated list of " << "integers)"; @@ -160,7 +248,7 @@ OnlineCmvn::OnlineCmvn(const OnlineCmvnOptions &opts, void OnlineCmvn::GetMostRecentCachedFrame(int32 frame, int32 *cached_frame, - Matrix *stats) { + MatrixBase *stats) { KALDI_ASSERT(frame >= 0); InitRingBufferIfNeeded(); // look for a cached frame on a previous frame as close as possible in time @@ -174,7 +262,7 @@ void OnlineCmvn::GetMostRecentCachedFrame(int32 frame, int32 index = t % opts_.ring_buffer_size; if (cached_stats_ring_[index].first == t) { *cached_frame = t; - *stats = cached_stats_ring_[index].second; + stats->CopyFromMat(cached_stats_ring_[index].second); return; } } @@ -182,7 +270,7 @@ void OnlineCmvn::GetMostRecentCachedFrame(int32 frame, if (n >= cached_stats_modulo_.size()) { if (cached_stats_modulo_.size() == 0) { *cached_frame = -1; - stats->Resize(2, this->Dim() + 1); + stats->SetZero(); return; } else { n = static_cast(cached_stats_modulo_.size() - 1); @@ -190,7 +278,7 @@ void OnlineCmvn::GetMostRecentCachedFrame(int32 frame, } *cached_frame = n * opts_.modulus; KALDI_ASSERT(cached_stats_modulo_[n] != NULL); - *stats = *(cached_stats_modulo_[n]); + stats->CopyFromMat(*(cached_stats_modulo_[n])); } // Initialize ring buffer for caching stats. @@ -202,7 +290,7 @@ void OnlineCmvn::InitRingBufferIfNeeded() { } } -void OnlineCmvn::CacheFrame(int32 frame, const Matrix &stats) { +void OnlineCmvn::CacheFrame(int32 frame, const MatrixBase &stats) { KALDI_ASSERT(frame >= 0); if (frame % opts_.modulus == 0) { // store in cached_stats_modulo_. int32 n = frame / opts_.modulus; @@ -239,18 +327,18 @@ void OnlineCmvn::ComputeStatsForFrame(int32 frame, KALDI_ASSERT(frame >= 0 && frame < src_->NumFramesReady()); int32 dim = this->Dim(), cur_frame; - Matrix stats(2, dim + 1); - GetMostRecentCachedFrame(frame, &cur_frame, &stats); + GetMostRecentCachedFrame(frame, &cur_frame, stats_out); - Vector feats(dim); - Vector feats_dbl(dim); + Vector &feats(temp_feats_); + Vector &feats_dbl(temp_feats_dbl_); while (cur_frame < frame) { cur_frame++; src_->GetFrame(cur_frame, &feats); feats_dbl.CopyFromVec(feats); - stats.Row(0).Range(0, dim).AddVec(1.0, feats_dbl); - stats.Row(1).Range(0, dim).AddVec2(1.0, feats_dbl); - stats(0, dim) += 1.0; + stats_out->Row(0).Range(0, dim).AddVec(1.0, feats_dbl); + if (opts_.normalize_variance) + stats_out->Row(1).Range(0, dim).AddVec2(1.0, feats_dbl); + (*stats_out)(0, dim) += 1.0; // it's a sliding buffer; a frame at the back may be // leaving the buffer so we have to subtract that. int32 prev_frame = cur_frame - opts_.cmn_window; @@ -258,13 +346,13 @@ void OnlineCmvn::ComputeStatsForFrame(int32 frame, // we need to subtract frame prev_f from the stats. src_->GetFrame(prev_frame, &feats); feats_dbl.CopyFromVec(feats); - stats.Row(0).Range(0, dim).AddVec(-1.0, feats_dbl); - stats.Row(1).Range(0, dim).AddVec2(-1.0, feats_dbl); - stats(0, dim) -= 1.0; + stats_out->Row(0).Range(0, dim).AddVec(-1.0, feats_dbl); + if (opts_.normalize_variance) + stats_out->Row(1).Range(0, dim).AddVec2(-1.0, feats_dbl); + (*stats_out)(0, dim) -= 1.0; } - CacheFrame(cur_frame, stats); + CacheFrame(cur_frame, (*stats_out)); } - stats_out->CopyFromMat(stats); } @@ -273,12 +361,23 @@ void OnlineCmvn::SmoothOnlineCmvnStats(const MatrixBase &speaker_stats, const MatrixBase &global_stats, const OnlineCmvnOptions &opts, MatrixBase *stats) { + if (speaker_stats.NumRows() == 2 && !opts.normalize_variance) { + // this is just for efficiency: don't operate on the variance if it's not + // needed. + int32 cols = speaker_stats.NumCols(); // dim + 1 + SubMatrix stats_temp(*stats, 0, 1, 0, cols); + SmoothOnlineCmvnStats(speaker_stats.RowRange(0, 1), + global_stats.RowRange(0, 1), + opts, &stats_temp); + return; + } int32 dim = stats->NumCols() - 1; double cur_count = (*stats)(0, dim); // If count exceeded cmn_window it would be an error in how "window_stats" // was accumulated. KALDI_ASSERT(cur_count <= 1.001 * opts.cmn_window); - if (cur_count >= opts.cmn_window) return; + if (cur_count >= opts.cmn_window) + return; if (speaker_stats.NumRows() != 0) { // if we have speaker stats.. double count_from_speaker = opts.cmn_window - cur_count, speaker_count = speaker_stats(0, dim); @@ -291,7 +390,8 @@ void OnlineCmvn::SmoothOnlineCmvnStats(const MatrixBase &speaker_stats, speaker_stats); cur_count = (*stats)(0, dim); } - if (cur_count >= opts.cmn_window) return; + if (cur_count >= opts.cmn_window) + return; if (global_stats.NumRows() != 0) { double count_from_global = opts.cmn_window - cur_count, global_count = global_stats(0, dim); @@ -311,7 +411,8 @@ void OnlineCmvn::GetFrame(int32 frame, src_->GetFrame(frame, feat); KALDI_ASSERT(feat->Dim() == this->Dim()); int32 dim = feat->Dim(); - Matrix stats(2, dim + 1); + Matrix &stats(temp_stats_); + stats.Resize(2, dim + 1, kUndefined); // Will do nothing if size was correct. if (frozen_state_.NumRows() != 0) { // the CMVN state has been frozen. stats.CopyFromMat(frozen_state_); } else { @@ -329,14 +430,13 @@ void OnlineCmvn::GetFrame(int32 frame, // call the function ApplyCmvn declared in ../transform/cmvn.h, which // requires a matrix. - Matrix feat_mat(1, dim); - feat_mat.Row(0).CopyFromVec(*feat); + // 1 row; num-cols == dim; stride == dim. + SubMatrix feat_mat(feat->Data(), 1, dim, dim); // the function ApplyCmvn takes a matrix, so form a one-row matrix to give it. if (opts_.normalize_mean) ApplyCmvn(stats, opts_.normalize_variance, &feat_mat); else KALDI_ASSERT(!opts_.normalize_variance); - feat->CopyFromVec(feat_mat.Row(0)); } void OnlineCmvn::Freeze(int32 cur_frame) { @@ -383,7 +483,7 @@ void OnlineCmvn::SetState(const OnlineCmvnState &cmvn_state) { int32 OnlineSpliceFrames::NumFramesReady() const { int32 num_frames = src_->NumFramesReady(); - if (num_frames > 0 && src_->IsLastFrame(num_frames-1)) + if (num_frames > 0 && src_->IsLastFrame(num_frames - 1)) return num_frames; else return std::max(0, num_frames - right_context_); @@ -430,6 +530,17 @@ void OnlineTransform::GetFrame(int32 frame, VectorBase *feat) { feat->AddMatVec(1.0, linear_term_, kNoTrans, input_feat, 1.0); } +void OnlineTransform::GetFrames( + const std::vector &frames, MatrixBase *feats) { + KALDI_ASSERT(static_cast(frames.size()) == feats->NumRows()); + int32 num_frames = feats->NumRows(), + input_dim = linear_term_.NumCols(); + Matrix input_feats(num_frames, input_dim, kUndefined); + src_->GetFrames(frames, &input_feats); + feats->CopyRowsFromVec(offset_); + feats->AddMatMat(1.0, input_feats, kNoTrans, linear_term_, kTrans, 1.0); +} + int32 OnlineDeltaFeature::Dim() const { int32 src_dim = src_->Dim(); @@ -493,6 +604,49 @@ void OnlineCacheFeature::GetFrame(int32 frame, VectorBase *feat) { } } +void OnlineCacheFeature::GetFrames( + const std::vector &frames, MatrixBase *feats) { + int32 num_frames = frames.size(); + // non_cached_frames will be the subset of 't' values in 'frames' which were + // not previously cached, which we therefore need to get from src_. + std::vector non_cached_frames; + // 'non_cached_indexes' stores the indexes 'i' into 'frames' corresponding to + // the corresponding frames in 'non_cached_frames'. + std::vector non_cached_indexes; + non_cached_frames.reserve(frames.size()); + non_cached_indexes.reserve(frames.size()); + for (int32 i = 0; i < num_frames; i++) { + int32 t = frames[i]; + if (static_cast(t) < cache_.size() && cache_[t] != NULL) { + feats->Row(i).CopyFromVec(*(cache_[t])); + } else { + non_cached_frames.push_back(t); + non_cached_indexes.push_back(i); + } + } + if (non_cached_frames.empty()) + return; + int32 num_non_cached_frames = non_cached_frames.size(), + dim = this->Dim(); + Matrix non_cached_feats(num_non_cached_frames, dim, + kUndefined); + src_->GetFrames(non_cached_frames, &non_cached_feats); + for (int32 i = 0; i < num_non_cached_frames; i++) { + int32 t = non_cached_frames[i]; + if (static_cast(t) < cache_.size() && cache_[t] != NULL) { + // We can reach this point due to repeat indexes in 'non_cached_frames'. + feats->Row(non_cached_indexes[i]).CopyFromVec(*(cache_[t])); + } else { + SubVector this_feat(non_cached_feats, i); + feats->Row(non_cached_indexes[i]).CopyFromVec(this_feat); + if (static_cast(t) >= cache_.size()) + cache_.resize(t + 1, NULL); + cache_[t] = new Vector(this_feat); + } + } +} + + void OnlineCacheFeature::ClearCache() { for (size_t i = 0; i < cache_.size(); i++) delete cache_[i]; @@ -500,7 +654,6 @@ void OnlineCacheFeature::ClearCache() { } - void OnlineAppendFeature::GetFrame(int32 frame, VectorBase *feat) { KALDI_ASSERT(feat->Dim() == Dim()); diff --git a/src/feat/online-feature.h b/src/feat/online-feature.h index 11d170972fa..4f66ffef2ff 100644 --- a/src/feat/online-feature.h +++ b/src/feat/online-feature.h @@ -41,6 +41,36 @@ namespace kaldi { /// @{ +/// This class serves as a storage for feature vectors with an option to limit +/// the memory usage by removing old elements. The deleted frames indices are +/// "remembered" so that regardless of the MAX_ITEMS setting, the user always +/// provides the indices as if no deletion was being performed. +/// This is useful when processing very long recordings which would otherwise +/// cause the memory to eventually blow up when the features are not being removed. +class RecyclingVector { +public: + /// By default it does not remove any elements. + RecyclingVector(int items_to_hold = -1); + + /// The ownership is being retained by this collection - do not delete the item. + Vector *At(int index) const; + + /// The ownership of the item is passed to this collection - do not delete the item. + void PushBack(Vector *item); + + /// This method returns the size as if no "recycling" had happened, + /// i.e. equivalent to the number of times the PushBack method has been called. + int Size() const; + + ~RecyclingVector(); + +private: + std::deque*> items_; + int items_to_hold_; + int first_available_index_; +}; + + /// This is a templated class for online feature extraction; /// it's templated on a class like MfccComputer or PlpComputer /// that does the basic feature extraction. @@ -61,7 +91,7 @@ class OnlineGenericBaseFeature: public OnlineBaseFeature { return computer_.GetFrameOptions().frame_shift_ms / 1000.0f; } - virtual int32 NumFramesReady() const { return features_.size(); } + virtual int32 NumFramesReady() const { return features_.Size(); } virtual void GetFrame(int32 frame, VectorBase *feat); @@ -83,14 +113,7 @@ class OnlineGenericBaseFeature: public OnlineBaseFeature { // more waveform. This will help flush out the last frame or two // of features, in the case where snip-edges == false; it also // affects the return value of IsLastFrame(). - virtual void InputFinished() { - input_finished_ = true; - ComputeFeatures(); - } - - ~OnlineGenericBaseFeature() { - DeletePointers(&features_); - } + virtual void InputFinished(); private: // This function computes any additional feature frames that it is possible to @@ -101,13 +124,19 @@ class OnlineGenericBaseFeature: public OnlineBaseFeature { // waveform_remainder_ while incrementing waveform_offset_ by the same amount. void ComputeFeatures(); + void MaybeCreateResampler(BaseFloat sampling_rate); + C computer_; // class that does the MFCC or PLP or filterbank computation + // resampler in cases when the input sampling frequency is not equal to + // the expected sampling rate + std::unique_ptr resampler_; + FeatureWindowFunction window_function_; // features_ is the Mfcc or Plp or Fbank features that we have already computed. - std::vector*> features_; + RecyclingVector features_; // True if the user has called "InputFinished()" bool input_finished_; @@ -182,7 +211,8 @@ struct OnlineCmvnOptions { // class computes the cmvn internally. smaller->more // time-efficient but less memory-efficient. Must be >= 1. int32 ring_buffer_size; // not configurable from command line; size of ring - // buffer used for caching CMVN stats. + // buffer used for caching CMVN stats. Must be >= + // modulus. std::string skip_dims; // Colon-separated list of dimensions to skip normalization // of, e.g. 13:14:15. @@ -371,10 +401,10 @@ class OnlineCmvn: public OnlineFeatureInterface { /// were cached, sets up empty stats for frame zero and returns that]. void GetMostRecentCachedFrame(int32 frame, int32 *cached_frame, - Matrix *stats); + MatrixBase *stats); /// Cache this frame of stats. - void CacheFrame(int32 frame, const Matrix &stats); + void CacheFrame(int32 frame, const MatrixBase &stats); /// Initialize ring buffer for caching stats. inline void InitRingBufferIfNeeded(); @@ -403,6 +433,12 @@ class OnlineCmvn: public OnlineFeatureInterface { // frame index. std::vector > > cached_stats_ring_; + // Some temporary variables used inside functions of this class, which + // put here to avoid reallocation. + Matrix temp_stats_; + Vector temp_feats_; + Vector temp_feats_dbl_; + OnlineFeatureInterface *src_; // Not owned here }; @@ -472,6 +508,9 @@ class OnlineTransform: public OnlineFeatureInterface { virtual void GetFrame(int32 frame, VectorBase *feat); + virtual void GetFrames(const std::vector &frames, + MatrixBase *feats); + // // Next, functions that are not in the interface. // @@ -537,6 +576,9 @@ class OnlineCacheFeature: public OnlineFeatureInterface { virtual void GetFrame(int32 frame, VectorBase *feat); + virtual void GetFrames(const std::vector &frames, + MatrixBase *feats); + virtual ~OnlineCacheFeature() { ClearCache(); } // Things that are not in the shared interface: diff --git a/src/feat/pitch-functions-test.cc b/src/feat/pitch-functions-test.cc index 098e590a8e9..0e481c18674 100644 --- a/src/feat/pitch-functions-test.cc +++ b/src/feat/pitch-functions-test.cc @@ -449,7 +449,7 @@ static void UnitTestKeeleNccfBallast() { // use pitch code with default configuration.. PitchExtractionOptions op; op.nccf_ballast = 0.05 * k; - KALDI_LOG << " nccf_ballast " << op.nccf_ballast << std::endl; + KALDI_LOG << " nccf_ballast " << op.nccf_ballast; // compute pitch. Matrix m; ComputeKaldiPitch(op, waveform, &m); @@ -493,7 +493,7 @@ static void UnitTestPitchExtractionSpeed() { double tot_time = timer.Elapsed(), speech_time = test_num * waveform.Dim() / wave.SampFreq(); KALDI_LOG << " Pitch extraction time per second of speech is " - << (tot_time / speech_time) << " seconds " << std::endl; + << (tot_time / speech_time) << " seconds."; } } static void UnitTestPitchExtractorCompareKeele() { diff --git a/src/feat/resample.cc b/src/feat/resample.cc index 518685d85c8..11f4c62bf1c 100644 --- a/src/feat/resample.cc +++ b/src/feat/resample.cc @@ -302,7 +302,7 @@ void ArbitraryResample::Resample(const VectorBase &input, VectorBase *output) const { KALDI_ASSERT(input.Dim() == num_samples_in_ && output->Dim() == weights_.size()); - + int32 output_dim = output->Dim(); for (int32 i = 0; i < output_dim; i++) { SubVector input_part(input, first_index_[i], weights_[i].Dim()); @@ -365,13 +365,13 @@ BaseFloat ArbitraryResample::FilterFunc(BaseFloat t) const { return filter * window; } -void DownsampleWaveForm(BaseFloat orig_freq, const VectorBase &wave, - BaseFloat new_freq, Vector *new_wave) { - KALDI_ASSERT(new_freq < orig_freq); - BaseFloat lowpass_cutoff = 0.99 * 0.5 * new_freq; +void ResampleWaveform(BaseFloat orig_freq, const VectorBase &wave, + BaseFloat new_freq, Vector *new_wave) { + BaseFloat min_freq = std::min(orig_freq, new_freq); + BaseFloat lowpass_cutoff = 0.99 * 0.5 * min_freq; int32 lowpass_filter_width = 6; - LinearResample signal_downsampler(orig_freq, new_freq, - lowpass_cutoff, lowpass_filter_width); - signal_downsampler.Resample(wave, true, new_wave); + LinearResample resampler(orig_freq, new_freq, + lowpass_cutoff, lowpass_filter_width); + resampler.Resample(wave, true, new_wave); } } // namespace kaldi diff --git a/src/feat/resample.h b/src/feat/resample.h index cc3e5064863..e0b4688c99b 100644 --- a/src/feat/resample.h +++ b/src/feat/resample.h @@ -40,7 +40,7 @@ namespace kaldi { /** \file[resample.h] - + This header contains declarations of classes for resampling signals. The normal cases of resampling a signal are upsampling and downsampling (increasing and decreasing the sample rate of a signal, respectively), @@ -51,7 +51,7 @@ namespace kaldi { The input signal is always evenly spaced, say sampled with frequency S, and we assume the original signal was band-limited to S/2 or lower. The n'th input sample x_n (with n = 0, 1, ...) is interpreted as the original - signal's value at time n/S. + signal's value at time n/S. For resampling, it is convenient to view the input signal as a continuous function x(t) of t, where each sample x_n becomes a delta function @@ -73,14 +73,14 @@ namespace kaldi { means we window the sinc function out to its first zero on the left and right, w = 2 means the second zero, and so on; we normally choose w to be at least two. We call this num_zeros, not w, in the code. - + Convolving the signal x(t) with this windowed filter h(t) = f(t)g(t) and evaluating the resulting signal s(t) at an arbitrary time t is easy: we have \f[ s(t) = 1/S \sum_n x_n h(t - n/S) \f]. (note: the sign of t - n/S might be wrong, but it doesn't matter as the filter and window are symmetric). This is true for arbitrary values of t. What the class ArbitraryResample does - is to allow you to evaluate the signal for specified values of t. + is to allow you to evaluate the signal for specified values of t. */ @@ -90,7 +90,7 @@ namespace kaldi { don't have to be linearly spaced. The low-pass filter cutoff "filter_cutoff_hz" should be less than half the sample rate; "num_zeros" should probably be at least two preferably more; higher numbers give - sharper filters but will be less efficient. + sharper filters but will be less efficient. */ class ArbitraryResample { public: @@ -115,7 +115,7 @@ class ArbitraryResample { /// This version of the Resample function processes just /// one vector. void Resample(const VectorBase &input, - VectorBase *output) const; + VectorBase *output) const; private: void SetIndexes(const Vector &sample_points); @@ -185,6 +185,10 @@ class LinearResample { /// Resample(x, y, true) for the last piece. Call it unnecessarily between /// signals will not do any harm. void Reset(); + + //// Return the input and output sampling rates (for checks, for example) + inline int32 GetInputSamplingRate() { return samp_rate_in_; } + inline int32 GetOutputSamplingRate() { return samp_rate_out_; } private: /// This function outputs the number of output samples we will output /// for a signal with "input_num_samp" input samples. If flush == true, @@ -248,20 +252,35 @@ class LinearResample { ///< previously seen input signal. }; -/// Downsample a waveform. This is a convenience wrapper for the -/// class 'LinearResample'. -/// The low-pass filter cutoff used in 'LinearResample' is 0.99 of half of the -/// new_freq and num_zeros is 6. -/// The downsampling results is also checked wit sox resampling toolkit. -/// Sox design is inspired by Laurent De Soras' paper, -/// https://ccrma.stanford.edu/~jos/resample/Implementation.html -/// It designs low pass filter using pass-band, stop-band, Nyquist freq -/// and stop-band attenuation. -/// e.g. The mainlob for Hanning window is 4pi/M, where the main-lobe width is -/// equal to (pass-band-freq - stop-band-freq). -/// Also the cutoff frequency is equal to (pass-band-freq - stop-band-freq). -void DownsampleWaveForm(BaseFloat orig_freq, const VectorBase &wave, - BaseFloat new_freq, Vector *new_wave); +/** + Downsample or upsample a waveform. This is a convenience wrapper for the + class 'LinearResample'. + The low-pass filter cutoff used in 'LinearResample' is 0.99 of the Nyquist, + where the Nyquist is half of the minimum of (orig_freq, new_freq). The + resampling is done with a symmetric FIR filter with N_z (number of zeros) + as 6. + + We compared the downsampling results with those from the sox resampling + toolkit. + Sox's design is inspired by Laurent De Soras' paper, + https://ccrma.stanford.edu/~jos/resample/Implementation.html + + Note: we expect that while orig_freq and new_freq are of type BaseFloat, they + are actually required to have exact integer values (like 16000 or 8000) with + a ratio between them that can be expressed as a rational number with + reasonably small integer factors. +*/ +void ResampleWaveform(BaseFloat orig_freq, const VectorBase &wave, + BaseFloat new_freq, Vector *new_wave); + + +/// This function is deprecated. It is provided for backward compatibility, to avoid +/// breaking older code. +inline void DownsampleWaveForm(BaseFloat orig_freq, const VectorBase &wave, + BaseFloat new_freq, Vector *new_wave) { + ResampleWaveform(orig_freq, wave, new_freq, new_wave); +} + /// @} End of "addtogroup feat" } // namespace kaldi diff --git a/src/featbin/Makefile b/src/featbin/Makefile index 8e72d0f744c..861ba3f7a93 100644 --- a/src/featbin/Makefile +++ b/src/featbin/Makefile @@ -25,7 +25,7 @@ TESTFILES = ADDLIBS = ../hmm/kaldi-hmm.a ../feat/kaldi-feat.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/featbin/extract-feature-segments.cc b/src/featbin/extract-feature-segments.cc index 93f599feb3a..f6cdcb96b18 100644 --- a/src/featbin/extract-feature-segments.cc +++ b/src/featbin/extract-feature-segments.cc @@ -25,7 +25,7 @@ #include "matrix/kaldi-matrix.h" /** @brief This is a program for extracting segments from feature files/archives - - usage : + - usage : - extract-feature-segments [options ..] - "segments-file" should have the information of the segments that needs to be extracted from the feature files - the format of the segments file : speaker_name filename start_time(in secs) end_time(in secs) @@ -37,6 +37,10 @@ int main(int argc, char *argv[]) { const char *usage = "Create feature files by segmenting input files.\n" + "Note: this program should no longer be needed now that\n" + "'ranges' in scp files are supported; search for 'ranges' in\n" + "http://kaldi-asr.org/doc/io_tut.html, or see the script\n" + "utils/data/subsegment_data_dir.sh.\n" "Usage: " "extract-feature-segments [options...] " " \n" @@ -144,9 +148,9 @@ int main(int argc, char *argv[]) { } } - /* check whether a segment start time and end time exists in utterance + /* check whether a segment start time and end time exists in utterance * if fails , skips the segment. - */ + */ if (!feat_reader.HasKey(utterance)) { KALDI_WARN << "Did not find features for utterance " << utterance << ", skipping segment " << segment; @@ -167,7 +171,7 @@ int main(int argc, char *argv[]) { end_samp -= snip_length; } - /* start sample must be less than total number of samples + /* start sample must be less than total number of samples * otherwise skip the segment */ if (start_samp < 0 || start_samp >= num_samp) { @@ -177,7 +181,7 @@ int main(int argc, char *argv[]) { continue; } - /* end sample must be less than total number samples + /* end sample must be less than total number samples * otherwise skip the segment */ if (end_samp > num_samp) { @@ -221,4 +225,3 @@ int main(int argc, char *argv[]) { return -1; } } - diff --git a/src/fgmmbin/Makefile b/src/fgmmbin/Makefile index baa4cd9be33..5db252477b5 100644 --- a/src/fgmmbin/Makefile +++ b/src/fgmmbin/Makefile @@ -18,7 +18,6 @@ TESTFILES = ADDLIBS = ../decoder/kaldi-decoder.a ../lat/kaldi-lat.a ../hmm/kaldi-hmm.a \ ../feat/kaldi-feat.a ../transform/kaldi-transform.a \ ../gmm/kaldi-gmm.a ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a \ - ../base/kaldi-base.a + ../matrix/kaldi-matrix.a ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/fstbin/Makefile b/src/fstbin/Makefile index 644eb639381..a22c014a7d5 100644 --- a/src/fstbin/Makefile +++ b/src/fstbin/Makefile @@ -15,7 +15,8 @@ BINFILES = fstdeterminizestar \ fstmakecontextsyms fstaddsubsequentialloop fstaddselfloops \ fstrmepslocal fstcomposecontext fsttablecompose fstrand \ fstdeterminizelog fstphicompose fstcopy \ - fstpushspecial fsts-to-transcripts fsts-project fsts-union fsts-concat + fstpushspecial fsts-to-transcripts fsts-project fsts-union \ + fsts-concat make-grammar-fst OBJFILES = @@ -24,7 +25,7 @@ TESTFILES = # actually, this library is currently empty. Everything is a header. LIBFILE = -ADDLIBS = ../fstext/kaldi-fstext.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a +ADDLIBS = ../decoder/kaldi-decoder.a ../fstext/kaldi-fstext.a \ + ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/fstbin/fstcomposecontext.cc b/src/fstbin/fstcomposecontext.cc index d5ea07df4d3..8f9d270ee6b 100644 --- a/src/fstbin/fstcomposecontext.cc +++ b/src/fstbin/fstcomposecontext.cc @@ -22,6 +22,7 @@ #include "util/common-utils.h" #include "fst/fstlib.h" #include "fstext/context-fst.h" +#include "fstext/grammar-context-fst.h" #include "fstext/fstext-utils.h" #include "fstext/kaldi-fst-io.h" @@ -34,27 +35,31 @@ ( echo " 0"; echo "a 1"; echo "b 2"; echo "c 3" ) > phones.txt fstmakecontextsyms phones.txt ilabels.sym > context.txt fstprint --isymbols=context.txt --osymbols=phones.txt tmp.fst -# 0 1 //a a -# 1 2 /a/b b -# 2 3 a/b/c c -# 3 + # and the result is: + +WARNING (fstcomposecontext[5.4]:main():fstcomposecontext.cc:130) Disambiguation symbols list is empty; this likely indicates an error in data preparation. +0 1 a +1 2 /a/b b +2 3 a/b/c c +3 4 b/c/ +4 # (2) with disambig syms: ( echo 4; echo 5) > disambig.list - ( echo " 0"; echo "a 1"; echo "b 2"; echo "c 3" ) > phones.txt + ( echo " 0"; echo "a 1"; echo "b 2"; echo "c 3"; echo "#0 4"; echo "#1 5") > phones.txt ( echo "0 1 1 1"; echo "1 2 2 2"; echo " 2 3 4 4"; echo "3 4 3 3"; echo "4 5 5 5"; echo "5 0" ) | fstcompile > in.fst - fstcomposecontext --disambig-syms=disambig.list ilabels.sym in.fst tmp.fst - fstmakecontextsyms --disambig-syms=disambig.list phones.txt ilabels.sym > context.txt + fstcomposecontext --read-disambig-syms=disambig.list ilabels.sym in.fst tmp.fst + fstmakecontextsyms phones.txt ilabels.sym > context.txt cp phones.txt phones_disambig.txt; ( echo "#0 4"; echo "#1 5" ) >> phones_disambig.txt fstprint --isymbols=context.txt --osymbols=phones_disambig.txt tmp.fst -# 0 1 //a a -# 1 2 /a/b b -# 2 3 #0 #0 -# 3 4 a/b/c c -# 4 5 #1 #1 -# 5 +0 1 #-1 a +1 2 /a/b b +2 3 #0 #0 +3 4 a/b/c c +4 5 #1 #1 +5 6 b/c/ */ @@ -86,22 +91,27 @@ int main(int argc, char *argv[]) { "\n" "Usage: fstcomposecontext [ [] ]\n" "E.g: fstcomposecontext ilabels.sym < LG.fst > CLG.fst\n"; - + ParseOptions po(usage); bool binary = true; std::string disambig_rxfilename, disambig_wxfilename; - int32 N = 3, P = 1; + int32 context_width = 3, central_position = 1; + int32 nonterm_phones_offset = -1; po.Register("binary", &binary, "If true, output ilabels-output-file in binary format"); po.Register("read-disambig-syms", &disambig_rxfilename, "List of disambiguation symbols on input of in.fst"); po.Register("write-disambig-syms", &disambig_wxfilename, "List of disambiguation symbols on input of out.fst"); - po.Register("context-size", &N, "Size of phone context window"); - po.Register("central-position", &P, + po.Register("context-size", &context_width, "Size of phone context window"); + po.Register("central-position", ¢ral_position, "Designated central position in context window"); + po.Register("nonterm-phones-offset", &nonterm_phones_offset, + "The integer id of #nonterm_bos in your phones.txt, if present " + "(only relevant for grammar-FST construction, see " + "doc/grammar.dox"); po.Read(argc, argv); @@ -130,13 +140,24 @@ int main(int argc, char *argv[]) { KALDI_WARN << "Disambiguation symbols list is empty; this likely " << "indicates an error in data preparation."; } - + std::vector > ilabels; VectorFst composed_fst; // Work gets done here (see context-fst.h) - ComposeContext(disambig_in, N, P, fst, &composed_fst, &ilabels); - + if (nonterm_phones_offset < 0) { + // The normal case. + ComposeContext(disambig_in, context_width, central_position, + fst, &composed_fst, &ilabels); + } else { + // The grammar-FST case. See ../doc/grammar.dox for an intro. + if (context_width != 2 || central_position != 1) { + KALDI_ERR << "Grammar-fst graph creation only supports models with left-" + "biphone context. (--nonterm-phones-offset option was supplied)."; + } + ComposeContextLeftBiphone(nonterm_phones_offset, disambig_in, + *fst, &composed_fst, &ilabels); + } WriteILabelInfo(Output(ilabels_out_filename, binary).Stream(), binary, ilabels); @@ -160,4 +181,3 @@ int main(int argc, char *argv[]) { return -1; } } - diff --git a/src/fstbin/fstmakecontextfst.cc b/src/fstbin/fstmakecontextfst.cc index bbe44b3566f..59655a61e9e 100644 --- a/src/fstbin/fstmakecontextfst.cc +++ b/src/fstbin/fstmakecontextfst.cc @@ -46,15 +46,15 @@ int main(int argc, char *argv[]) { bool binary = true; // binary output to ilabels_output_file. std::string disambig_rxfilename, disambig_wxfilename; - int32 N = 3, P = 1; - + int32 context_width = 3, central_position = 1; + ParseOptions po(usage); po.Register("read-disambig-syms", &disambig_rxfilename, "List of disambiguation symbols to read"); po.Register("write-disambig-syms", &disambig_wxfilename, "List of disambiguation symbols to write"); - po.Register("context-size", &N, "Size of phonetic context window"); - po.Register("central-position", &P, + po.Register("context-size", &context_width, "Size of phonetic context window"); + po.Register("central-position", ¢ral_position, "Designated central position in context window"); po.Register("binary", &binary, "Write ilabels output file in binary Kaldi format"); @@ -91,7 +91,7 @@ int main(int argc, char *argv[]) { if ( (disambig_wxfilename != "") && (disambig_rxfilename == "") ) KALDI_ERR << "fstmakecontextfst: cannot specify --write-disambig-syms if " "not specifying --read-disambig-syms\n"; - + std::vector disambig_in; if (disambig_rxfilename != "") { if (!ReadIntegerVectorSimple(disambig_rxfilename, &disambig_in)) @@ -100,21 +100,33 @@ int main(int argc, char *argv[]) { } if (std::binary_search(phone_syms.begin(), phone_syms.end(), subseq_sym) - ||std::binary_search(disambig_in.begin(), disambig_in.end(), subseq_sym)) - KALDI_ERR << "Invalid subsequential symbol "<<(subseq_sym)<<", already a phone or disambiguation symbol."; + || std::binary_search(disambig_in.begin(), disambig_in.end(), subseq_sym)) + KALDI_ERR << "Invalid subsequential symbol " << subseq_sym + << ", already a phone or disambiguation symbol."; + + // 'loop_fst' will be an acceptor FST with single (initial and final) state, with + // a loop for each phone and disambiguation symbol. + StdVectorFst loop_fst; + loop_fst.AddState(); // Add state zero. + loop_fst.SetStart(0); + loop_fst.SetFinal(0, TropicalWeight::One()); + for (size_t i = 0; i < phone_syms.size(); i++) { + int32 sym = phone_syms[i]; + loop_fst.AddArc(0, StdArc(sym, sym, TropicalWeight::One(), 0)); + } + for (size_t i = 0; i < disambig_in.size(); i++) { + int32 sym = disambig_in[i]; + loop_fst.AddArc(0, StdArc(sym, sym, TropicalWeight::One(), 0)); + } + std::vector > ilabels; + VectorFst context_fst; - ContextFst cfst(subseq_sym, - phone_syms, - disambig_in, - N, - P); + ComposeContext(disambig_in, context_width, central_position, + &loop_fst, &context_fst, &ilabels, true); - VectorFst vfst(cfst); // Copy the fst to a VectorFst. + WriteFstKaldi(context_fst, fst_out_filename); - WriteFstKaldi(vfst, fst_out_filename); - - const std::vector > &ilabels = cfst.ILabelInfo(); WriteILabelInfo(Output(ilabels_out_filename, binary).Stream(), binary, ilabels); @@ -133,4 +145,3 @@ int main(int argc, char *argv[]) { return -1; } } - diff --git a/src/fstbin/fstmakecontextsyms.cc b/src/fstbin/fstmakecontextsyms.cc index e3c7d279053..c9d49397545 100644 --- a/src/fstbin/fstmakecontextsyms.cc +++ b/src/fstbin/fstmakecontextsyms.cc @@ -32,18 +32,23 @@ ( echo 3; echo 4 ) > disambig.list fstmakecontextfst --read-disambig-syms=disambig.list <(grep -v '#' phones.txt) 5 ilabels.int > C.fst fstmakecontextsyms phones.txt ilabels.int > context_syms.txt + fstprint --isymbols=context_syms.txt --osymbols=phones.txt C.fst > C.txt + fstrandgen C.fst | fstprint --isymbols=context_syms.txt --osymbols=phones.txt Example output: -0 1 #0 #0 -1 2 #-1 a -2 3 /a/a a -3 4 a/a/a a -4 5 #0 #0 -5 6 a/a/b b -6 7 a/b/ #$ -7 8 #1 #1 -8 + + fstrandgen C.fst | fstprint --isymbols=context_syms.txt --osymbols=phones.txt +0 1 #-1 b +1 2 /b/ #$ +2 3 #1 #1 +3 4 #0 #0 +4 5 #0 #0 +5 6 #0 #0 +6 7 #0 #0 +7 8 #0 #0 +8 9 #1 #1 +9 */ diff --git a/src/fstbin/fstrand.cc b/src/fstbin/fstrand.cc index 9344b538d9c..f0bc3938051 100644 --- a/src/fstbin/fstrand.cc +++ b/src/fstbin/fstrand.cc @@ -45,6 +45,8 @@ int main(int argc, char *argv[]) { po.Register("allow-empty", &opts.allow_empty, "If true, we may generate an empty FST."); + po.Read(argc, argv); + if (po.NumArgs() > 1) { po.PrintUsage(); exit(1); diff --git a/src/fstbin/make-grammar-fst.cc b/src/fstbin/make-grammar-fst.cc new file mode 100644 index 00000000000..fc9a17908f9 --- /dev/null +++ b/src/fstbin/make-grammar-fst.cc @@ -0,0 +1,162 @@ +// fstbin/make-grammar-fst.cc + +// Copyright 2018 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "fst/fstlib.h" +#include "fstext/table-matcher.h" +#include "fstext/kaldi-fst-io.h" +#include "decoder/grammar-fst.h" + +namespace fst { + +// Reads an FST from disk using Kaldi I/O mechanisms, and if it is not of type +// ConstFst, copies it to that stype. +ConstFst* ReadAsConstFst(std::string rxfilename) { + // the following call will throw if there is an error. + Fst *fst = ReadFstKaldiGeneric(rxfilename); + ConstFst *const_fst = dynamic_cast* >(fst); + if (!const_fst) { + const_fst = new ConstFst(*fst); + delete fst; + } + return const_fst; +} + +} + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace fst; + using kaldi::int32; + + const char *usage = + "Construct GrammarFst and write it to disk (or convert it to ConstFst\n" + "and write that to disk instead). Mostly intended for demonstration\n" + "and testing purposes (since it may be more convenient to construct\n" + "GrammarFst from code). See kaldi-asr.org/doc/grammar.html\n" + "Can also be used to prepares FSTs for this use, by calling\n" + "PrepareForGrammarFst(), which does things like adding final-probs and\n" + "making small structural tweaks to the FST\n" + "\n" + "Usage (1): make-grammar-fst [options] \\\n" + " [ ...]] \n" + "\n" + ", are the integer ids of the corresponding\n" + " user-defined nonterminal symbols (e.g. #nonterm:contact_list) in the\n" + " phones.txt file.\n" + "e.g.: make-grammar-fst --nonterm-phones-offset=317 HCLG.fst \\\n" + " 320 HCLG1.fst HCLG_grammar.fst\n" + "\n" + "Usage (2): make-grammar-fst \n" + " Prepare individual FST for compilation into GrammarFst.\n" + " E.g. make-grammar-fst HCLG.fst HCLGmod.fst. The outputs of this\n" + " will then become the arguments , , ... for usage\n" + " pattern (1).\n" + "\n" + "The --nonterm-phones-offset option is required for both usage patterns.\n"; + + + ParseOptions po(usage); + + + int32 nonterm_phones_offset = -1; + bool write_as_grammar = true; + + po.Register("nonterm-phones-offset", &nonterm_phones_offset, + "Integer id of #nonterm_bos in phones.txt"); + po.Register("write-as-grammar", &write_as_grammar, "If true, " + "write as GrammarFst object; if false, convert to " + "ConstFst (readable by standard decoders) " + "and write that."); + + po.Read(argc, argv); + + + if (po.NumArgs() < 2 || po.NumArgs() % 2 != 0) { + po.PrintUsage(); + exit(1); + } + + if (nonterm_phones_offset < 0) + KALDI_ERR << "The --nonterm-phones-offset option must be supplied " + "and positive."; + + if (po.NumArgs() == 2) { + // this usage pattern calls PrepareForGrammarFst(). + VectorFst *fst = ReadFstKaldi(po.GetArg(1)); + PrepareForGrammarFst(nonterm_phones_offset, fst); + // This will write it as VectorFst; to avoid it having to be converted to + // ConstFst when read again by make-grammar-fst, you may want to pipe + // through fstconvert --fst_type=const. + WriteFstKaldi(*fst, po.GetArg(2)); + exit(0); + } + + std::string top_fst_str = po.GetArg(1), + fst_out_str = po.GetArg(po.NumArgs()); + + std::shared_ptr > top_fst( + ReadAsConstFst(top_fst_str)); + std::vector > > > pairs; + + int32 num_pairs = (po.NumArgs() - 2) / 2; + for (int32 i = 1; i <= num_pairs; i++) { + int32 nonterminal; + std::string nonterm_str = po.GetArg(2*i); + if (!ConvertStringToInteger(nonterm_str, &nonterminal) || + nonterminal <= 0) + KALDI_ERR << "Expected positive integer as nonterminal, got: " + << nonterm_str; + std::string fst_str = po.GetArg(2*i + 1); + std::shared_ptr > this_fst(ReadAsConstFst(fst_str)); + pairs.push_back(std::pair > >( + nonterminal, this_fst)); + } + + GrammarFst *grammar_fst = new GrammarFst(nonterm_phones_offset, + top_fst, + pairs); + + if (write_as_grammar) { + bool binary = true; // GrammarFst does not support non-binary write. + WriteKaldiObject(*grammar_fst, fst_out_str, binary); + delete grammar_fst; + } else { + VectorFst vfst; + CopyToVectorFst(grammar_fst, &vfst); + delete grammar_fst; + ConstFst cfst(vfst); + // We don't have a wrapper in kaldi-fst-io.h for writing type + // ConstFst, so do it manually. + bool binary = true, write_binary_header = false; // suppress the ^@B + Output ko(fst_out_str, binary, write_binary_header); + FstWriteOptions wopts(kaldi::PrintableWxfilename(fst_out_str)); + cfst.Write(ko.Stream(), wopts); + } + + KALDI_LOG << "Created grammar FST and wrote it to " + << fst_out_str; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/fstext/Makefile b/src/fstext/Makefile index 7efd9fcfd8c..b76bd413c42 100644 --- a/src/fstext/Makefile +++ b/src/fstext/Makefile @@ -17,14 +17,14 @@ TESTFILES = determinize-star-test \ determinize-lattice-test lattice-utils-test deterministic-fst-test \ push-special-test epsilon-property-test prune-special-test -OBJFILES = push-special.o kaldi-fst-io.o +OBJFILES = push-special.o kaldi-fst-io.o context-fst.o grammar-context-fst.o LIBNAME = kaldi-fstext # tree and matrix archives needed for test-context-fst # matrix archive needed for push-special. -ADDLIBS = ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a +ADDLIBS = ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/fstext/context-fst-inl.h b/src/fstext/context-fst-inl.h deleted file mode 100644 index dc8a4a8370b..00000000000 --- a/src/fstext/context-fst-inl.h +++ /dev/null @@ -1,519 +0,0 @@ -// fstext/context-fst-inl.h - -// Copyright 2009-2011 Microsoft Corporation; Jan Silovsky - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_FSTEXT_CONTEXT_FST_INL_H_ -#define KALDI_FSTEXT_CONTEXT_FST_INL_H_ -#include "base/kaldi-common.h" -#include "fstext/fstext-utils.h" - -// Do not include this file directly. It is included by context-fst.h. - - - -namespace fst { - -/// \addtogroup context_fst_group -/// @{ - -namespace internal { - -template -typename ContextFstImpl::StateId - ContextFstImpl::FindState(const vector &seq) { - // Finds state-id corresponding to this vector of phones. Inserts it if - // necessary. - assert(static_cast(seq.size()) == N_-1); - VectorToStateIter iter = state_map_.find(seq); - if (iter == state_map_.end()) { // Not already in map. - StateId this_state_id = (StateId)state_seqs_.size(); - state_seqs_.push_back(seq); - state_map_[seq] = this_state_id; - return this_state_id; - } else { - return iter->second; - } -} - -template -typename ContextFstImpl::Label -ContextFstImpl::FindLabel(const vector &label_vec) { - // Finds ilabel corresponding to this information.. Creates new ilabel if necessary. - VectorToLabelIter iter = ilabel_map_.find(label_vec); - if (iter == ilabel_map_.end()) { // Not already in map. - Label this_label = ilabel_info_.size(); - ilabel_info_.push_back(label_vec); - ilabel_map_[label_vec] = this_label; - return this_label; - } else { - return iter->second; - } -} - - -template -typename ContextFstImpl::StateId ContextFstImpl::Start() { - if (! CacheImpl::HasStart()) { - vector vec(N_-1, 0); // Vector of N_-1 epsilons. [e.g. N = 3]. - StateId s = FindState(vec); - assert(s == 0); - this->SetStart(s); - } - return CacheImpl::Start(); -} - - - -template -ContextFstImpl::ContextFstImpl(const ContextFstImpl &other): - phone_syms_(other.phone_syms_), - disambig_syms_(other.disambig_syms_) { - KALDI_ERR << "ContextFst copying not yet supported " - << "[not hard, but would have to test.]"; -} - - -template -ContextFstImpl::ContextFstImpl(Label subsequential_symbol, // epsilon not allowed. - const vector &phone_syms, // on output side of ifst. - const vector &disambig_syms, // on output - int N, - int P): - phone_syms_(phone_syms), disambig_syms_(disambig_syms), subsequential_symbol_(subsequential_symbol) , - N_(N), P_(P) { - - { // This block checks the inputs. - assert(subsequential_symbol != 0 - && disambig_syms_.count(subsequential_symbol) == 0 - && phone_syms_.count(subsequential_symbol) == 0); - if (phone_syms.empty()) - KALDI_WARN << "Context FST created but there are no phone symbols: probably input FST was empty."; - assert(phone_syms_.count(0) == 0); - assert(disambig_syms_.count(0) == 0); - for (size_t i = 0; i < phone_syms.size(); i++) - assert(disambig_syms_.count(phone_syms[i]) == 0); - assert(N>0 && P>=0 && P eps_vec; // empty vec. - // Make sure the symbol that equates to epsilon is zero in our numbering. - Label eps_id = FindLabel(eps_vec); // this function will add it to the input - // symbol table, if necessary. - assert(eps_id == 0); // doing this in the initializer should guarantee it is zero. - - if (N > P+1 && !disambig_syms_.empty()) { - // We add in a symbol whose sequence representation is [ 0 ], and whose symbol-id - // is 1. This is treated as a disambiguation symbol, we call it #-1 in printed - // form. It is necessary to ensure that all determinizable LG's will have determinizable - // CLG's. The problem it fixes is quite subtle-- it relates to reordering of - // disambiguation symbols (they appear earlier in CLG than in LG, relative to phones), - // and the fact that if a disambig symbol appears at the very start of a sequence in - // CLG, it's not clear exatly where it appeared on the corresponding sequence at - // the input of LG. - vector pseudo_eps_vec; - pseudo_eps_vec.push_back(0); - pseudo_eps_symbol_= FindLabel(pseudo_eps_vec); // this function will add it to the input - // symbol table, if necessary. - assert(pseudo_eps_symbol_ == 1); - } else pseudo_eps_symbol_ = 0; // use actual epsilon. -} - - - -template -typename ContextFstImpl::Weight ContextFstImpl::Final(StateId s) { - assert(static_cast(s) < state_seqs_.size()); // make sure state exists already. - if (!this->HasFinal(s)) { // Work out final-state weight. - const vector &seq = state_seqs_[s]; - - bool final_ok; - assert(static_cast(seq.size()) == N_-1); - - if (P_ < N_ - 1) { - /* Note that P_ (in zero based indexing) is the "central position", and for arcs out of - this state the thing at P_ will be the one we expand. If this is the subsequential symbol, - it means we will output nothing (and will obviously never output anything). Thus we make - this state the final state. - */ - final_ok = (seq[P_] == subsequential_symbol_); - } else { - /* If P_ == N_-1, then the "central phone" is the last one in the list (we have a left-context system). - In this case everything is output immediately and there is no need for a subsequential symbol. - Here, any state in the FST can be the final state. - */ - final_ok = true; - } - Weight w = final_ok ? Weight::One() : Weight::Zero(); - this->SetFinal(s, w); - return w; - } - return CacheImpl::Final(s); -} - -// Warning! Not tested for correctness. Does not really matter, the way -// this function is being used so far. Note: could possibly be wrong, -template -size_t ContextFstImpl::NumArcs(StateId s) { - if (this->HasArcs(s)) { - return CacheImpl::NumArcs(s); - } - KALDI_ASSERT(s >= 0 && s < state_seqs_.size()); - const vector &seq = state_seqs_[s]; - KALDI_ASSERT(seq.size() == N_ - 1); - if (!seq.empty() && seq.back() == subsequential_symbol_) { - // State is not a "normal" state because it just saw the subsequential symbol, - // hence it cannot accept phones. - - if (P_ == N_ - 1 || seq[P_] == subsequential_symbol_) { // don't - // accept subsequential symbol.. c.f. logic in CreateArc(). - return disambig_syms_.size(); - } else { - return disambig_syms_.size() + 1; // Accept disambig syms and - // subsequential symbol. - } - } else { - // For normal states, in general there is potentially an arc for each phone and an arc - // for each disambiguation symbol, plus one for the subsequential symbol. - return phone_syms_.size() + disambig_syms_.size() + 1; - } -} - -template -size_t ContextFstImpl::NumInputEpsilons(StateId s) { - if (!this->HasArcs(s)) - Expand(s); - return CacheImpl::NumInputEpsilons(s); -} - -template -void ContextFstImpl::InitArcIterator(StateId s, ArcIteratorData *data) { - if (!this->HasArcs(s)) - Expand(s); - CacheImpl::InitArcIterator(s, data); -} - - -template -void ContextFstImpl::CreateDisambigArc(StateId s, - Label olabel, - Arc *oarc) { // called from CreateArc. - // Creates a self-loop arc corresponding to the disambiguation symbol. - vector label_info; // (olabel); - label_info.push_back(-olabel); // olabel is a disambiguation symbol. Use its negative - // so we can easily distinguish them. - Label ilabel = FindLabel(label_info); - oarc->ilabel = ilabel; - oarc->olabel = olabel; - oarc->weight = Weight::One(); - oarc->nextstate = s; // self-loop. -} - -template -bool ContextFstImpl::CreatePhoneOrEpsArc(StateId src, - StateId dst, - Label olabel, - const vector &phone_seq, - Arc *oarc) { - // called from CreateArc. - // creates the arc with a phone's state on its input labels (or epsilon). - // returns true if it created the arc. - // returns false if it could not create an arc due to the decision-tree returning false - // [this only happens if opts_.behavior_on_failure == ContextFstOptions::kNoArc]. - - assert(phone_seq[P_] != subsequential_symbol_); // would be coding error. - - if (phone_seq[P_] == 0) { // this can happen at the beginning of the graph. - // we don't output a real phone. Epsilon arc (but sometimes we need to - // use a special disambiguation symbol instead of epsilon). - *oarc = Arc(pseudo_eps_symbol_, olabel, Weight::One(), dst); - // This 1 is a "special" disambiguation symbol (#-1 in printed form) that - // we use to represent epsilons. - return true; - } else { - // have a phone in central position. - Label ilabel = FindLabel(phone_seq); - *oarc = Arc(ilabel, olabel, Weight::One(), dst); - return true; - } -} - - -// This function is specific to ContextFst. It's not part of the Fst -// interface but it's called (indirectly)by the special matcher. It -// attempts to create an arc out of state s, with output label -// "olabel" [it works out the input label from the value of "olabel". -// It returns true if it is able to create an arc, and false -// otherwise. -template -bool ContextFstImpl::CreateArc(StateId s, - Label olabel, - Arc *oarc) { - // Returns true to indicate the arc exists. - - if (olabel == 0) return false; // No epsilon-output arcs in this FST. - - const vector &seq = state_seqs_[s]; - - if (IsDisambigSymbol(olabel)) { // Disambiguation-symbol arcs.. create self-loop. - CreateDisambigArc(s, olabel, oarc); - return true; - } else if (IsPhoneSymbol(olabel) || olabel == subsequential_symbol_) { - // If all is OK, we shift the old sequence left by 1 and push on the new phone. - - if (olabel != subsequential_symbol_ && !seq.empty() && - seq.back() == subsequential_symbol_) { - return false; // Phone not allowed to follow subsequential symbol. - } - - if (olabel == subsequential_symbol_ && - (P_ == N_-1 || seq[P_] == subsequential_symbol_)) { - // We already had "enough" subsequential symbols in a row and don't want to - // accept any more, or we'd be making the subsequential symbol the central phone. - return false; - } - - vector newseq(N_-1); // seq shifted left by 1. - for (int i = 0;i < N_-2;i++) newseq[i] = seq[i+1]; - if (N_ > 1) newseq[N_-2] = olabel; - - vector phoneseq(seq); // copy it before FindState which - // possibly changes the address. - StateId nextstate = FindState(newseq); - - phoneseq.push_back(olabel); // Now it's the full context window of size N_. - for (int i = 1; i < N_ ; i++) - if (phoneseq[i] == subsequential_symbol_) phoneseq[i] = 0; // don't put subseq. symbol on - // the output arcs, just 0. - return CreatePhoneOrEpsArc(s, nextstate, olabel, phoneseq, oarc); - } else { - KALDI_ERR << "ContextFst: CreateArc, invalid olabel supplied [confusion " - << "about phone list or disambig symbols?]: " << olabel; - } - return false; // won't get here. suppress compiler error. -} - -// Note that Expand is not called if we do the composition using -// ContextMatcher, which is the normal case. -template -void ContextFstImpl::Expand(StateId s) { // expands arcs only [not final state weight]. - assert(static_cast(s) < state_seqs_.size()); // make sure state exists already. - - // We just try adding all possible symbols on the output side. - Arc arc; - if (this->CreateArc(s, subsequential_symbol_, &arc)) { - this->PushArc(s, arc); - } - for (typename kaldi::ConstIntegerSet