diff --git a/.travis.yml b/.travis.yml index f8e2bac0362..9f94726c07b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -22,7 +22,7 @@ addons: branches: only: - master - - shortcut + - kaldi_52 before_install: - cat /proc/sys/kernel/core_pattern diff --git a/egs/ami/s5b/RESULTS_ihm b/egs/ami/s5b/RESULTS_ihm index 1003197701e..03fcd9b00f3 100644 --- a/egs/ami/s5b/RESULTS_ihm +++ b/egs/ami/s5b/RESULTS_ihm @@ -54,17 +54,12 @@ %WER 22.4 | 12643 89977 | 80.3 12.5 7.2 2.7 22.4 53.6 | -0.503 | exp/ihm/nnet3_cleaned/lstm_bidirectional_sp/decode_eval/ascore_10/eval_hires.ctm.filt.sys ############################################ -# cleanup + chain TDNN model. +# cleanup + chain TDNN model # local/chain/run_tdnn.sh --mic ihm --stage 4 & -# for d in exp/ihm/chain_cleaned/tdnn1d_sp_bi/decode_*; do grep Sum $d/*sc*/*ys | utils/best_wer.sh; done -%WER 21.7 | 13098 94488 | 81.1 10.4 8.4 2.8 21.7 54.4 | 0.096 | exp/ihm/chain_cleaned/tdnn1d_sp_bi/decode_dev/ascore_10/dev_hires.ctm.filt.sys -%WER 22.1 | 12643 89979 | 80.5 12.1 7.4 2.6 22.1 52.8 | 0.185 | exp/ihm/chain_cleaned/tdnn1d_sp_bi/decode_eval/ascore_10/eval_hires.ctm.filt.sys +# for d in exp/ihm/chain_cleaned/tdnn1e_sp_bi/decode_*; do grep Sum $d/*sc*/*ys | utils/best_wer.sh; done +%WER 21.4 | 13098 94487 | 81.4 10.1 8.5 2.8 21.4 53.7 | 0.090 | exp/ihm/chain_cleaned/tdnn1e_batch_sp_bi/decode_dev/ascore_10/dev_hires.ctm.filt.sys +%WER 21.5 | 12643 89977 | 81.0 11.8 7.2 2.5 21.5 52.4 | 0.168 | exp/ihm/chain_cleaned/tdnn1e_batch_sp_bi/decode_eval/ascore_10/eval_hires.ctm.filt.sys -# cleanup + chain TDNN model. Uses LDA instead of PCA for ivector features. -# local/chain/tuning/run_tdnn_1b.sh --mic ihm --stage 4 & -# for d in exp/ihm/chain_cleaned/tdnn1b_sp_bi/decode_*; do grep Sum $d/*sc*/*ys | utils/best_wer.sh; done -%WER 22.0 | 13098 94488 | 80.8 10.2 9.0 2.8 22.0 54.7 | 0.102 | exp/ihm/chain_cleaned/tdnn1b_sp_bi/decode_dev/ascore_10/dev_hires.ctm.filt.sys -%WER 22.2 | 12643 89968 | 80.3 12.1 7.6 2.6 22.2 52.9 | 0.170 | exp/ihm/chain_cleaned/tdnn1b_sp_bi/decode_eval/ascore_10/eval_hires.ctm.filt.sys # local/chain/run_tdnn.sh --mic ihm --train-set train --gmm tri3 --nnet3-affix "" --stage 4 # chain TDNN model without cleanup [note: cleanup helps very little on this IHM data.] @@ -72,6 +67,7 @@ %WER 21.8 | 13098 94484 | 80.7 9.7 9.6 2.5 21.8 54.2 | 0.114 | exp/ihm/chain/tdnn1d_sp_bi/decode_dev/ascore_10/dev_hires.ctm.filt.sys %WER 22.1 | 12643 89965 | 80.2 11.5 8.3 2.3 22.1 52.5 | 0.203 | exp/ihm/chain/tdnn1d_sp_bi/decode_eval/ascore_10/eval_hires.ctm.filt.sy + # local/chain/multi_condition/run_tdnn.sh --mic ihm # cleanup + chain TDNN model + IHM reverberated data # for d in exp/ihm/chain_cleaned_rvb/tdnn_sp_bi/decode_*; do grep Sum $d/*sc*/*ys | utils/best_wer.sh; done diff --git a/egs/ami/s5b/RESULTS_mdm b/egs/ami/s5b/RESULTS_mdm index d9155eca507..628fe715394 100644 --- a/egs/ami/s5b/RESULTS_mdm +++ b/egs/ami/s5b/RESULTS_mdm @@ -54,20 +54,13 @@ %WER 41.6 | 13964 89980 | 62.7 23.1 14.2 4.3 41.6 65.6 | 0.649 | exp/mdm8/nnet3/tdnn_sp_ihmali/decode_eval/ascore_12/eval_hires_o4.ctm.filt.sys -################ - -# local/chain/run_tdnn.sh --mic mdm8 --stage 11 & -# cleanup + chain TDNN model, alignments from mdm8 data itself. -# for d in exp/mdm8/chain_cleaned/tdnn_sp_bi/decode_*; do grep Sum $d/*sc*/*ys | utils/best_wer.sh; done -%WER 37.9 | 14471 94512 | 65.9 17.4 16.6 3.8 37.9 67.4 | 0.625 | exp/mdm8/chain_cleaned/tdnn_sp_bi/decode_dev/ascore_9/dev_hires_o4.ctm.filt.sys -%WER 41.3 | 13696 89959 | 62.0 18.6 19.4 3.3 41.3 67.2 | 0.591 | exp/mdm8/chain_cleaned/tdnn_sp_bi/decode_eval/ascore_9/eval_hires_o4.ctm.filt.sys - - +############################################ # cleanup + chain TDNN model, alignments from IHM data (IHM alignments help). # local/chain/run_tdnn.sh --mic mdm8 --use-ihm-ali true --stage 12 & -# for d in exp/mdm8/chain_cleaned/tdnn1d_sp_bi_ihmali/decode_*; do grep Sum $d/*sc*/*ys | utils/best_wer.sh; done -%WER 36.4 | 15140 94513 | 67.3 17.5 15.2 3.6 36.4 63.2 | 0.613 | exp/mdm8/chain_cleaned/tdnn1d_sp_bi_ihmali/decode_dev/ascore_9/dev_hires_o4.ctm.filt.sys -%WER 39.7 | 13835 89969 | 63.2 18.4 18.4 3.0 39.7 65.7 | 0.584 | exp/mdm8/chain_cleaned/tdnn1d_sp_bi_ihmali/decode_eval/ascore_9/eval_hires_o4.ctm.filt.sys +# for d in exp/mdm8/chain_cleaned/tdnn1e_sp_bi/decode_*; do grep Sum $d/*sc*/*ys | utils/best_wer.sh; done +%WER 36.0 | 14597 94517 | 67.8 17.7 14.5 3.8 36.0 64.9 | 0.623 | exp/mdm8/chain_cleaned/tdnn1e_sp_bi_ihmali/decode_dev/ascore_9/dev_hires_o4.ctm.filt.sys +%WER 39.3 | 13872 89973 | 63.9 19.0 17.1 3.2 39.3 65.1 | 0.594 | exp/mdm8/chain_cleaned/tdnn1e_sp_bi_ihmali/decode_eval/ascore_9/eval_hires_o4.ctm.filt.sys + # local/chain/run_tdnn.sh --use-ihm-ali true --mic mdm8 --train-set train --gmm tri3 --nnet3-affix "" --stage 12 & # chain TDNN model-- no cleanup, but IHM alignments. @@ -76,6 +69,14 @@ %WER 36.9 | 15282 94502 | 67.1 18.5 14.4 4.1 36.9 62.5 | 0.635 | exp/mdm8/chain/tdnn1d_sp_bi_ihmali/decode_dev/ascore_8/dev_hires_o4.ctm.filt.sys %WER 40.2 | 13729 89992 | 63.3 19.8 17.0 3.5 40.2 66.4 | 0.608 | exp/mdm8/chain/tdnn1d_sp_bi_ihmali/decode_eval/ascore_8/eval_hires_o4.ctm.filt.sys + +# local/chain/run_tdnn.sh --mic mdm8 --stage 11 & +# cleanup + chain TDNN model, alignments from mdm8 data itself. +# for d in exp/mdm8/chain_cleaned/tdnn_sp_bi/decode_*; do grep Sum $d/*sc*/*ys | utils/best_wer.sh; done +%WER 37.9 | 14471 94512 | 65.9 17.4 16.6 3.8 37.9 67.4 | 0.625 | exp/mdm8/chain_cleaned/tdnn_sp_bi/decode_dev/ascore_9/dev_hires_o4.ctm.filt.sys +%WER 41.3 | 13696 89959 | 62.0 18.6 19.4 3.3 41.3 67.2 | 0.591 | exp/mdm8/chain_cleaned/tdnn_sp_bi/decode_eval/ascore_9/eval_hires_o4.ctm.filt.sys + + # local/chain/multi_condition/run_tdnn.sh --mic mdm8 --use-ihm-ali true --train-set train_cleaned --gmm tri3_cleaned # cleanup + chain TDNN model, MDM original + IHM reverberated data, alignments from IHM data # for d in exp/mdm8/chain_cleaned_rvb/tdnn_sp_rvb_bi_ihmali/decode_*; do grep Sum $d/*sc*/*ys | utils/best_wer.sh; done diff --git a/egs/ami/s5b/RESULTS_sdm b/egs/ami/s5b/RESULTS_sdm index 737f8f6dc09..9f936b304be 100644 --- a/egs/ami/s5b/RESULTS_sdm +++ b/egs/ami/s5b/RESULTS_sdm @@ -52,10 +52,14 @@ %WER 37.9 | 15953 94512 | 66.7 22.0 11.3 4.7 37.9 58.9 | 0.734 | exp/sdm1/nnet3_cleaned/lstm_bidirectional_sp_ihmali/decode_dev/ascore_12/dev_hires_o4.ctm.filt.sys %WER 41.2 | 13271 89635 | 62.9 23.8 13.2 4.2 41.2 67.8 | 0.722 | exp/sdm1/nnet3_cleaned/lstm_bidirectional_sp_ihmali/decode_eval/ascore_11/eval_hires_o4.ctm.filt.sys -# ========================= +############################################ +# cleanup + chain TDNN model, alignments from IHM data (IHM alignments help) +# local/chain/run_tdnn.sh --mic sdm1 --use-ihm-ali true --stage 12 & +# for d in exp/sdm1/chain_cleaned/tdnn1e_sp_bi/decode_*; do grep Sum $d/*sc*/*ys | utils/best_wer.sh; done +%WER 39.1 | 14457 94509 | 64.6 19.7 15.7 3.7 39.1 66.5 | 0.585 | exp/sdm1/chain_cleaned/tdnn1e_sp_bi_ihmali/decode_dev/ascore_9/dev_hires_o4.ctm.filt.sys +%WER 43.2 | 13551 89981 | 60.3 20.9 18.8 3.5 43.2 67.1 | 0.554 | exp/sdm1/chain_cleaned/tdnn1e_sp_bi_ihmali/decode_eval/ascore_9/eval_hires_o4.ctm.filt.sys -# local/chain/run_tdnn.sh --mic sdm1 --stage 12 & # cleanup + chain TDNN model, alignments from sdm1 data itself. # for d in exp/sdm1/chain_cleaned/tdnn_sp_bi/decode_*; do grep Sum $d/*sc*/*ys | utils/best_wer.sh; done @@ -63,15 +67,6 @@ %WER 45.4 | 12886 89960 | 58.1 21.0 20.9 3.5 45.4 71.9 | 0.558 | exp/sdm1/chain_cleaned/tdnn_sp_bi/decode_eval/ascore_9/eval_hires_o4.ctm.filt.sys - -# cleanup + chain TDNN model, alignments from IHM data (IHM alignments help). -# local/chain/run_tdnn.sh --mic sdm1 --use-ihm-ali true --stage 12 & -# cleanup + chain TDNN model, cleaned data and alignments from ihm data. -# for d in exp/sdm1/chain_cleaned/tdnn1d_sp_bi_ihmali/decode_*; do grep Sum $d/*sc*/*ys | utils/best_wer.sh; done -%WER 39.5 | 14280 94503 | 64.0 19.3 16.7 3.5 39.5 67.7 | 0.582 | exp/sdm1/chain_cleaned/tdnn1d_sp_bi_ihmali/decode_dev/ascore_9/dev_hires_o4.ctm.filt.sys -%WER 43.9 | 13566 89961 | 59.3 20.9 19.9 3.1 43.9 67.9 | 0.547 | exp/sdm1/chain_cleaned/tdnn1d_sp_bi_ihmali/decode_eval/ascore_9/eval_hires_o4.ctm.filt.sys - - # no-cleanup + chain TDNN model, IHM alignments. # A bit worse than with cleanup [+0.3, +0.4]. # local/chain/run_tdnn.sh --use-ihm-ali true --mic sdm1 --train-set train --gmm tri3 --nnet3-affix "" --stage 12 @@ -79,6 +74,7 @@ %WER 39.8 | 15384 94535 | 64.4 21.0 14.6 4.2 39.8 62.8 | 0.610 | exp/sdm1/chain/tdnn1d_sp_bi_ihmali/decode_dev/ascore_8/dev_hires_o4.ctm.filt.sys %WER 44.3 | 14046 90002 | 59.6 23.1 17.3 3.9 44.3 65.6 | 0.571 | exp/sdm1/chain/tdnn1d_sp_bi_ihmali/decode_eval/ascore_8/eval_hires_o4.ctm.filt.sys + # local/chain/multi_condition/run_tdnn.sh --mic sdm1 --use-ihm-ali true --train-set train_cleaned --gmm tri3_cleaned # cleanup + chain TDNN model, SDM original + IHM reverberated data, alignments from ihm data. # for d in exp/sdm1/chain_cleaned_rvb/tdnn_sp_rvb_bi_ihmali/decode_*; do grep Sum $d/*sc*/*ys | utils/best_wer.sh; done diff --git a/egs/ami/s5b/local/chain/run_tdnn.sh b/egs/ami/s5b/local/chain/run_tdnn.sh index e1adaa9346d..75da1a0a553 120000 --- a/egs/ami/s5b/local/chain/run_tdnn.sh +++ b/egs/ami/s5b/local/chain/run_tdnn.sh @@ -1 +1 @@ -tuning/run_tdnn_1d.sh \ No newline at end of file +tuning/run_tdnn_1e.sh \ No newline at end of file diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_1e.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_1e.sh new file mode 100755 index 00000000000..697fe97df94 --- /dev/null +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_1e.sh @@ -0,0 +1,266 @@ +#!/bin/bash + +# same as 1b but uses batchnorm components instead of renorm + +# Results on 03/27/2017: +# local/chain/compare_wer_general.sh ihm tdnn1b_sp_bi tdnn1e_sp_bi +# System tdnn1b_sp_bi tdnn1e_sp_bi +# WER on dev 21.9 21.4 +# WER on eval 22.2 21.5 +# Final train prob -0.0906771 -0.0857669 +# Final valid prob -0.126942 -0.124401 +# Final train prob (xent) -1.4427 -1.37837 +# Final valid prob (xent) -1.60284 -1.5634 + +set -e -o pipefail +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +mic=ihm +nj=30 +min_seg_len=1.55 +use_ihm_ali=false +train_set=train_cleaned +gmm=tri3_cleaned # the gmm for the target data +ihm_gmm=tri3 # the gmm for the IHM system (if --use-ihm-ali true). +num_threads_ubm=32 +ivector_transform_type=pca +nnet3_affix=_cleaned # cleanup affix for nnet3 and chain dirs, e.g. _cleaned + +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +train_stage=-10 +tree_affix= # affix for tree directory, e.g. "a" or "b", in case we change the configuration. +tdnn_affix=1e #affix for TDNN directory, e.g. "a" or "b", in case we change the configuration. +common_egs_dir= # you can set this to use previously dumped egs. + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat <data/lang_chain/topo + fi +fi + +if [ $stage -le 13 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj 100 --cmd "$train_cmd" ${lores_train_data_dir} \ + data/lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 14 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --leftmost-questions-truncate -1 \ + --cmd "$train_cmd" 4200 ${lores_train_data_dir} data/lang_chain $ali_dir $tree_dir +fi + +xent_regularize=0.1 + +if [ $stage -le 15 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=450 + relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1) dim=450 + relu-batchnorm-layer name=tdnn3 input=Append(-1,0,1) dim=450 + relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=450 + relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=450 + relu-batchnorm-layer name=tdnn6 input=Append(-3,0,3) dim=450 + relu-batchnorm-layer name=tdnn7 input=Append(-3,0,3) dim=450 + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain input=tdnn7 dim=450 target-rms=0.5 + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn7 dim=450 target-rms=0.5 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 + +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 16 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/ami-$(date +'%m_%d_%H_%M')/s5b/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir $train_ivector_dir \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --egs.dir "$common_egs_dir" \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width 150 \ + --trainer.num-chunk-per-minibatch 128 \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs 4 \ + --trainer.optimization.num-jobs-initial 2 \ + --trainer.optimization.num-jobs-final 12 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs true \ + --feat-dir $train_data_dir \ + --tree-dir $tree_dir \ + --lat-dir $lat_dir \ + --dir $dir +fi + + +graph_dir=$dir/graph_${LM} +if [ $stage -le 17 ]; then + # Note: it might appear that this data/lang_chain directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 data/lang_${LM} $dir $graph_dir +fi + +if [ $stage -le 18 ]; then + rm $dir/.error 2>/dev/null || true + for decode_set in dev eval; do + ( + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $nj --cmd "$decode_cmd" \ + --online-ivector-dir exp/$mic/nnet3${nnet3_affix}/ivectors_${decode_set}_hires \ + --scoring-opts "--min-lmwt 5 " \ + $graph_dir data/$mic/${decode_set}_hires $dir/decode_${decode_set} || exit 1; + ) || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in decoding" + exit 1 + fi +fi +exit 0 diff --git a/egs/cifar/README.txt b/egs/cifar/README.txt new file mode 100644 index 00000000000..6546dba44bc --- /dev/null +++ b/egs/cifar/README.txt @@ -0,0 +1,7 @@ + +This directory contains example scripts for image classification with the +CIFAR-10 and CIFAR-100 datasets, which are available for free from +https://www.cs.toronto.edu/~kriz/cifar.html. + +This demonstrates applying the nnet3 framework to image classification for +fixed size images. diff --git a/egs/cifar/v1/cmd.sh b/egs/cifar/v1/cmd.sh new file mode 100644 index 00000000000..a14090a74a1 --- /dev/null +++ b/egs/cifar/v1/cmd.sh @@ -0,0 +1,29 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +export train_cmd="queue.pl" +export decode_cmd="queue.pl --mem 4G" +export mkgraph_cmd="queue.pl --mem 8G" +export cuda_cmd="queue.pl --gpu 1" + + +# the rest of this file is present for historical reasons. it's better to +# create and edit conf/queue.conf for cluster-specific configuration. +if [ "$(hostname -d)" == "fit.vutbr.cz" ]; then + # BUT cluster: + queue="all.q@@blade,all.q@@speech" + storage="matylda5" + export train_cmd="queue.pl -q $queue -l ram_free=1.5G,mem_free=1.5G,${storage}=0.25" + export decode_cmd="queue.pl -q $queue -l ram_free=2.5G,mem_free=2.5G,${storage}=0.1" + export cuda_cmd="queue.pl -q long.q -l gpu=1" +fi + diff --git a/egs/cifar/v1/image/README.txt b/egs/cifar/v1/image/README.txt new file mode 100644 index 00000000000..3982dd5ee3d --- /dev/null +++ b/egs/cifar/v1/image/README.txt @@ -0,0 +1,2 @@ +This directory contains various scripts that relate to image recognition: specifically, +the recognition of fixed-size images. diff --git a/egs/cifar/v1/image/nnet3/get_egs.sh b/egs/cifar/v1/image/nnet3/get_egs.sh new file mode 100644 index 00000000000..905c9441946 --- /dev/null +++ b/egs/cifar/v1/image/nnet3/get_egs.sh @@ -0,0 +1,32 @@ +#!/usr/bin/env bash + +# This script is like steps/nnet3/get_egs.sh, except it is specialized for +# classification of fixed-size images; and you have to provide the +# dev or test data in a separate directory. + + +# Begin configuration section. +cmd=run.pl +egs_per_archive=25000 +test_mode=false +# end configuration section + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + + +if [ $# != 2 ]; then + echo "Usage: $0 [opts] " + echo " e.g.: $0 --egs-per-iter 25000 data/cifar10_train exp/cifar10_train_egs" + echo " or: $0 --test-mode true data/cifar10_test exp/cifar10_test_egs" + echo "Options (with defaults):" + echo " --cmd 'run.pl' How to run jobs (e.g. queue.pl)" + echo " --test-mode false Set this to true if you just want a single archive" + echo " egs.ark to be created (useful for test data)" + echo " --egs-per-archive 25000 Number of images to put in each training archive" + echo " (this is a target; the actual number will be chosen" + echo " as some fraction of the total." + exit 1; +fi diff --git a/egs/cifar/v1/image/validate_image_dir.sh b/egs/cifar/v1/image/validate_image_dir.sh new file mode 100755 index 00000000000..f49d51031b8 --- /dev/null +++ b/egs/cifar/v1/image/validate_image_dir.sh @@ -0,0 +1,78 @@ +#!/usr/bin/env bash + +# This script validates a directory containing training or test images +# for image-classification tasks with fixed-size images. + + +if [ $# != 1 ]; then + echo "Usage: $0 " + echo "e.g.: $0 data/cifar10_train" +fi + +dir=$1 + +[ -e ./path.sh ] && . ./path.sh + +if [ ! -d $dir ]; then + echo "$0: directory $dir does not exist." +fi + +for f in images.scp labels.txt classes.txt num_colors; do + if [ ! -s "$dir/$f" ]; then + echo "$0: expected file $dir/$f to exist and be nonempty" + fi +done + + +num_colors=$(cat $dir/num_colors) + +if ! [[ $num_colors -gt 0 ]]; then + echo "$0: expected the file $dir/num_colors to contain a number >0" + exit 1 +fi + +paf="--print-args=false" + +num_cols=$(head -n 1 $dir/images.scp | feat-to-dim $paf scp:- -) +if ! [[ $[$num_cols%$num_colors] == 0 ]]; then + echo "$0: expected the number of columns in the image matrices ($num_cols) to " + echo " be a multiple of the number of colors ($num_colors)" + exit 1 +fi + +num_rows=$(head -n 1 $dir/images.scp | feat-to-len $paf scp:- -) + +height=$[$num_cols/$num_colors] + +echo "$0: images are width=$num_rows by height=$height, with $num_colors colors." + +if ! cmp <(awk '{print $1}' $dir/images.scp) <(awk '{print $1}' $dir/labels.txt); then + echo "$0: expected the first fields of $dir/images.scp and $dir/labels.txt to match up." + exit 1; +fi + +if ! [[ $num_cols -eq $(tail -n 1 $dir/images.scp | feat-to-dim $paf scp:- -) ]]; then + echo "$0: the number of columns in the image matrices is not consistent." + exit 1 +fi + +if ! [[ $num_rows -eq $(tail -n 1 $dir/images.scp | feat-to-len scp:- -) ]]; then + echo "$0: the number of rows in the image matrices is not consistent." + exit 1 +fi + +# Note: we don't require images.scp and labels.txt to be sorted, but they +# may not contain repeated keys. +if ! awk '{if($1 in a) { print "validate_image_dir.sh: key " $1 " is repeated in labels.txt"; exit 1; } a[$1]=1; }'; then + exit 1 +fi + + +if ! utils/int2sym.pl -f 2 $dir/classes.txt <$dir/labels.txt >/dev/null; then + echo "$0: classes.txt may have the wrong format or may not cover all labels in $dir/labels.txt" + exit 1; +fi + + +echo "$0: validated image-data directory $dir" +exit 0 diff --git a/egs/cifar/v1/local/prepare_data.sh b/egs/cifar/v1/local/prepare_data.sh new file mode 100755 index 00000000000..7314ebf9188 --- /dev/null +++ b/egs/cifar/v1/local/prepare_data.sh @@ -0,0 +1,73 @@ +#!/bin/bash + +# Copyright 2017 Johns Hopkins University (author: Hossein Hadian) +# Apache 2.0 + +# This script loads the training and test data for CIFAR-10 or CIFAR-100. + +[ -f ./path.sh ] && . ./path.sh; # source the path. + +dl_dir=data/dl +cifar10=$dl_dir/cifar-10-batches-bin +cifar10_url=https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz +cifar100=$dl_dir/cifar-100-binary +cifar100_url=https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz + +mkdir -p $dl_dir +if [ -d $cifar10 ]; then + echo Not downloading CIFAR-10 as it is already there. +else + if [ ! -f $dl_dir/cifar-10-binary.tar.gz ]; then + echo Downloading CIFAR-10... + wget -P $dl_dir $cifar10_url || exit 1; + fi + tar -xvzf $dl_dir/cifar-10-binary.tar.gz -C $dl_dir || exit 1; + echo Done downaloding and extracting CIFAR-10 +fi + +mkdir -p data/cifar10_{train,test}/data +seq 0 9 | paste -d' ' data/dl/cifar-10-batches-bin/batches.meta.txt - | grep '\S' >data/cifar10_train/classes.txt +cp data/cifar10_{train,test}/classes.txt +echo 3 > data/cifar10_train/num_colors +echo 3 > data/cifar10_test/num_colors + +local/process_data.py --dataset train $cifar10 data/cifar10_train/ | \ + copy-feats --compress=true --compression-method=6 \ + ark:- ark,scp:data/cifar10_train/data/images.ark,data/cifar10_train/images.scp || exit 1 + +local/process_data.py --dataset test $cifar10 data/cifar10_test/ | \ + copy-feats --compress=true --compression-method=6 \ + ark:- ark,scp:data/cifar10_test/data/images.ark,data/cifar10_test/images.scp || exit 1 + + + +### CIFAR 100 + +if [ -d $cifar100 ]; then + echo Not downloading CIFAR-100 as it is already there. +else + if [ ! -f $dl_dir/cifar-100-binary.tar.gz ]; then + echo Downloading CIFAR-100... + wget -P $dl_dir $cifar100_url || exit 1; + fi + tar -xvzf $dl_dir/cifar-100-binary.tar.gz -C $dl_dir || exit 1; + echo Done downaloding and extracting CIFAR-100 +fi + +mkdir -p data/cifar100_{train,test}/data +seq 0 99 | paste -d' ' $cifar100/fine_label_names.txt - | grep '\S' >data/cifar100_train/fine_classes.txt +seq 0 19 | paste -d' ' $cifar100/coarse_label_names.txt - | grep '\S' >data/cifar100_train/coarse_classes.txt + +cp data/cifar100_{train,test}/fine_classes.txt +cp data/cifar100_{train,test}/coarse_classes.txt + +echo 3 > data/cifar100_train/num_colors +echo 3 > data/cifar100_test/num_colors + +local/process_data.py --dataset train $cifar100 data/cifar100_train/ | \ + copy-feats --compress=true --compression-method=6 \ + ark:- ark,scp:data/cifar100_train/data/images.ark,data/cifar100_train/images.scp || exit 1 + +local/process_data.py --dataset test $cifar100 data/cifar100_test/ | \ + copy-feats --compress=true --compression-method=6 \ + ark:- ark,scp:data/cifar100_test/data/images.ark,data/cifar100_test/images.scp || exit 1 diff --git a/egs/cifar/v1/local/process_data.py b/egs/cifar/v1/local/process_data.py new file mode 100755 index 00000000000..4b2c70bb6ad --- /dev/null +++ b/egs/cifar/v1/local/process_data.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python + +# Copyright 2017 Johns Hopkins University (author: Hossein Hadian) +# Apache 2.0 + + +""" This script prepares the training and test data for CIFAR-10 or CIFAR-100. +""" + +import argparse +import os +import sys +import re +import errno + + +sys.path.insert(0, 'steps') +import libs.common as common_lib + +parser = argparse.ArgumentParser(description="""Converts train/test data of + CIFAR-10 or CIFAR-100 to + Kaldi feature format""") +parser.add_argument('database', type=str, + default='data/dl/cifar-10-batches-bin', + help='path to downloaded cifar data (binary version)') +parser.add_argument('dir', type=str, help='output dir') +parser.add_argument('--dataset', type=str, default='train', choices=['train', 'test']) +parser.add_argument('--out-ark', type=str, default='-', help='where to write output feature data') + +args = parser.parse_args() + +# CIFAR image dimensions: +C = 3 # num_channels +H = 32 # num_rows +W = 32 # num_cols + +def load_cifar10_data_batch(datafile): + num_images_in_batch = 10000 + data = [] + labels = [] + with open(datafile, 'rb') as fh: + for i in range(num_images_in_batch): + label = ord(fh.read(1)) + bin_img = fh.read(C * H * W) + img = [[[ord(byte) / 255.0 for byte in bin_img[channel*H*W+row*W:channel*H*W+(row+1)*W]] + for row in range(H)] for channel in range(C)] + labels += [label] + data += [img] + return data, labels + +def load_cifar100_data_batch(datafile): + num_images_in_batch = 10000 + data = [] + fine_labels = [] + coarse_labels = [] + with open(datafile, 'rb') as fh: + for i in range(num_images_in_batch): + coarse_label = ord(fh.read(1)) + fine_label = ord(fh.read(1)) + bin_img = fh.read(C * H * W) + img = [[[ord(byte) / 255.0 for byte in bin_img[channel*H*W+row*W:channel*H*W+(row+1)*W]] + for row in range(H)] for channel in range(C)] + fine_labels += [fine_label] + coarse_labels += [coarse_label] + data += [img] + return data, fine_labels, coarse_labels + +def image_to_feat_matrix(img): + mat = [0]*H # 32 * 96 + for row in range(H): + mat[row] = [0]*C*W + for ch in range(C): + for col in range(W): + mat[row][col*C+ch] = img[ch][row][col] + return mat + +def write_kaldi_matrix(file_handle, matrix, key): + # matrix is a list of lists + file_handle.write(key + " [ ") + num_rows = len(matrix) + if num_rows == 0: + raise Exception("Matrix is empty") + num_cols = len(matrix[0]) + + for row_index in range(len(matrix)): + if num_cols != len(matrix[row_index]): + raise Exception("All the rows of a matrix are expected to " + "have the same length") + file_handle.write(" ".join(map(lambda x: str(x), matrix[row_index]))) + if row_index != num_rows - 1: + file_handle.write("\n") + file_handle.write(" ]\n") + +def zeropad(x, length): + s = str(x) + while len(s) < length: + s = '0' + s + return s + +### main ### +cifar10 = (args.database.find('cifar-100') == -1) +if args.out_ark == '-': + out_fh = sys.stdout # output file handle to write the feats to +else: + out_fh = open(args.out_ark, 'wb') + +if cifar10: + img_id = 1 # similar to utt_id + labels_file = os.path.join(args.dir, 'labels.txt') + labels_fh = open(labels_file, 'wb') + + + if args.dataset == 'train': + for i in range(1, 6): + fpath = os.path.join(args.database, 'data_batch_' + str(i) + '.bin') + data, labels = load_cifar10_data_batch(fpath) + for i in range(len(data)): + key = zeropad(img_id, 5) + labels_fh.write(key + ' ' + str(labels[i]) + '\n') + feat_mat = image_to_feat_matrix(data[i]) + write_kaldi_matrix(out_fh, feat_mat, key) + img_id += 1 + else: + fpath = os.path.join(args.database, 'test_batch.bin') + data, labels = load_cifar10_data_batch(fpath) + for i in range(len(data)): + key = zeropad(img_id, 5) + labels_fh.write(key + ' ' + str(labels[i]) + '\n') + feat_mat = image_to_feat_matrix(data[i]) + write_kaldi_matrix(out_fh, feat_mat, key) + img_id += 1 + + labels_fh.close() +else: + img_id = 1 # similar to utt_id + fine_labels_file = os.path.join(args.dir, 'fine_labels.txt') + coarse_labels_file = os.path.join(args.dir, 'coarse_labels.txt') + fine_labels_fh = open(fine_labels_file, 'wb') + coarse_labels_fh = open(coarse_labels_file, 'wb') + + if args.dataset == 'train': + fpath = os.path.join(args.database, 'train.bin') + data, fine_labels, coarse_labels = load_cifar100_data_batch(fpath) + for i in range(len(data)): + key = zeropad(img_id, 5) + fine_labels_fh.write(key + ' ' + str(fine_labels[i]) + '\n') + coarse_labels_fh.write(key + ' ' + str(coarse_labels[i]) + '\n') + feat_mat = image_to_feat_matrix(data[i]) + write_kaldi_matrix(out_fh, feat_mat, key) + img_id += 1 + else: + fpath = os.path.join(args.database, 'test.bin') + data, fine_labels, coarse_labels = load_cifar100_data_batch(fpath) + for i in range(len(data)): + key = zeropad(img_id, 5) + fine_labels_fh.write(key + ' ' + str(fine_labels[i]) + '\n') + coarse_labels_fh.write(key + ' ' + str(coarse_labels[i]) + '\n') + feat_mat = image_to_feat_matrix(data[i]) + write_kaldi_matrix(out_fh, feat_mat, key) + img_id += 1 + + fine_labels_fh.close() + coarse_labels_fh.close() + +out_fh.close() diff --git a/egs/cifar/v1/path.sh b/egs/cifar/v1/path.sh new file mode 100755 index 00000000000..2d17b17a84a --- /dev/null +++ b/egs/cifar/v1/path.sh @@ -0,0 +1,6 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C diff --git a/egs/cifar/v1/run.sh b/egs/cifar/v1/run.sh new file mode 100755 index 00000000000..2107a38eea4 --- /dev/null +++ b/egs/cifar/v1/run.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +stage=0 + +. ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. + ## This relates to the queue. +. utils/parse_options.sh # e.g. this parses the --stage option if supplied. + +if [ $stage -le 0 ]; then + # data preparation + local/prepare_data.sh + for x in cifar{10,100}_{train,test}; do + image/validate_image_dir.sh data/$x + done +fi diff --git a/egs/cifar/v1/steps b/egs/cifar/v1/steps new file mode 120000 index 00000000000..6e99bf5b5ad --- /dev/null +++ b/egs/cifar/v1/steps @@ -0,0 +1 @@ +../../wsj/s5/steps \ No newline at end of file diff --git a/egs/cifar/v1/utils b/egs/cifar/v1/utils new file mode 120000 index 00000000000..b240885218f --- /dev/null +++ b/egs/cifar/v1/utils @@ -0,0 +1 @@ +../../wsj/s5/utils \ No newline at end of file diff --git a/egs/swbd/s5c/RESULTS b/egs/swbd/s5c/RESULTS index f103200f966..7c82c55b012 100644 --- a/egs/swbd/s5c/RESULTS +++ b/egs/swbd/s5c/RESULTS @@ -191,11 +191,17 @@ exit 0 %WER 24.3 | 2628 21594 | 78.6 15.0 6.4 2.9 24.3 60.0 | exp/nnet3/tdnn_cnn_sp/decode_eval2000_hires_sw1_tg/score_10_0.0/eval2000_hires.ctm.callhm.filt.sys -# current best 'chain' models with TDNNs (see local/chain/run_tdnn_7g.sh) +# current best 'chain' models with TDNNs (see local/chain/tuning/run_tdnn_7k.sh) +# (4 epoch training on data being speed-perturbed, volume-perturbed) +%WER 15.0 | 4459 42989 | 86.6 8.7 4.7 1.7 15.0 51.1 | exp/chain/tdnn_7k_sp/decode_eval2000_sw1_fsh_fg/score_10_0.5/eval2000_hires.ctm.filt.sys +%WER 10.1 | 1831 21395 | 91.1 6.0 3.0 1.2 10.1 44.7 | exp/chain/tdnn_7k_sp/decode_eval2000_sw1_fsh_fg/score_10_0.5/eval2000_hires.ctm.swbd.filt.sys +%WER 19.9 | 2628 21594 | 82.3 11.4 6.3 2.2 19.9 55.6 | exp/chain/tdnn_7k_sp/decode_eval2000_sw1_fsh_fg/score_10_0.0/eval2000_hires.ctm.callhm.filt.sys + +#(see local/chain/multi_condition/run_tdnn_7f.sh) # (2 epoch training on data being speed-perturbed, volume-perturbed and reverberated with room impulse responses) -%WER 14.6 | 4459 42989 | 87.1 8.7 4.2 1.7 14.6 50.7 | exp/chain/tdnn_7g_sp/decode_eval2000_sw1_fsh_fg/score_10_0.0/eval2000_hires.ctm.filt.sys -%WER 9.8 | 1831 21395 | 91.2 5.7 3.1 1.1 9.8 43.4 | exp/chain/tdnn_7g_sp/decode_eval2000_sw1_fsh_fg/score_11_0.0/eval2000_hires.ctm.swbd.filt.sys -%WER 19.3 | 2628 21594 | 83.0 11.5 5.5 2.3 19.3 55.8 | exp/chain/tdnn_7g_sp/decode_eval2000_sw1_fsh_fg/score_10_0.0/eval2000_hires.ctm.callhm.filt.sys +%WER 14.6 | 4459 42989 | 87.1 8.7 4.2 1.7 14.6 50.7 | exp/chain/tdnn_7f_sp_rvb1/decode_eval2000_sw1_fsh_fg/score_10_0.0/eval2000_hires.ctm.filt.sys +%WER 9.8 | 1831 21395 | 91.2 5.7 3.1 1.1 9.8 43.4 | exp/chain/tdnn_7f_sp_rvb1/decode_eval2000_sw1_fsh_fg/score_11_0.0/eval2000_hires.ctm.swbd.filt.sys +%WER 19.3 | 2628 21594 | 83.0 11.5 5.5 2.3 19.3 55.8 | exp/chain/tdnn_7f_sp_rvb1/decode_eval2000_sw1_fsh_fg/score_10_0.0/eval2000_hires.ctm.callhm.filt.sys # current best 'chain' models with LSTM (see local/chain/run_lstm_d.sh) %WER 15.9 | 4459 42989 | 86.0 9.6 4.3 2.0 15.9 51.7 | exp/chain/lstm_d_ld5_sp/decode_eval2000_sw1_fsh_fg/score_10_0.0/eval2000_hires.ctm.filt.sys diff --git a/egs/swbd/s5c/local/chain/run_tdnn.sh b/egs/swbd/s5c/local/chain/run_tdnn.sh index 7b86453e14b..9e1ad7a0ba7 120000 --- a/egs/swbd/s5c/local/chain/run_tdnn.sh +++ b/egs/swbd/s5c/local/chain/run_tdnn.sh @@ -1 +1 @@ -tuning/run_tdnn_7h.sh \ No newline at end of file +tuning/run_tdnn_7k.sh \ No newline at end of file diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7k.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7k.sh new file mode 100755 index 00000000000..ae1b210f5c9 --- /dev/null +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7k.sh @@ -0,0 +1,263 @@ +#!/bin/bash + +# run_tdnn_7k.sh is like run_tdnn_7h.sh but batchnorm components instead of renorm + +# local/chain/compare_wer_general.sh tdnn_7h_sp/ tdnn_7k_sp/ +# System tdnn_7h_sp/ tdnn_7k_sp/ +# WER on train_dev(tg) 13.99 13.98 +# WER on train_dev(fg) 12.82 12.66 +# WER on eval2000(tg) 16.8 16.6 +# WER on eval2000(fg) 15.3 15.0 +# Final train prob -0.087 -0.087 +# Final valid prob -0.107 -0.103 +# Final train prob (xent) -1.252 -1.223 +# Final valid prob (xent) -1.3105 -1.2945 + +set -e + +# configs for 'chain' +affix= +stage=12 +train_stage=-10 +get_egs_stage=-10 +speed_perturb=true +dir=exp/chain/tdnn_7k # Note: _sp will get added to this if $speed_perturb == true. +decode_iter= +decode_nj=50 + +# training options +num_epochs=4 +initial_effective_lrate=0.001 +final_effective_lrate=0.0001 +leftmost_questions_truncate=-1 +max_param_change=2.0 +final_layer_normalize_target=0.5 +num_jobs_initial=3 +num_jobs_final=16 +minibatch_size=128 +frames_per_eg=150 +remove_egs=false +common_egs_dir= +xent_regularize=0.1 + +test_online_decoding=false # if true, it will run the last decoding stage. + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 11 ]; then + # Build a tree using our new topology. This is the critically different + # step compared with other recipes. + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ + --leftmost-questions-truncate $leftmost_questions_truncate \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 7000 data/$train_set $lang $ali_dir $treedir +fi + +if [ $stage -le 12 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=625 + relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1) dim=625 + relu-batchnorm-layer name=tdnn3 input=Append(-1,0,1) dim=625 + relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=625 + relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=625 + relu-batchnorm-layer name=tdnn6 input=Append(-3,0,3) dim=625 + relu-batchnorm-layer name=tdnn7 input=Append(-3,0,3) dim=625 + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain input=tdnn7 dim=625 target-rms=0.5 + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn7 dim=625 target-rms=0.5 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 + +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 13 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/swbd-$(date +'%m_%d_%H_%M')/s5c/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir exp/nnet3/ivectors_${train_set} \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width $frames_per_eg \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.num-jobs-initial $num_jobs_initial \ + --trainer.optimization.num-jobs-final $num_jobs_final \ + --trainer.optimization.initial-effective-lrate $initial_effective_lrate \ + --trainer.optimization.final-effective-lrate $final_effective_lrate \ + --trainer.max-param-change $max_param_change \ + --cleanup.remove-egs $remove_egs \ + --feat-dir data/${train_set}_hires \ + --tree-dir $treedir \ + --lat-dir exp/tri4_lats_nodup$suffix \ + --dir $dir || exit 1; + +fi + +if [ $stage -le 14 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 data/lang_sw1_tg $dir $dir/graph_sw1_tg +fi + + +graph_dir=$dir/graph_sw1_tg +iter_opts= +if [ ! -z $decode_iter ]; then + iter_opts=" --iter $decode_iter " +fi +if [ $stage -le 15 ]; then + rm $dir/.error 2>/dev/null || true + for decode_set in train_dev eval2000; do + ( + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $decode_nj --cmd "$decode_cmd" $iter_opts \ + --online-ivector-dir exp/nnet3/ivectors_${decode_set} \ + $graph_dir data/${decode_set}_hires \ + $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_sw1_tg || exit 1; + if $has_fisher; then + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + data/lang_sw1_{tg,fsh_fg} data/${decode_set}_hires \ + $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_sw1_{tg,fsh_fg} || exit 1; + fi + ) || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in decoding" + exit 1 + fi +fi + +if $test_online_decoding && [ $stage -le 16 ]; then + # note: if the features change (e.g. you add pitch features), you will have to + # change the options of the following command line. + steps/online/nnet3/prepare_online_decoding.sh \ + --mfcc-config conf/mfcc_hires.conf \ + $lang exp/nnet3/extractor $dir ${dir}_online + + rm $dir/.error 2>/dev/null || true + for decode_set in train_dev eval2000; do + ( + # note: we just give it "$decode_set" as it only uses the wav.scp, the + # feature type does not matter. + + steps/online/nnet3/decode.sh --nj $decode_nj --cmd "$decode_cmd" \ + --acwt 1.0 --post-decode-acwt 10.0 \ + $graph_dir data/${decode_set}_hires \ + ${dir}_online/decode_${decode_set}${decode_iter:+_$decode_iter}_sw1_tg || exit 1; + if $has_fisher; then + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + data/lang_sw1_{tg,fsh_fg} data/${decode_set}_hires \ + ${dir}_online/decode_${decode_set}${decode_iter:+_$decode_iter}_sw1_{tg,fsh_fg} || exit 1; + fi + ) || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in decoding" + exit 1 + fi +fi + + +exit 0; diff --git a/egs/tedlium/s5_r2/RESULTS b/egs/tedlium/s5_r2/RESULTS index ec4b9c24a12..cd4f6bfbaf3 100644 --- a/egs/tedlium/s5_r2/RESULTS +++ b/egs/tedlium/s5_r2/RESULTS @@ -130,12 +130,12 @@ for x in exp/nnet3_cleaned/tdnn_sp/decode_*; do grep Sum $x/*ore*/*ys | utils/be ########## nnet3+chain systems # chain+TDNN, small LM -%WER 10.4 | 507 17783 | 91.1 6.3 2.6 1.5 10.4 80.5 | 0.052 | exp/chain_cleaned/tdnn_sp_bi/decode_dev/score_10_0.0/ctm.filt.filt.sys -%WER 9.8 | 1155 27500 | 91.4 6.0 2.6 1.1 9.8 73.5 | 0.048 | exp/chain_cleaned/tdnn_sp_bi/decode_test/score_10_0.0/ctm.filt.filt.sys +%WER 9.2 | 507 17783 | 92.0 5.6 2.3 1.3 9.2 78.7 | 0.070 | exp/chain_cleaned/tdnn1e_sp_bi/decode_dev/score_9_0.0/ctm.filt.filt.sys +%WER 9.4 | 1155 27500 | 91.8 5.5 2.7 1.2 9.4 71.7 | 0.140 | exp/chain_cleaned/tdnn1e_sp_bi/decode_test/score_10_0.0/ctm.filt.filt.sys # chain+TDNN, large LM -%WER 9.8 | 507 17783 | 91.6 5.8 2.6 1.5 9.8 78.7 | 0.022 | exp/chain_cleaned/tdnn_sp_bi/decode_dev_rescore/score_10_0.0/ctm.filt.filt.sys -%WER 9.3 | 1155 27500 | 91.8 5.5 2.7 1.1 9.3 71.7 | 0.001 | exp/chain_cleaned/tdnn_sp_bi/decode_test_rescore/score_10_0.0/ctm.filt.filt.sys +%WER 8.6 | 507 17783 | 92.5 4.9 2.5 1.2 8.6 75.9 | 0.069 | exp/chain_cleaned/tdnn1e_sp_bi/decode_dev_rescore/score_10_0.0/ctm.filt.filt.sys +%WER 8.9 | 1155 27500 | 92.2 5.1 2.7 1.1 8.9 70.0 | 0.108 | exp/chain_cleaned/tdnn1e_sp_bi/decode_test_rescore/score_10_0.0/ctm.filt.filt.sys # chain+TDNN systems ran without cleanup, using the command: diff --git a/egs/tedlium/s5_r2/local/chain/run_tdnn.sh b/egs/tedlium/s5_r2/local/chain/run_tdnn.sh index e1adaa9346d..75da1a0a553 120000 --- a/egs/tedlium/s5_r2/local/chain/run_tdnn.sh +++ b/egs/tedlium/s5_r2/local/chain/run_tdnn.sh @@ -1 +1 @@ -tuning/run_tdnn_1d.sh \ No newline at end of file +tuning/run_tdnn_1e.sh \ No newline at end of file diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1e.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1e.sh new file mode 100755 index 00000000000..08eeba59c3d --- /dev/null +++ b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_1e.sh @@ -0,0 +1,248 @@ +#!/bin/bash + + +# run_tdnn_1e.sh is like run_tdnn_1d.sh but batchnorm components instead of renorm + +exp/chain_cleaned/tdnn1d_sp_bi: num-iters=253 nj=2..12 num-params=7.0M dim=40+100->3597 combine=-0.098->-0.097 xent:train/valid[167,252,final]=(-1.40,-1.34,-1.34/-1.50,-1.46,-1.46) logprob:train/valid[167,252,final]=(-0.091,-0.083,-0.083/-0.104,-0.101,-0.101) +exp/chain_cleaned/tdnn1e_sp_bi/: num-iters=253 nj=2..12 num-params=7.0M dim=40+100->3597 combine=-0.095->-0.095 xent:train/valid[167,252,final]=(-1.37,-1.31,-1.31/-1.47,-1.44,-1.44) logprob:train/valid[167,252,final]=(-0.087,-0.078,-0.078/-0.102,-0.099,-0.099) + +# local/chain/compare_wer_general.sh exp/chain_cleaned/tdnn1d_sp_bi exp/chain_cleaned/tdnn1e_sp_bi +# System tdnn1d_sp_bi tdnn1e_sp_bi +# WER on dev(orig) 9.4 9.2 +# WER on dev(rescored) 8.6 8.6 +# WER on test(orig) 9.7 9.4 +# WER on test(rescored) 9.1 8.9 +# Final train prob -0.0827 -0.0776 +# Final valid prob -0.1011 -0.0992 +# Final train prob (xent) -1.3404 -1.3110 +# Final valid prob (xent) -1.4575 -1.4353 + +## how you run this (note: this assumes that the run_tdnn.sh soft link points here; +## otherwise call it directly in its location). +# by default, with cleanup: +# local/chain/run_tdnn.sh + +# without cleanup: +# local/chain/run_tdnn.sh --train-set train --gmm tri3 --nnet3-affix "" & + +# note, if you have already run the corresponding non-chain nnet3 system +# (local/nnet3/run_tdnn.sh), you may want to run with --stage 14. + +# This script is like run_tdnn_1a.sh except it uses an xconfig-based mechanism +# to get the configuration. + +set -e -o pipefail + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +nj=30 +decode_nj=30 +min_seg_len=1.55 +xent_regularize=0.1 +train_set=train_cleaned +gmm=tri3_cleaned # the gmm for the target data +num_threads_ubm=32 +nnet3_affix=_cleaned # cleanup affix for nnet3 and chain dirs, e.g. _cleaned + +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +train_stage=-10 +tree_affix= # affix for tree directory, e.g. "a" or "b", in case we change the configuration. +tdnn_affix=1d #affix for TDNN directory, e.g. "a" or "b", in case we change the configuration. +common_egs_dir= # you can set this to use previously dumped egs. + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat <data/lang_chain/topo + fi +fi + +if [ $stage -le 15 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj 100 --cmd "$train_cmd" ${lores_train_data_dir} \ + data/lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 16 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --leftmost-questions-truncate -1 \ + --cmd "$train_cmd" 4000 ${lores_train_data_dir} data/lang_chain $ali_dir $tree_dir +fi + +if [ $stage -le 17 ]; then + mkdir -p $dir + + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=450 self-repair-scale=1.0e-04 + relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1) dim=450 + relu-batchnorm-layer name=tdnn3 input=Append(-1,0,1,2) dim=450 + relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=450 + relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=450 + relu-batchnorm-layer name=tdnn6 input=Append(-6,-3,0) dim=450 + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain input=tdnn6 dim=450 target-rms=0.5 + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn6 dim=450 target-rms=0.5 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 + +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ + +fi + +if [ $stage -le 18 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/ami-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir $train_ivector_dir \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize 0.1 \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --egs.dir "$common_egs_dir" \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width 150 \ + --trainer.num-chunk-per-minibatch 128 \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs 4 \ + --trainer.optimization.num-jobs-initial 2 \ + --trainer.optimization.num-jobs-final 12 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs true \ + --feat-dir $train_data_dir \ + --tree-dir $tree_dir \ + --lat-dir $lat_dir \ + --dir $dir +fi + + + +if [ $stage -le 19 ]; then + # Note: it might appear that this data/lang_chain directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 data/lang $dir $dir/graph +fi + +if [ $stage -le 20 ]; then + rm $dir/.error 2>/dev/null || true + for dset in dev test; do + ( + steps/nnet3/decode.sh --num-threads 4 --nj $decode_nj --cmd "$decode_cmd" \ + --acwt 1.0 --post-decode-acwt 10.0 \ + --online-ivector-dir exp/nnet3${nnet3_affix}/ivectors_${dset}_hires \ + --scoring-opts "--min-lmwt 5 " \ + $dir/graph data/${dset}_hires $dir/decode_${dset} || exit 1; + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" data/lang data/lang_rescore \ + data/${dset}_hires ${dir}/decode_${dset} ${dir}/decode_${dset}_rescore || exit 1 + ) || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in decoding" + exit 1 + fi +fi +exit 0 diff --git a/egs/wsj/s5/steps/libs/nnet3/report/log_parse.py b/egs/wsj/s5/steps/libs/nnet3/report/log_parse.py index cdbbb00a68a..b5d3e17dded 100755 --- a/egs/wsj/s5/steps/libs/nnet3/report/log_parse.py +++ b/egs/wsj/s5/steps/libs/nnet3/report/log_parse.py @@ -5,6 +5,7 @@ # Apache 2.0. from __future__ import division +from __future__ import print_function import traceback import datetime import logging @@ -15,6 +16,30 @@ logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) +g_lstmp_nonlin_regex_pattern = ''.join([".*progress.([0-9]+).log:component name=(.+) ", + "type=(.*)Component,.*", + "i_t_sigmoid.*", + "value-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*", + "deriv-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*", + "f_t_sigmoid.*", + "value-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*", + "deriv-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*", + "c_t_tanh.*", + "value-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*", + "deriv-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*", + "o_t_sigmoid.*", + "value-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*", + "deriv-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*", + "m_t_tanh.*", + "value-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*", + "deriv-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\]"]) + + +g_normal_nonlin_regex_pattern = ''.join([".*progress.([0-9]+).log:component name=(.+) ", + "type=(.*)Component,.*", + "value-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*", + "deriv-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\]"]) + class KaldiLogParseException(Exception): """ An Exception class that throws an error when there is an issue in parsing the log files. Extend this class if more granularity is needed. @@ -27,10 +52,55 @@ def __init__(self, message = None): "There was an error while trying to parse the logs." " Details : \n{0}\n".format(message)) +# This function is used to fill stats_per_component_per_iter table with the +# results of regular expression. +def fill_nonlin_stats_table_with_regex_result(groups, gate_index, stats_table): + iteration = int(groups[0]) + component_name = groups[1] + component_type = groups[2] + value_percentiles = groups[3+gate_index*6] + value_mean = float(groups[4+gate_index*6]) + value_stddev = float(groups[5+gate_index*6]) + value_percentiles_split = re.split(',| ',value_percentiles) + assert len(value_percentiles_split) == 13 + value_5th = float(value_percentiles_split[4]) + value_50th = float(value_percentiles_split[6]) + value_95th = float(value_percentiles_split[9]) + deriv_percentiles = groups[6+gate_index*6] + deriv_mean = float(groups[7+gate_index*6]) + deriv_stddev = float(groups[8+gate_index*6]) + deriv_percentiles_split = re.split(',| ',deriv_percentiles) + assert len(deriv_percentiles_split) == 13 + deriv_5th = float(deriv_percentiles_split[4]) + deriv_50th = float(deriv_percentiles_split[6]) + deriv_95th = float(deriv_percentiles_split[9]) + try: + if stats_table[component_name]['stats'].has_key(iteration): + stats_table[component_name]['stats'][iteration].extend( + [value_mean, value_stddev, + deriv_mean, deriv_stddev, + value_5th, value_50th, value_95th, + deriv_5th, deriv_50th, deriv_95th]) + else: + stats_table[component_name]['stats'][iteration] = [ + value_mean, value_stddev, + deriv_mean, deriv_stddev, + value_5th, value_50th, value_95th, + deriv_5th, deriv_50th, deriv_95th] + except KeyError: + stats_table[component_name] = {} + stats_table[component_name]['type'] = component_type + stats_table[component_name]['stats'] = {} + stats_table[component_name][ + 'stats'][iteration] = [value_mean, value_stddev, + deriv_mean, deriv_stddev, + value_5th, value_50th, value_95th, + deriv_5th, deriv_50th, deriv_95th] + def parse_progress_logs_for_nonlinearity_stats(exp_dir): - """ Parse progress logs for mean and std stats for non-linearities. + """ Parse progress logs for mean and std stats for non-linearities. e.g. for a line that is parsed from progress.*.log: exp/nnet3/lstm_self_repair_ld5_sp/log/progress.9.log:component name=Lstm3_i type=SigmoidComponent, dim=1280, self-repair-scale=1e-05, count=1.96e+05, @@ -48,39 +118,28 @@ def parse_progress_logs_for_nonlinearity_stats(exp_dir): progress_log_lines = common_lib.run_kaldi_command( 'grep -e "value-avg.*deriv-avg" {0}'.format(progress_log_files))[0] - parse_regex = re.compile( - ".*progress.([0-9]+).log:component name=(.+) " - "type=(.*)Component,.*" - "value-avg=\[.*mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*" - "deriv-avg=\[.*mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\]") + parse_regex = re.compile(g_normal_nonlin_regex_pattern) + for line in progress_log_lines.split("\n"): mat_obj = parse_regex.search(line) if mat_obj is None: continue - # groups = ('9', 'Lstm3_i', 'Sigmoid', '0.502', '0.23', - # '0.134', '0.0397') + # groups = ('9', 'Lstm3_i', 'Sigmoid', '0.05...0.99', '0.502', '0.23', + # '0.009...0.21', '0.134', '0.0397') groups = mat_obj.groups() - iteration = int(groups[0]) - component_name = groups[1] component_type = groups[2] - value_mean = float(groups[3]) - value_stddev = float(groups[4]) - deriv_mean = float(groups[5]) - deriv_stddev = float(groups[6]) - try: - stats_per_component_per_iter[component_name][ - 'stats'][iteration] = [value_mean, value_stddev, - deriv_mean, deriv_stddev] - except KeyError: - stats_per_component_per_iter[component_name] = {} - stats_per_component_per_iter[component_name][ - 'type'] = component_type - stats_per_component_per_iter[component_name]['stats'] = {} - stats_per_component_per_iter[component_name][ - 'stats'][iteration] = [value_mean, value_stddev, - deriv_mean, deriv_stddev] - + if component_type == 'LstmNonlinearity': + parse_regex_lstmp = re.compile(g_lstmp_nonlin_regex_pattern) + mat_obj = parse_regex_lstmp.search(line) + groups = mat_obj.groups() + assert len(groups) == 33 + for i in list(range(0,5)): + fill_nonlin_stats_table_with_regex_result(groups, i, + stats_per_component_per_iter) + else: + fill_nonlin_stats_table_with_regex_result(groups, 0, + stats_per_component_per_iter) return stats_per_component_per_iter diff --git a/egs/wsj/s5/steps/nnet3/report/generate_plots.py b/egs/wsj/s5/steps/nnet3/report/generate_plots.py index 6f185ad313f..6a652f9ec68 100755 --- a/egs/wsj/s5/steps/nnet3/report/generate_plots.py +++ b/egs/wsj/s5/steps/nnet3/report/generate_plots.py @@ -21,7 +21,7 @@ mpl.use('Agg') import matplotlib.pyplot as plt import numpy as np - + from matplotlib.patches import Rectangle g_plot = True except ImportError: warnings.warn( @@ -91,7 +91,6 @@ def get_args(): g_plot_colors = ['red', 'blue', 'green', 'black', 'magenta', 'yellow', 'cyan'] - class LatexReport: """Class for writing a Latex report""" @@ -213,6 +212,88 @@ def generate_acc_logprob_plots(exp_dir, output_dir, plot, key='accuracy', "Plot of {0} vs iterations for {1}".format(key, output_name)) +# The name of five gates of lstmp +g_lstm_gate = ['i_t_sigmoid', 'f_t_sigmoid', 'c_t_tanh', 'o_t_sigmoid', 'm_t_tanh'] + +# The "extra" item looks like a placeholder. As each unit in python plot is +# composed by a legend_handle(linestyle) and a legend_label(description). +# For the unit which doesn't have linestyle, we use the "extra" placeholder. +extra = Rectangle((0, 0), 1, 1, facecolor="w", fill=False, edgecolor='none', linewidth=0) + +# This function is used to insert a column to the legend, the column_index is 1-based +def insert_a_column_legend(legend_handle, legend_label, lp, mp, hp, + dir, prefix_length, column_index): + handle = [extra, lp, mp, hp] + label = ["[1]{0}".format(dir[prefix_length:]), "", "", ""] + for row in range(1,5): + legend_handle.insert(column_index*row-1, handle[row-1]) + legend_label.insert(column_index*row-1, label[row-1]) + + +# This function is used to plot a normal nonlinearity component or a gate of lstmp +def plot_a_nonlin_component(fig, dirs, stat_tables_per_component_per_dir, + component_name, common_prefix, prefix_length, component_type, + start_iter, gate_index=0): + fig.clf() + index = 0 + legend_handle = [extra, extra, extra, extra] + legend_label = ["", '5th percentile', '50th percentile', '95th percentile'] + + for dir in dirs: + color_val = g_plot_colors[index] + index += 1 + try: + iter_stats = (stat_tables_per_component_per_dir[dir][component_name]) + except KeyError: + # this component is not available in this network so lets + # not just plot it + insert_a_column_legend(legend_handle, legend_label, lp, mp, hp, + dir, prefix_length, index+1) + continue + + data = np.array(iter_stats) + data = data[data[:, 0] >= start_iter, :] + ax = plt.subplot(211) + lp, = ax.plot(data[:, 0], data[:, gate_index*10+5], color=color_val, + linestyle='--') + mp, = ax.plot(data[:, 0], data[:, gate_index*10+6], color=color_val, + linestyle='-') + hp, = ax.plot(data[:, 0], data[:, gate_index*10+7], color=color_val, + linestyle='--') + insert_a_column_legend(legend_handle, legend_label, lp, mp, hp, + dir, prefix_length, index+1) + + ax.set_ylabel('Value-{0}'.format(component_type)) + ax.grid(True) + + ax = plt.subplot(212) + lp, = ax.plot(data[:, 0], data[:, gate_index*10+8], color=color_val, + linestyle='--') + mp, = ax.plot(data[:, 0], data[:, gate_index*10+9], color=color_val, + linestyle='-') + hp, = ax.plot(data[:, 0], data[:, gate_index*10+10], color=color_val, + linestyle='--') + ax.set_xlabel('Iteration') + ax.set_ylabel('Derivative-{0}'.format(component_type)) + ax.grid(True) + + lgd = plt.legend(legend_handle, legend_label, loc='lower center', + bbox_to_anchor=(0.5 , -0.5 + len(dirs) * -0.2), + ncol=4, handletextpad = -2, title="[1]:{0}".format(common_prefix), + borderaxespad=0.) + plt.grid(True) + return lgd + + +# This function is used to generate the statistic plots of nonlinearity component +# Mainly divided into the following steps: +# 1) With log_parse function, we get the statistics from each directory. +# 2) Convert the collected nonlinearity statistics into the tables. Each table +# contains all the statistics in each component of each directory. +# 3) The statistics of each component are stored into corresponding log files. +# Each line of the log file contains the statistics of one iteration. +# 4) Plot the "Per-dimension average-(value, derivative) percentiles" figure +# for each nonlinearity component. def generate_nonlin_stats_plots(exp_dir, output_dir, plot, comparison_dir=None, start_iter=1, latex_report=None): assert start_iter >= 1 @@ -230,7 +311,6 @@ def generate_nonlin_stats_plots(exp_dir, output_dir, plot, comparison_dir=None, logger.warning("Couldn't find any rows for the" "nonlin stats plot, not generating it") stats_per_dir[dir] = stats_per_component_per_iter - # convert the nonlin stats into tables stat_tables_per_component_per_dir = {} for dir in dirs: @@ -254,15 +334,15 @@ def generate_nonlin_stats_plots(exp_dir, output_dir, plot, comparison_dir=None, # this is the main experiment directory with open("{dir}/nonlinstats_{comp_name}.log".format( dir=output_dir, comp_name=component_name), "w") as f: - f.write( - "Iteration\tValueMean\tValueStddev\tDerivMean\tDerivStddev\n") + f.write("Iteration\tValueMean\tValueStddev\tDerivMean\tDerivStddev\t" + "Value_5th\tValue_50th\tValue_95th\t" + "Deriv_5th\tDeriv_50th\tDeriv_95th\n") iter_stat_report = [] iter_stats = main_stat_tables[component_name] for row in iter_stats: iter_stat_report.append("\t".join([str(x) for x in row])) f.write("\n".join(iter_stat_report)) f.close() - if plot: main_component_names = main_stat_tables.keys() main_component_names.sort() @@ -279,64 +359,50 @@ def generate_nonlin_stats_plots(exp_dir, output_dir, plot, comparison_dir=None, given experiment dirs are not the same, so comparison plots are provided only for common component names. Make sure that these are comparable experiments before analyzing these plots.""") - + fig = plt.figure() + + common_prefix = os.path.commonprefix(dirs) + prefix_length = common_prefix.rfind('/') + common_prefix = common_prefix[0:prefix_length] + for component_name in main_component_names: - fig.clf() - index = 0 - plots = [] - for dir in dirs: - color_val = g_plot_colors[index] - index += 1 - try: - iter_stats = ( - stat_tables_per_component_per_dir[dir][component_name]) - except KeyError: - # this component is not available in this network so lets - # not just plot it - continue - - data = np.array(iter_stats) - data = data[data[:, 0] >= start_iter, :] - ax = plt.subplot(211) - mp, = ax.plot(data[:, 0], data[:, 1], color=color_val, - label="Mean {0}".format(dir)) - msph, = ax.plot(data[:, 0], data[:, 1] + data[:, 2], - color=color_val, linestyle='--', - label="Mean+-Stddev {0}".format(dir)) - mspl, = ax.plot(data[:, 0], data[:, 1] - data[:, 2], - color=color_val, linestyle='--') - plots.append(mp) - plots.append(msph) - ax.set_ylabel('Value-{0}'.format(comp_type)) - ax.grid(True) - - ax = plt.subplot(212) - mp, = ax.plot(data[:, 0], data[:, 3], color=color_val) - msph, = ax.plot(data[:, 0], data[:, 3] + data[:, 4], - color=color_val, linestyle='--') - mspl, = ax.plot(data[:, 0], data[:, 3] - data[:, 4], - color=color_val, linestyle='--') - ax.set_xlabel('Iteration') - ax.set_ylabel('Derivative-{0}'.format(comp_type)) - ax.grid(True) - - lgd = plt.legend(handles=plots, loc='lower center', - bbox_to_anchor=(0.5, -0.5 + len(dirs) * -0.2), - ncol=1, borderaxespad=0.) - plt.grid(True) - fig.suptitle("Mean and stddev of the value and derivative at " - "{comp_name}".format(comp_name=component_name)) - comp_name = latex_compliant_name(component_name) - figfile_name = '{dir}/nonlinstats_{comp_name}.pdf'.format( - dir=output_dir, comp_name=comp_name) - fig.savefig(figfile_name, bbox_extra_artists=(lgd,), + if stats_per_dir[exp_dir][component_name]['type'] == 'LstmNonlinearity': + for i in range(0,5): + component_type = 'Lstm-' + g_lstm_gate[i] + lgd = plot_a_nonlin_component(fig, dirs, + stat_tables_per_component_per_dir, component_name, + common_prefix, prefix_length, component_type, start_iter, i) + fig.suptitle("Per-dimension average-(value, derivative) percentiles for " + "{component_name}-{gate}".format(component_name=component_name, gate=g_lstm_gate[i])) + comp_name = latex_compliant_name(component_name) + figfile_name = '{dir}/nonlinstats_{comp_name}_{gate}.pdf'.format( + dir=output_dir, comp_name=comp_name, gate=g_lstm_gate[i]) + fig.savefig(figfile_name, bbox_extra_artists=(lgd,), bbox_inches='tight') - if latex_report is not None: - latex_report.add_figure( + if latex_report is not None: + latex_report.add_figure( + figfile_name, + "Per-dimension average-(value, derivative) percentiles for " + "{0}-{1}".format(component_name, g_lstm_gate[i])) + else: + component_type = stats_per_dir[exp_dir][component_name]['type'] + lgd = plot_a_nonlin_component(fig, dirs, + stat_tables_per_component_per_dir,component_name, + common_prefix, prefix_length, component_type, start_iter, 0) + fig.suptitle("Per-dimension average-(value, derivative) percentiles for " + "{component_name}".format(component_name=component_name)) + comp_name = latex_compliant_name(component_name) + figfile_name = '{dir}/nonlinstats_{comp_name}.pdf'.format( + dir=output_dir, comp_name=comp_name) + fig.savefig(figfile_name, bbox_extra_artists=(lgd,), + bbox_inches='tight') + if latex_report is not None: + latex_report.add_figure( figfile_name, - "Mean and stddev of the value and derivative " - "at {0}".format(component_name)) + "Per-dimension average-(value, derivative) percentiles for " + "{0}".format(component_name)) + def generate_clipped_proportion_plots(exp_dir, output_dir, plot, diff --git a/src/bin/Makefile b/src/bin/Makefile index 687040889b3..d9f8d3d27ae 100644 --- a/src/bin/Makefile +++ b/src/bin/Makefile @@ -24,7 +24,7 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \ matrix-logprob matrix-sum \ build-pfile-from-ali get-post-on-ali tree-info am-info \ vector-sum matrix-sum-rows est-pca sum-lda-accs sum-mllt-accs \ - transform-vec align-text matrix-dim + transform-vec align-text matrix-dim post-to-smat OBJFILES = diff --git a/src/bin/post-to-smat.cc b/src/bin/post-to-smat.cc new file mode 100644 index 00000000000..8cd8df41647 --- /dev/null +++ b/src/bin/post-to-smat.cc @@ -0,0 +1,84 @@ +// bin/post-to-smat.cc + +// Copyright 2017 Johns Hopkins University (Author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "hmm/posterior.h" + +int main(int argc, char *argv[]) { + using namespace kaldi; + typedef kaldi::int32 int32; + try { + const char *usage = + "This program turns an archive of per-frame posteriors, e.g. from\n" + "ali-to-post | post-to-pdf-post,\n" + "into an archive of SparseMatrix. This is just a format transformation.\n" + "This may not make sense if the indexes in question are one-based (at least,\n" + "you'd have to increase the dimension by one.\n" + "\n" + "See also: post-to-phone-post, ali-to-post, post-to-pdf-post\n" + "\n" + "Usage: post-to-smat [options] \n" + "e.g.: post-to-smat --dim=1038 ark:- ark:-\n"; + + ParseOptions po(usage); + + int32 dim = -1; + + po.Register("dim", &dim, "The num-cols in each output SparseMatrix. All " + "the integers in the input posteriors are expected to be \n" + ">= 0 and < dim. This must be specified."); + + po.Read(argc, argv); + + if (dim <= 0) { + KALDI_ERR << "The --dim option must be specified."; + } + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string posteriors_rspecifier = po.GetArg(1), + sparse_matrix_wspecifier = po.GetArg(2); + + + SequentialPosteriorReader posterior_reader(posteriors_rspecifier); + + TableWriter > > sparse_matrix_writer( + sparse_matrix_wspecifier); + + int32 num_done = 0; + for (; !posterior_reader.Done(); posterior_reader.Next()) { + const kaldi::Posterior &posterior = posterior_reader.Value(); + // The following constructor will throw an error if there is some kind of + // dimension mismatch. + SparseMatrix smat(dim, posterior); + sparse_matrix_writer.Write(posterior_reader.Key(), smat); + num_done++; + } + KALDI_LOG << "Done converting " << num_done + << " posteriors into sparse matrices."; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/makefiles/default_rules.mk b/src/makefiles/default_rules.mk index 34abd905924..fb230d88e0b 100644 --- a/src/makefiles/default_rules.mk +++ b/src/makefiles/default_rules.mk @@ -94,7 +94,7 @@ test: test_compile rm -rf core; \ fi; \ else \ - echo " $${time_taken}s... SUCCESS"; \ + echo " $${time_taken}s... SUCCESS $$x"; \ rm -f $$x.testlog; \ fi; \ done; \ diff --git a/src/matrix/compressed-matrix.cc b/src/matrix/compressed-matrix.cc index 45965e87651..2db73dedf75 100644 --- a/src/matrix/compressed-matrix.cc +++ b/src/matrix/compressed-matrix.cc @@ -58,7 +58,7 @@ void CompressedMatrix::ComputeGlobalHeader( case kTwoByteAuto: case kTwoByteSignedInteger: header->format = static_cast(kTwoByte); // 2. break; - case kOneByteAuto: case kOneByteInteger: case kOneByteZeroOne: + case kOneByteAuto: case kOneByteUnsignedInteger: case kOneByteZeroOne: header->format = static_cast(kOneByte); // 3. break; default: @@ -95,6 +95,11 @@ void CompressedMatrix::ComputeGlobalHeader( header->range = 65535.0; break; } + case kOneByteUnsignedInteger: { + header->min_value = 0.0; + header->range = 255.0; + break; + } case kOneByteZeroOne: { header->min_value = 0.0; header->range = 1.0; diff --git a/src/matrix/compressed-matrix.h b/src/matrix/compressed-matrix.h index 7166192b78c..568aa7275c5 100644 --- a/src/matrix/compressed-matrix.h +++ b/src/matrix/compressed-matrix.h @@ -60,7 +60,10 @@ namespace kaldi { representable range of values chosen automatically with the minimum and maximum elements of the matrix as its edges. - kOneByteZeroOne = 6 Each element is stored in + kOneByteUnsignedInteger = 6 Each element is stored in + one byte as a uint8, with the representable range of + values equal to [0.0, 255.0]. + kOneByteZeroOne = 7 Each element is stored in one byte as a uint8, with the representable range of values equal to [0.0, 1.0]. Suitable for image data that has previously been compressed as int8. @@ -75,7 +78,7 @@ enum CompressionMethod { kTwoByteAuto = 3, kTwoByteSignedInteger = 4, kOneByteAuto = 5, - kOneByteInteger = 6, + kOneByteUnsignedInteger = 6, kOneByteZeroOne = 7 }; diff --git a/src/nnet3bin/Makefile b/src/nnet3bin/Makefile index 2bae1dcdc43..e633f4c5fde 100644 --- a/src/nnet3bin/Makefile +++ b/src/nnet3bin/Makefile @@ -16,8 +16,9 @@ BINFILES = nnet3-init nnet3-info nnet3-get-egs nnet3-copy-egs nnet3-subset-egs \ nnet3-discriminative-get-egs nnet3-discriminative-copy-egs \ nnet3-discriminative-merge-egs nnet3-discriminative-shuffle-egs \ nnet3-discriminative-compute-objf nnet3-discriminative-train \ - nnet3-discriminative-subset-egs \ - nnet3-discriminative-compute-from-egs nnet3-latgen-faster-looped + nnet3-discriminative-subset-egs nnet3-get-egs-simple \ + nnet3-discriminative-compute-from-egs nnet3-latgen-faster-looped \ + nnet3-egs-augment-image OBJFILES = diff --git a/src/nnet3bin/nnet3-copy-egs.cc b/src/nnet3bin/nnet3-copy-egs.cc index 42413114af3..9375dd16ce5 100644 --- a/src/nnet3bin/nnet3-copy-egs.cc +++ b/src/nnet3bin/nnet3-copy-egs.cc @@ -270,7 +270,8 @@ int main(int argc, char *argv[]) { "e.g.\n" "nnet3-copy-egs ark:train.egs ark,t:text.egs\n" "or:\n" - "nnet3-copy-egs ark:train.egs ark:1.egs ark:2.egs\n"; + "nnet3-copy-egs ark:train.egs ark:1.egs ark:2.egs\n" + "See also: nnet3-subset-egs, nnet3-get-egs, nnet3-merge-egs, nnet3-shuffle-egs\n"; bool random = false; int32 srand_seed = 0; diff --git a/src/nnet3bin/nnet3-egs-augment-image.cc b/src/nnet3bin/nnet3-egs-augment-image.cc new file mode 100644 index 00000000000..1322896533c --- /dev/null +++ b/src/nnet3bin/nnet3-egs-augment-image.cc @@ -0,0 +1,231 @@ +// nnet3bin/nnet3-egs-augment-image.cc + +// Copyright 2017 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "hmm/transition-model.h" +#include "nnet3/nnet-example.h" +#include "nnet3/nnet-example-utils.h" + +namespace kaldi { +namespace nnet3 { + +struct ImageAugmentationConfig { + int32 num_channels; + BaseFloat horizontal_flip_prob; + BaseFloat horizontal_shift; + BaseFloat vertical_shift; + + ImageAugmentationConfig(): + num_channels(1), + horizontal_flip_prob(0.0), + horizontal_shift(0.0), + vertical_shift(0.0) { } + + + void Register(ParseOptions *po) { + po->Register("num-channels", &num_channels, "Number of colors in the image." + "It is is important to specify this (helps interpret the image " + "correctly."); + po->Register("horizontal-flip-prob", &horizontal_flip_prob, + "Probability of doing horizontal flip"); + po->Register("horizontal-shift", &horizontal_shift, + "Maximum allowed horizontal shift as proportion of image " + "width. Padding is with closest pixel."); + // TODO: vertical_shift + } + + void Check() const { + KALDI_ASSERT(num_channels >= 1); + KALDI_ASSERT(horizontal_flip_prob >= 0 && + horizontal_flip_prob <= 1); + KALDI_ASSERT(horizontal_shift >= 0 && horizontal_shift <= 1); + KALDI_ASSERT(vertical_shift >= 0 && vertical_shift <= 1); + } +}; + + +/* Flips the image horizontally. */ +void HorizontalFlip(MatrixBase *image) { + int32 num_rows = image->NumRows(); + Vector temp(image->NumCols()); + for (int32 r = 0; r < num_rows / 2; r++) { + SubVector row_a(*image, r), row_b(*image, + num_rows - r - 1); + temp.CopyFromVec(row_a); + // TODO + } +} + + +// Shifts the image horizontally by 'horizontal_shift' (+ve == to the right). +void HorizontalShift(int32 horizontal_shift, + MatrixBase *image) { + // TODO. +} + +void VerticalShift(int32 vertical_shift, + int32 num_channels, + MatrixBase *image) { + // TODO. + int32 num_rows = image->NumRows(), + num_cols = image->NumCols(), height = num_cols / num_channels; + KALDI_ASSERT(num_cols % num_channels == 0); + for (int32 r = 0; r < num_rows; r++) { + BaseFloat *this_row = image->RowData(r); + // TODO: Do something with 'this_row'. + } +} + + + +/** + This function randomly modifies (perturbs) the image. + + @param [in] config Configuration class that says how + to perturb the image. + @param [in,out] image The image matrix to be modified. + image->NumRows() is the width (number of x values) in + the image; image->NumCols() is the height times number + of channels/colors (channel varies the fastest). + */ +void PerturbImage(const ImageAugmentationConfig &config, + MatrixBase *image) { + config.Check(); + int32 image_width = image->NumRows(), + num_channels = config.num_channels, + image_height = image->NumCols() / num_channels; + if (image->NumCols() % num_channels != 0) { + KALDI_ERR << "Number of columns in image must divide the number " + "of channels"; + } + if (WithProb(config.horizontal_flip_prob)) { + HorizontalFlip(image); + } + { // horizontal shift + int32 horizontal_shift_max = + static_cast(0.5 + config.horizontal_shift * image_width); + if (horizontal_shift_max > image_width - 1) + horizontal_shift_max = image_width - 1; // would be very strange. + int32 horizontal_shift = RandInt(-horizontal_shift_max, + horizontal_shift_max); + if (horizontal_shift != 0) + HorizontalShift(horizontal_shift_max, image); + } + + // TODO, vertical shift + +} + + +/** + This function does image perturbation as directed by 'config' + The example 'eg' is expected to contain a NnetIo member with the + name 'input', representing an image. + */ +void PerturbImageInNnetExample( + const ImageAugmentationConfig &config, + NnetExample *eg) { + int32 io_size = eg->io.size(); + bool found_input = false; + for (int32 i = 0; i < io_size; i++) { + NnetIo &io = eg->io[i]; + if (io.name == "input") { + found_input = true; + Matrix image; + io.features.GetMatrix(&image); + // note: 'GetMatrix' may uncompress if it was compressed. + // We won't recompress, but this won't matter because this + // program is intended to be used as part of a pipe, we + // likely won't be dumping the perturbed data to disk. + PerturbImage(config, &image); + + // modify the 'io' object. + io.features = image; + } + } + if (!found_input) + KALDI_ERR << "Nnet example to perturb had no NnetIo object named 'input'"; +} + + +} // namespace nnet3 +} // namespace kaldi + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace kaldi::nnet3; + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "Copy examples (single frames or fixed-size groups of frames) for neural\n" + "network training, doing image augmentation inline (copies after possibly\n" + "modifying of each image, randomly chosen according to configuration\n" + "parameters).\n" + "E.g.:\n" + " nnet3-egs-augment-image --horizontal-flip-prob=0.5 --horizontal-shift=0.1\\\n" + " --vertical-shift=0.1 --srand=103 --num-channels=3 ark:- ark:-\n" + "\n" + "Requires that each eg contain a NnetIo object 'input', with successive\n" + "'t' values representing different x offsets , and the feature dimension\n" + "representing the y offset and the channel (color), with the channel\n" + "varying the fastest.\n" + "See also: nnet3-copy-egs\n"; + + + int32 srand_seed = 0; + + ImageAugmentationConfig config; + + ParseOptions po(usage); + po.Register("srand", &srand_seed, "Seed for the random number generator"); + config.Register(&po); + + po.Read(argc, argv); + + srand(srand_seed); + + if (po.NumArgs() < 2) { + po.PrintUsage(); + exit(1); + } + + std::string examples_rspecifier = po.GetArg(1), + examples_wspecifier = po.GetArg(2); + + SequentialNnetExampleReader example_reader(examples_rspecifier); + NnetExampleWriter example_writer(examples_wspecifier); + + + int64 num_done = 0; + for (; !example_reader.Done(); example_reader.Next(), num_done++) { + std::string key = example_reader.Key(); + NnetExample eg(example_reader.Value()); + PerturbImageInNnetExample(config, &eg); + example_writer.Write(key, eg); + } + KALDI_LOG << "Perturbed" << num_done << " neural-network training images."; + return (num_done == 0 ? 1 : 0); + } catch(const std::exception &e) { + std::cerr << e.what() << '\n'; + return -1; + } +} diff --git a/src/nnet3bin/nnet3-get-egs.cc b/src/nnet3bin/nnet3-get-egs.cc index efab3e89a5f..229192bc4b1 100644 --- a/src/nnet3bin/nnet3-get-egs.cc +++ b/src/nnet3bin/nnet3-get-egs.cc @@ -163,7 +163,8 @@ int main(int argc, char *argv[]) { "An example [where $feats expands to the actual features]:\n" "nnet3-get-egs --num-pdfs=2658 --left-context=12 --right-context=9 --num-frames=8 \"$feats\"\\\n" "\"ark:gunzip -c exp/nnet/ali.1.gz | ali-to-pdf exp/nnet/1.nnet ark:- ark:- | ali-to-post ark:- ark:- |\" \\\n" - " ark:- \n"; + " ark:- \n" + "See also: nnet3-chain-get-egs, nnet3-get-egs-simple\n"; bool compress = true;