diff --git a/egs/aishell2/s5/conf/online_cmvn.conf b/egs/aishell2/s5/conf/online_cmvn.conf new file mode 100644 index 00000000000..048bdfa65de --- /dev/null +++ b/egs/aishell2/s5/conf/online_cmvn.conf @@ -0,0 +1 @@ +# configuration file for apply-cmvn-online diff --git a/egs/ami/s5/local/ami_download.sh b/egs/ami/s5/local/ami_download.sh index b14f8550c75..cba130c8467 100755 --- a/egs/ami/s5/local/ami_download.sh +++ b/egs/ami/s5/local/ami_download.sh @@ -53,12 +53,12 @@ cat local/split_train.orig local/split_eval.orig local/split_dev.orig > $wdir/am wgetfile=$wdir/wget_$mic.sh # TODO fix this with Pawel, files don't exist anymore, -manifest="wget --continue -O $adir/MANIFEST.TXT http://groups.inf.ed.ac.uk/ami/download/temp/amiBuild-04237-Sun-Jun-15-2014.manifest.txt" -license="wget --continue -O $adir/LICENCE.TXT http://groups.inf.ed.ac.uk/ami/download/temp/Creative-Commons-Attribution-NonCommercial-ShareAlike-2.5.txt" +manifest="wget --continue -O $adir/MANIFEST.TXT http://groups.inf.ed.ac.uk/ami/download/temp/amiBuild-0153-Tue-Oct-2-2018.manifest.txt" + echo "#!/bin/bash" > $wgetfile echo $manifest >> $wgetfile -echo $license >> $wgetfile + while read line; do if [ "$mic" == "ihm" ]; then extra_headset= #some meetings have 5 sepakers (headsets) @@ -100,8 +100,7 @@ else fi fi -echo "Downloads of AMI corpus completed succesfully. License can be found under $adir/LICENCE.TXT" +echo "Downloads of AMI corpus completed succesfully." exit 0; - diff --git a/egs/ami/s5/local/tfrnnlm/run_lstm.sh b/egs/ami/s5/local/tfrnnlm/run_lstm.sh index 31ae4a8bad7..a298590a31d 100755 --- a/egs/ami/s5/local/tfrnnlm/run_lstm.sh +++ b/egs/ami/s5/local/tfrnnlm/run_lstm.sh @@ -27,7 +27,7 @@ mkdir -p $dir if [ $stage -le 2 ]; then # the following script uses TensorFlow. You could use tools/extras/install_tensorflow_py.sh to install it $cuda_cmd $dir/train_rnnlm.log utils/parallel/limit_num_gpus.sh \ - python steps/tfrnnlm/lstm.py --data-path=$dir --save-path=$dir/rnnlm --vocab-path=$dir/wordlist.rnn.final + python steps/tfrnnlm/lstm.py --data_path=$dir --save_path=$dir/rnnlm --vocab_path=$dir/wordlist.rnn.final fi final_lm=ami_fsh.o3g.kn diff --git a/egs/ami/s5/local/tfrnnlm/run_lstm_fast.sh b/egs/ami/s5/local/tfrnnlm/run_lstm_fast.sh index 8dd876c2b2c..4cc71b55b5c 100755 --- a/egs/ami/s5/local/tfrnnlm/run_lstm_fast.sh +++ b/egs/ami/s5/local/tfrnnlm/run_lstm_fast.sh @@ -27,7 +27,7 @@ mkdir -p $dir if [ $stage -le 2 ]; then # the following script uses TensorFlow. You could use tools/extras/install_tensorflow_py.sh to install it $cuda_cmd $dir/train_rnnlm.log utils/parallel/limit_num_gpus.sh \ - python steps/tfrnnlm/lstm_fast.py --data-path=$dir --save-path=$dir/rnnlm --vocab-path=$dir/wordlist.rnn.final + python steps/tfrnnlm/lstm_fast.py --data_path=$dir --save_path=$dir/rnnlm --vocab_path=$dir/wordlist.rnn.final fi final_lm=ami_fsh.o3g.kn diff --git a/egs/ami/s5/local/tfrnnlm/run_vanilla_rnnlm.sh b/egs/ami/s5/local/tfrnnlm/run_vanilla_rnnlm.sh index 7a4635f07a4..15d237b0e12 100755 --- a/egs/ami/s5/local/tfrnnlm/run_vanilla_rnnlm.sh +++ b/egs/ami/s5/local/tfrnnlm/run_vanilla_rnnlm.sh @@ -27,7 +27,7 @@ mkdir -p $dir if [ $stage -le 2 ]; then # the following script uses TensorFlow. You could use tools/extras/install_tensorflow_py.sh to install it $cuda_cmd $dir/train_rnnlm.log utils/parallel/limit_num_gpus.sh \ - python steps/tfrnnlm/vanilla_rnnlm.py --data-path=$dir --save-path=$dir/rnnlm --vocab-path=$dir/wordlist.rnn.final + python steps/tfrnnlm/vanilla_rnnlm.py --data_path=$dir --save_path=$dir/rnnlm --vocab_path=$dir/wordlist.rnn.final fi final_lm=ami_fsh.o3g.kn diff --git a/egs/ami/s5b/RESULTS_ihm b/egs/ami/s5b/RESULTS_ihm index 42af5763829..ee2e5dd1cc2 100644 --- a/egs/ami/s5b/RESULTS_ihm +++ b/egs/ami/s5b/RESULTS_ihm @@ -88,6 +88,10 @@ # local/chain/multi_condition/tuning/run_tdnn_lstm_1a.sh --mic ihm # cleanup + chain TDNN+LSTM model + IHM reverberated data +# Old results: %WER 19.4 | 13098 94479 | 83.8 10.0 6.1 3.2 19.4 51.8 | -0.168 | exp/ihm/chain_cleaned_rvb/tdnn_lstm1i_sp_rvb_bi/decode_dev/ascore_10/dev_hires.ctm.filt.sys %WER 19.3 | 12643 89977 | 83.3 11.0 5.7 2.6 19.3 49.6 | -0.046 | exp/ihm/chain_cleaned_rvb/tdnn_lstm1i_sp_rvb_bi/decode_eval/ascore_10/eval_hires.ctm.filt.sys +# New results after simplifying scripts to remove combining short segments etc.: +%WER 19.4 | 12643 89979 | 83.1 10.9 6.0 2.5 19.4 50.7 | 0.010 | exp/ihm/chain_cleaned_rvb/tdnn_lstm1a_sp_rvb_bi/decode_eval/ascore_11/eval_hires.ctm.filt.sys +%WER 19.4 | 13098 94484 | 83.7 10.2 6.1 3.1 19.4 52.0 | -0.119 | exp/ihm/chain_cleaned_rvb/tdnn_lstm1a_sp_rvb_bi/decode_dev/ascore_11/dev_hires.ctm.filt.sys diff --git a/egs/ami/s5b/RESULTS_sdm b/egs/ami/s5b/RESULTS_sdm index 0993b2eb52a..96568ac70cb 100644 --- a/egs/ami/s5b/RESULTS_sdm +++ b/egs/ami/s5b/RESULTS_sdm @@ -96,6 +96,10 @@ # local/chain/multi_condition/tuning/run_tdnn_lstm_1a.sh --mic sdm1 --use-ihm-ali true --train-set train_cleaned --gmm tri3_cleaned # cleanup + chain TDNN+LSTM model, SDM original + IHM reverberated data, alignments from ihm data. # *** best system *** +# Old results: %WER 34.0 | 14455 94497 | 69.8 17.7 12.5 3.8 34.0 63.9 | 0.675 | exp/sdm1/chain_cleaned_rvb/tdnn_lstm1i_sp_rvb_bi_ihmali/decode_dev/ascore_10/dev_hires_o4.ctm.filt.sys %WER 37.5 | 13261 89982 | 65.9 19.3 14.7 3.5 37.5 66.2 | 0.642 | exp/sdm1/chain_cleaned_rvb/tdnn_lstm1i_sp_rvb_bi_ihmali/decode_eval/ascore_10/eval_hires_o4.ctm.filt.sys +# New results after simplifying scripts to remove combining short segments etc.: +%WER 34.6 | 14604 94498 | 69.6 18.8 11.6 4.2 34.6 64.4 | 0.652 | exp/sdm1/chain_cleaned_rvb/tdnn_lstm1a_sp_rvb_bi_ihmali/decode_dev/ascore_11/dev_hires_o4.ctm.filt.sys +%WER 37.6 | 13606 89636 | 66.1 21.0 12.9 3.7 37.6 65.1 | 0.613 | exp/sdm1/chain_cleaned_rvb/tdnn_lstm1a_sp_rvb_bi_ihmali/decode_eval/ascore_11/eval_hires_o4.ctm.filt.sys diff --git a/egs/ami/s5b/local/chain/multi_condition/tuning/run_tdnn_lstm_1a.sh b/egs/ami/s5b/local/chain/multi_condition/tuning/run_tdnn_lstm_1a.sh index 2869049843f..a8494420b0d 100755 --- a/egs/ami/s5b/local/chain/multi_condition/tuning/run_tdnn_lstm_1a.sh +++ b/egs/ami/s5b/local/chain/multi_condition/tuning/run_tdnn_lstm_1a.sh @@ -19,7 +19,6 @@ set -e -o pipefail stage=0 mic=ihm nj=30 -min_seg_len=1.55 use_ihm_ali=false train_set=train_cleaned gmm=tri3_cleaned # the gmm for the target data @@ -27,7 +26,7 @@ ihm_gmm=tri3_cleaned # the gmm for the IHM system (if --use-ihm-ali true). num_threads_ubm=32 num_data_reps=1 -chunk_width=150 +chunk_width=160,140,110,80 chunk_left_context=40 chunk_right_context=0 label_delay=5 @@ -35,13 +34,13 @@ label_delay=5 # are just hardcoded at this level, in the commands below. train_stage=-10 tree_affix= # affix for tree directory, e.g. "a" or "b", in case we change the configuration. -tlstm_affix=1i #affix for TDNN-LSTM directory, e.g. "a" or "b", in case we change the configuration. +tlstm_affix=1a #affix for TDNN-LSTM directory, e.g. "a" or "b", in case we change the configuration. common_egs_dir= # you can set this to use previously dumped egs. # decode options extra_left_context=50 -frames_per_chunk= +frames_per_chunk=160 # End configuration section. @@ -75,21 +74,19 @@ rvb_affix=_rvb if $use_ihm_ali; then gmm_dir=exp/ihm/${ihm_gmm} - ali_dir=exp/${mic}/${ihm_gmm}_ali_${train_set}_sp_comb_ihmdata - lores_train_data_dir=data/$mic/${train_set}_ihmdata_sp_comb + lores_train_data_dir=data/$mic/${train_set}_ihmdata_sp tree_dir=exp/$mic/chain${nnet3_affix}/tree_bi${tree_affix}_ihmdata - original_lat_dir=exp/$mic/chain${nnet3_affix}/${ihm_gmm}_${train_set}_sp_comb_lats_ihmdata - lat_dir=exp/$mic/chain${nnet3_affix}${rvb_affix}/${ihm_gmm}_${train_set}_sp${rvb_affix}_comb_lats_ihmdata + original_lat_dir=exp/$mic/chain${nnet3_affix}/${ihm_gmm}_${train_set}_sp_lats_ihmdata + lat_dir=exp/$mic/chain${nnet3_affix}${rvb_affix}/${ihm_gmm}_${train_set}_sp${rvb_affix}_lats_ihmdata dir=exp/$mic/chain${nnet3_affix}${rvb_affix}/tdnn_lstm${tlstm_affix}_sp${rvb_affix}_bi_ihmali # note: the distinction between when we use the 'ihmdata' suffix versus # 'ihmali' is pretty arbitrary. else gmm_dir=exp/${mic}/$gmm - ali_dir=exp/${mic}/${gmm}_ali_${train_set}_sp_comb - lores_train_data_dir=data/$mic/${train_set}_sp_comb + lores_train_data_dir=data/$mic/${train_set}_sp tree_dir=exp/$mic/chain${nnet3_affix}/tree_bi${tree_affix} - original_lat_dir=exp/$mic/chain${nnet3_affix}/${gmm}_${train_set}_sp_comb_lats - lat_dir=exp/$mic/chain${nnet3_affix}${rvb_affix}/${gmm}_${train_set}_sp${rvb_affix}_comb_lats + original_lat_dir=exp/$mic/chain${nnet3_affix}/${gmm}_${train_set}_sp_lats + lat_dir=exp/$mic/chain${nnet3_affix}${rvb_affix}/${gmm}_${train_set}_sp${rvb_affix}_lats dir=exp/$mic/chain${nnet3_affix}${rvb_affix}/tdnn_lstm${tlstm_affix}_sp${rvb_affix}_bi fi @@ -97,9 +94,7 @@ fi local/nnet3/multi_condition/run_ivector_common.sh --stage $stage \ --mic $mic \ --nj $nj \ - --min-seg-len $min_seg_len \ --train-set $train_set \ - --gmm $gmm \ --num-threads-ubm $num_threads_ubm \ --num-data-reps $num_data_reps \ --nnet3-affix "$nnet3_affix" @@ -109,13 +104,13 @@ local/nnet3/multi_condition/run_ivector_common.sh --stage $stage \ local/nnet3/prepare_lores_feats.sh --stage $stage \ --mic $mic \ --nj $nj \ - --min-seg-len $min_seg_len \ + --min-seg-len "" \ --use-ihm-ali $use_ihm_ali \ --train-set $train_set -train_data_dir=data/$mic/${train_set}_sp${rvb_affix}_hires_comb -train_ivector_dir=exp/$mic/nnet3${nnet3_affix}${rvb_affix}/ivectors_${train_set}_sp${rvb_affix}_hires_comb +train_data_dir=data/$mic/${train_set}_sp${rvb_affix}_hires +train_ivector_dir=exp/$mic/nnet3${nnet3_affix}${rvb_affix}/ivectors_${train_set}_sp${rvb_affix}_hires final_lm=`cat data/local/lm/final_lm` LM=$final_lm.pr1-7 @@ -126,19 +121,6 @@ for f in $gmm_dir/final.mdl $lores_train_data_dir/feats.scp \ done -if [ $stage -le 11 ]; then - if [ -f $ali_dir/ali.1.gz ]; then - echo "$0: alignments in $ali_dir appear to already exist. Please either remove them " - echo " ... or use a later --stage option." - exit 1 - fi - echo "$0: aligning perturbed, short-segment-combined ${maybe_ihm}data" - steps/align_fmllr.sh --nj $nj --cmd "$train_cmd" \ - ${lores_train_data_dir} data/lang $gmm_dir $ali_dir -fi - -[ ! -f $ali_dir/ali.1.gz ] && echo "$0: expected $ali_dir/ali.1.gz to exist" && exit 1 - if [ $stage -le 12 ]; then echo "$0: creating lang directory with one state per phone." # Create a version of the lang/ directory that has one state per phone in the @@ -165,28 +147,42 @@ fi if [ $stage -le 13 ]; then # Get the alignments as lattices (gives the chain training more freedom). # use the same num-jobs as the alignments - steps/align_fmllr_lats.sh --nj 100 --cmd "$train_cmd" ${lores_train_data_dir} \ + steps/align_fmllr_lats.sh --nj 100 --cmd "$train_cmd" \ + --generate-ali-from-lats true ${lores_train_data_dir} \ data/lang $gmm_dir $original_lat_dir rm $original_lat_dir/fsts.*.gz # save space - lat_dir_ihmdata=exp/ihm/chain${nnet3_affix}/${gmm}_${train_set}_sp_comb_lats + lat_dir_ihmdata=exp/ihm/chain${nnet3_affix}/${gmm}_${train_set}_sp_lats + + original_lat_nj=$(cat $original_lat_dir/num_jobs) + ihm_lat_nj=$(cat $lat_dir_ihmdata/num_jobs) - mkdir -p $lat_dir/temp/ - mkdir -p $lat_dir/temp2/ - lattice-copy "ark:gunzip -c $original_lat_dir/lat.*.gz |" ark,scp:$lat_dir/temp/lats.ark,$lat_dir/temp/lats.scp - lattice-copy "ark:gunzip -c $lat_dir_ihmdata/lat.*.gz |" ark,scp:$lat_dir/temp2/lats.ark,$lat_dir/temp2/lats.scp + $train_cmd --max-jobs-run 10 JOB=1:$original_lat_nj $lat_dir/temp/log/copy_original_lats.JOB.log \ + lattice-copy "ark:gunzip -c $original_lat_dir/lat.JOB.gz |" ark,scp:$lat_dir/temp/lats.JOB.ark,$lat_dir/temp/lats.JOB.scp + + $train_cmd --max-jobs-run 10 JOB=1:$ihm_lat_nj $lat_dir/temp2/log/copy_ihm_lats.JOB.log \ + lattice-copy "ark:gunzip -c $lat_dir_ihmdata/lat.JOB.gz |" ark,scp:$lat_dir/temp2/lats.JOB.ark,$lat_dir/temp2/lats.JOB.scp + + for n in $(seq $original_lat_nj); do + cat $lat_dir/temp/lats.$n.scp + done > $lat_dir/temp/combined_lats.scp - # copy the lattices for the reverberated data - rm -f $lat_dir/temp/combined_lats.scp - touch $lat_dir/temp/combined_lats.scp - cat $lat_dir/temp/lats.scp >> $lat_dir/temp/combined_lats.scp for i in `seq 1 $num_data_reps`; do - cat $lat_dir/temp2/lats.scp | sed -e "s/^/rev${i}_/" >> $lat_dir/temp/combined_lats.scp - done + for n in $(seq $ihm_lat_nj); do + cat $lat_dir/temp2/lats.$n.scp + done | sed -e "s/^/rev${i}_/" + done >> $lat_dir/temp/combined_lats.scp + sort -u $lat_dir/temp/combined_lats.scp > $lat_dir/temp/combined_lats_sorted.scp - lattice-copy scp:$lat_dir/temp/combined_lats_sorted.scp "ark:|gzip -c >$lat_dir/lat.1.gz" || exit 1; - echo "1" > $lat_dir/num_jobs + utils/split_data.sh $train_data_dir $nj + + $train_cmd --max-jobs-run 10 JOB=1:$nj $lat_dir/copy_combined_lats.JOB.log \ + lattice-copy --include=$train_data_dir/split$nj/JOB/utt2spk \ + scp:$lat_dir/temp/combined_lats_sorted.scp \ + "ark:|gzip -c >$lat_dir/lat.JOB.gz" || exit 1; + + echo $nj > $lat_dir/num_jobs # copy other files from original lattice dir for f in cmvn_opts final.mdl splice_opts tree; do @@ -206,7 +202,7 @@ if [ $stage -le 14 ]; then steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ --context-opts "--context-width=2 --central-position=1" \ --leftmost-questions-truncate -1 \ - --cmd "$train_cmd" 4200 ${lores_train_data_dir} data/lang_chain $ali_dir $tree_dir + --cmd "$train_cmd" 4200 ${lores_train_data_dir} data/lang_chain $original_lat_dir $tree_dir fi xent_regularize=0.1 @@ -312,7 +308,6 @@ if [ $stage -le 18 ]; then rm $dir/.error 2>/dev/null || true [ -z $extra_left_context ] && extra_left_context=$chunk_left_context; - [ -z $frames_per_chunk ] && frames_per_chunk=$chunk_width; for decode_set in dev eval; do ( diff --git a/egs/ami/s5b/local/nnet3/multi_condition/run_ivector_common.sh b/egs/ami/s5b/local/nnet3/multi_condition/run_ivector_common.sh index eb20415e515..5ba35fa421c 100755 --- a/egs/ami/s5b/local/nnet3/multi_condition/run_ivector_common.sh +++ b/egs/ami/s5b/local/nnet3/multi_condition/run_ivector_common.sh @@ -10,19 +10,17 @@ set -e -o pipefail stage=1 mic=ihm nj=30 -min_seg_len=1.55 # min length in seconds... we do this because chain training - # will discard segments shorter than 1.5 seconds. Must remain in sync with - # the same option given to prepare_lores_feats.sh. train_set=train_cleaned # you might set this to e.g. train_cleaned. -gmm=tri3_cleaned # This specifies a GMM-dir from the features of the type you're training the system on; - # it should contain alignments for 'train_set'. - +norvb_datadir=data/ihm/train_cleaned_sp num_threads_ubm=32 rvb_affix=_rvb nnet3_affix=_cleaned # affix for exp/$mic/nnet3 directory to put iVector stuff in, so it # becomes exp/$mic/nnet3_cleaned or whatever. num_data_reps=1 +sample_rate=16000 + +max_jobs_run=10 . ./cmd.sh . ./path.sh @@ -30,10 +28,7 @@ num_data_reps=1 nnet3_affix=${nnet3_affix}$rvb_affix -gmmdir=exp/${mic}/${gmm} - - -for f in data/${mic}/${train_set}/feats.scp ${gmmdir}/final.mdl; do +for f in data/${mic}/${train_set}/feats.scp; do if [ ! -f $f ]; then echo "$0: expected file $f to exist" exit 1 @@ -73,36 +68,23 @@ if [ $stage -le 1 ]; then for datadir in ${train_set}_sp dev eval; do steps/make_mfcc.sh --nj $nj --mfcc-config conf/mfcc_hires.conf \ - --cmd "$train_cmd" data/$mic/${datadir}_hires + --cmd "$train_cmd --max-jobs-run $max_jobs_run" data/$mic/${datadir}_hires steps/compute_cmvn_stats.sh data/$mic/${datadir}_hires utils/fix_data_dir.sh data/$mic/${datadir}_hires done fi -if [ $stage -le 2 ]; then - echo "$0: combining short segments of speed-perturbed high-resolution MFCC training data" - # we have to combine short segments or we won't be able to train chain models - # on those segments. - utils/data/combine_short_segments.sh \ - data/${mic}/${train_set}_sp_hires $min_seg_len data/${mic}/${train_set}_sp_hires_comb - - # just copy over the CMVN to avoid having to recompute it. - cp data/${mic}/${train_set}_sp_hires/cmvn.scp data/${mic}/${train_set}_sp_hires_comb/ - utils/fix_data_dir.sh data/${mic}/${train_set}_sp_hires_comb/ -fi if [ $stage -le 3 ]; then echo "$0: creating reverberated MFCC features" - datadir=data/ihm/train_cleaned_sp - - mfccdir=${datadir}_rvb${num_data_reps}_hires/data + mfccdir=${norvb_datadir}${rvb_affix}${num_data_reps}_hires/data if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then utils/create_split_dir.pl /export/b0{5,6,7,8}/$USER/kaldi-data/egs/ami-$mic-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage fi - if [ ! -f ${datadir}_rvb${num_data_reps}_hires/feats.scp ]; then - if [ ! -d "RIRS_NOISES" ]; then + if [ ! -f ${norvb_datadir}${rvb_affix}${num_data_reps}_hires/feats.scp ]; then + if [ ! -d "RIRS_NOISES/" ]; then # Download the package that includes the real RIRs, simulated RIRs, isotropic noises and point-source noises wget --no-check-certificate http://www.openslr.org/resources/28/rirs_noises.zip unzip rirs_noises.zip @@ -123,60 +105,29 @@ if [ $stage -le 3 ]; then --isotropic-noise-addition-probability 1 \ --num-replications ${num_data_reps} \ --max-noises-per-minute 1 \ - --source-sampling-rate 16000 \ - ${datadir} ${datadir}_rvb${num_data_reps} + --source-sampling-rate $sample_rate \ + ${norvb_datadir} ${norvb_datadir}${rvb_affix}${num_data_reps} - utils/copy_data_dir.sh ${datadir}_rvb${num_data_reps} ${datadir}_rvb${num_data_reps}_hires - utils/data/perturb_data_dir_volume.sh ${datadir}_rvb${num_data_reps}_hires + utils/copy_data_dir.sh ${norvb_datadir}${rvb_affix}${num_data_reps} ${norvb_datadir}${rvb_affix}${num_data_reps}_hires + utils/data/perturb_data_dir_volume.sh ${norvb_datadir}${rvb_affix}${num_data_reps}_hires steps/make_mfcc.sh --nj $nj --mfcc-config conf/mfcc_hires.conf \ - --cmd "$train_cmd" ${datadir}_rvb${num_data_reps}_hires - steps/compute_cmvn_stats.sh ${datadir}_rvb${num_data_reps}_hires - utils/fix_data_dir.sh ${datadir}_rvb${num_data_reps}_hires - - utils/data/combine_short_segments.sh \ - ${datadir}_rvb${num_data_reps}_hires $min_seg_len ${datadir}_rvb${num_data_reps}_hires_comb - - # just copy over the CMVN to avoid having to recompute it. - cp ${datadir}_rvb${num_data_reps}_hires/cmvn.scp ${datadir}_rvb${num_data_reps}_hires_comb/ - utils/fix_data_dir.sh ${datadir}_rvb${num_data_reps}_hires_comb/ + --cmd "$train_cmd --max-jobs-run $max_jobs_run" ${norvb_datadir}${rvb_affix}${num_data_reps}_hires + steps/compute_cmvn_stats.sh ${norvb_datadir}${rvb_affix}${num_data_reps}_hires + utils/fix_data_dir.sh ${norvb_datadir}${rvb_affix}${num_data_reps}_hires fi - utils/combine_data.sh data/${mic}/${train_set}_sp_rvb_hires data/${mic}/${train_set}_sp_hires ${datadir}_rvb${num_data_reps}_hires - utils/combine_data.sh data/${mic}/${train_set}_sp_rvb_hires_comb data/${mic}/${train_set}_sp_hires_comb ${datadir}_rvb${num_data_reps}_hires_comb + utils/combine_data.sh data/${mic}/${train_set}_sp${rvb_affix}_hires data/${mic}/${train_set}_sp_hires ${norvb_datadir}${rvb_affix}${num_data_reps}_hires fi - if [ $stage -le 4 ]; then - echo "$0: selecting segments of hires training data that were also present in the" - echo " ... original training data." - - # note, these data-dirs are temporary; we put them in a sub-directory - # of the place where we'll make the alignments. - temp_data_root=exp/$mic/nnet3${nnet3_affix}/tri5 - mkdir -p $temp_data_root - - utils/data/subset_data_dir.sh --utt-list data/${mic}/${train_set}/feats.scp \ - data/${mic}/${train_set}_sp_hires $temp_data_root/${train_set}_hires - - # note: essentially all the original segments should be in the hires data. - n1=$(wc -l $tmpdir/ihmutt2utt # Map the 1st field of the segments file from the ihm data (the 1st field being # the utterance-id) to the corresponding SDM or MDM utterance-id. The other # fields remain the same (e.g. we want the recording-ids from the IHM data). -utils/apply_map.pl -f 1 $tmpdir/ihmutt2utt data/$mic/train_ihmdata/segments +utils/apply_map.pl -f 1 $tmpdir/ihmutt2utt data/$mic/${train_set}_ihmdata/segments -utils/fix_data_dir.sh data/$mic/train_ihmdata +utils/fix_data_dir.sh data/$mic/${train_set}_ihmdata rm $tmpdir/ihmutt2utt diff --git a/egs/aspire/s5/local/nnet3/segment_and_decode.sh b/egs/aspire/s5/local/nnet3/segment_and_decode.sh index d66b72200c1..e8917d091e2 100755 --- a/egs/aspire/s5/local/nnet3/segment_and_decode.sh +++ b/egs/aspire/s5/local/nnet3/segment_and_decode.sh @@ -109,9 +109,9 @@ fi if [ $stage -le 4 ]; then utils/copy_data_dir.sh $sad_work_dir/${segmented_data_set}_seg \ - data/${segmented_data_set}_hires - steps/compute_cmvn_stats.sh data/${segmented_data_set}_hires - utils/fix_data_dir.sh data/${segmented_data_set}_hires + data/${segmented_data_set}_seg_hires + steps/compute_cmvn_stats.sh data/${segmented_data_set}_seg_hires + utils/fix_data_dir.sh data/${segmented_data_set}_seg_hires fi if [ $stage -le 5 ]; then @@ -122,11 +122,11 @@ if [ $stage -le 5 ]; then # acoustic conditions drift over time within the speaker's data. steps/online/nnet2/extract_ivectors.sh --cmd "$train_cmd" --nj $decode_num_jobs \ --sub-speaker-frames $sub_speaker_frames --max-count $max_count \ - data/${segmented_data_set}_hires $lang $ivector_root_dir/extractor \ - $ivector_root_dir/ivectors_${segmented_data_set} + data/${segmented_data_set}_seg_hires $lang $ivector_root_dir/extractor \ + $ivector_root_dir/ivectors_${segmented_data_set}_seg fi -decode_dir=$dir/decode_${segmented_data_set}${affix}_pp +decode_dir=$dir/decode_${segmented_data_set}_seg${affix}_pp if [ $stage -le 6 ]; then echo "Generating lattices" rm -f ${decode_dir}_tg/.error @@ -138,8 +138,8 @@ if [ $stage -le 6 ]; then --extra-right-context-final $extra_right_context_final \ --frames-per-chunk "$frames_per_chunk" \ --skip-scoring true ${iter:+--iter $iter} --lattice-beam $lattice_beam \ - --online-ivector-dir $ivector_root_dir/ivectors_${segmented_data_set} \ - $graph data/${segmented_data_set}_hires ${decode_dir}_tg || \ + --online-ivector-dir $ivector_root_dir/ivectors_${segmented_data_set}_seg \ + $graph data/${segmented_data_set}_seg_hires ${decode_dir}_tg || \ { echo "$0: Error decoding" && exit 1; } fi @@ -147,7 +147,7 @@ if [ $stage -le 7 ]; then echo "Rescoring lattices" steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ --skip-scoring true \ - ${lang}_pp_test{,_fg} data/${segmented_data_set}_hires \ + ${lang}_pp_test{,_fg} data/${segmented_data_set}_seg_hires \ ${decode_dir}_{tg,fg}; fi @@ -161,5 +161,5 @@ if [ $stage -le 8 ]; then ${iter:+--iter $iter} \ --decode-mbr true \ --tune-hyper true \ - $lang $decode_dir $act_data_set $segmented_data_set $out_file + $lang $decode_dir $act_data_set ${segmented_data_set}_seg $out_file fi diff --git a/egs/callhome_diarization/v1/diarization/make_rttm.py b/egs/callhome_diarization/v1/diarization/make_rttm.py index 1705411069f..a5242c7654f 100755 --- a/egs/callhome_diarization/v1/diarization/make_rttm.py +++ b/egs/callhome_diarization/v1/diarization/make_rttm.py @@ -80,7 +80,7 @@ def main(): # Cut up overlapping segments so they are contiguous contiguous_segs = [] - for reco in reco2segs: + for reco in sorted(reco2segs): segs = reco2segs[reco].strip().split() new_segs = "" for i in range(1, len(segs)-1): diff --git a/egs/chime1/s5/local/chime1_prepare_data.sh b/egs/chime1/s5/local/chime1_prepare_data.sh index e60c46ff8da..c5963b5d4ab 100755 --- a/egs/chime1/s5/local/chime1_prepare_data.sh +++ b/egs/chime1/s5/local/chime1_prepare_data.sh @@ -53,7 +53,7 @@ for x in "devel" "test"; do for sid in `seq 34`; do sid2=`printf "s%02d" $sid` ls -1 $wav_dir/*/s${sid}_*.wav \ - | perl -ape "s/(.*)\/(.*)\/s.*_(.*).wav/${sid2}_\3_\2$\t\1\/\2\/s${sid}_\3.wav/;" \ + | perl -ape "s/(.*)\/(.*)\/s.*_(.*).wav/${sid2}_\3_\2\t\1\/\2\/s${sid}_\3.wav/;" \ | sort >> $scp done fi @@ -68,7 +68,7 @@ for x in $set_list; do # Create utt2spk files # No speaker ID - perl -ape "s/(.*)\t.*/\1$\t\1/;" < "$scp" > "$data/$x/utt2spk" + perl -ape "s/(.*)\t.*/\1\t\1/;" < "$scp" > "$data/$x/utt2spk" # Use speaker ID # perl -ape "s/(s..)(.*)\\t.*/\1\2\t\1/;" < "$scp" > "$data/$x/utt2spk" diff --git a/egs/chime5/s5/local/score_for_submit.sh b/egs/chime5/s5/local/score_for_submit.sh index 5502c5994e5..23121d68b93 100755 --- a/egs/chime5/s5/local/score_for_submit.sh +++ b/egs/chime5/s5/local/score_for_submit.sh @@ -43,7 +43,7 @@ for session in S02 S09; do # get nerror nerr=`grep "\#csid" $score_result | grep $room | grep $session | awk '{sum+=$4+$5+$6} END {print sum}'` # get nwords from references (NF-2 means to exclude utterance id and " ref ") - nwrd=`grep " ref " $score_result | grep $room | grep $session | sed -e "s/\*//g" | awk '{sum+=NF-2} END {print sum}'` + nwrd=`grep "\#csid" $score_result | grep $room | grep $session | awk '{sum+=$3+$4+$6} END {print sum}'` # compute wer with scale=2 wer=`echo "scale=2; 100 * $nerr / $nwrd" | bc` @@ -59,7 +59,7 @@ echo -n "overall: " # get nerror nerr=`grep "\#csid" $score_result | awk '{sum+=$4+$5+$6} END {print sum}'` # get nwords from references (NF-2 means to exclude utterance id and " ref ") -nwrd=`grep " ref " $score_result | sed -e "s/\*//g" | awk '{sum+=NF-2} END {print sum}'` +nwrd=`grep "\#csid" $score_result | awk '{sum+=$3+$4+$6} END {print sum}'` # compute wer with scale=2 wer=`echo "scale=2; 100 * $nerr / $nwrd" | bc` echo -n "#words $nwrd, " @@ -81,7 +81,7 @@ for session in S01 S21; do # get nerror nerr=`grep "\#csid" $score_result | grep $room | grep $session | awk '{sum+=$4+$5+$6} END {print sum}'` # get nwords from references (NF-2 means to exclude utterance id and " ref ") - nwrd=`grep " ref " $score_result | grep $room | grep $session | sed -e "s/\*//g" | awk '{sum+=NF-2} END {print sum}'` + nwrd=`grep "\#csid" $score_result | grep $room | grep $session | awk '{sum+=$3+$4+$6} END {print sum}'` # compute wer with scale=2 wer=`echo "scale=2; 100 * $nerr / $nwrd" | bc` @@ -98,7 +98,7 @@ if $do_eval; then # get nerror nerr=`grep "\#csid" $score_result | awk '{sum+=$4+$5+$6} END {print sum}'` # get nwords from references (NF-2 means to exclude utterance id and " ref ") - nwrd=`grep " ref " $score_result | sed -e "s/\*//g" | awk '{sum+=NF-2} END {print sum}'` + nwrd=`grep "\#csid" $score_result | awk '{sum+=$3+$4+$6} END {print sum}'` # compute wer with scale=2 wer=`echo "scale=2; 100 * $nerr / $nwrd" | bc` echo -n "overall: " diff --git a/egs/cifar/v1/image/ocr/make_features.py b/egs/cifar/v1/image/ocr/make_features.py index 7ab75498277..a11cbcc7a82 100755 --- a/egs/cifar/v1/image/ocr/make_features.py +++ b/egs/cifar/v1/image/ocr/make_features.py @@ -43,10 +43,15 @@ parser.add_argument('--padding', type=int, default=5, help='Number of white pixels to pad on the left' 'and right side of the image.') +parser.add_argument('--num-channels', type=int, default=1, + help='Number of color channels') +parser.add_argument('--vertical-shift', type=int, default=0, + help='total number of padding pixel per column') parser.add_argument('--fliplr', type=lambda x: (str(x).lower()=='true'), default=False, help="Flip the image left-right for right to left languages") -parser.add_argument("--augment", type=lambda x: (str(x).lower()=='true'), default=False, - help="performs image augmentation") +parser.add_argument('--augment_type', type=str, default='no_aug', + choices=['no_aug', 'random_scale','random_shift'], + help='Subset of data to process.') args = parser.parse_args() @@ -66,7 +71,6 @@ def write_kaldi_matrix(file_handle, matrix, key): file_handle.write("\n") file_handle.write(" ]\n") - def horizontal_pad(im, allowed_lengths = None): if allowed_lengths is None: left_padding = right_padding = args.padding @@ -84,9 +88,9 @@ def horizontal_pad(im, allowed_lengths = None): left_padding = int(padding // 2) right_padding = padding - left_padding dim_y = im.shape[0] # height - im_pad = np.concatenate((255 * np.ones((dim_y, left_padding), + im_pad = np.concatenate((255 * np.ones((dim_y, left_padding, args.num_channels), dtype=int), im), axis=1) - im_pad1 = np.concatenate((im_pad, 255 * np.ones((dim_y, right_padding), + im_pad1 = np.concatenate((im_pad, 255 * np.ones((dim_y, right_padding, args.num_channels), dtype=int)), axis=1) return im_pad1 @@ -110,6 +114,33 @@ def get_scaled_image_aug(im, mode='normal'): return im_scaled_up return im +def vertical_shift(im, mode='normal'): + if args.vertical_shift == 0: + return im + total = args.vertical_shift + if mode == 'notmid': + val = random.randint(0, 1) + if val == 0: + mode = 'top' + else: + mode = 'bottom' + if mode == 'normal': + top = int(total / 2) + bottom = total - top + elif mode == 'top': # more padding on top + top = random.randint(total / 2, total) + bottom = total - top + elif mode == 'bottom': # more padding on bottom + top = random.randint(0, total / 2) + bottom = total - top + width = im.shape[1] + im_pad = np.concatenate( + (255 * np.ones((top, width), dtype=int) - + np.random.normal(2, 1, (top, width)).astype(int), im), axis=0) + im_pad = np.concatenate( + (im_pad, 255 * np.ones((bottom, width), dtype=int) - + np.random.normal(2, 1, (bottom, width)).astype(int)), axis=0) + return im_pad ### main ### random.seed(1) @@ -132,7 +163,6 @@ def get_scaled_image_aug(im, mode='normal'): num_fail = 0 num_ok = 0 -aug_setting = ['normal', 'scaled'] with open(data_list_path) as f: for line in f: line = line.strip() @@ -142,15 +172,25 @@ def get_scaled_image_aug(im, mode='normal'): im = misc.imread(image_path) if args.fliplr: im = np.fliplr(im) - if args.augment: - im_aug = get_scaled_image_aug(im, aug_setting[1]) - else: - im_aug = get_scaled_image_aug(im, aug_setting[0]) - im_horizontal_padded = horizontal_pad(im_aug, allowed_lengths) - if im_horizontal_padded is None: + if args.augment_type == 'no_aug' or 'random_shift': + im = get_scaled_image_aug(im, 'normal') + elif args.augment_type == 'random_scale': + im = get_scaled_image_aug(im, 'scaled') + im = horizontal_pad(im, allowed_lengths) + if im is None: num_fail += 1 continue - data = np.transpose(im_horizontal_padded, (1, 0)) + if args.augment_type == 'no_aug' or 'random_scale': + im = vertical_shift(im, 'normal') + elif args.augment_type == 'random_shift': + im = vertical_shift(im, 'notmid') + if args.num_channels == 1: + data = np.transpose(im, (1, 0)) + elif args.num_channels == 3: + H = im.shape[0] + W = im.shape[1] + C = im.shape[2] + data = np.reshape(np.transpose(im, (1, 0, 2)), (W, H * C)) data = np.divide(data, 255.0) num_ok += 1 write_kaldi_matrix(out_fh, data, image_id) diff --git a/egs/librispeech/s5/local/chain/run_cnn_tdnn.sh b/egs/librispeech/s5/local/chain/run_cnn_tdnn.sh new file mode 100755 index 00000000000..cd8f38d8309 --- /dev/null +++ b/egs/librispeech/s5/local/chain/run_cnn_tdnn.sh @@ -0,0 +1 @@ +tuning/run_cnn_tdnn_1a.sh diff --git a/egs/librispeech/s5/local/chain/tuning/run_cnn_tdnn_1a.sh b/egs/librispeech/s5/local/chain/tuning/run_cnn_tdnn_1a.sh new file mode 100755 index 00000000000..2a60587fc35 --- /dev/null +++ b/egs/librispeech/s5/local/chain/tuning/run_cnn_tdnn_1a.sh @@ -0,0 +1,278 @@ +#!/bin/bash + +# This is based on tdnn_1d_sp, but adding cnn as the front-end. +# The cnn-tdnn-f (tdnn_cnn_1a_sp) outperforms the tdnn-f (tdnn_1d_sp). + +# bash local/chain/compare_wer.sh exp/chain_cleaned/tdnn_1d_sp exp/chain_cleaned/tdnn_cnn_1a_sp/ +# System tdnn_1d_sp tdnn_cnn_1a_sp +# WER on dev(fglarge) 3.29 3.34 +# WER on dev(tglarge) 3.44 3.39 +# WER on dev(tgmed) 4.22 4.29 +# WER on dev(tgsmall) 4.72 4.77 +# WER on dev_other(fglarge) 8.71 8.62 +# WER on dev_other(tglarge) 9.05 9.00 +# WER on dev_other(tgmed) 11.09 10.93 +# WER on dev_other(tgsmall) 12.13 12.02 +# WER on test(fglarge) 3.80 3.69 +# WER on test(tglarge) 3.89 3.80 +# WER on test(tgmed) 4.72 4.64 +# WER on test(tgsmall) 5.19 5.16 +# WER on test_other(fglarge) 8.76 8.71 +# WER on test_other(tglarge) 9.19 9.11 +# WER on test_other(tgmed) 11.22 11.00 +# WER on test_other(tgsmall) 12.24 12.16 +# Final train prob -0.0378 -0.0420 +# Final valid prob -0.0374 -0.0400 +# Final train prob (xent) -0.6099 -0.6881 +# Final valid prob (xent) -0.6353 -0.7180 +# Num-parameters 22623456 18100736 + + +set -e + +# configs for 'chain' +stage=0 +decode_nj=50 +train_set=train_960_cleaned +gmm=tri6b_cleaned +nnet3_affix=_cleaned + +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +affix=cnn_1a +tree_affix= +train_stage=-10 +get_egs_stage=-10 +decode_iter= + +# TDNN options +frames_per_eg=150,110,100 +remove_egs=true +common_egs_dir= +xent_regularize=0.1 +dropout_schedule='0,0@0.20,0.5@0.50,0' + +test_online_decoding=true # if true, it will run the last decoding stage. + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # MFCC to filterbank + idct-layer name=idct input=input dim=40 cepstral-lifter=22 affine-transform-file=$dir/configs/idct.mat + + linear-component name=ivector-linear $ivector_affine_opts dim=200 input=ReplaceIndex(ivector, t, 0) + batchnorm-component name=ivector-batchnorm target-rms=0.025 + batchnorm-component name=idct-batchnorm input=idct + + combine-feature-maps-layer name=combine_inputs input=Append(idct-batchnorm, ivector-batchnorm) num-filters1=1 num-filters2=5 height=40 + conv-relu-batchnorm-layer name=cnn1 $cnn_opts height-in=40 height-out=40 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=64 + conv-relu-batchnorm-layer name=cnn2 $cnn_opts height-in=40 height-out=40 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=64 + conv-relu-batchnorm-layer name=cnn3 $cnn_opts height-in=40 height-out=20 height-subsample-out=2 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=128 + conv-relu-batchnorm-layer name=cnn4 $cnn_opts height-in=20 height-out=20 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=128 + conv-relu-batchnorm-layer name=cnn5 $cnn_opts height-in=20 height-out=10 height-subsample-out=2 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=256 + conv-relu-batchnorm-layer name=cnn6 $cnn_opts height-in=10 height-out=10 time-offsets=-1,0,1 height-offsets=-1,0,1 num-filters-out=256 + + # the first TDNN-F layer has no bypass + tdnnf-layer name=tdnnf7 $tdnnf_first_opts dim=1536 bottleneck-dim=256 time-stride=0 + tdnnf-layer name=tdnnf8 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf9 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf10 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf11 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf12 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf13 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf14 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf15 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf16 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf17 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf18 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + linear-component name=prefinal-l dim=256 $linear_opts + + prefinal-layer name=prefinal-chain input=prefinal-l $prefinal_opts big-dim=1536 small-dim=256 + output-layer name=output include-log-softmax=false dim=$num_targets $output_opts + + prefinal-layer name=prefinal-xent input=prefinal-l $prefinal_opts big-dim=1536 small-dim=256 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor $output_opts +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 15 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b{09,10,11,12}/$USER/kaldi-data/egs/swbd-$(date +'%m_%d_%H_%M')/s5c/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage $train_stage \ + --use-gpu "wait" \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir $train_ivector_dir \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.0 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--frames-overlap-per-eg 0 --constrained false" \ + --egs.chunk-width $frames_per_eg \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --trainer.num-chunk-per-minibatch 64 \ + --trainer.frames-per-iter 2500000 \ + --trainer.num-epochs 4 \ + --trainer.optimization.num-jobs-initial 3 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.00015 \ + --trainer.optimization.final-effective-lrate 0.000015 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs $remove_egs \ + --feat-dir $train_data_dir \ + --tree-dir $tree_dir \ + --lat-dir $lat_dir \ + --dir $dir || exit 1; + +fi + +graph_dir=$dir/graph_tgsmall +if [ $stage -le 16 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 --remove-oov data/lang_test_tgsmall $dir $graph_dir + # remove from the graph, and convert back to const-FST. + fstrmsymbols --apply-to-output=true --remove-arcs=true "echo 3|" $graph_dir/HCLG.fst - | \ + fstconvert --fst_type=const > $graph_dir/temp.fst + mv $graph_dir/temp.fst $graph_dir/HCLG.fst +fi + +iter_opts= +if [ ! -z $decode_iter ]; then + iter_opts=" --iter $decode_iter " +fi +if [ $stage -le 17 ]; then + rm $dir/.error 2>/dev/null || true + for decode_set in test_clean test_other dev_clean dev_other; do + ( + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $decode_nj --cmd "$decode_cmd" $iter_opts \ + --online-ivector-dir exp/nnet3${nnet3_affix}/ivectors_${decode_set}_hires \ + $graph_dir data/${decode_set}_hires $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_tgsmall || exit 1 + steps/lmrescore.sh --cmd "$decode_cmd" --self-loop-scale 1.0 data/lang_test_{tgsmall,tgmed} \ + data/${decode_set}_hires $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_{tgsmall,tgmed} || exit 1 + steps/lmrescore_const_arpa.sh \ + --cmd "$decode_cmd" data/lang_test_{tgsmall,tglarge} \ + data/${decode_set}_hires $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_{tgsmall,tglarge} || exit 1 + steps/lmrescore_const_arpa.sh \ + --cmd "$decode_cmd" data/lang_test_{tgsmall,fglarge} \ + data/${decode_set}_hires $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_{tgsmall,fglarge} || exit 1 + ) || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in decoding" + exit 1 + fi +fi + +if $test_online_decoding && [ $stage -le 18 ]; then + # note: if the features change (e.g. you add pitch features), you will have to + # change the options of the following command line. + steps/online/nnet3/prepare_online_decoding.sh \ + --mfcc-config conf/mfcc_hires.conf \ + $lang exp/nnet3${nnet3_affix}/extractor $dir ${dir}_online + + rm $dir/.error 2>/dev/null || true + for data in test_clean test_other dev_clean dev_other; do + ( + nspk=$(wc -l $lat_dir/splice_opts - fi if [ $stage -le 3 ]; then @@ -185,7 +181,7 @@ if [ $stage -le 5 ]; then --chain.leaky-hmm-coefficient=0.1 \ --chain.l2-regularize=0.00005 \ --chain.apply-deriv-weights=false \ - --chain.lm-opts="--num-extra-lm-states=500" \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=1000" \ --chain.frame-subsampling-factor=$frame_subsampling_factor \ --chain.alignment-subsampling-factor=1 \ --chain.left-tolerance 3 \ @@ -201,11 +197,8 @@ if [ $stage -le 5 ]; then --trainer.optimization.shrink-value=1.0 \ --trainer.num-chunk-per-minibatch=64,32 \ --trainer.optimization.momentum=0.0 \ + --trainer.add-option="--optimization.memory-compression-level=2" \ --egs.chunk-width=$chunk_width \ - --egs.chunk-left-context=$chunk_left_context \ - --egs.chunk-right-context=$chunk_right_context \ - --egs.chunk-left-context-initial=0 \ - --egs.chunk-right-context-final=0 \ --egs.dir="$common_egs_dir" \ --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ --cleanup.remove-egs=$remove_egs \ @@ -226,18 +219,20 @@ if [ $stage -le 6 ]; then # as long as phones.txt was compatible. utils/mkgraph.sh \ - --self-loop-scale 1.0 data/$lang_test \ + --self-loop-scale 1.0 $lang_decode \ $dir $dir/graph || exit 1; fi if [ $stage -le 7 ]; then frames_per_chunk=$(echo $chunk_width | cut -d, -f1) steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ - --extra-left-context $chunk_left_context \ - --extra-right-context $chunk_right_context \ - --extra-left-context-initial 0 \ - --extra-right-context-final 0 \ --frames-per-chunk $frames_per_chunk \ --nj $nj --cmd "$cmd" \ $dir/graph data/test $dir/decode_test || exit 1; + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/test $dir/decode_test{,_rescored} || exit 1 fi + +echo "Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/madcat_ar/v1/local/chain/tuning/run_e2e_cnn_1a.sh b/egs/madcat_ar/v1/local/chain/tuning/run_e2e_cnn_1a.sh index 033cb88df10..2891e50da9e 100755 --- a/egs/madcat_ar/v1/local/chain/tuning/run_e2e_cnn_1a.sh +++ b/egs/madcat_ar/v1/local/chain/tuning/run_e2e_cnn_1a.sh @@ -27,16 +27,12 @@ affix=1a # training options tdnn_dim=450 -num_epochs=2 -num_jobs_initial=6 -num_jobs_final=16 minibatch_size=150=128,64/300=128,64/600=64,32/1200=32,16 common_egs_dir= -l2_regularize=0.00005 -frames_per_iter=2000000 -cmvn_opts="--norm-means=true --norm-vars=true" +cmvn_opts="--norm-means=false --norm-vars=false" train_set=train -lang_test=lang_test +lang_decode=data/lang +lang_rescore=data/lang_rescore_6g # End configuration section. echo "$0 $@" # Print the command line for logging @@ -118,7 +114,7 @@ if [ $stage -le 3 ]; then --cmd "$cmd" \ --feat.cmvn-opts "$cmvn_opts" \ --chain.leaky-hmm-coefficient 0.1 \ - --chain.l2-regularize $l2_regularize \ + --chain.l2-regularize 0.00005 \ --chain.apply-deriv-weights false \ --egs.dir "$common_egs_dir" \ --egs.stage $get_egs_stage \ @@ -128,11 +124,11 @@ if [ $stage -le 3 ]; then --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=1000" \ --trainer.add-option="--optimization.memory-compression-level=2" \ --trainer.num-chunk-per-minibatch $minibatch_size \ - --trainer.frames-per-iter $frames_per_iter \ - --trainer.num-epochs $num_epochs \ + --trainer.frames-per-iter 2000000 \ + --trainer.num-epochs 2 \ --trainer.optimization.momentum 0 \ - --trainer.optimization.num-jobs-initial $num_jobs_initial \ - --trainer.optimization.num-jobs-final $num_jobs_final \ + --trainer.optimization.num-jobs-initial 6 \ + --trainer.optimization.num-jobs-final 16 \ --trainer.optimization.initial-effective-lrate 0.001 \ --trainer.optimization.final-effective-lrate 0.0001 \ --trainer.optimization.shrink-value 1.0 \ @@ -152,7 +148,7 @@ if [ $stage -le 4 ]; then # as long as phones.txt was compatible. utils/mkgraph.sh \ - --self-loop-scale 1.0 data/$lang_test \ + --self-loop-scale 1.0 $lang_decode \ $dir $dir/graph || exit 1; fi @@ -161,6 +157,9 @@ if [ $stage -le 5 ]; then steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ --nj $nj --cmd "$cmd" \ $dir/graph data/test $dir/decode_test || exit 1; + + steps/lmrescore_const_arpa.sh --cmd "$cmd" $lang_decode $lang_rescore \ + data/test $dir/decode_test{,_rescored} || exit 1 fi echo "Done. Date: $(date). Results:" diff --git a/egs/madcat_ar/v1/local/create_line_image_from_page_image.py b/egs/madcat_ar/v1/local/create_line_image_from_page_image.py index 34e339f1877..778555c427e 100755 --- a/egs/madcat_ar/v1/local/create_line_image_from_page_image.py +++ b/egs/madcat_ar/v1/local/create_line_image_from_page_image.py @@ -21,22 +21,10 @@ import numpy as np from math import atan2, cos, sin, pi, degrees, sqrt from collections import namedtuple - +import random from scipy.spatial import ConvexHull from PIL import Image from scipy.misc import toimage -import logging - -sys.path.insert(0, 'steps') -logger = logging.getLogger('libs') -logger.setLevel(logging.INFO) -handler = logging.StreamHandler() -handler.setLevel(logging.INFO) -formatter = logging.Formatter("%(asctime)s [%(pathname)s:%(lineno)s - " - "%(funcName)s - %(levelname)s ] %(message)s") -handler.setFormatter(formatter) -logger.addHandler(handler) - parser = argparse.ArgumentParser(description="Creates line images from page image", epilog="E.g. " + sys.argv[0] + " data/LDC2012T15" " data/LDC2013T09 data/LDC2013T15 data/madcat.train.raw.lineid " @@ -60,8 +48,12 @@ help='Path to the downloaded (and extracted) writing conditions file 3') parser.add_argument('--padding', type=int, default=400, help='padding across horizontal/verticle direction') +parser.add_argument('--pixel-scaling', type=int, default=30, + help='padding across horizontal/verticle direction') parser.add_argument("--subset", type=lambda x: (str(x).lower()=='true'), default=False, help="only processes subset of data based on writing condition") +parser.add_argument("--augment", type=lambda x: (str(x).lower()=='true'), default=False, + help="performs image augmentation") args = parser.parse_args() """ @@ -196,21 +188,6 @@ def rectangle_corners(rectangle): return rotate_points(rectangle['rectangle_center'], rectangle['unit_vector_angle'], corner_points) -def get_orientation(origin, p1, p2): - """ - Given origin and two points, return the orientation of the Point p1 with - regards to Point p2 using origin. - Returns - ------- - integer: Negative if p1 is clockwise of p2. - """ - difference = ( - ((p2[0] - origin[0]) * (p1[1] - origin[1])) - - ((p1[0] - origin[0]) * (p2[1] - origin[1])) - ) - return difference - - def minimum_bounding_box(points): """ Given a list of 2D points, it returns the minimum area rectangle bounding all the points in the point cloud. @@ -357,6 +334,36 @@ def update_minimum_bounding_box_input(bounding_box_input): return updated_minimum_bounding_box_input +def dilate_polygon(points, amount_increase): + """ Increases size of polygon given as a list of tuples. + Assumes points in polygon are given in CCW + """ + expanded_points = [] + for index, point in enumerate(points): + prev_point = points[(index - 1) % len(points)] + next_point = points[(index + 1) % len(points)] + prev_edge = np.subtract(point, prev_point) + next_edge = np.subtract(next_point, point) + + prev_normal = ((1 * prev_edge[1]), (-1 * prev_edge[0])) + prev_normal = np.divide(prev_normal, np.linalg.norm(prev_normal)) + next_normal = ((1 * next_edge[1]), (-1 * next_edge[0])) + next_normal = np.divide(next_normal, np.linalg.norm(next_normal)) + + bisect = np.add(prev_normal, next_normal) + bisect = np.divide(bisect, np.linalg.norm(bisect)) + + cos_theta = np.dot(next_normal, bisect) + hyp = amount_increase / cos_theta + + new_point = np.around(point + hyp * bisect) + new_point = new_point.astype(int) + new_point = new_point.tolist() + new_point = tuple(new_point) + expanded_points.append(new_point) + return expanded_points + + def set_line_image_data(image, line_id, image_file_name, image_fh): """ Given an image, saves a flipped line image. Line image file name is formed by appending the line id at the end page image name. @@ -395,50 +402,83 @@ def get_line_images_from_page_image(image_file_name, madcat_file_path, image_fh) word_coordinate = (int(word_node.getAttribute('x')), int(word_node.getAttribute('y'))) minimum_bounding_box_input.append(word_coordinate) updated_mbb_input = update_minimum_bounding_box_input(minimum_bounding_box_input) - bounding_box = minimum_bounding_box(updated_mbb_input) - - p1, p2, p3, p4 = bounding_box.corner_points - x1, y1 = p1 - x2, y2 = p2 - x3, y3 = p3 - x4, y4 = p4 - min_x = int(min(x1, x2, x3, x4)) - min_y = int(min(y1, y2, y3, y4)) - max_x = int(max(x1, x2, x3, x4)) - max_y = int(max(y1, y2, y3, y4)) - box = (min_x, min_y, max_x, max_y) - region_initial = im.crop(box) - rot_points = [] - p1_new = (x1 - min_x, y1 - min_y) - p2_new = (x2 - min_x, y2 - min_y) - p3_new = (x3 - min_x, y3 - min_y) - p4_new = (x4 - min_x, y4 - min_y) - rot_points.append(p1_new) - rot_points.append(p2_new) - rot_points.append(p3_new) - rot_points.append(p4_new) - - cropped_bounding_box = bounding_box_tuple(bounding_box.area, - bounding_box.length_parallel, - bounding_box.length_orthogonal, - bounding_box.length_orthogonal, - bounding_box.unit_vector, - bounding_box.unit_vector_angle, - set(rot_points) - ) - - rotation_angle_in_rad = get_smaller_angle(cropped_bounding_box) - img2 = region_initial.rotate(degrees(rotation_angle_in_rad), resample = Image.BICUBIC) - x_dash_1, y_dash_1, x_dash_2, y_dash_2, x_dash_3, y_dash_3, x_dash_4, y_dash_4 = rotated_points( + points_ordered = [updated_mbb_input[index] for index in ConvexHull(updated_mbb_input).vertices] + if args.augment: + for i in range(0, 3): + additional_pixel = random.randint(1, args.pixel_scaling) + mar = dilate_polygon(points_ordered, (i-1)*args.pixel_scaling + additional_pixel + 1) + bounding_box = minimum_bounding_box(mar) + (x1, y1), (x2, y2), (x3, y3), (x4, y4) = bounding_box.corner_points + min_x, min_y = int(min(x1, x2, x3, x4)), int(min(y1, y2, y3, y4)) + max_x, max_y = int(max(x1, x2, x3, x4)), int(max(y1, y2, y3, y4)) + box = (min_x, min_y, max_x, max_y) + region_initial = im.crop(box) + rot_points = [] + p1, p2 = (x1 - min_x, y1 - min_y), (x2 - min_x, y2 - min_y) + p3, p4 = (x3 - min_x, y3 - min_y), (x4 - min_x, y4 - min_y) + rot_points.append(p1) + rot_points.append(p2) + rot_points.append(p3) + rot_points.append(p4) + + cropped_bounding_box = bounding_box_tuple(bounding_box.area, + bounding_box.length_parallel, + bounding_box.length_orthogonal, + bounding_box.length_orthogonal, + bounding_box.unit_vector, + bounding_box.unit_vector_angle, + set(rot_points) + ) + + rotation_angle_in_rad = get_smaller_angle(cropped_bounding_box) + img2 = region_initial.rotate(degrees(rotation_angle_in_rad), resample = Image.BICUBIC) + x_dash_1, y_dash_1, x_dash_2, y_dash_2, x_dash_3, y_dash_3, x_dash_4, y_dash_4 = rotated_points( + cropped_bounding_box, get_center(region_initial)) + + min_x = int(min(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) + min_y = int(min(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) + max_x = int(max(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) + max_y = int(max(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) + box = (min_x, min_y, max_x, max_y) + region_final = img2.crop(box) + line_id = id + '_scale' + str(i) + set_line_image_data(region_final, line_id, image_file_name, image_fh) + else: + bounding_box = minimum_bounding_box(points_ordered) + (x1, y1), (x2, y2), (x3, y3), (x4, y4) = bounding_box.corner_points + min_x, min_y = int(min(x1, x2, x3, x4)), int(min(y1, y2, y3, y4)) + max_x, max_y = int(max(x1, x2, x3, x4)), int(max(y1, y2, y3, y4)) + box = (min_x, min_y, max_x, max_y) + region_initial = im.crop(box) + rot_points = [] + p1, p2 = (x1 - min_x, y1 - min_y), (x2 - min_x, y2 - min_y) + p3, p4 = (x3 - min_x, y3 - min_y), (x4 - min_x, y4 - min_y) + rot_points.append(p1) + rot_points.append(p2) + rot_points.append(p3) + rot_points.append(p4) + + cropped_bounding_box = bounding_box_tuple(bounding_box.area, + bounding_box.length_parallel, + bounding_box.length_orthogonal, + bounding_box.length_orthogonal, + bounding_box.unit_vector, + bounding_box.unit_vector_angle, + set(rot_points) + ) + + rotation_angle_in_rad = get_smaller_angle(cropped_bounding_box) + img2 = region_initial.rotate(degrees(rotation_angle_in_rad), resample = Image.BICUBIC) + x_dash_1, y_dash_1, x_dash_2, y_dash_2, x_dash_3, y_dash_3, x_dash_4, y_dash_4 = rotated_points( cropped_bounding_box, get_center(region_initial)) - min_x = int(min(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) - min_y = int(min(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) - max_x = int(max(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) - max_y = int(max(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) - box = (min_x, min_y, max_x, max_y) - region_final = img2.crop(box) - set_line_image_data(region_final, id, image_file_name, image_fh) + min_x = int(min(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) + min_y = int(min(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) + max_x = int(max(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) + max_y = int(max(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) + box = (min_x, min_y, max_x, max_y) + region_final = img2.crop(box) + set_line_image_data(region_final, id, image_file_name, image_fh) def check_file_location(base_name, wc_dict1, wc_dict2, wc_dict3): @@ -496,6 +536,8 @@ def check_writing_condition(wc_dict, base_name): writing_condition = wc_dict[base_name].strip() if writing_condition != 'IUC': return False + else: + return True else: return True diff --git a/egs/madcat_ar/v1/local/extract_features.sh b/egs/madcat_ar/v1/local/extract_features.sh index 56a8443e328..9fe588f31b8 100755 --- a/egs/madcat_ar/v1/local/extract_features.sh +++ b/egs/madcat_ar/v1/local/extract_features.sh @@ -9,6 +9,8 @@ nj=4 cmd=run.pl feat_dim=40 +augment='no_aug' +verticle_shift=0 echo "$0 $@" . ./cmd.sh @@ -34,9 +36,10 @@ done utils/split_scp.pl $scp $split_scps || exit 1; $cmd JOB=1:$nj $logdir/extract_features.JOB.log \ - local/make_features.py $logdir/images.JOB.scp \ + image/ocr/make_features.py $logdir/images.JOB.scp \ --allowed_len_file_path $data/allowed_lengths.txt \ - --feat-dim $feat_dim \| \ + --feat-dim $feat_dim --augment_type $augment \ + --vertical-shift $verticle_shift \| \ copy-feats --compress=true --compression-method=7 \ ark:- ark,scp:$featdir/images.JOB.ark,$featdir/images.JOB.scp diff --git a/egs/madcat_ar/v1/local/extract_lines.sh b/egs/madcat_ar/v1/local/extract_lines.sh index 50129ad38c9..ab87836ae3a 100755 --- a/egs/madcat_ar/v1/local/extract_lines.sh +++ b/egs/madcat_ar/v1/local/extract_lines.sh @@ -11,6 +11,8 @@ writing_condition2=/export/corpora/LDC/LDC2013T09/docs/writing_conditions.tab writing_condition3=/export/corpora/LDC/LDC2013T15/docs/writing_conditions.tab data_split_file=data/download/data_splits/madcat.dev.raw.lineid data=data/local/dev +subset=false +augment=false echo "$0 $@" . ./cmd.sh @@ -35,7 +37,7 @@ done $cmd JOB=1:$nj $log_dir/extract_lines.JOB.log \ local/create_line_image_from_page_image.py $download_dir1 $download_dir2 $download_dir3 \ $log_dir/lines.JOB.scp $data/JOB $writing_condition1 $writing_condition2 $writing_condition3 \ - || exit 1; + --subset $subset --augment $augment || exit 1; ## concatenate the .scp files together. for n in $(seq $nj); do diff --git a/egs/madcat_ar/v1/local/make_features.py b/egs/madcat_ar/v1/local/make_features.py deleted file mode 100755 index a21276d32c2..00000000000 --- a/egs/madcat_ar/v1/local/make_features.py +++ /dev/null @@ -1,138 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2017 Chun Chieh Chang -# 2017 Ashish Arora -# 2018 Hossein Hadian - -""" This script converts images to Kaldi-format feature matrices. The input to - this script is the path to a data directory, e.g. "data/train". This script - reads the images listed in images.scp and writes them to standard output - (by default) as Kaldi-formatted matrices (in text form). It also scales the - images so they have the same height (via --feat-dim). It can optionally pad - the images (on left/right sides) with white pixels. - If an 'image2num_frames' file is found in the data dir, it will be used - to enforce the images to have the specified length in that file by padding - white pixels (the --padding option will be ignored in this case). This relates - to end2end chain training. - - eg. local/make_features.py data/train --feat-dim 40 -""" - -import argparse -import os -import sys -import numpy as np -from scipy import misc - -parser = argparse.ArgumentParser(description="""Converts images (in 'dir'/images.scp) to features and - writes them to standard output in text format.""") -parser.add_argument('images_scp_path', type=str, - help='Path of images.scp file') -parser.add_argument('--allowed_len_file_path', type=str, default=None, - help='If supplied, each images will be padded to reach the ' - 'target length (this overrides --padding).') -parser.add_argument('--out-ark', type=str, default='-', - help='Where to write the output feature file') -parser.add_argument('--feat-dim', type=int, default=40, - help='Size to scale the height of all images') -parser.add_argument('--padding', type=int, default=5, - help='Number of white pixels to pad on the left' - 'and right side of the image.') - - -args = parser.parse_args() - - -def write_kaldi_matrix(file_handle, matrix, key): - file_handle.write(key + " [ ") - num_rows = len(matrix) - if num_rows == 0: - raise Exception("Matrix is empty") - num_cols = len(matrix[0]) - - for row_index in range(len(matrix)): - if num_cols != len(matrix[row_index]): - raise Exception("All the rows of a matrix are expected to " - "have the same length") - file_handle.write(" ".join(map(lambda x: str(x), matrix[row_index]))) - if row_index != num_rows - 1: - file_handle.write("\n") - file_handle.write(" ]\n") - - -def get_scaled_image(im): - scale_size = args.feat_dim - sx = im.shape[1] # width - sy = im.shape[0] # height - scale = (1.0 * scale_size) / sy - nx = int(scale_size) - ny = int(scale * sx) - im = misc.imresize(im, (nx, ny)) - return im - - -def horizontal_pad(im, allowed_lengths = None): - if allowed_lengths is None: - left_padding = right_padding = args.padding - else: # Find an allowed length for the image - imlen = im.shape[1] # width - allowed_len = 0 - for l in allowed_lengths: - if l > imlen: - allowed_len = l - break - if allowed_len == 0: - # No allowed length was found for the image (the image is too long) - return None - padding = allowed_len - imlen - left_padding = int(padding // 2) - right_padding = padding - left_padding - dim_y = im.shape[0] # height - im_pad = np.concatenate((255 * np.ones((dim_y, left_padding), - dtype=int), im), axis=1) - im_pad1 = np.concatenate((im_pad, 255 * np.ones((dim_y, right_padding), - dtype=int)), axis=1) - return im_pad1 - - -### main ### - -data_list_path = args.images_scp_path - -if args.out_ark == '-': - out_fh = sys.stdout -else: - out_fh = open(args.out_ark,'wb') - -allowed_lengths = None -allowed_len_handle = args.allowed_len_file_path -if os.path.isfile(allowed_len_handle): - print("Found 'allowed_lengths.txt' file...", file=sys.stderr) - allowed_lengths = [] - with open(allowed_len_handle) as f: - for line in f: - allowed_lengths.append(int(line.strip())) - print("Read {} allowed lengths and will apply them to the " - "features.".format(len(allowed_lengths)), file=sys.stderr) - -num_fail = 0 -num_ok = 0 -with open(data_list_path) as f: - for line in f: - line = line.strip() - line_vect = line.split(' ') - image_id = line_vect[0] - image_path = line_vect[1] - im = misc.imread(image_path) - im_scaled = get_scaled_image(im) - im_horizontal_padded = horizontal_pad(im_scaled, allowed_lengths) - if im_horizontal_padded is None: - num_fail += 1 - continue - data = np.transpose(im_horizontal_padded, (1, 0)) - data = np.divide(data, 255.0) - num_ok += 1 - write_kaldi_matrix(out_fh, data, image_id) - -print('Generated features for {} images. Failed for {} (image too ' - 'long).'.format(num_ok, num_fail), file=sys.stderr) diff --git a/egs/madcat_ar/v1/local/prepare_data.sh b/egs/madcat_ar/v1/local/prepare_data.sh deleted file mode 100755 index d808d736845..00000000000 --- a/egs/madcat_ar/v1/local/prepare_data.sh +++ /dev/null @@ -1,53 +0,0 @@ -#!/bin/bash - -# Copyright 2017 Chun Chieh Chang -# 2017 Ashish Arora -# 2017 Hossein Hadian -# Apache 2.0 - -# This script prepares the training and test data for MADCAT Arabic dataset -# (i.e text, images.scp, utt2spk and spk2utt). It calls process_data.py. - -# Eg. local/prepare_data.sh -# Eg. text file: LDC0001_000404_NHR_ARB_20070113.0052_11_LDC0001_00z2 ﻮﺠﻫ ﻮﻌﻘﻟ ﻍﺍﺮﻗ ﺢﺗّﻯ ﺎﻠﻨﺧﺎﻋ -# utt2spk file: LDC0001_000397_NHR_ARB_20070113.0052_11_LDC0001_00z1 LDC0001 -# images.scp file: LDC0009_000000_arb-NG-2-76513-5612324_2_LDC0009_00z0 -# data/local/lines/1/arb-NG-2-76513-5612324_2_LDC0009_00z0.tif - -stage=0 -download_dir1=/export/corpora/LDC/LDC2012T15/data -download_dir2=/export/corpora/LDC/LDC2013T09/data -download_dir3=/export/corpora/LDC/LDC2013T15/data -writing_condition1=/export/corpora/LDC/LDC2012T15/docs/writing_conditions.tab -writing_condition2=/export/corpora/LDC/LDC2013T09/docs/writing_conditions.tab -writing_condition3=/export/corpora/LDC/LDC2013T15/docs/writing_conditions.tab -data_splits_dir=data/download/data_splits -images_scp_dir=data/local - -. ./cmd.sh -. ./path.sh -. ./utils/parse_options.sh || exit 1; - -mkdir -p data/{train,test,dev} - -if [ $stage -le 1 ]; then - echo "$0: Processing dev, train and test data..." - echo "Date: $(date)." - local/process_data.py $download_dir1 $download_dir2 $download_dir3 \ - $data_splits_dir/madcat.dev.raw.lineid data/dev $images_scp_dir/dev/images.scp \ - $writing_condition1 $writing_condition2 $writing_condition3 || exit 1 - - local/process_data.py $download_dir1 $download_dir2 $download_dir3 \ - $data_splits_dir/madcat.test.raw.lineid data/test $images_scp_dir/test/images.scp \ - $writing_condition1 $writing_condition2 $writing_condition3 || exit 1 - - local/process_data.py $download_dir1 $download_dir2 $download_dir3 \ - $data_splits_dir/madcat.train.raw.lineid data/train $images_scp_dir/train/images.scp \ - $writing_condition1 $writing_condition2 $writing_condition3 || exit 1 - - for dataset in dev test train; do - echo "$0: Fixing data directory for dataset: $dataset" - echo "Date: $(date)." - image/fix_data_dir.sh data/$dataset - done -fi diff --git a/egs/madcat_ar/v1/local/prepend_words.py b/egs/madcat_ar/v1/local/prepend_words.py deleted file mode 100755 index d53eb8974bf..00000000000 --- a/egs/madcat_ar/v1/local/prepend_words.py +++ /dev/null @@ -1,13 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# This script, prepend '|' to every words in the transcript to mark -# the beginning of the words for finding the initial-space of every word -# after decoding. - -import sys, io - -infile = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') -output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') -for line in infile: - output.write(' '.join(["|" + word for word in line.split()]) + '\n') diff --git a/egs/madcat_ar/v1/local/process_data.py b/egs/madcat_ar/v1/local/process_data.py index 920cb6f700b..e476b67cb96 100755 --- a/egs/madcat_ar/v1/local/process_data.py +++ b/egs/madcat_ar/v1/local/process_data.py @@ -42,6 +42,8 @@ help='Path to the downloaded (and extracted) writing conditions file 2') parser.add_argument('writing_condition3', type=str, help='Path to the downloaded (and extracted) writing conditions file 3') +parser.add_argument("--augment", type=lambda x: (str(x).lower()=='true'), default=False, + help="performs image augmentation") parser.add_argument("--subset", type=lambda x: (str(x).lower()=='true'), default=False, help="only processes subset of data based on writing condition") args = parser.parse_args() @@ -103,6 +105,8 @@ def check_writing_condition(wc_dict): writing_condition = wc_dict[base_name].strip() if writing_condition != 'IUC': return False + else: + return True else: return True @@ -184,14 +188,30 @@ def get_line_image_location(): writer_id = writer[0].getAttribute('id') text_line_word_dict = read_text(madcat_xml_path) base_name = os.path.basename(image_file_path).split('.tif')[0] - for lineID in sorted(text_line_word_dict): - updated_base_name = base_name + '_' + str(lineID).zfill(4) +'.png' - location = image_loc_dict[updated_base_name] - image_file_path = os.path.join(location, updated_base_name) - line = text_line_word_dict[lineID] - text = ' '.join(line) - utt_id = writer_id + '_' + str(image_num).zfill(6) + '_' + base_name + '_' + str(lineID).zfill(4) - text_fh.write(utt_id + ' ' + text + '\n') - utt2spk_fh.write(utt_id + ' ' + writer_id + '\n') - image_fh.write(utt_id + ' ' + image_file_path + '\n') - image_num += 1 + for line_id in sorted(text_line_word_dict): + if args.augment: + key = (line_id + '.')[:-1] + for i in range(0, 3): + location_id = '_' + line_id + '_scale' + str(i) + line_image_file_name = base_name + location_id + '.png' + location = image_loc_dict[line_image_file_name] + image_file_path = os.path.join(location, line_image_file_name) + line = text_line_word_dict[key] + text = ' '.join(line) + base_line_image_file_name = line_image_file_name.split('.png')[0] + utt_id = writer_id + '_' + str(image_num).zfill(6) + '_' + base_line_image_file_name + text_fh.write(utt_id + ' ' + text + '\n') + utt2spk_fh.write(utt_id + ' ' + writer_id + '\n') + image_fh.write(utt_id + ' ' + image_file_path + '\n') + image_num += 1 + else: + updated_base_name = base_name + '_' + str(line_id).zfill(4) +'.png' + location = image_loc_dict[updated_base_name] + image_file_path = os.path.join(location, updated_base_name) + line = text_line_word_dict[line_id] + text = ' '.join(line) + utt_id = writer_id + '_' + str(image_num).zfill(6) + '_' + base_name + '_' + str(line_id).zfill(4) + text_fh.write(utt_id + ' ' + text + '\n') + utt2spk_fh.write(utt_id + ' ' + writer_id + '\n') + image_fh.write(utt_id + ' ' + image_file_path + '\n') + image_num += 1 diff --git a/egs/madcat_ar/v1/local/tl/augment_data.sh b/egs/madcat_ar/v1/local/tl/augment_data.sh new file mode 100755 index 00000000000..cc44aa58a62 --- /dev/null +++ b/egs/madcat_ar/v1/local/tl/augment_data.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# Copyright 2018 Hossein Hadian +# 2018 Ashish Arora + +# Apache 2.0 +# This script performs data augmentation. + +nj=4 +cmd=run.pl +feat_dim=40 +verticle_shift=0 +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +srcdir=$1 +outdir=$2 +datadir=$3 +aug_set=aug1 +mkdir -p $datadir/augmentations +echo "copying $srcdir to $datadir/augmentations/$aug_set, allowed length, creating feats.scp" + +for set in $aug_set; do + image/copy_data_dir.sh --spk-prefix $set- --utt-prefix $set- \ + $srcdir $datadir/augmentations/$set + cat $srcdir/allowed_lengths.txt > $datadir/augmentations/$set/allowed_lengths.txt + local/extract_features.sh --nj $nj --cmd "$cmd" --feat-dim $feat_dim \ + --vertical-shift $verticle_shift \ + --augment 'random_shift' $datadir/augmentations/$set +done + +echo " combine original data and data from different augmentations" +utils/combine_data.sh --extra-files images.scp $outdir $srcdir $datadir/augmentations/$aug_set +cat $srcdir/allowed_lengths.txt > $outdir/allowed_lengths.txt diff --git a/egs/madcat_ar/v1/local/tl/chain/run_cnn_e2eali.sh b/egs/madcat_ar/v1/local/tl/chain/run_cnn_e2eali.sh new file mode 100755 index 00000000000..e0cca104f50 --- /dev/null +++ b/egs/madcat_ar/v1/local/tl/chain/run_cnn_e2eali.sh @@ -0,0 +1,229 @@ +#!/bin/bash + +# ./local/chain/compare_wer.sh exp/chain/cnn_e2eali_1a/ +# System cnn_e2eali_1a +# WER 16.78 +# CER 5.22 +# Final train prob -0.1189 +# Final valid prob -0.1319 +# Final train prob (xent) -0.6395 +# Final valid prob (xent) -0.6732 +# Parameters 3.73M + +# steps/info/chain_dir_info.pl exp/chain/cnn_e2eali_1a/ +# exp/chain/cnn_e2eali_1a/: num-iters=24 nj=3..15 num-params=3.7M dim=56->392 combine=-0.125->-0.125 (over 1) xent:train/valid[15,23,final]=(-0.850,-1.24,-0.640/-0.901,-1.31,-0.673) logprob:train/valid[15,23,final]=(-0.149,-0.209,-0.119/-0.166,-0.229,-0.132) +set -e -o pipefail + +stage=0 + +nj=30 +train_set=train +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. +affix=_1a #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +common_egs_dir= +reporting_email= + +# chain options +train_stage=-10 +xent_regularize=0.1 +# training chunk-options +chunk_width=340,300,200,100 +num_leaves=500 +tdnn_dim=450 +srand=0 +remove_egs=true +lang_decode=data/lang +# End configuration section. +echo "$0 $@" # Print the command line for logging + + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 2 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/nnet3/align_lats.sh --nj $nj --cmd "$cmd" \ + --acoustic-scale 1.0 \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0' \ + ${train_data_dir} data/lang $e2echain_model_dir $lat_dir + echo "" >$lat_dir/splice_opts + +fi + +if [ $stage -le 3 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor 4 \ + --alignment-subsampling-factor 1 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$cmd" $num_leaves ${train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 4 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') + learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + cnn_opts="l2-regularize=0.075" + tdnn_opts="l2-regularize=0.075" + output_opts="l2-regularize=0.1" + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=70" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=56 name=input + conv-relu-batchnorm-layer name=cnn1 height-in=56 height-out=56 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=56 height-out=28 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=28 height-out=28 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=28 height-out=28 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=28 height-out=14 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=14 height-out=14 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=14 height-out=14 time-offsets=-4,0,4 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' mod?els... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 5 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/iam-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=1000" \ + --chain.frame-subsampling-factor=4 \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=2 \ + --trainer.frames-per-iter=2000000 \ + --trainer.optimization.num-jobs-initial=3 \ + --trainer.optimization.num-jobs-final=16 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=64,32 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 6 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $lang_decode \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 7 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/test $dir/decode_test || exit 1; +fi + +echo "Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/madcat_ar/v1/local/tl/chain/run_e2e_cnn.sh b/egs/madcat_ar/v1/local/tl/chain/run_e2e_cnn.sh new file mode 100755 index 00000000000..3fca8cf5fdc --- /dev/null +++ b/egs/madcat_ar/v1/local/tl/chain/run_e2e_cnn.sh @@ -0,0 +1,165 @@ +#!/bin/bash +# Copyright 2017 Hossein Hadian + +# This script does end2end chain training (i.e. from scratch) + +# ./local/chain/compare_wer.sh exp/chain/e2e_cnn_1a/ +# System e2e_cnn_1a +# WER 19.30 +# CER 5.72 +# Final train prob -0.0734 +# Final valid prob -0.0607 +# Final train prob (xent) +# Final valid prob (xent) +# Parameters 3.30M + +# steps/info/chain_dir_info.pl exp/chain/e2e_cnn_1a/ +# exp/chain/e2e_cnn_1a/: num-iters=24 nj=3..15 num-params=3.3M dim=56->292 combine=-0.060->-0.060 (over 1) logprob:train/valid[15,23,final]=(-0.122,-0.143,-0.073/-0.105,-0.132,-0.061) + +set -e + + +# configs for 'chain' +stage=0 +nj=30 +train_stage=-10 +get_egs_stage=-10 +affix=1a + +# training options +tdnn_dim=450 +minibatch_size=150=64,32/300=32,16/600=16,8/1200=8,4 +common_egs_dir= +frames_per_iter=1000000 +cmvn_opts="--norm-means=false --norm-vars=false" +train_set=train +lang_decode=data/lang + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 1 ]; then + steps/nnet3/chain/e2e/prepare_e2e.sh --nj $nj --cmd "$cmd" \ + --shared-phones true \ + --type mono \ + data/$train_set $lang $treedir + $cmd $treedir/log/make_phone_lm.log \ + cat data/$train_set/text \| \ + steps/nnet3/chain/e2e/text_to_phones.py data/lang \| \ + utils/sym2int.pl -f 2- data/lang/phones.txt \| \ + chain-est-phone-lm --num-extra-lm-states=500 \ + ark:- $treedir/phone_lm.fst +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $treedir/tree | grep num-pdfs | awk '{print $2}') + common1="required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=36" + common2="required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=70" + common3="required-time-offsets= height-offsets=-1,0,1 num-filters-out=70" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=56 name=input + conv-relu-batchnorm-layer name=cnn1 height-in=56 height-out=56 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=56 height-out=28 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=28 height-out=28 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=28 height-out=28 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=28 height-out=14 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=14 height-out=14 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=14 height-out=14 time-offsets=-4,0,4 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-4,0,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs +fi + +if [ $stage -le 3 ]; then + # no need to store the egs in a shared storage because we always + # remove them. Anyway, it takes only 5 minutes to generate them. + + steps/nnet3/chain/e2e/train_e2e.py --stage $train_stage \ + --cmd "$cmd" \ + --feat.cmvn-opts "$cmvn_opts" \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--num_egs_diagnostic 100 --num_utts_subset 400" \ + --chain.frame-subsampling-factor 4 \ + --chain.alignment-subsampling-factor 4 \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=1000" \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter 2000000 \ + --trainer.num-epochs 2 \ + --trainer.optimization.momentum 0 \ + --trainer.optimization.num-jobs-initial 3 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.optimization.shrink-value 1.0 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs true \ + --feat-dir data/${train_set} \ + --tree-dir $treedir \ + --dir $dir || exit 1; +fi + +if [ $stage -le 4 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $lang_decode \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 5 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $nj --cmd "$cmd" \ + $dir/graph data/test $dir/decode_test || exit 1; +fi + +echo "Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/madcat_ar/v1/local/tl/process_waldo_data.py b/egs/madcat_ar/v1/local/tl/process_waldo_data.py new file mode 100755 index 00000000000..0d278e64122 --- /dev/null +++ b/egs/madcat_ar/v1/local/tl/process_waldo_data.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 + +""" This script reads image and transcription mapping and creates the following files :text, utt2spk, images.scp. + Eg. local/process_waldo_data.py lines/hyp_line_image_transcription_mapping_kaldi.txt data/test + Eg. text file: LDC0001_000404_NHR_ARB_20070113.0052_11_LDC0001_00z2 ﻮﺠﻫ ﻮﻌﻘﻟ ﻍﺍﺮﻗ ﺢﺗّﻯ ﺎﻠﻨﺧﺎﻋ + utt2spk file: LDC0001_000397_NHR_ARB_20070113.0052_11_LDC0001_00z1 LDC0001 + images.scp file: LDC0009_000000_arb-NG-2-76513-5612324_2_LDC0009_00z0 + data/local/lines/1/arb-NG-2-76513-5612324_2_LDC0009_00z0.tif +""" + +import argparse +import os +import sys + +parser = argparse.ArgumentParser(description="Creates text, utt2spk and images.scp files", + epilog="E.g. " + sys.argv[0] + " data/train data/local/lines ", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('image_transcription_file', type=str, + help='Path to the file containing line image path and transcription information') +parser.add_argument('out_dir', type=str, + help='directory location to write output files.') +args = parser.parse_args() + + +def read_image_text(image_text_path): + """ Given the file path containing, mapping information of line image + and transcription, it returns a dict. The dict contains this mapping + info. It can be accessed via line_id and will provide transcription. + Returns: + -------- + dict: line_id and transcription mapping + """ + image_transcription_dict = dict() + with open(image_text_path, encoding='utf-8') as f: + for line in f: + line_vect = line.strip().split(' ') + image_path = line_vect[0] + line_id = os.path.basename(image_path).split('.png')[0] + transcription = line_vect[1:] + joined_transcription = list() + for word in transcription: + joined_transcription.append(word) + joined_transcription = " ".join(joined_transcription) + image_transcription_dict[line_id] = joined_transcription + return image_transcription_dict + + +### main ### +print("Processing '{}' data...".format(args.out_dir)) +text_file = os.path.join(args.out_dir, 'text') +text_fh = open(text_file, 'w', encoding='utf-8') +utt2spk_file = os.path.join(args.out_dir, 'utt2spk') +utt2spk_fh = open(utt2spk_file, 'w', encoding='utf-8') +image_file = os.path.join(args.out_dir, 'images.scp') +image_fh = open(image_file, 'w', encoding='utf-8') + +image_transcription_dict = read_image_text(args.image_transcription_file) +for line_id in sorted(image_transcription_dict.keys()): + writer_id = line_id.strip().split('_')[-3] + updated_line_id = line_id + '.png' + image_file_path = os.path.join('lines', updated_line_id) + text = image_transcription_dict[line_id] + utt_id = line_id + text_fh.write(utt_id + ' ' + text + '\n') + utt2spk_fh.write(utt_id + ' ' + writer_id + '\n') + image_fh.write(utt_id + ' ' + image_file_path + '\n') + diff --git a/egs/madcat_ar/v1/local/tl/run_text_localization.sh b/egs/madcat_ar/v1/local/tl/run_text_localization.sh new file mode 100755 index 00000000000..8d12f7d802f --- /dev/null +++ b/egs/madcat_ar/v1/local/tl/run_text_localization.sh @@ -0,0 +1,143 @@ +#!/bin/bash +# Copyright 2017 Hossein Hadian +# 2018 Ashish Arora + +# This script performs full page text recognition on automatically extracted line images +# from madcat arabic data. It is created as a separate scrip, because it performs +# data augmentation, uses smaller language model and calls process_waldo_data for +# test images (automatically extracted line images). Data augmentation increases image +# height hence requires different DNN arachitecture and different chain scripts. + +set -e +stage=0 +nj=70 +# download_dir{1,2,3} points to the database path on the JHU grid. If you have not +# already downloaded the database you can set it to a local directory +# This corpus can be purchased here: +# https://catalog.ldc.upenn.edu/{LDC2012T15,LDC2013T09/,LDC2013T15/} +download_dir1=/export/corpora/LDC/LDC2012T15/data +download_dir2=/export/corpora/LDC/LDC2013T09/data +download_dir3=/export/corpora/LDC/LDC2013T15/data +writing_condition1=/export/corpora/LDC/LDC2012T15/docs/writing_conditions.tab +writing_condition2=/export/corpora/LDC/LDC2013T09/docs/writing_conditions.tab +writing_condition3=/export/corpora/LDC/LDC2013T15/docs/writing_conditions.tab +data_splits_dir=data/download/data_splits +images_scp_dir=data/local +overwrite=false +subset=true +augment=true +verticle_shift=16 +. ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. + ## This relates to the queue. +. ./path.sh +. ./utils/parse_options.sh # e.g. this parses the above options + # if supplied. +./local/check_tools.sh + +mkdir -p data/{train,test,dev}/data +mkdir -p data/local/{train,test,dev} +if [ $stage -le 0 ]; then + + if [ -f data/train/text ] && ! $overwrite; then + echo "$0: Not processing, probably script have run from wrong stage" + echo "Exiting with status 1 to avoid data corruption" + exit 1; + fi + echo "$0: Downloading data splits...$(date)" + local/download_data.sh --data_splits $data_splits_dir --download_dir1 $download_dir1 \ + --download_dir2 $download_dir2 --download_dir3 $download_dir3 + + for set in train dev; do + data_split_file=$data_splits_dir/madcat.$set.raw.lineid + local/extract_lines.sh --nj $nj --cmd $cmd --data_split_file $data_split_file \ + --download_dir1 $download_dir1 --download_dir2 $download_dir2 \ + --download_dir3 $download_dir3 --writing_condition1 $writing_condition1 \ + --writing_condition2 $writing_condition2 --writing_condition3 $writing_condition3 \ + --data data/local/$set --subset $subset --augment $augment || exit 1 + done + + echo "$0: Preparing data..." + for set in dev train; do + local/process_data.py $download_dir1 $download_dir2 $download_dir3 \ + $data_splits_dir/madcat.$set.raw.lineid data/$set $images_scp_dir/$set/images.scp \ + $writing_condition1 $writing_condition2 $writing_condition3 --augment $augment --subset $subset + image/fix_data_dir.sh data/${set} + done + + local/tl/process_waldo_data.py lines/hyp_line_image_transcription_mapping_kaldi.txt data/test + utils/utt2spk_to_spk2utt.pl data/test/utt2spk > data/test/spk2utt +fi + +if [ $stage -le 1 ]; then + echo "$0: Obtaining image groups. calling get_image2num_frames $(date)." + image/get_image2num_frames.py data/train + image/get_allowed_lengths.py --frame-subsampling-factor 4 10 data/train + for set in dev train test; do + echo "$0: Extracting features and calling compute_cmvn_stats for dataset: $set. $(date)" + local/extract_features.sh --nj $nj --cmd $cmd --feat-dim 40 \ + --verticle_shift $verticle_shift data/$set + steps/compute_cmvn_stats.sh data/$set || exit 1; + done + echo "$0: Fixing data directory for train dataset $(date)." + image/fix_data_dir.sh data/train +fi + +if [ $stage -le 2 ]; then + for set in train; do + echo "$(date) stage 2: Performing augmentation, it will double training data" + local/tl/augment_data.sh --nj $nj --cmd "$cmd" --feat-dim 40 \ + --verticle_shift $verticle_shift data/${set} data/${set}_aug data + steps/compute_cmvn_stats.sh data/${set}_aug || exit 1; + done +fi + +if [ $stage -le 3 ]; then + echo "$0: Preparing BPE..." + cut -d' ' -f2- data/train/text | utils/lang/bpe/reverse.py | \ + utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/learn_bpe.py -s 700 > data/local/bpe.txt + + for set in test train dev train_aug; do + cut -d' ' -f1 data/$set/text > data/$set/ids + cut -d' ' -f2- data/$set/text | utils/lang/bpe/reverse.py | \ + utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt \ + | sed 's/@@//g' > data/$set/bpe_text + + mv data/$set/text data/$set/text.old + paste -d' ' data/$set/ids data/$set/bpe_text > data/$set/text + rm -f data/$set/bpe_text data/$set/ids + done + + echo "$0:Preparing dictionary and lang..." + local/prepare_dict.sh + utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 8 --sil-prob 0.0 --position-dependent-phones false \ + data/local/dict "" data/lang/temp data/lang + utils/lang/bpe/add_final_optional_silence.sh --final-sil-prob 0.5 data/lang +fi + +if [ $stage -le 4 ]; then + echo "$0: Estimating a language model for decoding..." + local/tl/train_lm.sh --order 3 + utils/format_lm.sh data/lang data/local/local_lm/data/arpa/3gram_unpruned.arpa.gz \ + data/local/dict/lexicon.txt data/lang +fi + +nj=30 +if [ $stage -le 5 ]; then + echo "$0: Calling the flat-start chain recipe... $(date)." + local/tl/chain/run_e2e_cnn.sh --nj $nj --train_set train_aug +fi + +if [ $stage -le 6 ]; then + echo "$0: Aligning the training data using the e2e chain model...$(date)." + steps/nnet3/align.sh --nj $nj --cmd "$cmd" \ + --use-gpu false \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0 --acoustic-scale=1.0' \ + data/train_aug data/lang exp/chain/e2e_cnn_1a exp/chain/e2e_ali_train +fi + +if [ $stage -le 7 ]; then + echo "$0: Building a tree and training a regular chain model using the e2e alignments...$(date)" + local/tl/chain/run_cnn_e2eali.sh --nj $nj --train_set train_aug +fi diff --git a/egs/madcat_ar/v1/local/tl/train_lm.sh b/egs/madcat_ar/v1/local/tl/train_lm.sh new file mode 100755 index 00000000000..524bb2e9f40 --- /dev/null +++ b/egs/madcat_ar/v1/local/tl/train_lm.sh @@ -0,0 +1,102 @@ +#!/bin/bash + +# Copyright 2016 Vincent Nguyen +# 2016 Johns Hopkins University (author: Daniel Povey) +# 2017 Ashish Arora +# 2017 Hossein Hadian +# Apache 2.0 +# +# This script trains a LM on the training transcriptions. +# It is based on the example scripts distributed with PocoLM + +# It will check if pocolm is installed and if not will proceed with installation + +set -e +stage=0 +dir=data/local/local_lm +order=3 +echo "$0 $@" # Print the command line for logging +. ./utils/parse_options.sh || exit 1; + +lm_dir=${dir}/data + + +mkdir -p $dir +. ./path.sh || exit 1; # for KALDI_ROOT +export PATH=$KALDI_ROOT/tools/pocolm/scripts:$PATH +( # First make sure the pocolm toolkit is installed. + cd $KALDI_ROOT/tools || exit 1; + if [ -d pocolm ]; then + echo Not installing the pocolm toolkit since it is already there. + else + echo "$0: Please install the PocoLM toolkit with: " + echo " cd ../../../tools; extras/install_pocolm.sh; cd -" + exit 1; + fi +) || exit 1; + +bypass_metaparam_optim_opt= +# If you want to bypass the metaparameter optimization steps with specific metaparameters +# un-comment the following line, and change the numbers to some appropriate values. +# You can find the values from output log of train_lm.py. +# These example numbers of metaparameters is for 4-gram model (with min-counts) +# running with train_lm.py. +# The dev perplexity should be close to the non-bypassed model. +# Note: to use these example parameters, you may need to remove the .done files +# to make sure the make_lm_dir.py be called and tain only 3-gram model +#for order in 3; do +#rm -f ${lm_dir}/${num_word}_${order}.pocolm/.done +if [ $stage -le 0 ]; then + mkdir -p ${dir}/data + mkdir -p ${dir}/data/text + + echo "$0: Getting the Data sources" + + rm ${dir}/data/text/* 2>/dev/null || true + + # use the validation data as the dev set. + # Note: the name 'dev' is treated specially by pocolm, it automatically + # becomes the dev set. + + cat data/dev/text | cut -d " " -f 2- > ${dir}/data/text/dev.txt + + # use the training data as an additional data source. + # we can later fold the dev data into this. + cat data/train/text | cut -d " " -f 2- > ${dir}/data/text/train.txt + + # for reporting perplexities, we'll use the "real" dev set. + # (the validation data is used as ${dir}/data/text/dev.txt to work + # out interpolation weights.) + # note, we can't put it in ${dir}/data/text/, because then pocolm would use + # it as one of the data sources. + cut -d " " -f 2- < data/test/text > ${dir}/data/real_dev_set.txt + + # get the wordlist from MADCAT text + cat ${dir}/data/text/train.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + cat ${dir}/data/word_count | awk '{print $2}' > ${dir}/data/wordlist +fi + +if [ $stage -le 1 ]; then + # decide on the vocabulary. + # Note: you'd use --wordlist if you had a previously determined word-list + # that you wanted to use. + # Note: if you have more than one order, use a certain amount of words as the + # vocab and want to restrict max memory for 'sort', + echo "$0: training the unpruned LM" + min_counts='train=1' + wordlist=${dir}/data/wordlist + + lm_name="`basename ${wordlist}`_${order}" + if [ -n "${min_counts}" ]; then + lm_name+="_`echo ${min_counts} | tr -s "[:blank:]" "_" | tr "=" "-"`" + fi + unpruned_lm_dir=${lm_dir}/${lm_name}.pocolm + train_lm.py --wordlist=${wordlist} --num-splits=20 --warm-start-ratio=20 \ + --limit-unk-history=true \ + ${bypass_metaparam_optim_opt} \ + ${dir}/data/text ${order} ${lm_dir}/work ${unpruned_lm_dir} + + get_data_prob.py ${dir}/data/real_dev_set.txt ${unpruned_lm_dir} 2>&1 | grep -F '[perplexity' + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${unpruned_lm_dir} | gzip -c > ${dir}/data/arpa/${order}gram_unpruned.arpa.gz +fi diff --git a/egs/madcat_ar/v1/local/wer_output_filter b/egs/madcat_ar/v1/local/wer_output_filter index c0f03e7178a..d6d46f3f565 100755 --- a/egs/madcat_ar/v1/local/wer_output_filter +++ b/egs/madcat_ar/v1/local/wer_output_filter @@ -2,6 +2,9 @@ # Copyright 2012-2014 Johns Hopkins University (Author: Yenda Trmal) # Apache 2.0 +# This script converts a BPE-encoded text to normal text and performs normalization. +# It is used in scoring. + use utf8; use open qw(:encoding(utf8)); diff --git a/egs/madcat_ar/v1/run.sh b/egs/madcat_ar/v1/run.sh index f6a63320497..d3937582662 100755 --- a/egs/madcat_ar/v1/run.sh +++ b/egs/madcat_ar/v1/run.sh @@ -32,7 +32,6 @@ mkdir -p data/{train,test,dev}/data mkdir -p data/local/{train,test,dev} if [ $stage -le 0 ]; then - if [ -f data/train/text ] && ! $overwrite; then echo "$0: Not processing, probably script have run from wrong stage" echo "Exiting with status 1 to avoid data corruption" @@ -42,30 +41,27 @@ if [ $stage -le 0 ]; then echo "$0: Downloading data splits...$(date)" local/download_data.sh --data_splits $data_splits_dir --download_dir1 $download_dir1 \ --download_dir2 $download_dir2 --download_dir3 $download_dir3 -fi -if [ $stage -le 1 ]; then - for dataset in test train dev; do - data_split_file=$data_splits_dir/madcat.$dataset.raw.lineid + for set in test train dev; do + data_split_file=$data_splits_dir/madcat.$set.raw.lineid local/extract_lines.sh --nj $nj --cmd $cmd --data_split_file $data_split_file \ --download_dir1 $download_dir1 --download_dir2 $download_dir2 \ --download_dir3 $download_dir3 --writing_condition1 $writing_condition1 \ --writing_condition2 $writing_condition2 --writing_condition3 $writing_condition3 \ - --data data/local/$dataset + --data data/local/$set --subset $subset --augment $augment || exit 1 done -fi -if [ $stage -le 2 ]; then echo "$0: Preparing data..." - local/prepare_data.sh --download_dir1 $download_dir1 --download_dir2 $download_dir2 \ - --download_dir3 $download_dir3 --images_scp_dir data/local \ - --data_splits_dir $data_splits_dir --writing_condition1 $writing_condition1 \ - --writing_condition2 $writing_condition2 --writing_condition3 $writing_condition3 + for set in dev train test; do + local/process_data.py $download_dir1 $download_dir2 $download_dir3 \ + $data_splits_dir/madcat.$set.raw.lineid data/$set $images_scp_dir/$set/images.scp \ + $writing_condition1 $writing_condition2 $writing_condition3 --augment $augment --subset $subset + image/fix_data_dir.sh data/${set} + done fi -mkdir -p data/{train,test,dev}/data -if [ $stage -le 3 ]; then +if [ $stage -le 1 ]; then for dataset in test train; do local/extract_features.sh --nj $nj --cmd $cmd --feat-dim 40 data/$dataset steps/compute_cmvn_stats.sh data/$dataset || exit 1; @@ -73,33 +69,53 @@ if [ $stage -le 3 ]; then utils/fix_data_dir.sh data/train fi -if [ $stage -le 4 ]; then - echo "$0: Preparing dictionary and lang..." +if [ $stage -le 2 ]; then + echo "$0: Preparing BPE..." + cut -d' ' -f2- data/train/text | utils/lang/bpe/reverse.py | \ + utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/learn_bpe.py -s 700 > data/local/bpe.txt + + for set in test train dev; do + cut -d' ' -f1 data/$set/text > data/$set/ids + cut -d' ' -f2- data/$set/text | utils/lang/bpe/reverse.py | \ + utils/lang/bpe/prepend_words.py | \ + utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt \ + | sed 's/@@//g' > data/$set/bpe_text + + mv data/$set/text data/$set/text.old + paste -d' ' data/$set/ids data/$set/bpe_text > data/$set/text + rm -f data/$set/bpe_text data/$set/ids + done + + echo "$0:Preparing dictionary and lang..." local/prepare_dict.sh - utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 8 --sil-prob 0.95 \ - data/local/dict "" data/lang/temp data/lang + utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 8 --sil-prob 0.0 --position-dependent-phones false \ + data/local/dict "" data/lang/temp data/lang + utils/lang/bpe/add_final_optional_silence.sh --final-sil-prob 0.5 data/lang fi -if [ $stage -le 5 ]; then +if [ $stage -le 3 ]; then echo "$0: Estimating a language model for decoding..." local/train_lm.sh - utils/format_lm.sh data/lang data/local/local_lm/data/arpa/6gram_unpruned.arpa.gz \ - data/local/dict/lexicon.txt data/lang_test + utils/format_lm.sh data/lang data/local/local_lm/data/arpa/6gram_small.arpa.gz \ + data/local/dict/lexicon.txt data/lang + utils/build_const_arpa_lm.sh data/local/local_lm/data/arpa/6gram_unpruned.arpa.gz \ + data/lang data/lang_rescore_6g fi -if [ $stage -le 6 ]; then +if [ $stage -le 4 ]; then steps/train_mono.sh --nj $nj --cmd $cmd --totgauss 10000 data/train \ data/lang exp/mono fi -if [ $stage -le 7 ] && $decode_gmm; then - utils/mkgraph.sh --mono data/lang_test exp/mono exp/mono/graph +if [ $stage -le 5 ] && $decode_gmm; then + utils/mkgraph.sh --mono data/lang exp/mono exp/mono/graph steps/decode.sh --nj $nj --cmd $cmd exp/mono/graph data/test \ exp/mono/decode_test fi -if [ $stage -le 8 ]; then +if [ $stage -le 6 ]; then steps/align_si.sh --nj $nj --cmd $cmd data/train data/lang \ exp/mono exp/mono_ali @@ -107,14 +123,14 @@ if [ $stage -le 8 ]; then exp/mono_ali exp/tri fi -if [ $stage -le 9 ] && $decode_gmm; then - utils/mkgraph.sh data/lang_test exp/tri exp/tri/graph +if [ $stage -le 7 ] && $decode_gmm; then + utils/mkgraph.sh data/lang exp/tri exp/tri/graph steps/decode.sh --nj $nj --cmd $cmd exp/tri/graph data/test \ exp/tri/decode_test fi -if [ $stage -le 10 ]; then +if [ $stage -le 8 ]; then steps/align_si.sh --nj $nj --cmd $cmd data/train data/lang \ exp/tri exp/tri_ali @@ -123,22 +139,22 @@ if [ $stage -le 10 ]; then data/train data/lang exp/tri_ali exp/tri3 fi -if [ $stage -le 11 ] && $decode_gmm; then - utils/mkgraph.sh data/lang_test exp/tri3 exp/tri3/graph +if [ $stage -le 9 ] && $decode_gmm; then + utils/mkgraph.sh data/lang exp/tri3 exp/tri3/graph steps/decode.sh --nj $nj --cmd $cmd exp/tri3/graph \ data/test exp/tri3/decode_test fi -if [ $stage -le 12 ]; then +if [ $stage -le 10 ]; then steps/align_fmllr.sh --nj $nj --cmd $cmd --use-graphs true \ data/train data/lang exp/tri3 exp/tri3_ali fi -if [ $stage -le 13 ]; then +if [ $stage -le 11 ]; then local/chain/run_cnn.sh fi -if [ $stage -le 14 ]; then +if [ $stage -le 12 ]; then local/chain/run_cnn_chainali.sh --stage 2 fi diff --git a/egs/madcat_ar/v1/run_end2end.sh b/egs/madcat_ar/v1/run_end2end.sh index 3986ede9d7f..de67e444f39 100755 --- a/egs/madcat_ar/v1/run_end2end.sh +++ b/egs/madcat_ar/v1/run_end2end.sh @@ -15,8 +15,10 @@ writing_condition1=/export/corpora/LDC/LDC2012T15/docs/writing_conditions.tab writing_condition2=/export/corpora/LDC/LDC2013T09/docs/writing_conditions.tab writing_condition3=/export/corpora/LDC/LDC2013T15/docs/writing_conditions.tab data_splits_dir=data/download/data_splits +images_scp_dir=data/local overwrite=false - +subset=false +augment=false . ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. ## This relates to the queue. . ./path.sh @@ -37,20 +39,23 @@ if [ $stage -le 0 ]; then local/download_data.sh --data_splits $data_splits_dir --download_dir1 $download_dir1 \ --download_dir2 $download_dir2 --download_dir3 $download_dir3 - for dataset in test train dev; do - data_split_file=$data_splits_dir/madcat.$dataset.raw.lineid + for set in test train dev; do + data_split_file=$data_splits_dir/madcat.$set.raw.lineid local/extract_lines.sh --nj $nj --cmd $cmd --data_split_file $data_split_file \ --download_dir1 $download_dir1 --download_dir2 $download_dir2 \ --download_dir3 $download_dir3 --writing_condition1 $writing_condition1 \ --writing_condition2 $writing_condition2 --writing_condition3 $writing_condition3 \ - --data data/local/$dataset + --data data/local/$set --subset $subset --augment $augment || exit 1 done echo "$0: Preparing data..." - local/prepare_data.sh --download_dir1 $download_dir1 --download_dir2 $download_dir2 \ - --download_dir3 $download_dir3 --images_scp_dir data/local \ - --data_splits_dir $data_splits_dir --writing_condition1 $writing_condition1 \ - --writing_condition2 $writing_condition2 --writing_condition3 $writing_condition3 + for set in dev train test; do + local/process_data.py $download_dir1 $download_dir2 $download_dir3 \ + $data_splits_dir/madcat.$set.raw.lineid data/$set $images_scp_dir/$set/images.scp \ + $writing_condition1 $writing_condition2 $writing_condition3 --augment $augment --subset $subset + image/fix_data_dir.sh data/${set} + done + fi if [ $stage -le 1 ]; then @@ -58,10 +63,10 @@ if [ $stage -le 1 ]; then image/get_image2num_frames.py data/train image/get_allowed_lengths.py --frame-subsampling-factor 4 10 data/train - for dataset in test train; do - echo "$0: Extracting features and calling compute_cmvn_stats for dataset: $dataset. $(date)" - local/extract_features.sh --nj $nj --cmd $cmd --feat-dim 40 data/$dataset - steps/compute_cmvn_stats.sh data/$dataset || exit 1; + for set in test train; do + echo "$0: Extracting features and calling compute_cmvn_stats for dataset: $set. $(date)" + local/extract_features.sh --nj $nj --cmd $cmd --feat-dim 40 data/$set + steps/compute_cmvn_stats.sh data/$set || exit 1; done echo "$0: Fixing data directory for train dataset $(date)." utils/fix_data_dir.sh data/train @@ -69,14 +74,14 @@ fi if [ $stage -le 2 ]; then echo "$0: Preparing BPE..." - cut -d' ' -f2- data/train/text | local/reverse.py | \ - utils/lang/bpe/prepend_words.py --encoding 'utf-8' | \ + cut -d' ' -f2- data/train/text | utils/lang/bpe/reverse.py | \ + utils/lang/bpe/prepend_words.py | \ utils/lang/bpe/learn_bpe.py -s 700 > data/local/bpe.txt for set in test train dev; do cut -d' ' -f1 data/$set/text > data/$set/ids - cut -d' ' -f2- data/$set/text | local/reverse.py | \ - utils/lang/bpe/prepend_words.py --encoding 'utf-8' | \ + cut -d' ' -f2- data/$set/text | utils/lang/bpe/reverse.py | \ + utils/lang/bpe/prepend_words.py | \ utils/lang/bpe/apply_bpe.py -c data/local/bpe.txt \ | sed 's/@@//g' > data/$set/bpe_text @@ -95,8 +100,10 @@ fi if [ $stage -le 3 ]; then echo "$0: Estimating a language model for decoding..." local/train_lm.sh - utils/format_lm.sh data/lang data/local/local_lm/data/arpa/6gram_unpruned.arpa.gz \ - data/local/dict/lexicon.txt data/lang_test + utils/format_lm.sh data/lang data/local/local_lm/data/arpa/6gram_big.arpa.gz \ + data/local/dict/lexicon.txt data/lang + utils/build_const_arpa_lm.sh data/local/local_lm/data/arpa/6gram_unpruned.arpa.gz \ + data/lang data/lang_rescore_6g fi if [ $stage -le 4 ]; then diff --git a/egs/multi_en/s5/local/g2p/apply_g2p.sh b/egs/multi_en/s5/local/g2p/apply_g2p.sh deleted file mode 100755 index 8484155800d..00000000000 --- a/egs/multi_en/s5/local/g2p/apply_g2p.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/bin/bash - -# Copyright 2016 Allen Guo -# 2017 Xiaohui Zhang -# Apache License 2.0 - -# This script applies a trained Phonetisarus G2P model to -# synthesize pronunciations for missing words (i.e., words in -# transcripts but not the lexicon), and output the expanded lexicon. - -var_counts=1 - -. ./path.sh || exit 1 -. parse_options.sh || exit 1; - -if [ $# -ne "4" ]; then - echo "Usage: $0 " - exit 1 -fi - -model=$1 -workdir=$2 -lexicon=$3 -outlexicon=$4 - -mkdir -p $workdir - -# awk command from http://stackoverflow.com/questions/2626274/print-all-but-the-first-three-columns -echo 'Gathering missing words...' -cat data/*/train/text | \ - local/count_oovs.pl $lexicon | \ - awk '{if (NF > 3 ) {for(i=4; i $workdir/missing.txt -cat $workdir/missing.txt | \ - grep "^[a-z]*$" > $workdir/missing_onlywords.txt - -echo 'Synthesizing pronunciations for missing words...' -phonetisaurus-apply --nbest $var_counts --model $model --thresh 5 --accumulate --word_list $workdir/missing_onlywords.txt > $workdir/missing_g2p_${var_counts}.txt - -echo "Adding new pronunciations to $lexicon" -cat "$lexicon" $workdir/missing_g2p_${var_counts}.txt | sort | uniq > $outlexicon diff --git a/egs/multi_en/s5/local/g2p/train_g2p.sh b/egs/multi_en/s5/local/g2p/train_g2p.sh deleted file mode 100755 index 43e75f6608d..00000000000 --- a/egs/multi_en/s5/local/g2p/train_g2p.sh +++ /dev/null @@ -1,67 +0,0 @@ -#!/bin/bash - -# Copyright 2017 Intellisist, Inc. (Author: Navneeth K) -# 2017 Xiaohui Zhang -# Apache License 2.0 - -# This script trains a g2p model using Phonetisaurus and SRILM. - -stage=0 -silence_phones= - -echo "$0 $@" # Print the command line for logging - -[ -f ./path.sh ] && . ./path.sh; # source the path. -. utils/parse_options.sh || exit 1; - - -if [ $# -ne 2 ]; then - echo "Usage: $0 " - exit 1; -fi - -lexicondir=$1 -outdir=$2 - -[ ! -f $lexicondir/lexicon.txt ] && echo "Cannot find $lexicondir/lexicon.txt" && exit - -isuconv=`which uconv` -if [ -z $isuconv ]; then - echo "uconv was not found. You must install the icu4c package." - exit 1; -fi - -mkdir -p $outdir - - -# For input lexicon, remove pronunciations containing non-utf-8-encodable characters, -# and optionally remove words that are mapped to a single silence phone from the lexicon. -if [ $stage -le 0 ]; then - lexicon=$lexicondir/lexicon.txt - if [ ! -z "$silence_phones" ]; then - awk 'NR==FNR{a[$1] = 1; next} {s=$2;for(i=3;i<=NF;i++) s=s" "$i; if(!(s in a)) print $1" "s}' \ - $silence_phones $lexicon | \ - awk '{printf("%s\t",$1); for (i=2;i 0'> $outdir/lexicon_tab_separated.txt - else - awk '{printf("%s\t",$1); for (i=2;i 0'> $outdir/lexicon_tab_separated.txt - fi -fi - -if [ $stage -le 1 ]; then - # Align lexicon stage. Lexicon is assumed to have first column tab separated - phonetisaurus-align --input=$outdir/lexicon_tab_separated.txt --ofile=${outdir}/aligned_lexicon.corpus || exit 1; -fi - -if [ $stage -le 2 ]; then - # Convert aligned lexicon to arpa using srilm. - ngram-count -order 7 -kn-modify-counts-at-end -gt1min 0 -gt2min 0 \ - -gt3min 0 -gt4min 0 -gt5min 0 -gt6min 0 -gt7min 0 -ukndiscount \ - -text ${outdir}/aligned_lexicon.corpus -lm ${outdir}/aligned_lexicon.arpa -fi - -if [ $stage -le 3 ]; then - # Convert the arpa file to FST. - phonetisaurus-arpa2wfst --lm=${outdir}/aligned_lexicon.arpa --ofile=${outdir}/model.fst -fi diff --git a/egs/multi_en/s5/run.sh b/egs/multi_en/s5/run.sh index 3a1262101aa..034ffeb4e66 100755 --- a/egs/multi_en/s5/run.sh +++ b/egs/multi_en/s5/run.sh @@ -58,8 +58,8 @@ if [ $stage -le 1 ]; then # We prepare the basic dictionary in data/local/dict_combined. local/prepare_dict.sh $swbd $tedlium2 ( - local/g2p/train_g2p.sh --stage 0 --silence-phones \ - "data/local/dict_combined/silence_phones.txt" data/local/dict_combined exp/g2p || touch exp/g2p/.error + steps/dict/train_g2p_phonetisaurus.sh --stage 0 --silence-phones \ + "data/local/dict_combined/silence_phones.txt" data/local/dict_combined/lexicon.txt exp/g2p || touch exp/g2p/.error ) & fi @@ -114,8 +114,28 @@ if [ $stage -le 4 ]; then mkdir -p $dict_dir rm $dict_dir/lexiconp.txt 2>/dev/null || true cp data/local/dict_combined/{extra_questions,nonsilence_phones,silence_phones,optional_silence}.txt $dict_dir - local/g2p/apply_g2p.sh --var-counts 1 exp/g2p/model.fst data/local/g2p_phonetisarus \ - data/local/dict_combined/lexicon.txt $dict_dir/lexicon.txt || exit 1; + + echo 'Gathering missing words...' + + lexicon=data/local/dict_combined/lexicon.txt + g2p_tmp_dir=data/local/g2p_phonetisarus + mkdir -p $g2p_tmp_dir + + # awk command from http://stackoverflow.com/questions/2626274/print-all-but-the-first-three-columns + cat data/*/train/text | \ + local/count_oovs.pl $lexicon | \ + awk '{if (NF > 3 ) {for(i=4; i $g2p_tmp_dir/missing.txt + cat $g2p_tmp_dir/missing.txt | \ + grep "^[a-z]*$" > $g2p_tmp_dir/missing_onlywords.txt + + steps/dict/apply_g2p_phonetisaurus.sh --nbest 1 $g2p_tmp_dir/missing_onlywords.txt exp/g2p exp/g2p/oov_lex || exit 1; + cp exp/g2p/oov_lex/lexicon.lex $g2p_tmp_dir/missing_lexicon.txt + + extended_lexicon=$dict_dir/lexicon.txt + echo "Adding new pronunciations to get extended lexicon $extended_lexicon" + cat <(cut -f 1,3 $g2p_tmp_dir/missing_lexicon.txt) $lexicon | sort | uniq > $extended_lexicon fi # We'll do multiple iterations of pron/sil-prob estimation. So the structure of diff --git a/egs/tedlium/s5_r2/local/rnnlm/tuning/run_lstm_tdnn_with_lm1b.sh b/egs/tedlium/s5_r2/local/rnnlm/tuning/run_lstm_tdnn_with_lm1b.sh index a0b16dea890..ec289df81ef 100755 --- a/egs/tedlium/s5_r2/local/rnnlm/tuning/run_lstm_tdnn_with_lm1b.sh +++ b/egs/tedlium/s5_r2/local/rnnlm/tuning/run_lstm_tdnn_with_lm1b.sh @@ -3,10 +3,10 @@ # Copyright 2012 Johns Hopkins University (author: Daniel Povey) Tony Robinson # 2018 Ke Li -# rnnlm/train_rnnlm.sh: best iteration (out of 9) was 8, linking it to final iteration. -# rnnlm/train_rnnlm.sh: train/dev perplexity was 32.2 / 123.2. -# Train objf: -4.02 -3.71 -3.64 -3.58 -3.55 -3.52 -3.50 -3.48 -3.44 -# Dev objf: -11.92 -5.13 -5.03 -4.94 -4.91 -4.87 -4.85 -4.83 -4.81 +# rnnlm/train_rnnlm.sh: best iteration (out of 60) was 58, linking it to final iteration. +# rnnlm/train_rnnlm.sh: train/dev perplexity was 25.1 / 104.5. +# Train objf: -3.60 -3.52 -3.48 -3.44 -3.41 -3.38 -3.36 -3.35 -3.33 -3.31 -3.29 -3.29 -3.28 -3.28 -3.27 -3.25 -3.25 -3.23 -3.23 -3.22 -3.22 -3.21 -3.20 -3.19 -3.19 -3.18 -3.18 -3.18 -3.17 -3.17 -3.15 -3.15 -3.15 -3.15 -3.14 -3.14 -3.12 -3.14 -3.16 -3.13 -3.12 -3.13 -3.11 -3.12 -3.11 -3.10 -3.09 -3.10 -3.06 -3.08 -3.10 -3.09 -3.08 -3.09 -3.02 -3.01 -3.02 -2.98 -3.02 +# Dev objf: -5.12 -5.04 -4.98 -4.93 -4.91 -4.89 -4.87 -4.86 -4.82 -4.80 -4.79 -4.79 -4.78 -4.77 -4.78 -4.76 -4.76 -4.75 -4.75 -4.74 -4.74 -4.73 -4.73 -4.72 -4.71 -4.72 -4.71 -4.70 -4.70 -4.70 -4.70 -4.70 -4.70 -4.70 -4.69 -4.69 -4.68 -4.68 -4.67 -4.67 -4.68 -4.67 -4.67 -4.67 -4.67 -4.67 -4.67 -4.67 -4.66 -4.68 -4.68 -4.72 -4.68 -4.66 -4.71 -4.65 -4.65 -4.65 -4.65 # 1-pass results # %WER 8.3 | 1155 27500 | 92.7 4.9 2.4 1.0 8.3 68.8 | -0.019 | /export/a12/ywang/kaldi/egs/tedlium/s5_r2/exp/chain_cleaned/tdnn_lstm1i_adversarial1.0_interval4_epoches7_lin_to_5_sp_bi/decode_looped_test/score_10_0.0/ctm.filt.filt.sys @@ -15,10 +15,10 @@ # %WER 7.8 | 1155 27500 | 93.1 4.5 2.4 0.9 7.8 66.4 | -0.089 | /export/a12/ywang/kaldi/egs/tedlium/s5_r2/exp/chain_cleaned/tdnn_lstm1i_adversarial1.0_interval4_epoches7_lin_to_5_sp_bi/decode_looped_test_rescore/score_10_0.0/ctm.filt.filt.sys # RNNLM lattice rescoring -# %WER 7.3 | 1155 27500 | 93.6 4.0 2.4 0.9 7.3 65.4 | -0.138 | exp/decode_test_rnnlm_lm1b_tedlium_weight3/score_10_0.0/ctm.filt.filt.sys +# %WER 6.8 | 1155 27500 | 94.0 3.7 2.3 0.8 6.8 62.3 | -0.844 | exp/decode_looped_test_rnnlm_lm1b_tedlium_weight3_rescore//score_10_0.0/ctm.filt.filt.sys # RNNLM nbest rescoring -# %WER 7.3 | 1155 27500 | 93.6 4.3 2.1 0.9 7.3 65.0 | -0.895 | exp/decode_test_rnnlm_lm1b_tedlium_weight3_nbest/score_8_0.0/ctm.filt.filt.sys +# %WER 6.9 | 1155 27500 | 94.0 3.8 2.2 0.9 6.9 61.3 | -0.769 | exp/decode_looped_test_rnnlm_lm1b_tedlium_weight3_nbest_rescore//score_10_0.0/ctm.filt.filt.sys # Begin configuration section. cmd=run.pl @@ -29,7 +29,7 @@ lstm_rpd=256 lstm_nrpd=256 stage=0 train_stage=-10 -epochs=3 +epochs=20 # variables for lattice rescoring run_lat_rescore=true diff --git a/egs/tunisian_msa/s5/conf/pitch.conf b/egs/tunisian_msa/s5/conf/pitch.conf deleted file mode 100644 index e959a19d5b8..00000000000 --- a/egs/tunisian_msa/s5/conf/pitch.conf +++ /dev/null @@ -1 +0,0 @@ ---sample-frequency=16000 diff --git a/egs/tunisian_msa/s5/conf/plp.conf b/egs/tunisian_msa/s5/conf/plp.conf deleted file mode 100644 index e959a19d5b8..00000000000 --- a/egs/tunisian_msa/s5/conf/plp.conf +++ /dev/null @@ -1 +0,0 @@ ---sample-frequency=16000 diff --git a/egs/tunisian_msa/s5/local/buckwalter2utf8.pl b/egs/tunisian_msa/s5/local/buckwalter2utf8.pl deleted file mode 100755 index c952e554f86..00000000000 --- a/egs/tunisian_msa/s5/local/buckwalter2utf8.pl +++ /dev/null @@ -1,11 +0,0 @@ -#!/usr/bin/env perl -# Input buckwalter encoded Arabic and print it out as utf-8 encoded Arabic. -use strict; -use warnings; -use Carp; - -use Encode::Arabic::Buckwalter; # imports just like 'use Encode' would, plus more - -while ( my $line = <>) { - print encode 'utf8', decode 'buckwalter', $line; -} diff --git a/egs/tunisian_msa/s5/local/chain/tuning/run_tdnn_1a.sh b/egs/tunisian_msa/s5/local/chain/tuning/run_tdnn_1a.sh index d3c4a4ef11f..a2662584549 100755 --- a/egs/tunisian_msa/s5/local/chain/tuning/run_tdnn_1a.sh +++ b/egs/tunisian_msa/s5/local/chain/tuning/run_tdnn_1a.sh @@ -25,7 +25,7 @@ nnet3_affix= # are just hardcoded at this level, in the commands below. affix=1a # affix for the TDNN directory name tree_affix= -train_stage=22 +train_stage=-10 get_egs_stage=-10 decode_iter= diff --git a/egs/tunisian_msa/s5/local/qcri_buckwalter2utf8.pl b/egs/tunisian_msa/s5/local/qcri_buckwalter2utf8.pl deleted file mode 100755 index 9074d4807c2..00000000000 --- a/egs/tunisian_msa/s5/local/qcri_buckwalter2utf8.pl +++ /dev/null @@ -1,21 +0,0 @@ -#!/usr/bin/env perl -#qcri_buckwalter2utf8.pl - convert the qcri dictionary toutf8 - -use strict; -use warnings; -use Carp; - -use Encode::Arabic::Buckwalter; # imports just like 'use Encode' would, plus more - -my $bw_dict = "qcri.txt"; - -open my $B, '<', $bw_dict or croak "Problem with $bw_dict $!"; - - LINE: while ( my $line = <$B> ) { - chomp $line; - next LINE if ( $line =~ /^\#/ ); - my ($w,$p) = split / /, $line, 2; - print encode 'utf8', decode 'buckwalter', $w; - print " $p\n"; -} - diff --git a/egs/tunisian_msa/s5/local/qcri_buckwalter2utf8.sh b/egs/tunisian_msa/s5/local/qcri_buckwalter2utf8.sh index b8433967e14..0468c04ebd8 100755 --- a/egs/tunisian_msa/s5/local/qcri_buckwalter2utf8.sh +++ b/egs/tunisian_msa/s5/local/qcri_buckwalter2utf8.sh @@ -1,5 +1,8 @@ #!/bin/bash +# Copyright 2018 John Morgan +# Apache 2.0. + # write separate files for word and pronunciation fields cut -d " " -f 1 qcri.txt > qcri_words_buckwalter.txt cut -d " " -f 2- qcri.txt > qcri_prons.txt diff --git a/egs/wsj/s5/local/nnet3/run_ivector_common.sh b/egs/wsj/s5/local/nnet3/run_ivector_common.sh index 2c218ae3673..813c6e14aed 100755 --- a/egs/wsj/s5/local/nnet3/run_ivector_common.sh +++ b/egs/wsj/s5/local/nnet3/run_ivector_common.sh @@ -149,8 +149,8 @@ if [ $stage -le 5 ]; then done fi -if [ -f data/${train_set}_sp/feats.scp ] && [ $stage -le 8 ]; then - echo "$0: $feats already exists. Refusing to overwrite the features " +if [ -f data/${train_set}_sp/feats.scp ] && [ $stage -le 7 ]; then + echo "$0: data/${train_set}_sp/feats.scp already exists. Refusing to overwrite the features " echo " to avoid wasting time. Please remove the file and continue if you really mean this." exit 1; fi diff --git a/egs/wsj/s5/steps/cleanup/debug_lexicon.sh b/egs/wsj/s5/steps/cleanup/debug_lexicon.sh index d35c9557af8..eca807ad247 100755 --- a/egs/wsj/s5/steps/cleanup/debug_lexicon.sh +++ b/egs/wsj/s5/steps/cleanup/debug_lexicon.sh @@ -116,21 +116,21 @@ if [ $stage -le 8 ]; then export LC_ALL=C - cat $dir/phone.ctm | utils/apply_map.pl -f 5 $dir/phone_map.txt | sort > $dir/phone_mapped.ctm + cat $dir/phone.ctm | utils/apply_map.pl -f 5 $dir/phone_map.txt > $dir/phone_mapped.ctm cat $dir/word.ctm | awk '{printf("%s-%s %010.0f START %s\n", $1, $2, 1000*$3, $5); printf("%s-%s %010.0f END %s\n", $1, $2, 1000*($3+$4), $5);}' | \ sort > $dir/word_processed.ctm # filter out those utteraces which only appea in phone_processed.ctm but not in word_processed.ctm cat $dir/phone_mapped.ctm | awk '{printf("%s-%s %010.0f PHONE %s\n", $1, $2, 1000*($3+(0.5*$4)), $5);}' | \ - awk 'NR==FNR{a[$1] = 1; next} {if($1 in a) print $0}' $dir/word_processed.ctm - \ - > $dir/phone_processed.ctm + awk 'NR==FNR{a[$1] = 1; next} {if($1 in a) print $0}' $dir/word_processed.ctm - | \ + sort > $dir/phone_processed.ctm # merge-sort both ctm's sort -m $dir/word_processed.ctm $dir/phone_processed.ctm > $dir/combined.ctm fi - # after merge-sort of the two ctm's, we add to cover "deserted" phones due to precision limits, and then merge all consecutive 's. +# after merge-sort of the two ctm's, we add to cover "deserted" phones due to precision limits, and then merge all consecutive 's. if [ $stage -le 9 ]; then awk '{print $1, $3, $4}' $dir/combined.ctm | \ perl -e ' while (<>) { chop; @A = split(" ", $_); ($utt, $a,$b) = @A; diff --git a/egs/wsj/s5/steps/compare_alignments.sh b/egs/wsj/s5/steps/compare_alignments.sh new file mode 100755 index 00000000000..d94d2197fee --- /dev/null +++ b/egs/wsj/s5/steps/compare_alignments.sh @@ -0,0 +1,220 @@ +#!/bin/bash + +# Copyright 2018 Johns Hopkins University (author: Daniel Povey) +# Apache 2.0. + +set -e +stage=0 +cmd=run.pl # We use this only for get_ctm.sh, which can be a little slow. +num_to_sample=1000 # We sample this many utterances for human-readable display, starting from the worst and then + # starting from the middle. +cleanup=true + +if [ -f ./path.sh ]; then . ./path.sh; fi + +. ./utils/parse_options.sh + +if [ $# -ne 5 ] && [ $# -ne 7 ]; then + cat < + or: $0 [options] + e.g.: $0 data/lang data/train exp/tri2_ali exp/tri3_ali exp/compare_ali_2_3 + + Options: + --cmd (run.pl|queue.pl...) # specify how to run the sub-processes. + # (passed through to get_train_ctm.sh) + --cleanup # Specify --cleanup false to prevent + # cleanup of temporary files. + --stage # Enables you to run part of the script. + +EOF + exit 1 +fi + +if [ $# -eq 5 ]; then + lang1=$1 + lang2=$1 + data1=$2 + data2=$2 + ali_dir1=$3 + ali_dir2=$4 + dir=$5 +else + lang1=$1 + lang2=$2 + data1=$3 + data2=$4 + ali_dir1=$5 + ali_dir2=$6 + dir=$7 +fi + +for f in $lang1/phones.txt $lang2/phones.txt $data1/utt2spk $data2/utt2spk \ + $ali_dir1/ali.1.gz $ali_dir2/ali.2.gz; do + if [ ! -f $f ]; then + echo "$0: expected file $f to exist" + exit 1 + fi +done + +# This will exit if the phone symbol id's are different, due to +# `set -e` above. +utils/lang/check_phones_compatible.sh $lang1/phones.txt $lang2/phones.txt + +nj1=$(cat $ali_dir1/num_jobs) +nj2=$(cat $ali_dir2/num_jobs) + +mkdir -p $dir/log + + +if [ $stage -le 0 ]; then + echo "$0: converting alignments to phones." + + for j in $(seq $nj1); do gunzip -c $ali_dir1/ali.$j.gz; done | \ + ali-to-phones --per-frame=true $ali_dir1/final.mdl ark:- ark:- | gzip -c > $dir/phones1.gz + + for j in $(seq $nj2); do gunzip -c $ali_dir2/ali.$j.gz; done | \ + ali-to-phones --per-frame=true $ali_dir2/final.mdl ark:- ark:- | gzip -c > $dir/phones2.gz +fi + +if [ $stage -le 1 ]; then + echo "$0: getting comparison stats and utterance stats." + compare-int-vector --binary=false --write-confusion-matrix=$dir/conf.mat \ + "ark:gunzip -c $dir/phones1.gz|" "ark:gunzip -c $dir/phones2.gz|" 2>$dir/log/compare_phones.log > $dir/utt_stats.phones + tail -n 8 $dir/log/compare_phones.log +fi + +if [ $stage -le 3 ]; then + cat $dir/conf.mat | grep -v -F '[' | sed 's/]//' | awk '{n=NF; for (k=1;k<=n;k++) { conf[NR,k] = $k; row_tot[NR] += $k; col_tot[k] += $k; } } END{ + for (row=1;row<=n;row++) for (col=1;col<=n;col++) { + val = conf[row,col]; this_row_tot = row_tot[row]; this_col_tot = col_tot[col]; + rval=conf[col,row] + min_tot = (this_row_tot < this_col_tot ? this_row_tot : this_col_tot); + if (val != 0) { + phone1 = row-1; phone2 = col-1; + if (row == col) printf("COR %d %d %.2f%\n", phone1, val, (val * 100 / this_row_tot)); + else { + norm_prob = val * val / min_tot; # heuristic for sorting. + printf("SUB %d %d %d %d %.2f%% %.2f%%\n", + norm_prob, phone1, phone2, val, (val * 100 / min_tot), (rval * 100 / min_tot)); }}}}' > $dir/phone_stats.all + + ( + echo "# Format: " + grep '^COR' $dir/phone_stats.all | sort -n -k4,4 | awk '{print $2, $3, $4}' | utils/int2sym.pl -f 1 $lang1/phones.txt + ) > $dir/phones_correct.txt + + ( + echo "#Format: " + echo "# is the number of frames that were labeled in the first" + echo "# set of alignments and in the second." + echo "# is divided by the smaller of the total num-frames of" + echo "# phone1 or phone2, whichever is smaller; expressed as a percentage." + echo "# is the same but for the reverse substitution, from" + echo "# to ; the comparison with the substitutions are)." + grep '^SUB' $dir/phone_stats.all | sort -nr -k2,2 | awk '{print $3,$4,$5,$6,$7}' | utils/int2sym.pl -f 1-2 $lang1/phones.txt + ) > $dir/phone_subs.txt +fi + +if [ $stage -le 4 ]; then + echo "$0: getting CTMs" + steps/get_train_ctm.sh --use-segments false --print-silence true --cmd "$cmd" --frame-shift 1.0 $data1 $lang1 $ali_dir1 $dir/ctm1 + steps/get_train_ctm.sh --use-segments false --print-silence true --cmd "$cmd" --frame-shift 1.0 $data2 $lang2 $ali_dir2 $dir/ctm2 +fi + +if [ $stage -le 5 ]; then + oov=$(cat $lang1/oov.int) + # Note: below, we use $lang1 for both setups; this is by design as compare-int-vector + # assumes they use the same symbol table. + for n in 1 2; do + cat $dir/ctm${n}/ctm | utils/sym2int.pl --map-oov $oov -f 5 $lang1/words.txt | \ + awk 'BEGIN{utt_id="";} { if (utt_id != $1) { if (utt_id != "") printf("\n"); utt_id=$1; printf("%s ", utt_id); } t_start=int($3); t_end=t_start + int($4); word=$5; for (t=t_start; t$dir/words${n}.gz + done +fi + +if [ $stage -le 5 ]; then + compare-int-vector --binary=false --write-tot-counts=$dir/words_tot.vec --write-diff-counts=$dir/words_diff.vec \ + "ark:gunzip -c $dir/words1.gz|" "ark:gunzip -c $dir/words2.gz|" 2>$dir/log/compare_words.log >$dir/utt_stats.words + tail -n 8 $dir/log/compare_words.log +fi + +if [ $stage -le 6 ]; then + + ( echo "# Word stats. Format:"; + echo " " + + paste <(awk '{for (n=2;n 0) print $1*$1/$2, $1/$2, $1, $2, (NR-1)}' | utils/int2sym.pl -f 5 $lang1/words.txt | \ + sort -nr | awk '{print $2, $3, $4, $5;}' + ) > $dir/word_stats.txt + +fi + +if [ $stage -le 7 ]; then + for type in phones words; do + num_utts=$(wc -l <$dir/utt_stats.$type) + cat $dir/utt_stats.$type | awk -v type=$type 'BEGIN{print "Utterance-id proportion-"type"-changed num-frames num-wrong-frames"; } + {print $1, $3 * 1.0 / $2, $2, $3; }' | sort -nr -k2,2 > $dir/utt_stats.$type.sorted + ( + echo "$0: Percentiles 100, 90, .. 0 of proportion-$type-changed distribution (over utterances) are:" + cat $dir/utt_stats.$type.sorted | awk -v n=$num_utts 'BEGIN{k=int((n-1)/10);} {if (NR % k == 1) printf("%s ", $2); } END{print "";}' + ) | tee $dir/utt_stats.$type.percentiles + done +fi + + +if [ $stage -le 8 ]; then + # Display the 1000 worst utterances, and 1000 utterances from the middle of the pack, in a readable format. + num_utts=$(wc -l <$dir/utt_stats.words.sorted) + half_num_utts=$[$num_utts/2]; + if [ $num_to_sample -gt $half_num_utts ]; then + num_to_sample=$half_num_utts + fi + head -n $num_to_sample $dir/utt_stats.words.sorted | awk '{print $1}' > $dir/utt_ids.worst + tail -n +$half_num_utts $dir/utt_stats.words.sorted | head -n $num_to_sample | awk '{print $1}' > $dir/utt_ids.mid + + for suf in worst mid; do + for n in 1 2; do + gunzip -c $dir/phones${n}.gz | copy-int-vector ark:- ark,t:- | utils/filter_scp.pl $dir/utt_ids.$suf >$dir/temp + # the next command reorders them, and duplicates the utterance-idwhich we'll later use + # that to display the word sequence. + awk '{print $1,$1,$1}' <$dir/utt_ids.$suf | utils/apply_map.pl -f 3 $dir/temp > $dir/phones${n}.$suf + rm $dir/temp + done + # the stuff with 0 and below is a kind of hack so that if the phones are the same, we end up + # with just the phone, but if different, we end up with p1/p2. + # The apply_map.pl stuff is to put the transcript there. + + ( + echo "# Format: ... ... " + echo "# If the two alignments have the same phone, just that phone will be printed;" + echo "# otherwise the two phones will be printed, as in 'phone1/phone2'. So '/' is present" + echo "# whenever there is a mismatch." + + paste $dir/phones1.$suf $dir/phones2.$suf | perl -ane ' @A = split("\t", $_); @A1 = split(" ", $A[0]); @A2 = split(" ", $A[1]); + $utt = shift @A1; shift @A2; print $utt, " "; + for ($n = 0; $n < @A1 && $n < @A2; $n++) { $a1=$A1[$n]; $a2=$A2[$n]; if ($a1 eq $a2) { print "$a1 "; } else { print "$a1 0 $a2 "; }} + print "\n" ' | utils/int2sym.pl -f 3- $lang1/phones.txt | sed 's: :/:g' | \ + utils/apply_map.pl -f 2 $data1/text + ) > $dir/compare_phones_${suf}.txt + done +fi + + +if [ $stage -le 9 ] && $cleanup; then + rm $dir/phones{1,2}.gz $dir/words{1,2}.gz $dir/ctm*/ctm $dir/*.vec $dir/conf.mat \ + $dir/utt_ids.* $dir/phones{1,2}.{mid,worst} $dir/utt_stats.{phones,words} \ + $dir/phone_stats.all +fi + +# clean up +exit 0 diff --git a/egs/wsj/s5/steps/data/reverberate_data_dir.py b/egs/wsj/s5/steps/data/reverberate_data_dir.py index f6be7a286ec..570613855a0 100755 --- a/egs/wsj/s5/steps/data/reverberate_data_dir.py +++ b/egs/wsj/s5/steps/data/reverberate_data_dir.py @@ -1,5 +1,6 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # Copyright 2016 Tom Ko +# 2018 David Snyder # Apache 2.0 # script to generate reverberated data @@ -167,14 +168,13 @@ def ParseFileToDict(file, assert2fields = False, value_processor = None): # This function creates a file and write the content of a dictionary into it def WriteDictToFile(dict, file_name): file = open(file_name, 'w') - keys = dict.keys() - keys.sort() + keys = sorted(dict.keys()) for key in keys: value = dict[key] if type(value) in [list, tuple] : if type(value) is tuple: value = list(value) - value.sort() + value = sorted(value) value = ' '.join(str(value)) file.write('{0} {1}\n'.format(key, value)) file.close() @@ -185,8 +185,7 @@ def CreateCorruptedUtt2uniq(input_dir, output_dir, num_replicas, include_origina corrupted_utt2uniq = {} # Parse the utt2spk to get the utterance id utt2spk = ParseFileToDict(input_dir + "/utt2spk", value_processor = lambda x: " ".join(x)) - keys = utt2spk.keys() - keys.sort() + keys = sorted(utt2spk.keys()) if include_original: start_index = 0 else: @@ -290,8 +289,8 @@ def GenerateReverberationOpts(room_dict, # the room dictionary, please refer to assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['snrs']) if len(noise_addition_descriptor['noise_io']) > 0: reverberate_opts += "--additive-signals='{0}' ".format(','.join(noise_addition_descriptor['noise_io'])) - reverberate_opts += "--start-times='{0}' ".format(','.join(map(lambda x:str(x), noise_addition_descriptor['start_times']))) - reverberate_opts += "--snrs='{0}' ".format(','.join(map(lambda x:str(x), noise_addition_descriptor['snrs']))) + reverberate_opts += "--start-times='{0}' ".format(','.join([str(x) for x in noise_addition_descriptor['start_times']])) + reverberate_opts += "--snrs='{0}' ".format(','.join([str(x) for x in noise_addition_descriptor['snrs']])) return reverberate_opts @@ -331,8 +330,7 @@ def GenerateReverberatedWavScp(wav_scp, # a dictionary whose values are the Kal foreground_snrs = list_cyclic_iterator(foreground_snr_array) background_snrs = list_cyclic_iterator(background_snr_array) corrupted_wav_scp = {} - keys = wav_scp.keys() - keys.sort() + keys = sorted(wav_scp.keys()) if include_original: start_index = 0 else: @@ -373,7 +371,7 @@ def GenerateReverberatedWavScp(wav_scp, # a dictionary whose values are the Kal # This function replicate the entries in files like segments, utt2spk, text def AddPrefixToFields(input_file, output_file, num_replicas, include_original, prefix, field = [0]): - list = map(lambda x: x.strip(), open(input_file)) + list = [x.strip() for x in open(input_file)] f = open(output_file, "w") if include_original: start_index = 0 @@ -415,8 +413,8 @@ def CreateReverberatedCopy(input_dir, print("Getting the duration of the recordings..."); data_lib.RunKaldiCommand("utils/data/get_reco2dur.sh {}".format(input_dir)) durations = ParseFileToDict(input_dir + "/reco2dur", value_processor = lambda x: float(x[0])) - foreground_snr_array = map(lambda x: float(x), foreground_snr_string.split(':')) - background_snr_array = map(lambda x: float(x), background_snr_string.split(':')) + foreground_snr_array = [float(x) for x in foreground_snr_string.split(':')] + background_snr_array = [float(x) for x in background_snr_string.split(':')] GenerateReverberatedWavScp(wav_scp, durations, output_dir, room_dict, pointsource_noise_list, iso_noise_dict, foreground_snr_array, background_snr_array, num_replicas, include_original, prefix, @@ -445,11 +443,11 @@ def CreateReverberatedCopy(input_dir, # This function smooths the probability distribution in the list -def SmoothProbabilityDistribution(list, smoothing_weight=0.0, target_sum=1.0): - if len(list) > 0: +def SmoothProbabilityDistribution(set_list, smoothing_weight=0.0, target_sum=1.0): + if len(list(set_list)) > 0: num_unspecified = 0 accumulated_prob = 0 - for item in list: + for item in set_list: if item.probability is None: num_unspecified += 1 else: @@ -463,7 +461,7 @@ def SmoothProbabilityDistribution(list, smoothing_weight=0.0, target_sum=1.0): warnings.warn("The sum of probabilities specified by user is larger than or equal to 1. " "The items without probabilities specified will be given zero to their probabilities.") - for item in list: + for item in set_list: if item.probability is None: item.probability = uniform_probability else: @@ -471,11 +469,11 @@ def SmoothProbabilityDistribution(list, smoothing_weight=0.0, target_sum=1.0): item.probability = (1 - smoothing_weight) * item.probability + smoothing_weight * uniform_probability # Normalize the probability - sum_p = sum(item.probability for item in list) - for item in list: + sum_p = sum(item.probability for item in set_list) + for item in set_list: item.probability = item.probability / sum_p * target_sum - return list + return set_list # This function parse the array of rir set parameter strings. @@ -521,7 +519,7 @@ def ParseRirList(rir_set_para_array, smoothing_weight, sampling_rate = None): rir_list = [] for rir_set in set_list: - current_rir_list = map(lambda x: rir_parser.parse_args(shlex.split(x.strip())),open(rir_set.filename)) + current_rir_list = [rir_parser.parse_args(shlex.split(x.strip())) for x in open(rir_set.filename)] for rir in current_rir_list: if sampling_rate is not None: # check if the rspecifier is a pipe or not @@ -586,7 +584,7 @@ def ParseNoiseList(noise_set_para_array, smoothing_weight, sampling_rate = None) pointsource_noise_list = [] iso_noise_dict = {} for noise_set in set_list: - current_noise_list = map(lambda x: noise_parser.parse_args(shlex.split(x.strip())),open(noise_set.filename)) + current_noise_list = [noise_parser.parse_args(shlex.split(x.strip())) for x in open(noise_set.filename)] current_pointsource_noise_list = [] for noise in current_noise_list: if sampling_rate is not None: diff --git a/egs/wsj/s5/steps/dict/apply_g2p_phonetisaurus.sh b/egs/wsj/s5/steps/dict/apply_g2p_phonetisaurus.sh new file mode 100755 index 00000000000..a793f91fd0a --- /dev/null +++ b/egs/wsj/s5/steps/dict/apply_g2p_phonetisaurus.sh @@ -0,0 +1,99 @@ +#!/bin/bash +# Copyright 2014 Johns Hopkins University (Author: Yenda Trmal) +# Copyright 2016 Xiaohui Zhang +# 2018 Ruizhe Huang +# Apache 2.0 + +# This script applies a trained Phonetisarus G2P model to +# synthesize pronunciations for missing words (i.e., words in +# transcripts but not the lexicon), and output the expanded lexicon. +# The user could specify either nbest or pmass option +# to determine the number of output pronunciation variants, +# or use them together to get the intersection of two options. + +# Begin configuration section. +stage=0 +nbest= # Generate up to N, like N=3, pronunciation variants for each word + # (The maximum size of the nbest list, not considering pruning and taking the prob-mass yet). +thresh=5 # Pruning threshold for the n-best list, in (0, 99], which is a -log-probability value. + # A large threshold makes the nbest list shorter, and less likely to hit the max size. + # This value corresponds to the weight_threshold in shortest-path.h of openfst. +pmass= # Select the top variants from the pruned nbest list, + # summing up to this total prob-mass for a word. + # On the "boundary", it's greedy by design, e.g. if pmass = 0.8, + # and we have prob(pron_1) = 0.5, and prob(pron_2) = 0.4, then we get both. +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +[ -f ./path.sh ] && . ./path.sh; # source the path. +. utils/parse_options.sh || exit 1; + +set -u +set -e + +if [ $# != 3 ]; then + echo "Usage: $0 [options] " + echo "... where is a list of words whose pronunciation is to be generated." + echo " is a directory used as a target during training of G2P" + echo " is the directory where the output lexicon should be stored." + echo " The format of the output lexicon output-dir/lexicon.lex is" + echo " \t\t per line." + echo "e.g.: $0 --nbest 1 exp/g2p/oov_words.txt exp/g2p exp/g2p/oov_lex" + echo "" + echo "main options (for others, see top of script file)" + echo " --nbest # Generate upto N pronunciation variants for each word." + echo " --pmass # Select the top variants from the pruned nbest list," + echo " # summing up to this total prob-mass, within [0, 1], for a word." + echo " --thresh # Pruning threshold for n-best." + exit 1; +fi + +wordlist=$1 +modeldir=$2 +outdir=$3 + +model=$modeldir/model.fst +output_lex=$outdir/lexicon.lex +mkdir -p $outdir + +[ ! -f ${model:-} ] && echo "$0: File $model not found in the directory $modeldir." && exit 1 +[ ! -f $wordlist ] && echo "$0: File $wordlist not found!" && exit 1 +[ -z $pmass ] && [ -z $nbest ] && echo "$0: nbest or/and pmass should be specified." && exit 1; +if ! phonetisaurus=`which phonetisaurus-apply` ; then + echo "Phonetisarus was not found !" + echo "Go to $KALDI_ROOT/tools and execute extras/install_phonetisaurus.sh" + exit 1 +fi + +cp $wordlist $outdir/wordlist.txt + +# three options: 1) nbest, 2) pmass, 3) nbest+pmass, +nbest=${nbest:-20} # if nbest is not specified, set it to 20, due to Phonetisaurus mechanism +pmass=${pmass:-1.0} # if pmass is not specified, set it to 1.0, due to Phonetisaurus mechanism + +[[ ! $nbest =~ ^[1-9][0-9]*$ ]] && echo "$0: nbest should be a positive integer." && exit 1; + +echo "Applying the G2P model to wordlist $wordlist" +phonetisaurus-apply --pmass $pmass --nbest $nbest --thresh $thresh \ + --word_list $wordlist --model $model \ + --accumulate --verbose --prob \ + 1>$output_lex + +echo "Completed. Synthesized lexicon for new words is in $output_lex" + +# Some words might have been removed or skipped during the process, +# let's check it and warn the user if so... +nlex=`cut -f 1 $output_lex | sort -u | wc -l` +nwlist=`cut -f 1 $wordlist | sort -u | wc -l` +if [ $nlex -ne $nwlist ] ; then + failed_wordlist=$outdir/lexicon.failed + echo "WARNING: Unable to generate pronunciation for all words. "; + echo "WARINNG: Wordlist: $nwlist words" + echo "WARNING: Lexicon : $nlex words" + comm -13 <(cut -f 1 $output_lex | sort -u ) \ + <(cut -f 1 $wordlist | sort -u ) \ + >$failed_wordlist && echo "WARNING: The list of failed words is in $failed_wordlist" +fi +exit 0 + diff --git a/egs/wsj/s5/steps/dict/train_g2p_phonetisaurus.sh b/egs/wsj/s5/steps/dict/train_g2p_phonetisaurus.sh new file mode 100755 index 00000000000..94c483e09e2 --- /dev/null +++ b/egs/wsj/s5/steps/dict/train_g2p_phonetisaurus.sh @@ -0,0 +1,88 @@ +#!/bin/bash + +# Copyright 2017 Intellisist, Inc. (Author: Navneeth K) +# 2017 Xiaohui Zhang +# 2018 Ruizhe Huang +# Apache License 2.0 + +# This script trains a g2p model using Phonetisaurus. + +stage=0 +encoding='utf-8' +only_words=true +silence_phones= + +echo "$0 $@" # Print the command line for logging + +[ -f ./path.sh ] && . ./path.sh; # source the path. +. utils/parse_options.sh || exit 1; + +set -u +set -e + +if [ $# != 2 ]; then + echo "Usage: $0 [options] " + echo " where is the training lexicon (one pronunciation per " + echo " word per line, with lines like 'hello h uh l ow') and" + echo " is directory where the models will be stored" + echo "e.g.: $0 --silence-phones data/local/dict/silence_phones.txt data/local/dict/lexicon.txt exp/g2p/" + echo "" + echo "main options (for others, see top of script file)" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --silence-phones # e.g. data/local/dict/silence_phones.txt." + echo " # A list of silence phones, one or more per line" + echo " # Relates to --only-words option" + echo " --only-words (true|false) (default: true) # If true, exclude silence words, i.e." + echo " # words with one or multiple phones which are all silence." + exit 1; +fi + +lexicon=$1 +wdir=$2 + +[ ! -f $lexicon ] && echo "Cannot find $lexicon" && exit + +isuconv=`which uconv` +if [ -z $isuconv ]; then + echo "uconv was not found. You must install the icu4c package." + exit 1; +fi + +if ! phonetisaurus=`which phonetisaurus-apply` ; then + echo "Phonetisarus was not found !" + echo "Go to $KALDI_ROOT/tools and execute extras/install_phonetisaurus.sh" + exit 1 +fi + +mkdir -p $wdir + + +# For input lexicon, remove pronunciations containing non-utf-8-encodable characters, +# and optionally remove words that are mapped to a single silence phone from the lexicon. +if [ $stage -le 0 ]; then + if $only_words && [ ! -z "$silence_phones" ]; then + awk 'NR==FNR{a[$1] = 1; next} {s=$2;for(i=3;i<=NF;i++) s=s" "$i; if(!(s in a)) print $1" "s}' \ + $silence_phones $lexicon | \ + awk '{printf("%s\t",$1); for (i=2;i 0'> $wdir/lexicon_tab_separated.txt + else + awk '{printf("%s\t",$1); for (i=2;i 0'> $wdir/lexicon_tab_separated.txt + fi +fi + +if [ $stage -le 1 ]; then + # Align lexicon stage. Lexicon is assumed to have first column tab separated + phonetisaurus-align --input=$wdir/lexicon_tab_separated.txt --ofile=${wdir}/aligned_lexicon.corpus || exit 1; +fi + +if [ $stage -le 2 ]; then + # Convert aligned lexicon to arpa using make_kn_lm.py, a re-implementation of srilm's ngram-count functionality. + ./utils/lang/make_kn_lm.py -ngram-order 7 -text ${wdir}/aligned_lexicon.corpus -lm ${wdir}/aligned_lexicon.arpa +fi + +if [ $stage -le 3 ]; then + # Convert the arpa file to FST. + phonetisaurus-arpa2wfst --lm=${wdir}/aligned_lexicon.arpa --ofile=${wdir}/model.fst +fi + diff --git a/egs/wsj/s5/steps/get_train_ctm.sh b/egs/wsj/s5/steps/get_train_ctm.sh index 878e11e45ac..6942014fc88 100755 --- a/egs/wsj/s5/steps/get_train_ctm.sh +++ b/egs/wsj/s5/steps/get_train_ctm.sh @@ -20,8 +20,9 @@ echo "$0 $@" # Print the command line for logging [ -f ./path.sh ] && . ./path.sh . parse_options.sh || exit 1; -if [ $# -ne 3 ]; then - echo "Usage: $0 [options] " +if [ $# -ne 3 ] && [ $# -ne 4 ]; then + echo "Usage: $0 [options] []" + echo "( defaults to .)" echo " Options:" echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." echo " --stage (0|1|2) # start scoring script from part-way through." @@ -39,27 +40,31 @@ fi data=$1 lang=$2 # Note: may be graph directory not lang directory, but has the necessary stuff copied. -dir=$3 +ali_dir=$3 +dir=$4 +if [ -z $dir ]; then + dir=$ali_dir +fi -model=$dir/final.mdl # assume model one level up from decoding dir. +model=$ali_dir/final.mdl # assume model one level up from decoding dir. -for f in $lang/words.txt $model $dir/ali.1.gz $lang/oov.int; do +for f in $lang/words.txt $model $ali_dir/ali.1.gz $lang/oov.int; do [ ! -f $f ] && echo "$0: expecting file $f to exist" && exit 1; done oov=`cat $lang/oov.int` || exit 1; -nj=`cat $dir/num_jobs` || exit 1; +nj=`cat $ali_dir/num_jobs` || exit 1; split_data.sh $data $nj || exit 1; sdata=$data/split$nj -mkdir -p $dir/log +mkdir -p $dir/log || exit 1; if [ $stage -le 0 ]; then if [ -f $lang/phones/word_boundary.int ]; then $cmd JOB=1:$nj $dir/log/get_ctm.JOB.log \ - set -o pipefail '&&' linear-to-nbest "ark:gunzip -c $dir/ali.JOB.gz|" \ + set -o pipefail '&&' linear-to-nbest "ark:gunzip -c $ali_dir/ali.JOB.gz|" \ "ark:utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt < $sdata/JOB/text |" \ '' '' ark:- \| \ lattice-align-words $lang/phones/word_boundary.int $model ark:- ark:- \| \ @@ -72,7 +77,7 @@ if [ $stage -le 0 ]; then exit 1; fi $cmd JOB=1:$nj $dir/log/get_ctm.JOB.log \ - set -o pipefail '&&' linear-to-nbest "ark:gunzip -c $dir/ali.JOB.gz|" \ + set -o pipefail '&&' linear-to-nbest "ark:gunzip -c $ali_dir/ali.JOB.gz|" \ "ark:utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt < $sdata/JOB/text |" \ '' '' ark:- \| \ lattice-align-words-lexicon $lang/phones/align_lexicon.int $model ark:- ark:- \| \ @@ -94,4 +99,3 @@ if [ $stage -le 1 ]; then fi rm $dir/ctm.*.gz fi - diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py index 01c1b1e533c..611e1a3fdef 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py @@ -76,6 +76,7 @@ 'linear-component': xlayers.XconfigLinearComponent, 'affine-component': xlayers.XconfigAffineComponent, 'scale-component': xlayers.XconfigPerElementScaleComponent, + 'offset-component': xlayers.XconfigPerElementOffsetComponent, 'combine-feature-maps-layer': xlayers.XconfigCombineFeatureMapsLayer } diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/trivial_layers.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/trivial_layers.py index 42cc20293a5..f91258bab04 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/trivial_layers.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/trivial_layers.py @@ -261,9 +261,13 @@ def _generate_config(self): class XconfigCombineFeatureMapsLayer(XconfigLayerBase): """This class is for parsing lines like 'combine-feature-maps-layer name=combine_features1 height=40 num-filters1=1 num-filters2=4' - It produces a PermuteComponent. It expects its input to be two things + or + 'combine-feature-maps-layer name=combine_features1 height=40 num-filters1=1 num-filters2=4 num-filters3=2' + + It produces a PermuteComponent. It expects its input to be two or three things appended together, where the first is of dimension height * num-filters1 and - the second is of dimension height * num-filters2; it interpolates the filters + the second is of dimension height * num-filters2 (and the third, if present is + of dimension height * num-filters2; it interpolates the filters so the output can be interpreted as a single feature map with the same height as the input and the sum of the num-filters. @@ -278,21 +282,24 @@ def set_default_configs(self): self.config = { 'input': '[-1]', 'num-filters1': -1, 'num-filters2': -1, + 'num-filters3': 0, 'height': -1 } def check_configs(self): input_dim = self.descriptors['input']['dim'] if (self.config['num-filters1'] <= 0 or self.config['num-filters2'] <= 0 or + self.config['num-filters3'] < 0 or self.config['height'] <= 0): raise RuntimeError("invalid values of num-filters1, num-filters2 and/or height") f1 = self.config['num-filters1'] f2 = self.config['num-filters2'] + f3 = self.config['num-filters3'] h = self.config['height'] - if input_dim != (f1 + f2) * h: - raise RuntimeError("Expected input-dim={0} based on num-filters1={1}, num-filters2={2} " - "and height={3}, but got input-dim={4}".format( - (f1 + f2) * h, f1, f2, h, input_dim)) + if input_dim != (f1 + f2 + f3) * h: + raise RuntimeError("Expected input-dim={0} based on num-filters1={1}, num-filters2={2}, " + "num-filters3={3} and height={4}, but got input-dim={5}".format( + (f1 + f2 + f3) * h, f1, f2, f3, h, input_dim)) def output_name(self, auxiliary_output=None): assert auxiliary_output is None @@ -321,15 +328,18 @@ def _generate_config(self): dim = self.descriptors['input']['dim'] num_filters1 = self.config['num-filters1'] num_filters2 = self.config['num-filters2'] + num_filters3 = self.config['num-filters3'] # normally 0. height = self.config['height'] - assert dim == (num_filters1 + num_filters2) * height + assert dim == (num_filters1 + num_filters2 + num_filters3) * height column_map = [] for h in range(height): for f in range(num_filters1): column_map.append(h * num_filters1 + f) for f in range(num_filters2): - column_map.append((height * num_filters1) + h * num_filters2 + f) + column_map.append(height * num_filters1 + h * num_filters2 + f) + for f in range(num_filters3): + column_map.append(height * (num_filters1 + num_filters2) + h * num_filters3 + f) configs = [] line = ('component name={0} type=PermuteComponent column-map={1} '.format( @@ -496,3 +506,77 @@ def _generate_config(self): self.name, input_desc)) configs.append(line) return configs + +class XconfigPerElementOffsetComponent(XconfigLayerBase): + """This class is for parsing lines like + 'offset-component name=offset1 input=Append(-3,0,3)' + which will produce just a single component, of type PerElementOffsetComponent, with + output-dim 1024 in this case, and input-dim determined by the dimension of the input . + + Parameters of the class, and their defaults: + input='[-1]' [Descriptor giving the input of the layer.] + + The following (shown with their effective defaults) are just passed through + to the component's config line. (These defaults are mostly set in the + code). + + max-change=0.75 + l2-regularize=0.0 + param-mean=0.0 # affects initialization + param-stddev=0.0 # affects initialization + learning-rate-factor=1.0 + """ + def __init__(self, first_token, key_to_value, prev_names=None): + XconfigLayerBase.__init__(self, first_token, key_to_value, prev_names) + + def set_default_configs(self): + self.config = {'input': '[-1]', + 'l2-regularize': '', + 'max-change': 0.75, + 'param-mean': '', + 'param-stddev': '', + 'learning-rate-factor': '' } + + def check_configs(self): + pass + + def output_name(self, auxiliary_output=None): + assert auxiliary_output is None + return self.name + + def output_dim(self, auxiliary_output=None): + assert auxiliary_output is None + return self.descriptors['input']['dim'] + + def get_full_config(self): + ans = [] + config_lines = self._generate_config() + + for line in config_lines: + for config_name in ['ref', 'final']: + # we do not support user specified matrices in this layer + # so 'ref' and 'final' configs are the same. + ans.append((config_name, line)) + return ans + + def _generate_config(self): + # by 'descriptor_final_string' we mean a string that can appear in + # config-files, i.e. it contains the 'final' names of nodes. + input_desc = self.descriptors['input']['final-string'] + dim = self.descriptors['input']['dim'] + + opts = '' + for opt_name in ['learning-rate-factor', 'max-change', 'l2-regularize', 'param-mean', + 'param-stddev' ]: + value = self.config[opt_name] + if value != '': + opts += ' {0}={1}'.format(opt_name, value) + + configs = [] + line = ('component name={0} type=PerElementOffsetComponent dim={1} {2} ' + ''.format(self.name, dim, opts)) + configs.append(line) + line = ('component-node name={0} component={0} input={1}'.format( + self.name, input_desc)) + configs.append(line) + return configs diff --git a/egs/wsj/s5/steps/nnet3/chain/e2e/prepare_e2e.sh b/egs/wsj/s5/steps/nnet3/chain/e2e/prepare_e2e.sh index a060f0f3b36..c211381bf8b 100755 --- a/egs/wsj/s5/steps/nnet3/chain/e2e/prepare_e2e.sh +++ b/egs/wsj/s5/steps/nnet3/chain/e2e/prepare_e2e.sh @@ -18,6 +18,7 @@ treedir= # if specified, the tree and model will be copied from the # note that it may not be flat start anymore. type=mono # can be either mono or biphone -- either way # the resulting tree is full (i.e. it doesn't do any tying) +ci_silence=false # if true, silence phones will be treated as context independent scale_opts="--transition-scale=0.0 --self-loop-scale=0.0" # End configuration section. @@ -63,12 +64,17 @@ if $shared_phones; then shared_phones_opt="--shared-phones=$lang/phones/sets.int" fi +ciphonelist=`cat $lang/phones/context_indep.csl` || exit 1; +if $ci_silence; then + ci_opt="--ci-phones=$ciphonelist" +fi + if [ $stage -le 0 ]; then if [ -z $treedir ]; then echo "$0: Initializing $type system." # feat dim does not matter here. Just set it to 10 $cmd $dir/log/init_${type}_mdl_tree.log \ - gmm-init-$type $shared_phones_opt $lang/topo 10 \ + gmm-init-$type $ci_opt $shared_phones_opt $lang/topo 10 \ $dir/0.mdl $dir/tree || exit 1; else echo "$0: Copied tree/mdl from $treedir." >$dir/log/init_mdl_tree.log diff --git a/egs/wsj/s5/steps/nnet3/compute_output.sh b/egs/wsj/s5/steps/nnet3/compute_output.sh index da3cb704878..e55f705043b 100755 --- a/egs/wsj/s5/steps/nnet3/compute_output.sh +++ b/egs/wsj/s5/steps/nnet3/compute_output.sh @@ -54,7 +54,7 @@ fdir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print model=$srcdir/$iter.raw if [ ! -f $srcdir/$iter.raw ]; then - echo "$0: WARNING: no such file $srcdir/$iter.raw. Trying $srcdir/$iter.mdl instead." && exit 1 + echo "$0: WARNING: no such file $srcdir/$iter.raw. Trying $srcdir/$iter.mdl instead." model=$srcdir/$iter.mdl fi @@ -104,12 +104,15 @@ gpu_queue_opt= if $use_gpu; then gpu_queue_opt="--gpu 1" + suffix="-batch" gpu_opt="--use-gpu=yes" +else + gpu_opt="--use-gpu=no" fi if [ $stage -le 2 ]; then $cmd $gpu_queue_opt JOB=1:$nj $dir/log/compute_output.JOB.log \ - nnet3-compute $gpu_opt $ivector_opts $frame_subsampling_opt \ + nnet3-compute$suffix $gpu_opt $ivector_opts $frame_subsampling_opt \ --frames-per-chunk=$frames_per_chunk \ --extra-left-context=$extra_left_context \ --extra-right-context=$extra_right_context \ diff --git a/egs/wsj/s5/steps/nnet3/decode.sh b/egs/wsj/s5/steps/nnet3/decode.sh index 5b8374a5a1d..14dda2bd457 100755 --- a/egs/wsj/s5/steps/nnet3/decode.sh +++ b/egs/wsj/s5/steps/nnet3/decode.sh @@ -20,6 +20,10 @@ ivector_scale=1.0 lattice_beam=8.0 # Beam we use in lattice generation. iter=final num_threads=1 # if >1, will use gmm-latgen-faster-parallel +use_gpu=false # If true, will use a GPU, with nnet3-latgen-faster-batch. + # In that case it is recommended to set num-threads to a large + # number, e.g. 20 if you have that many free CPU slots on a GPU + # node, and to use a small number of jobs. scoring_opts= skip_diagnostics=false skip_scoring=false @@ -49,6 +53,9 @@ if [ $# -ne 3 ]; then echo " --iter # Iteration of model to decode; default is final." echo " --scoring-opts # options to local/score.sh" echo " --num-threads # number of threads to use, default 1." + echo " --use-gpu # default: false. If true, we recommend" + echo " # to use large --num-threads as the graph" + echo " # search becomes the limiting factor." exit 1; fi @@ -74,7 +81,16 @@ done sdata=$data/split$nj; cmvn_opts=`cat $srcdir/cmvn_opts` || exit 1; thread_string= -[ $num_threads -gt 1 ] && thread_string="-parallel --num-threads=$num_threads" +if $use_gpu; then + if [ $num_threads -eq 1 ]; then + echo "$0: **Warning: we recommend to use --num-threads > 1 for GPU-based decoding." + fi + thread_string="-batch --num-threads=$num_threads" + queue_opt="--num-threads $num_threads --gpu 1" +elif [ $num_threads -gt 1 ]; then + thread_string="-parallel --num-threads=$num_threads" + queue_opt="--num-threads $num_threads" +fi mkdir -p $dir/log [[ -d $sdata && $data/feats.scp -ot $sdata ]] || split_data.sh $data $nj || exit 1; @@ -104,7 +120,7 @@ if [ -f $srcdir/frame_subsampling_factor ]; then fi if [ $stage -le 1 ]; then - $cmd --num-threads $num_threads JOB=1:$nj $dir/log/decode.JOB.log \ + $cmd $queue_opt JOB=1:$nj $dir/log/decode.JOB.log \ nnet3-latgen-faster$thread_string $ivector_opts $frame_subsampling_opt \ --frames-per-chunk=$frames_per_chunk \ --extra-left-context=$extra_left_context \ diff --git a/egs/wsj/s5/steps/nnet3/decode_score_fusion.sh b/egs/wsj/s5/steps/nnet3/decode_score_fusion.sh index 7fb5daefdf3..cb678e84245 100755 --- a/egs/wsj/s5/steps/nnet3/decode_score_fusion.sh +++ b/egs/wsj/s5/steps/nnet3/decode_score_fusion.sh @@ -119,7 +119,7 @@ fi if [ $frame_subsampling_factor -eq 3 ]; then if [ $acwt != 1.0 ] || [ $post_decode_acwt != 10.0 ]; then echo -e '\n\n' - echo "$0 WARNING: In standard chain sysemt, acwt = 1.0, post_decode_acwt = 10.0" + echo "$0 WARNING: In standard chain system, acwt = 1.0, post_decode_acwt = 10.0" echo "$0 WARNING: Your acwt = $acwt, post_decode_acwt = $post_decode_acwt" echo "$0 WARNING: This is OK if you know what you are doing." echo -e '\n\n' diff --git a/egs/wsj/s5/steps/nnet3/train_raw_dnn.py b/egs/wsj/s5/steps/nnet3/train_raw_dnn.py index 34214169d5d..ab2b6873fee 100755 --- a/egs/wsj/s5/steps/nnet3/train_raw_dnn.py +++ b/egs/wsj/s5/steps/nnet3/train_raw_dnn.py @@ -101,7 +101,14 @@ def get_args(): help="Directory with features used for training " "the neural network.") parser.add_argument("--targets-scp", type=str, required=False, - help="Targets for training neural network.") + help="""Targets for training neural network. + This is a kaldi-format SCP file of target matrices. + . + The target matrix's column dim must match + the neural network output dim, and the + row dim must match the number of output frames + i.e. after subsampling if "--frame-subsampling-factor" + option is passed to --egs.opts.""") parser.add_argument("--dir", type=str, required=True, help="Directory to store the models and " "all other files.") diff --git a/egs/wsj/s5/steps/segmentation/detect_speech_activity.sh b/egs/wsj/s5/steps/segmentation/detect_speech_activity.sh index f71a14aebf1..831283bb5ec 100755 --- a/egs/wsj/s5/steps/segmentation/detect_speech_activity.sh +++ b/egs/wsj/s5/steps/segmentation/detect_speech_activity.sh @@ -56,7 +56,15 @@ acwt=0.3 # e.g. --speech-in-sil-weight=0.0 --garbage-in-sil-weight=0.0 --sil-in-speech-weight=0.0 --garbage-in-speech-weight=0.3 transform_probs_opts="" +# Postprocessing options segment_padding=0.2 # Duration (in seconds) of padding added to segments +min_segment_dur=0 # Minimum duration (in seconds) required for a segment to be included + # This is before any padding. Segments shorter than this duration will be removed. + # This is an alternative to --min-speech-duration above. +merge_consecutive_max_dur=0 # Merge consecutive segments as long as the merged segment is no longer than this many + # seconds. The segments are only merged if their boundaries are touching. + # This is after padding by --segment-padding seconds. + # 0 means do not merge. Use 'inf' to not limit the duration. echo $* @@ -225,7 +233,8 @@ fi if [ $stage -le 7 ]; then steps/segmentation/post_process_sad_to_segments.sh \ - --segment-padding $segment_padding \ + --segment-padding $segment_padding --min-segment-dur $min_segment_dur \ + --merge-consecutive-max-dur $merge_consecutive_max_dur \ --cmd "$cmd" --frame-shift $(perl -e "print $frame_subsampling_factor * $frame_shift") \ ${test_data_dir} ${seg_dir} ${seg_dir} fi diff --git a/egs/wsj/s5/steps/segmentation/internal/sad_to_segments.py b/egs/wsj/s5/steps/segmentation/internal/sad_to_segments.py index 9b1c0f12b9a..cf19f9bbfb3 100755 --- a/egs/wsj/s5/steps/segmentation/internal/sad_to_segments.py +++ b/egs/wsj/s5/steps/segmentation/internal/sad_to_segments.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # Copyright 2017 Vimal Manohar +# 2018 Capital One (Author: Zhiyuan Guan) # Apache 2.0 """ @@ -29,6 +30,7 @@ global_verbose = 0 + def get_args(): parser = argparse.ArgumentParser( description=""" @@ -44,18 +46,31 @@ def get_args(): parser.add_argument("--utt2dur", type=str, help="File containing durations of utterances.") + parser.add_argument("--frame-shift", type=float, default=0.01, help="Frame shift to convert frame indexes to time") parser.add_argument("--segment-padding", type=float, default=0.2, help="Additional padding on speech segments. But we " - "ensure that the padding does not go beyond the " - "adjacent segment.") + "ensure that the padding does not go beyond the " + "adjacent segment.") + parser.add_argument("--min-segment-dur", type=float, default=0, + help="Minimum duration (in seconds) required for a segment " + "to be included. This is before any padding. Segments " + "shorter than this duration will be removed.") + + parser.add_argument("--merge-consecutive-max-dur", type=float, default=0, + help="Merge consecutive segments as long as the merged " + "segment is no longer than this many seconds. The segments " + "are only merged if their boundaries are touching. " + "This is after padding by --segment-padding seconds." + "0 means do not merge. Use 'inf' to not limit the duration.") parser.add_argument("in_sad", type=str, help="Input file containing alignments in " - "text archive format") + "text archive format") + parser.add_argument("out_segments", type=str, help="Output kaldi segments file") @@ -80,28 +95,45 @@ def to_str(segment): class SegmenterStats(object): """Stores stats about the post-process stages""" + def __init__(self): - self.num_segments = 0 + self.num_segments_initial = 0 + self.num_short_segments_filtered = 0 + self.num_merges = 0 + self.num_segments_final = 0 self.initial_duration = 0.0 self.padding_duration = 0.0 + self.filter_short_duration = 0.0 self.final_duration = 0.0 def add(self, other): """Adds stats from another object""" - self.num_segments += other.num_segments + self.num_segments_initial += other.num_segments_initial + self.num_short_segments_filtered += other.num_short_segments_filtered + self.num_merges += other.num_merges + self.num_segments_final += other.num_segments_final self.initial_duration += other.initial_duration - self.padding_duration = other.padding_duration - self.final_duration = other.final_duration + self.filter_short_duration += other.filter_short_duration + self.padding_duration += other.padding_duration + self.final_duration += other.final_duration def __str__(self): - return ("num-segments={num_segments}, " + return ("num-segments-initial={num_segments_initial}, " + "num-short-segments-filtered={num_short_segments_filtered}, " + "num-merges={num_merges}, " + "num-segments-final={num_segments_final}, " "initial-duration={initial_duration}, " + "filter-short-duration={filter_short_duration}, " "padding-duration={padding_duration}, " "final-duration={final_duration}".format( - num_segments=self.num_segments, - initial_duration=self.initial_duration, - padding_duration=self.padding_duration, - final_duration=self.final_duration)) + num_segments_initial=self.num_segments_initial, + num_short_segments_filtered=self.num_short_segments_filtered, + num_merges=self.num_merges, + num_segments_final=self.num_segments_final, + initial_duration=self.initial_duration, + filter_short_duration=self.filter_short_duration, + padding_duration=self.padding_duration, + final_duration=self.final_duration)) def process_label(text_label): @@ -114,13 +146,14 @@ def process_label(text_label): prev_label = int(text_label) if prev_label not in [1, 2]: raise ValueError("Expecting label to 1 (non-speech) or 2 (speech); " - "got {0}".format(prev_label)) + "got {}".format(prev_label)) return prev_label class Segmentation(object): """Stores segmentation for an utterances""" + def __init__(self): self.segments = None self.stats = SegmenterStats() @@ -143,8 +176,8 @@ def initialize_segments(self, alignment, frame_shift=0.01): float(i) * frame_shift, prev_label]) prev_label = process_label(text_label) - prev_length = 0 self.stats.initial_duration += (prev_length * frame_shift) + prev_length = 0 elif prev_label is None: prev_label = process_label(text_label) @@ -156,7 +189,27 @@ def initialize_segments(self, alignment, frame_shift=0.01): float(len(alignment)) * frame_shift, prev_label]) self.stats.initial_duration += (prev_length * frame_shift) - self.stats.num_segments = len(self.segments) + self.stats.num_segments_initial = len(self.segments) + self.stats.num_segments_final = len(self.segments) + self.stats.final_duration = self.stats.initial_duration + + def filter_short_segments(self, min_dur): + """Filters out segments with durations shorter than 'min_dur'.""" + if min_dur <= 0: + return + + segments_kept = [] + for segment in self.segments: + assert segment[2] == 2, segment + dur = segment[1] - segment[0] + if dur < min_dur: + self.stats.filter_short_duration += dur + self.stats.num_short_segments_filtered += 1 + else: + segments_kept.append(segment) + self.segments = segments_kept + self.stats.num_segments_final = len(self.segments) + self.stats.final_duration -= self.stats.filter_short_duration def pad_speech_segments(self, segment_padding, max_duration=float("inf")): """Pads segments by duration 'segment_padding' on either sides, but @@ -166,19 +219,19 @@ def pad_speech_segments(self, segment_padding, max_duration=float("inf")): max_duration = float("inf") for i, segment in enumerate(self.segments): assert segment[2] == 2, segment - segment[0] -= segment_padding # try adding padding on the left side + segment[0] -= segment_padding # try adding padding on the left side self.stats.padding_duration += segment_padding if segment[0] < 0.0: # Padding takes the segment start to before the beginning of the utterance. # Reduce padding. self.stats.padding_duration += segment[0] segment[0] = 0.0 - if i >= 1 and self.segments[i-1][1] > segment[0]: + if i >= 1 and self.segments[i - 1][1] > segment[0]: # Padding takes the segment start to before the end the previous segment. # Reduce padding. self.stats.padding_duration -= ( - self.segments[i-1][1] - segment[0]) - segment[0] = self.segments[i-1][1] + self.segments[i - 1][1] - segment[0]) + segment[0] = self.segments[i - 1][1] segment[1] += segment_padding self.stats.padding_duration += segment_padding @@ -188,12 +241,35 @@ def pad_speech_segments(self, segment_padding, max_duration=float("inf")): self.stats.padding_duration -= (segment[1] - max_duration) segment[1] = max_duration if (i + 1 < len(self.segments) - and segment[1] > self.segments[i+1][0]): + and segment[1] > self.segments[i + 1][0]): # Padding takes the segment end beyond the start of the next segment. # Reduce padding. self.stats.padding_duration -= ( - segment[1] - self.segments[i+1][0]) - segment[1] = self.segments[i+1][0] + segment[1] - self.segments[i + 1][0]) + segment[1] = self.segments[i + 1][0] + self.stats.final_duration += self.stats.padding_duration + + def merge_consecutive_segments(self, max_dur): + """Merge consecutive segments (happens after padding), provided that + the merged segment is no longer than 'max_dur'.""" + if max_dur <= 0 or not self.segments: + return + + merged_segments = [self.segments[0]] + for segment in self.segments[1:]: + assert segment[2] == 2, segment + if segment[0] == merged_segments[-1][1] and \ + segment[1] - merged_segments[-1][1] <= max_dur: + # The segment starts at the same time the last segment ends, + # and the merged segment is shorter than 'max_dur'. + # Extend the previous segment. + merged_segments[-1][1] = segment[1] + self.stats.num_merges += 1 + else: + merged_segments.append(segment) + + self.segments = merged_segments + self.stats.num_segments_final = len(self.segments) def write(self, key, file_handle): """Write segments to file""" @@ -203,9 +279,9 @@ def write(self, key, file_handle): for segment in self.segments: seg_id = "{key}-{st:07d}-{end:07d}".format( key=key, st=int(segment[0] * 100), end=int(segment[1] * 100)) - print ("{seg_id} {key} {st:.2f} {end:.2f}".format( + print("{seg_id} {key} {st:.2f} {end:.2f}".format( seg_id=seg_id, key=key, st=segment[0], end=segment[1]), - file=file_handle) + file=file_handle) def run(args): @@ -235,9 +311,11 @@ def run(args): segmentation = Segmentation() segmentation.initialize_segments( parts[1:], args.frame_shift) + segmentation.filter_short_segments(args.min_segment_dur) segmentation.pad_speech_segments(args.segment_padding, None if args.utt2dur is None else utt2dur[utt_id]) + segmentation.merge_consecutive_segments(args.merge_consecutive_max_dur) segmentation.write(utt_id, out_segments_fh) global_stats.add(segmentation.stats) logger.info(global_stats) diff --git a/egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh b/egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh index ca9cea2518b..b168c307b57 100755 --- a/egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh +++ b/egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh @@ -18,6 +18,8 @@ nj=18 # The values below are in seconds frame_shift=0.01 segment_padding=0.2 +min_segment_dur=0 +merge_consecutive_max_dur=0 . utils/parse_options.sh @@ -53,6 +55,7 @@ if [ $stage -le 0 ]; then copy-int-vector "ark:gunzip -c $vad_dir/ali.JOB.gz |" ark,t:- \| \ steps/segmentation/internal/sad_to_segments.py \ --frame-shift=$frame_shift --segment-padding=$segment_padding \ + --min-segment-dur=$min_segment_dur --merge-consecutive-max-dur=$merge_consecutive_max_dur \ --utt2dur=$data_dir/utt2dur - $dir/segments.JOB fi diff --git a/egs/wsj/s5/steps/tfrnnlm/lstm.py b/egs/wsj/s5/steps/tfrnnlm/lstm.py index 06969fbcb5d..5f175212c4b 100644 --- a/egs/wsj/s5/steps/tfrnnlm/lstm.py +++ b/egs/wsj/s5/steps/tfrnnlm/lstm.py @@ -16,8 +16,8 @@ # this script trains a vanilla RNNLM with TensorFlow. # to call the script, do -# python steps/tfrnnlm/lstm.py --data-path=$datadir \ -# --save-path=$savepath --vocab-path=$rnn.wordlist [--hidden-size=$size] +# python steps/tfrnnlm/lstm.py --data_path=$datadir \ +# --save_path=$savepath --vocab_path=$rnn.wordlist [--hidden-size=$size] # # One example recipe is at egs/ami/s5/local/tfrnnlm/run_lstm.sh @@ -38,15 +38,15 @@ flags = tf.flags logging = tf.logging -flags.DEFINE_integer("hidden-size", 200, "hidden dim of RNN") +flags.DEFINE_integer("hidden_size", 200, "hidden dim of RNN") -flags.DEFINE_string("data-path", None, +flags.DEFINE_string("data_path", None, "Where the training/test data is stored.") -flags.DEFINE_string("vocab-path", None, +flags.DEFINE_string("vocab_path", None, "Where the wordlist file is stored.") -flags.DEFINE_string("save-path", None, +flags.DEFINE_string("save_path", None, "Model output directory.") -flags.DEFINE_bool("use-fp16", False, +flags.DEFINE_bool("use_fp16", False, "Train using 16-bit floats instead of 32bit floats") FLAGS = flags.FLAGS diff --git a/egs/wsj/s5/steps/tfrnnlm/lstm_fast.py b/egs/wsj/s5/steps/tfrnnlm/lstm_fast.py index 9643468ccfb..440962a3780 100644 --- a/egs/wsj/s5/steps/tfrnnlm/lstm_fast.py +++ b/egs/wsj/s5/steps/tfrnnlm/lstm_fast.py @@ -16,8 +16,8 @@ # this script trains a vanilla RNNLM with TensorFlow. # to call the script, do -# python steps/tfrnnlm/lstm_fast.py --data-path=$datadir \ -# --save-path=$savepath --vocab-path=$rnn.wordlist [--hidden-size=$size] +# python steps/tfrnnlm/lstm_fast.py --data_path=$datadir \ +# --save_path=$savepath --vocab_path=$rnn.wordlist [--hidden-size=$size] # # One example recipe is at egs/ami/s5/local/tfrnnlm/run_vanilla_rnnlm.sh @@ -38,15 +38,15 @@ flags = tf.flags logging = tf.logging -flags.DEFINE_integer("hidden-size", 200, "hidden dim of RNN") +flags.DEFINE_integer("hidden_size", 200, "hidden dim of RNN") -flags.DEFINE_string("data-path", None, +flags.DEFINE_string("data_path", None, "Where the training/test data is stored.") -flags.DEFINE_string("vocab-path", None, +flags.DEFINE_string("vocab_path", None, "Where the wordlist file is stored.") -flags.DEFINE_string("save-path", None, +flags.DEFINE_string("save_path", None, "Model output directory.") -flags.DEFINE_bool("use-fp16", False, +flags.DEFINE_bool("use_fp16", False, "Train using 16-bit floats instead of 32bit floats") FLAGS = flags.FLAGS diff --git a/egs/wsj/s5/steps/tfrnnlm/vanilla_rnnlm.py b/egs/wsj/s5/steps/tfrnnlm/vanilla_rnnlm.py index de263c6923f..f3ce1a5c297 100644 --- a/egs/wsj/s5/steps/tfrnnlm/vanilla_rnnlm.py +++ b/egs/wsj/s5/steps/tfrnnlm/vanilla_rnnlm.py @@ -16,8 +16,8 @@ # this script trains a vanilla RNNLM with TensorFlow. # to call the script, do -# python steps/tfrnnlm/vanilla_rnnlm.py --data-path=$datadir \ -# --save-path=$savepath --vocab-path=$rnn.wordlist [--hidden-size=$size] +# python steps/tfrnnlm/vanilla_rnnlm.py --data_path=$datadir \ +# --save_path=$savepath --vocab_path=$rnn.wordlist [--hidden-size=$size] # # One example recipe is at egs/ami/s5/local/tfrnnlm/run_vanilla_rnnlm.sh @@ -38,15 +38,15 @@ flags = tf.flags logging = tf.logging -flags.DEFINE_integer("hidden-size", 200, "hidden dim of RNN") +flags.DEFINE_integer("hidden_size", 200, "hidden dim of RNN") -flags.DEFINE_string("data-path", None, +flags.DEFINE_string("data_path", None, "Where the training/test data is stored.") -flags.DEFINE_string("vocab-path", None, +flags.DEFINE_string("vocab_path", None, "Where the wordlist file is stored.") -flags.DEFINE_string("save-path", None, +flags.DEFINE_string("save_path", None, "Model output directory.") -flags.DEFINE_bool("use-fp16", False, +flags.DEFINE_bool("use_fp16", False, "Train using 16-bit floats instead of 32bit floats") FLAGS = flags.FLAGS diff --git a/egs/wsj/s5/utils/lang/bpe/bidi.py b/egs/wsj/s5/utils/lang/bpe/bidi.py new file mode 100755 index 00000000000..447313a5d02 --- /dev/null +++ b/egs/wsj/s5/utils/lang/bpe/bidi.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright 2018 Chun-Chieh Chang + +# This script is largely written by Stephen Rawls +# and uses the python package https://pypi.org/project/PyICU_BiDi/ +# The code leaves right to left text alone and reverses left to right text. + +import icu_bidi +import io +import sys +import unicodedata +# R=strong right-to-left; AL=strong arabic right-to-left +rtl_set = set(chr(i) for i in range(sys.maxunicode) + if unicodedata.bidirectional(chr(i)) in ['R','AL']) +def determine_text_direction(text): + # Easy case first + for char in text: + if char in rtl_set: + return icu_bidi.UBiDiLevel.UBIDI_RTL + # If we made it here we did not encounter any strongly rtl char + return icu_bidi.UBiDiLevel.UBIDI_LTR + +def utf8_visual_to_logical(text): + text_dir = determine_text_direction(text) + + bidi = icu_bidi.Bidi() + bidi.inverse = True + bidi.reordering_mode = icu_bidi.UBiDiReorderingMode.UBIDI_REORDER_INVERSE_LIKE_DIRECT + bidi.reordering_options = icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_DEFAULT # icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_INSERT_MARKS + + bidi.set_para(text, text_dir, None) + + res = bidi.get_reordered(0 | icu_bidi.UBidiWriteReorderedOpt.UBIDI_DO_MIRRORING | icu_bidi.UBidiWriteReorderedOpt.UBIDI_KEEP_BASE_COMBINING) + + return res + +def utf8_logical_to_visual(text): + text_dir = determine_text_direction(text) + + bidi = icu_bidi.Bidi() + + bidi.reordering_mode = icu_bidi.UBiDiReorderingMode.UBIDI_REORDER_DEFAULT + bidi.reordering_options = icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_DEFAULT #icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_INSERT_MARKS + + bidi.set_para(text, text_dir, None) + + res = bidi.get_reordered(0 | icu_bidi.UBidiWriteReorderedOpt.UBIDI_DO_MIRRORING | icu_bidi.UBidiWriteReorderedOpt.UBIDI_KEEP_BASE_COMBINING) + + return res + + +##main## +sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding="utf8") +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf8") +for line in sys.stdin: + line = line.strip() + line = utf8_logical_to_visual(line)[::-1] + sys.stdout.write(line + '\n') diff --git a/egs/madcat_ar/v1/local/reverse.py b/egs/wsj/s5/utils/lang/bpe/reverse.py similarity index 100% rename from egs/madcat_ar/v1/local/reverse.py rename to egs/wsj/s5/utils/lang/bpe/reverse.py diff --git a/egs/wsj/s5/utils/lang/make_kn_lm.py b/egs/wsj/s5/utils/lang/make_kn_lm.py new file mode 100755 index 00000000000..00c4c8f2378 --- /dev/null +++ b/egs/wsj/s5/utils/lang/make_kn_lm.py @@ -0,0 +1,379 @@ +#!/usr/bin/env python3 + +# Copyright 2016 Johns Hopkins University (Author: Daniel Povey) +# 2018 Ruizhe Huang +# Apache 2.0. + +# This is an implementation of computing Kneser-Ney smoothed language model +# in the same way as srilm. This is a back-off, unmodified version of +# Kneser-Ney smoothing, which produces the same results as the following +# command (as an example) of srilm: +# +# $ ngram-count -order 4 -kn-modify-counts-at-end -ukndiscount -gt1min 0 -gt2min 0 -gt3min 0 -gt4min 0 \ +# -text corpus.txt -lm lm.arpa +# +# The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py +# The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html + +import sys +import os +import re +import io +import math +import argparse +from collections import Counter, defaultdict + + +parser = argparse.ArgumentParser(description=""" + Generate kneser-ney language model as arpa format. By default, + it will read the corpus from standard input, and output to standard output. + """) +parser.add_argument("-ngram-order", type=int, default=4, choices=[2, 3, 4, 5, 6, 7], help="Order of n-gram") +parser.add_argument("-text", type=str, default=None, help="Path to the corpus file") +parser.add_argument("-lm", type=str, default=None, help="Path to output arpa file for language models") +parser.add_argument("-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level") +args = parser.parse_args() + +default_encoding = "latin-1" # For encoding-agnostic scripts, we assume byte stream as input. + # Need to be very careful about the use of strip() and split() + # in this case, because there is a latin-1 whitespace character + # (nbsp) which is part of the unicode encoding range. + # Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717 +strip_chars = " \t\r\n" +whitespace = re.compile("[ \t]+") + + +class CountsForHistory: + # This class (which is more like a struct) stores the counts seen in a + # particular history-state. It is used inside class NgramCounts. + # It really does the job of a dict from int to float, but it also + # keeps track of the total count. + def __init__(self): + # The 'lambda: defaultdict(float)' is an anonymous function taking no + # arguments that returns a new defaultdict(float). + self.word_to_count = defaultdict(int) + self.word_to_context = defaultdict(set) # using a set to count the number of unique contexts + self.word_to_f = dict() # discounted probability + self.word_to_bow = dict() # back-off weight + self.total_count = 0 + + def words(self): + return self.word_to_count.keys() + + def __str__(self): + # e.g. returns ' total=12: 3->4, 4->6, -1->2' + return ' total={0}: {1}'.format( + str(self.total_count), + ', '.join(['{0} -> {1}'.format(word, count) + for word, count in self.word_to_count.items()])) + + def add_count(self, predicted_word, context_word, count): + assert count >= 0 + + self.total_count += count + self.word_to_count[predicted_word] += count + if context_word is not None: + self.word_to_context[predicted_word].add(context_word) + + +class NgramCounts: + # A note on data-structure. Firstly, all words are represented as + # integers. We store n-gram counts as an array, indexed by (history-length + # == n-gram order minus one) (note: python calls arrays "lists") of dicts + # from histories to counts, where histories are arrays of integers and + # "counts" are dicts from integer to float. For instance, when + # accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd + # do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an + # array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict. + def __init__(self, ngram_order, bos_symbol='', eos_symbol=''): + assert ngram_order >= 2 + + self.ngram_order = ngram_order + self.bos_symbol = bos_symbol + self.eos_symbol = eos_symbol + + self.counts = [] + for n in range(ngram_order): + self.counts.append(defaultdict(lambda: CountsForHistory())) + + self.d = [] # list of discounting factor for each order of ngram + + # adds a raw count (called while processing input data). + # Suppose we see the sequence '6 7 8 9' and ngram_order=4, 'history' + # would be (6,7,8) and 'predicted_word' would be 9; 'count' would be + # 1. + def add_count(self, history, predicted_word, context_word, count): + self.counts[len(history)][history].add_count(predicted_word, context_word, count) + + # 'line' is a string containing a sequence of integer word-ids. + # This function adds the un-smoothed counts from this line of text. + def add_raw_counts_from_line(self, line): + words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol] + + for i in range(len(words)): + for n in range(1, self.ngram_order+1): + if i + n > len(words): + break + + ngram = words[i: i + n] + predicted_word = ngram[-1] + history = tuple(ngram[: -1]) + if i == 0 or n == self.ngram_order: + context_word = None + else: + context_word = words[i-1] + + self.add_count(history, predicted_word, context_word, 1) + + def add_raw_counts_from_standard_input(self): + lines_processed = 0 + infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding) # byte stream as input + for line in infile: + line = line.strip(strip_chars) + if line == '': + break + self.add_raw_counts_from_line(line) + lines_processed += 1 + if lines_processed == 0 or args.verbose > 0: + print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr) + + def add_raw_counts_from_file(self, filename): + lines_processed = 0 + with open(filename, encoding=default_encoding) as fp: + for line in fp: + line = line.strip(strip_chars) + if line == '': + break + self.add_raw_counts_from_line(line) + lines_processed += 1 + if lines_processed == 0 or args.verbose > 0: + print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr) + + def cal_discounting_constants(self): + # For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N), + # where n1_N is the number of unique N-grams with count = 1 (counts-of-counts). + # This constant is used similarly to absolute discounting. + # Return value: d is a list of floats, where d[N+1] = D_N + + self.d = [0] # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0 + # This is a special case: as we currently assumed having seen all vocabularies in the dictionary, + # but perhaps this is not the case for some other scenarios. + for n in range(1, self.ngram_order): + this_order_counts = self.counts[n] + n1 = 0 + n2 = 0 + for hist, counts_for_hist in this_order_counts.items(): + stat = Counter(counts_for_hist.word_to_count.values()) + n1 += stat[1] + n2 += stat[2] + assert n1 + 2 * n2 > 0 + self.d.append(n1 * 1.0 / (n1 + 2 * n2)) + + def cal_f(self): + # f(a_z) is a probability distribution of word sequence a_z. + # Typically f(a_z) is discounted to be less than the ML estimate so we have + # some leftover probability for the z words unseen in the context (a_). + # + # f(a_z) = (c(a_z) - D0) / c(a_) ;; for highest order N-grams + # f(_z) = (n(*_z) - D1) / n(*_*) ;; for lower order N-grams + + # highest order N-grams + n = self.ngram_order - 1 + this_order_counts = self.counts[n] + for hist, counts_for_hist in this_order_counts.items(): + for w, c in counts_for_hist.word_to_count.items(): + counts_for_hist.word_to_f[w] = max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count + + # lower order N-grams + for n in range(0, self.ngram_order - 1): + this_order_counts = self.counts[n] + for hist, counts_for_hist in this_order_counts.items(): + + n_star_star = 0 + for w in counts_for_hist.word_to_count.keys(): + n_star_star += len(counts_for_hist.word_to_context[w]) + + if n_star_star != 0: + for w in counts_for_hist.word_to_count.keys(): + n_star_z = len(counts_for_hist.word_to_context[w]) + counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star + else: # patterns begin with , they do not have "modified count", so use raw count instead + for w in counts_for_hist.word_to_count.keys(): + n_star_z = counts_for_hist.word_to_count[w] + counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / counts_for_hist.total_count + + def cal_bow(self): + # Backoff weights are only necessary for ngrams which form a prefix of a longer ngram. + # Thus, two sorts of ngrams do not have a bow: + # 1) highest order ngram + # 2) ngrams ending in + # + # bow(a_) = (1 - Sum_Z1 f(a_z)) / (1 - Sum_Z1 f(_z)) + # Note that Z1 is the set of all words with c(a_z) > 0 + + # highest order N-grams + n = self.ngram_order - 1 + this_order_counts = self.counts[n] + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + counts_for_hist.word_to_bow[w] = None + + # lower order N-grams + for n in range(0, self.ngram_order - 1): + this_order_counts = self.counts[n] + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + if w == self.eos_symbol: + counts_for_hist.word_to_bow[w] = None + else: + a_ = hist + (w,) + + assert len(a_) < self.ngram_order + assert a_ in self.counts[len(a_)].keys() + + a_counts_for_hist = self.counts[len(a_)][a_] + + sum_z1_f_a_z = 0 + for u in a_counts_for_hist.word_to_count.keys(): + sum_z1_f_a_z += a_counts_for_hist.word_to_f[u] + + sum_z1_f_z = 0 + _ = a_[1:] + _counts_for_hist = self.counts[len(_)][_] + for u in a_counts_for_hist.word_to_count.keys(): # Should be careful here: what is Z1 + sum_z1_f_z += _counts_for_hist.word_to_f[u] + + counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (1.0 - sum_z1_f_z) + + def print_raw_counts(self, info_string): + # these are useful for debug. + print(info_string) + res = [] + for this_order_counts in self.counts: + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + ngram = " ".join(hist) + " " + w + ngram = ngram.strip(strip_chars) + + res.append("{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w])) + res.sort(reverse=True) + for r in res: + print(r) + + def print_modified_counts(self, info_string): + # these are useful for debug. + print(info_string) + res = [] + for this_order_counts in self.counts: + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + ngram = " ".join(hist) + " " + w + ngram = ngram.strip(strip_chars) + + modified_count = len(counts_for_hist.word_to_context[w]) + raw_count = counts_for_hist.word_to_count[w] + + if modified_count == 0: + res.append("{0}\t{1}".format(ngram, raw_count)) + else: + res.append("{0}\t{1}".format(ngram, modified_count)) + res.sort(reverse=True) + for r in res: + print(r) + + def print_f(self, info_string): + # these are useful for debug. + print(info_string) + res = [] + for this_order_counts in self.counts: + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + ngram = " ".join(hist) + " " + w + ngram = ngram.strip(strip_chars) + + f = counts_for_hist.word_to_f[w] + if f == 0: # f() is always 0 + f = 1e-99 + + res.append("{0}\t{1}".format(ngram, math.log(f, 10))) + res.sort(reverse=True) + for r in res: + print(r) + + def print_f_and_bow(self, info_string): + # these are useful for debug. + print(info_string) + res = [] + for this_order_counts in self.counts: + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + ngram = " ".join(hist) + " " + w + ngram = ngram.strip(strip_chars) + + f = counts_for_hist.word_to_f[w] + if f == 0: # f() is always 0 + f = 1e-99 + + bow = counts_for_hist.word_to_bow[w] + if bow is None: + res.append("{1}\t{0}".format(ngram, math.log(f, 10))) + else: + res.append("{1}\t{0}\t{2}".format(ngram, math.log(f, 10), math.log(bow, 10))) + res.sort(reverse=True) + for r in res: + print(r) + + def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1')): + # print as ARPA format. + + print('\\data\\', file=fout) + for hist_len in range(self.ngram_order): + # print the number of n-grams. + print('ngram {0}={1}'.format( + hist_len + 1, + sum([len(counts_for_hist.word_to_f) for counts_for_hist in self.counts[hist_len].values()])), + file=fout + ) + + print('', file=fout) + + for hist_len in range(self.ngram_order): + print('\\{0}-grams:'.format(hist_len + 1), file=fout) + + this_order_counts = self.counts[hist_len] + for hist, counts_for_hist in this_order_counts.items(): + for word in counts_for_hist.word_to_count.keys(): + ngram = hist + (word,) + prob = counts_for_hist.word_to_f[word] + bow = counts_for_hist.word_to_bow[word] + + if prob == 0: # f() is always 0 + prob = 1e-99 + + line = '{0}\t{1}'.format('%.7f' % math.log10(prob), ' '.join(ngram)) + if bow is not None: + line += '\t{0}'.format('%.7f' % math.log10(bow)) + print(line, file=fout) + print('', file=fout) + print('\\end\\', file=fout) + + +if __name__ == "__main__": + + ngram_counts = NgramCounts(args.ngram_order) + + if args.text is None: + ngram_counts.add_raw_counts_from_standard_input() + else: + assert os.path.isfile(args.text) + ngram_counts.add_raw_counts_from_file(args.text) + + ngram_counts.cal_discounting_constants() + ngram_counts.cal_f() + ngram_counts.cal_bow() + + if args.lm is None: + ngram_counts.print_as_arpa() + else: + with open(args.lm, 'w', encoding=default_encoding) as f: + ngram_counts.print_as_arpa(fout=f) diff --git a/egs/wsj/s5/utils/parallel/pbs.pl b/egs/wsj/s5/utils/parallel/pbs.pl index 35a33ba2dca..d61bb1d4566 100755 --- a/egs/wsj/s5/utils/parallel/pbs.pl +++ b/egs/wsj/s5/utils/parallel/pbs.pl @@ -401,7 +401,7 @@ () exit(1); } -my $sge_job_id; +my $pbs_job_id; if (! $sync) { # We're not submitting with -sync y, so we # need to wait for the jobs to finish. We wait for the # sync-files we "touched" in the script to exist. @@ -413,25 +413,25 @@ () push @syncfiles, "$syncfile.$jobid"; } } - # We will need the sge_job_id, to check that job still exists - { # Get the SGE job-id from the log file in q/ - open(L, "<$queue_logfile") || die "Error opening log file $queue_logfile"; - undef $sge_job_id; - while () { - if (m/Your job\S* (\d+)[. ].+ has been submitted/) { - if (defined $sge_job_id) { + # We will need the pbs_job_id, to check that job still exists + { # Get the PBS job-id from the log file in q/ + open my $L, '<', $queue_logfile || die "Error opening log file $queue_logfile"; + undef $pbs_job_id; + while (<$L>) { + if (/(\d+.+\.pbsserver)/) { + if (defined $pbs_job_id) { die "Error: your job was submitted more than once (see $queue_logfile)"; } else { - $sge_job_id = $1; + $pbs_job_id = $1; } } } - close(L); - if (!defined $sge_job_id) { - die "Error: log file $queue_logfile does not specify the SGE job-id."; + close $L; + if (!defined $pbs_job_id) { + die "Error: log file $queue_logfile does not specify the PBS job-id."; } } - my $check_sge_job_ctr=1; + my $check_pbs_job_ctr=1; # my $wait = 0.1; my $counter = 0; @@ -460,11 +460,11 @@ () } } - # Check that the job exists in SGE. Job can be killed if duration + # Check that the job exists in PBS. Job can be killed if duration # exceeds some hard limit, or in case of a machine shutdown. - if (($check_sge_job_ctr++ % 10) == 0) { # Don't run qstat too often, avoid stress on SGE. + if (($check_pbs_job_ctr++ % 10) == 0) { # Don't run qstat too often, avoid stress on PBS. if ( -f $f ) { next; }; #syncfile appeared: OK. - $ret = system("qstat -t $sge_job_id >/dev/null 2>/dev/null"); + $ret = system("qstat -t $pbs_job_id >/dev/null 2>/dev/null"); # system(...) : To get the actual exit value, shift $ret right by eight bits. if ($ret>>8 == 1) { # Job does not seem to exist # Don't consider immediately missing job as error, first wait some @@ -513,7 +513,7 @@ () exit(1); } } elsif ($ret != 0) { - print STDERR "pbs.pl: Warning: qstat command returned status $ret (qstat -t $sge_job_id,$!)\n"; + print STDERR "pbs.pl: Warning: qstat command returned status $ret (qstat -t $pbs_job_id,$!)\n"; } } } diff --git a/egs/wsj/s5/utils/parallel/queue.pl b/egs/wsj/s5/utils/parallel/queue.pl index e14af5ef6e3..bddcb4fec23 100755 --- a/egs/wsj/s5/utils/parallel/queue.pl +++ b/egs/wsj/s5/utils/parallel/queue.pl @@ -176,7 +176,7 @@ sub caught_signal { option max_jobs_run=* -tc $0 default gpu=0 option gpu=0 -option gpu=* -l gpu=$0 -q g.q +option gpu=* -l gpu=$0 -q '*.q' EOF # Here the configuration options specified by the user on the command line diff --git a/egs/yomdle_fa/README.txt b/egs/yomdle_fa/README.txt new file mode 100644 index 00000000000..984ffdb53b5 --- /dev/null +++ b/egs/yomdle_fa/README.txt @@ -0,0 +1,3 @@ +This directory contains example scripts for OCR on the Yomdle and Slam datasets. +Training is done on the Yomdle dataset and testing is done on Slam. +LM rescoring is also done with extra corpus data obtained from various newswires (e.g. Hamshahri) diff --git a/egs/yomdle_fa/v1/cmd.sh b/egs/yomdle_fa/v1/cmd.sh new file mode 100755 index 00000000000..3c8eb9f93a5 --- /dev/null +++ b/egs/yomdle_fa/v1/cmd.sh @@ -0,0 +1,13 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +export cmd="queue.pl" diff --git a/egs/yomdle_fa/v1/image b/egs/yomdle_fa/v1/image new file mode 120000 index 00000000000..1668ee99922 --- /dev/null +++ b/egs/yomdle_fa/v1/image @@ -0,0 +1 @@ +../../cifar/v1/image/ \ No newline at end of file diff --git a/egs/yomdle_fa/v1/local/augment_data.sh b/egs/yomdle_fa/v1/local/augment_data.sh new file mode 100755 index 00000000000..1c38bcb072d --- /dev/null +++ b/egs/yomdle_fa/v1/local/augment_data.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# Copyright 2018 Hossein Hadian +# 2018 Ashish Arora + +# Apache 2.0 +# This script performs data augmentation. + +nj=4 +cmd=run.pl +feat_dim=40 +fliplr=false +verticle_shift=0 +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +srcdir=$1 +outdir=$2 +datadir=$3 + +mkdir -p $datadir/augmentations +echo "copying $srcdir to $datadir/augmentations/aug1, allowed length, creating feats.scp" + +for set in aug1; do + image/copy_data_dir.sh --spk-prefix $set- --utt-prefix $set- \ + $srcdir $datadir/augmentations/$set + cat $srcdir/allowed_lengths.txt > $datadir/augmentations/$set/allowed_lengths.txt + local/extract_features.sh --nj $nj --cmd "$cmd" --feat-dim $feat_dim \ + --vertical-shift $verticle_shift \ + --fliplr $fliplr --augment 'random_scale' $datadir/augmentations/$set + +done + +echo " combine original data and data from different augmentations" +utils/combine_data.sh --extra-files images.scp $outdir $srcdir $datadir/augmentations/aug1 +cat $srcdir/allowed_lengths.txt > $outdir/allowed_lengths.txt diff --git a/egs/yomdle_fa/v1/local/bidi.py b/egs/yomdle_fa/v1/local/bidi.py new file mode 100755 index 00000000000..447313a5d02 --- /dev/null +++ b/egs/yomdle_fa/v1/local/bidi.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright 2018 Chun-Chieh Chang + +# This script is largely written by Stephen Rawls +# and uses the python package https://pypi.org/project/PyICU_BiDi/ +# The code leaves right to left text alone and reverses left to right text. + +import icu_bidi +import io +import sys +import unicodedata +# R=strong right-to-left; AL=strong arabic right-to-left +rtl_set = set(chr(i) for i in range(sys.maxunicode) + if unicodedata.bidirectional(chr(i)) in ['R','AL']) +def determine_text_direction(text): + # Easy case first + for char in text: + if char in rtl_set: + return icu_bidi.UBiDiLevel.UBIDI_RTL + # If we made it here we did not encounter any strongly rtl char + return icu_bidi.UBiDiLevel.UBIDI_LTR + +def utf8_visual_to_logical(text): + text_dir = determine_text_direction(text) + + bidi = icu_bidi.Bidi() + bidi.inverse = True + bidi.reordering_mode = icu_bidi.UBiDiReorderingMode.UBIDI_REORDER_INVERSE_LIKE_DIRECT + bidi.reordering_options = icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_DEFAULT # icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_INSERT_MARKS + + bidi.set_para(text, text_dir, None) + + res = bidi.get_reordered(0 | icu_bidi.UBidiWriteReorderedOpt.UBIDI_DO_MIRRORING | icu_bidi.UBidiWriteReorderedOpt.UBIDI_KEEP_BASE_COMBINING) + + return res + +def utf8_logical_to_visual(text): + text_dir = determine_text_direction(text) + + bidi = icu_bidi.Bidi() + + bidi.reordering_mode = icu_bidi.UBiDiReorderingMode.UBIDI_REORDER_DEFAULT + bidi.reordering_options = icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_DEFAULT #icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_INSERT_MARKS + + bidi.set_para(text, text_dir, None) + + res = bidi.get_reordered(0 | icu_bidi.UBidiWriteReorderedOpt.UBIDI_DO_MIRRORING | icu_bidi.UBidiWriteReorderedOpt.UBIDI_KEEP_BASE_COMBINING) + + return res + + +##main## +sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding="utf8") +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf8") +for line in sys.stdin: + line = line.strip() + line = utf8_logical_to_visual(line)[::-1] + sys.stdout.write(line + '\n') diff --git a/egs/yomdle_fa/v1/local/chain/compare_wer.sh b/egs/yomdle_fa/v1/local/chain/compare_wer.sh new file mode 100755 index 00000000000..ab880c1adb5 --- /dev/null +++ b/egs/yomdle_fa/v1/local/chain/compare_wer.sh @@ -0,0 +1,67 @@ +#!/bin/bash + +# this script is used for comparing decoding results between systems. +# e.g. local/chain/compare_wer.sh exp/chain/cnn{1a,1b} + +# Copyright 2017 Chun Chieh Chang +# 2017 Ashish Arora + +if [ $# == 0 ]; then + echo "Usage: $0: [ ... ]" + echo "e.g.: $0 exp/chain/cnn{1a,1b}" + exit 1 +fi + +echo "# $0 $*" +used_epochs=false + +echo -n "# System " +for x in $*; do printf "% 10s" " $(basename $x)"; done +echo + +echo -n "# WER " +for x in $*; do + wer=$(cat $x/decode_test/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# CER " +for x in $*; do + cer=$(cat $x/decode_test/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + + +if $used_epochs; then + exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. +fi + +echo -n "# Final train prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final train prob (xent) " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob (xent) " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo diff --git a/egs/yomdle_fa/v1/local/chain/run_cnn_e2eali_1b.sh b/egs/yomdle_fa/v1/local/chain/run_cnn_e2eali_1b.sh new file mode 100755 index 00000000000..e7c125d16de --- /dev/null +++ b/egs/yomdle_fa/v1/local/chain/run_cnn_e2eali_1b.sh @@ -0,0 +1,244 @@ +#!/bin/bash + +# e2eali_1b is the same as chainali_1a but uses the e2e chain model to get the +# lattice alignments and to build a tree + +# local/chain/compare_wer.sh scale_baseline2/exp_yomdle_farsi/chain/e2e_cnn_1a scale_baseline2/exp_yomdle_farsi/chain/cnn_e2eali_1b +# System e2e_cnn_1a cnn_e2eali_1b +# WER 19.55 18.45 +# CER 5.64 4.94 +# Final train prob -0.0065 -0.0633 +# Final valid prob 0.0015 -0.0619 +# Final train prob (xent) -0.2636 +# Final valid prob (xent) -0.2511 + +set -e -o pipefail + +data_dir=data +exp_dir=exp + +stage=0 + +nj=30 +train_set=train +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. +affix=_1b #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +common_egs_dir= +reporting_email= + +# chain options +train_stage=-10 +xent_regularize=0.1 +frame_subsampling_factor=4 +# training chunk-options +chunk_width=340,300,200,100 +num_leaves=500 +# we don't need extra left/right context for TDNN systems. +chunk_left_context=0 +chunk_right_context=0 +tdnn_dim=450 +# training options +srand=0 +remove_egs=true +lang_test=lang_test +# End configuration section. +echo "$0 $@" # Print the command line for logging + + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 2 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/nnet3/align_lats.sh --nj $nj --cmd "$cmd" \ + --acoustic-scale 1.0 \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0' \ + ${train_data_dir} $data_dir/lang $e2echain_model_dir $lat_dir + echo "" >$lat_dir/splice_opts + +fi + +if [ $stage -le 3 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor $frame_subsampling_factor \ + --alignment-subsampling-factor 1 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$cmd" $num_leaves ${train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 4 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') + learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + cnn_opts="l2-regularize=0.075" + tdnn_opts="l2-regularize=0.075" + output_opts="l2-regularize=0.1" + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=72" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=144" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=196" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=120 name=input + + conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-8,-4,0,4,8) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' mod?els... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 5 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/iam-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=500" \ + --chain.frame-subsampling-factor=$frame_subsampling_factor \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=16 \ + --trainer.frames-per-iter=1000000 \ + --trainer.optimization.num-jobs-initial=4 \ + --trainer.optimization.num-jobs-final=8 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=32,16 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.chunk-left-context=$chunk_left_context \ + --egs.chunk-right-context=$chunk_right_context \ + --egs.chunk-left-context-initial=0 \ + --egs.chunk-right-context-final=0 \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 6 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $data_dir/$lang_test \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 7 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --extra-left-context $chunk_left_context \ + --extra-right-context $chunk_right_context \ + --extra-left-context-initial 0 \ + --extra-right-context-final 0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph $data_dir/test $dir/decode_test || exit 1; +fi diff --git a/egs/yomdle_fa/v1/local/chain/run_flatstart_cnn1a.sh b/egs/yomdle_fa/v1/local/chain/run_flatstart_cnn1a.sh new file mode 100755 index 00000000000..bb5352943f6 --- /dev/null +++ b/egs/yomdle_fa/v1/local/chain/run_flatstart_cnn1a.sh @@ -0,0 +1,170 @@ +#!/bin/bash +# Copyright 2017 Hossein Hadian + +# This script does end2end chain training (i.e. from scratch) + +# local/chain/compare_wer.sh exp_yomdle_farsi/chain/e2e_cnn_1a exp_yomdle_farsi/chain/cnn_e2eali_1b +# System e2e_cnn_1a cnn_e2eali_1b +# WER 19.55 18.45 +# CER 5.64 4.94 +# Final train prob -0.0065 -0.0633 +# Final valid prob 0.0015 -0.0619 +# Final train prob (xent) -0.2636 +# Final valid prob (xent) -0.2511 + +set -e + +data_dir=data +exp_dir=exp + +# configs for 'chain' +stage=0 +nj=30 +train_stage=-10 +get_egs_stage=-10 +affix=1a + +# training options +tdnn_dim=450 +num_epochs=4 +num_jobs_initial=4 +num_jobs_final=8 +minibatch_size=150=64,32/300=32,16/600=16,8/1200=8,4 +common_egs_dir= +l2_regularize=0.00005 +frames_per_iter=1000000 +cmvn_opts="--norm-means=false --norm-vars=false" +train_set=train +lang_test=lang_test + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 1 ]; then + steps/nnet3/chain/e2e/prepare_e2e.sh --nj $nj --cmd "$cmd" \ + --shared-phones true \ + --type mono \ + $data_dir/$train_set $lang $treedir + $cmd $treedir/log/make_phone_lm.log \ + cat $data_dir/$train_set/text \| \ + steps/nnet3/chain/e2e/text_to_phones.py $data_dir/lang \| \ + utils/sym2int.pl -f 2- $data_dir/lang/phones.txt \| \ + chain-est-phone-lm --num-extra-lm-states=500 \ + ark:- $treedir/phone_lm.fst +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $treedir/tree | grep num-pdfs | awk '{print $2}') + + cnn_opts="l2-regularize=0.075" + tdnn_opts="l2-regularize=0.075" + output_opts="l2-regularize=0.1" + + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=72" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=144" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=144" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=120 name=input + conv-relu-batchnorm-layer name=cnn1 height-in=40 height-out=40 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=40 height-out=20 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=20 height-out=20 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=20 height-out=10 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=10 height-out=10 time-offsets=-4,0,4 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-8,-4,0,4,8) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $output_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs +fi + +if [ $stage -le 3 ]; then + # no need to store the egs in a shared storage because we always + # remove them. Anyway, it takes only 5 minutes to generate them. + + steps/nnet3/chain/e2e/train_e2e.py --stage $train_stage \ + --cmd "$cmd" \ + --feat.cmvn-opts "$cmvn_opts" \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize $l2_regularize \ + --chain.apply-deriv-weights false \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--num_egs_diagnostic 100 --num_utts_subset 400" \ + --chain.frame-subsampling-factor 4 \ + --chain.alignment-subsampling-factor 4 \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter $frames_per_iter \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.momentum 0 \ + --trainer.optimization.num-jobs-initial $num_jobs_initial \ + --trainer.optimization.num-jobs-final $num_jobs_final \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.optimization.shrink-value 1.0 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs true \ + --feat-dir $data_dir/${train_set} \ + --tree-dir $treedir \ + --dir $dir || exit 1; +fi + +if [ $stage -le 4 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $data_dir/$lang_test \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 5 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $nj --cmd "$cmd" \ + $dir/graph $data_dir/test $dir/decode_test || exit 1; +fi + +echo "Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/yomdle_fa/v1/local/create_download.sh b/egs/yomdle_fa/v1/local/create_download.sh new file mode 100755 index 00000000000..1040ecc2165 --- /dev/null +++ b/egs/yomdle_fa/v1/local/create_download.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# Copyright 2018 Chun-Chieh Chang + +# The original format of the dataset given is GEDI and page images. +# This script is written to create line images from page images. +# It also creates csv files from the GEDI files. + +database_slam=/export/corpora5/slam/SLAM/Farsi/transcribed +database_yomdle=/export/corpora5/slam/YOMDLE/final_farsi +slam_dir=download/slam_farsi +yomdle_dir=download/yomdle_farsi + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +echo "$0: Processing SLAM ${language}" +echo "Date: $(date)." +mkdir -p ${slam_dir}/{truth_csv,truth_csv_raw,truth_line_image} +local/GEDI2CSV_enriched.py \ + --inputDir ${database_slam} \ + --outputDir ${slam_dir}/truth_csv_raw \ + --log ${slam_dir}/GEDI2CSV_enriched.log +local/create_line_image_from_page_image.py \ + ${database_slam} \ + ${slam_dir}/truth_csv_raw \ + ${slam_dir} + +echo "$0: Processing YOMDLE ${language}" +echo "Date: $(date)." +mkdir -p ${yomdle_dir}/{truth_csv,truth_csv_raw,truth_line_image} +local/YOMDLE2CSV.py \ + --inputDir ${database_yomdle} \ + --outputDir ${yomdle_dir}/truth_csv_raw/ \ + --log ${yomdle_dir}/YOMDLE2CSV.log +local/create_line_image_from_page_image.py \ + --im-format "jpg" \ + ${database_yomdle}/images \ + ${yomdle_dir}/truth_csv_raw \ + ${yomdle_dir} diff --git a/egs/yomdle_fa/v1/local/create_line_image_from_page_image.py b/egs/yomdle_fa/v1/local/create_line_image_from_page_image.py new file mode 100755 index 00000000000..77a6791d5d7 --- /dev/null +++ b/egs/yomdle_fa/v1/local/create_line_image_from_page_image.py @@ -0,0 +1,458 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora +# Apache 2.0 +# minimum bounding box part in this script is originally from +#https://github.com/BebeSparkelSparkel/MinimumBoundingBox +#https://startupnextdoor.com/computing-convex-hull-in-python/ +""" This module will be used for extracting line images from page image. + Given the word segmentation (bounding box around a word) for every word, it will + extract line segmentation. To extract line segmentation, it will take word bounding + boxes of a line as input, will create a minimum area bounding box that will contain + all corner points of word bounding boxes. The obtained bounding box (will not necessarily + be vertically or horizontally aligned). Hence to extract line image from line bounding box, + page image is rotated and line image is cropped and saved. +""" + +import argparse +import csv +import itertools +import sys +import os +import numpy as np +from math import atan2, cos, sin, pi, degrees, sqrt +from collections import namedtuple + +from scipy.spatial import ConvexHull +from PIL import Image +from scipy.misc import toimage + +parser = argparse.ArgumentParser(description="Creates line images from page image") +parser.add_argument('image_dir', type=str, help='Path to full page images') +parser.add_argument('csv_dir', type=str, help='Path to csv files') +parser.add_argument('out_dir', type=str, help='Path to output directory') +parser.add_argument('--im-format', type=str, default='png', help='What file format are the images') +parser.add_argument('--padding', type=int, default=100, help='Padding so BBox does not exceed image area') +parser.add_argument('--head', type=int, default=-1, help='Number of csv files to process') +args = parser.parse_args() + +""" +bounding_box is a named tuple which contains: + area (float): area of the rectangle + length_parallel (float): length of the side that is parallel to unit_vector + length_orthogonal (float): length of the side that is orthogonal to unit_vector + rectangle_center(int, int): coordinates of the rectangle center + (use rectangle_corners to get the corner points of the rectangle) + unit_vector (float, float): direction of the length_parallel side. + (it's orthogonal vector can be found with the orthogonal_vector function + unit_vector_angle (float): angle of the unit vector to be in radians. + corner_points [(float, float)]: set that contains the corners of the rectangle +""" + +bounding_box_tuple = namedtuple('bounding_box_tuple', 'area ' + 'length_parallel ' + 'length_orthogonal ' + 'rectangle_center ' + 'unit_vector ' + 'unit_vector_angle ' + 'corner_points' + ) + + +def unit_vector(pt0, pt1): + """ Given two points pt0 and pt1, return a unit vector that + points in the direction of pt0 to pt1. + Returns + ------- + (float, float): unit vector + """ + dis_0_to_1 = sqrt((pt0[0] - pt1[0])**2 + (pt0[1] - pt1[1])**2) + return (pt1[0] - pt0[0]) / dis_0_to_1, \ + (pt1[1] - pt0[1]) / dis_0_to_1 + + +def orthogonal_vector(vector): + """ Given a vector, returns a orthogonal/perpendicular vector of equal length. + Returns + ------ + (float, float): A vector that points in the direction orthogonal to vector. + """ + return -1 * vector[1], vector[0] + + +def bounding_area(index, hull): + """ Given index location in an array and convex hull, it gets two points + hull[index] and hull[index+1]. From these two points, it returns a named + tuple that mainly contains area of the box that bounds the hull. This + bounding box orintation is same as the orientation of the lines formed + by the point hull[index] and hull[index+1]. + Returns + ------- + a named tuple that contains: + area: area of the rectangle + length_parallel: length of the side that is parallel to unit_vector + length_orthogonal: length of the side that is orthogonal to unit_vector + rectangle_center: coordinates of the rectangle center + unit_vector: direction of the length_parallel side. + (it's orthogonal vector can be found with the orthogonal_vector function) + """ + unit_vector_p = unit_vector(hull[index], hull[index+1]) + unit_vector_o = orthogonal_vector(unit_vector_p) + + dis_p = tuple(np.dot(unit_vector_p, pt) for pt in hull) + dis_o = tuple(np.dot(unit_vector_o, pt) for pt in hull) + + min_p = min(dis_p) + min_o = min(dis_o) + len_p = max(dis_p) - min_p + len_o = max(dis_o) - min_o + + return {'area': len_p * len_o, + 'length_parallel': len_p, + 'length_orthogonal': len_o, + 'rectangle_center': (min_p + len_p / 2, min_o + len_o / 2), + 'unit_vector': unit_vector_p, + } + + +def to_xy_coordinates(unit_vector_angle, point): + """ Given angle from horizontal axis and a point from origin, + returns converted unit vector coordinates in x, y coordinates. + angle of unit vector should be in radians. + Returns + ------ + (float, float): converted x,y coordinate of the unit vector. + """ + angle_orthogonal = unit_vector_angle + pi / 2 + return point[0] * cos(unit_vector_angle) + point[1] * cos(angle_orthogonal), \ + point[0] * sin(unit_vector_angle) + point[1] * sin(angle_orthogonal) + + +def rotate_points(center_of_rotation, angle, points): + """ Rotates a point cloud around the center_of_rotation point by angle + input + ----- + center_of_rotation (float, float): angle of unit vector to be in radians. + angle (float): angle of rotation to be in radians. + points [(float, float)]: Points to be a list or tuple of points. Points to be rotated. + Returns + ------ + [(float, float)]: Rotated points around center of rotation by angle + """ + rot_points = [] + ang = [] + for pt in points: + diff = tuple([pt[d] - center_of_rotation[d] for d in range(2)]) + diff_angle = atan2(diff[1], diff[0]) + angle + ang.append(diff_angle) + diff_length = sqrt(sum([d**2 for d in diff])) + rot_points.append((center_of_rotation[0] + diff_length * cos(diff_angle), + center_of_rotation[1] + diff_length * sin(diff_angle))) + + return rot_points + + +def rectangle_corners(rectangle): + """ Given rectangle center and its inclination, returns the corner + locations of the rectangle. + Returns + ------ + [(float, float)]: 4 corner points of rectangle. + """ + corner_points = [] + for i1 in (.5, -.5): + for i2 in (i1, -1 * i1): + corner_points.append((rectangle['rectangle_center'][0] + i1 * rectangle['length_parallel'], + rectangle['rectangle_center'][1] + i2 * rectangle['length_orthogonal'])) + + return rotate_points(rectangle['rectangle_center'], rectangle['unit_vector_angle'], corner_points) + + +def get_orientation(origin, p1, p2): + """ + Given origin and two points, return the orientation of the Point p1 with + regards to Point p2 using origin. + Returns + ------- + integer: Negative if p1 is clockwise of p2. + """ + difference = ( + ((p2[0] - origin[0]) * (p1[1] - origin[1])) + - ((p1[0] - origin[0]) * (p2[1] - origin[1])) + ) + return difference + + +def compute_hull(points): + """ + Given input list of points, return a list of points that + made up the convex hull. + Returns + ------- + [(float, float)]: convexhull points + """ + hull_points = [] + start = points[0] + min_x = start[0] + for p in points[1:]: + if p[0] < min_x: + min_x = p[0] + start = p + + point = start + hull_points.append(start) + + far_point = None + while far_point is not start: + p1 = None + for p in points: + if p is point: + continue + else: + p1 = p + break + + far_point = p1 + + for p2 in points: + if p2 is point or p2 is p1: + continue + else: + direction = get_orientation(point, far_point, p2) + if direction > 0: + far_point = p2 + + hull_points.append(far_point) + point = far_point + return hull_points + + +def minimum_bounding_box(points): + """ Given a list of 2D points, it returns the minimum area rectangle bounding all + the points in the point cloud. + Returns + ------ + returns a namedtuple that contains: + area: area of the rectangle + length_parallel: length of the side that is parallel to unit_vector + length_orthogonal: length of the side that is orthogonal to unit_vector + rectangle_center: coordinates of the rectangle center + unit_vector: direction of the length_parallel side. RADIANS + unit_vector_angle: angle of the unit vector + corner_points: set that contains the corners of the rectangle + """ + + if len(points) <= 2: raise ValueError('More than two points required.') + + hull_ordered = [points[index] for index in ConvexHull(points).vertices] + hull_ordered.append(hull_ordered[0]) + #hull_ordered = compute_hull(points) + hull_ordered = tuple(hull_ordered) + + min_rectangle = bounding_area(0, hull_ordered) + for i in range(1, len(hull_ordered)-1): + rectangle = bounding_area(i, hull_ordered) + if rectangle['area'] < min_rectangle['area']: + min_rectangle = rectangle + + min_rectangle['unit_vector_angle'] = atan2(min_rectangle['unit_vector'][1], min_rectangle['unit_vector'][0]) + min_rectangle['rectangle_center'] = to_xy_coordinates(min_rectangle['unit_vector_angle'], min_rectangle['rectangle_center']) + + return bounding_box_tuple( + area = min_rectangle['area'], + length_parallel = min_rectangle['length_parallel'], + length_orthogonal = min_rectangle['length_orthogonal'], + rectangle_center = min_rectangle['rectangle_center'], + unit_vector = min_rectangle['unit_vector'], + unit_vector_angle = min_rectangle['unit_vector_angle'], + corner_points = set(rectangle_corners(min_rectangle)) + ) + + +def get_center(im): + """ Given image, returns the location of center pixel + Returns + ------- + (int, int): center of the image + """ + center_x = im.size[0] / 2 + center_y = im.size[1] / 2 + return int(center_x), int(center_y) + + +def get_horizontal_angle(unit_vector_angle): + """ Given an angle in radians, returns angle of the unit vector in + first or fourth quadrant. + Returns + ------ + (float): updated angle of the unit vector to be in radians. + It is only in first or fourth quadrant. + """ + if unit_vector_angle > pi / 2 and unit_vector_angle <= pi: + unit_vector_angle = unit_vector_angle - pi + elif unit_vector_angle > -pi and unit_vector_angle < -pi / 2: + unit_vector_angle = unit_vector_angle + pi + + return unit_vector_angle + + +def get_smaller_angle(bounding_box): + """ Given a rectangle, returns its smallest absolute angle from horizontal axis. + Returns + ------ + (float): smallest angle of the rectangle to be in radians. + """ + unit_vector = bounding_box.unit_vector + unit_vector_angle = bounding_box.unit_vector_angle + ortho_vector = orthogonal_vector(unit_vector) + ortho_vector_angle = atan2(ortho_vector[1], ortho_vector[0]) + + unit_vector_angle_updated = get_horizontal_angle(unit_vector_angle) + ortho_vector_angle_updated = get_horizontal_angle(ortho_vector_angle) + + if abs(unit_vector_angle_updated) < abs(ortho_vector_angle_updated): + return unit_vector_angle_updated + else: + return ortho_vector_angle_updated + + +def rotated_points(bounding_box, center): + """ Given the rectangle, returns corner points of rotated rectangle. + It rotates the rectangle around the center by its smallest angle. + Returns + ------- + [(int, int)]: 4 corner points of rectangle. + """ + p1, p2, p3, p4 = bounding_box.corner_points + x1, y1 = p1 + x2, y2 = p2 + x3, y3 = p3 + x4, y4 = p4 + center_x, center_y = center + rotation_angle_in_rad = -get_smaller_angle(bounding_box) + x_dash_1 = (x1 - center_x) * cos(rotation_angle_in_rad) - (y1 - center_y) * sin(rotation_angle_in_rad) + center_x + x_dash_2 = (x2 - center_x) * cos(rotation_angle_in_rad) - (y2 - center_y) * sin(rotation_angle_in_rad) + center_x + x_dash_3 = (x3 - center_x) * cos(rotation_angle_in_rad) - (y3 - center_y) * sin(rotation_angle_in_rad) + center_x + x_dash_4 = (x4 - center_x) * cos(rotation_angle_in_rad) - (y4 - center_y) * sin(rotation_angle_in_rad) + center_x + + y_dash_1 = (y1 - center_y) * cos(rotation_angle_in_rad) + (x1 - center_x) * sin(rotation_angle_in_rad) + center_y + y_dash_2 = (y2 - center_y) * cos(rotation_angle_in_rad) + (x2 - center_x) * sin(rotation_angle_in_rad) + center_y + y_dash_3 = (y3 - center_y) * cos(rotation_angle_in_rad) + (x3 - center_x) * sin(rotation_angle_in_rad) + center_y + y_dash_4 = (y4 - center_y) * cos(rotation_angle_in_rad) + (x4 - center_x) * sin(rotation_angle_in_rad) + center_y + return x_dash_1, y_dash_1, x_dash_2, y_dash_2, x_dash_3, y_dash_3, x_dash_4, y_dash_4 + + +def pad_image(image): + """ Given an image, returns a padded image around the border. + This routine save the code from crashing if bounding boxes that are + slightly outside the page boundary. + Returns + ------- + image: page image + """ + offset = int(args.padding // 2) + padded_image = Image.new('RGB', (image.size[0] + int(args.padding), image.size[1] + int(args.padding)), "white") + padded_image.paste(im = image, box = (offset, offset)) + return padded_image + +def update_minimum_bounding_box_input(bounding_box_input): + """ Given list of 2D points, returns list of 2D points shifted by an offset. + Returns + ------ + points [(float, float)]: points, a list or tuple of 2D coordinates + """ + updated_minimum_bounding_box_input = [] + offset = int(args.padding // 2) + for point in bounding_box_input: + x, y = point + new_x = x + offset + new_y = y + offset + word_coordinate = (new_x, new_y) + updated_minimum_bounding_box_input.append(word_coordinate) + + return updated_minimum_bounding_box_input + + +### main ### +csv_count = 0 +for filename in sorted(os.listdir(args.csv_dir)): + if filename.endswith('.csv') and (csv_count < args.head or args.head < 0): + csv_count = csv_count + 1 + with open(os.path.join(args.csv_dir, filename), 'r', encoding='utf-8') as f: + image_file = os.path.join(args.image_dir, os.path.splitext(filename)[0] + '.' + args.im_format) + if not os.path.isfile(image_file): + continue + csv_out_file = os.path.join(args.out_dir, 'truth_csv', filename) + csv_out_fh = open(csv_out_file, 'w', encoding='utf-8') + csv_out_writer = csv.writer(csv_out_fh) + im = Image.open(image_file) + im = pad_image(im) + count = 1 + for row in itertools.islice(csv.reader(f), 0, None): + if count == 1: + count = 0 + continue + + points = [] + points.append((int(row[2]), int(row[3]))) + points.append((int(row[4]), int(row[5]))) + points.append((int(row[6]), int(row[7]))) + points.append((int(row[8]), int(row[9]))) + + x = [int(row[2]), int(row[4]), int(row[6]), int(row[8])] + y = [int(row[3]), int(row[5]), int(row[7]), int(row[9])] + min_x, min_y = min(x), min(y) + max_x, max_y = max(x), max(y) + if min_x == max_x or min_y == max_y: + continue + + try: + updated_mbb_input = update_minimum_bounding_box_input(points) + bounding_box = minimum_bounding_box(updated_mbb_input) + except Exception as e: + print("Error: Skipping Image " + row[1]) + continue + + p1, p2, p3, p4 = bounding_box.corner_points + x1, y1 = p1 + x2, y2 = p2 + x3, y3 = p3 + x4, y4 = p4 + min_x = int(min(x1, x2, x3, x4)) + min_y = int(min(y1, y2, y3, y4)) + max_x = int(max(x1, x2, x3, x4)) + max_y = int(max(y1, y2, y3, y4)) + box = (min_x, min_y, max_x, max_y) + region_initial = im.crop(box) + rot_points = [] + p1_new = (x1 - min_x, y1 - min_y) + p2_new = (x2 - min_x, y2 - min_y) + p3_new = (x3 - min_x, y3 - min_y) + p4_new = (x4 - min_x, y4 - min_y) + rot_points.append(p1_new) + rot_points.append(p2_new) + rot_points.append(p3_new) + rot_points.append(p4_new) + + cropped_bounding_box = bounding_box_tuple(bounding_box.area, + bounding_box.length_parallel, + bounding_box.length_orthogonal, + bounding_box.length_orthogonal, + bounding_box.unit_vector, + bounding_box.unit_vector_angle, + set(rot_points)) + + rotation_angle_in_rad = get_smaller_angle(cropped_bounding_box) + img2 = region_initial.rotate(degrees(rotation_angle_in_rad), resample = Image.BICUBIC) + x_dash_1, y_dash_1, x_dash_2, y_dash_2, x_dash_3, y_dash_3, x_dash_4, y_dash_4 = rotated_points( + cropped_bounding_box, get_center(region_initial)) + + min_x = int(min(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) + min_y = int(min(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) + max_x = int(max(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) + max_y = int(max(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) + box = (min_x, min_y, max_x, max_y) + region_final = img2.crop(box) + csv_out_writer.writerow(row) + image_out_file = os.path.join(args.out_dir, 'truth_line_image', row[1]) + region_final.save(image_out_file) diff --git a/egs/yomdle_fa/v1/local/extract_features.sh b/egs/yomdle_fa/v1/local/extract_features.sh new file mode 100755 index 00000000000..f75837ae5b3 --- /dev/null +++ b/egs/yomdle_fa/v1/local/extract_features.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Copyright 2017 Yiwen Shao +# 2018 Ashish Arora + +nj=4 +cmd=run.pl +feat_dim=40 +fliplr=false +augment='no_aug' +num_channels=3 +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +data=$1 +featdir=$data/data +scp=$data/images.scp +logdir=$data/log + +mkdir -p $logdir +mkdir -p $featdir + +# make $featdir an absolute pathname +featdir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $featdir ${PWD}` + +for n in $(seq $nj); do + split_scps="$split_scps $logdir/images.$n.scp" +done + +# split images.scp +utils/split_scp.pl $scp $split_scps || exit 1; + +$cmd JOB=1:$nj $logdir/extract_features.JOB.log \ + image/ocr/make_features.py $logdir/images.JOB.scp \ + --allowed_len_file_path $data/allowed_lengths.txt \ + --feat-dim $feat_dim --num-channels $num_channels --fliplr $fliplr --augment_type $augment \| \ + copy-feats --compress=true --compression-method=7 \ + ark:- ark,scp:$featdir/images.JOB.ark,$featdir/images.JOB.scp + +## aggregates the output scp's to get feats.scp +for n in $(seq $nj); do + cat $featdir/images.$n.scp || exit 1; +done > $data/feats.scp || exit 1 diff --git a/egs/yomdle_fa/v1/local/gedi2csv.py b/egs/yomdle_fa/v1/local/gedi2csv.py new file mode 100755 index 00000000000..43a07421dd1 --- /dev/null +++ b/egs/yomdle_fa/v1/local/gedi2csv.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python3 + +""" +GEDI2CSV +Convert GEDI-type bounding boxes to CSV format + +GEDI Format Example: + + + + + + + + + +CSV Format Example +ID,name,col1,row1,col2,row2,col3,row3,col4,row4,confidence,truth,pgrot,bbrot,qual,script,lang +0,chinese_scanned_books_0001_0.png,99,41,99,14,754,14,754,41,100,凡我的邻人说是好的,有一大部分在我灵魂中却,0,0.0,0,,zh-cn +""" + +import logging +import os +import sys +import time +import glob +import csv +import imghdr +from PIL import Image +import argparse +import pdb +import cv2 +import numpy as np +import xml.etree.ElementTree as ET + +sin = np.sin +cos = np.cos +pi = np.pi + +def Rotate2D(pts, cnt, ang=90): + M = np.array([[cos(ang),-sin(ang)],[sin(ang),cos(ang)]]) + res = np.dot(pts-cnt,M)+cnt + return M, res + +def npbox2string(npar): + if np.shape(npar)[0] != 1: + print('Error during CSV conversion\n') + c1,r1 = npar[0][0],npar[0][1] + c2,r2 = npar[0][2],npar[0][3] + c3,r3 = npar[0][4],npar[0][5] + c4,r4 = npar[0][6],npar[0][7] + + return c1,r1,c2,r2,c3,r3,c4,r4 + +# cv2.minAreaRect() returns a Box2D structure which contains following detals - ( center (x,y), (width, height), angle of rotation ) +# Get 4 corners of the rectangle using cv2.boxPoints() + +class GEDI2CSV(): + + """ Initialize the extractor""" + def __init__(self, logger, args): + self._logger = logger + self._args = args + + """ + Segment image with GEDI bounding box information + """ + def csvfile(self, coords, polys, baseName, pgrot): + + """ for writing the files """ + writePath = self._args.outputDir + writePath = os.path.join(writePath,'') + if os.path.isdir(writePath) != True: + os.makedirs(writePath) + + rotlist = [] + + header=['ID','name','col1','row1','col2','row2','col3','row3','col4','row4','confidence','truth','pgrot','bbrot','qual','script','text_type'] + conf=100 + write_ctr = 0 + if len(coords) == 0 and len(polys) == 0: + self._logger.info('Found %s with no text content',(baseName)) + print('...Found %s with no text content' % (baseName)) + return + + strPos = writePath + baseName + + """ for each group of coordinates """ + for i in coords: + + [id,x,y,w,h,degrees,text,qual,script,text_type] = i + + contour = np.array([(x,y),(x+w,y),(x+w,y+h),(x,y+h)]) + + """ + First rotate around upper left corner based on orientationD keyword + """ + M, rot = Rotate2D(contour, np.array([x,y]), degrees*pi/180) + rot = np.int0(rot) + + # rot is the 8 points rotated by degrees + # pgrot is the rotation after extraction, so save + + # save rotated points to list or array + rot = np.reshape(rot,(-1,1)).T + c1,r1,c2,r2,c3,r3,c4,r4 = npbox2string(rot) + + text = text.replace(u'\ufeff','') + + bbrot = degrees + rotlist.append([id,baseName + '_' + id + '.png',c1,r1,c2,r2,c3,r3,c4,r4,conf,text,pgrot,bbrot,qual,script,text_type]) + + # if there are polygons, first save the text + for j in polys: + arr = [] + [id,poly_val,text,qual,script,text_type] = j + for i in poly_val: + arr.append(eval(i)) + + contour = np.asarray(arr) + convex = cv2.convexHull(contour) + rect = cv2.minAreaRect(convex) + box = cv2.boxPoints(rect) + box = np.int0(box) + box = np.reshape(box,(-1,1)).T + c1,r1,c2,r2,c3,r3,c4,r4 = npbox2string(box) + + bbrot = 0.0 + + rotlist.append([id,baseName + '_' + id + '.png',c1,r1,c2,r2,c3,r3,c4,r4,conf,text,pgrot,bbrot,qual,script,text_type]) + + # then write out all of list to file + with open(strPos + ".csv", "w", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(header) + for row in rotlist: + writer.writerow(row) + write_ctr += 1 + + return write_ctr + + +def main(args): + + startTime = time.clock() + + writePath = args.outputDir + if os.path.isdir(writePath) != True: + os.makedirs(writePath) + + """ Setup logging """ + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + if args.log: + handler = logging.FileHandler(args.log) + handler.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + + gtconverter = GEDI2CSV(logger, args) + namespaces = {"gedi" : "http://lamp.cfar.umd.edu/media/projects/GEDI/"} + keyCnt=0 + + fileCnt = 0 + line_write_ctr = 0 + line_error_ctr = 0 + + """ + Get all XML files in the directory and sub folders + """ + for root, dirnames, filenames in os.walk(args.inputDir, followlinks=True): + for file in filenames: + if file.lower().endswith('.xml'): + fullName = os.path.join(root,file) + baseName = os.path.splitext(fullName) + + fileCnt += 1 + + """ read the XML file """ + tree = ET.parse(fullName) + gedi_root = tree.getroot() + child = gedi_root.findall('gedi:DL_DOCUMENT',namespaces)[0] + totalpages = int(child.attrib['NrOfPages']) + coordinates=[] + polygons = [] + if args.ftype == 'boxed': + fileTypeStr = 'col' + elif args.ftype == 'transcribed': + fileTypeStr = 'Text_Content' + else: + print('Filetype must be either boxed or transcribed!') + logger.info('Filetype must be either boxed or transcribed!') + sys.exit(-1) + + if args.quality == 'both': + qualset = {'Regular','Low-Quality'} + elif args.quality == 'low': + qualset = {'Low-Quality'} + elif args.quality == 'regular': + qualset = {'Regular'} + else: + print('Quality must be both, low or regular!') + logger.info('Quality must be both, low or regular!') + sys.exit(-1) + + + + """ and for each page """ + for i, pgs in enumerate(child.iterfind('gedi:DL_PAGE',namespaces)): + + if 'GEDI_orientation' not in pgs.attrib: + pageRot=0 + else: + pageRot = int(pgs.attrib['GEDI_orientation']) + logger.info(' PAGE ROTATION %s, %s' % (fullName, str(pageRot))) + + """ find children for each page """ + for zone in pgs.findall('gedi:DL_ZONE',namespaces): + + if zone.attrib['gedi_type']=='Text' and zone.attrib['Type'] in \ + ('Machine_Print','Confusable_Allograph','Handwriting') and zone.attrib['Quality'] in qualset: + if zone.get('polygon'): + keyCnt+=1 + polygons.append([zone.attrib['id'],zone.get('polygon').split(';'), + zone.get('Text_Content'),zone.get('Quality'),zone.get('Script'),zone.get('Type')]) + elif zone.get(fileTypeStr) != None: + keyCnt+=1 + coord = [zone.attrib['id'],int(zone.attrib['col']),int(zone.attrib['row']), + int(zone.attrib['width']), int(zone.attrib['height']), + float(zone.get('orientationD',0.0)), + zone.get('Text_Content'),zone.get('Quality'),zone.get('Script'),zone.get('Type')] + coordinates.append(coord) + + if len(coordinates) > 0 or len(polygons) > 0: + line_write_ctr += gtconverter.csvfile(coordinates, polygons, os.path.splitext(file)[0], pageRot) + else: + print('...%s has no applicable content' % (baseName[0])) + + print('complete...total files %d, lines written %d' % (fileCnt, line_write_ctr)) + + +def parse_arguments(argv): + """ Args and defaults """ + parser = argparse.ArgumentParser() + + parser.add_argument('--inputDir', type=str, help='Input directory', required=True) + parser.add_argument('--outputDir', type=str, help='Output directory', required=True) + parser.add_argument('--ftype', type=str, help='GEDI file type (either "boxed" or "transcribed")', default='transcribed') + parser.add_argument('--quality', type=str, help='GEDI file quality (either "both" or "low" or "regular")', default='regular') + parser.add_argument('--log', type=str, help='Log directory', default='./GEDI2CSV_enriched.log') + + return parser.parse_args(argv) + +if __name__ == '__main__': + """ Run """ + main(parse_arguments(sys.argv[1:])) + + + + + + diff --git a/egs/yomdle_fa/v1/local/prepare_dict.sh b/egs/yomdle_fa/v1/local/prepare_dict.sh new file mode 100755 index 00000000000..f1b1a8d70cc --- /dev/null +++ b/egs/yomdle_fa/v1/local/prepare_dict.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +# Copyright 2017 Hossein Hadian +# 2017 Chun Chieh Chang +# 2017 Ashish Arora + +# This script prepares the dictionary. + +set -e +dir=data/local/dict +data_dir=data + +. ./utils/parse_options.sh || exit 1; + +base_dir=$(echo "$DIRECTORY" | cut -d "/" -f2) + +mkdir -p $dir + +local/prepare_lexicon.py --data-dir $data_dir $dir + +sed -i '/^\s*$/d' $dir/lexicon.txt +cut -d' ' -f2- $dir/lexicon.txt | sed 's/SIL//g' | tr ' ' '\n' | sort -u | sed '/^$/d' >$dir/nonsilence_phones.txt || exit 1; + +echo ' SIL' >> $dir/lexicon.txt + +echo SIL > $dir/silence_phones.txt + +echo SIL >$dir/optional_silence.txt + +echo -n "" >$dir/extra_questions.txt diff --git a/egs/yomdle_fa/v1/local/prepare_lexicon.py b/egs/yomdle_fa/v1/local/prepare_lexicon.py new file mode 100755 index 00000000000..46be4f37970 --- /dev/null +++ b/egs/yomdle_fa/v1/local/prepare_lexicon.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora + +import argparse +import os + +parser = argparse.ArgumentParser(description="""Creates the list of characters and words in lexicon""") +parser.add_argument('dir', type=str, help='output path') +parser.add_argument('--data-dir', type=str, default='data', help='Path to text file') +args = parser.parse_args() + +### main ### +lex = {} +text_path = os.path.join(args.data_dir, 'train', 'text') +text_fh = open(text_path, 'r', encoding='utf-8') + +with open(text_path, 'r', encoding='utf-8') as f: + for line in f: + line_vect = line.strip().split(' ') + for i in range(1, len(line_vect)): + characters = list(line_vect[i]) + # Put SIL instead of "|". Because every "|" in the beginning of the words is for initial-space of that word + characters = " ".join([ 'SIL' if char == '|' else char for char in characters]) + characters = characters.replace('#','') + lex[line_vect[i]] = characters + +with open(os.path.join(args.dir, 'lexicon.txt'), 'w', encoding='utf-8') as fp: + for key in sorted(lex): + fp.write(key + " " + lex[key] + "\n") diff --git a/egs/yomdle_fa/v1/local/process_data.py b/egs/yomdle_fa/v1/local/process_data.py new file mode 100755 index 00000000000..3423cc5380e --- /dev/null +++ b/egs/yomdle_fa/v1/local/process_data.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora +# 2018 Chun Chieh Chang + +""" This script reads the extracted Farsi OCR (yomdle and slam) database files + and creates the following files (for the data subset selected via --dataset): + text, utt2spk, images.scp. + Eg. local/process_data.py data/download/ data/local/splits/train.txt data/train + Eg. text file: english_phone_books_0001_1 To sum up, then, it would appear that + utt2spk file: english_phone_books_0001_0 english_phone_books_0001 + images.scp file: english_phone_books_0001_0 \ + data/download/truth_line_image/english_phone_books_0001_0.png +""" + +import argparse +import os +import sys +import csv +import itertools +import unicodedata + +parser = argparse.ArgumentParser(description="Creates text, utt2spk, and images.scp files") +parser.add_argument('database_path', type=str, help='Path to data') +parser.add_argument('out_dir', type=str, help='directory to output files') +parser.add_argument('--head', type=int, default=-1, help='limit on number of synth data') +args = parser.parse_args() + +### main ### +print("Processing '{}' data...".format(args.out_dir)) + +text_file = os.path.join(args.out_dir, 'text') +text_fh = open(text_file, 'w', encoding='utf-8') +utt2spk_file = os.path.join(args.out_dir, 'utt2spk') +utt2spk_fh = open(utt2spk_file, 'w', encoding='utf-8') +image_file = os.path.join(args.out_dir, 'images.scp') +image_fh = open(image_file, 'w', encoding='utf-8') + +count = 0 +for filename in sorted(os.listdir(os.path.join(args.database_path, 'truth_csv'))): + if filename.endswith('.csv') and (count < args.head or args.head < 0): + count = count + 1 + csv_filepath = os.path.join(args.database_path, 'truth_csv', filename) + csv_file = open(csv_filepath, 'r', encoding='utf-8') + row_count = 0 + for row in csv.reader(csv_file): + if row_count == 0: + row_count = 1 + continue + image_id = os.path.splitext(row[1])[0] + image_filepath = os.path.join(args.database_path, 'truth_line_image', row[1]) + text = unicodedata.normalize('NFC', row[11]) + file_info = os.stat(image_filepath) + if file_info.st_size != 0: + if text: + text_fh.write(image_id + ' ' + text + '\n') + utt2spk_fh.write(image_id + ' ' + '_'.join(image_id.split('_')[:-1]) + '\n') + image_fh.write(image_id + ' ' + image_filepath + ' ' + row[13] + '\n') diff --git a/egs/yomdle_fa/v1/local/score.sh b/egs/yomdle_fa/v1/local/score.sh new file mode 100755 index 00000000000..f2405205f02 --- /dev/null +++ b/egs/yomdle_fa/v1/local/score.sh @@ -0,0 +1,5 @@ +#!/bin/bash + + +steps/scoring/score_kaldi_wer.sh --max-lmwt 10 "$@" +steps/scoring/score_kaldi_cer.sh --max-lmwt 10 --stage 2 "$@" diff --git a/egs/yomdle_fa/v1/local/train_lm.sh b/egs/yomdle_fa/v1/local/train_lm.sh new file mode 100755 index 00000000000..bc738f217da --- /dev/null +++ b/egs/yomdle_fa/v1/local/train_lm.sh @@ -0,0 +1,110 @@ +#!/bin/bash + +# Copyright 2016 Vincent Nguyen +# 2016 Johns Hopkins University (author: Daniel Povey) +# 2017 Ashish Arora +# 2017 Hossein Hadian +# Apache 2.0 +# +# This script trains a LM on the YOMDLE training transcriptions. +# It is based on the example scripts distributed with PocoLM + +# It will check if pocolm is installed and if not will proceed with installation + +set -e +stage=0 +dir=data/local/local_lm +data_dir=data + +echo "$0 $@" # Print the command line for logging +. ./utils/parse_options.sh || exit 1; + +lm_dir=${dir}/data + + +mkdir -p $dir +. ./path.sh || exit 1; # for KALDI_ROOT +export PATH=$KALDI_ROOT/tools/pocolm/scripts:$PATH +( # First make sure the pocolm toolkit is installed. + cd $KALDI_ROOT/tools || exit 1; + if [ -d pocolm ]; then + echo Not installing the pocolm toolkit since it is already there. + else + echo "$0: Please install the PocoLM toolkit with: " + echo " cd ../../../tools; extras/install_pocolm.sh; cd -" + exit 1; + fi +) || exit 1; + +bypass_metaparam_optim_opt= +# If you want to bypass the metaparameter optimization steps with specific metaparameters +# un-comment the following line, and change the numbers to some appropriate values. +# You can find the values from output log of train_lm.py. +# These example numbers of metaparameters is for 4-gram model (with min-counts) +# running with train_lm.py. +# The dev perplexity should be close to the non-bypassed model. +#bypass_metaparam_optim_opt= +# Note: to use these example parameters, you may need to remove the .done files +# to make sure the make_lm_dir.py be called and tain only 3-gram model +#for order in 3; do +#rm -f ${lm_dir}/${num_word}_${order}.pocolm/.done + +if [ $stage -le 0 ]; then + mkdir -p ${dir}/data + mkdir -p ${dir}/data/text + + echo "$0: Getting the Data sources" + + rm ${dir}/data/text/* 2>/dev/null || true + + # Note: the name 'dev' is treated specially by pocolm, it automatically + # becomes the dev set. + nr=`cat $data_dir/train/text | wc -l` + nr_dev=$(($nr / 10 )) + nr_train=$(( $nr - $nr_dev )) + + # use the training data as an additional data source. + # we can later fold the dev data into this. + head -n $nr_train $data_dir/train/text | cut -d " " -f 2- > ${dir}/data/text/train.txt + tail -n $nr_dev $data_dir/train/text | cut -d " " -f 2- > ${dir}/data/text/dev.txt + + # for reporting perplexities, we'll use the "real" dev set. + # (the validation data is used as ${dir}/data/text/dev.txt to work + # out interpolation weights.) + # note, we can't put it in ${dir}/data/text/, because then pocolm would use + # it as one of the data sources. + cut -d " " -f 2- < $data_dir/test/text > ${dir}/data/real_dev_set.txt + + # get the wordlist from MADCAT text + cat ${dir}/data/text/train.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + cat ${dir}/data/word_count | awk '{print $2}' > ${dir}/data/wordlist +fi + +order=3 + +if [ $stage -le 1 ]; then + # decide on the vocabulary. + # Note: you'd use --wordlist if you had a previously determined word-list + # that you wanted to use. + # Note: if you have more than one order, use a certain amount of words as the + # vocab and want to restrict max memory for 'sort', + echo "$0: training the unpruned LM" + min_counts='train=1' + wordlist=${dir}/data/wordlist + + lm_name="`basename ${wordlist}`_${order}" + if [ -n "${min_counts}" ]; then + lm_name+="_`echo ${min_counts} | tr -s "[:blank:]" "_" | tr "=" "-"`" + fi + unpruned_lm_dir=${lm_dir}/${lm_name}.pocolm + train_lm.py --wordlist=${wordlist} --num-splits=5 --warm-start-ratio=1 \ + --min-counts="$min_counts" \ + --limit-unk-history=true \ + ${bypass_metaparam_optim_opt} \ + ${dir}/data/text ${order} ${lm_dir}/work ${unpruned_lm_dir} + + get_data_prob.py ${dir}/data/real_dev_set.txt ${unpruned_lm_dir} 2>&1 | grep -F '[perplexity' + + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${unpruned_lm_dir} | gzip -c > ${dir}/data/arpa/${order}gram_unpruned.arpa.gz +fi diff --git a/egs/yomdle_fa/v1/local/train_lm_lr.sh b/egs/yomdle_fa/v1/local/train_lm_lr.sh new file mode 100755 index 00000000000..5bfc20acdeb --- /dev/null +++ b/egs/yomdle_fa/v1/local/train_lm_lr.sh @@ -0,0 +1,113 @@ +#!/bin/bash + +# Copyright 2016 Vincent Nguyen +# 2016 Johns Hopkins University (author: Daniel Povey) +# 2017 Ashish Arora +# 2017 Hossein Hadian +# Apache 2.0 +# +# This script trains a LM on the YOMDLE+Extra training transcriptions. +# It is based on the example scripts distributed with PocoLM + +# It will check if pocolm is installed and if not will proceed with installation + +set -e +stage=0 +dir=data/local/local_lm +data_dir=data +extra_lm=download/extra_lm.txt +order=3 + +echo "$0 $@" # Print the command line for logging +. ./utils/parse_options.sh || exit 1; + +lm_dir=${dir}/data + + +mkdir -p $dir +. ./path.sh || exit 1; # for KALDI_ROOT +export PATH=$KALDI_ROOT/tools/pocolm/scripts:$PATH +( # First make sure the pocolm toolkit is installed. + cd $KALDI_ROOT/tools || exit 1; + if [ -d pocolm ]; then + echo Not installing the pocolm toolkit since it is already there. + else + echo "$0: Please install the PocoLM toolkit with: " + echo " cd ../../../tools; extras/install_pocolm.sh; cd -" + exit 1; + fi +) || exit 1; + +bypass_metaparam_optim_opt= +# If you want to bypass the metaparameter optimization steps with specific metaparameters +# un-comment the following line, and change the numbers to some appropriate values. +# You can find the values from output log of train_lm.py. +# These example numbers of metaparameters is for 4-gram model (with min-counts) +# running with train_lm.py. +# The dev perplexity should be close to the non-bypassed model. +#bypass_metaparam_optim_opt= +# Note: to use these example parameters, you may need to remove the .done files +# to make sure the make_lm_dir.py be called and tain only 3-gram model +#for order in 3; do +#rm -f ${lm_dir}/${num_word}_${order}.pocolm/.done + +if [ $stage -le 0 ]; then + mkdir -p ${dir}/data + mkdir -p ${dir}/data/text + + echo "$0: Getting the Data sources" + + rm ${dir}/data/text/* 2>/dev/null || true + + cat ${extra_lm} | local/bidi.py | utils/lang/bpe/prepend_words.py --encoding 'utf-8' | python3 utils/lang/bpe/apply_bpe.py -c $data_dir/train/bpe.out | sed 's/@@//g' > ${dir}/data/text/extra_lm.txt + + # Note: the name 'dev' is treated specially by pocolm, it automatically + # becomes the dev set. + nr=`cat $data_dir/train/text | wc -l` + nr_dev=$(($nr / 10 )) + nr_train=$(( $nr - $nr_dev )) + + # use the training data as an additional data source. + # we can later fold the dev data into this. + head -n $nr_train $data_dir/train/text | cut -d " " -f 2- > ${dir}/data/text/train.txt + tail -n $nr_dev $data_dir/train/text | cut -d " " -f 2- > ${dir}/data/text/dev.txt + + # for reporting perplexities, we'll use the "real" dev set. + # (the validation data is used as ${dir}/data/text/dev.txt to work + # out interpolation weights.) + # note, we can't put it in ${dir}/data/text/, because then pocolm would use + # it as one of the data sources. + cut -d " " -f 2- < $data_dir/test/text > ${dir}/data/real_dev_set.txt + + # get the wordlist from MADCAT text + cat ${dir}/data/text/{train,extra_lm}.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + #cat ${dir}/data/text/extra_fa.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + cat ${dir}/data/word_count | awk '{print $2}' > ${dir}/data/wordlist +fi + +if [ $stage -le 1 ]; then + # decide on the vocabulary. + # Note: you'd use --wordlist if you had a previously determined word-list + # that you wanted to use. + # Note: if you have more than one order, use a certain amount of words as the + # vocab and want to restrict max memory for 'sort', + echo "$0: training the unpruned LM" + min_counts='extra_lm=10 train=1' + wordlist=${dir}/data/wordlist + + lm_name="`basename ${wordlist}`_${order}" + if [ -n "${min_counts}" ]; then + lm_name+="_`echo ${min_counts} | tr -s "[:blank:]" "_" | tr "=" "-"`" + fi + unpruned_lm_dir=${lm_dir}/${lm_name}.pocolm + train_lm.py --wordlist=${wordlist} --num-splits=30 --warm-start-ratio=1 \ + --min-counts="$min_counts" \ + --limit-unk-history=true \ + ${bypass_metaparam_optim_opt} \ + ${dir}/data/text ${order} ${lm_dir}/work ${unpruned_lm_dir} + + get_data_prob.py ${dir}/data/real_dev_set.txt ${unpruned_lm_dir} 2>&1 | grep -F '[perplexity' + + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${unpruned_lm_dir} | gzip -c > ${dir}/data/arpa/${order}gram_unpruned.arpa.gz +fi diff --git a/egs/yomdle_fa/v1/local/wer_output_filter b/egs/yomdle_fa/v1/local/wer_output_filter new file mode 100755 index 00000000000..08d5563bca4 --- /dev/null +++ b/egs/yomdle_fa/v1/local/wer_output_filter @@ -0,0 +1,151 @@ +#!/usr/bin/env perl +# Copyright 2012-2014 Johns Hopkins University (Author: Yenda Trmal) +# Apache 2.0 + +use utf8; + +use open qw(:encoding(utf8)); +binmode STDIN, ":utf8"; +binmode STDOUT, ":utf8"; +binmode STDERR, ":utf8"; + +# Arabic-specific normalization +while (<>) { + @F = split " "; + print "$F[0] "; + foreach $s (@F[1..$#F]) { + # Normalize tabs, spaces, and no-break spaces + $s =~ s/[\x{0009}\x{0020}\x{00A0}]+/ /g; + # Normalize "dots"/"filled-circles" to periods + $s =~ s/[\x{25CF}\x{u2022}\x{2219}]+/\x{002E}/g; + # Normalize dashes to regular hyphen + $s =~ s/[\x{2010}\x{2011}\x{2012}\x{2013}\x{2014}\x{2015}]+/\x{002D}/g; + # Normalize various parenthesis to regular parenthesis + $s =~ s/\x{UFF09}/\x{0029}/g; + $s =~ s/\x{UFF08}/\x{0028}/g; + + # Convert various presentation forms to base form + $s =~ s/[\x{FED1}\x{FED3}\x{FED4}\x{FED2}]+/\x{0641}/g; + $s =~ s/[\x{FBB0}\x{FBB1}]+/\x{06D3}/g; + $s =~ s/[\x{FECD}\x{FECF}\x{FED0}\x{FECE}]+/\x{063A}/g; + $s =~ s/[\x{FBDD}]+/\x{0677}/g; + $s =~ s/[\x{FBA6}\x{FBA8}\x{FBA9}\x{FBA7}]+/\x{06C1}/g; + $s =~ s/[\x{FEC1}\x{FEC3}\x{FEC4}\x{FEC2}]+/\x{0637}/g; + $s =~ s/[\x{FE85}\x{FE86}]+/\x{0624}/g; + $s =~ s/[\x{FEA5}\x{FEA7}\x{FEA8}\x{FEA6}]+/\x{062E}/g; + $s =~ s/[\x{FBD9}\x{FBDA}]+/\x{06C6}/g; + $s =~ s/[\x{FE8F}\x{FE91}\x{FE92}\x{FE90}]+/\x{0628}/g; + $s =~ s/[\x{FEED}\x{FEEE}]+/\x{0648}/g; + $s =~ s/[\x{FE99}\x{FE9B}\x{FE9C}\x{FE9A}]+/\x{062B}/g; + $s =~ s/[\x{FEBD}\x{FEBF}\x{FEC0}\x{FEBE}]+/\x{0636}/g; + $s =~ s/[\x{FEE5}\x{FEE7}\x{FEE8}\x{FEE6}]+/\x{0646}/g; + $s =~ s/[\x{FBFC}\x{FBFE}\x{FBFF}\x{FBFD}]+/\x{06CC}/g; + $s =~ s/[\x{FBA4}\x{FBA5}]+/\x{06C0}/g; + $s =~ s/[\x{FB72}\x{FB74}\x{FB75}\x{FB73}]+/\x{0684}/g; + $s =~ s/[\x{FBD3}\x{FBD5}\x{FBD6}\x{FBD4}]+/\x{06AD}/g; + $s =~ s/[\x{FB6A}\x{FB6C}\x{FB6D}\x{FB6B}]+/\x{06A4}/g; + $s =~ s/[\x{FB66}\x{FB68}\x{FB69}\x{FB67}]+/\x{0679}/g; + $s =~ s/[\x{FB5E}\x{FB60}\x{FB61}\x{FB5F}]+/\x{067A}/g; + $s =~ s/[\x{FB88}\x{FB89}]+/\x{0688}/g; + $s =~ s/[\x{FB7E}\x{FB80}\x{FB81}\x{FB7F}]+/\x{0687}/g; + $s =~ s/[\x{FB8E}\x{FB90}\x{FB91}\x{FB8F}]+/\x{06A9}/g; + $s =~ s/[\x{FB86}\x{FB87}]+/\x{068E}/g; + $s =~ s/[\x{FE83}\x{FE84}]+/\x{0623}/g; + $s =~ s/[\x{FB8A}\x{FB8B}]+/\x{0698}/g; + $s =~ s/[\x{FED5}\x{FED7}\x{FED8}\x{FED6}]+/\x{0642}/g; + $s =~ s/[\x{FED9}\x{FEDB}\x{FEDC}\x{FEDA}]+/\x{0643}/g; + $s =~ s/[\x{FBE0}\x{FBE1}]+/\x{06C5}/g; + $s =~ s/[\x{FEB9}\x{FEBB}\x{FEBC}\x{FEBA}]+/\x{0635}/g; + $s =~ s/[\x{FEC5}\x{FEC7}\x{FEC8}\x{FEC6}]+/\x{0638}/g; + $s =~ s/[\x{FE8D}\x{FE8E}]+/\x{0627}/g; + $s =~ s/[\x{FB9A}\x{FB9C}\x{FB9D}\x{FB9B}]+/\x{06B1}/g; + $s =~ s/[\x{FEAD}\x{FEAE}]+/\x{0631}/g; + $s =~ s/[\x{FEF1}\x{FEF3}\x{FEF4}\x{FEF2}]+/\x{064A}/g; + $s =~ s/[\x{FE93}\x{FE94}]+/\x{0629}/g; + $s =~ s/[\x{FBE4}\x{FBE6}\x{FBE7}\x{FBE5}]+/\x{06D0}/g; + $s =~ s/[\x{FE89}\x{FE8B}\x{FE8C}\x{FE8A}]+/\x{0626}/g; + $s =~ s/[\x{FB84}\x{FB85}]+/\x{068C}/g; + $s =~ s/[\x{FE9D}\x{FE9F}\x{FEA0}\x{FE9E}]+/\x{062C}/g; + $s =~ s/[\x{FB82}\x{FB83}]+/\x{068D}/g; + $s =~ s/[\x{FEA1}\x{FEA3}\x{FEA4}\x{FEA2}]+/\x{062D}/g; + $s =~ s/[\x{FB52}\x{FB54}\x{FB55}\x{FB53}]+/\x{067B}/g; + $s =~ s/[\x{FB92}\x{FB94}\x{FB95}\x{FB93}]+/\x{06AF}/g; + $s =~ s/[\x{FB7A}\x{FB7C}\x{FB7D}\x{FB7B}]+/\x{0686}/g; + $s =~ s/[\x{FBDB}\x{FBDC}]+/\x{06C8}/g; + $s =~ s/[\x{FB56}\x{FB58}\x{FB59}\x{FB57}]+/\x{067E}/g; + $s =~ s/[\x{FEB5}\x{FEB7}\x{FEB8}\x{FEB6}]+/\x{0634}/g; + $s =~ s/[\x{FBE2}\x{FBE3}]+/\x{06C9}/g; + $s =~ s/[\x{FB96}\x{FB98}\x{FB99}\x{FB97}]+/\x{06B3}/g; + $s =~ s/[\x{FE80}]+/\x{0621}/g; + $s =~ s/[\x{FBAE}\x{FBAF}]+/\x{06D2}/g; + $s =~ s/[\x{FB62}\x{FB64}\x{FB65}\x{FB63}]+/\x{067F}/g; + $s =~ s/[\x{FEE9}\x{FEEB}\x{FEEC}\x{FEEA}]+/\x{0647}/g; + $s =~ s/[\x{FE81}\x{FE82}]+/\x{0622}/g; + $s =~ s/[\x{FBDE}\x{FBDF}]+/\x{06CB}/g; + $s =~ s/[\x{FE87}\x{FE88}]+/\x{0625}/g; + $s =~ s/[\x{FB6E}\x{FB70}\x{FB71}\x{FB6F}]+/\x{06A6}/g; + $s =~ s/[\x{FBA0}\x{FBA2}\x{FBA3}\x{FBA1}]+/\x{06BB}/g; + $s =~ s/[\x{FBAA}\x{FBAC}\x{FBAD}\x{FBAB}]+/\x{06BE}/g; + $s =~ s/[\x{FEA9}\x{FEAA}]+/\x{062F}/g; + $s =~ s/[\x{FEE1}\x{FEE3}\x{FEE4}\x{FEE2}]+/\x{0645}/g; + $s =~ s/[\x{FEEF}\x{FBE8}\x{FBE9}\x{FEF0}]+/\x{0649}/g; + $s =~ s/[\x{FB8C}\x{FB8D}]+/\x{0691}/g; + $s =~ s/[\x{FB76}\x{FB78}\x{FB79}\x{FB77}]+/\x{0683}/g; + $s =~ s/[\x{FB5A}\x{FB5C}\x{FB5D}\x{FB5B}]+/\x{0680}/g; + $s =~ s/[\x{FB9E}\x{FB9F}]+/\x{06BA}/g; + $s =~ s/[\x{FEC9}\x{FECB}\x{FECC}\x{FECA}]+/\x{0639}/g; + $s =~ s/[\x{FEDD}\x{FEDF}\x{FEE0}\x{FEDE}]+/\x{0644}/g; + $s =~ s/[\x{FB50}\x{FB51}]+/\x{0671}/g; + $s =~ s/[\x{FEB1}\x{FEB3}\x{FEB4}\x{FEB2}]+/\x{0633}/g; + $s =~ s/[\x{FE95}\x{FE97}\x{FE98}\x{FE96}]+/\x{062A}/g; + $s =~ s/[\x{FBD7}\x{FBD8}]+/\x{06C7}/g; + $s =~ s/[\x{FEAF}\x{FEB0}]+/\x{0632}/g; + $s =~ s/[\x{FEAB}\x{FEAC}]+/\x{0630}/g; + + # Remove tatweel + $s =~ s/\x{0640}//g; + # Remove vowels and hamza + $s =~ s/[\x{064B}-\x{0655}]+//g; + # Remove right-to-left and left-to-right + $s =~ s/[\x{200F}\x{200E}]+//g; + # Arabic Keheh to Arabic Kaf + $s =~ s/\x{06A9}/\x{0643}/g; + # Arabic Yeh to Farsi Yeh + $s =~ s/\x{064A}/\x{06CC}/g; + # Decompose RIAL + $s =~ s/\x{FDFC}/\x{0631}\x{06CC}\x{0627}\x{0644}/g; + # Farsi arabic-indic digits to arabic-indic digits + $s =~ s/\x{06F0}/\x{0660}/g; + $s =~ s/\x{06F1}/\x{0661}/g; + $s =~ s/\x{06F2}/\x{0662}/g; + $s =~ s/\x{06F3}/\x{0663}/g; + $s =~ s/\x{06F4}/\x{0664}/g; + $s =~ s/\x{06F5}/\x{0665}/g; + $s =~ s/\x{06F6}/\x{0666}/g; + $s =~ s/\x{06F7}/\x{0667}/g; + $s =~ s/\x{06F8}/\x{0668}/g; + $s =~ s/\x{06F9}/\x{0669}/g; + # Arabic-indic digits to digits + $s =~ s/\x{0660}/0/g; + $s =~ s/\x{0661}/1/g; + $s =~ s/\x{0662}/2/g; + $s =~ s/\x{0663}/3/g; + $s =~ s/\x{0664}/4/g; + $s =~ s/\x{0665}/5/g; + $s =~ s/\x{0666}/6/g; + $s =~ s/\x{0667}/7/g; + $s =~ s/\x{0668}/8/g; + $s =~ s/\x{0669}/9/g; + # Arabic comma to comma + $s =~ s/\x{060C}/\x{002C}/g; + + $s =~ s/\|/ /g; + if ($s ne "") { + print "$s"; + } else { + print ""; + } + } + print "\n"; +} + diff --git a/egs/yomdle_fa/v1/local/yomdle2csv.py b/egs/yomdle_fa/v1/local/yomdle2csv.py new file mode 100755 index 00000000000..3641de90324 --- /dev/null +++ b/egs/yomdle_fa/v1/local/yomdle2csv.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 + +""" +GEDI2CSV +Convert GEDI-type bounding boxes to CSV format + +GEDI Format Example: + + + + + + + + + +CSV Format Example +ID,name,col1,row1,col2,row2,col3,row3,col4,row4,confidence,truth,pgrot,bbrot,qual,script,lang +0,chinese_scanned_books_0001_0.png,99,41,99,14,754,14,754,41,100,凡我的邻人说是好的,有一大部分在我灵魂中却,0,0.0,0,,zh-cn +""" + +import logging +import os +import sys +import time +import glob +import csv +import imghdr +from PIL import Image +import argparse +import pdb +import cv2 +import numpy as np +import xml.etree.ElementTree as ET + +sin = np.sin +cos = np.cos +pi = np.pi + +def Rotate2D(pts, cnt, ang=90): + M = np.array([[cos(ang),-sin(ang)],[sin(ang),cos(ang)]]) + res = np.dot(pts-cnt,M)+cnt + return M, res + +def npbox2string(npar): + if np.shape(npar)[0] != 1: + print('Error during CSV conversion\n') + c1,r1 = npar[0][0],npar[0][1] + c2,r2 = npar[0][2],npar[0][3] + c3,r3 = npar[0][4],npar[0][5] + c4,r4 = npar[0][6],npar[0][7] + + return c1,r1,c2,r2,c3,r3,c4,r4 + +# cv2.minAreaRect() returns a Box2D structure which contains following detals - ( center (x,y), (width, height), angle of rotation ) +# Get 4 corners of the rectangle using cv2.boxPoints() + +class GEDI2CSV(): + + """ Initialize the extractor""" + def __init__(self, logger, args): + self._logger = logger + self._args = args + + """ + Segment image with GEDI bounding box information + """ + def csvfile(self, coords, polys, baseName, pgrot): + + """ for writing the files """ + writePath = self._args.outputDir + if os.path.isdir(writePath) != True: + os.makedirs(writePath) + + rotlist = [] + + header=['ID','name','col1','row1','col2','row2','col3','row3','col4','row4','confidence','truth','pgrot','bbrot','qual','script','lang'] + conf=100 + pgrot = 0 + bbrot = 0 + qual = 0 + script = '' + + write_ctr = 0 + if len(coords) == 0 and len(polys) == 0: + self._logger.info('Found %s with no text content',(baseName)) + print('...Found %s with no text content' % (baseName)) + return + + strPos = writePath + baseName + + for j in polys: + try: + arr = [] + [id,poly_val,text,qual,lang] = j + script=None + #print(j) + for i in poly_val: + if len(i.strip()) > 0: + #print(i) + arr.append(eval(i)) + + contour = np.asarray(arr) + #print(contour) + convex = cv2.convexHull(contour) + rect = cv2.minAreaRect(convex) + box = cv2.boxPoints(rect) + box = np.int0(box) + box = np.reshape(box,(-1,1)).T + c1,r1,c2,r2,c3,r3,c4,r4 = npbox2string(box) + + bbrot = 0.0 + + rotlist.append([id,baseName + '_' + id + '.png',c1,r1,c2,r2,c3,r3,c4,r4,conf,text,pgrot,bbrot,qual,script,lang]) + + except: + print('...polygon error %s, %s' % (j, baseName)) + continue + + # then write out all of list to file + with open(strPos + ".csv", "w", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(header) + for row in rotlist: + writer.writerow(row) + write_ctr += 1 + + return write_ctr + + +def main(args): + + startTime = time.clock() + + writePath = args.outputDir + print('write to %s' % (writePath)) + if os.path.isdir(writePath) != True: + os.makedirs(writePath) + + """ Setup logging """ + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + if args.log: + handler = logging.FileHandler(args.log) + handler.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + + gtconverter = GEDI2CSV(logger, args) + namespaces = {"gedi" : "http://lamp.cfar.umd.edu/media/projects/GEDI/"} + keyCnt=0 + + fileCnt = 0 + line_write_ctr = 0 + line_error_ctr = 0 + file_error_ctr = 0 + """ + Get all XML files in the directory and sub folders + """ + print('reading %s' % (args.inputDir)) + for root, dirnames, filenames in os.walk(args.inputDir, followlinks=True): + for file in filenames: + if file.lower().endswith('.xml'): + fullName = os.path.join(root,file) + baseName = os.path.splitext(fullName) + + fileCnt += 1 + + try: + """ read the XML file """ + tree = ET.parse(fullName) + except: + print('...ERROR parsing %s' % (fullName)) + file_error_ctr += 1 + continue + + gedi_root = tree.getroot() + child = gedi_root.findall('gedi:DL_DOCUMENT',namespaces)[0] + totalpages = int(child.attrib['NrOfPages']) + coordinates=[] + polygons = [] + + """ and for each page """ + for i, pgs in enumerate(child.iterfind('gedi:DL_PAGE',namespaces)): + + if 'GEDI_orientation' not in pgs.attrib: + pageRot=0 + else: + pageRot = int(pgs.attrib['GEDI_orientation']) + logger.info(' PAGE ROTATION %s, %s' % (fullName, str(pageRot))) + + """ find children for each page """ + for zone in pgs.findall('gedi:DL_ZONE',namespaces): + + if zone.attrib['gedi_type']=='Text' : + if zone.get('polygon'): + keyCnt+=1 + polygons.append([zone.attrib['id'],zone.get('polygon').split(';'), + zone.get('Text_Content'),zone.get('Illegible'),zone.get('Language')]) + else: + print('...Not polygon') + + + if len(coordinates) > 0 or len(polygons) > 0: + line_write_ctr += gtconverter.csvfile(coordinates, polygons, os.path.splitext(file)[0], pageRot) + else: + print('...%s has no text content' % (baseName[0])) + + + print('complete...total files %d, lines written %d, img errors %d, line error %d' % (fileCnt, line_write_ctr, file_error_ctr, line_error_ctr)) + + +def parse_arguments(argv): + """ Args and defaults """ + parser = argparse.ArgumentParser() + + parser.add_argument('--inputDir', type=str, help='Input directory', default='/data/YOMDLE/final_arabic/xml') + parser.add_argument('--outputDir', type=str, help='Output directory', default='/exp/YOMDLE/final_arabic/csv_truth/') + parser.add_argument('--log', type=str, help='Log directory', default='/exp/logs.txt') + + return parser.parse_args(argv) + + +if __name__ == '__main__': + """ Run """ + main(parse_arguments(sys.argv[1:])) diff --git a/egs/yomdle_fa/v1/path.sh b/egs/yomdle_fa/v1/path.sh new file mode 100644 index 00000000000..2d17b17a84a --- /dev/null +++ b/egs/yomdle_fa/v1/path.sh @@ -0,0 +1,6 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C diff --git a/egs/yomdle_fa/v1/run.sh b/egs/yomdle_fa/v1/run.sh new file mode 100755 index 00000000000..a7547b1ee69 --- /dev/null +++ b/egs/yomdle_fa/v1/run.sh @@ -0,0 +1,120 @@ +#!/bin/bash + +set -e +stage=0 +nj=60 + +database_slam=/export/corpora5/slam/SLAM/Farsi/transcribed +database_yomdle=/export/corpora5/slam/YOMDLE/final_farsi +download_dir=data_yomdle_farsi/download/ +extra_lm=download/extra_lm.txt +data_dir=data_yomdle_farsi +exp_dir=exp_yomdle_farsi + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + local/create_download.sh --database-slam $database_slam \ + --database-yomdle $database_yomdle \ + --slam-dir download/slam_farsi \ + --yomdle-dir download/yomdle_farsi +fi + +if [ $stage -le 0 ]; then + mkdir -p data_slam_farsi/slam + mkdir -p data_yomdle_farsi/yomdle + local/process_data.py download/slam_farsi data_slam_farsi/slam + local/process_data.py download/yomdle_farsi data_yomdle_farsi/yomdle + ln -s ../data_slam_farsi/slam ${data_dir}/test + ln -s ../data_yomdle_farsi/yomdle ${data_dir}/train + image/fix_data_dir.sh ${data_dir}/test + image/fix_data_dir.sh ${data_dir}/train +fi + +mkdir -p $data_dir/{train,test}/data +if [ $stage -le 1 ]; then + echo "$0: Obtaining image groups. calling get_image2num_frames" + echo "Date: $(date)." + image/get_image2num_frames.py --feat-dim 40 $data_dir/train + image/get_allowed_lengths.py --frame-subsampling-factor 4 10 $data_dir/train + + for datasplit in train test; do + echo "$0: Extracting features and calling compute_cmvn_stats for dataset: $datasplit. " + echo "Date: $(date)." + local/extract_features.sh --nj $nj --cmd "$cmd" \ + --feat-dim 40 --num-channels 3 --fliplr true \ + $data_dir/${datasplit} + steps/compute_cmvn_stats.sh $data_dir/${datasplit} || exit 1; + done + + echo "$0: Fixing data directory for train dataset" + echo "Date: $(date)." + utils/fix_data_dir.sh $data_dir/train +fi + +if [ $stage -le 2 ]; then + for datasplit in train; do + echo "$(date) stage 2: Performing augmentation, it will double training data" + local/augment_data.sh --nj $nj --cmd "$cmd" --feat-dim 40 --fliplr false $data_dir/${datasplit} $data_dir/${datasplit}_aug $data_dir + steps/compute_cmvn_stats.sh $data_dir/${datasplit}_aug || exit 1; + done +fi + +if [ $stage -le 3 ]; then + echo "$0: Preparing dictionary and lang..." + if [ ! -f $data_dir/train/bpe.out ]; then + cut -d' ' -f2- $data_dir/train/text | local/bidi.py | utils/lang/bpe/prepend_words.py | python3 utils/lang/bpe/learn_bpe.py -s 700 > $data_dir/train/bpe.out + for datasplit in test train train_aug; do + cut -d' ' -f1 $data_dir/$datasplit/text > $data_dir/$datasplit/ids + cut -d' ' -f2- $data_dir/$datasplit/text | local/bidi.py | utils/lang/bpe/prepend_words.py | python3 utils/lang/bpe/apply_bpe.py -c $data_dir/train/bpe.out | sed 's/@@//g' > $data_dir/$datasplit/bpe_text + mv $data_dir/$datasplit/text $data_dir/$datasplit/text.old + paste -d' ' $data_dir/$datasplit/ids $data_dir/$datasplit/bpe_text > $data_dir/$datasplit/text + done + fi + + local/prepare_dict.sh --data-dir $data_dir --dir $data_dir/local/dict + # This recipe uses byte-pair encoding, the silences are part of the words' pronunciations. + # So we set --sil-prob to 0.0 + utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 8 --sil-prob 0.0 --position-dependent-phones false \ + $data_dir/local/dict "" $data_dir/lang/temp $data_dir/lang + utils/lang/bpe/add_final_optional_silence.sh --final-sil-prob 0.5 $data_dir/lang +fi + +if [ $stage -le 4 ]; then + echo "$0: Estimating a language model for decoding..." + local/train_lm.sh --data-dir $data_dir --dir $data_dir/local/local_lm + utils/format_lm.sh $data_dir/lang $data_dir/local/local_lm/data/arpa/3gram_unpruned.arpa.gz \ + $data_dir/local/dict/lexicon.txt $data_dir/lang_test +fi + +if [ $stage -le 5 ]; then + echo "$0: Calling the flat-start chain recipe..." + echo "Date: $(date)." + local/chain/run_flatstart_cnn1a.sh --nj $nj --train-set train_aug --data-dir $data_dir --exp-dir $exp_dir +fi + +if [ $stage -le 6 ]; then + echo "$0: Aligning the training data using the e2e chain model..." + echo "Date: $(date)." + steps/nnet3/align.sh --nj $nj --cmd "$cmd" \ + --scale-opts '--transition-scale=1.0 --acoustic-scale=1.0 --self-loop-scale=1.0' \ + $data_dir/train_aug $data_dir/lang $exp_dir/chain/e2e_cnn_1a $exp_dir/chain/e2e_ali_train +fi + +if [ $stage -le 7 ]; then + echo "$0: Building a tree and training a regular chain model using the e2e alignments..." + echo "Date: $(date)." + local/chain/run_cnn_e2eali_1b.sh --nj $nj --train-set train_aug --data-dir $data_dir --exp-dir $exp_dir +fi + +if [ $stage -le 8 ]; then + echo "$0: Estimating a language model for lattice rescoring...$(date)" + local/train_lm_lr.sh --data-dir $data_dir --dir $data_dir/local/local_lm_lr --extra-lm $extra_lm --order 6 + + utils/build_const_arpa_lm.sh $data_dir/local/local_lm_lr/data/arpa/6gram_unpruned.arpa.gz \ + $data_dir/lang_test $data_dir/lang_test_lr + steps/lmrescore_const_arpa.sh $data_dir/lang_test $data_dir/lang_test_lr \ + $data_dir/test $exp_dir/chain/cnn_e2eali_1b/decode_test $exp_dir/chain/cnn_e2eali_1b/decode_test_lr +fi diff --git a/egs/yomdle_fa/v1/steps b/egs/yomdle_fa/v1/steps new file mode 120000 index 00000000000..1b186770dd1 --- /dev/null +++ b/egs/yomdle_fa/v1/steps @@ -0,0 +1 @@ +../../wsj/s5/steps/ \ No newline at end of file diff --git a/egs/yomdle_fa/v1/utils b/egs/yomdle_fa/v1/utils new file mode 120000 index 00000000000..a3279dc8679 --- /dev/null +++ b/egs/yomdle_fa/v1/utils @@ -0,0 +1 @@ +../../wsj/s5/utils/ \ No newline at end of file diff --git a/egs/yomdle_tamil/v1/local/augment_data.sh b/egs/yomdle_tamil/v1/local/augment_data.sh index 82fa5230a43..136bfd24eb2 100755 --- a/egs/yomdle_tamil/v1/local/augment_data.sh +++ b/egs/yomdle_tamil/v1/local/augment_data.sh @@ -8,6 +8,7 @@ nj=4 cmd=run.pl feat_dim=40 +verticle_shift=0 echo "$0 $@" . ./cmd.sh @@ -26,7 +27,8 @@ for set in aug1; do $srcdir $datadir/augmentations/$set cat $srcdir/allowed_lengths.txt > $datadir/augmentations/$set/allowed_lengths.txt local/extract_features.sh --nj $nj --cmd "$cmd" --feat-dim $feat_dim \ - --fliplr false --augment true $datadir/augmentations/$set + --vertical-shift $verticle_shift \ + --fliplr false --augment 'random_scale' $datadir/augmentations/$set done echo " combine original data and data from different augmentations" diff --git a/egs/yomdle_tamil/v1/local/extract_features.sh b/egs/yomdle_tamil/v1/local/extract_features.sh index 4ed6ba04348..3880ebad3e8 100755 --- a/egs/yomdle_tamil/v1/local/extract_features.sh +++ b/egs/yomdle_tamil/v1/local/extract_features.sh @@ -9,7 +9,7 @@ nj=4 cmd=run.pl feat_dim=40 -augment=false +augment='no_aug' fliplr=false echo "$0 $@" @@ -38,7 +38,7 @@ utils/split_scp.pl $scp $split_scps || exit 1; $cmd JOB=1:$nj $logdir/extract_features.JOB.log \ image/ocr/make_features.py $logdir/images.JOB.scp \ --allowed_len_file_path $data/allowed_lengths.txt \ - --feat-dim $feat_dim --fliplr $fliplr --augment $augment \| \ + --feat-dim $feat_dim --fliplr $fliplr --augment_type $augment \| \ copy-feats --compress=true --compression-method=7 \ ark:- ark,scp:$featdir/images.JOB.ark,$featdir/images.JOB.scp diff --git a/egs/yomdle_zh/README.txt b/egs/yomdle_zh/README.txt new file mode 100644 index 00000000000..39d2348ca10 --- /dev/null +++ b/egs/yomdle_zh/README.txt @@ -0,0 +1,3 @@ +This directory contains example scripts for OCR on the Yomdle and Slam datasets. +Training is done on the Yomdle dataset and testing is done on Slam. +LM rescoring is also done with extra corpus data obtained from various sources (e.g. Hamshahri) diff --git a/egs/yomdle_zh/v1/cmd.sh b/egs/yomdle_zh/v1/cmd.sh new file mode 100755 index 00000000000..3c8eb9f93a5 --- /dev/null +++ b/egs/yomdle_zh/v1/cmd.sh @@ -0,0 +1,13 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +export cmd="queue.pl" diff --git a/egs/yomdle_zh/v1/image b/egs/yomdle_zh/v1/image new file mode 120000 index 00000000000..1668ee99922 --- /dev/null +++ b/egs/yomdle_zh/v1/image @@ -0,0 +1 @@ +../../cifar/v1/image/ \ No newline at end of file diff --git a/egs/yomdle_zh/v1/local/augment_data.sh b/egs/yomdle_zh/v1/local/augment_data.sh new file mode 100755 index 00000000000..1f13ed15ded --- /dev/null +++ b/egs/yomdle_zh/v1/local/augment_data.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright 2018 Hossein Hadian +# 2018 Ashish Arora + +# Apache 2.0 +# This script performs data augmentation. + +nj=4 +cmd=run.pl +feat_dim=40 +fliplr=false +verticle_shift=0 +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +srcdir=$1 +outdir=$2 +datadir=$3 + +mkdir -p $datadir/augmentations +echo "copying $srcdir to $datadir/augmentations/aug1, allowed length, creating feats.scp" + +for set in aug1; do + image/copy_data_dir.sh --spk-prefix $set- --utt-prefix $set- \ + $srcdir $datadir/augmentations/$set + cat $srcdir/allowed_lengths.txt > $datadir/augmentations/$set/allowed_lengths.txt + local/extract_features.sh --nj $nj --cmd "$cmd" --feat-dim $feat_dim \ + --vertical-shift $verticle_shift \ + --fliplr $fliplr --augment 'random_scale' $datadir/augmentations/$set +done + +echo " combine original data and data from different augmentations" +utils/combine_data.sh --extra-files images.scp $outdir $srcdir $datadir/augmentations/aug1 +cat $srcdir/allowed_lengths.txt > $outdir/allowed_lengths.txt diff --git a/egs/yomdle_zh/v1/local/bidi.py b/egs/yomdle_zh/v1/local/bidi.py new file mode 100755 index 00000000000..447313a5d02 --- /dev/null +++ b/egs/yomdle_zh/v1/local/bidi.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright 2018 Chun-Chieh Chang + +# This script is largely written by Stephen Rawls +# and uses the python package https://pypi.org/project/PyICU_BiDi/ +# The code leaves right to left text alone and reverses left to right text. + +import icu_bidi +import io +import sys +import unicodedata +# R=strong right-to-left; AL=strong arabic right-to-left +rtl_set = set(chr(i) for i in range(sys.maxunicode) + if unicodedata.bidirectional(chr(i)) in ['R','AL']) +def determine_text_direction(text): + # Easy case first + for char in text: + if char in rtl_set: + return icu_bidi.UBiDiLevel.UBIDI_RTL + # If we made it here we did not encounter any strongly rtl char + return icu_bidi.UBiDiLevel.UBIDI_LTR + +def utf8_visual_to_logical(text): + text_dir = determine_text_direction(text) + + bidi = icu_bidi.Bidi() + bidi.inverse = True + bidi.reordering_mode = icu_bidi.UBiDiReorderingMode.UBIDI_REORDER_INVERSE_LIKE_DIRECT + bidi.reordering_options = icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_DEFAULT # icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_INSERT_MARKS + + bidi.set_para(text, text_dir, None) + + res = bidi.get_reordered(0 | icu_bidi.UBidiWriteReorderedOpt.UBIDI_DO_MIRRORING | icu_bidi.UBidiWriteReorderedOpt.UBIDI_KEEP_BASE_COMBINING) + + return res + +def utf8_logical_to_visual(text): + text_dir = determine_text_direction(text) + + bidi = icu_bidi.Bidi() + + bidi.reordering_mode = icu_bidi.UBiDiReorderingMode.UBIDI_REORDER_DEFAULT + bidi.reordering_options = icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_DEFAULT #icu_bidi.UBiDiReorderingOption.UBIDI_OPTION_INSERT_MARKS + + bidi.set_para(text, text_dir, None) + + res = bidi.get_reordered(0 | icu_bidi.UBidiWriteReorderedOpt.UBIDI_DO_MIRRORING | icu_bidi.UBidiWriteReorderedOpt.UBIDI_KEEP_BASE_COMBINING) + + return res + + +##main## +sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding="utf8") +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf8") +for line in sys.stdin: + line = line.strip() + line = utf8_logical_to_visual(line)[::-1] + sys.stdout.write(line + '\n') diff --git a/egs/yomdle_zh/v1/local/chain/compare_wer.sh b/egs/yomdle_zh/v1/local/chain/compare_wer.sh new file mode 100755 index 00000000000..ab880c1adb5 --- /dev/null +++ b/egs/yomdle_zh/v1/local/chain/compare_wer.sh @@ -0,0 +1,67 @@ +#!/bin/bash + +# this script is used for comparing decoding results between systems. +# e.g. local/chain/compare_wer.sh exp/chain/cnn{1a,1b} + +# Copyright 2017 Chun Chieh Chang +# 2017 Ashish Arora + +if [ $# == 0 ]; then + echo "Usage: $0: [ ... ]" + echo "e.g.: $0 exp/chain/cnn{1a,1b}" + exit 1 +fi + +echo "# $0 $*" +used_epochs=false + +echo -n "# System " +for x in $*; do printf "% 10s" " $(basename $x)"; done +echo + +echo -n "# WER " +for x in $*; do + wer=$(cat $x/decode_test/scoring_kaldi/best_wer | awk '{print $2}') + printf "% 10s" $wer +done +echo + +echo -n "# CER " +for x in $*; do + cer=$(cat $x/decode_test/scoring_kaldi/best_cer | awk '{print $2}') + printf "% 10s" $cer +done +echo + + +if $used_epochs; then + exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. +fi + +echo -n "# Final train prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final train prob (xent) " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob (xent) " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo diff --git a/egs/yomdle_zh/v1/local/chain/run_cnn_e2eali_1b.sh b/egs/yomdle_zh/v1/local/chain/run_cnn_e2eali_1b.sh new file mode 100755 index 00000000000..4183aa74587 --- /dev/null +++ b/egs/yomdle_zh/v1/local/chain/run_cnn_e2eali_1b.sh @@ -0,0 +1,245 @@ +#!/bin/bash + +# e2eali_1b is the same as chainali_1a but uses the e2e chain model to get the +# lattice alignments and to build a tree + +# ./local/chain/compare_wer.sh exp_yomdle_chinese/chain/e2e_cnn_1a exp_yomdle_chinese/chain/cnn_e2eali_1b +# System e2e_cnn_1a cnn_e2eali_1b +# CER 15.44 13.57 +# Final train prob 0.0616 -0.0512 +# Final valid prob 0.0390 -0.0718 +# Final train prob (xent) -0.6199 +# Final valid prob (xent) -0.7448 + +set -e -o pipefail + +data_dir=data +exp_dir=exp + +stage=0 + +nj=30 +train_set=train +nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. +affix=_1b #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. +common_egs_dir= +reporting_email= + +# chain options +train_stage=-10 +xent_regularize=0.1 +frame_subsampling_factor=4 +# training chunk-options +chunk_width=340,300,200,100 +num_leaves=1000 +# we don't need extra left/right context for TDNN systems. +chunk_left_context=0 +chunk_right_context=0 +tdnn_dim=450 +# training options +srand=0 +remove_egs=true +lang_test=lang_test +# End configuration section. +echo "$0 $@" # Print the command line for logging + + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 2 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/nnet3/align_lats.sh --nj $nj --cmd "$cmd" \ + --acoustic-scale 1.0 \ + --scale-opts '--transition-scale=1.0 --self-loop-scale=1.0' \ + ${train_data_dir} $data_dir/lang $e2echain_model_dir $lat_dir + echo "" >$lat_dir/splice_opts + +fi + +if [ $stage -le 3 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor $frame_subsampling_factor \ + --alignment-subsampling-factor 1 \ + --context-opts "--context-width=3 --central-position=1" \ + --cmd "$cmd" $num_leaves ${train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 4 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $tree_dir/tree | grep num-pdfs | awk '{print $2}') + learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + cnn_opts="l2-regularize=0.075" + tdnn_opts="l2-regularize=0.075" + output_opts="l2-regularize=0.1" + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=32" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=128" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=512" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=180 name=input + + conv-relu-batchnorm-layer name=cnn1 height-in=60 height-out=60 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=60 height-out=60 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn3 height-in=60 height-out=30 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn4 height-in=30 height-out=30 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=30 height-out=30 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn6 height-in=30 height-out=15 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn7 height-in=15 height-out=15 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn8 height-in=15 height-out=15 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn9 height-in=15 height-out=15 time-offsets=-4,0,4 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-8,-4,0,4,8) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' mod?els... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn3 dim=$tdnn_dim target-rms=0.5 $tdnn_opts + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 5 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/iam-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$cmd" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=500" \ + --chain.frame-subsampling-factor=$frame_subsampling_factor \ + --chain.alignment-subsampling-factor=1 \ + --chain.left-tolerance 3 \ + --chain.right-tolerance 3 \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=6 \ + --trainer.frames-per-iter=1000000 \ + --trainer.optimization.num-jobs-initial=4 \ + --trainer.optimization.num-jobs-final=8 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=16,8 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.chunk-left-context=$chunk_left_context \ + --egs.chunk-right-context=$chunk_right_context \ + --egs.chunk-left-context-initial=0 \ + --egs.chunk-right-context-final=0 \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0 --constrained false" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 6 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $data_dir/$lang_test \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 7 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --extra-left-context $chunk_left_context \ + --extra-right-context $chunk_right_context \ + --extra-left-context-initial 0 \ + --extra-right-context-final 0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$cmd" \ + $dir/graph $data_dir/test $dir/decode_test || exit 1; +fi diff --git a/egs/yomdle_zh/v1/local/chain/run_flatstart_cnn1a.sh b/egs/yomdle_zh/v1/local/chain/run_flatstart_cnn1a.sh new file mode 100755 index 00000000000..88bbd32790c --- /dev/null +++ b/egs/yomdle_zh/v1/local/chain/run_flatstart_cnn1a.sh @@ -0,0 +1,169 @@ +#!/bin/bash +# Copyright 2017 Hossein Hadian + +# This script does end2end chain training (i.e. from scratch) + +# ./local/chain/compare_wer.sh exp_yomdle_chinese/chain/e2e_cnn_1a exp_yomdle_chinese/chain/cnn_e2eali_1b +# System e2e_cnn_1a cnn_e2eali_1b +# CER 15.44 13.57 +# Final train prob 0.0616 -0.0512 +# Final valid prob 0.0390 -0.0718 +# Final train prob (xent) -0.6199 +# Final valid prob (xent) -0.7448 + +set -e + +data_dir=data +exp_dir=exp + +# configs for 'chain' +stage=0 +nj=30 +train_stage=-10 +get_egs_stage=-10 +affix=1a + +# training options +tdnn_dim=450 +num_epochs=4 +num_jobs_initial=4 +num_jobs_final=8 +minibatch_size=150=64,32/300=32,16/600=16,8/1200=8,4 +common_egs_dir= +l2_regularize=0.00005 +frames_per_iter=1000000 +cmvn_opts="--norm-means=false --norm-vars=false" +train_set=train +lang_test=lang_test + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 1 ]; then + steps/nnet3/chain/e2e/prepare_e2e.sh --nj $nj --cmd "$cmd" \ + --shared-phones true \ + --type mono \ + $data_dir/$train_set $lang $treedir + $cmd $treedir/log/make_phone_lm.log \ + cat $data_dir/$train_set/text \| \ + steps/nnet3/chain/e2e/text_to_phones.py $data_dir/lang \| \ + utils/sym2int.pl -f 2- $data_dir/lang/phones.txt \| \ + chain-est-phone-lm --num-extra-lm-states=500 \ + ark:- $treedir/phone_lm.fst +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(tree-info $treedir/tree | grep num-pdfs | awk '{print $2}') + + cnn_opts="l2-regularize=0.075" + tdnn_opts="l2-regularize=0.075" + output_opts="l2-regularize=0.1" + + common1="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=32" + common2="$cnn_opts required-time-offsets= height-offsets=-2,-1,0,1,2 num-filters-out=128" + common3="$cnn_opts required-time-offsets= height-offsets=-1,0,1 num-filters-out=512" + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=180 name=input + conv-relu-batchnorm-layer name=cnn1 height-in=60 height-out=60 time-offsets=-3,-2,-1,0,1,2,3 $common1 + conv-relu-batchnorm-layer name=cnn2 height-in=60 height-out=30 time-offsets=-2,-1,0,1,2 $common1 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn3 height-in=30 height-out=30 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn4 height-in=30 height-out=30 time-offsets=-4,-2,0,2,4 $common2 + conv-relu-batchnorm-layer name=cnn5 height-in=30 height-out=15 time-offsets=-4,-2,0,2,4 $common2 height-subsample-out=2 + conv-relu-batchnorm-layer name=cnn6 height-in=15 height-out=15 time-offsets=-4,0,4 $common3 + conv-relu-batchnorm-layer name=cnn7 height-in=15 height-out=15 time-offsets=-4,0,4 $common3 + relu-batchnorm-layer name=tdnn1 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn2 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + relu-batchnorm-layer name=tdnn3 input=Append(-4,0,4) dim=$tdnn_dim $tdnn_opts + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain dim=$tdnn_dim target-rms=0.5 $output_opts + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 $output_opts +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs +fi + +if [ $stage -le 3 ]; then + # no need to store the egs in a shared storage because we always + # remove them. Anyway, it takes only 5 minutes to generate them. + + steps/nnet3/chain/e2e/train_e2e.py --stage $train_stage \ + --cmd "$cmd" \ + --feat.cmvn-opts "$cmvn_opts" \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize $l2_regularize \ + --chain.apply-deriv-weights false \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--num_egs_diagnostic 100 --num_utts_subset 400" \ + --chain.frame-subsampling-factor 4 \ + --chain.alignment-subsampling-factor 4 \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter $frames_per_iter \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.momentum 0 \ + --trainer.optimization.num-jobs-initial $num_jobs_initial \ + --trainer.optimization.num-jobs-final $num_jobs_final \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.optimization.shrink-value 1.0 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs true \ + --feat-dir $data_dir/${train_set} \ + --tree-dir $treedir \ + --dir $dir || exit 1; +fi + +if [ $stage -le 4 ]; then + # The reason we are using data/lang here, instead of $lang, is just to + # emphasize that it's not actually important to give mkgraph.sh the + # lang directory with the matched topology (since it gets the + # topology file from the model). So you could give it a different + # lang directory, one that contained a wordlist and LM of your choice, + # as long as phones.txt was compatible. + + utils/mkgraph.sh \ + --self-loop-scale 1.0 $data_dir/$lang_test \ + $dir $dir/graph || exit 1; +fi + +if [ $stage -le 5 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $nj --cmd "$cmd" \ + $dir/graph $data_dir/test $dir/decode_test || exit 1; +fi + +echo "Done. Date: $(date). Results:" +local/chain/compare_wer.sh $dir diff --git a/egs/yomdle_zh/v1/local/create_download.sh b/egs/yomdle_zh/v1/local/create_download.sh new file mode 100755 index 00000000000..a440a331747 --- /dev/null +++ b/egs/yomdle_zh/v1/local/create_download.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Copyright 2018 Chun-Chieh Chang + +# The original format of the dataset given is GEDI and page images. +# This script is written to create line images from page images. +# It also creates csv files from the GEDI files. + +database_slam=/export/corpora5/slam/SLAM/Farsi/transcribed +database_yomdle=/export/corpora5/slam/YOMDLE/final_farsi +cangjie_url=https://raw.githubusercontent.com/wanleung/libcangjie/master/tables/cj5-cc.txt +download_dir=download +slam_dir=$download_dir/slam_farsi +yomdle_dir=$download_dir/yomdle_farsi + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +echo "$0: Processing SLAM ${language}" +echo "Date: $(date)." +mkdir -p ${slam_dir}/{truth_csv,truth_csv_raw,truth_line_image} +local/gedi2csv.py \ + --inputDir ${database_slam} \ + --outputDir ${slam_dir}/truth_csv_raw \ + --log ${slam_dir}/GEDI2CSV_enriched.log +local/create_line_image_from_page_image.py \ + ${database_slam} \ + ${slam_dir}/truth_csv_raw \ + ${slam_dir} + +echo "$0: Processing YOMDLE ${language}" +echo "Date: $(date)." +mkdir -p ${yomdle_dir}/{truth_csv,truth_csv_raw,truth_line_image} +local/yomdle2csv.py \ + --inputDir ${database_yomdle} \ + --outputDir ${yomdle_dir}/truth_csv_raw/ \ + --log ${yomdle_dir}/YOMDLE2CSV.log +local/create_line_image_from_page_image.py \ + --im-format "jpg" \ + ${database_yomdle}/images \ + ${yomdle_dir}/truth_csv_raw \ + ${yomdle_dir} + +echo "Downloading table for CangJie." +wget -P $download_dir/ $cangjie_url || exit 1; +sed -ie '1,8d' $download_dir/cj5-cc.txt diff --git a/egs/yomdle_zh/v1/local/create_line_image_from_page_image.py b/egs/yomdle_zh/v1/local/create_line_image_from_page_image.py new file mode 100755 index 00000000000..77a6791d5d7 --- /dev/null +++ b/egs/yomdle_zh/v1/local/create_line_image_from_page_image.py @@ -0,0 +1,458 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora +# Apache 2.0 +# minimum bounding box part in this script is originally from +#https://github.com/BebeSparkelSparkel/MinimumBoundingBox +#https://startupnextdoor.com/computing-convex-hull-in-python/ +""" This module will be used for extracting line images from page image. + Given the word segmentation (bounding box around a word) for every word, it will + extract line segmentation. To extract line segmentation, it will take word bounding + boxes of a line as input, will create a minimum area bounding box that will contain + all corner points of word bounding boxes. The obtained bounding box (will not necessarily + be vertically or horizontally aligned). Hence to extract line image from line bounding box, + page image is rotated and line image is cropped and saved. +""" + +import argparse +import csv +import itertools +import sys +import os +import numpy as np +from math import atan2, cos, sin, pi, degrees, sqrt +from collections import namedtuple + +from scipy.spatial import ConvexHull +from PIL import Image +from scipy.misc import toimage + +parser = argparse.ArgumentParser(description="Creates line images from page image") +parser.add_argument('image_dir', type=str, help='Path to full page images') +parser.add_argument('csv_dir', type=str, help='Path to csv files') +parser.add_argument('out_dir', type=str, help='Path to output directory') +parser.add_argument('--im-format', type=str, default='png', help='What file format are the images') +parser.add_argument('--padding', type=int, default=100, help='Padding so BBox does not exceed image area') +parser.add_argument('--head', type=int, default=-1, help='Number of csv files to process') +args = parser.parse_args() + +""" +bounding_box is a named tuple which contains: + area (float): area of the rectangle + length_parallel (float): length of the side that is parallel to unit_vector + length_orthogonal (float): length of the side that is orthogonal to unit_vector + rectangle_center(int, int): coordinates of the rectangle center + (use rectangle_corners to get the corner points of the rectangle) + unit_vector (float, float): direction of the length_parallel side. + (it's orthogonal vector can be found with the orthogonal_vector function + unit_vector_angle (float): angle of the unit vector to be in radians. + corner_points [(float, float)]: set that contains the corners of the rectangle +""" + +bounding_box_tuple = namedtuple('bounding_box_tuple', 'area ' + 'length_parallel ' + 'length_orthogonal ' + 'rectangle_center ' + 'unit_vector ' + 'unit_vector_angle ' + 'corner_points' + ) + + +def unit_vector(pt0, pt1): + """ Given two points pt0 and pt1, return a unit vector that + points in the direction of pt0 to pt1. + Returns + ------- + (float, float): unit vector + """ + dis_0_to_1 = sqrt((pt0[0] - pt1[0])**2 + (pt0[1] - pt1[1])**2) + return (pt1[0] - pt0[0]) / dis_0_to_1, \ + (pt1[1] - pt0[1]) / dis_0_to_1 + + +def orthogonal_vector(vector): + """ Given a vector, returns a orthogonal/perpendicular vector of equal length. + Returns + ------ + (float, float): A vector that points in the direction orthogonal to vector. + """ + return -1 * vector[1], vector[0] + + +def bounding_area(index, hull): + """ Given index location in an array and convex hull, it gets two points + hull[index] and hull[index+1]. From these two points, it returns a named + tuple that mainly contains area of the box that bounds the hull. This + bounding box orintation is same as the orientation of the lines formed + by the point hull[index] and hull[index+1]. + Returns + ------- + a named tuple that contains: + area: area of the rectangle + length_parallel: length of the side that is parallel to unit_vector + length_orthogonal: length of the side that is orthogonal to unit_vector + rectangle_center: coordinates of the rectangle center + unit_vector: direction of the length_parallel side. + (it's orthogonal vector can be found with the orthogonal_vector function) + """ + unit_vector_p = unit_vector(hull[index], hull[index+1]) + unit_vector_o = orthogonal_vector(unit_vector_p) + + dis_p = tuple(np.dot(unit_vector_p, pt) for pt in hull) + dis_o = tuple(np.dot(unit_vector_o, pt) for pt in hull) + + min_p = min(dis_p) + min_o = min(dis_o) + len_p = max(dis_p) - min_p + len_o = max(dis_o) - min_o + + return {'area': len_p * len_o, + 'length_parallel': len_p, + 'length_orthogonal': len_o, + 'rectangle_center': (min_p + len_p / 2, min_o + len_o / 2), + 'unit_vector': unit_vector_p, + } + + +def to_xy_coordinates(unit_vector_angle, point): + """ Given angle from horizontal axis and a point from origin, + returns converted unit vector coordinates in x, y coordinates. + angle of unit vector should be in radians. + Returns + ------ + (float, float): converted x,y coordinate of the unit vector. + """ + angle_orthogonal = unit_vector_angle + pi / 2 + return point[0] * cos(unit_vector_angle) + point[1] * cos(angle_orthogonal), \ + point[0] * sin(unit_vector_angle) + point[1] * sin(angle_orthogonal) + + +def rotate_points(center_of_rotation, angle, points): + """ Rotates a point cloud around the center_of_rotation point by angle + input + ----- + center_of_rotation (float, float): angle of unit vector to be in radians. + angle (float): angle of rotation to be in radians. + points [(float, float)]: Points to be a list or tuple of points. Points to be rotated. + Returns + ------ + [(float, float)]: Rotated points around center of rotation by angle + """ + rot_points = [] + ang = [] + for pt in points: + diff = tuple([pt[d] - center_of_rotation[d] for d in range(2)]) + diff_angle = atan2(diff[1], diff[0]) + angle + ang.append(diff_angle) + diff_length = sqrt(sum([d**2 for d in diff])) + rot_points.append((center_of_rotation[0] + diff_length * cos(diff_angle), + center_of_rotation[1] + diff_length * sin(diff_angle))) + + return rot_points + + +def rectangle_corners(rectangle): + """ Given rectangle center and its inclination, returns the corner + locations of the rectangle. + Returns + ------ + [(float, float)]: 4 corner points of rectangle. + """ + corner_points = [] + for i1 in (.5, -.5): + for i2 in (i1, -1 * i1): + corner_points.append((rectangle['rectangle_center'][0] + i1 * rectangle['length_parallel'], + rectangle['rectangle_center'][1] + i2 * rectangle['length_orthogonal'])) + + return rotate_points(rectangle['rectangle_center'], rectangle['unit_vector_angle'], corner_points) + + +def get_orientation(origin, p1, p2): + """ + Given origin and two points, return the orientation of the Point p1 with + regards to Point p2 using origin. + Returns + ------- + integer: Negative if p1 is clockwise of p2. + """ + difference = ( + ((p2[0] - origin[0]) * (p1[1] - origin[1])) + - ((p1[0] - origin[0]) * (p2[1] - origin[1])) + ) + return difference + + +def compute_hull(points): + """ + Given input list of points, return a list of points that + made up the convex hull. + Returns + ------- + [(float, float)]: convexhull points + """ + hull_points = [] + start = points[0] + min_x = start[0] + for p in points[1:]: + if p[0] < min_x: + min_x = p[0] + start = p + + point = start + hull_points.append(start) + + far_point = None + while far_point is not start: + p1 = None + for p in points: + if p is point: + continue + else: + p1 = p + break + + far_point = p1 + + for p2 in points: + if p2 is point or p2 is p1: + continue + else: + direction = get_orientation(point, far_point, p2) + if direction > 0: + far_point = p2 + + hull_points.append(far_point) + point = far_point + return hull_points + + +def minimum_bounding_box(points): + """ Given a list of 2D points, it returns the minimum area rectangle bounding all + the points in the point cloud. + Returns + ------ + returns a namedtuple that contains: + area: area of the rectangle + length_parallel: length of the side that is parallel to unit_vector + length_orthogonal: length of the side that is orthogonal to unit_vector + rectangle_center: coordinates of the rectangle center + unit_vector: direction of the length_parallel side. RADIANS + unit_vector_angle: angle of the unit vector + corner_points: set that contains the corners of the rectangle + """ + + if len(points) <= 2: raise ValueError('More than two points required.') + + hull_ordered = [points[index] for index in ConvexHull(points).vertices] + hull_ordered.append(hull_ordered[0]) + #hull_ordered = compute_hull(points) + hull_ordered = tuple(hull_ordered) + + min_rectangle = bounding_area(0, hull_ordered) + for i in range(1, len(hull_ordered)-1): + rectangle = bounding_area(i, hull_ordered) + if rectangle['area'] < min_rectangle['area']: + min_rectangle = rectangle + + min_rectangle['unit_vector_angle'] = atan2(min_rectangle['unit_vector'][1], min_rectangle['unit_vector'][0]) + min_rectangle['rectangle_center'] = to_xy_coordinates(min_rectangle['unit_vector_angle'], min_rectangle['rectangle_center']) + + return bounding_box_tuple( + area = min_rectangle['area'], + length_parallel = min_rectangle['length_parallel'], + length_orthogonal = min_rectangle['length_orthogonal'], + rectangle_center = min_rectangle['rectangle_center'], + unit_vector = min_rectangle['unit_vector'], + unit_vector_angle = min_rectangle['unit_vector_angle'], + corner_points = set(rectangle_corners(min_rectangle)) + ) + + +def get_center(im): + """ Given image, returns the location of center pixel + Returns + ------- + (int, int): center of the image + """ + center_x = im.size[0] / 2 + center_y = im.size[1] / 2 + return int(center_x), int(center_y) + + +def get_horizontal_angle(unit_vector_angle): + """ Given an angle in radians, returns angle of the unit vector in + first or fourth quadrant. + Returns + ------ + (float): updated angle of the unit vector to be in radians. + It is only in first or fourth quadrant. + """ + if unit_vector_angle > pi / 2 and unit_vector_angle <= pi: + unit_vector_angle = unit_vector_angle - pi + elif unit_vector_angle > -pi and unit_vector_angle < -pi / 2: + unit_vector_angle = unit_vector_angle + pi + + return unit_vector_angle + + +def get_smaller_angle(bounding_box): + """ Given a rectangle, returns its smallest absolute angle from horizontal axis. + Returns + ------ + (float): smallest angle of the rectangle to be in radians. + """ + unit_vector = bounding_box.unit_vector + unit_vector_angle = bounding_box.unit_vector_angle + ortho_vector = orthogonal_vector(unit_vector) + ortho_vector_angle = atan2(ortho_vector[1], ortho_vector[0]) + + unit_vector_angle_updated = get_horizontal_angle(unit_vector_angle) + ortho_vector_angle_updated = get_horizontal_angle(ortho_vector_angle) + + if abs(unit_vector_angle_updated) < abs(ortho_vector_angle_updated): + return unit_vector_angle_updated + else: + return ortho_vector_angle_updated + + +def rotated_points(bounding_box, center): + """ Given the rectangle, returns corner points of rotated rectangle. + It rotates the rectangle around the center by its smallest angle. + Returns + ------- + [(int, int)]: 4 corner points of rectangle. + """ + p1, p2, p3, p4 = bounding_box.corner_points + x1, y1 = p1 + x2, y2 = p2 + x3, y3 = p3 + x4, y4 = p4 + center_x, center_y = center + rotation_angle_in_rad = -get_smaller_angle(bounding_box) + x_dash_1 = (x1 - center_x) * cos(rotation_angle_in_rad) - (y1 - center_y) * sin(rotation_angle_in_rad) + center_x + x_dash_2 = (x2 - center_x) * cos(rotation_angle_in_rad) - (y2 - center_y) * sin(rotation_angle_in_rad) + center_x + x_dash_3 = (x3 - center_x) * cos(rotation_angle_in_rad) - (y3 - center_y) * sin(rotation_angle_in_rad) + center_x + x_dash_4 = (x4 - center_x) * cos(rotation_angle_in_rad) - (y4 - center_y) * sin(rotation_angle_in_rad) + center_x + + y_dash_1 = (y1 - center_y) * cos(rotation_angle_in_rad) + (x1 - center_x) * sin(rotation_angle_in_rad) + center_y + y_dash_2 = (y2 - center_y) * cos(rotation_angle_in_rad) + (x2 - center_x) * sin(rotation_angle_in_rad) + center_y + y_dash_3 = (y3 - center_y) * cos(rotation_angle_in_rad) + (x3 - center_x) * sin(rotation_angle_in_rad) + center_y + y_dash_4 = (y4 - center_y) * cos(rotation_angle_in_rad) + (x4 - center_x) * sin(rotation_angle_in_rad) + center_y + return x_dash_1, y_dash_1, x_dash_2, y_dash_2, x_dash_3, y_dash_3, x_dash_4, y_dash_4 + + +def pad_image(image): + """ Given an image, returns a padded image around the border. + This routine save the code from crashing if bounding boxes that are + slightly outside the page boundary. + Returns + ------- + image: page image + """ + offset = int(args.padding // 2) + padded_image = Image.new('RGB', (image.size[0] + int(args.padding), image.size[1] + int(args.padding)), "white") + padded_image.paste(im = image, box = (offset, offset)) + return padded_image + +def update_minimum_bounding_box_input(bounding_box_input): + """ Given list of 2D points, returns list of 2D points shifted by an offset. + Returns + ------ + points [(float, float)]: points, a list or tuple of 2D coordinates + """ + updated_minimum_bounding_box_input = [] + offset = int(args.padding // 2) + for point in bounding_box_input: + x, y = point + new_x = x + offset + new_y = y + offset + word_coordinate = (new_x, new_y) + updated_minimum_bounding_box_input.append(word_coordinate) + + return updated_minimum_bounding_box_input + + +### main ### +csv_count = 0 +for filename in sorted(os.listdir(args.csv_dir)): + if filename.endswith('.csv') and (csv_count < args.head or args.head < 0): + csv_count = csv_count + 1 + with open(os.path.join(args.csv_dir, filename), 'r', encoding='utf-8') as f: + image_file = os.path.join(args.image_dir, os.path.splitext(filename)[0] + '.' + args.im_format) + if not os.path.isfile(image_file): + continue + csv_out_file = os.path.join(args.out_dir, 'truth_csv', filename) + csv_out_fh = open(csv_out_file, 'w', encoding='utf-8') + csv_out_writer = csv.writer(csv_out_fh) + im = Image.open(image_file) + im = pad_image(im) + count = 1 + for row in itertools.islice(csv.reader(f), 0, None): + if count == 1: + count = 0 + continue + + points = [] + points.append((int(row[2]), int(row[3]))) + points.append((int(row[4]), int(row[5]))) + points.append((int(row[6]), int(row[7]))) + points.append((int(row[8]), int(row[9]))) + + x = [int(row[2]), int(row[4]), int(row[6]), int(row[8])] + y = [int(row[3]), int(row[5]), int(row[7]), int(row[9])] + min_x, min_y = min(x), min(y) + max_x, max_y = max(x), max(y) + if min_x == max_x or min_y == max_y: + continue + + try: + updated_mbb_input = update_minimum_bounding_box_input(points) + bounding_box = minimum_bounding_box(updated_mbb_input) + except Exception as e: + print("Error: Skipping Image " + row[1]) + continue + + p1, p2, p3, p4 = bounding_box.corner_points + x1, y1 = p1 + x2, y2 = p2 + x3, y3 = p3 + x4, y4 = p4 + min_x = int(min(x1, x2, x3, x4)) + min_y = int(min(y1, y2, y3, y4)) + max_x = int(max(x1, x2, x3, x4)) + max_y = int(max(y1, y2, y3, y4)) + box = (min_x, min_y, max_x, max_y) + region_initial = im.crop(box) + rot_points = [] + p1_new = (x1 - min_x, y1 - min_y) + p2_new = (x2 - min_x, y2 - min_y) + p3_new = (x3 - min_x, y3 - min_y) + p4_new = (x4 - min_x, y4 - min_y) + rot_points.append(p1_new) + rot_points.append(p2_new) + rot_points.append(p3_new) + rot_points.append(p4_new) + + cropped_bounding_box = bounding_box_tuple(bounding_box.area, + bounding_box.length_parallel, + bounding_box.length_orthogonal, + bounding_box.length_orthogonal, + bounding_box.unit_vector, + bounding_box.unit_vector_angle, + set(rot_points)) + + rotation_angle_in_rad = get_smaller_angle(cropped_bounding_box) + img2 = region_initial.rotate(degrees(rotation_angle_in_rad), resample = Image.BICUBIC) + x_dash_1, y_dash_1, x_dash_2, y_dash_2, x_dash_3, y_dash_3, x_dash_4, y_dash_4 = rotated_points( + cropped_bounding_box, get_center(region_initial)) + + min_x = int(min(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) + min_y = int(min(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) + max_x = int(max(x_dash_1, x_dash_2, x_dash_3, x_dash_4)) + max_y = int(max(y_dash_1, y_dash_2, y_dash_3, y_dash_4)) + box = (min_x, min_y, max_x, max_y) + region_final = img2.crop(box) + csv_out_writer.writerow(row) + image_out_file = os.path.join(args.out_dir, 'truth_line_image', row[1]) + region_final.save(image_out_file) diff --git a/egs/yomdle_zh/v1/local/extract_features.sh b/egs/yomdle_zh/v1/local/extract_features.sh new file mode 100755 index 00000000000..f75837ae5b3 --- /dev/null +++ b/egs/yomdle_zh/v1/local/extract_features.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Copyright 2017 Yiwen Shao +# 2018 Ashish Arora + +nj=4 +cmd=run.pl +feat_dim=40 +fliplr=false +augment='no_aug' +num_channels=3 +echo "$0 $@" + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh || exit 1; + +data=$1 +featdir=$data/data +scp=$data/images.scp +logdir=$data/log + +mkdir -p $logdir +mkdir -p $featdir + +# make $featdir an absolute pathname +featdir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $featdir ${PWD}` + +for n in $(seq $nj); do + split_scps="$split_scps $logdir/images.$n.scp" +done + +# split images.scp +utils/split_scp.pl $scp $split_scps || exit 1; + +$cmd JOB=1:$nj $logdir/extract_features.JOB.log \ + image/ocr/make_features.py $logdir/images.JOB.scp \ + --allowed_len_file_path $data/allowed_lengths.txt \ + --feat-dim $feat_dim --num-channels $num_channels --fliplr $fliplr --augment_type $augment \| \ + copy-feats --compress=true --compression-method=7 \ + ark:- ark,scp:$featdir/images.JOB.ark,$featdir/images.JOB.scp + +## aggregates the output scp's to get feats.scp +for n in $(seq $nj); do + cat $featdir/images.$n.scp || exit 1; +done > $data/feats.scp || exit 1 diff --git a/egs/yomdle_zh/v1/local/gedi2csv.py b/egs/yomdle_zh/v1/local/gedi2csv.py new file mode 100755 index 00000000000..43a07421dd1 --- /dev/null +++ b/egs/yomdle_zh/v1/local/gedi2csv.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python3 + +""" +GEDI2CSV +Convert GEDI-type bounding boxes to CSV format + +GEDI Format Example: + + + + + + + + + +CSV Format Example +ID,name,col1,row1,col2,row2,col3,row3,col4,row4,confidence,truth,pgrot,bbrot,qual,script,lang +0,chinese_scanned_books_0001_0.png,99,41,99,14,754,14,754,41,100,凡我的邻人说是好的,有一大部分在我灵魂中却,0,0.0,0,,zh-cn +""" + +import logging +import os +import sys +import time +import glob +import csv +import imghdr +from PIL import Image +import argparse +import pdb +import cv2 +import numpy as np +import xml.etree.ElementTree as ET + +sin = np.sin +cos = np.cos +pi = np.pi + +def Rotate2D(pts, cnt, ang=90): + M = np.array([[cos(ang),-sin(ang)],[sin(ang),cos(ang)]]) + res = np.dot(pts-cnt,M)+cnt + return M, res + +def npbox2string(npar): + if np.shape(npar)[0] != 1: + print('Error during CSV conversion\n') + c1,r1 = npar[0][0],npar[0][1] + c2,r2 = npar[0][2],npar[0][3] + c3,r3 = npar[0][4],npar[0][5] + c4,r4 = npar[0][6],npar[0][7] + + return c1,r1,c2,r2,c3,r3,c4,r4 + +# cv2.minAreaRect() returns a Box2D structure which contains following detals - ( center (x,y), (width, height), angle of rotation ) +# Get 4 corners of the rectangle using cv2.boxPoints() + +class GEDI2CSV(): + + """ Initialize the extractor""" + def __init__(self, logger, args): + self._logger = logger + self._args = args + + """ + Segment image with GEDI bounding box information + """ + def csvfile(self, coords, polys, baseName, pgrot): + + """ for writing the files """ + writePath = self._args.outputDir + writePath = os.path.join(writePath,'') + if os.path.isdir(writePath) != True: + os.makedirs(writePath) + + rotlist = [] + + header=['ID','name','col1','row1','col2','row2','col3','row3','col4','row4','confidence','truth','pgrot','bbrot','qual','script','text_type'] + conf=100 + write_ctr = 0 + if len(coords) == 0 and len(polys) == 0: + self._logger.info('Found %s with no text content',(baseName)) + print('...Found %s with no text content' % (baseName)) + return + + strPos = writePath + baseName + + """ for each group of coordinates """ + for i in coords: + + [id,x,y,w,h,degrees,text,qual,script,text_type] = i + + contour = np.array([(x,y),(x+w,y),(x+w,y+h),(x,y+h)]) + + """ + First rotate around upper left corner based on orientationD keyword + """ + M, rot = Rotate2D(contour, np.array([x,y]), degrees*pi/180) + rot = np.int0(rot) + + # rot is the 8 points rotated by degrees + # pgrot is the rotation after extraction, so save + + # save rotated points to list or array + rot = np.reshape(rot,(-1,1)).T + c1,r1,c2,r2,c3,r3,c4,r4 = npbox2string(rot) + + text = text.replace(u'\ufeff','') + + bbrot = degrees + rotlist.append([id,baseName + '_' + id + '.png',c1,r1,c2,r2,c3,r3,c4,r4,conf,text,pgrot,bbrot,qual,script,text_type]) + + # if there are polygons, first save the text + for j in polys: + arr = [] + [id,poly_val,text,qual,script,text_type] = j + for i in poly_val: + arr.append(eval(i)) + + contour = np.asarray(arr) + convex = cv2.convexHull(contour) + rect = cv2.minAreaRect(convex) + box = cv2.boxPoints(rect) + box = np.int0(box) + box = np.reshape(box,(-1,1)).T + c1,r1,c2,r2,c3,r3,c4,r4 = npbox2string(box) + + bbrot = 0.0 + + rotlist.append([id,baseName + '_' + id + '.png',c1,r1,c2,r2,c3,r3,c4,r4,conf,text,pgrot,bbrot,qual,script,text_type]) + + # then write out all of list to file + with open(strPos + ".csv", "w", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(header) + for row in rotlist: + writer.writerow(row) + write_ctr += 1 + + return write_ctr + + +def main(args): + + startTime = time.clock() + + writePath = args.outputDir + if os.path.isdir(writePath) != True: + os.makedirs(writePath) + + """ Setup logging """ + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + if args.log: + handler = logging.FileHandler(args.log) + handler.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + + gtconverter = GEDI2CSV(logger, args) + namespaces = {"gedi" : "http://lamp.cfar.umd.edu/media/projects/GEDI/"} + keyCnt=0 + + fileCnt = 0 + line_write_ctr = 0 + line_error_ctr = 0 + + """ + Get all XML files in the directory and sub folders + """ + for root, dirnames, filenames in os.walk(args.inputDir, followlinks=True): + for file in filenames: + if file.lower().endswith('.xml'): + fullName = os.path.join(root,file) + baseName = os.path.splitext(fullName) + + fileCnt += 1 + + """ read the XML file """ + tree = ET.parse(fullName) + gedi_root = tree.getroot() + child = gedi_root.findall('gedi:DL_DOCUMENT',namespaces)[0] + totalpages = int(child.attrib['NrOfPages']) + coordinates=[] + polygons = [] + if args.ftype == 'boxed': + fileTypeStr = 'col' + elif args.ftype == 'transcribed': + fileTypeStr = 'Text_Content' + else: + print('Filetype must be either boxed or transcribed!') + logger.info('Filetype must be either boxed or transcribed!') + sys.exit(-1) + + if args.quality == 'both': + qualset = {'Regular','Low-Quality'} + elif args.quality == 'low': + qualset = {'Low-Quality'} + elif args.quality == 'regular': + qualset = {'Regular'} + else: + print('Quality must be both, low or regular!') + logger.info('Quality must be both, low or regular!') + sys.exit(-1) + + + + """ and for each page """ + for i, pgs in enumerate(child.iterfind('gedi:DL_PAGE',namespaces)): + + if 'GEDI_orientation' not in pgs.attrib: + pageRot=0 + else: + pageRot = int(pgs.attrib['GEDI_orientation']) + logger.info(' PAGE ROTATION %s, %s' % (fullName, str(pageRot))) + + """ find children for each page """ + for zone in pgs.findall('gedi:DL_ZONE',namespaces): + + if zone.attrib['gedi_type']=='Text' and zone.attrib['Type'] in \ + ('Machine_Print','Confusable_Allograph','Handwriting') and zone.attrib['Quality'] in qualset: + if zone.get('polygon'): + keyCnt+=1 + polygons.append([zone.attrib['id'],zone.get('polygon').split(';'), + zone.get('Text_Content'),zone.get('Quality'),zone.get('Script'),zone.get('Type')]) + elif zone.get(fileTypeStr) != None: + keyCnt+=1 + coord = [zone.attrib['id'],int(zone.attrib['col']),int(zone.attrib['row']), + int(zone.attrib['width']), int(zone.attrib['height']), + float(zone.get('orientationD',0.0)), + zone.get('Text_Content'),zone.get('Quality'),zone.get('Script'),zone.get('Type')] + coordinates.append(coord) + + if len(coordinates) > 0 or len(polygons) > 0: + line_write_ctr += gtconverter.csvfile(coordinates, polygons, os.path.splitext(file)[0], pageRot) + else: + print('...%s has no applicable content' % (baseName[0])) + + print('complete...total files %d, lines written %d' % (fileCnt, line_write_ctr)) + + +def parse_arguments(argv): + """ Args and defaults """ + parser = argparse.ArgumentParser() + + parser.add_argument('--inputDir', type=str, help='Input directory', required=True) + parser.add_argument('--outputDir', type=str, help='Output directory', required=True) + parser.add_argument('--ftype', type=str, help='GEDI file type (either "boxed" or "transcribed")', default='transcribed') + parser.add_argument('--quality', type=str, help='GEDI file quality (either "both" or "low" or "regular")', default='regular') + parser.add_argument('--log', type=str, help='Log directory', default='./GEDI2CSV_enriched.log') + + return parser.parse_args(argv) + +if __name__ == '__main__': + """ Run """ + main(parse_arguments(sys.argv[1:])) + + + + + + diff --git a/egs/yomdle_zh/v1/local/prepare_dict.sh b/egs/yomdle_zh/v1/local/prepare_dict.sh new file mode 100755 index 00000000000..65b2e7aa901 --- /dev/null +++ b/egs/yomdle_zh/v1/local/prepare_dict.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash + +# Copyright 2017 Hossein Hadian +# 2017 Chun Chieh Chang +# 2017 Ashish Arora + +# This script prepares the dictionary. + +set -e +dir=data/local/dict +data_dir=data + +. ./utils/parse_options.sh || exit 1; + +base_dir=$(echo "$DIRECTORY" | cut -d "/" -f2) + +mkdir -p $dir + +local/prepare_lexicon.py --data-dir $data_dir $dir + +cut -d' ' -f2- $dir/lexicon.txt | sed 's/SIL//g' | tr ' ' '\n' | sort -u | sed '/^$/d' >$dir/nonsilence_phones.txt || exit 1; + +echo ' SIL' >> $dir/lexicon.txt + +echo SIL > $dir/silence_phones.txt + +echo SIL >$dir/optional_silence.txt + +echo -n "" >$dir/extra_questions.txt diff --git a/egs/yomdle_zh/v1/local/prepare_lexicon.py b/egs/yomdle_zh/v1/local/prepare_lexicon.py new file mode 100755 index 00000000000..3ebb52e38f4 --- /dev/null +++ b/egs/yomdle_zh/v1/local/prepare_lexicon.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora +# Chun-Chieh Chang + +import argparse +import os + +parser = argparse.ArgumentParser(description="""Creates the list of characters and words in lexicon""") +parser.add_argument('dir', type=str, help='output path') +parser.add_argument('--data-dir', type=str, default='data', help='Path to text file') +args = parser.parse_args() + +### main ### +lex = {} +text_path = os.path.join(args.data_dir, 'train', 'text') +text_fh = open(text_path, 'r', encoding='utf-8') + +# Used specially for Chinese. +# Uses the ChangJie keyboard input method to create subword units for Chinese. +cj5_table = {} +with open('download/cj5-cc.txt', 'r', encoding='utf-8') as f: + for line in f: + line_vect = line.strip().split() + if not line_vect[0].startswith('yyy') and not line_vect[0].startswith('z'): + cj5_table[line_vect[1]] = "cj5_" + " cj5_".join(list(line_vect[0])) + +with open(text_path, 'r', encoding='utf-8') as f: + for line in f: + line_vect = line.strip().split() + for i in range(1, len(line_vect)): + characters = list(line_vect[i]) + # Put SIL instead of "|". Because every "|" in the beginning of the words is for initial-space of that word + characters = " ".join([ 'SIL' if char == '|' else cj5_table[char] if char in cj5_table else char for char in characters]) + characters = characters.replace('#','') + lex[line_vect[i]] = characters + +with open(os.path.join(args.dir, 'lexicon.txt'), 'w', encoding='utf-8') as fp: + for key in sorted(lex): + fp.write(key + " " + lex[key] + "\n") diff --git a/egs/yomdle_zh/v1/local/process_data.py b/egs/yomdle_zh/v1/local/process_data.py new file mode 100755 index 00000000000..8964af8890a --- /dev/null +++ b/egs/yomdle_zh/v1/local/process_data.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Ashish Arora +# 2018 Chun Chieh Chang + +""" This script reads the extracted Farsi OCR (yomdle and slam) database files + and creates the following files (for the data subset selected via --dataset): + text, utt2spk, images.scp. + Eg. local/process_data.py data/download/ data/local/splits/train.txt data/train + Eg. text file: english_phone_books_0001_1 To sum up, then, it would appear that + utt2spk file: english_phone_books_0001_0 english_phone_books_0001 + images.scp file: english_phone_books_0001_0 \ + data/download/truth_line_image/english_phone_books_0001_0.png +""" + +import argparse +import os +import sys +import csv +import itertools +import unicodedata + +parser = argparse.ArgumentParser(description="Creates text, utt2spk, and images.scp files") +parser.add_argument('database_path', type=str, help='Path to data') +parser.add_argument('out_dir', type=str, help='directory to output files') +parser.add_argument('--head', type=int, default=-1, help='limit on number of synth data') +args = parser.parse_args() + +### main ### +print("Processing '{}' data...".format(args.out_dir)) + +text_file = os.path.join(args.out_dir, 'text') +text_fh = open(text_file, 'w', encoding='utf-8') +utt2spk_file = os.path.join(args.out_dir, 'utt2spk') +utt2spk_fh = open(utt2spk_file, 'w', encoding='utf-8') +image_file = os.path.join(args.out_dir, 'images.scp') +image_fh = open(image_file, 'w', encoding='utf-8') + +count = 0 +for filename in sorted(os.listdir(os.path.join(args.database_path, 'truth_csv'))): + if filename.endswith('.csv') and (count < args.head or args.head < 0): + count = count + 1 + csv_filepath = os.path.join(args.database_path, 'truth_csv', filename) + csv_file = open(csv_filepath, 'r', encoding='utf-8') + row_count = 0 + for row in csv.reader(csv_file): + if row_count == 0: + row_count = 1 + continue + image_id = os.path.splitext(row[1])[0] + image_filepath = os.path.join(args.database_path, 'truth_line_image', row[1]) + text = unicodedata.normalize('NFC', row[11]).replace('\n', '') + if os.path.isfile(image_filepath) and os.stat(image_filepath).st_size != 0 and text: + text_fh.write(image_id + ' ' + text + '\n') + utt2spk_fh.write(image_id + ' ' + '_'.join(image_id.split('_')[:-1]) + '\n') + image_fh.write(image_id + ' ' + image_filepath + ' ' + row[13] + '\n') diff --git a/egs/yomdle_zh/v1/local/score.sh b/egs/yomdle_zh/v1/local/score.sh new file mode 100755 index 00000000000..f2405205f02 --- /dev/null +++ b/egs/yomdle_zh/v1/local/score.sh @@ -0,0 +1,5 @@ +#!/bin/bash + + +steps/scoring/score_kaldi_wer.sh --max-lmwt 10 "$@" +steps/scoring/score_kaldi_cer.sh --max-lmwt 10 --stage 2 "$@" diff --git a/egs/yomdle_zh/v1/local/train_lm.sh b/egs/yomdle_zh/v1/local/train_lm.sh new file mode 100755 index 00000000000..bc738f217da --- /dev/null +++ b/egs/yomdle_zh/v1/local/train_lm.sh @@ -0,0 +1,110 @@ +#!/bin/bash + +# Copyright 2016 Vincent Nguyen +# 2016 Johns Hopkins University (author: Daniel Povey) +# 2017 Ashish Arora +# 2017 Hossein Hadian +# Apache 2.0 +# +# This script trains a LM on the YOMDLE training transcriptions. +# It is based on the example scripts distributed with PocoLM + +# It will check if pocolm is installed and if not will proceed with installation + +set -e +stage=0 +dir=data/local/local_lm +data_dir=data + +echo "$0 $@" # Print the command line for logging +. ./utils/parse_options.sh || exit 1; + +lm_dir=${dir}/data + + +mkdir -p $dir +. ./path.sh || exit 1; # for KALDI_ROOT +export PATH=$KALDI_ROOT/tools/pocolm/scripts:$PATH +( # First make sure the pocolm toolkit is installed. + cd $KALDI_ROOT/tools || exit 1; + if [ -d pocolm ]; then + echo Not installing the pocolm toolkit since it is already there. + else + echo "$0: Please install the PocoLM toolkit with: " + echo " cd ../../../tools; extras/install_pocolm.sh; cd -" + exit 1; + fi +) || exit 1; + +bypass_metaparam_optim_opt= +# If you want to bypass the metaparameter optimization steps with specific metaparameters +# un-comment the following line, and change the numbers to some appropriate values. +# You can find the values from output log of train_lm.py. +# These example numbers of metaparameters is for 4-gram model (with min-counts) +# running with train_lm.py. +# The dev perplexity should be close to the non-bypassed model. +#bypass_metaparam_optim_opt= +# Note: to use these example parameters, you may need to remove the .done files +# to make sure the make_lm_dir.py be called and tain only 3-gram model +#for order in 3; do +#rm -f ${lm_dir}/${num_word}_${order}.pocolm/.done + +if [ $stage -le 0 ]; then + mkdir -p ${dir}/data + mkdir -p ${dir}/data/text + + echo "$0: Getting the Data sources" + + rm ${dir}/data/text/* 2>/dev/null || true + + # Note: the name 'dev' is treated specially by pocolm, it automatically + # becomes the dev set. + nr=`cat $data_dir/train/text | wc -l` + nr_dev=$(($nr / 10 )) + nr_train=$(( $nr - $nr_dev )) + + # use the training data as an additional data source. + # we can later fold the dev data into this. + head -n $nr_train $data_dir/train/text | cut -d " " -f 2- > ${dir}/data/text/train.txt + tail -n $nr_dev $data_dir/train/text | cut -d " " -f 2- > ${dir}/data/text/dev.txt + + # for reporting perplexities, we'll use the "real" dev set. + # (the validation data is used as ${dir}/data/text/dev.txt to work + # out interpolation weights.) + # note, we can't put it in ${dir}/data/text/, because then pocolm would use + # it as one of the data sources. + cut -d " " -f 2- < $data_dir/test/text > ${dir}/data/real_dev_set.txt + + # get the wordlist from MADCAT text + cat ${dir}/data/text/train.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + cat ${dir}/data/word_count | awk '{print $2}' > ${dir}/data/wordlist +fi + +order=3 + +if [ $stage -le 1 ]; then + # decide on the vocabulary. + # Note: you'd use --wordlist if you had a previously determined word-list + # that you wanted to use. + # Note: if you have more than one order, use a certain amount of words as the + # vocab and want to restrict max memory for 'sort', + echo "$0: training the unpruned LM" + min_counts='train=1' + wordlist=${dir}/data/wordlist + + lm_name="`basename ${wordlist}`_${order}" + if [ -n "${min_counts}" ]; then + lm_name+="_`echo ${min_counts} | tr -s "[:blank:]" "_" | tr "=" "-"`" + fi + unpruned_lm_dir=${lm_dir}/${lm_name}.pocolm + train_lm.py --wordlist=${wordlist} --num-splits=5 --warm-start-ratio=1 \ + --min-counts="$min_counts" \ + --limit-unk-history=true \ + ${bypass_metaparam_optim_opt} \ + ${dir}/data/text ${order} ${lm_dir}/work ${unpruned_lm_dir} + + get_data_prob.py ${dir}/data/real_dev_set.txt ${unpruned_lm_dir} 2>&1 | grep -F '[perplexity' + + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${unpruned_lm_dir} | gzip -c > ${dir}/data/arpa/${order}gram_unpruned.arpa.gz +fi diff --git a/egs/yomdle_zh/v1/local/train_lm_lr.sh b/egs/yomdle_zh/v1/local/train_lm_lr.sh new file mode 100755 index 00000000000..b95e6474b18 --- /dev/null +++ b/egs/yomdle_zh/v1/local/train_lm_lr.sh @@ -0,0 +1,113 @@ +#!/bin/bash + +# Copyright 2016 Vincent Nguyen +# 2016 Johns Hopkins University (author: Daniel Povey) +# 2017 Ashish Arora +# 2017 Hossein Hadian +# Apache 2.0 +# +# This script trains a LM on the YOMDLE+Extra training transcriptions. +# It is based on the example scripts distributed with PocoLM + +# It will check if pocolm is installed and if not will proceed with installation + +set -e +stage=0 +dir=data/local/local_lm +data_dir=data +extra_lm=download/extra_lm.txt +order=3 + +echo "$0 $@" # Print the command line for logging +. ./utils/parse_options.sh || exit 1; + +lm_dir=${dir}/data + + +mkdir -p $dir +. ./path.sh || exit 1; # for KALDI_ROOT +export PATH=$KALDI_ROOT/tools/pocolm/scripts:$PATH +( # First make sure the pocolm toolkit is installed. + cd $KALDI_ROOT/tools || exit 1; + if [ -d pocolm ]; then + echo Not installing the pocolm toolkit since it is already there. + else + echo "$0: Please install the PocoLM toolkit with: " + echo " cd ../../../tools; extras/install_pocolm.sh; cd -" + exit 1; + fi +) || exit 1; + +bypass_metaparam_optim_opt= +# If you want to bypass the metaparameter optimization steps with specific metaparameters +# un-comment the following line, and change the numbers to some appropriate values. +# You can find the values from output log of train_lm.py. +# These example numbers of metaparameters is for 4-gram model (with min-counts) +# running with train_lm.py. +# The dev perplexity should be close to the non-bypassed model. +#bypass_metaparam_optim_opt= +# Note: to use these example parameters, you may need to remove the .done files +# to make sure the make_lm_dir.py be called and tain only 3-gram model +#for order in 3; do +#rm -f ${lm_dir}/${num_word}_${order}.pocolm/.done + +if [ $stage -le 0 ]; then + mkdir -p ${dir}/data + mkdir -p ${dir}/data/text + + echo "$0: Getting the Data sources" + + rm ${dir}/data/text/* 2>/dev/null || true + + cat ${extra_lm} | utils/lang/bpe/prepend_words.py | python3 utils/lang/bpe/apply_bpe.py -c $data_dir/train/bpe.out | sed 's/@@//g' > ${dir}/data/text/extra_lm.txt + + # Note: the name 'dev' is treated specially by pocolm, it automatically + # becomes the dev set. + nr=`cat $data_dir/train/text | wc -l` + nr_dev=$(($nr / 10 )) + nr_train=$(( $nr - $nr_dev )) + + # use the training data as an additional data source. + # we can later fold the dev data into this. + head -n $nr_train $data_dir/train/text | cut -d " " -f 2- > ${dir}/data/text/train.txt + tail -n $nr_dev $data_dir/train/text | cut -d " " -f 2- > ${dir}/data/text/dev.txt + + # for reporting perplexities, we'll use the "real" dev set. + # (the validation data is used as ${dir}/data/text/dev.txt to work + # out interpolation weights.) + # note, we can't put it in ${dir}/data/text/, because then pocolm would use + # it as one of the data sources. + cut -d " " -f 2- < $data_dir/test/text > ${dir}/data/real_dev_set.txt + + # get the wordlist from MADCAT text + cat ${dir}/data/text/{train,extra_lm}.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + #cat ${dir}/data/text/extra_fa.txt | tr '[:space:]' '[\n*]' | grep -v "^\s*$" | sort | uniq -c | sort -bnr > ${dir}/data/word_count + cat ${dir}/data/word_count | awk '{print $2}' > ${dir}/data/wordlist +fi + +if [ $stage -le 1 ]; then + # decide on the vocabulary. + # Note: you'd use --wordlist if you had a previously determined word-list + # that you wanted to use. + # Note: if you have more than one order, use a certain amount of words as the + # vocab and want to restrict max memory for 'sort', + echo "$0: training the unpruned LM" + min_counts='extra_lm=10 train=1' + wordlist=${dir}/data/wordlist + + lm_name="`basename ${wordlist}`_${order}" + if [ -n "${min_counts}" ]; then + lm_name+="_`echo ${min_counts} | tr -s "[:blank:]" "_" | tr "=" "-"`" + fi + unpruned_lm_dir=${lm_dir}/${lm_name}.pocolm + train_lm.py --wordlist=${wordlist} --num-splits=30 --warm-start-ratio=1 \ + --min-counts="$min_counts" \ + --limit-unk-history=true \ + ${bypass_metaparam_optim_opt} \ + ${dir}/data/text ${order} ${lm_dir}/work ${unpruned_lm_dir} + + get_data_prob.py ${dir}/data/real_dev_set.txt ${unpruned_lm_dir} 2>&1 | grep -F '[perplexity' + + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${unpruned_lm_dir} | gzip -c > ${dir}/data/arpa/${order}gram_unpruned.arpa.gz +fi diff --git a/egs/yomdle_zh/v1/local/wer_output_filter b/egs/yomdle_zh/v1/local/wer_output_filter new file mode 100755 index 00000000000..08d5563bca4 --- /dev/null +++ b/egs/yomdle_zh/v1/local/wer_output_filter @@ -0,0 +1,151 @@ +#!/usr/bin/env perl +# Copyright 2012-2014 Johns Hopkins University (Author: Yenda Trmal) +# Apache 2.0 + +use utf8; + +use open qw(:encoding(utf8)); +binmode STDIN, ":utf8"; +binmode STDOUT, ":utf8"; +binmode STDERR, ":utf8"; + +# Arabic-specific normalization +while (<>) { + @F = split " "; + print "$F[0] "; + foreach $s (@F[1..$#F]) { + # Normalize tabs, spaces, and no-break spaces + $s =~ s/[\x{0009}\x{0020}\x{00A0}]+/ /g; + # Normalize "dots"/"filled-circles" to periods + $s =~ s/[\x{25CF}\x{u2022}\x{2219}]+/\x{002E}/g; + # Normalize dashes to regular hyphen + $s =~ s/[\x{2010}\x{2011}\x{2012}\x{2013}\x{2014}\x{2015}]+/\x{002D}/g; + # Normalize various parenthesis to regular parenthesis + $s =~ s/\x{UFF09}/\x{0029}/g; + $s =~ s/\x{UFF08}/\x{0028}/g; + + # Convert various presentation forms to base form + $s =~ s/[\x{FED1}\x{FED3}\x{FED4}\x{FED2}]+/\x{0641}/g; + $s =~ s/[\x{FBB0}\x{FBB1}]+/\x{06D3}/g; + $s =~ s/[\x{FECD}\x{FECF}\x{FED0}\x{FECE}]+/\x{063A}/g; + $s =~ s/[\x{FBDD}]+/\x{0677}/g; + $s =~ s/[\x{FBA6}\x{FBA8}\x{FBA9}\x{FBA7}]+/\x{06C1}/g; + $s =~ s/[\x{FEC1}\x{FEC3}\x{FEC4}\x{FEC2}]+/\x{0637}/g; + $s =~ s/[\x{FE85}\x{FE86}]+/\x{0624}/g; + $s =~ s/[\x{FEA5}\x{FEA7}\x{FEA8}\x{FEA6}]+/\x{062E}/g; + $s =~ s/[\x{FBD9}\x{FBDA}]+/\x{06C6}/g; + $s =~ s/[\x{FE8F}\x{FE91}\x{FE92}\x{FE90}]+/\x{0628}/g; + $s =~ s/[\x{FEED}\x{FEEE}]+/\x{0648}/g; + $s =~ s/[\x{FE99}\x{FE9B}\x{FE9C}\x{FE9A}]+/\x{062B}/g; + $s =~ s/[\x{FEBD}\x{FEBF}\x{FEC0}\x{FEBE}]+/\x{0636}/g; + $s =~ s/[\x{FEE5}\x{FEE7}\x{FEE8}\x{FEE6}]+/\x{0646}/g; + $s =~ s/[\x{FBFC}\x{FBFE}\x{FBFF}\x{FBFD}]+/\x{06CC}/g; + $s =~ s/[\x{FBA4}\x{FBA5}]+/\x{06C0}/g; + $s =~ s/[\x{FB72}\x{FB74}\x{FB75}\x{FB73}]+/\x{0684}/g; + $s =~ s/[\x{FBD3}\x{FBD5}\x{FBD6}\x{FBD4}]+/\x{06AD}/g; + $s =~ s/[\x{FB6A}\x{FB6C}\x{FB6D}\x{FB6B}]+/\x{06A4}/g; + $s =~ s/[\x{FB66}\x{FB68}\x{FB69}\x{FB67}]+/\x{0679}/g; + $s =~ s/[\x{FB5E}\x{FB60}\x{FB61}\x{FB5F}]+/\x{067A}/g; + $s =~ s/[\x{FB88}\x{FB89}]+/\x{0688}/g; + $s =~ s/[\x{FB7E}\x{FB80}\x{FB81}\x{FB7F}]+/\x{0687}/g; + $s =~ s/[\x{FB8E}\x{FB90}\x{FB91}\x{FB8F}]+/\x{06A9}/g; + $s =~ s/[\x{FB86}\x{FB87}]+/\x{068E}/g; + $s =~ s/[\x{FE83}\x{FE84}]+/\x{0623}/g; + $s =~ s/[\x{FB8A}\x{FB8B}]+/\x{0698}/g; + $s =~ s/[\x{FED5}\x{FED7}\x{FED8}\x{FED6}]+/\x{0642}/g; + $s =~ s/[\x{FED9}\x{FEDB}\x{FEDC}\x{FEDA}]+/\x{0643}/g; + $s =~ s/[\x{FBE0}\x{FBE1}]+/\x{06C5}/g; + $s =~ s/[\x{FEB9}\x{FEBB}\x{FEBC}\x{FEBA}]+/\x{0635}/g; + $s =~ s/[\x{FEC5}\x{FEC7}\x{FEC8}\x{FEC6}]+/\x{0638}/g; + $s =~ s/[\x{FE8D}\x{FE8E}]+/\x{0627}/g; + $s =~ s/[\x{FB9A}\x{FB9C}\x{FB9D}\x{FB9B}]+/\x{06B1}/g; + $s =~ s/[\x{FEAD}\x{FEAE}]+/\x{0631}/g; + $s =~ s/[\x{FEF1}\x{FEF3}\x{FEF4}\x{FEF2}]+/\x{064A}/g; + $s =~ s/[\x{FE93}\x{FE94}]+/\x{0629}/g; + $s =~ s/[\x{FBE4}\x{FBE6}\x{FBE7}\x{FBE5}]+/\x{06D0}/g; + $s =~ s/[\x{FE89}\x{FE8B}\x{FE8C}\x{FE8A}]+/\x{0626}/g; + $s =~ s/[\x{FB84}\x{FB85}]+/\x{068C}/g; + $s =~ s/[\x{FE9D}\x{FE9F}\x{FEA0}\x{FE9E}]+/\x{062C}/g; + $s =~ s/[\x{FB82}\x{FB83}]+/\x{068D}/g; + $s =~ s/[\x{FEA1}\x{FEA3}\x{FEA4}\x{FEA2}]+/\x{062D}/g; + $s =~ s/[\x{FB52}\x{FB54}\x{FB55}\x{FB53}]+/\x{067B}/g; + $s =~ s/[\x{FB92}\x{FB94}\x{FB95}\x{FB93}]+/\x{06AF}/g; + $s =~ s/[\x{FB7A}\x{FB7C}\x{FB7D}\x{FB7B}]+/\x{0686}/g; + $s =~ s/[\x{FBDB}\x{FBDC}]+/\x{06C8}/g; + $s =~ s/[\x{FB56}\x{FB58}\x{FB59}\x{FB57}]+/\x{067E}/g; + $s =~ s/[\x{FEB5}\x{FEB7}\x{FEB8}\x{FEB6}]+/\x{0634}/g; + $s =~ s/[\x{FBE2}\x{FBE3}]+/\x{06C9}/g; + $s =~ s/[\x{FB96}\x{FB98}\x{FB99}\x{FB97}]+/\x{06B3}/g; + $s =~ s/[\x{FE80}]+/\x{0621}/g; + $s =~ s/[\x{FBAE}\x{FBAF}]+/\x{06D2}/g; + $s =~ s/[\x{FB62}\x{FB64}\x{FB65}\x{FB63}]+/\x{067F}/g; + $s =~ s/[\x{FEE9}\x{FEEB}\x{FEEC}\x{FEEA}]+/\x{0647}/g; + $s =~ s/[\x{FE81}\x{FE82}]+/\x{0622}/g; + $s =~ s/[\x{FBDE}\x{FBDF}]+/\x{06CB}/g; + $s =~ s/[\x{FE87}\x{FE88}]+/\x{0625}/g; + $s =~ s/[\x{FB6E}\x{FB70}\x{FB71}\x{FB6F}]+/\x{06A6}/g; + $s =~ s/[\x{FBA0}\x{FBA2}\x{FBA3}\x{FBA1}]+/\x{06BB}/g; + $s =~ s/[\x{FBAA}\x{FBAC}\x{FBAD}\x{FBAB}]+/\x{06BE}/g; + $s =~ s/[\x{FEA9}\x{FEAA}]+/\x{062F}/g; + $s =~ s/[\x{FEE1}\x{FEE3}\x{FEE4}\x{FEE2}]+/\x{0645}/g; + $s =~ s/[\x{FEEF}\x{FBE8}\x{FBE9}\x{FEF0}]+/\x{0649}/g; + $s =~ s/[\x{FB8C}\x{FB8D}]+/\x{0691}/g; + $s =~ s/[\x{FB76}\x{FB78}\x{FB79}\x{FB77}]+/\x{0683}/g; + $s =~ s/[\x{FB5A}\x{FB5C}\x{FB5D}\x{FB5B}]+/\x{0680}/g; + $s =~ s/[\x{FB9E}\x{FB9F}]+/\x{06BA}/g; + $s =~ s/[\x{FEC9}\x{FECB}\x{FECC}\x{FECA}]+/\x{0639}/g; + $s =~ s/[\x{FEDD}\x{FEDF}\x{FEE0}\x{FEDE}]+/\x{0644}/g; + $s =~ s/[\x{FB50}\x{FB51}]+/\x{0671}/g; + $s =~ s/[\x{FEB1}\x{FEB3}\x{FEB4}\x{FEB2}]+/\x{0633}/g; + $s =~ s/[\x{FE95}\x{FE97}\x{FE98}\x{FE96}]+/\x{062A}/g; + $s =~ s/[\x{FBD7}\x{FBD8}]+/\x{06C7}/g; + $s =~ s/[\x{FEAF}\x{FEB0}]+/\x{0632}/g; + $s =~ s/[\x{FEAB}\x{FEAC}]+/\x{0630}/g; + + # Remove tatweel + $s =~ s/\x{0640}//g; + # Remove vowels and hamza + $s =~ s/[\x{064B}-\x{0655}]+//g; + # Remove right-to-left and left-to-right + $s =~ s/[\x{200F}\x{200E}]+//g; + # Arabic Keheh to Arabic Kaf + $s =~ s/\x{06A9}/\x{0643}/g; + # Arabic Yeh to Farsi Yeh + $s =~ s/\x{064A}/\x{06CC}/g; + # Decompose RIAL + $s =~ s/\x{FDFC}/\x{0631}\x{06CC}\x{0627}\x{0644}/g; + # Farsi arabic-indic digits to arabic-indic digits + $s =~ s/\x{06F0}/\x{0660}/g; + $s =~ s/\x{06F1}/\x{0661}/g; + $s =~ s/\x{06F2}/\x{0662}/g; + $s =~ s/\x{06F3}/\x{0663}/g; + $s =~ s/\x{06F4}/\x{0664}/g; + $s =~ s/\x{06F5}/\x{0665}/g; + $s =~ s/\x{06F6}/\x{0666}/g; + $s =~ s/\x{06F7}/\x{0667}/g; + $s =~ s/\x{06F8}/\x{0668}/g; + $s =~ s/\x{06F9}/\x{0669}/g; + # Arabic-indic digits to digits + $s =~ s/\x{0660}/0/g; + $s =~ s/\x{0661}/1/g; + $s =~ s/\x{0662}/2/g; + $s =~ s/\x{0663}/3/g; + $s =~ s/\x{0664}/4/g; + $s =~ s/\x{0665}/5/g; + $s =~ s/\x{0666}/6/g; + $s =~ s/\x{0667}/7/g; + $s =~ s/\x{0668}/8/g; + $s =~ s/\x{0669}/9/g; + # Arabic comma to comma + $s =~ s/\x{060C}/\x{002C}/g; + + $s =~ s/\|/ /g; + if ($s ne "") { + print "$s"; + } else { + print ""; + } + } + print "\n"; +} + diff --git a/egs/yomdle_zh/v1/local/yomdle2csv.py b/egs/yomdle_zh/v1/local/yomdle2csv.py new file mode 100755 index 00000000000..3641de90324 --- /dev/null +++ b/egs/yomdle_zh/v1/local/yomdle2csv.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 + +""" +GEDI2CSV +Convert GEDI-type bounding boxes to CSV format + +GEDI Format Example: + + + + + + + + + +CSV Format Example +ID,name,col1,row1,col2,row2,col3,row3,col4,row4,confidence,truth,pgrot,bbrot,qual,script,lang +0,chinese_scanned_books_0001_0.png,99,41,99,14,754,14,754,41,100,凡我的邻人说是好的,有一大部分在我灵魂中却,0,0.0,0,,zh-cn +""" + +import logging +import os +import sys +import time +import glob +import csv +import imghdr +from PIL import Image +import argparse +import pdb +import cv2 +import numpy as np +import xml.etree.ElementTree as ET + +sin = np.sin +cos = np.cos +pi = np.pi + +def Rotate2D(pts, cnt, ang=90): + M = np.array([[cos(ang),-sin(ang)],[sin(ang),cos(ang)]]) + res = np.dot(pts-cnt,M)+cnt + return M, res + +def npbox2string(npar): + if np.shape(npar)[0] != 1: + print('Error during CSV conversion\n') + c1,r1 = npar[0][0],npar[0][1] + c2,r2 = npar[0][2],npar[0][3] + c3,r3 = npar[0][4],npar[0][5] + c4,r4 = npar[0][6],npar[0][7] + + return c1,r1,c2,r2,c3,r3,c4,r4 + +# cv2.minAreaRect() returns a Box2D structure which contains following detals - ( center (x,y), (width, height), angle of rotation ) +# Get 4 corners of the rectangle using cv2.boxPoints() + +class GEDI2CSV(): + + """ Initialize the extractor""" + def __init__(self, logger, args): + self._logger = logger + self._args = args + + """ + Segment image with GEDI bounding box information + """ + def csvfile(self, coords, polys, baseName, pgrot): + + """ for writing the files """ + writePath = self._args.outputDir + if os.path.isdir(writePath) != True: + os.makedirs(writePath) + + rotlist = [] + + header=['ID','name','col1','row1','col2','row2','col3','row3','col4','row4','confidence','truth','pgrot','bbrot','qual','script','lang'] + conf=100 + pgrot = 0 + bbrot = 0 + qual = 0 + script = '' + + write_ctr = 0 + if len(coords) == 0 and len(polys) == 0: + self._logger.info('Found %s with no text content',(baseName)) + print('...Found %s with no text content' % (baseName)) + return + + strPos = writePath + baseName + + for j in polys: + try: + arr = [] + [id,poly_val,text,qual,lang] = j + script=None + #print(j) + for i in poly_val: + if len(i.strip()) > 0: + #print(i) + arr.append(eval(i)) + + contour = np.asarray(arr) + #print(contour) + convex = cv2.convexHull(contour) + rect = cv2.minAreaRect(convex) + box = cv2.boxPoints(rect) + box = np.int0(box) + box = np.reshape(box,(-1,1)).T + c1,r1,c2,r2,c3,r3,c4,r4 = npbox2string(box) + + bbrot = 0.0 + + rotlist.append([id,baseName + '_' + id + '.png',c1,r1,c2,r2,c3,r3,c4,r4,conf,text,pgrot,bbrot,qual,script,lang]) + + except: + print('...polygon error %s, %s' % (j, baseName)) + continue + + # then write out all of list to file + with open(strPos + ".csv", "w", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(header) + for row in rotlist: + writer.writerow(row) + write_ctr += 1 + + return write_ctr + + +def main(args): + + startTime = time.clock() + + writePath = args.outputDir + print('write to %s' % (writePath)) + if os.path.isdir(writePath) != True: + os.makedirs(writePath) + + """ Setup logging """ + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + if args.log: + handler = logging.FileHandler(args.log) + handler.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + + gtconverter = GEDI2CSV(logger, args) + namespaces = {"gedi" : "http://lamp.cfar.umd.edu/media/projects/GEDI/"} + keyCnt=0 + + fileCnt = 0 + line_write_ctr = 0 + line_error_ctr = 0 + file_error_ctr = 0 + """ + Get all XML files in the directory and sub folders + """ + print('reading %s' % (args.inputDir)) + for root, dirnames, filenames in os.walk(args.inputDir, followlinks=True): + for file in filenames: + if file.lower().endswith('.xml'): + fullName = os.path.join(root,file) + baseName = os.path.splitext(fullName) + + fileCnt += 1 + + try: + """ read the XML file """ + tree = ET.parse(fullName) + except: + print('...ERROR parsing %s' % (fullName)) + file_error_ctr += 1 + continue + + gedi_root = tree.getroot() + child = gedi_root.findall('gedi:DL_DOCUMENT',namespaces)[0] + totalpages = int(child.attrib['NrOfPages']) + coordinates=[] + polygons = [] + + """ and for each page """ + for i, pgs in enumerate(child.iterfind('gedi:DL_PAGE',namespaces)): + + if 'GEDI_orientation' not in pgs.attrib: + pageRot=0 + else: + pageRot = int(pgs.attrib['GEDI_orientation']) + logger.info(' PAGE ROTATION %s, %s' % (fullName, str(pageRot))) + + """ find children for each page """ + for zone in pgs.findall('gedi:DL_ZONE',namespaces): + + if zone.attrib['gedi_type']=='Text' : + if zone.get('polygon'): + keyCnt+=1 + polygons.append([zone.attrib['id'],zone.get('polygon').split(';'), + zone.get('Text_Content'),zone.get('Illegible'),zone.get('Language')]) + else: + print('...Not polygon') + + + if len(coordinates) > 0 or len(polygons) > 0: + line_write_ctr += gtconverter.csvfile(coordinates, polygons, os.path.splitext(file)[0], pageRot) + else: + print('...%s has no text content' % (baseName[0])) + + + print('complete...total files %d, lines written %d, img errors %d, line error %d' % (fileCnt, line_write_ctr, file_error_ctr, line_error_ctr)) + + +def parse_arguments(argv): + """ Args and defaults """ + parser = argparse.ArgumentParser() + + parser.add_argument('--inputDir', type=str, help='Input directory', default='/data/YOMDLE/final_arabic/xml') + parser.add_argument('--outputDir', type=str, help='Output directory', default='/exp/YOMDLE/final_arabic/csv_truth/') + parser.add_argument('--log', type=str, help='Log directory', default='/exp/logs.txt') + + return parser.parse_args(argv) + + +if __name__ == '__main__': + """ Run """ + main(parse_arguments(sys.argv[1:])) diff --git a/egs/yomdle_zh/v1/path.sh b/egs/yomdle_zh/v1/path.sh new file mode 100644 index 00000000000..2d17b17a84a --- /dev/null +++ b/egs/yomdle_zh/v1/path.sh @@ -0,0 +1,6 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C diff --git a/egs/yomdle_zh/v1/run.sh b/egs/yomdle_zh/v1/run.sh new file mode 100755 index 00000000000..128f15694cc --- /dev/null +++ b/egs/yomdle_zh/v1/run.sh @@ -0,0 +1,120 @@ +#!/bin/bash + +set -e +stage=0 +nj=60 + +database_slam=/export/corpora5/slam/SLAM/Chinese/transcribed +database_yomdle=/export/corpora5/slam/YOMDLE/final_chinese +download_dir=data_yomdle_chinese/download/ +extra_lm=download/extra_lm.txt +data_dir=data_yomdle_chinese +exp_dir=exp_yomdle_chinese + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + local/create_download.sh --database-slam $database_slam \ + --database-yomdle $database_yomdle \ + --slam-dir download/slam_chinese \ + --yomdle-dir download/yomdle_chinese +fi + +if [ $stage -le 0 ]; then + mkdir -p data_slam_chinese/slam + mkdir -p data_yomdle_chinese/yomdle + local/process_data.py download/slam_chinese data_slam_chinese/slam + local/process_data.py download/yomdle_chinese data_yomdle_chinese/yomdle + ln -s ../data_slam_chinese/slam ${data_dir}/test + ln -s ../data_yomdle_chinese/yomdle ${data_dir}/train + image/fix_data_dir.sh ${data_dir}/test + image/fix_data_dir.sh ${data_dir}/train +fi + +mkdir -p $data_dir/{train,test}/data +if [ $stage -le 1 ]; then + echo "$0: Obtaining image groups. calling get_image2num_frames" + echo "Date: $(date)." + image/get_image2num_frames.py --feat-dim 60 $data_dir/train + image/get_allowed_lengths.py --frame-subsampling-factor 4 10 $data_dir/train + + for datasplit in train test; do + echo "$0: Extracting features and calling compute_cmvn_stats for dataset: $datasplit. " + echo "Date: $(date)." + local/extract_features.sh --nj $nj --cmd "$cmd" \ + --feat-dim 60 --num-channels 3 \ + $data_dir/${datasplit} + steps/compute_cmvn_stats.sh $data_dir/${datasplit} || exit 1; + done + + echo "$0: Fixing data directory for train dataset" + echo "Date: $(date)." + utils/fix_data_dir.sh $data_dir/train +fi + +if [ $stage -le 2 ]; then + for datasplit in train; do + echo "$(date) stage 2: Performing augmentation, it will double training data" + local/augment_data.sh --nj $nj --cmd "$cmd" --feat-dim 60 $data_dir/${datasplit} $data_dir/${datasplit}_aug $data_dir + steps/compute_cmvn_stats.sh $data_dir/${datasplit}_aug || exit 1; + done +fi + +if [ $stage -le 3 ]; then + echo "$0: Preparing dictionary and lang..." + if [ ! -f $data_dir/train/bpe.out ]; then + cut -d' ' -f2- $data_dir/train/text | utils/lang/bpe/prepend_words.py | python3 utils/lang/bpe/learn_bpe.py -s 700 > $data_dir/train/bpe.out + for datasplit in test train train_aug; do + cut -d' ' -f1 $data_dir/$datasplit/text > $data_dir/$datasplit/ids + cut -d' ' -f2- $data_dir/$datasplit/text | utils/lang/bpe/prepend_words.py | python3 utils/lang/bpe/apply_bpe.py -c $data_dir/train/bpe.out | sed 's/@@//g' > $data_dir/$datasplit/bpe_text + mv $data_dir/$datasplit/text $data_dir/$datasplit/text.old + paste -d' ' $data_dir/$datasplit/ids $data_dir/$datasplit/bpe_text > $data_dir/$datasplit/text + done + fi + + local/prepare_dict.sh --data-dir $data_dir --dir $data_dir/local/dict + # This recipe uses byte-pair encoding, the silences are part of the words' pronunciations. + # So we set --sil-prob to 0.0 + utils/prepare_lang.sh --num-sil-states 4 --num-nonsil-states 8 --sil-prob 0.0 --position-dependent-phones false \ + $data_dir/local/dict "" $data_dir/lang/temp $data_dir/lang + utils/lang/bpe/add_final_optional_silence.sh --final-sil-prob 0.5 $data_dir/lang +fi + +if [ $stage -le 4 ]; then + echo "$0: Estimating a language model for decoding..." + local/train_lm.sh --data-dir $data_dir --dir $data_dir/local/local_lm + utils/format_lm.sh $data_dir/lang $data_dir/local/local_lm/data/arpa/3gram_unpruned.arpa.gz \ + $data_dir/local/dict/lexicon.txt $data_dir/lang_test +fi + +if [ $stage -le 5 ]; then + echo "$0: Calling the flat-start chain recipe..." + echo "Date: $(date)." + local/chain/run_flatstart_cnn1a.sh --nj $nj --train-set train_aug --data-dir $data_dir --exp-dir $exp_dir +fi + +if [ $stage -le 6 ]; then + echo "$0: Aligning the training data using the e2e chain model..." + echo "Date: $(date)." + steps/nnet3/align.sh --nj $nj --cmd "$cmd" --use-gpu false \ + --scale-opts '--transition-scale=1.0 --acoustic-scale=1.0 --self-loop-scale=1.0' \ + $data_dir/train_aug $data_dir/lang $exp_dir/chain/e2e_cnn_1a $exp_dir/chain/e2e_ali_train +fi + +if [ $stage -le 7 ]; then + echo "$0: Building a tree and training a regular chain model using the e2e alignments..." + echo "Date: $(date)." + local/chain/run_cnn_e2eali_1b.sh --nj $nj --train-set train_aug --data-dir $data_dir --exp-dir $exp_dir +fi + +if [ $stage -le 8 ]; then + echo "$0: Estimating a language model for lattice rescoring...$(date)" + local/train_lm_lr.sh --data-dir $data_dir --dir $data_dir/local/local_lm_lr --extra-lm $extra_lm --order 6 + + utils/build_const_arpa_lm.sh $data_dir/local/local_lm_lr/data/arpa/6gram_unpruned.arpa.gz \ + $data_dir/lang_test $data_dir/lang_test_lr + steps/lmrescore_const_arpa.sh $data_dir/lang_test $data_dir/lang_test_lr \ + $data_dir/test $exp_dir/chain/cnn_e2eali_1b/decode_test $exp_dir/chain/cnn_e2eali_1b/decode_test_lr +fi diff --git a/egs/yomdle_zh/v1/steps b/egs/yomdle_zh/v1/steps new file mode 120000 index 00000000000..1b186770dd1 --- /dev/null +++ b/egs/yomdle_zh/v1/steps @@ -0,0 +1 @@ +../../wsj/s5/steps/ \ No newline at end of file diff --git a/egs/yomdle_zh/v1/utils b/egs/yomdle_zh/v1/utils new file mode 120000 index 00000000000..a3279dc8679 --- /dev/null +++ b/egs/yomdle_zh/v1/utils @@ -0,0 +1 @@ +../../wsj/s5/utils/ \ No newline at end of file diff --git a/scripts/rnnlm/lmrescore_nbest.sh b/scripts/rnnlm/lmrescore_nbest.sh index 6f28c960dd9..58b19b9fa79 100755 --- a/scripts/rnnlm/lmrescore_nbest.sh +++ b/scripts/rnnlm/lmrescore_nbest.sh @@ -29,7 +29,7 @@ if [ $# != 6 ]; then echo "This version applies an RNNLM and mixes it with the LM scores" echo "previously in the lattices., controlled by the first parameter (rnnlm-weight)" echo "" - echo "Usage: utils/rnnlmrescore.sh " + echo "Usage: $0 [options] " echo "Main options:" echo " --inv-acwt # default 12. e.g. --inv-acwt 17. Equivalent to LM scale to use." echo " # for N-best list generation... note, we'll score at different acwt's" diff --git a/scripts/rnnlm/lmrescore_nbest_back.sh b/scripts/rnnlm/lmrescore_nbest_back.sh index 9b62456573c..7531d99b0a4 100755 --- a/scripts/rnnlm/lmrescore_nbest_back.sh +++ b/scripts/rnnlm/lmrescore_nbest_back.sh @@ -32,7 +32,7 @@ if [ $# != 6 ]; then echo "This version applies an RNNLM and mixes it with the LM scores" echo "previously in the lattices., controlled by the first parameter (rnnlm-weight)" echo "" - echo "Usage: utils/rnnlmrescore.sh " + echo "Usage: $0 [options] " echo "Main options:" echo " --inv-acwt # default 12. e.g. --inv-acwt 17. Equivalent to LM scale to use." echo " # for N-best list generation... note, we'll score at different acwt's" diff --git a/src/base/Makefile b/src/base/Makefile index 583c6badcf2..49af4f87ff4 100644 --- a/src/base/Makefile +++ b/src/base/Makefile @@ -18,7 +18,7 @@ include ../kaldi.mk TESTFILES = kaldi-math-test io-funcs-test kaldi-error-test timer-test -OBJFILES = kaldi-math.o kaldi-error.o io-funcs.o kaldi-utils.o +OBJFILES = kaldi-math.o kaldi-error.o io-funcs.o kaldi-utils.o timer.o LIBNAME = kaldi-base diff --git a/src/base/kaldi-common.h b/src/base/kaldi-common.h index e0002d91bb7..264565d1812 100644 --- a/src/base/kaldi-common.h +++ b/src/base/kaldi-common.h @@ -36,5 +36,6 @@ #include "base/kaldi-types.h" #include "base/io-funcs.h" #include "base/kaldi-math.h" +#include "base/timer.h" #endif // KALDI_BASE_KALDI_COMMON_H_ diff --git a/src/base/timer.cc b/src/base/timer.cc new file mode 100644 index 00000000000..ce4ef292783 --- /dev/null +++ b/src/base/timer.cc @@ -0,0 +1,85 @@ +// base/timer.cc + +// Copyright 2018 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/timer.h" +#include "base/kaldi-error.h" +#include +#include +#include +#include + +namespace kaldi { + +class ProfileStats { + public: + void AccStats(const char *function_name, double elapsed) { + std::unordered_map::iterator + iter = map_.find(function_name); + if (iter == map_.end()) { + map_[function_name] = ProfileStatsEntry(function_name); + map_[function_name].total_time = elapsed; + } else { + iter->second.total_time += elapsed; + } + } + ~ProfileStats() { + // This map makes sure we agglomerate the time if there were any duplicate + // addresses of strings. + std::unordered_map total_time; + for (auto iter = map_.begin(); iter != map_.end(); iter++) + total_time[iter->second.name] += iter->second.total_time; + + ReverseSecondComparator comp; + std::vector > pairs(total_time.begin(), + total_time.end()); + std::sort(pairs.begin(), pairs.end(), comp); + for (size_t i = 0; i < pairs.size(); i++) { + KALDI_LOG << "Time taken in " << pairs[i].first << " is " + << std::fixed << std::setprecision(2) << pairs[i].second << "s."; + } + } + private: + + struct ProfileStatsEntry { + std::string name; + double total_time; + ProfileStatsEntry() { } + ProfileStatsEntry(const char *name): name(name) { } + }; + + struct ReverseSecondComparator { + bool operator () (const std::pair &a, + const std::pair &b) { + return a.second > b.second; + } + }; + + // Note: this map is keyed on the address of the string, there is no proper + // hash function. The assumption is that the strings are compile-time + // constants. + std::unordered_map map_; +}; + +ProfileStats g_profile_stats; + +Profiler::~Profiler() { + g_profile_stats.AccStats(name_, tim_.Elapsed()); +} + +} // namespace kaldi diff --git a/src/base/timer.h b/src/base/timer.h index 7889c4a258b..96c5babb305 100644 --- a/src/base/timer.h +++ b/src/base/timer.h @@ -20,7 +20,7 @@ #define KALDI_BASE_TIMER_H_ #include "base/kaldi-utils.h" -// Note: Sleep(float secs) is included in base/kaldi-utils.h. +#include "base/kaldi-error.h" #if defined(_MSC_VER) || defined(MINGW) @@ -87,7 +87,27 @@ class Timer { struct timeval time_start_; struct timezone time_zone_; }; -} + +class Profiler { + public: + // Caution: the 'const char' should always be a string constant; for speed, + // internally the profiling code uses the address of it as a lookup key. + Profiler(const char *function_name): name_(function_name) { } + ~Profiler(); + private: + Timer tim_; + const char *name_; +}; + +// To add timing info for a function, you just put +// KALDI_PROFILE; +// at the beginning of the function. Caution: this doesn't +// include the class name. +#define KALDI_PROFILE Profiler _profiler(__func__) + + + +} // namespace kaldi #endif diff --git a/src/bin/Makefile b/src/bin/Makefile index c2b9eb48830..7cb01b50120 100644 --- a/src/bin/Makefile +++ b/src/bin/Makefile @@ -21,7 +21,8 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \ post-to-pdf-post logprob-to-post prob-to-post copy-post \ matrix-sum build-pfile-from-ali get-post-on-ali tree-info am-info \ vector-sum matrix-sum-rows est-pca sum-lda-accs sum-mllt-accs \ - transform-vec align-text matrix-dim post-to-smat compile-graph + transform-vec align-text matrix-dim post-to-smat compile-graph \ + compare-int-vector OBJFILES = @@ -29,8 +30,8 @@ OBJFILES = ADDLIBS = ../decoder/kaldi-decoder.a ../lat/kaldi-lat.a ../lm/kaldi-lm.a \ ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a TESTFILES = diff --git a/src/bin/ali-to-phones.cc b/src/bin/ali-to-phones.cc index 2a76000cfae..602e32e9768 100644 --- a/src/bin/ali-to-phones.cc +++ b/src/bin/ali-to-phones.cc @@ -38,7 +38,7 @@ int main(int argc, char *argv[]) { " ali-to-phones 1.mdl ark:1.ali ark:-\n" "or:\n" " ali-to-phones --ctm-output 1.mdl ark:1.ali 1.ctm\n" - "See also: show-alignments lattice-align-phones\n"; + "See also: show-alignments lattice-align-phones, compare-int-vector\n"; ParseOptions po(usage); bool per_frame = false; bool write_lengths = false; @@ -137,5 +137,3 @@ int main(int argc, char *argv[]) { return -1; } } - - diff --git a/src/bin/compare-int-vector.cc b/src/bin/compare-int-vector.cc new file mode 100644 index 00000000000..5f80ff5ee6c --- /dev/null +++ b/src/bin/compare-int-vector.cc @@ -0,0 +1,184 @@ +// bin/compare-int-vector.cc + +// Copyright 2018 Johns Hopkins University (Author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-vector.h" +#include "transform/transform-common.h" +#include + + +namespace kaldi { +void AddToCount(int32 location_to_add, + double value_to_add, + std::vector *counts) { + if (location_to_add < 0) + KALDI_ERR << "Contents of vectors cannot be " + "negative if --write-tot-counts or --write-diff-counts " + "options are provided."; + if (counts->size() <= static_cast(location_to_add)) + counts->resize(location_to_add + 1, 0.0); + (*counts)[location_to_add] += value_to_add; +} + +void AddToConfusionMatrix(int32 phone1, int32 phone2, + Matrix *counts) { + if (phone1 < 0 || phone2 < 0) + KALDI_ERR << "Contents of vectors cannot be " + "negative if --write-confusion-matrix option is " + "provided."; + int32 max_size = std::max(phone1, phone2) + 1; + if (counts->NumRows() < max_size) + counts->Resize(max_size, max_size, kCopyData); + (*counts)(phone1, phone2) += 1.0; +} + + +void WriteAsKaldiVector(const std::vector &counts, + std::string &wxfilename, + bool binary) { + Vector counts_vec(counts.size()); + for (size_t i = 0; i < counts.size(); i++) + counts_vec(i) = counts[i]; + WriteKaldiObject(counts_vec, wxfilename, binary); +} + +} // namespace kaldi + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Compare vectors of integers (e.g. phone alignments)\n" + "Prints to stdout fields of the form:\n" + " \n" + "\n" + "e.g.:\n" + " SWB1_A_31410_32892 420 36\n" + "\n" + "Usage:\n" + "compare-int-vector [options] \n" + "\n" + "e.g. compare-int-vector scp:foo.scp scp:bar.scp > comparison\n" + "E.g. the inputs might come from ali-to-phones.\n" + "Warnings are printed if the vector lengths differ for a given utterance-id,\n" + "and in those cases, the number of frames printed will be the smaller of the\n" + "\n" + "See also: ali-to-phones, copy-int-vector\n"; + + + ParseOptions po(usage); + + std::string tot_wxfilename, + diff_wxfilename, + confusion_matrix_wxfilename; + bool binary = true; + + po.Register("binary", &binary, "If true, write in binary mode (only applies " + "if --write-tot-counts or --write-diff-counts options are supplied)."); + po.Register("write-tot-counts", &tot_wxfilename, "Filename to write " + "vector of total counts. These may be summed with 'vector-sum'."); + po.Register("write-diff-counts", &diff_wxfilename, "Filename to write " + "vector of counts of phones (or whatever is in the inputs) " + "that differ from one vector to the other. Each time a pair differs, " + "0.5 will be added to each one's location."); + po.Register("write-confusion-matrix", &confusion_matrix_wxfilename, + "Filename to write confusion matrix, indexed by [phone1][phone2]." + "These may be summed by 'matrix-sum'."); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string vector1_rspecifier = po.GetArg(1), + vector2_rspecifier = po.GetArg(2); + + int64 num_done = 0, + num_not_found = 0, + num_mismatched_lengths = 0, + tot_frames = 0, tot_difference = 0; + + std::vector diff_counts; + std::vector tot_counts; + Matrix confusion_matrix; + + SequentialInt32VectorReader reader1(vector1_rspecifier); + RandomAccessInt32VectorReader reader2(vector2_rspecifier); + + for (; !reader1.Done(); reader1.Next(), num_done++) { + const std::string &key = reader1.Key(); + if (!reader2.HasKey(key)) { + KALDI_WARN << "No key " << key << " found in second input."; + num_not_found++; + continue; + } + const std::vector &value1 = reader1.Value(), + &value2 = reader2.Value(key); + size_t len1 = value1.size(), len2 = value2.size(); + if (len1 != len2) { + KALDI_WARN << "For utterance " << key << ", lengths differ " + << len1 << " vs. " << len2; + num_mismatched_lengths++; + } + size_t len = std::min(len1, len2), + difference = 0; + for (size_t i = 0; i < len; i++) { + int32 phone1 = value1[i], phone2 = value2[i]; + if (phone1 != phone2) { + difference++; + if (!diff_wxfilename.empty()) { + AddToCount(phone1, 0.5, &diff_counts); + AddToCount(phone2, 0.5, &diff_counts); + } + } + if (!tot_wxfilename.empty()) + AddToCount(phone1, 1.0, &tot_counts); + if (!confusion_matrix_wxfilename.empty()) + AddToConfusionMatrix(phone1, phone2, &confusion_matrix); + } + num_done++; + std::cout << key << " " << len << " " << difference << "\n"; + tot_frames += len; + tot_difference += difference; + } + + BaseFloat difference_percent = tot_difference * 100.0 / tot_frames; + KALDI_LOG << "Computed difference for " << num_done << " utterances, of which " + << num_mismatched_lengths << " had mismatched lengths; corresponding " + "utterance not found for " << num_not_found; + KALDI_LOG << "Average p(different) is " << std::setprecision(4) << difference_percent + << "%, over " << tot_frames << " frames."; + + if (!tot_wxfilename.empty()) + WriteAsKaldiVector(tot_counts, tot_wxfilename, binary); + if (!diff_wxfilename.empty()) + WriteAsKaldiVector(diff_counts, diff_wxfilename, binary); + if (!confusion_matrix_wxfilename.empty()) + WriteKaldiObject(confusion_matrix, confusion_matrix_wxfilename, binary); + + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/bin/copy-post.cc b/src/bin/copy-post.cc index 6d0d351a594..d5ca3f42980 100644 --- a/src/bin/copy-post.cc +++ b/src/bin/copy-post.cc @@ -26,13 +26,13 @@ int main(int argc, char *argv[]) { try { using namespace kaldi; - typedef kaldi::int32 int32; + typedef kaldi::int32 int32; const char *usage = "Copy archives of posteriors, with optional scaling\n" - "(Also see rand-prune-post and sum-post)\n" "\n" - "Usage: copy-post \n"; + "Usage: copy-post \n" + "See also: post-to-weights, scale-post, sum-post, weight-post ...\n"; BaseFloat scale = 1.0; ParseOptions po(usage); @@ -43,15 +43,15 @@ int main(int argc, char *argv[]) { po.PrintUsage(); exit(1); } - + std::string post_rspecifier = po.GetArg(1), post_wspecifier = po.GetArg(2); kaldi::SequentialPosteriorReader posterior_reader(post_rspecifier); - kaldi::PosteriorWriter posterior_writer(post_wspecifier); + kaldi::PosteriorWriter posterior_writer(post_wspecifier); int32 num_done = 0; - + for (; !posterior_reader.Done(); posterior_reader.Next()) { std::string key = posterior_reader.Key(); @@ -71,4 +71,3 @@ int main(int argc, char *argv[]) { return -1; } } - diff --git a/src/bin/matrix-sum.cc b/src/bin/matrix-sum.cc index 8a7b5a39e00..3c93dfd0d39 100644 --- a/src/bin/matrix-sum.cc +++ b/src/bin/matrix-sum.cc @@ -238,17 +238,20 @@ int32 TypeThreeUsage(const ParseOptions &po, << "tables, the intermediate arguments must not be tables."; } - bool add = true; - Matrix mat; + Matrix sum; for (int32 i = 1; i < po.NumArgs(); i++) { - bool binary_in; - Input ki(po.GetArg(i), &binary_in); - // this Read function will throw if there is a size mismatch. - mat.Read(ki.Stream(), binary_in, add); + Matrix this_mat; + ReadKaldiObject(po.GetArg(i), &this_mat); + if (sum.NumRows() < this_mat.NumRows() || + sum.NumCols() < this_mat.NumCols()) + sum.Resize(std::max(sum.NumRows(), this_mat.NumRows()), + std::max(sum.NumCols(), this_mat.NumCols()), + kCopyData); + sum.AddMat(1.0, this_mat); } if (average) - mat.Scale(1.0 / (po.NumArgs() - 1)); - WriteKaldiObject(mat, po.GetArg(po.NumArgs()), binary); + sum.Scale(1.0 / (po.NumArgs() - 1)); + WriteKaldiObject(sum, po.GetArg(po.NumArgs()), binary); KALDI_LOG << "Summed " << (po.NumArgs() - 1) << " matrices; " << "wrote sum to " << PrintableWxfilename(po.GetArg(po.NumArgs())); return 0; @@ -335,4 +338,3 @@ int main(int argc, char *argv[]) { return -1; } } - diff --git a/src/bin/vector-sum.cc b/src/bin/vector-sum.cc index 42404e38384..3e622cafdc7 100644 --- a/src/bin/vector-sum.cc +++ b/src/bin/vector-sum.cc @@ -1,7 +1,7 @@ // bin/vector-sum.cc -// Copyright 2014 Vimal Manohar -// 2014 Johns Hopkins University (author: Daniel Povey) +// Copyright 2014 Vimal Manohar +// 2014-2018 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // @@ -43,7 +43,7 @@ int32 TypeOneUsage(const ParseOptions &po) { // Input vectors SequentialBaseFloatVectorReader vector_reader1(vector_in_fn1); - std::vector vector_readers(num_args-2, + std::vector vector_readers(num_args-2, static_cast(NULL)); std::vector vector_in_fns(num_args-2); for (int32 i = 2; i < num_args; ++i) { @@ -51,7 +51,7 @@ int32 TypeOneUsage(const ParseOptions &po) { vector_in_fns[i-2] = po.GetArg(i); } - int32 n_utts = 0, n_total_vectors = 0, + int32 n_utts = 0, n_total_vectors = 0, n_success = 0, n_missing = 0, n_other_errors = 0; for (; !vector_reader1.Done(); vector_reader1.Next()) { @@ -70,10 +70,10 @@ int32 TypeOneUsage(const ParseOptions &po) { if (vector2.Dim() == vector_out.Dim()) { vector_out.AddVec(1.0, vector2); } else { - KALDI_WARN << "Dimension mismatch for utterance " << key + KALDI_WARN << "Dimension mismatch for utterance " << key << " : " << vector2.Dim() << " for " << "system " << (i + 2) << ", rspecifier: " - << vector_in_fns[i] << " vs " << vector_out.Dim() + << vector_in_fns[i] << " vs " << vector_out.Dim() << " primary vector, rspecifier:" << vector_in_fn1; n_other_errors++; } @@ -94,9 +94,9 @@ int32 TypeOneUsage(const ParseOptions &po) { << " different systems"; KALDI_LOG << "Produced output for " << n_success << " utterances; " << n_missing << " total missing vectors"; - + DeletePointers(&vector_readers); - + return (n_success != 0 && n_missing < (n_success - n_missing)) ? 0 : 1; } @@ -108,13 +108,13 @@ int32 TypeTwoUsage(const ParseOptions &po, "vector-sum: first argument must be an rspecifier"); // if next assert fails it would be bug in the code as otherwise we shouldn't // be called. - KALDI_ASSERT(ClassifyWspecifier(po.GetArg(2), NULL, NULL, NULL) == + KALDI_ASSERT(ClassifyWspecifier(po.GetArg(2), NULL, NULL, NULL) == kNoWspecifier); SequentialBaseFloatVectorReader vec_reader(po.GetArg(1)); Vector sum; - + int32 num_done = 0, num_err = 0; for (; !vec_reader.Done(); vec_reader.Next()) { @@ -134,7 +134,7 @@ int32 TypeTwoUsage(const ParseOptions &po, } } } - + if (num_done > 0 && average) sum.Scale(1.0 / num_done); Vector sum_float(sum); @@ -157,21 +157,21 @@ int32 TypeThreeUsage(const ParseOptions &po, << "tables, the intermediate arguments must not be tables."; } } - if (ClassifyWspecifier(po.GetArg(po.NumArgs()), NULL, NULL, NULL) != + if (ClassifyWspecifier(po.GetArg(po.NumArgs()), NULL, NULL, NULL) != kNoWspecifier) { KALDI_ERR << "Wrong usage (type 3): if first and last arguments are not " << "tables, the intermediate arguments must not be tables."; } - bool add = true; - Vector vec; + Vector sum; for (int32 i = 1; i < po.NumArgs(); i++) { - bool binary_in; - Input ki(po.GetArg(i), &binary_in); - // this Read function will throw if there is a size mismatch. - vec.Read(ki.Stream(), binary_in, add); + Vector this_vec; + ReadKaldiObject(po.GetArg(i), &this_vec); + if (sum.Dim() < this_vec.Dim()) + sum.Resize(this_vec.Dim(), kCopyData);; + sum.AddVec(1.0, this_vec); } - WriteKaldiObject(vec, po.GetArg(po.NumArgs()), binary); + WriteKaldiObject(sum, po.GetArg(po.NumArgs()), binary); KALDI_LOG << "Summed " << (po.NumArgs() - 1) << " vectors; " << "wrote sum to " << PrintableWxfilename(po.GetArg(po.NumArgs())); return 0; @@ -201,15 +201,15 @@ int main(int argc, char *argv[]) { " \n" " e.g.: vector-sum --binary=false 1.vec 2.vec 3.vec sum.vec\n" "See also: copy-vector, dot-weights\n"; - + bool binary, average = false; - + ParseOptions po(usage); po.Register("binary", &binary, "If true, write output as binary (only " "relevant for usage types two or three"); po.Register("average", &average, "Do average instead of sum"); - + po.Read(argc, argv); int32 N = po.NumArgs(), exit_status; @@ -226,11 +226,11 @@ int main(int argc, char *argv[]) { exit_status = TypeTwoUsage(po, binary, average); } else if (po.NumArgs() >= 2 && ClassifyRspecifier(po.GetArg(1), NULL, NULL) == kNoRspecifier && - ClassifyWspecifier(po.GetArg(N), NULL, NULL, NULL) == + ClassifyWspecifier(po.GetArg(N), NULL, NULL, NULL) == kNoWspecifier) { // summing flat files. exit_status = TypeThreeUsage(po, binary); - } else { + } else { po.PrintUsage(); exit(1); } diff --git a/src/chain/Makefile b/src/chain/Makefile index 2a735c2ca2d..fbad28f7de6 100644 --- a/src/chain/Makefile +++ b/src/chain/Makefile @@ -18,8 +18,7 @@ LIBNAME = kaldi-chain ADDLIBS = ../cudamatrix/kaldi-cudamatrix.a ../lat/kaldi-lat.a \ ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a ../tree/kaldi-tree.a \ - ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a # Make sure we have CUDA_ARCH from kaldi.mk, ifeq ($(CUDA), true) diff --git a/src/chain/chain-denominator.cc b/src/chain/chain-denominator.cc index e41e942e266..b644e429b67 100644 --- a/src/chain/chain-denominator.cc +++ b/src/chain/chain-denominator.cc @@ -61,10 +61,11 @@ DenominatorComputation::DenominatorComputation( num_sequences_).SetZero(); KALDI_ASSERT(nnet_output.NumRows() % num_sequences == 0); - // the kStrideEqualNumCols argument means we'll allocate a contiguous block of - // memory for this; it is added to ensure that the same block of memory - // (cached in the allocator) can be used for xent_output_deriv when allocated - // from chain-training.cc. + // the kStrideEqualNumCols argument is so that we can share the same + // memory block with xent_output_deriv (see chain-training.cc, search for + // kStrideEqualNumCols). This depends on how the allocator works, and + // actually might not happen, but anyway, the impact on speed would + // likely be un-measurably small. exp_nnet_output_transposed_.Resize(nnet_output.NumCols(), nnet_output.NumRows(), kUndefined, kStrideEqualNumCols); diff --git a/src/chainbin/Makefile b/src/chainbin/Makefile index 61f653f174f..41ac7342d17 100644 --- a/src/chainbin/Makefile +++ b/src/chainbin/Makefile @@ -25,7 +25,7 @@ ADDLIBS = ../nnet3/kaldi-nnet3.a ../chain/kaldi-chain.a \ ../cudamatrix/kaldi-cudamatrix.a ../decoder/kaldi-decoder.a \ ../lat/kaldi-lat.a ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/cudamatrix/Makefile b/src/cudamatrix/Makefile index ca831390ea9..45c2ba44fd7 100644 --- a/src/cudamatrix/Makefile +++ b/src/cudamatrix/Makefile @@ -18,8 +18,7 @@ endif LIBNAME = kaldi-cudamatrix -ADDLIBS = ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a +ADDLIBS = ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a # Make sure we have CUDA_ARCH from kaldi.mk, ifeq ($(CUDA), true) diff --git a/src/cudamatrix/cu-allocator.cc b/src/cudamatrix/cu-allocator.cc index cfbc6757530..d1617bfedd4 100644 --- a/src/cudamatrix/cu-allocator.cc +++ b/src/cudamatrix/cu-allocator.cc @@ -398,7 +398,7 @@ void* CuMemoryAllocator::MallocPitch(size_t row_bytes, } void CuMemoryAllocator::Free(void *ptr) { - CuTimer tim; + Timer tim; if (!opts_.cache_memory) { CU_SAFE_CALL(cudaFree(ptr)); tot_time_taken_ += tim.Elapsed(); @@ -586,6 +586,23 @@ void CuMemoryAllocator::SortSubregions() { } } +CuMemoryAllocator::~CuMemoryAllocator() { + // We mainly free these blocks of memory so that cuda-memcheck doesn't report + // spurious errors. + for (size_t i = 0; i < memory_regions_.size(); i++) { + // No need to check the return status here-- the program is exiting anyway. + cudaFree(memory_regions_[i].begin); + } + for (size_t i = 0; i < subregions_.size(); i++) { + SubRegion *subregion = subregions_[i]; + for (auto iter = subregion->free_blocks.begin(); + iter != subregion->free_blocks.end(); ++iter) + delete iter->second; + delete subregion; + } +} + + CuMemoryAllocator g_cuda_allocator; diff --git a/src/cudamatrix/cu-allocator.h b/src/cudamatrix/cu-allocator.h index 20425704a2b..9dd2bb82aea 100644 --- a/src/cudamatrix/cu-allocator.h +++ b/src/cudamatrix/cu-allocator.h @@ -54,7 +54,7 @@ struct CuAllocatorOptions { bool cache_memory; // The proportion of the device's memory that the CuAllocator allocates to - // start with; by default this is 0.8, although if you want to share the + // start with; by default this is 0.5, although if you want to share the // device (not recommended!) you should set this lower. BaseFloat memory_proportion; @@ -187,6 +187,8 @@ class CuMemoryAllocator { // by the user (c.f. RegisterCuAllocatorOptions()) before the options are read. void SetOptions(const CuAllocatorOptions &opts) { opts_ = opts; } + ~CuMemoryAllocator(); + private: struct SubRegion; diff --git a/src/cudamatrix/cu-device.cc b/src/cudamatrix/cu-device.cc index 37912ea8adf..49c179b3673 100644 --- a/src/cudamatrix/cu-device.cc +++ b/src/cudamatrix/cu-device.cc @@ -102,7 +102,7 @@ void CuDevice::Initialize() { if (!multi_threaded_) { multi_threaded_ = true; KALDI_WARN << "For multi-threaded code that might use GPU, you should call " - "CuDevice()::Instantiate().AllowMultithreading() at the start of " + "CuDevice::Instantiate().AllowMultithreading() at the start of " "the program."; } device_id_copy_ = device_id_; diff --git a/src/cudamatrix/cu-matrix-inl.h b/src/cudamatrix/cu-matrix-inl.h index 9b7a707d2e5..0e182d4e72a 100644 --- a/src/cudamatrix/cu-matrix-inl.h +++ b/src/cudamatrix/cu-matrix-inl.h @@ -36,6 +36,7 @@ inline CuSubMatrix::CuSubMatrix(const CuMatrixBase &mat, // initializer, so nothing to do. } else { KALDI_ASSERT(row_offset >= 0 && col_offset >= 0 && + num_rows >= 0 && num_cols >= 0 && row_offset + num_rows <= mat.num_rows_ && col_offset + num_cols <= mat.num_cols_); this->data_ = mat.data_ + static_cast(col_offset) + @@ -68,5 +69,3 @@ inline CuSubMatrix::CuSubMatrix(const Real *data, } // namespace kaldi #endif - - diff --git a/src/decoder/Makefile b/src/decoder/Makefile index 35c84758779..020fe358fe9 100644 --- a/src/decoder/Makefile +++ b/src/decoder/Makefile @@ -7,14 +7,13 @@ TESTFILES = OBJFILES = training-graph-compiler.o lattice-simple-decoder.o lattice-faster-decoder.o \ lattice-faster-online-decoder.o simple-decoder.o faster-decoder.o \ - decoder-wrappers.o grammar-fst.o + decoder-wrappers.o grammar-fst.o decodable-matrix.o LIBNAME = kaldi-decoder -ADDLIBS = ../lat/kaldi-lat.a ../hmm/kaldi-hmm.a \ +ADDLIBS = ../lat/kaldi-lat.a ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../fstext/kaldi-fstext.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/decoder/decodable-matrix.cc b/src/decoder/decodable-matrix.cc new file mode 100644 index 00000000000..3cc7b87f2d7 --- /dev/null +++ b/src/decoder/decodable-matrix.cc @@ -0,0 +1,107 @@ +// decoder/decodable-matrix.cc + +// Copyright 2018 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/decodable-matrix.h" + +namespace kaldi { + +DecodableMatrixMapped::DecodableMatrixMapped( + const TransitionModel &tm, + const MatrixBase &likes, + int32 frame_offset): + trans_model_(tm), likes_(&likes), likes_to_delete_(NULL), + frame_offset_(frame_offset) { + stride_ = likes.Stride(); + raw_data_ = likes.Data() - (stride_ * frame_offset); + + if (likes.NumCols() != tm.NumPdfs()) + KALDI_ERR << "Mismatch, matrix has " + << likes.NumCols() << " rows but transition-model has " + << tm.NumPdfs() << " pdf-ids."; +} + +DecodableMatrixMapped::DecodableMatrixMapped( + const TransitionModel &tm, const Matrix *likes, + int32 frame_offset): + trans_model_(tm), likes_(likes), likes_to_delete_(likes), + frame_offset_(frame_offset) { + stride_ = likes->Stride(); + raw_data_ = likes->Data() - (stride_ * frame_offset_); + if (likes->NumCols() != tm.NumPdfs()) + KALDI_ERR << "Mismatch, matrix has " + << likes->NumCols() << " rows but transition-model has " + << tm.NumPdfs() << " pdf-ids."; +} + + +BaseFloat DecodableMatrixMapped::LogLikelihood(int32 frame, int32 tid) { + int32 pdf_id = trans_model_.TransitionIdToPdfFast(tid); +#ifdef KALDI_PARANOID + return (*likes_)(frame - frame_offset_, pdf_id); +#else + return raw_data_[frame * stride_ + pdf_id]; +#endif +} + +int32 DecodableMatrixMapped::NumFramesReady() const { + return frame_offset_ + likes_->NumRows(); +} + +bool DecodableMatrixMapped::IsLastFrame(int32 frame) const { + KALDI_ASSERT(frame < NumFramesReady()); + return (frame == NumFramesReady() - 1); +} + +// Indices are one-based! This is for compatibility with OpenFst. +int32 DecodableMatrixMapped::NumIndices() const { + return trans_model_.NumTransitionIds(); +} + +DecodableMatrixMapped::~DecodableMatrixMapped() { + delete likes_to_delete_; +} + + +void DecodableMatrixMappedOffset::AcceptLoglikes( + Matrix *loglikes, int32 frames_to_discard) { + if (loglikes->NumRows() == 0) return; + KALDI_ASSERT(loglikes->NumCols() == trans_model_.NumPdfs()); + KALDI_ASSERT(frames_to_discard <= loglikes_.NumRows() && + frames_to_discard >= 0); + if (frames_to_discard == loglikes_.NumRows()) { + loglikes_.Swap(loglikes); + loglikes->Resize(0, 0); + } else { + int32 old_rows_kept = loglikes_.NumRows() - frames_to_discard, + new_num_rows = old_rows_kept + loglikes->NumRows(); + Matrix new_loglikes(new_num_rows, loglikes->NumCols()); + new_loglikes.RowRange(0, old_rows_kept).CopyFromMat( + loglikes_.RowRange(frames_to_discard, old_rows_kept)); + new_loglikes.RowRange(old_rows_kept, loglikes->NumRows()).CopyFromMat( + *loglikes); + loglikes_.Swap(&new_loglikes); + } + frame_offset_ += frames_to_discard; + stride_ = loglikes_.Stride(); + raw_data_ = loglikes_.Data() - (frame_offset_ * stride_); +} + + + +} // end namespace kaldi. diff --git a/src/decoder/decodable-matrix.h b/src/decoder/decodable-matrix.h index de70ea82753..f32a007e6ca 100644 --- a/src/decoder/decodable-matrix.h +++ b/src/decoder/decodable-matrix.h @@ -32,8 +32,7 @@ namespace kaldi { class DecodableMatrixScaledMapped: public DecodableInterface { public: - // This constructor creates an object that will not delete "likes" - // when done. + // This constructor creates an object that will not delete "likes" when done. DecodableMatrixScaledMapped(const TransitionModel &tm, const Matrix &likes, BaseFloat scale): trans_model_(tm), likes_(&likes), @@ -55,7 +54,7 @@ class DecodableMatrixScaledMapped: public DecodableInterface { KALDI_ERR << "DecodableMatrixScaledMapped: mismatch, matrix has " << likes->NumCols() << " rows but transition-model has " << tm.NumPdfs() << " pdf-ids."; - } + } virtual int32 NumFramesReady() const { return likes_->NumRows(); } @@ -66,7 +65,7 @@ class DecodableMatrixScaledMapped: public DecodableInterface { // Note, frames are numbered from zero. virtual BaseFloat LogLikelihood(int32 frame, int32 tid) { - return scale_ * (*likes_)(frame, trans_model_.TransitionIdToPdf(tid)); + return scale_ * (*likes_)(frame, trans_model_.TransitionIdToPdfFast(tid)); } // Indices are one-based! This is for compatibility with OpenFst. @@ -83,6 +82,59 @@ class DecodableMatrixScaledMapped: public DecodableInterface { KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableMatrixScaledMapped); }; +/** + This is like DecodableMatrixScaledMapped, but it doesn't support an acoustic + scale, and it does support a frame offset, whereby you can state that the + first row of 'likes' is actually the n'th row of the matrix of available + log-likelihoods. It's useful if the neural net output comes in chunks for + different frame ranges. + + Note: DecodableMatrixMappedOffset solves the same problem in a slightly + different way, where you use the same decodable object. This one, unlike + DecodableMatrixMappedOffset, is compatible with when the loglikes are in a + SubMatrix. + */ +class DecodableMatrixMapped: public DecodableInterface { + public: + // This constructor creates an object that will not delete "likes" when done. + // the frame_offset is the frame the row 0 of 'likes' corresponds to, would be + // greater than one if this is not the first chunk of likelihoods. + DecodableMatrixMapped(const TransitionModel &tm, + const MatrixBase &likes, + int32 frame_offset = 0); + + // This constructor creates an object that will delete "likes" + // when done. + DecodableMatrixMapped(const TransitionModel &tm, + const Matrix *likes, + int32 frame_offset = 0); + + virtual int32 NumFramesReady() const; + + virtual bool IsLastFrame(int32 frame) const; + + virtual BaseFloat LogLikelihood(int32 frame, int32 tid); + + // Note: these indices are 1-based. + virtual int32 NumIndices() const; + + virtual ~DecodableMatrixMapped(); + + private: + const TransitionModel &trans_model_; // for tid to pdf mapping + const MatrixBase *likes_; + const Matrix *likes_to_delete_; + int32 frame_offset_; + + // raw_data_ and stride_ are a kind of fast look-aside for 'likes_', to be + // used when KALDI_PARANOID is false. + const BaseFloat *raw_data_; + int32 stride_; + + KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableMatrixMapped); +}; + + /** This decodable class returns log-likes stored in a matrix; it supports repeatedly writing to the matrix and setting a time-offset representing the @@ -91,68 +143,51 @@ class DecodableMatrixScaledMapped: public DecodableInterface { code will call SetLoglikes() each time more log-likelihods are available. If you try to access a log-likelihood that's no longer available because the frame index is less than the current offset, it is of course an error. + + See also DecodableMatrixMapped, which supports the same type of thing but + with a different interface where you are expected to re-construct the + object each time you want to decode. */ class DecodableMatrixMappedOffset: public DecodableInterface { public: DecodableMatrixMappedOffset(const TransitionModel &tm): - trans_model_(tm), frame_offset_(0), input_is_finished_(false) { } - - + trans_model_(tm), frame_offset_(0), input_is_finished_(false) { } virtual int32 NumFramesReady() { return frame_offset_ + loglikes_.NumRows(); } // this is not part of the generic Decodable interface. int32 FirstAvailableFrame() { return frame_offset_; } - + + // Logically, this function appends 'loglikes' (interpreted as newly available + // frames) to the log-likelihoods stored in the class. + // // This function is destructive of the input "loglikes" because it may // under some circumstances do a shallow copy using Swap(). This function // appends loglikes to any existing likelihoods you've previously supplied. - // frames_to_discard, if nonzero, will discard that number of previously - // available frames, from the left, advancing FirstAvailableFrame() by - // a number equal to frames_to_discard. You should only set frames_to_discard - // to nonzero if you know your decoder won't want to access the loglikes - // for older frames. void AcceptLoglikes(Matrix *loglikes, - int32 frames_to_discard) { - if (loglikes->NumRows() == 0) return; - KALDI_ASSERT(loglikes->NumCols() == trans_model_.NumPdfs()); - KALDI_ASSERT(frames_to_discard <= loglikes_.NumRows() && - frames_to_discard >= 0); - if (frames_to_discard == loglikes_.NumRows()) { - loglikes_.Swap(loglikes); - loglikes->Resize(0, 0); - } else { - int32 old_rows_kept = loglikes_.NumRows() - frames_to_discard, - new_num_rows = old_rows_kept + loglikes->NumRows(); - Matrix new_loglikes(new_num_rows, loglikes->NumCols()); - new_loglikes.RowRange(0, old_rows_kept).CopyFromMat( - loglikes_.RowRange(frames_to_discard, old_rows_kept)); - new_loglikes.RowRange(old_rows_kept, loglikes->NumRows()).CopyFromMat( - *loglikes); - loglikes_.Swap(&new_loglikes); - } - frame_offset_ += frames_to_discard; - } + int32 frames_to_discard); void InputIsFinished() { input_is_finished_ = true; } - + virtual int32 NumFramesReady() const { return loglikes_.NumRows() + frame_offset_; } - + virtual bool IsLastFrame(int32 frame) const { KALDI_ASSERT(frame < NumFramesReady()); return (frame == NumFramesReady() - 1 && input_is_finished_); } virtual BaseFloat LogLikelihood(int32 frame, int32 tid) { - int32 index = frame - frame_offset_; - KALDI_ASSERT(index >= 0 && index < loglikes_.NumRows()); - return loglikes_(index, trans_model_.TransitionIdToPdf(tid)); + int32 pdf_id = trans_model_.TransitionIdToPdfFast(tid); +#ifdef KALDI_PARANOID + return loglikes_(frame - frame_offset_, pdf_id); +#else + // This does no checking, so will be faster. + return raw_data_[frame * stride_ + pdf_id]; +#endif } - - virtual int32 NumIndices() const { return trans_model_.NumTransitionIds(); } // nothing special to do in destructor. @@ -162,6 +197,15 @@ class DecodableMatrixMappedOffset: public DecodableInterface { Matrix loglikes_; int32 frame_offset_; bool input_is_finished_; + + // 'raw_data_' and 'stride_' are intended as a fast look-aside which is an + // alternative to accessing data_. raw_data_ is a faked version of + // data_->Data() as if it started from frame zero rather than frame_offset_. + // This simplifies the code of LogLikelihood(), in cases where KALDI_PARANOID + // is not defined. + BaseFloat *raw_data_; + int32 stride_; + KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableMatrixMappedOffset); }; @@ -171,20 +215,20 @@ class DecodableMatrixScaled: public DecodableInterface { DecodableMatrixScaled(const Matrix &likes, BaseFloat scale): likes_(likes), scale_(scale) { } - + virtual int32 NumFramesReady() const { return likes_.NumRows(); } - + virtual bool IsLastFrame(int32 frame) const { KALDI_ASSERT(frame < NumFramesReady()); return (frame == NumFramesReady() - 1); } - + // Note, frames are numbered from zero. virtual BaseFloat LogLikelihood(int32 frame, int32 index) { - if (index > likes_.NumCols() || index <= 0 || + if (index > likes_.NumCols() || index <= 0 || frame < 0 || frame >= likes_.NumRows()) - KALDI_ERR << "Invalid (frame, index - 1) = (" - << frame << ", " << index - 1 << ") for matrix of size " + KALDI_ERR << "Invalid (frame, index - 1) = (" + << frame << ", " << index - 1 << ") for matrix of size " << likes_.NumRows() << " x " << likes_.NumCols(); return scale_ * likes_(frame, index - 1); } diff --git a/src/decoder/decoder-wrappers.cc b/src/decoder/decoder-wrappers.cc index 76f95dab7cc..ff573c74d15 100644 --- a/src/decoder/decoder-wrappers.cc +++ b/src/decoder/decoder-wrappers.cc @@ -382,7 +382,7 @@ bool DecodeUtteranceLatticeSimple( for (size_t i = 0; i < words.size(); i++) { std::string s = word_syms->Find(words[i]); if (s == "") - KALDI_ERR << "Word-id " << words[i] <<" not in symbol table."; + KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; std::cerr << s << ' '; } std::cerr << '\n'; diff --git a/src/decoder/grammar-fst.h b/src/decoder/grammar-fst.h index 70ceadc8daa..f66933c132d 100644 --- a/src/decoder/grammar-fst.h +++ b/src/decoder/grammar-fst.h @@ -87,6 +87,9 @@ template<> class ArcIterator; sub-FSTs that represent nonterminals in the grammar; and multiple return points whenever we invoke a nonterminal. For more information see \ref grammar (i.e. ../doc/grammar.dox). + + Caution: this class is not thread safe, i.e. you shouldn't access the same + GrammarFst from multiple threads. We can fix this later if needed. */ class GrammarFst { public: diff --git a/src/decoder/lattice-faster-decoder.h b/src/decoder/lattice-faster-decoder.h index 766ad051e10..c611ec9dc05 100644 --- a/src/decoder/lattice-faster-decoder.h +++ b/src/decoder/lattice-faster-decoder.h @@ -83,6 +83,7 @@ struct LatticeFasterDecoderConfig { } void Check() const { KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 + && min_active <= max_active && prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0 && prune_scale > 0.0 && prune_scale < 1.0); } diff --git a/src/feat/Makefile b/src/feat/Makefile index 2af9da2ec59..dcd029f7f94 100644 --- a/src/feat/Makefile +++ b/src/feat/Makefile @@ -16,7 +16,7 @@ OBJFILES = feature-functions.o feature-mfcc.o feature-plp.o feature-fbank.o \ LIBNAME = kaldi-feat ADDLIBS = ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/feat/feature-common-inl.h b/src/feat/feature-common-inl.h index 546f272e821..ad8fa244982 100644 --- a/src/feat/feature-common-inl.h +++ b/src/feat/feature-common-inl.h @@ -49,8 +49,7 @@ void OfflineFeatureTpl::ComputeFeatures( new_sample_freq, &downsampled_wave); Compute(downsampled_wave, vtln_warp, output); } else - KALDI_ERR << "The waveform is allowed to get downsampled." - << "New sample Frequency " << new_sample_freq + KALDI_ERR << "New sample Frequency " << new_sample_freq << " is larger than waveform original sampling frequency " << sample_freq; diff --git a/src/feat/online-feature.cc b/src/feat/online-feature.cc index 267a4724580..f35cf631752 100644 --- a/src/feat/online-feature.cc +++ b/src/feat/online-feature.cc @@ -143,7 +143,9 @@ void OnlineCmvnState::Read(std::istream &is, bool binary) { OnlineCmvn::OnlineCmvn(const OnlineCmvnOptions &opts, const OnlineCmvnState &cmvn_state, OnlineFeatureInterface *src): - opts_(opts), src_(src) { + opts_(opts), temp_stats_(2, src->Dim() + 1), + temp_feats_(src->Dim()), temp_feats_dbl_(src->Dim()), + src_(src) { SetState(cmvn_state); if (!SplitStringToIntegers(opts.skip_dims, ":", false, &skip_dims_)) KALDI_ERR << "Bad --skip-dims option (should be colon-separated list of " @@ -151,7 +153,10 @@ OnlineCmvn::OnlineCmvn(const OnlineCmvnOptions &opts, } OnlineCmvn::OnlineCmvn(const OnlineCmvnOptions &opts, - OnlineFeatureInterface *src): opts_(opts), src_(src) { + OnlineFeatureInterface *src): + opts_(opts), temp_stats_(2, src->Dim() + 1), + temp_feats_(src->Dim()), temp_feats_dbl_(src->Dim()), + src_(src) { if (!SplitStringToIntegers(opts.skip_dims, ":", false, &skip_dims_)) KALDI_ERR << "Bad --skip-dims option (should be colon-separated list of " << "integers)"; @@ -160,7 +165,7 @@ OnlineCmvn::OnlineCmvn(const OnlineCmvnOptions &opts, void OnlineCmvn::GetMostRecentCachedFrame(int32 frame, int32 *cached_frame, - Matrix *stats) { + MatrixBase *stats) { KALDI_ASSERT(frame >= 0); InitRingBufferIfNeeded(); // look for a cached frame on a previous frame as close as possible in time @@ -174,7 +179,7 @@ void OnlineCmvn::GetMostRecentCachedFrame(int32 frame, int32 index = t % opts_.ring_buffer_size; if (cached_stats_ring_[index].first == t) { *cached_frame = t; - *stats = cached_stats_ring_[index].second; + stats->CopyFromMat(cached_stats_ring_[index].second); return; } } @@ -182,7 +187,7 @@ void OnlineCmvn::GetMostRecentCachedFrame(int32 frame, if (n >= cached_stats_modulo_.size()) { if (cached_stats_modulo_.size() == 0) { *cached_frame = -1; - stats->Resize(2, this->Dim() + 1); + stats->SetZero(); return; } else { n = static_cast(cached_stats_modulo_.size() - 1); @@ -190,7 +195,7 @@ void OnlineCmvn::GetMostRecentCachedFrame(int32 frame, } *cached_frame = n * opts_.modulus; KALDI_ASSERT(cached_stats_modulo_[n] != NULL); - *stats = *(cached_stats_modulo_[n]); + stats->CopyFromMat(*(cached_stats_modulo_[n])); } // Initialize ring buffer for caching stats. @@ -202,7 +207,7 @@ void OnlineCmvn::InitRingBufferIfNeeded() { } } -void OnlineCmvn::CacheFrame(int32 frame, const Matrix &stats) { +void OnlineCmvn::CacheFrame(int32 frame, const MatrixBase &stats) { KALDI_ASSERT(frame >= 0); if (frame % opts_.modulus == 0) { // store in cached_stats_modulo_. int32 n = frame / opts_.modulus; @@ -239,18 +244,18 @@ void OnlineCmvn::ComputeStatsForFrame(int32 frame, KALDI_ASSERT(frame >= 0 && frame < src_->NumFramesReady()); int32 dim = this->Dim(), cur_frame; - Matrix stats(2, dim + 1); - GetMostRecentCachedFrame(frame, &cur_frame, &stats); + GetMostRecentCachedFrame(frame, &cur_frame, stats_out); - Vector feats(dim); - Vector feats_dbl(dim); + Vector &feats(temp_feats_); + Vector &feats_dbl(temp_feats_dbl_); while (cur_frame < frame) { cur_frame++; src_->GetFrame(cur_frame, &feats); feats_dbl.CopyFromVec(feats); - stats.Row(0).Range(0, dim).AddVec(1.0, feats_dbl); - stats.Row(1).Range(0, dim).AddVec2(1.0, feats_dbl); - stats(0, dim) += 1.0; + stats_out->Row(0).Range(0, dim).AddVec(1.0, feats_dbl); + if (opts_.normalize_variance) + stats_out->Row(1).Range(0, dim).AddVec2(1.0, feats_dbl); + (*stats_out)(0, dim) += 1.0; // it's a sliding buffer; a frame at the back may be // leaving the buffer so we have to subtract that. int32 prev_frame = cur_frame - opts_.cmn_window; @@ -258,13 +263,13 @@ void OnlineCmvn::ComputeStatsForFrame(int32 frame, // we need to subtract frame prev_f from the stats. src_->GetFrame(prev_frame, &feats); feats_dbl.CopyFromVec(feats); - stats.Row(0).Range(0, dim).AddVec(-1.0, feats_dbl); - stats.Row(1).Range(0, dim).AddVec2(-1.0, feats_dbl); - stats(0, dim) -= 1.0; + stats_out->Row(0).Range(0, dim).AddVec(-1.0, feats_dbl); + if (opts_.normalize_variance) + stats_out->Row(1).Range(0, dim).AddVec2(-1.0, feats_dbl); + (*stats_out)(0, dim) -= 1.0; } - CacheFrame(cur_frame, stats); + CacheFrame(cur_frame, (*stats_out)); } - stats_out->CopyFromMat(stats); } @@ -273,6 +278,16 @@ void OnlineCmvn::SmoothOnlineCmvnStats(const MatrixBase &speaker_stats, const MatrixBase &global_stats, const OnlineCmvnOptions &opts, MatrixBase *stats) { + if (speaker_stats.NumRows() == 2 && !opts.normalize_variance) { + // this is just for efficiency: don't operate on the variance if it's not + // needed. + int32 cols = speaker_stats.NumCols(); // dim + 1 + SubMatrix stats_temp(*stats, 0, 1, 0, cols); + SmoothOnlineCmvnStats(speaker_stats.RowRange(0, 1), + global_stats.RowRange(0, 1), + opts, &stats_temp); + return; + } int32 dim = stats->NumCols() - 1; double cur_count = (*stats)(0, dim); // If count exceeded cmn_window it would be an error in how "window_stats" @@ -311,7 +326,8 @@ void OnlineCmvn::GetFrame(int32 frame, src_->GetFrame(frame, feat); KALDI_ASSERT(feat->Dim() == this->Dim()); int32 dim = feat->Dim(); - Matrix stats(2, dim + 1); + Matrix &stats(temp_stats_); + stats.Resize(2, dim + 1, kUndefined); // Will do nothing if size was correct. if (frozen_state_.NumRows() != 0) { // the CMVN state has been frozen. stats.CopyFromMat(frozen_state_); } else { @@ -329,14 +345,13 @@ void OnlineCmvn::GetFrame(int32 frame, // call the function ApplyCmvn declared in ../transform/cmvn.h, which // requires a matrix. - Matrix feat_mat(1, dim); - feat_mat.Row(0).CopyFromVec(*feat); + // 1 row; num-cols == dim; stride == dim. + SubMatrix feat_mat(feat->Data(), 1, dim, dim); // the function ApplyCmvn takes a matrix, so form a one-row matrix to give it. if (opts_.normalize_mean) ApplyCmvn(stats, opts_.normalize_variance, &feat_mat); else KALDI_ASSERT(!opts_.normalize_variance); - feat->CopyFromVec(feat_mat.Row(0)); } void OnlineCmvn::Freeze(int32 cur_frame) { @@ -430,6 +445,17 @@ void OnlineTransform::GetFrame(int32 frame, VectorBase *feat) { feat->AddMatVec(1.0, linear_term_, kNoTrans, input_feat, 1.0); } +void OnlineTransform::GetFrames( + const std::vector &frames, MatrixBase *feats) { + KALDI_ASSERT(static_cast(frames.size()) == feats->NumRows()); + int32 num_frames = feats->NumRows(), + input_dim = linear_term_.NumCols(); + Matrix input_feats(num_frames, input_dim, kUndefined); + src_->GetFrames(frames, &input_feats); + feats->CopyRowsFromVec(offset_); + feats->AddMatMat(1.0, input_feats, kNoTrans, linear_term_, kTrans, 1.0); +} + int32 OnlineDeltaFeature::Dim() const { int32 src_dim = src_->Dim(); @@ -493,6 +519,44 @@ void OnlineCacheFeature::GetFrame(int32 frame, VectorBase *feat) { } } +void OnlineCacheFeature::GetFrames( + const std::vector &frames, MatrixBase *feats) { + int32 num_frames = frames.size(); + // non_cached_frames will be the subset of 't' values in 'frames' which were + // not previously cached, which we therefore need to get from src_. + std::vector non_cached_frames; + // 'non_cached_indexes' stores the indexes 'i' into 'frames' corresponding to + // the corresponding frames in 'non_cached_frames'. + std::vector non_cached_indexes; + non_cached_frames.reserve(frames.size()); + non_cached_indexes.reserve(frames.size()); + for (int32 i = 0; i < num_frames; i++) { + int32 t = frames[i]; + if (static_cast(t) < cache_.size() && cache_[t] != NULL) { + feats->Row(i).CopyFromVec(*(cache_[t])); + } else { + non_cached_frames.push_back(t); + non_cached_indexes.push_back(i); + } + } + if (non_cached_frames.empty()) + return; + int32 num_non_cached_frames = non_cached_frames.size(), + dim = this->Dim(); + Matrix non_cached_feats(num_non_cached_frames, dim, + kUndefined); + src_->GetFrames(non_cached_frames, &non_cached_feats); + for (int32 i = 0; i < num_non_cached_frames; i++) { + SubVector this_feat(non_cached_feats, i); + feats->Row(non_cached_indexes[i]).CopyFromVec(this_feat); + int32 t = non_cached_frames[i]; + if (static_cast(t) >= cache_.size()) + cache_.resize(t + 1, NULL); + cache_[t] = new Vector(this_feat); + } +} + + void OnlineCacheFeature::ClearCache() { for (size_t i = 0; i < cache_.size(); i++) delete cache_[i]; @@ -500,7 +564,6 @@ void OnlineCacheFeature::ClearCache() { } - void OnlineAppendFeature::GetFrame(int32 frame, VectorBase *feat) { KALDI_ASSERT(feat->Dim() == Dim()); diff --git a/src/feat/online-feature.h b/src/feat/online-feature.h index 11d170972fa..d41bb6747c7 100644 --- a/src/feat/online-feature.h +++ b/src/feat/online-feature.h @@ -182,7 +182,8 @@ struct OnlineCmvnOptions { // class computes the cmvn internally. smaller->more // time-efficient but less memory-efficient. Must be >= 1. int32 ring_buffer_size; // not configurable from command line; size of ring - // buffer used for caching CMVN stats. + // buffer used for caching CMVN stats. Must be >= + // modulus. std::string skip_dims; // Colon-separated list of dimensions to skip normalization // of, e.g. 13:14:15. @@ -371,10 +372,10 @@ class OnlineCmvn: public OnlineFeatureInterface { /// were cached, sets up empty stats for frame zero and returns that]. void GetMostRecentCachedFrame(int32 frame, int32 *cached_frame, - Matrix *stats); + MatrixBase *stats); /// Cache this frame of stats. - void CacheFrame(int32 frame, const Matrix &stats); + void CacheFrame(int32 frame, const MatrixBase &stats); /// Initialize ring buffer for caching stats. inline void InitRingBufferIfNeeded(); @@ -403,6 +404,12 @@ class OnlineCmvn: public OnlineFeatureInterface { // frame index. std::vector > > cached_stats_ring_; + // Some temporary variables used inside functions of this class, which + // put here to avoid reallocation. + Matrix temp_stats_; + Vector temp_feats_; + Vector temp_feats_dbl_; + OnlineFeatureInterface *src_; // Not owned here }; @@ -472,6 +479,9 @@ class OnlineTransform: public OnlineFeatureInterface { virtual void GetFrame(int32 frame, VectorBase *feat); + virtual void GetFrames(const std::vector &frames, + MatrixBase *feats); + // // Next, functions that are not in the interface. // @@ -537,6 +547,9 @@ class OnlineCacheFeature: public OnlineFeatureInterface { virtual void GetFrame(int32 frame, VectorBase *feat); + virtual void GetFrames(const std::vector &frames, + MatrixBase *feats); + virtual ~OnlineCacheFeature() { ClearCache(); } // Things that are not in the shared interface: diff --git a/src/featbin/Makefile b/src/featbin/Makefile index 8e72d0f744c..861ba3f7a93 100644 --- a/src/featbin/Makefile +++ b/src/featbin/Makefile @@ -25,7 +25,7 @@ TESTFILES = ADDLIBS = ../hmm/kaldi-hmm.a ../feat/kaldi-feat.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/featbin/extract-feature-segments.cc b/src/featbin/extract-feature-segments.cc index 93f599feb3a..f6cdcb96b18 100644 --- a/src/featbin/extract-feature-segments.cc +++ b/src/featbin/extract-feature-segments.cc @@ -25,7 +25,7 @@ #include "matrix/kaldi-matrix.h" /** @brief This is a program for extracting segments from feature files/archives - - usage : + - usage : - extract-feature-segments [options ..] - "segments-file" should have the information of the segments that needs to be extracted from the feature files - the format of the segments file : speaker_name filename start_time(in secs) end_time(in secs) @@ -37,6 +37,10 @@ int main(int argc, char *argv[]) { const char *usage = "Create feature files by segmenting input files.\n" + "Note: this program should no longer be needed now that\n" + "'ranges' in scp files are supported; search for 'ranges' in\n" + "http://kaldi-asr.org/doc/io_tut.html, or see the script\n" + "utils/data/subsegment_data_dir.sh.\n" "Usage: " "extract-feature-segments [options...] " " \n" @@ -144,9 +148,9 @@ int main(int argc, char *argv[]) { } } - /* check whether a segment start time and end time exists in utterance + /* check whether a segment start time and end time exists in utterance * if fails , skips the segment. - */ + */ if (!feat_reader.HasKey(utterance)) { KALDI_WARN << "Did not find features for utterance " << utterance << ", skipping segment " << segment; @@ -167,7 +171,7 @@ int main(int argc, char *argv[]) { end_samp -= snip_length; } - /* start sample must be less than total number of samples + /* start sample must be less than total number of samples * otherwise skip the segment */ if (start_samp < 0 || start_samp >= num_samp) { @@ -177,7 +181,7 @@ int main(int argc, char *argv[]) { continue; } - /* end sample must be less than total number samples + /* end sample must be less than total number samples * otherwise skip the segment */ if (end_samp > num_samp) { @@ -221,4 +225,3 @@ int main(int argc, char *argv[]) { return -1; } } - diff --git a/src/fgmmbin/Makefile b/src/fgmmbin/Makefile index baa4cd9be33..5db252477b5 100644 --- a/src/fgmmbin/Makefile +++ b/src/fgmmbin/Makefile @@ -18,7 +18,6 @@ TESTFILES = ADDLIBS = ../decoder/kaldi-decoder.a ../lat/kaldi-lat.a ../hmm/kaldi-hmm.a \ ../feat/kaldi-feat.a ../transform/kaldi-transform.a \ ../gmm/kaldi-gmm.a ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a \ - ../base/kaldi-base.a + ../matrix/kaldi-matrix.a ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/fstbin/Makefile b/src/fstbin/Makefile index 4236282b3fc..a22c014a7d5 100644 --- a/src/fstbin/Makefile +++ b/src/fstbin/Makefile @@ -26,6 +26,6 @@ TESTFILES = LIBFILE = ADDLIBS = ../decoder/kaldi-decoder.a ../fstext/kaldi-fstext.a \ - ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/fstbin/fstrand.cc b/src/fstbin/fstrand.cc index 9344b538d9c..f0bc3938051 100644 --- a/src/fstbin/fstrand.cc +++ b/src/fstbin/fstrand.cc @@ -45,6 +45,8 @@ int main(int argc, char *argv[]) { po.Register("allow-empty", &opts.allow_empty, "If true, we may generate an empty FST."); + po.Read(argc, argv); + if (po.NumArgs() > 1) { po.PrintUsage(); exit(1); diff --git a/src/fstext/Makefile b/src/fstext/Makefile index dc25ddae95b..b76bd413c42 100644 --- a/src/fstext/Makefile +++ b/src/fstext/Makefile @@ -24,7 +24,7 @@ LIBNAME = kaldi-fstext # tree and matrix archives needed for test-context-fst # matrix archive needed for push-special. -ADDLIBS = ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a +ADDLIBS = ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/gmm/Makefile b/src/gmm/Makefile index d8aedadfd93..caee6734afe 100644 --- a/src/gmm/Makefile +++ b/src/gmm/Makefile @@ -14,8 +14,8 @@ OBJFILES = diag-gmm.o diag-gmm-normal.o mle-diag-gmm.o am-diag-gmm.o \ LIBNAME = kaldi-gmm -ADDLIBS = ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a +ADDLIBS = ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a diff --git a/src/gmm/diag-gmm.h b/src/gmm/diag-gmm.h index 1243d7a6bfd..4a10ea34471 100644 --- a/src/gmm/diag-gmm.h +++ b/src/gmm/diag-gmm.h @@ -100,7 +100,7 @@ class DiagGmm { const std::vector &indices, Vector *loglikes) const; - /// Get gaussian selection information for one frame. Returns og-like + /// Get gaussian selection information for one frame. Returns log-like /// this frame. Output is the best "num_gselect" indices, sorted from best to /// worst likelihood. If "num_gselect" > NumGauss(), sets it to NumGauss(). BaseFloat GaussianSelection(const VectorBase &data, diff --git a/src/gmmbin/Makefile b/src/gmmbin/Makefile index 72a0fa15e73..82d10abe9ce 100644 --- a/src/gmmbin/Makefile +++ b/src/gmmbin/Makefile @@ -37,8 +37,8 @@ TESTFILES = ADDLIBS = ../decoder/kaldi-decoder.a ../lat/kaldi-lat.a \ ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a ../feat/kaldi-feat.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/gmmbin/gmm-decode-simple.cc b/src/gmmbin/gmm-decode-simple.cc index b408afafdff..5ef35552dc0 100644 --- a/src/gmmbin/gmm-decode-simple.cc +++ b/src/gmmbin/gmm-decode-simple.cc @@ -38,8 +38,9 @@ int main(int argc, char *argv[]) { typedef kaldi::int32 int32; using fst::SymbolTable; using fst::VectorFst; + using fst::Fst; using fst::StdArc; - using fst::ReadFstKaldi; + using fst::ReadFstKaldiGeneric; const char *usage = "Decode features using GMM-based model.\n" @@ -86,7 +87,7 @@ int main(int argc, char *argv[]) { am_gmm.Read(ki.Stream(), binary); } - VectorFst *decode_fst = ReadFstKaldi(fst_in_filename); + Fst *decode_fst = ReadFstKaldiGeneric(fst_in_filename); Int32VectorWriter words_writer(words_wspecifier); diff --git a/src/gmmbin/gmm-init-biphone.cc b/src/gmmbin/gmm-init-biphone.cc index d1c789a620e..e5cc182f94c 100644 --- a/src/gmmbin/gmm-init-biphone.cc +++ b/src/gmmbin/gmm-init-biphone.cc @@ -51,10 +51,12 @@ void ReadSharedPhonesList(std::string rxfilename, std::vector EventMap *GetFullBiphoneStubMap(const std::vector > &phone_sets, const std::vector &phone2num_pdf_classes, - const std::vector &share_roots) { + const std::vector &share_roots, + const std::vector &ci_phones_list) { { // Check the inputs - KALDI_ASSERT(!phone_sets.empty() && share_roots.size() == phone_sets.size()); + KALDI_ASSERT(!phone_sets.empty() && + share_roots.size() == phone_sets.size()); std::set all_phones; for (size_t i = 0; i < phone_sets.size(); i++) { KALDI_ASSERT(IsSortedAndUniq(phone_sets[i])); @@ -66,9 +68,18 @@ EventMap } } + int32 numpdfs_per_phone = phone2num_pdf_classes[1]; int32 current_pdfid = 0; std::map level1_map; // key is 1 + + for (size_t i = 0; i < ci_phones_list.size(); i++) { + std::map level2_map; + level2_map[0] = current_pdfid++; + if (numpdfs_per_phone == 2) level2_map[1] = current_pdfid++; + level1_map[ci_phones_list[i]] = new TableEventMap(kPdfClass, level2_map); + } + for (size_t i = 0; i < phone_sets.size(); i++) { if (numpdfs_per_phone == 1) { @@ -99,9 +110,11 @@ EventMap level3_map[0] = current_pdfid++; level3_map[1] = current_pdfid++; level2_map[0] = new TableEventMap(kPdfClass, level3_map); // no-left-context case + for (size_t i = 0; i < ci_phones_list.size(); i++) // ci-phone left-context cases + level2_map[ci_phones_list[i]] = new TableEventMap(kPdfClass, level3_map); } for (size_t j = 0; j < phone_sets.size(); j++) { - std::map level3_map; // key is -1 + std::map level3_map; // key is kPdfClass level3_map[0] = current_pdfid++; level3_map[1] = current_pdfid++; @@ -121,17 +134,35 @@ EventMap return new TableEventMap(1, level1_map); } + ContextDependency* -BiphoneContextDependencyFull(const std::vector > phone_sets, - const std::vector phone2num_pdf_classes) { - std::vector share_roots(phone_sets.size(), false); // Don't share roots +BiphoneContextDependencyFull(std::vector > phone_sets, + const std::vector phone2num_pdf_classes, + const std::vector &ci_phones_list) { + // Remove all the CI phones from the phone sets + std::set ci_phones; + for (size_t i = 0; i < ci_phones_list.size(); i++) + ci_phones.insert(ci_phones_list[i]); + for (int32 i = phone_sets.size() - 1; i >= 0; i--) { + for (int32 j = phone_sets[i].size() - 1; j >= 0; j--) { + if (ci_phones.find(phone_sets[i][j]) != ci_phones.end()) { // Delete it + phone_sets[i].erase(phone_sets[i].begin() + j); + if (phone_sets[i].empty()) // If empty, delete the whole entry + phone_sets.erase(phone_sets.begin() + i); + } + } + } + + std::vector share_roots(phone_sets.size(), false); // Don't share roots // N is context size, P = position of central phone (must be 0). int32 P = 1, N = 2; EventMap *pdf_map = GetFullBiphoneStubMap(phone_sets, - phone2num_pdf_classes, share_roots); + phone2num_pdf_classes, + share_roots, ci_phones_list); return new ContextDependency(N, P, pdf_map); } + } // end namespace kaldi int main(int argc, char *argv[]) { @@ -148,11 +179,17 @@ int main(int argc, char *argv[]) { bool binary = true; std::string shared_phones_rxfilename; + std::string ci_phones_str; + std::vector ci_phones; // Sorted, uniqe vector of + // context-independent phones. + ParseOptions po(usage); po.Register("binary", &binary, "Write output in binary mode"); po.Register("shared-phones", &shared_phones_rxfilename, "rxfilename containing, on each line, a list of phones " "whose pdfs should be shared."); + po.Register("ci-phones", &ci_phones_str, "Colon-separated list of " + "integer indices of context-independent phones."); po.Read(argc, argv); if (po.NumArgs() != 4) { @@ -169,6 +206,14 @@ int main(int argc, char *argv[]) { std::string model_filename = po.GetArg(3); std::string tree_filename = po.GetArg(4); + if (!ci_phones_str.empty()) { + SplitStringToIntegers(ci_phones_str, ":", false, &ci_phones); + std::sort(ci_phones.begin(), ci_phones.end()); + if (!IsSortedAndUniq(ci_phones) || ci_phones.empty() || ci_phones[0] == 0) + KALDI_ERR << "Invalid --ci-phones option: " << ci_phones_str; + } + + Vector glob_inv_var(dim); glob_inv_var.Set(1.0); Vector glob_mean(dim); @@ -200,7 +245,8 @@ int main(int argc, char *argv[]) { ReadSharedPhonesList(shared_phones_rxfilename, &shared_phones); // ReadSharedPhonesList crashes on error. } - ctx_dep = BiphoneContextDependencyFull(shared_phones, phone2num_pdf_classes); + ctx_dep = BiphoneContextDependencyFull(shared_phones, phone2num_pdf_classes, + ci_phones); int32 num_pdfs = ctx_dep->NumPdfs(); diff --git a/src/hmm/Makefile b/src/hmm/Makefile index 6da3b7b7757..0ad5da74c28 100644 --- a/src/hmm/Makefile +++ b/src/hmm/Makefile @@ -9,8 +9,8 @@ OBJFILES = hmm-topology.o transition-model.o hmm-utils.o tree-accu.o \ posterior.o hmm-test-utils.o LIBNAME = kaldi-hmm -ADDLIBS = ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a +ADDLIBS = ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/hmm/posterior-test.cc b/src/hmm/posterior-test.cc index b6958674f9b..0906cb3d0dc 100644 --- a/src/hmm/posterior-test.cc +++ b/src/hmm/posterior-test.cc @@ -33,12 +33,12 @@ void TestVectorToPosteriorEntry() { std::vector > post_entry; - BaseFloat ans = VectorToPosteriorEntry(loglikes, gselect, min_post, &post_entry); + VectorToPosteriorEntry(loglikes, gselect, min_post, &post_entry); KALDI_ASSERT(post_entry.size() <= gselect); int32 max_elem; - BaseFloat max_val = loglikes.Max(&max_elem); + loglikes.Max(&max_elem); KALDI_ASSERT(post_entry[0].first == max_elem); KALDI_ASSERT(post_entry.back().second >= min_post); @@ -48,7 +48,6 @@ void TestVectorToPosteriorEntry() { for (size_t i = 0; i < post_entry.size(); i++) sum += post_entry[i].second; KALDI_ASSERT(fabs(sum - 1.0) < 0.01); - KALDI_ASSERT(ans >= max_val); } void TestPosteriorIo() { @@ -92,4 +91,3 @@ int main() { } std::cout << "Test OK.\n"; } - diff --git a/src/hmm/posterior.cc b/src/hmm/posterior.cc index 42db6e99cf4..860a979a0ce 100644 --- a/src/hmm/posterior.cc +++ b/src/hmm/posterior.cc @@ -402,7 +402,7 @@ void WeightSilencePostDistributed(const TransitionModel &trans_model, for (size_t i = 0; i < post->size(); i++) { std::vector > this_post; this_post.reserve((*post)[i].size()); - BaseFloat sil_weight = 0.0, nonsil_weight = 0.0; + BaseFloat sil_weight = 0.0, nonsil_weight = 0.0; for (size_t j = 0; j < (*post)[i].size(); j++) { int32 tid = (*post)[i][j].first, phone = trans_model.TransitionIdToPhone(tid); @@ -418,12 +418,23 @@ void WeightSilencePostDistributed(const TransitionModel &trans_model, if (frame_scale != 0.0) { for (size_t j = 0; j < (*post)[i].size(); j++) { int32 tid = (*post)[i][j].first; - BaseFloat weight = (*post)[i][j].second; + BaseFloat weight = (*post)[i][j].second; this_post.push_back(std::make_pair(tid, weight * frame_scale)); } } - (*post)[i].swap(this_post); + (*post)[i].swap(this_post); + } +} + +inline static BaseFloat GetTotalPosterior( + const std::vector > &post_entry) { + BaseFloat tot = 0.0; + std::vector >::const_iterator + iter = post_entry.begin(), end = post_entry.end(); + for (; iter != end; ++iter) { + tot += iter->second; } + return tot; } BaseFloat VectorToPosteriorEntry( @@ -434,39 +445,66 @@ BaseFloat VectorToPosteriorEntry( KALDI_ASSERT(num_gselect > 0 && min_post >= 0 && min_post < 1.0); // we name num_gauss assuming each entry in log_likes represents a Gaussian; // it doesn't matter if they don't. + int32 num_gauss = log_likes.Dim(); KALDI_ASSERT(num_gauss > 0); if (num_gselect > num_gauss) num_gselect = num_gauss; - Vector log_likes_normalized(log_likes); - BaseFloat ans = log_likes_normalized.ApplySoftMax(); - std::vector > temp_post(num_gauss); - for (int32 g = 0; g < num_gauss; g++) - temp_post[g] = std::pair(g, log_likes_normalized(g)); + std::vector > temp_post; + BaseFloat max_like = log_likes.Max(); + if (min_post != 0.0) { + BaseFloat like_cutoff = max_like + Log(min_post); + for (int32 g = 0; g < num_gauss; g++) { + BaseFloat like = log_likes(g); + if (like > like_cutoff) { + BaseFloat post = exp(like - max_like); + temp_post.push_back(std::pair(g, post)); + } + } + } + if (temp_post.empty()) { + // we reach here if min_post was 0.0 or if no posteriors reached the + // threshold min_post (we need at least one). + temp_post.resize(num_gauss); + for (int32 g = 0; g < num_gauss; g++) + temp_post[g] = std::pair(g, Exp(log_likes(g) - max_like)); + } + CompareReverseSecond compare; - // Sort in decreasing order on posterior. For efficiency we - // first do nth_element and then sort, as we only need the part we're - // going to output, to be sorted. - std::nth_element(temp_post.begin(), - temp_post.begin() + num_gselect, temp_post.end(), - compare); - std::sort(temp_post.begin(), temp_post.begin() + num_gselect, - compare); + if (static_cast(temp_post.size()) > num_gselect * 2) { + // Sort in decreasing order on posterior. For efficiency we + // first do nth_element and then sort, as we only need the part we're + // going to output, to be sorted. + std::nth_element(temp_post.begin(), + temp_post.begin() + num_gselect, temp_post.end(), + compare); + std::sort(temp_post.begin(), temp_post.begin() + num_gselect, + compare); + } else { + std::sort(temp_post.begin(), temp_post.end(), compare); + } + + size_t num_to_insert = std::min(temp_post.size(), + num_gselect); post_entry->clear(); post_entry->insert(post_entry->end(), - temp_post.begin(), temp_post.begin() + num_gselect); - while (post_entry->size() > 1 && post_entry->back().second < min_post) - post_entry->pop_back(); + temp_post.begin(), temp_post.begin() + num_to_insert); + + BaseFloat tot_post = GetTotalPosterior(*post_entry), + cutoff = min_post * tot_post; + + while (post_entry->size() > 1 && post_entry->back().second < cutoff) { + tot_post -= post_entry->back().second; + post_entry->pop_back(); + } // Now renormalize to sum to one after pruning. - BaseFloat tot = 0.0; - size_t size = post_entry->size(); - for (size_t i = 0; i < size; i++) - tot += (*post_entry)[i].second; - BaseFloat inv_tot = 1.0 / tot; - for (size_t i = 0; i < size; i++) - (*post_entry)[i].second *= inv_tot; - return ans; + BaseFloat inv_tot = 1.0 / tot_post; + auto end = post_entry->end(); + for (auto iter = post_entry->begin(); iter != end; ++iter) + iter->second *= inv_tot; + + return max_like + log(tot_post); } diff --git a/src/hmm/posterior.h b/src/hmm/posterior.h index cfe3fc44572..0c255845dd5 100644 --- a/src/hmm/posterior.h +++ b/src/hmm/posterior.h @@ -190,8 +190,9 @@ struct CompareReverseSecond { /// by applying Softmax(), then prunes the posteriors using "gselect" and /// "min_post" (keeping at least one), and outputs the result into /// "post_entry", sorted from greatest to least posterior. -/// Returns the total log-likelihood (the output of calling ApplySoftMax() -/// on a copy of log_likes). +/// +/// It returns the log of the sum of the selected log-likes that contributed +/// to the posterior. BaseFloat VectorToPosteriorEntry( const VectorBase &log_likes, int32 num_gselect, diff --git a/src/hmm/transition-model.cc b/src/hmm/transition-model.cc index 83edbaf5805..5ecb7776f00 100644 --- a/src/hmm/transition-model.cc +++ b/src/hmm/transition-model.cc @@ -166,7 +166,7 @@ void TransitionModel::ComputeDerived() { id2state_.resize(cur_transition_id); // cur_transition_id is #transition-ids+1. id2pdf_id_.resize(cur_transition_id); - for (int32 tstate = 1; tstate <= static_cast(tuples_.size()); tstate++) + for (int32 tstate = 1; tstate <= static_cast(tuples_.size()); tstate++) { for (int32 tid = state2id_[tstate]; tid < state2id_[tstate+1]; tid++) { id2state_[tid] = tstate; if (IsSelfLoop(tid)) @@ -174,6 +174,17 @@ void TransitionModel::ComputeDerived() { else id2pdf_id_[tid] = tuples_[tstate-1].forward_pdf; } + } + + // The following statements put copies a large number in the region of memory + // past the end of the id2pdf_id_ array, while leaving the aray as it was + // before. The goal of this is to speed up decoding by disabling a check + // inside TransitionIdToPdf() that the transition-id was within the correct + // range. + int32 num_big_numbers = std::min(2000, cur_transition_id); + id2pdf_id_.resize(cur_transition_id + num_big_numbers, + std::numeric_limits::max()); + id2pdf_id_.resize(cur_transition_id); } void TransitionModel::InitializeProbs() { diff --git a/src/hmm/transition-model.h b/src/hmm/transition-model.h index 9843dff946b..f03b54e8b71 100644 --- a/src/hmm/transition-model.h +++ b/src/hmm/transition-model.h @@ -156,6 +156,10 @@ class TransitionModel { // this state doesn't have a self-loop. inline int32 TransitionIdToPdf(int32 trans_id) const; + // TransitionIdToPdfFast is as TransitionIdToPdf but skips an assertion + // (unless we're in paranoid mode). + inline int32 TransitionIdToPdfFast(int32 trans_id) const; + int32 TransitionIdToPhone(int32 trans_id) const; int32 TransitionIdToPdfClass(int32 trans_id) const; int32 TransitionIdToHmmState(int32 trans_id) const; @@ -316,14 +320,26 @@ class TransitionModel { /// of pdfs). int32 num_pdfs_; - KALDI_DISALLOW_COPY_AND_ASSIGN(TransitionModel); - }; inline int32 TransitionModel::TransitionIdToPdf(int32 trans_id) const { - KALDI_ASSERT(static_cast(trans_id) < id2pdf_id_.size() && - "Likely graph/model mismatch (graph built from wrong model?)"); + KALDI_ASSERT( + static_cast(trans_id) < id2pdf_id_.size() && + "Likely graph/model mismatch (graph built from wrong model?)"); + return id2pdf_id_[trans_id]; +} + +inline int32 TransitionModel::TransitionIdToPdfFast(int32 trans_id) const { + // Note: it's a little dangerous to assert this only in paranoid mode. + // However, this function is called in the inner loop of decoders and + // the assertion likely takes a significant amount of time. We make + // sure that past the end of thd id2pdf_id_ array there are big + // numbers, which will make the calling code more likely to segfault + // (rather than silently die) if this is called for out-of-range values. + KALDI_PARANOID_ASSERT( + static_cast(trans_id) < id2pdf_id_.size() && + "Likely graph/model mismatch (graph built from wrong model?)"); return id2pdf_id_[trans_id]; } diff --git a/src/itf/decodable-itf.h b/src/itf/decodable-itf.h index 9852861969d..9f1f2f62e2b 100644 --- a/src/itf/decodable-itf.h +++ b/src/itf/decodable-itf.h @@ -72,19 +72,18 @@ namespace kaldi { always just return the number of frames in the file, and IsLastFrame() will return true for the last frame. - For truly online decoding, the "old" online decodable objects in ../online/ have a - "blocking" IsLastFrame() and will crash if you call NumFramesReady(). + For truly online decoding, the "old" online decodable objects in ../online/ + have a "blocking" IsLastFrame() and will crash if you call NumFramesReady(). The "new" online decodable objects in ../online2/ return the number of frames currently accessible if you call NumFramesReady(). You will likely not need to call IsLastFrame(), but we implement it to only return true for the last frame of the file once we've decided to terminate decoding. */ - class DecodableInterface { public: /// Returns the log likelihood, which will be negated in the decoder. - /// The "frame" starts from zero. You should verify that IsLastFrame(frame-1) - /// returns false before calling this. + /// The "frame" starts from zero. You should verify that NumFramesReady() > frame + /// before calling this. virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0; /// Returns true if this is the last frame. Frames are zero-based, so the diff --git a/src/itf/online-feature-itf.h b/src/itf/online-feature-itf.h index 3837024ab55..22c1c392450 100644 --- a/src/itf/online-feature-itf.h +++ b/src/itf/online-feature-itf.h @@ -45,11 +45,11 @@ namespace kaldi { implementing a child class you must not make assumptions about the order in which the user makes these calls. */ - + class OnlineFeatureInterface { public: virtual int32 Dim() const = 0; /// returns the feature dimension. - + /// Returns the total number of frames, since the start of the utterance, that /// are now available. In an online-decoding context, this will likely /// increase with time as more data becomes available. @@ -65,7 +65,7 @@ class OnlineFeatureInterface { /// many frames are in the decodable object (as it used to be, and for backward /// compatibility, still is, in the Decodable interface). virtual bool IsLastFrame(int32 frame) const = 0; - + /// Gets the feature vector for this frame. Before calling this for a given /// frame, it is assumed that you called NumFramesReady() and it returned a /// number greater than "frame". Otherwise this call will likely crash with @@ -74,6 +74,21 @@ class OnlineFeatureInterface { /// the class. virtual void GetFrame(int32 frame, VectorBase *feat) = 0; + + /// This is like GetFrame() but for a collection of frames. There is a + /// default implementation that just gets the frames one by one, but it + /// may be overridden for efficiency by child classes (since sometimes + /// it's more efficient to do things in a batch). + virtual void GetFrames(const std::vector &frames, + MatrixBase *feats) { + KALDI_ASSERT(static_cast(frames.size()) == feats->NumRows()); + for (size_t i = 0; i < frames.size(); i++) { + SubVector feat(*feats, i); + GetFrame(frames[i], &feat); + } + } + + // Returns frame shift in seconds. Helps to estimate duration from frame // counts. virtual BaseFloat FrameShiftInSeconds() const = 0; @@ -81,8 +96,8 @@ class OnlineFeatureInterface { /// Virtual destructor. Note: constructors that take another member of /// type OnlineFeatureInterface are not expected to take ownership of /// that pointer; the caller needs to keep track of that manually. - virtual ~OnlineFeatureInterface() { } - + virtual ~OnlineFeatureInterface() { } + }; diff --git a/src/ivector/Makefile b/src/ivector/Makefile index 408018befa4..1154da6880b 100644 --- a/src/ivector/Makefile +++ b/src/ivector/Makefile @@ -13,8 +13,8 @@ OBJFILES = ivector-extractor.o voice-activity-detection.o plda.o \ LIBNAME = kaldi-ivector ADDLIBS = ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/ivector/ivector-extractor.cc b/src/ivector/ivector-extractor.cc index aaba3837698..c3a122281c2 100644 --- a/src/ivector/ivector-extractor.cc +++ b/src/ivector/ivector-extractor.cc @@ -578,10 +578,96 @@ void OnlineIvectorEstimationStats::AccStats( quadratic_term_.AddToDiag(prior_scale_change); } } + num_frames_ += tot_weight; +} + + +// This is used in OnlineIvectorEstimationStats::AccStats(). +struct GaussInfo { + // total weight for this Gaussian. + BaseFloat tot_weight; + // vector of pairs of (frame-index, weight for this Gaussian) + std::vector > frame_weights; + GaussInfo(): tot_weight(0.0) { } +}; + +static void ConvertPostToGaussInfo( + const std::vector > > &gauss_post, + std::unordered_map *gauss_info) { + int32 num_frames = gauss_post.size(); + for (int32 t = 0; t < num_frames; t++) { + const std::vector > &this_post = gauss_post[t]; + auto iter = this_post.begin(), end = this_post.end(); + for (; iter != end; ++iter) { + int32 gauss_idx = iter->first; + GaussInfo &info = (*gauss_info)[gauss_idx]; + BaseFloat weight = iter->second; + info.tot_weight += weight; + info.frame_weights.push_back(std::pair(t, weight)); + } + } +} + +void OnlineIvectorEstimationStats::AccStats( + const IvectorExtractor &extractor, + const MatrixBase &features, + const std::vector > > &gauss_post) { + KALDI_ASSERT(extractor.IvectorDim() == this->IvectorDim()); + KALDI_ASSERT(!extractor.IvectorDependentWeights()); + + int32 feat_dim = features.NumCols(); + std::unordered_map gauss_info; + ConvertPostToGaussInfo(gauss_post, &gauss_info); + + Vector weighted_feats(feat_dim, kUndefined); + double tot_weight = 0.0; + int32 ivector_dim = this->IvectorDim(), + quadratic_term_dim = (ivector_dim * (ivector_dim + 1)) / 2; + SubVector quadratic_term_vec(quadratic_term_.Data(), + quadratic_term_dim); + + std::unordered_map::const_iterator + iter = gauss_info.begin(), end = gauss_info.end(); + for (; iter != end; ++iter) { + int32 gauss_idx = iter->first; + const GaussInfo &info = iter->second; + + weighted_feats.SetZero(); + std::vector >::const_iterator + f_iter = info.frame_weights.begin(), f_end = info.frame_weights.end(); + for (; f_iter != f_end; ++f_iter) { + int32 t = f_iter->first; + BaseFloat weight = f_iter->second; + weighted_feats.AddVec(weight, features.Row(t)); + } + BaseFloat this_tot_weight = info.tot_weight; + linear_term_.AddMatVec(1.0, extractor.Sigma_inv_M_[gauss_idx], kTrans, + weighted_feats, 1.0); + SubVector U_g(extractor.U_, gauss_idx); + quadratic_term_vec.AddVec(this_tot_weight, U_g); + tot_weight += this_tot_weight; + } + if (max_count_ > 0.0) { + // see comments in header RE max_count for explanation. It relates to + // prior scaling when the count exceeds max_count_ + double old_num_frames = num_frames_, + new_num_frames = num_frames_ + tot_weight; + double old_prior_scale = std::max(old_num_frames, max_count_) / max_count_, + new_prior_scale = std::max(new_num_frames, max_count_) / max_count_; + // The prior_scales are the inverses of the scales we would put on the stats + // if we were implementing this by scaling the stats. Instead we + // scale the prior term. + double prior_scale_change = new_prior_scale - old_prior_scale; + if (prior_scale_change != 0.0) { + linear_term_(0) += prior_offset_ * prior_scale_change; + quadratic_term_.AddToDiag(prior_scale_change); + } + } num_frames_ += tot_weight; } + void OnlineIvectorEstimationStats::Scale(double scale) { KALDI_ASSERT(scale >= 0.0 && scale <= 1.0); double old_num_frames = num_frames_; diff --git a/src/ivector/ivector-extractor.h b/src/ivector/ivector-extractor.h index 9641d9d79e8..3b9b6f3eb5c 100644 --- a/src/ivector/ivector-extractor.h +++ b/src/ivector/ivector-extractor.h @@ -323,10 +323,17 @@ class OnlineIvectorEstimationStats { OnlineIvectorEstimationStats(const OnlineIvectorEstimationStats &other); + // Accumulate stats for one frame. void AccStats(const IvectorExtractor &extractor, const VectorBase &feature, const std::vector > &gauss_post); + // Accumulate stats for a sequence (or collection) of frames. + void AccStats(const IvectorExtractor &extractor, + const MatrixBase &features, + const std::vector > > &gauss_post); + + int32 IvectorDim() const { return linear_term_.Dim(); } /// This function gets the current estimate of the iVector. Internally it diff --git a/src/ivectorbin/Makefile b/src/ivectorbin/Makefile index 75a17708c43..5a738352d9c 100644 --- a/src/ivectorbin/Makefile +++ b/src/ivectorbin/Makefile @@ -26,7 +26,7 @@ TESTFILES = ADDLIBS = ../ivector/kaldi-ivector.a ../hmm/kaldi-hmm.a ../gmm/kaldi-gmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/kws/Makefile b/src/kws/Makefile index a5b74ea2229..c4367eb2958 100644 --- a/src/kws/Makefile +++ b/src/kws/Makefile @@ -10,8 +10,7 @@ OBJFILES = kws-functions.o kws-functions2.o kws-scoring.o LIBNAME = kaldi-kws ADDLIBS = ../lat/kaldi-lat.a ../hmm/kaldi-hmm.a ../tree/kaldi-tree.a \ - ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/kwsbin/Makefile b/src/kwsbin/Makefile index cade044e153..bcc2685b7f3 100644 --- a/src/kwsbin/Makefile +++ b/src/kwsbin/Makefile @@ -17,7 +17,6 @@ TESTFILES = ADDLIBS = ../kws/kaldi-kws.a ../lat/kaldi-lat.a ../fstext/kaldi-fstext.a \ ../hmm/kaldi-hmm.a ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a \ - ../base/kaldi-base.a + ../matrix/kaldi-matrix.a ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/lat/Makefile b/src/lat/Makefile index bba2329fdf6..56521486826 100644 --- a/src/lat/Makefile +++ b/src/lat/Makefile @@ -16,8 +16,7 @@ OBJFILES = kaldi-lattice.o lattice-functions.o word-align-lattice.o \ LIBNAME = kaldi-lat ADDLIBS = ../hmm/kaldi-hmm.a ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a \ - ../base/kaldi-base.a + ../matrix/kaldi-matrix.a ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/latbin/Makefile b/src/latbin/Makefile index afff54cb845..9809cdcbb85 100644 --- a/src/latbin/Makefile +++ b/src/latbin/Makefile @@ -32,10 +32,9 @@ OBJFILES = TESTFILES = -ADDLIBS = ../rnnlm/kaldi-rnnlm.a ../lat/kaldi-lat.a ../nnet3/kaldi-nnet3.a ../lm/kaldi-lm.a \ +ADDLIBS = ../rnnlm/kaldi-rnnlm.a ../nnet3/kaldi-nnet3.a \ + ../cudamatrix/kaldi-cudamatrix.a ../lat/kaldi-lat.a ../lm/kaldi-lm.a \ ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a ../tree/kaldi-tree.a \ - ../util/kaldi-util.a \ - ../cudamatrix/kaldi-cudamatrix.a ../matrix/kaldi-matrix.a \ - ../base/kaldi-base.a + ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/lm/Makefile b/src/lm/Makefile index 3dfb409f970..c0654fa83b2 100644 --- a/src/lm/Makefile +++ b/src/lm/Makefile @@ -12,7 +12,6 @@ OBJFILES = arpa-file-parser.o arpa-lm-compiler.o const-arpa-lm.o \ LIBNAME = kaldi-lm ADDLIBS = ../fstext/kaldi-fstext.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a \ - ../base/kaldi-base.a + ../matrix/kaldi-matrix.a ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/lmbin/Makefile b/src/lmbin/Makefile index c88f6151a8f..108ddab50c5 100644 --- a/src/lmbin/Makefile +++ b/src/lmbin/Makefile @@ -10,7 +10,7 @@ OBJFILES = TESTFILES = -ADDLIBS = ../lm/kaldi-lm.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a +ADDLIBS = ../lm/kaldi-lm.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/matrix/kaldi-vector.h b/src/matrix/kaldi-vector.h index 3eb4a932095..383d8ca2862 100644 --- a/src/matrix/kaldi-vector.h +++ b/src/matrix/kaldi-vector.h @@ -514,8 +514,9 @@ class SubVector : public VectorBase { /// Constructor from a pointer to memory and a length. Keeps a pointer /// to the data but does not take ownership (will never delete). - SubVector(Real *data, MatrixIndexT length) : VectorBase () { - VectorBase::data_ = data; + /// Caution: this constructor enables you to evade const constraints. + SubVector(const Real *data, MatrixIndexT length) : VectorBase () { + VectorBase::data_ = const_cast(data); VectorBase::dim_ = length; } @@ -594,4 +595,3 @@ Real VecMatVec(const VectorBase &v1, const MatrixBase &M, #endif // KALDI_MATRIX_KALDI_VECTOR_H_ - diff --git a/src/nnet/Makefile b/src/nnet/Makefile index 99f54ae2af2..7f324479a0f 100644 --- a/src/nnet/Makefile +++ b/src/nnet/Makefile @@ -15,8 +15,8 @@ OBJFILES = nnet-nnet.o nnet-component.o nnet-loss.o \ LIBNAME = kaldi-nnet ADDLIBS = ../cudamatrix/kaldi-cudamatrix.a ../hmm/kaldi-hmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/nnet2/Makefile b/src/nnet2/Makefile index 5fc27419ec1..7c19ec2603c 100644 --- a/src/nnet2/Makefile +++ b/src/nnet2/Makefile @@ -27,7 +27,7 @@ LIBNAME = kaldi-nnet2 ADDLIBS = ../cudamatrix/kaldi-cudamatrix.a ../lat/kaldi-lat.a \ ../hmm/kaldi-hmm.a ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/nnet2/decodable-am-nnet.h b/src/nnet2/decodable-am-nnet.h index e3dedb33727..6c40b11bf9d 100644 --- a/src/nnet2/decodable-am-nnet.h +++ b/src/nnet2/decodable-am-nnet.h @@ -76,14 +76,14 @@ class DecodableAmNnet: public DecodableInterface { // from one (this routine is called by FSTs). virtual BaseFloat LogLikelihood(int32 frame, int32 transition_id) { return log_probs_(frame, - trans_model_.TransitionIdToPdf(transition_id)); + trans_model_.TransitionIdToPdfFast(transition_id)); } virtual int32 NumFramesReady() const { return log_probs_.NumRows(); } - + // Indices are one-based! This is for compatibility with OpenFst. virtual int32 NumIndices() const { return trans_model_.NumTransitionIds(); } - + virtual bool IsLastFrame(int32 frame) const { KALDI_ASSERT(frame < NumFramesReady()); return (frame == NumFramesReady() - 1); @@ -139,7 +139,7 @@ class DecodableAmNnetParallel: public DecodableInterface { virtual BaseFloat LogLikelihood(int32 frame, int32 transition_id) { if (feats_) Compute(); // this function sets feats_ to NULL. return log_probs_(frame, - trans_model_.TransitionIdToPdf(transition_id)); + trans_model_.TransitionIdToPdfFast(transition_id)); } int32 NumFramesReady() const { @@ -155,10 +155,10 @@ class DecodableAmNnetParallel: public DecodableInterface { return log_probs_.NumRows(); } } - + // Indices are one-based! This is for compatibility with OpenFst. virtual int32 NumIndices() const { return trans_model_.NumTransitionIds(); } - + virtual bool IsLastFrame(int32 frame) const { KALDI_ASSERT(frame < NumFramesReady()); return (frame == NumFramesReady() - 1); @@ -180,7 +180,7 @@ class DecodableAmNnetParallel: public DecodableInterface { - + } // namespace nnet2 } // namespace kaldi diff --git a/src/nnet2bin/Makefile b/src/nnet2bin/Makefile index 3280acfc968..b7e2c385006 100644 --- a/src/nnet2bin/Makefile +++ b/src/nnet2bin/Makefile @@ -38,7 +38,7 @@ ADDLIBS = ../nnet2/kaldi-nnet2.a ../nnet/kaldi-nnet.a \ ../cudamatrix/kaldi-cudamatrix.a ../decoder/kaldi-decoder.a \ ../lat/kaldi-lat.a ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/nnet3/Makefile b/src/nnet3/Makefile index 135853cadc3..aac16fb1c86 100644 --- a/src/nnet3/Makefile +++ b/src/nnet3/Makefile @@ -31,15 +31,16 @@ OBJFILES = nnet-common.o nnet-compile.o nnet-component-itf.o \ nnet-compile-looped.o decodable-simple-looped.o \ decodable-online-looped.o convolution.o \ nnet-convolutional-component.o attention.o \ - nnet-attention-component.o nnet-tdnn-component.o + nnet-attention-component.o nnet-tdnn-component.o nnet-batch-compute.o LIBNAME = kaldi-nnet3 ADDLIBS = ../chain/kaldi-chain.a ../cudamatrix/kaldi-cudamatrix.a \ - ../lat/kaldi-lat.a ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a \ + ../decoder/kaldi-decoder.a ../lat/kaldi-lat.a \ + ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/nnet3/decodable-online-looped.cc b/src/nnet3/decodable-online-looped.cc index f231a2d5b62..5817df5fd25 100644 --- a/src/nnet3/decodable-online-looped.cc +++ b/src/nnet3/decodable-online-looped.cc @@ -244,7 +244,7 @@ BaseFloat DecodableAmNnetLoopedOnline::LogLikelihood(int32 subsampled_frame, EnsureFrameIsComputed(subsampled_frame); return current_log_post_( subsampled_frame - current_log_post_subsampled_offset_, - trans_model_.TransitionIdToPdf(index)); + trans_model_.TransitionIdToPdfFast(index)); } diff --git a/src/nnet3/decodable-simple-looped.cc b/src/nnet3/decodable-simple-looped.cc index d4edb440d5a..0452304cf55 100644 --- a/src/nnet3/decodable-simple-looped.cc +++ b/src/nnet3/decodable-simple-looped.cc @@ -257,7 +257,7 @@ DecodableAmNnetSimpleLooped::DecodableAmNnetSimpleLooped( BaseFloat DecodableAmNnetSimpleLooped::LogLikelihood(int32 frame, int32 transition_id) { - int32 pdf_id = trans_model_.TransitionIdToPdf(transition_id); + int32 pdf_id = trans_model_.TransitionIdToPdfFast(transition_id); return decodable_nnet_.GetOutput(frame, pdf_id); } diff --git a/src/nnet3/nnet-am-decodable-simple.cc b/src/nnet3/nnet-am-decodable-simple.cc index d66e24830c6..9682bd96bc7 100644 --- a/src/nnet3/nnet-am-decodable-simple.cc +++ b/src/nnet3/nnet-am-decodable-simple.cc @@ -77,7 +77,7 @@ DecodableAmNnetSimple::DecodableAmNnetSimple( BaseFloat DecodableAmNnetSimple::LogLikelihood(int32 frame, int32 transition_id) { - int32 pdf_id = trans_model_.TransitionIdToPdf(transition_id); + int32 pdf_id = trans_model_.TransitionIdToPdfFast(transition_id); return decodable_nnet_.GetOutput(frame, pdf_id); } @@ -204,7 +204,7 @@ void DecodableNnetSimple::GetCurrentIvector(int32 output_t_start, << ", only available till frame " << online_ivector_feats_->NumRows() << " * ivector-period=" << online_ivector_period_ - << " (mismatched --ivector-period?)"; + << " (mismatched --online-ivector-period?)"; } ivector_frame = online_ivector_feats_->NumRows() - 1; } @@ -357,7 +357,7 @@ void DecodableAmNnetSimpleParallel::DeletePointers() { BaseFloat DecodableAmNnetSimpleParallel::LogLikelihood(int32 frame, int32 transition_id) { - int32 pdf_id = trans_model_.TransitionIdToPdf(transition_id); + int32 pdf_id = trans_model_.TransitionIdToPdfFast(transition_id); return decodable_nnet_->GetOutput(frame, pdf_id); } diff --git a/src/nnet3/nnet-batch-compute.cc b/src/nnet3/nnet-batch-compute.cc new file mode 100644 index 00000000000..6db046796be --- /dev/null +++ b/src/nnet3/nnet-batch-compute.cc @@ -0,0 +1,1313 @@ +// nnet3/nnet-batch-compute.cc + +// Copyright 2012-2018 Johns Hopkins University (author: Daniel Povey) +// 2018 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "nnet3/nnet-batch-compute.h" +#include "nnet3/nnet-utils.h" +#include "decoder/decodable-matrix.h" + +namespace kaldi { +namespace nnet3 { + + +NnetBatchComputer::NnetBatchComputer( + const NnetBatchComputerOptions &opts, + const Nnet &nnet, + const VectorBase &priors): + opts_(opts), + nnet_(nnet), + compiler_(nnet_, opts.optimize_config), + log_priors_(priors), + num_full_minibatches_(0) { + log_priors_.ApplyLog(); + CheckAndFixConfigs(); + ComputeSimpleNnetContext(nnet, &nnet_left_context_, + &nnet_right_context_); + input_dim_ = nnet.InputDim("input"); + ivector_dim_ = std::max(0, nnet.InputDim("ivector")); + output_dim_ = nnet.OutputDim("output"); + KALDI_ASSERT(input_dim_ > 0 && output_dim_ > 0); +} + +void NnetBatchComputer::PrintMinibatchStats() { + int32 max_stats_to_print = 10; + int64 tot_tasks = 0, tot_minibatches = 0; + double tot_time = 0.0; + std::ostringstream os; + struct MinibatchStats { + int32 num_frames_out; + int32 num_frames_in; + int32 minibatch_size; + int32 num_done; + int32 percent_full; + BaseFloat seconds_taken; + + bool operator < (const MinibatchStats &other) const { + return seconds_taken > other.seconds_taken; // sort from most to least time. + } + }; + std::vector all_stats; + os << "Minibatch stats: seconds-taken,frames-in:frames-out*minibatch-size=num-done(percent-full%) "; + + for (MapType::const_iterator iter = tasks_.begin(); + iter != tasks_.end(); ++iter) { + for (std::map::const_iterator + miter = iter->second.minibatch_info.begin(); + miter != iter->second.minibatch_info.end(); ++miter) { + const ComputationGroupKey &key = iter->first; + const MinibatchSizeInfo &minfo = miter->second; + MinibatchStats stats; + stats.num_frames_in = key.num_input_frames; + stats.num_frames_out = key.num_output_frames; + stats.minibatch_size = miter->first; + stats.num_done = minfo.num_done; + stats.seconds_taken = minfo.seconds_taken; + + tot_tasks += minfo.tot_num_tasks; + tot_minibatches += minfo.num_done; + tot_time += minfo.seconds_taken; + stats.percent_full = int32(minfo.tot_num_tasks * 100.0 / + (stats.minibatch_size * stats.num_done)); + all_stats.push_back(stats); + } + } + + std::sort(all_stats.begin(), all_stats.end()); + os << std::fixed << std::setprecision(2); + int32 num_stats = all_stats.size(); + for (int32 i = 0; i < std::min(num_stats, max_stats_to_print); i++) { + MinibatchStats &stats = all_stats[i]; + os << stats.seconds_taken << ',' << stats.num_frames_in << ':' + << stats.num_frames_out << '*' << stats.minibatch_size + << '=' << stats.num_done << '(' << stats.percent_full << "%) "; + } + if (num_stats > max_stats_to_print) + os << "..."; + KALDI_LOG << os.str(); + KALDI_LOG << "Did " << tot_tasks << " tasks in " << tot_minibatches + << " minibatches, taking " << tot_time << " seconds."; +} + +NnetBatchComputer::~NnetBatchComputer() { + PrintMinibatchStats(); + // the destructor shouldn't be called while the mutex is locked; if it is, it + // likely means the program has already crashed, or it's a programming error. + if (!mutex_.try_lock()) + KALDI_ERR << "Destructor called while object locked."; + int32 num_pending_tasks = 0; + for (auto iter = tasks_.begin(); iter != tasks_.end(); ++iter) + num_pending_tasks += iter->second.tasks.size(); + if (num_pending_tasks > 0) + KALDI_ERR << "Tasks are pending but object is being destroyed"; + for (auto iter = no_more_than_n_minibatches_full_.begin(); + iter != no_more_than_n_minibatches_full_.end(); ++iter) { + std::condition_variable *cond = iter->second; + // the next call will notify any threads that were waiting on this condition + // variable-- there shouldn't be any, though, as it would be a programming + // error, but better to wake them up so we can see any messages they print. + cond->notify_all(); + delete cond; + } + KALDI_ASSERT(num_full_minibatches_ == 0); // failure would be a coding error. +} + +NnetBatchComputer::MinibatchSizeInfo* +NnetBatchComputer::GetHighestPriorityComputation( + bool allow_partial_minibatch, + int32 *minibatch_size_out, + std::vector *tasks) { + tasks->clear(); + std::unique_lock(mutex_); + MapType::iterator iter = tasks_.begin(), end = tasks_.end(), + best_iter = tasks_.end(); + double highest_priority = -std::numeric_limits::infinity(); + + for (; iter != end; ++iter) { + ComputationGroupInfo &info = iter->second; + double this_priority = GetPriority(allow_partial_minibatch, info); + if (this_priority > highest_priority) { + highest_priority = this_priority; + best_iter = iter; + } + } + if (best_iter == tasks_.end()) { + // either allow_partial_minibatch == false and there were no full + // minibatches, or there were no pending tasks at all. + return NULL; + } + ComputationGroupInfo &info = best_iter->second; + int32 actual_minibatch_size = GetActualMinibatchSize(info); + *minibatch_size_out = actual_minibatch_size; + MinibatchSizeInfo *minfo = &(info.minibatch_info[actual_minibatch_size]); + if (minfo->computation == NULL) + minfo->computation = GetComputation(info, actual_minibatch_size); + GetHighestPriorityTasks(actual_minibatch_size, &info, tasks); + return minfo; +} + + +void NnetBatchComputer::GetHighestPriorityTasks( + int32 num_tasks_needed, + ComputationGroupInfo *info, + std::vector *tasks) { + int32 num_tasks_present = info->tasks.size(), + minibatch_size = GetMinibatchSize(*info); + KALDI_ASSERT(tasks->empty()); + if (num_tasks_needed >= num_tasks_present) { + tasks->swap(info->tasks); + } else { + int32 num_tasks_not_needed = num_tasks_present - num_tasks_needed; + // We don't sort the tasks with a comparator that dereferences the pointers, + // because the priorities can change asynchronously, and we're concerned that + // something weird might happen in the sorting if the things it's comparing + // are changing. + std::vector > pairs(num_tasks_present); + for (int32 i = 0; i < num_tasks_present; i++) { + pairs[i].first = info->tasks[i]->priority; + pairs[i].second = info->tasks[i]; + } + std::nth_element(pairs.begin(), pairs.begin() + num_tasks_not_needed, + pairs.end()); + + // The lowest-priority 'num_tasks_not_needed' stay in the 'info' struct. + info->tasks.clear(); + for (int32 i = 0; i < num_tasks_not_needed; i++) + info->tasks.push_back(pairs[i].second); + // The highest-priority 'num_tasks_needed' tasks go to the output 'tasks' + // array. + for (int32 i = num_tasks_not_needed; i < num_tasks_present; i++) + tasks->push_back(pairs[i].second); + // The following assertion checks that the is_edge and is_irregular values + // are the same for the entire minibatch, which they should always be. + KALDI_ASSERT(GetMinibatchSize(*info) == minibatch_size); + } + + { + // This block updates num_full_minibatches_ and notifies threads waiting on + // any related condition variable. + int32 new_num_tasks_present = info->tasks.size(), + full_minibatch_reduction = + (num_tasks_present / minibatch_size) - + (new_num_tasks_present / minibatch_size); + for (int32 i = 0; i < full_minibatch_reduction; i++) { + num_full_minibatches_--; + KALDI_ASSERT(num_full_minibatches_ >= 0); + std::unordered_map::const_iterator + iter = no_more_than_n_minibatches_full_.find(num_full_minibatches_); + if (iter != no_more_than_n_minibatches_full_.end()) { + std::condition_variable *cond = iter->second; + cond->notify_all(); + } + } + } +} + + +int32 NnetBatchComputer::GetMinibatchSize( + const ComputationGroupInfo &info) const { + if (info.tasks.empty()) { + return opts_.minibatch_size; // actually it shouldn't matter what we return + // in this case. + } + const NnetInferenceTask &task = *(info.tasks[0]); + if (task.is_irregular) + return 1; + else if (task.is_edge) + return opts_.edge_minibatch_size; + else + return opts_.minibatch_size; +} + +int32 NnetBatchComputer::GetActualMinibatchSize( + const ComputationGroupInfo &info) const { + KALDI_ASSERT(!info.tasks.empty()); + int32 num_tasks = info.tasks.size(), + this_minibatch_size = GetMinibatchSize(info); + KALDI_ASSERT(num_tasks > 0); + while (num_tasks < + int32(opts_.partial_minibatch_factor * this_minibatch_size)) + this_minibatch_size *= opts_.partial_minibatch_factor; + return int32(this_minibatch_size); +} + + +std::shared_ptr NnetBatchComputer::GetComputation( + const ComputationGroupInfo &info, + int32 minibatch_size) { + KALDI_ASSERT(!info.tasks.empty()); + // note: all the tasks will have the same structure, in the respects that + // would affect the computation. + NnetInferenceTask *example_task = info.tasks[0]; + ComputationRequest request; + GetComputationRequest(*example_task, minibatch_size, &request); + return compiler_.Compile(request); +} + + +double NnetBatchComputer::GetPriority(bool allow_partial_minibatch, + const ComputationGroupInfo &info) const { + if (info.tasks.empty()) + return -std::numeric_limits::infinity(); + int32 this_minibatch_size = GetMinibatchSize(info); + int32 num_tasks = info.tasks.size(); + + if (!allow_partial_minibatch && num_tasks < this_minibatch_size) + return -std::numeric_limits::infinity(); + + // penalty_for_not_full will be negative if the minibatch is not full, up to a + // maximum of 10. the 10 is a heuristic; it could be changed. + // Note: the penalty is effectively infinity if allow_partial_minibatch == false; + // see the 'return' above. + double proportion_full = std::min(num_tasks, this_minibatch_size) / + double(this_minibatch_size), + penalty_for_not_full = 10.0 * (proportion_full - 1.0), + task_priority_sum = 0.0; + + + if (num_tasks > this_minibatch_size) { + // Get the average of the priorities of the highest-priority tasks (no more + // than 'minibatch_size' of them. + std::vector priorities; + priorities.resize(num_tasks); + for (int32 i = 0; i < num_tasks; i++) + priorities[i] = info.tasks[i]->priority; + // sort from greatest to least. + std::nth_element(priorities.begin(), + priorities.begin() + this_minibatch_size, + priorities.end(), + std::greater()); + for (int32 i = 0; i < this_minibatch_size; i++) + task_priority_sum += priorities[i]; + return penalty_for_not_full + task_priority_sum / this_minibatch_size; + } else { + for (int32 i = 0; i < num_tasks; i++) + task_priority_sum += info.tasks[i]->priority; + return penalty_for_not_full + task_priority_sum / num_tasks; + } +} + + +// static +void NnetBatchComputer::GetComputationRequest( + const NnetInferenceTask &task, + int32 minibatch_size, + ComputationRequest *request) { + request->need_model_derivative = false; + request->store_component_stats = false; + request->inputs.reserve(2); + + int32 num_input_frames = task.input.NumRows(), + first_input_t = task.first_input_t, + num_output_frames = task.num_output_frames, + output_t_stride = task.output_t_stride; + bool has_ivector = (task.ivector.Dim() != 0); + + std::vector input_indexes, ivector_indexes, output_indexes; + input_indexes.reserve(minibatch_size * num_input_frames); + output_indexes.reserve(minibatch_size * num_output_frames); + if (has_ivector) + ivector_indexes.reserve(minibatch_size); + + for (int32 n = 0; n < minibatch_size; n++) { + for (int32 t = first_input_t; t < first_input_t + num_input_frames; t++) + input_indexes.push_back(Index(n, t, 0)); + if (has_ivector) + ivector_indexes.push_back(Index(n, 0, 0)); + for (int32 t = 0; t < num_output_frames; t++) + output_indexes.push_back(Index(n, t * output_t_stride, 0)); + } + request->inputs.push_back(IoSpecification("input", input_indexes)); + if (has_ivector) + request->inputs.push_back(IoSpecification("ivector", ivector_indexes)); + request->outputs.push_back(IoSpecification("output", output_indexes)); +} + + + +void NnetBatchComputer::CheckAndFixConfigs() { + static bool warned_frames_per_chunk = false; + int32 nnet_modulus = nnet_.Modulus(); + if (opts_.frame_subsampling_factor < 1 || + opts_.frames_per_chunk < 1) { + KALDI_ERR << "--frame-subsampling-factor and " + << "--frames-per-chunk must be > 0"; + } + KALDI_ASSERT(nnet_modulus > 0); + int32 n = Lcm(opts_.frame_subsampling_factor, nnet_modulus); + + if (opts_.frames_per_chunk % n != 0) { + // round up to the nearest multiple of n. + int32 frames_per_chunk = n * ((opts_.frames_per_chunk + n - 1) / n); + if (!warned_frames_per_chunk) { + warned_frames_per_chunk = true; + if (nnet_modulus == 1) { + // simpler error message. + KALDI_LOG << "Increasing --frames-per-chunk from " + << opts_.frames_per_chunk << " to " + << frames_per_chunk << " to make it a multiple of " + << "--frame-subsampling-factor=" + << opts_.frame_subsampling_factor; + } else { + KALDI_LOG << "Increasing --frames-per-chunk from " + << opts_.frames_per_chunk << " to " + << frames_per_chunk << " due to " + << "--frame-subsampling-factor=" + << opts_.frame_subsampling_factor << " and " + << "nnet shift-invariance modulus = " << nnet_modulus; + } + } + opts_.frames_per_chunk = frames_per_chunk; + } + KALDI_ASSERT(opts_.minibatch_size >= 1 && + opts_.edge_minibatch_size >= 1 && + opts_.partial_minibatch_factor < 1.0 && + opts_.partial_minibatch_factor >= 0.0); +} + + +void NnetBatchComputer::FormatInputs( + int32 minibatch_size, + const std::vector &tasks, + CuMatrix *input, + CuMatrix *ivector) { + int32 num_input_frames = tasks[0]->input.NumRows(), + input_dim = tasks[0]->input.NumCols(), + ivector_dim = tasks[0]->ivector.Dim(), + num_tasks = tasks.size(); + KALDI_ASSERT(num_tasks > 0 && num_tasks <= minibatch_size); + + // We first aggregate the input frames and i-vectors in matrices on the CPU, + // and then transfer them to the GPU. Later on we'll change this code to + // used pinned memory. + Matrix input_cpu(num_tasks * num_input_frames, input_dim, + kUndefined); + + + for (int32 n = 0; n < num_tasks; n++) { + SubMatrix input_part(input_cpu, + n * num_input_frames, num_input_frames, + 0, input_dim); + input_part.CopyFromMat(tasks[n]->input); + } + input->Resize(minibatch_size * num_input_frames, input_dim, + kUndefined); + input->RowRange(0, num_tasks * num_input_frames).CopyFromMat(input_cpu); + if (num_tasks < minibatch_size) { + // The following will make things easier to debug if something fails, but + // shouldn't be strictly necessary. + // the -1 means 'take all remaining rows'. + input->RowRange(num_tasks * num_input_frames, + (minibatch_size - num_tasks) * num_input_frames).SetZero(); + } + + if (ivector_dim != 0) { + Matrix ivectors_cpu(num_tasks, ivector_dim, kUndefined); + for (int32 n = 0; n < num_tasks; n++) + ivectors_cpu.Row(n).CopyFromVec(tasks[n]->ivector); + + ivector->Resize(minibatch_size, ivector_dim, kUndefined); + ivector->RowRange(0, num_tasks).CopyFromMat(ivectors_cpu); + + if (num_tasks < minibatch_size) { + // The following will make things easier to debug if something fails, but + // shouldn't be strictly necessary. + // the -1 means 'take all remaining rows'. + ivector->RowRange(num_tasks, minibatch_size - num_tasks).SetZero(); + } + } +} + +void NnetBatchComputer::FormatOutputs( + const CuMatrix &output, + const std::vector &tasks) { + KALDI_ASSERT(!tasks.empty()); + int32 num_output_frames = tasks[0]->num_output_frames, + output_dim = output.NumCols(), + num_tasks = tasks.size(); + bool did_output_to_gpu = false; + + // Note: it may not be optimal to do so many individual calls to copy the + // output to CPU; we'd have to test that, as I'm not sure how much the latency + // of a GPU call is. On the other hand, the downsides of one big call are + // that we'd have to make another copy in CPU memory; and also we might not be + // able to take advantage if not all frames of the output are used. + + // Also, we should probably used pinned memory. + + // We don't bother zeroing frames of the output that are unused, but you could + // un-comment the commented lines of code below to do so. + for (int32 n = 0; n < num_tasks; n++) { + NnetInferenceTask *task = tasks[n]; + + int32 left_unused = task->num_initial_unused_output_frames, + used = task->num_used_output_frames; + // int32 right_unused = num_output_frames - used - left_unused; + + if (task->output_to_cpu) { + task->output_cpu.Resize(num_output_frames, output_dim, + kUndefined); + // if (left_unused > 0) + // task->output_cpu.RowRange(0, left_unused).SetZero(); + task->output_cpu.RowRange(left_unused, used).CopyFromMat( + output.RowRange(n * num_output_frames + left_unused, used)); + // if (right_unused > 0) + // task->output_cpu.RowRange(0, left_unused + used, right_unused).SetZero(); + } else { + did_output_to_gpu = true; + task->output.Resize(num_output_frames, output_dim, + kUndefined); + // if (left_unused > 0) + // task->output.RowRange(0, left_unused).SetZero(); + task->output.RowRange(left_unused, used).CopyFromMat( + output.RowRange(n * num_output_frames + left_unused, used)); + // if (right_unused > 0) + // task->output.RowRange(0, left_unused + used, right_unused).SetZero(); + } + } + // The output of this function will likely be consumed by another thread. + // The following call will make sure the relevant kernels complete before + // any kernels from the other thread use the output. + if (did_output_to_gpu) + SynchronizeGpu(); +} + +void NnetBatchComputer::AcceptTask(NnetInferenceTask *task, + int32 max_minibatches_full) { + std::unique_lock lock(mutex_); + + if (max_minibatches_full > 0 && num_full_minibatches_ > max_minibatches_full) { + std::unordered_map::iterator + iter = no_more_than_n_minibatches_full_.find(max_minibatches_full); + std::condition_variable *cond; + if (iter != no_more_than_n_minibatches_full_.end()) { + cond = iter->second; + } else { + cond = new std::condition_variable(); + no_more_than_n_minibatches_full_[max_minibatches_full] = cond; + } + while (num_full_minibatches_ > max_minibatches_full) + cond->wait(lock); + } + ComputationGroupKey key(*task); + ComputationGroupInfo &info = tasks_[key]; + info.tasks.push_back(task); + int32 minibatch_size = GetMinibatchSize(info); + if (static_cast(info.tasks.size()) % minibatch_size == 0) + num_full_minibatches_++; +} + +bool NnetBatchComputer::Compute(bool allow_partial_minibatch) { + int32 minibatch_size; + std::vector tasks; + MinibatchSizeInfo *minfo = + GetHighestPriorityComputation(allow_partial_minibatch, + &minibatch_size, + &tasks); + if (minfo == NULL) + return false; + + Timer tim; + Nnet *nnet_to_update = NULL; // we're not doing any update + NnetComputer computer(opts_.compute_config, *(minfo->computation), + nnet_, nnet_to_update); + + + CuMatrix input; + CuMatrix ivector; + FormatInputs(minibatch_size, tasks, &input, &ivector); + computer.AcceptInput("input", &input); + if (ivector.NumRows() != 0) + computer.AcceptInput("ivector", &ivector); + computer.Run(); + CuMatrix output; + computer.GetOutputDestructive("output", &output); + if (log_priors_.Dim() != 0) { + output.AddVecToRows(-1.0, log_priors_); + } + output.Scale(opts_.acoustic_scale); + FormatOutputs(output, tasks); + + // Update the stats, for diagnostics. + minfo->num_done++; + minfo->tot_num_tasks += static_cast(tasks.size()); + minfo->seconds_taken += tim.Elapsed(); + + + SynchronizeGpu(); + + for (size_t i = 0; i < tasks.size(); i++) + tasks[i]->semaphore.Signal(); + + return true; +} + + +/** + This namespace contains things needed for the implementation of + the function NnetBatchComputer::SplitUtteranceIntoTasks(). + */ +namespace utterance_splitting { +/** + This function figures out how many chunks are needed for this utterance, + sets 'tasks' to a vector with that many elements, and sets up the + following elements in 'tasks': + output_t_stride + num_output_frames + num_initial_unused_output_frames + num_used_output_frames + @param [in] opts Options class + @param [in] num_subsampled_frames The number of output frames in this + utterance. Must be > 0. + @param [in] num_subsampled_frames_per_chunk The number of output frames + per chunk + @param [out] The 'tasks' array is output to here; it will have one + task per chunk, with only the members 'output_t_stride', + 'num_output_frames', 'num_initial_unused_output_frames', + 'num_used_output_frames' and 'is_irregular' set up. +*/ +void GetOutputFrameInfoForTasks( + const NnetBatchComputerOptions &opts, + int32 num_subsampled_frames, + int32 num_subsampled_frames_per_chunk, + std::vector *tasks) { + KALDI_ASSERT(num_subsampled_frames > 0); + int32 fpc = num_subsampled_frames_per_chunk; + int32 num_tasks = (num_subsampled_frames + fpc - 1) / fpc; + tasks->resize(num_tasks); + for (int32 i = 0; i < num_tasks; i++) { + (*tasks)[i].output_t_stride = opts.frame_subsampling_factor; + } + if (num_subsampled_frames <= fpc) { // there is one chunk. + KALDI_ASSERT(num_tasks == 1); // TODO: remove this. + NnetInferenceTask &task = (*tasks)[0]; + task.first_used_output_frame_index = 0; + if (opts.ensure_exact_final_context) { + task.num_output_frames = num_subsampled_frames; + task.num_initial_unused_output_frames = 0; + task.num_used_output_frames = num_subsampled_frames; + task.is_irregular = true; + } else { + task.num_output_frames = fpc; + task.num_initial_unused_output_frames = 0; + task.num_used_output_frames = num_subsampled_frames; + task.is_irregular = false; + } + } else { + for (int32 i = 0; i + 1 < num_tasks; i++) { + NnetInferenceTask &task = (*tasks)[i]; + task.num_output_frames = fpc; + task.num_initial_unused_output_frames = 0; + task.num_used_output_frames = fpc; + task.first_used_output_frame_index = i * fpc; + task.is_irregular = false; + } + // The last chunk will end on the last frame of the file, but we won't use + // the part of its output that overlaps with the preceding chunk. + NnetInferenceTask &task = (*tasks)[num_tasks - 1]; + task.num_output_frames = fpc; + task.num_initial_unused_output_frames = ((num_tasks - 1) * fpc) - + (num_subsampled_frames - fpc); + task.num_used_output_frames = + num_subsampled_frames - ((num_tasks - 1) * fpc); + task.first_used_output_frame_index = (num_tasks - 1) * fpc; + task.is_irregular = false; + } + + if (true) { + // Do some checking. TODO: remove this. + KALDI_ASSERT((*tasks)[0].first_used_output_frame_index == 0); + for (int32 i = 1; i < num_tasks; i++) { + KALDI_ASSERT((*tasks)[i].first_used_output_frame_index == + (*tasks)[i-1].first_used_output_frame_index + + (*tasks)[i-1].num_used_output_frames); + } + KALDI_ASSERT((*tasks)[num_tasks-1].first_used_output_frame_index + + (*tasks)[num_tasks-1].num_used_output_frames == + num_subsampled_frames); + for (int32 i = 0; i < num_tasks; i++) { + const NnetInferenceTask &task = (*tasks)[i]; + KALDI_ASSERT(task.num_used_output_frames + + task.num_initial_unused_output_frames <= + task.num_output_frames); + } + } +} + +void AddOnlineIvectorsToTasks( + const NnetBatchComputerOptions &opts, + const Matrix &online_ivectors, + int32 online_ivector_period, + std::vector *tasks) { + int32 f = opts.frame_subsampling_factor, + num_tasks = tasks->size(); + for (int32 i = 0; i < num_tasks; i++) { + NnetInferenceTask &task = (*tasks)[i]; + // begin_output_t and end_output_t are the subsampled frame indexes at + // the output; you'd have to multiply them by f to get real frame indexes. + int32 begin_output_t = task.first_used_output_frame_index - + task.num_initial_unused_output_frames, + mid_output_t = begin_output_t + (task.num_output_frames / 2), + mid_input_t = mid_output_t * f, + ivector_frame = mid_input_t / online_ivector_period, + num_ivector_frames = online_ivectors.NumRows(), + margin_in_frames = 20, + margin_in_ivector_frames = + (margin_in_frames + online_ivector_period - 1) / online_ivector_period; + // the 'margin' is our tolerance for when the number of rows of + // 'online_ivectors' is less than what we expected; we allow 20 frames of + // tolerance in the numbering of the original (input) features. + if (ivector_frame >= num_ivector_frames) { + if (num_ivector_frames > 0 && ivector_frame > num_ivector_frames - + margin_in_ivector_frames) { + ivector_frame = num_ivector_frames - 1; // Just take the last available one. + } else { + KALDI_ERR << "Could not get iVector for frame " << ivector_frame + << ", online-ivectors matrix has " + << online_ivectors.NumRows() + << " rows. Mismatched --online-ivector-period?"; + } + } + task.ivector = online_ivectors.Row(ivector_frame); + } +} + + + +/** + This function sets up the 'input' and 'first_input_t' and 'is_edge' members + of the 'tasks' array; it is responsible for working out, for each task, + which input frames it needs (including left-context and right-context). + + The 'nnet_left_context' and 'nnet_right_context' are the inherent left + and right context of the network (num-frames required on left and right + to compute an output frame), and may be computed by doing: + ComputeSimpleNnetContext(nnet, &nnet_left_context_, &nnet_right_context_) +*/ +static void SplitInputToTasks(const NnetBatchComputerOptions &opts, + int32 nnet_left_context, + int32 nnet_right_context, + const Matrix &input, + std::vector *tasks) { + int32 num_input_frames = input.NumRows(), + f = opts.frame_subsampling_factor, + num_subsampled_frames = (num_input_frames + f - 1) / f, + extra_left_context_initial = (opts.extra_left_context_initial < 0 ? + opts.extra_left_context : + opts.extra_left_context_initial), + extra_right_context_final = (opts.extra_right_context_final < 0 ? + opts.extra_right_context : + opts.extra_right_context_final), + num_tasks = tasks->size(); + for (int32 i = 0; i < num_tasks; i++) { + NnetInferenceTask &task = (*tasks)[i]; + // begin_output_t and end_output_t are the subsampled frame indexes at + // the output; you'd have to multiply them by f to get real frame indexes. + int32 begin_output_t = task.first_used_output_frame_index - + task.num_initial_unused_output_frames, + end_output_t = begin_output_t + task.num_output_frames; + // begin_input_t and end_input_t are the real 't' values corresponding to + // begin_output_t and end_output_t; they are the beginning and end + // (i.e. first and last-plus-one) frame indexes without any left or right + // context. + int32 begin_input_t = begin_output_t * f, + end_input_t = end_output_t * f; + // Detect whether the left and right edges touch (or pass over) the left + // and right boundaries. Note: we don't expect begin_output_t to ever be + // negative. + bool left_edge = (begin_output_t <= 0), + right_edge = (end_output_t >= num_subsampled_frames); + int32 tot_left_context = nnet_left_context + + (left_edge ? extra_left_context_initial : opts.extra_left_context), + tot_right_context = nnet_right_context + + (right_edge ? extra_right_context_final : opts.extra_right_context); + + // 'is_edge' is only true if it's an edge minibatch *and* its being an + // edge actually made a difference to the structure of the example. + task.is_edge = + (tot_left_context != nnet_left_context + opts.extra_left_context || + tot_right_context != nnet_right_context + opts.extra_right_context); + + int32 begin_input_t_padded = begin_input_t - tot_left_context, + end_input_t_padded = end_input_t + tot_right_context; + + // 'task.first_input_t' is a representation of 'begin_input_t_padded' in a + // shifted/normalized numbering where the output time indexes start from + // zero. + task.first_input_t = begin_input_t_padded - (begin_output_t * f); + + task.input.Resize(end_input_t_padded - begin_input_t_padded, + input.NumCols(), kUndefined); + // the 't' value below is in the numbering of 'input'. + for (int32 t = begin_input_t_padded; t < end_input_t_padded; t++) { + int32 t_clipped = t; + if (t_clipped < 0) t_clipped = 0; + if (t_clipped >= num_input_frames) t_clipped = num_input_frames - 1; + SubVector dest(task.input, + t - begin_input_t_padded), + src(input, t_clipped); + dest.CopyFromVec(src); + } + } +} + +} // namespace utterance_splitting + + +void NnetBatchComputer::SplitUtteranceIntoTasks( + bool output_to_cpu, + const Matrix &input, + const Vector *ivector, + const Matrix *online_ivectors, + int32 online_ivector_period, + std::vector *tasks) { + using namespace utterance_splitting; + + + { // This block does some checking. + if (input.NumCols() != input_dim_) { + KALDI_ERR << "Input features did not have expected dimension: expected " + << input_dim_ << ", got " << input.NumCols(); + } + int32 ivector_dim = (ivector != NULL ? ivector->Dim() : + (online_ivectors != NULL ? + online_ivectors->NumCols() : 0)); + if (ivector_dim_ != 0 && ivector_dim == 0) + KALDI_ERR << "Model expects i-vectors but none were supplied"; + else if (ivector_dim_ == 0 && ivector_dim != 0) + KALDI_ERR << "You supplied i-vectors but model does not expect them."; + else if (ivector_dim != ivector_dim_) + KALDI_ERR << "I-vector dimensions mismatch: model expects " + << ivector_dim_ << ", you supplied " << ivector_dim; + } + + + int32 num_input_frames = input.NumRows(), + f = opts_.frame_subsampling_factor, + num_subsampled_frames = (num_input_frames + f - 1) / f, + num_subsampled_frames_per_chunk = opts_.frames_per_chunk / f; + + GetOutputFrameInfoForTasks(opts_, num_subsampled_frames, + num_subsampled_frames_per_chunk, + tasks); + + SplitInputToTasks(opts_, nnet_left_context_, nnet_right_context_, + input, tasks); + + if (ivector != NULL) { + KALDI_ASSERT(online_ivectors == NULL); + for (size_t i = 0; i < tasks->size(); i++) + (*tasks)[i].ivector = *ivector; + } else if (online_ivectors != NULL) { + AddOnlineIvectorsToTasks(opts_, *online_ivectors, + online_ivector_period, tasks); + } + + for (size_t i = 0; i < tasks->size(); i++) { + (*tasks)[i].output_to_cpu = output_to_cpu; + // The priority will be set by the user; this just avoids undefined + // behavior. + (*tasks)[i].priority = 0.0; + } +} + + +void MergeTaskOutput( + const std::vector &tasks, + Matrix *output) { + int32 num_tasks = tasks.size(), + num_output_frames = 0, + output_dim = -1; + for (int32 i = 0; i < num_tasks; i++) { + const NnetInferenceTask &task = tasks[i]; + num_output_frames += task.num_used_output_frames; + if (i == 0) { + output_dim = (task.output_to_cpu ? + task.output_cpu.NumCols() : + task.output.NumCols()); + } + } + KALDI_ASSERT(num_output_frames != 0 && output_dim != 0); + int32 cur_output_frame = 0; + output->Resize(num_output_frames, output_dim); + for (int32 i = 0; i < num_tasks; i++) { + const NnetInferenceTask &task = tasks[i]; + int32 skip = task.num_initial_unused_output_frames, + num_used = task.num_used_output_frames; + KALDI_ASSERT(cur_output_frame == task.first_used_output_frame_index); + if (task.output_to_cpu) { + output->RowRange(cur_output_frame, num_used).CopyFromMat( + task.output_cpu.RowRange(skip, num_used)); + } else { + output->RowRange(cur_output_frame, num_used).CopyFromMat( + task.output.RowRange(skip, num_used)); + } + cur_output_frame += num_used; + } + KALDI_ASSERT(cur_output_frame == num_output_frames); +} + + +NnetBatchInference::NnetBatchInference( + const NnetBatchComputerOptions &opts, + const Nnet &nnet, + const VectorBase &priors): + computer_(opts, nnet, priors), + is_finished_(false), + utterance_counter_(0) { + // 'thread_' will run the Compute() function in the background. + compute_thread_ = std::thread(ComputeFunc, this); +} + + +void NnetBatchInference::AcceptInput( + const std::string &utterance_id, + const Matrix &input, + const Vector *ivector, + const Matrix *online_ivectors, + int32 online_ivector_period) { + + UtteranceInfo *info = new UtteranceInfo(); + info->utterance_id = utterance_id; + info->num_tasks_finished = 0; + bool output_to_cpu = true; // This wrapper is for when you need the nnet + // output on CPU, e.g. because you want it + // written to disk. If this needs to be + // configurable in the future, we can make changes + // then. + computer_.SplitUtteranceIntoTasks( + output_to_cpu, input, ivector, online_ivectors, + online_ivector_period, &(info->tasks)); + + // Setting this to a nonzero value will cause the AcceptTask() call below to + // hang until the computation thread has made some progress, if too much + // data is already queued. + int32 max_full_minibatches = 2; + + // Earlier utterances have higher priority, which is important to make sure + // that their corresponding tasks are completed and they can be output to disk. + double priority = -1.0 * (utterance_counter_++); + for (size_t i = 0; i < info->tasks.size(); i++) { + info->tasks[i].priority = priority; + computer_.AcceptTask(&(info->tasks[i]), max_full_minibatches); + } + utts_.push_back(info); + tasks_ready_semaphore_.Signal(); +} + +bool NnetBatchInference::GetOutput(std::string *utterance_id, + Matrix *output) { + if (utts_.empty()) + return false; + + UtteranceInfo *info = *utts_.begin(); + std::vector &tasks = info->tasks; + int32 num_tasks = tasks.size(); + for (; info->num_tasks_finished < num_tasks; ++info->num_tasks_finished) { + Semaphore &semaphore = tasks[info->num_tasks_finished].semaphore; + if (is_finished_) { + semaphore.Wait(); + } else { + if (!semaphore.TryWait()) { + // If not all of the tasks of this utterance are ready yet, + // just return false. + return false; + } + } + } + MergeTaskOutput(tasks, output); + *utterance_id = info->utterance_id; + delete info; + utts_.pop_front(); + return true; +} + +NnetBatchInference::~NnetBatchInference() { + if (!is_finished_) + KALDI_ERR << "Object destroyed before Finished() was called."; + if (!utts_.empty()) + KALDI_ERR << "You should get all output before destroying this object."; + compute_thread_.join(); +} + +void NnetBatchInference::Finished() { + is_finished_ = true; + tasks_ready_semaphore_.Signal(); +} + +// This is run as the thread of class NnetBatchInference. +void NnetBatchInference::Compute() { + bool allow_partial_minibatch = false; + while (true) { + // keep calling Compute() as long as it makes progress. + while (computer_.Compute(allow_partial_minibatch)); + + // ... then wait on tasks_ready_semaphore_. + tasks_ready_semaphore_.Wait(); + if (is_finished_) { + allow_partial_minibatch = true; + while (computer_.Compute(allow_partial_minibatch)); + return; + } + } +} + + +NnetBatchDecoder::NnetBatchDecoder( + const fst::Fst &fst, + const LatticeFasterDecoderConfig &decoder_opts, + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + bool allow_partial, + int32 num_threads, + NnetBatchComputer *computer): + fst_(fst), decoder_opts_(decoder_opts), + trans_model_(trans_model), word_syms_(word_syms), + allow_partial_(allow_partial), computer_(computer), + is_finished_(false), tasks_finished_(false), priority_offset_(0.0), + tot_like_(0.0), frame_count_(0), num_success_(0), num_fail_(0), + num_partial_(0) { + KALDI_ASSERT(num_threads > 0); + for (int32 i = 0; i < num_threads; i++) + decode_threads_.push_back(new std::thread(DecodeFunc, this)); + compute_thread_ = std::thread(ComputeFunc, this); +} + +void NnetBatchDecoder::SetPriorities(std::vector *tasks) { + size_t num_tasks = tasks->size(); + double priority_offset = priority_offset_; + for (size_t i = 0; i < num_tasks; i++) + (*tasks)[i].priority = priority_offset - (double)i; +} + +void NnetBatchDecoder::UpdatePriorityOffset(double priority) { + size_t num_tasks = decode_threads_.size(), + new_weight = 1.0 / num_tasks, + old_weight = 1.0 - new_weight; + // The next line is vulnerable to a race condition but if it happened it + // wouldn't matter. + priority_offset_ = priority_offset_ * old_weight + priority * new_weight; +} + +void NnetBatchDecoder::AcceptInput( + const std::string &utterance_id, + const Matrix &input, + const Vector *ivector, + const Matrix *online_ivectors, + int32 online_ivector_period){ + // This function basically does a handshake with one of the decoder threads. + // It may have to wait till one of the decoder threads becomes ready. + input_utterance_.utterance_id = utterance_id; + input_utterance_.input = &input; + input_utterance_.ivector = ivector; + input_utterance_.online_ivectors = online_ivectors; + input_utterance_.online_ivector_period = online_ivector_period; + + + UtteranceOutput *this_output = new UtteranceOutput(); + this_output->utterance_id = utterance_id; + pending_utts_.push_back(this_output); + + input_ready_semaphore_.Signal(); + input_consumed_semaphore_.Wait(); +} + +int32 NnetBatchDecoder::Finished() { + is_finished_ = true; + for (size_t i = 0; i < decode_threads_.size(); i++) + input_ready_semaphore_.Signal(); + for (size_t i = 0; i < decode_threads_.size(); i++) { + decode_threads_[i]->join(); + delete decode_threads_[i]; + decode_threads_[i] = NULL; + } + // don't clear decode_threads_, since its size is needed in the destructor to + // compute timing. + + tasks_finished_ = true; + tasks_ready_semaphore_.Signal(); + compute_thread_.join(); + return num_success_; +} + + +bool NnetBatchDecoder::GetOutput( + std::string *utterance_id, + CompactLattice *clat, + std::string *sentence) { + if (!decoder_opts_.determinize_lattice) + KALDI_ERR << "Don't call this version of GetOutput if you are " + "not determinizing."; + while (true) { + if (pending_utts_.empty()) + return false; + if (!pending_utts_.front()->finished) + return false; + UtteranceOutput *this_output = pending_utts_.front(); + pending_utts_.pop_front(); + if (this_output->compact_lat.NumStates() == 0) { + delete this_output; + // ... and continue round the loop, without returning any output to the + // user for this utterance. Something went wrong in decoding: for + // example, the user specified allow_partial == false and no final-states + // were active on the last frame, or something more unexpected. A warning + // would have been printed in the decoder thread. + } else { + *clat = this_output->compact_lat; + utterance_id->swap(this_output->utterance_id); + sentence->swap(this_output->sentence); + delete this_output; + return true; + } + } +} + + +bool NnetBatchDecoder::GetOutput( + std::string *utterance_id, + Lattice *lat, + std::string *sentence) { + if (decoder_opts_.determinize_lattice) + KALDI_ERR << "Don't call this version of GetOutput if you are " + "determinizing."; + while (true) { + if (pending_utts_.empty()) + return false; + if (!pending_utts_.front()->finished) + return false; + UtteranceOutput *this_output = pending_utts_.front(); + pending_utts_.pop_front(); + if (this_output->compact_lat.NumStates() == 0) { + delete this_output; + // ... and continue round the loop, without returning any output to the + // user for this utterance. Something went wrong in decoding: for + // example, the user specified allow_partial == false and no final-states + // were active on the last frame, or something more unexpected. A warning + // would have been printed in the decoder thread. + } else { + *lat = this_output->lat; // OpenFST has shallow copy so no need to swap. + utterance_id->swap(this_output->utterance_id); + sentence->swap(this_output->sentence); + delete this_output; + return true; + } + } +} + +void NnetBatchDecoder::Compute() { + while (!tasks_finished_) { + tasks_ready_semaphore_.Wait(); + bool allow_partial_minibatch = true; + while (computer_->Compute(allow_partial_minibatch)); + } +} + +void NnetBatchDecoder::Decode() { + while (true) { + input_ready_semaphore_.Wait(); + if (is_finished_) + return; + + std::vector tasks; + std::string utterance_id; + // we can be confident that the last element of 'pending_utts_' is the one + // for this utterance, as we know exactly at what point in the code the main + // thread will be in AcceptInput(). + UtteranceOutput *output_utterance = pending_utts_.back(); + { + UtteranceInput input_utterance(input_utterance_); + utterance_id = input_utterance.utterance_id; + bool output_to_cpu = true; + computer_->SplitUtteranceIntoTasks(output_to_cpu, + *(input_utterance.input), + input_utterance.ivector, + input_utterance.online_ivectors, + input_utterance.online_ivector_period, + &tasks); + KALDI_ASSERT(output_utterance->utterance_id == utterance_id); + input_consumed_semaphore_.Signal(); + // Now let input_utterance go out of scope; it's no longer valid as it may + // be overwritten by something else. + } + + SetPriorities(&tasks); + for (size_t i = 0; i < tasks.size(); i++) + computer_->AcceptTask(&(tasks[i])); + tasks_ready_semaphore_.Signal(); + + { + int32 frame_offset = 0; + LatticeFasterDecoder decoder(fst_, decoder_opts_); + decoder.InitDecoding(); + + + for (size_t i = 0; i < tasks.size(); i++) { + NnetInferenceTask &task = tasks[i]; + task.semaphore.Wait(); + UpdatePriorityOffset(task.priority); + + SubMatrix post(task.output_cpu, + task.num_initial_unused_output_frames, + task.num_used_output_frames, + 0, task.output_cpu.NumCols()); + DecodableMatrixMapped decodable(trans_model_, post, frame_offset); + frame_offset += post.NumRows(); + decoder.AdvanceDecoding(&decodable); + task.output.Resize(0, 0); // Free some memory. + } + + bool use_final_probs = true; + if (!decoder.ReachedFinal()) { + if (allow_partial_) { + KALDI_WARN << "Outputting partial output for utterance " + << utterance_id << " since no final-state reached\n"; + use_final_probs = false; + std::unique_lock lock(stats_mutex_); + num_partial_++; + } else { + KALDI_WARN << "Not producing output for utterance " << utterance_id + << " since no final-state reached and " + << "--allow-partial=false.\n"; + std::unique_lock lock(stats_mutex_); + num_fail_++; + continue; + } + } + // if we reached this point, we are getting a lattice. + decoder.GetRawLattice(&output_utterance->lat, use_final_probs); + // Let the decoder and the decodable object go out of scope, to save + // memory. + } + ProcessOutputUtterance(output_utterance); + } +} + + +void NnetBatchDecoder::UtteranceFailed() { + std::unique_lock lock(stats_mutex_); + num_fail_++; +} + +void NnetBatchDecoder::ProcessOutputUtterance(UtteranceOutput *output) { + fst::Connect(&(output->lat)); + if (output->lat.NumStates() == 0) { + KALDI_WARN << "Unexpected problem getting lattice for utterance " + << output->utterance_id; + std::unique_lock lock(stats_mutex_); + num_fail_++; + return; + } + + { // This block accumulates diagnostics, prints log messages, and sets + // output->sentence. + Lattice best_path; + LatticeWeight weight; + ShortestPath(output->lat, &best_path); + std::vector alignment; + std::vector words; + GetLinearSymbolSequence(best_path, &alignment, &words, &weight); + int32 num_frames = alignment.size(); + if (word_syms_ != NULL) { + std::ostringstream os; + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms_->Find(words[i]); + if (s == "") + KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; + os << s << ' '; + } + output->sentence = os.str(); + } + double likelihood = -(weight.Value1() + weight.Value2()); + // Note: these logging messages will be out-of-order w.r.t. the transcripts + // that are printed to cerr; we keep those transcripts in the same order + // that the utterances were in, but these logging messages may be out of + // order (due to multiple threads). + KALDI_LOG << "Log-like per frame for utterance " << output->utterance_id + << " is " << (likelihood / num_frames) << " over " + << num_frames << " frames."; + KALDI_VLOG(2) << "Cost for utterance " << output->utterance_id << " is " + << weight.Value1() << " + " << weight.Value2(); + + std::unique_lock lock(stats_mutex_); + tot_like_ += likelihood; + frame_count_ += num_frames; + num_success_ += 1; + } + + if (decoder_opts_.determinize_lattice) { + if (!DeterminizeLatticePhonePrunedWrapper( + trans_model_, + &output->lat, + decoder_opts_.lattice_beam, + &(output->compact_lat), + decoder_opts_.det_opts)) + KALDI_WARN << "Determinization finished earlier than the beam for " + << "utterance " << output->utterance_id; + output->lat.DeleteStates(); // Save memory. + } + + // We'll write the lattice without acoustic scaling, so we need to reverse + // the scale that we applied when decoding. + BaseFloat acoustic_scale = computer_->GetOptions().acoustic_scale; + if (acoustic_scale != 0.0) { + if (decoder_opts_.determinize_lattice) + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), + &(output->compact_lat)); + else + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), + &(output->lat)); + } + output->finished = true; +} + + + +NnetBatchDecoder::~NnetBatchDecoder() { + if (!is_finished_ || !pending_utts_.empty()) { + // At this point the application is bound to fail so raising another + // exception is not a big problem. + KALDI_ERR << "Destroying NnetBatchDecoder object without calling " + "Finished() and consuming the remaining output"; + } + // Print diagnostics. + + kaldi::int64 input_frame_count = + frame_count_ * computer_->GetOptions().frame_subsampling_factor; + int32 num_threads = static_cast(decode_threads_.size()); + + KALDI_LOG << "Overall likelihood per frame was " + << tot_like_ / std::max(1, frame_count_) + << " over " << frame_count_ << " frames."; + + double elapsed = timer_.Elapsed(); + // the +1 below is just to avoid division-by-zero errors. + KALDI_LOG << "Time taken "<< elapsed + << "s: real-time factor assuming 100 frames/sec is " + << (num_threads * elapsed * 100.0 / + std::max(input_frame_count, 1)) + << " (per thread; with " << num_threads << " threads)."; + KALDI_LOG << "Done " << num_success_ << " utterances (" + << num_partial_ << " forced out); failed for " + << num_fail_; +} + + +} // namespace nnet3 +} // namespace kaldi diff --git a/src/nnet3/nnet-batch-compute.h b/src/nnet3/nnet-batch-compute.h new file mode 100644 index 00000000000..9861a28976c --- /dev/null +++ b/src/nnet3/nnet-batch-compute.h @@ -0,0 +1,836 @@ +// nnet3/nnet-batch-compute.h + +// Copyright 2012-2018 Johns Hopkins University (author: Daniel Povey) +// 2018 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_NNET3_NNET_BATCH_COMPUTE_H_ +#define KALDI_NNET3_NNET_BATCH_COMPUTE_H_ + +#include +#include +#include +#include +#include +#include "base/kaldi-common.h" +#include "gmm/am-diag-gmm.h" +#include "hmm/transition-model.h" +#include "itf/decodable-itf.h" +#include "nnet3/nnet-optimize.h" +#include "nnet3/nnet-compute.h" +#include "nnet3/am-nnet-simple.h" +#include "nnet3/nnet-am-decodable-simple.h" +#include "decoder/lattice-faster-decoder.h" +#include "util/stl-utils.h" + + +namespace kaldi { +namespace nnet3 { + + +/** + class NnetInferenceTask represents a chunk of an utterance that is + requested to be computed. This will be given to NnetBatchComputer, which + will aggregate the tasks and complete them. + */ +struct NnetInferenceTask { + // The copy constructor is required to exist because of std::vector's resize() + // function, but in practice should never be used. + NnetInferenceTask(const NnetInferenceTask &other) { + KALDI_ERR << "NnetInferenceTask was not designed to be copied."; + } + NnetInferenceTask() { } + + + // The input frames, which are treated as being numbered t=0, t=1, etc. (If + // the lowest t value was originally nonzero in the 'natural' numbering, this + // just means we conceptually shift the 't' values; the only real constraint + // is that the 't' values are contiguous. + Matrix input; + + // The index of the first output frame (in the shifted numbering where the + // first output frame is numbered zero. This will typically be less than one, + // because most network topologies require left context. If this was an + // 'interior' chunk of a recurrent topology like LSTMs, first_input_t may be + // substantially less than zero, due to 'extra_left_context'. + int32 first_input_t; + + // The stride of output 't' values: e.g., will be 1 for normal-frame-rate + // models, and 3 for low-frame-rate models such as chain models. + int32 output_t_stride; + + // The number of output 't' values (they will start from zero and be separated + // by output_t_stride). This will be the num-rows of 'output'. + int32 num_output_frames; + + // 'num_initial_unused_output_frames', which will normally be zero, is the + // number of rows of the output matrix ('output' or 'output_cpu') which won't + // actually be needed by the user, usually because they overlap with a + // previous chunk. This can happen because the number of outputs isn't a + // multiple of the number of chunks. + int32 num_initial_unused_output_frames; + + // 0 < num_used_output_frames <= num_output_frames - num_initial_unused_output_frames + // is the number of output frames which are actually going to be used by the + // user. (Due to edge effects, not all are necessarily used). + int32 num_used_output_frames; + + // first_used_output_frame_index is provided for the convenience of the user + // so that they can know how this chunk relates to the utterance which it is + // a part of. + // It represents an output frame index in the original utterance-- after + // subsampling; so not a 't' value but a 't' value divided by + // frame-subsampling-factor. Specifically, it tells you the row index in the + // full utterance's output which corresponds to the first 'used' frame index + // at the output of this chunk, specifically: the row numbered + // 'num_initial_unused_output_frames' in the 'output' or 'output_cpu' data + // member. + int32 first_used_output_frame_index; + + // True if this chunk is an 'edge' (the beginning or end of an utterance) AND + // is structurally different somehow from non-edge chunk, e.g. requires less + // context. This is present only so that NnetBatchComputer will know the + // appropriate minibatch size to use. + bool is_edge; + + // True if this task represents an irregular-sized chunk. These can happen + // only for utterances that are shorter than the requested minibatch size, and + // it should be quite rare. We use a minibatch size of 1 in this case. + bool is_irregular; + + // The i-vector for this chunk, if this network accepts i-vector inputs. + Vector ivector; + + // A priority (higher is more urgent); may be either sign. May be updated + // after this object is provided to class NnetBatchComputer. + double priority; + + // This semaphore will be incremented by class NnetBatchComputer when this + // chunk is done. After this semaphore is incremented, class + // NnetBatchComputer will no longer hold any pointers to this class. + Semaphore semaphore; + + // Will be set to true by the caller if they want the output of the neural net + // to be copied to CPU (to 'output'). If false, the output will stay on + // the GPU (if used)- in cu_output. + bool output_to_cpu; + + // The neural net output, of dimension num_output_frames by the output-dim of + // the neural net, will be written to 'output_cpu' if 'output_to_cpu' is true. + // This is expected to be empty when this task is provided to class + // NnetBatchComputer, and will be nonempty (if output_to_cpu == true) when the + // task is completed and the semaphore is signaled. + Matrix output_cpu; + + // The output goes here instead of 'output_to_cpu' is false. + CuMatrix output; +}; + + +struct NnetBatchComputerOptions: public NnetSimpleComputationOptions { + int32 minibatch_size; + int32 edge_minibatch_size; + bool ensure_exact_final_context; + BaseFloat partial_minibatch_factor; + + NnetBatchComputerOptions(): minibatch_size(128), + edge_minibatch_size(32), + ensure_exact_final_context(false), + partial_minibatch_factor(0.5) { + } + + void Register(OptionsItf *po) { + NnetSimpleComputationOptions::Register(po); + po->Register("minibatch-size", &minibatch_size, "Number of chunks per " + "minibatch (see also edge-minibatch-size)"); + po->Register("edge-minibatch-size", &edge_minibatch_size, "Number of " + "chunks per minibatch: this applies to chunks at the " + "beginnings and ends of utterances, in cases (such as " + "recurrent models) when the computation would be different " + "from the usual one."); + po->Register("ensure-exact-final-context", &ensure_exact_final_context, + "If true, for utterances shorter than --frames-per-chunk, " + "use exact-length, special computations. If false, " + "pad with repeats of the last frame. Would only affect " + "the output for backwards-recurrent models, but would " + "negatively impact speed in all cases."); + po->Register("partial-minibatch-factor", &partial_minibatch_factor, + "Factor that controls how small partial minibatches will be " + "they become necessary. We will potentially do the computation " + "for sizes: int(partial_minibatch_factor^n * minibatch_size " + ", for n = 0, 1, 2.... Set it to 0.0 if you want to use " + "only the specified minibatch sizes."); + } +}; + + +/** + Merges together the 'output_cpu' (if the 'output_to_cpu' members are true) or + the 'output' members of 'tasks' into a single CPU matrix 'output'. Requires that + those outputs are nonempty (i.e. that those tasks must have been completed). + + @param [in] tasks The vector of tasks whose outputs are to be merged. + The tasks must have already been completed. + @param [output output The spliced-together output matrix + + TODO: in the future, maybe start from GPU and use pinned matrices for the + transfer. + */ +void MergeTaskOutput( + const std::vector &tasks, + Matrix *output); + +/** + This class does neural net inference in a way that is optimized for GPU use: + it combines chunks of multiple utterances into minibatches for more efficient + computation. It does the computation in one background thread that accesses + the GPU. It is thread safe, i.e. you can call it from multiple threads + without having to worry about data races and the like. +*/ +class NnetBatchComputer { + public: + /** Constructor. It stores references to all the arguments, so don't delete + them till this object goes out of scop. + + \param [in] opts Options struct + \param [in] nnet The neural net which we'll be doing the computation with + \param [in] priors Either the empty vector, or a vector of prior + probabilities which we'll take the log of and subtract + from the neural net outputs (e.g. used in non-chain + systems). + */ + NnetBatchComputer(const NnetBatchComputerOptions &opts, + const Nnet &nnet, + const VectorBase &priors); + + + /// Accepts a task, meaning the task will be queued. (Note: the pointer is + /// still owned by the caller. + /// If the max_minibatches_full >= 0, then the calling thread will block until + /// no more than that many full minibatches are waiting to be computed. This + /// is a mechanism to prevent too many requests from piling up in memory. + void AcceptTask(NnetInferenceTask *task, + int32 max_minibatches_full = -1); + + /// Returns the number of full minibatches waiting to be computed. + int32 NumFullPendingMinibatches() const { return num_full_minibatches_; } + + + /** + Does some kind of computation, choosing the highest-priority thing to + compute. It returns true if it did some kind of computation, and false + otherwise. This function locks the class, but not for the entire time + it's being called: only at the beginning and at the end. + @param [in] allow_partial_minibatch If false, then this will only + do the computation if a full minibatch is ready; if true, it + is allowed to do computation on partial (not-full) minibatches. + */ + bool Compute(bool allow_partial_minibatch); + + + /** + Split a single utterance into a list of separate tasks which can then + be given to this class by AcceptTask(). + + @param [in] output_to_cpu Will become the 'output_to_cpu' member of the + output tasks; this controls whether the computation code should transfer + the outputs to CPU (which is to save GPU memory). + @param [in] ivector If non-NULL, and i-vector for the whole utterance is + expected to be supplied here (and online_ivectors should be NULL). + This is relevant if you estimate i-vectors per speaker instead of + online. + @param [in] online_ivectors Matrix of ivectors, one every 'online_ivector_period' frames. + @param [in] online_ivector_period Affects the interpretation of 'online_ivectors'. + @param [out] tasks The tasks created will be output to here. The + priorities will be set to zero; setting them to a meaningful + value is up to the caller. + */ + void SplitUtteranceIntoTasks( + bool output_to_cpu, + const Matrix &input, + const Vector *ivector, + const Matrix *online_ivectors, + int32 online_ivector_period, + std::vector *tasks); + + const NnetBatchComputerOptions &GetOptions() { return opts_; } + + ~NnetBatchComputer(); + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(NnetBatchComputer); + + // Information about a specific minibatch size for a group of tasks sharing a + // specific structure (in terms of left and right context, etc.) + struct MinibatchSizeInfo { + // the computation for this minibatch size. + std::shared_ptr computation; + int32 num_done; // The number of minibatches computed: for diagnostics. + int64 tot_num_tasks; // The total number of tasks in those minibatches, + // also for diagnostics... can be used to compute + // how 'full', on average, these minibatches were. + double seconds_taken; // The total time elapsed in computation for this + // minibatch type. + MinibatchSizeInfo(): computation(NULL), num_done(0), + tot_num_tasks(0), seconds_taken(0.0) { } + }; + + + // A computation group is a group of tasks that have the same structure + // (number of input and output frames, left and right context). + struct ComputationGroupInfo { + // The tasks to be completed. This array is added-to by AcceptTask(), + // and removed-from by GetHighestPriorityComputation(), which is called + // from Compute(). + std::vector tasks; + + // Map from minibatch-size to information specific to this minibatch-size, + // including the NnetComputation. This is set up by + // GetHighestPriorityComputation(), which is called from Compute(). + std::map minibatch_info; + }; + + // This struct allows us to arrange the tasks into groups that can be + // computed in the same minibatch. + struct ComputationGroupKey { + ComputationGroupKey(const NnetInferenceTask &task): + num_input_frames(task.input.NumRows()), + first_input_t(task.first_input_t), + num_output_frames(task.num_output_frames) {} + + bool operator == (const ComputationGroupKey &other) const { + return num_input_frames == other.num_input_frames && + first_input_t == other.first_input_t && + num_output_frames == other.num_output_frames; + } + int32 num_input_frames; + int32 first_input_t; + int32 num_output_frames; + }; + + struct ComputationGroupKeyHasher { + int32 operator () (const ComputationGroupKey &key) const { + return key.num_input_frames + 18043 * key.first_input_t + + 6413 * key.num_output_frames; + } + }; + + + typedef unordered_map MapType; + + // Gets the priority for a group, higher means higher priority. (A group is a + // list of tasks that may be computed in the same minibatch). What this + // function does is a kind of heuristic. + // If allow_partial_minibatch == false, it will set the priority for + // any minibatches that are not full to negative infinity. + inline double GetPriority(bool allow_partial_minibatch, + const ComputationGroupInfo &info) const; + + // Returns the minibatch size for this group of tasks, i.e. the size of a full + // minibatch for this type of task, which is what we'd ideally like to + // compute. Note: the is_edge and is_irregular options should be the same + // for for all tasks in the group. + // - If 'tasks' is empty or info.is_edge and info.is_irregular are both, + // false, then return opts_.minibatch_size + // - If 'tasks' is nonempty and tasks[0].is_irregular is true, then + // returns 1. + // - If 'tasks' is nonempty and tasks[0].is_irregular is false and + // tasks[0].is_edge is true, then returns opts_.edge_minibatch_size. + inline int32 GetMinibatchSize(const ComputationGroupInfo &info) const; + + + // This function compiles, and returns, a computation for tasks of + // the structure present in info.tasks[0], and the specified minibatch + // size. + std::shared_ptr GetComputation( + const ComputationGroupInfo &info, + int32 minibatch_size); + + + // Returns the actual minibatch size we'll use for this computation. In most + // cases it will be opts_.minibatch_size (or opts_.edge_minibatch_size if + // appropriate; but if the number of available tasks is much less than the + // appropriate minibatch size, it may be less. The minibatch size may be + // greater than info.tasks.size(); in that case, the remaining 'n' values in + // the minibatch are not used. (It may also be less than info.tasks.size(), + // in which case we only do some of them). + int32 GetActualMinibatchSize(const ComputationGroupInfo &info) const; + + + // This function gets the highest-priority 'num_tasks' tasks from 'info', + // removes them from the array info->tasks, and puts them into the array + // 'tasks' (which is assumed to be initially empty). + // This function also updates the num_full_minibatches_ variable if + // necessary, and takes care of notifying any related condition variables. + void GetHighestPriorityTasks( + int32 num_tasks, + ComputationGroupInfo *info, + std::vector *tasks); + + /** + This function finds and returns the computation corresponding to the + highest-priority group of tasks. + + @param [in] allow_partial_minibatch If this is true, then this + function may return a computation corresponding to a partial + minibatch-- i.e. the minibatch size in the computation may be + less than the minibatch size in the options class, and/or + the number of tasks may not be as many as the minibatch size + in the computation. + @param [out] minibatch_size If this function returns non-NULL, then + this will be set to the minibatch size that the returned + computation expects. This may be less than tasks->size(), + in cases where the minibatch was not 'full'. + @param [out] tasks The tasks which we'll be doing the computation + for in this minibatch are put here (and removed from tasks_, + in cases where this function returns non-NULL. + @return This function returns a pointer to the appropriate + 'MinibatchSizeInfo' object corresponding to the computation + that we'll be doing for this minibatch, or NULL if there is nothing + to compute. + */ + MinibatchSizeInfo *GetHighestPriorityComputation( + bool allow_partial_minibatch, + int32 *minibatch_size, + std::vector *tasks); + + /** + formats the inputs to the computation and transfers them to GPU. + @param [in] minibatch_size The number of parallel sequences + we're doing this computation for. This will be + more than tasks.size() in some cases. + @param [in] tasks The tasks we're doing the computation for. + The input comes from here. + @param [out] input The main feature input to the computation is + put into here. + @param [out] ivector If we're using i-vectors, the i-vectors are + put here. + */ + void FormatInputs(int32 minibatch_size, + const std::vector &tasks, + CuMatrix *input, + CuMatrix *ivector); + + + // Copies 'output', piece by piece, to the 'output_cpu' or 'output' + // members of 'tasks', depending on their 'output_to_cpu' value. + void FormatOutputs(const CuMatrix &output, + const std::vector &tasks); + + + // Changes opts_.frames_per_chunk to be a multiple of + // opts_.frame_subsampling_factor, if needed. + void CheckAndFixConfigs(); + + // this function creates and returns the computation request which is to be + // compiled. + static void GetComputationRequest(const NnetInferenceTask &task, + int32 minibatch_size, + ComputationRequest *request); + + // Prints some logging information about what we computed, with breakdown by + // minibatch type. + void PrintMinibatchStats(); + + NnetBatchComputerOptions opts_; + const Nnet &nnet_; + CachingOptimizingCompiler compiler_; + CuVector log_priors_; + + // Mutex that guards this object. It is only held for fairly quick operations + // (not while the actual computation is being done). + std::mutex mutex_; + + // tasks_ contains all the queued tasks. + // Each key contains a vector of NnetInferenceTask* pointers, of the same + // structure (i.e., IsCompatible() returns true). + MapType tasks_; + + // num_full_minibatches_ is a function of the data in tasks_ (and the + // minibatch sizes, specified in opts_. It is the number of full minibatches + // of tasks that are pending, meaning: for each group of tasks, the number of + // pending tasks divided by the minibatch-size for that group in integer + // arithmetic. This is kept updated for thread synchronization reasons, because + // it is the shared variable + int32 num_full_minibatches_; + + // a map from 'n' to a condition variable corresponding to the condition: + // num_full_minibatches_ <= n. Any time the number of full minibatches drops + // below n, the corresponding condition variable is notified (if it exists). + std::unordered_map no_more_than_n_minibatches_full_; + + // some static information about the neural net, computed at the start. + int32 nnet_left_context_; + int32 nnet_right_context_; + int32 input_dim_; + int32 ivector_dim_; + int32 output_dim_; +}; + + +/** + This class implements a simplified interface to class NnetBatchComputer, + which is suitable for programs like 'nnet3-compute' where you want to support + fast GPU-based inference on a sequence of utterances, and get them back + from the object in the same order. + */ +class NnetBatchInference { + public: + + NnetBatchInference( + const NnetBatchComputerOptions &opts, + const Nnet &nnet, + const VectorBase &priors); + + /** + The user should call this one by one for the utterances that this class + needs to compute (interspersed with calls to GetOutput()). This call + will block when enough ready-to-be-computed data is present. + + @param [in] utterance_id The string representing the utterance-id; + it will be provided back to the user when GetOutput() is + called. + @param [in] input The input features (e.g. MFCCs) + @param [in] ivector If non-NULL, this is expected to be the + i-vector for this utterance (and 'online_ivectors' should + be NULL). + @param [in] online_ivector_period Only relevant if + 'online_ivector' is non-NULL, this says how many + frames of 'input' is covered by each row of + 'online_ivectors'. + */ + void AcceptInput(const std::string &utterance_id, + const Matrix &input, + const Vector *ivector, + const Matrix *online_ivectors, + int32 online_ivector_period); + + /** + The user should call this after the last input has been provided + via AcceptInput(). This will force the last utterances to be + flushed out (to be retrieved by GetOutput()), rather than waiting + until the relevant minibatches are full. + */ + void Finished(); + + /** + The user should call this to obtain output. It's guaranteed to + be in the same order as the input was provided, but it may be + delayed. 'output' will be the output of the neural net, spliced + together over the chunks (and with acoustic scaling applied if + it was specified in the options; the subtraction of priors will + depend whether you supplied a non-empty vector of priors to the + constructor. + + This call does not block (i.e. does not wait on any semaphores) unless you + have previously called Finished(). It returns true if it actually got any + output; if none was ready it will return false. + */ + bool GetOutput(std::string *utterance_id, + Matrix *output); + + ~NnetBatchInference(); + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(NnetBatchInference); + + // This is the computation thread, which is run in the background. It will + // exit once the user calls Finished() and all computation is completed. + void Compute(); + // static wrapper for Compute(). + static void ComputeFunc(NnetBatchInference *object) { object->Compute(); } + + + // This object implements the internals of what this class does. It is + // accessed both by the main thread (from where AcceptInput(), Finished() and + // GetOutput() are called), and from the background thread in which Compute() + // is called. + NnetBatchComputer computer_; + + // This is set to true when the user calls Finished(); the computation thread + // sees it and knows to flush + bool is_finished_; + + // This semaphore is signaled by the main thread (the thread in which + // AcceptInput() is called) every time a new utterance is added, and waited on + // in the background thread in which Compute() is called. + Semaphore tasks_ready_semaphore_; + + struct UtteranceInfo { + std::string utterance_id; + // The tasks into which we split this utterance. + std::vector tasks; + // 'num_tasks_finished' is the number of tasks which are known to be + // finished, meaning we successfully waited for those tasks' 'semaphore' + // member. When this reaches tasks.size(), we are ready to consolidate + // the output into a single matrix and return it to the user. + size_t num_tasks_finished; + }; + + // This list is only accessed directly by the main thread, by AcceptInput() + // and GetOutput(). It is a list of utterances, with more recently added ones + // at the back. When utterances are given to the user by GetOutput(), + std::list utts_; + + int32 utterance_counter_; // counter that increases on every utterance. + + // The thread running the Compute() process. + std::thread compute_thread_; +}; + + +/** + Decoder object that uses multiple CPU threads for the graph search, plus a + GPU for the neural net inference (that's done by a separate + NnetBatchComputer object). The interface of this object should + accessed from only one thread, though-- presumably the main thread of the + program. + */ +class NnetBatchDecoder { + public: + /** + Constructor. + @param [in] fst FST that we are decoding with, will be shared between + all decoder threads. + @param [in] decoder_config Configuration object for the decoders. + @param [in] trans_model The transition model-- needed to construct the decoders, + and for determinization. + @param [in] word_syms A pointer to a symbol table of words, used for printing + the decoded words to stderr. If NULL, the word-level output will not + be logged. + @param [in] allow_partial If true, in cases where no final-state was reached + on the final frame of the decoding, we still output a lattice; + it just may contain partial words (words that are cut off in + the middle). If false, we just won't output anything for + those lattices. + @param [in] num_threads The number of decoder threads to use. It will use + two more threads on top of this: the main thread, for I/O, + and a thread for possibly-GPU-based inference. + @param [in] computer The NnetBatchComputer object, through which the + neural net will be evaluated. + */ + NnetBatchDecoder(const fst::Fst &fst, + const LatticeFasterDecoderConfig &decoder_config, + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + bool allow_partial, + int32 num_threads, + NnetBatchComputer *computer); + + /** + The user should call this one by one for the utterances that + it needs to compute (interspersed with calls to GetOutput()). This + call will block when no threads are ready to start processing this + utterance. + + @param [in] utterance_id The string representing the utterance-id; + it will be provided back to the user when GetOutput() is + called. + @param [in] input The input features (e.g. MFCCs) + @param [in] ivector If non-NULL, this is expected to be the + i-vector for this utterance (and 'online_ivectors' should + be NULL). + @param [in] online_ivector_period Only relevant if + 'online_ivector' is non-NULL, this says how many + frames of 'input' is covered by each row of + 'online_ivectors'. + */ + void AcceptInput(const std::string &utterance_id, + const Matrix &input, + const Vector *ivector, + const Matrix *online_ivectors, + int32 online_ivector_period); + + /* + The user should call this function each time there was a problem with an utterance + prior to being able to call AcceptInput()-- e.g. missing i-vectors. This will + update the num-failed-utterances stats which are stored in this class. + */ + void UtteranceFailed(); + + /* + The user should call this when all input has been provided, e.g. + when AcceptInput will not be called any more. It will block until + all threads have terminated; after that, you can call GetOutput() + until it returns false, which will guarantee that nothing remains + to compute. + It returns the number of utterances that have been successfully decoded. + */ + int32 Finished(); + + /** + The user should call this to obtain output (This version should + only be called if config.determinize_lattice == true (w.r.t. the + config provided to the constructor). The output is guaranteed to + be in the same order as the input was provided, but it may be + delayed, *and* some outputs may be missing, for example because + of search failures (allow_partial will affect this). + + The acoustic scores in the output lattice will already be divided by + the acoustic scale we decoded with. + + This call does not block (i.e. does not wait on any semaphores). It + returns true if it actually got any output; if none was ready it will + return false. + @param [out] utterance_id If an output was ready, its utterance-id is written to here. + @param [out] clat If an output was ready, it compact lattice will be + written to here. + @param [out] sentence If an output was ready and a nonempty symbol table + was provided to the constructor of this class, contains + the word-sequence decoded as a string. Otherwise will + be empty. + @return Returns true if a decoded output was ready. (These appear asynchronously + as the decoding is done in background threads). + */ + bool GetOutput(std::string *utterance_id, + CompactLattice *clat, + std::string *sentence); + + // This version of GetOutput is for where config.determinize_lattice == false + // (w.r.t. the config provided to the constructor). It is the same as the + // other version except it outputs to a normal Lattice, not a CompactLattice. + bool GetOutput(std::string *utterance_id, + Lattice *lat, + std::string *sentence); + + ~NnetBatchDecoder(); + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(NnetBatchDecoder); + + struct UtteranceInput { + std::string utterance_id; + const Matrix *input; + const Vector *ivector; + const Matrix *online_ivectors; + int32 online_ivector_period; + }; + + // This object is created when a thread finished an utterance. For utterances + // where decoding failed somehow, the relevant lattice (compact_lat, if + // opts_.determinize == true, or lat otherwise) will be empty (have no + // states). + struct UtteranceOutput { + std::string utterance_id; + bool finished; + CompactLattice compact_lat; + Lattice lat; + std::string sentence; // 'sentence' is only nonempty if a non-NULL symbol + // table was provided to the constructor of class + // NnetBatchDecoder; it's the sentence as a string (a + // sequence of words separated by space). It's used + // for printing the sentence to stderr, which we do + // in the main thread to keep the order consistent. + }; + + // This is the decoding thread, several copies of which are run in the + // background. It will exit once the user calls Finished() and all + // computation is completed. + void Decode(); + // static wrapper for Compute(). + static void DecodeFunc(NnetBatchDecoder *object) { object->Decode(); } + + // This is the computation thread; it handles the neural net inference. + void Compute(); + // static wrapper for Compute(). + static void ComputeFunc(NnetBatchDecoder *object) { object->Compute(); } + + + // Sets the priorities of the tasks in a newly provided utterance. + void SetPriorities(std::vector *tasks); + + // In the single-thread case, this sets priority_offset_ to 'priority'. + // In the multi-threaded case it causes priority_offset_ to approach + // 'priority' at a rate that depends on the nunber of threads. + void UpdatePriorityOffset(double priority); + + // This function does the determinization (if needed) and finds the best path through + // the lattice to update the stats. It is expected that when it is called, 'output' must + // have its 'lat' member set up. + void ProcessOutputUtterance(UtteranceOutput *output); + + const fst::Fst &fst_; + const LatticeFasterDecoderConfig &decoder_opts_; + const TransitionModel &trans_model_; + const fst::SymbolTable *word_syms_; // May be NULL. Owned here. + bool allow_partial_; + NnetBatchComputer *computer_; + std::vector decode_threads_; + std::thread compute_thread_; // Thread that calls computer_->Compute(). + + + // 'input_utterance', together with utterance_ready_semaphore_ and + // utterance_consumed_semaphore_, use used to 'hand off' information about a + // newly provided utterance from AcceptInput() to a decoder thread that is + // ready to process a new utterance. + UtteranceInput input_utterance_; + Semaphore input_ready_semaphore_; // Is signaled by the main thread when + // AcceptInput() is called and a new + // utterance is being provided (or when the + // input is finished), and waited on in + // decoder thread. + Semaphore input_consumed_semaphore_; // Is signaled in decoder thread when it + // has finished consuming the input, so + // the main thread can know when it + // should continue (to avoid letting + // 'input' go out of scope while it's + // still needed). + + Semaphore tasks_ready_semaphore_; // Is signaled when new tasks are added to + // the computer_ object (or when we're finished). + + bool is_finished_; // True if the input is finished. If this is true, a + // signal to input_ready_semaphore_ indicates to the + // decoder thread that it should terminate. + + bool tasks_finished_; // True if we know that no more tasks will be given + // to the computer_ object. + + + // pending_utts_ is a list of utterances that have been provided via + // AcceptInput(), but their decoding has not yet finished. AcceptInput() will + // push_back to it, and GetOutput() will pop_front(). When a decoding thread + // has finished an utterance it will set its 'finished' member to true. There + // is no need to synchronize or use mutexes here. + std::list pending_utts_; + + // priority_offset_ is something used in determining the priorities of nnet + // computation tasks. It starts off at zero and becomes more negative with + // time, with the aim being that the priority of the first task (i.e. the + // leftmost chunk) of a new utterance should be at about the same priority as + // whatever chunks we are just now getting around to decoding. + double priority_offset_; + + // Some statistics accumulated by this class, for logging and timing purposes. + double tot_like_; // Total likelihood (of best path) over all lattices that + // we output. + int64 frame_count_; // Frame count over all latices that we output. + int32 num_success_; // Number of successfully decoded files. + int32 num_fail_; // Number of files where decoding failed. + int32 num_partial_; // Number of files that were successfully decoded but + // reached no final-state (can only be nonzero if + // allow_partial_ is true). + std::mutex stats_mutex_; // Mutex that guards the statistics from tot_like_ + // through num_partial_. + Timer timer_; // Timer used to print real-time info. +}; + + +} // namespace nnet3 +} // namespace kaldi + +#endif // KALDI_NNET3_NNET_BATCH_COMPUTE_H_ diff --git a/src/nnet3/nnet-chain-training.cc b/src/nnet3/nnet-chain-training.cc index 87eacf75327..a798cb597f5 100644 --- a/src/nnet3/nnet-chain-training.cc +++ b/src/nnet3/nnet-chain-training.cc @@ -95,78 +95,6 @@ void NnetChainTrainer::Train(const NnetChainExample &chain_eg) { num_minibatches_processed_++; } -// This object exists to help avoid memory fragmentation: it allocates, -// but does not use, the exact sizes of memory that are going to be needed -// in ComputeChainObjfAndDeriv(). -class ChainTrainerMemoryHolder { - public: - ChainTrainerMemoryHolder(const Nnet &nnet, - int32 num_den_graph_states, - const NnetChainExample &eg); - private: - CuMatrix nnet_output_deriv_; - CuMatrix xent_output_deriv_; - CuMatrix beta_; - CuMatrix alpha_; - -}; - -ChainTrainerMemoryHolder::ChainTrainerMemoryHolder(const Nnet &nnet, - int32 den_graph_states, - const NnetChainExample &eg) { - - std::vector::const_iterator iter = eg.outputs.begin(), - end = eg.outputs.end(); - - int32 max_rows = 0, - max_cols = 0; - - size_t max_frames_per_sequence = 0, - max_sequence_size = 0, - max_alpha_matrix_size = 0; - - for (; iter != end; ++iter) { - // there will normally be just one of these things; we'll normally loop once. - const NnetChainSupervision &sup = *iter; - - int32 output_rows = sup.supervision.num_sequences * sup.supervision.frames_per_sequence; - int32 output_cols = nnet.OutputDim("output"); - - size_t curr_frames_per_sequence = output_rows / sup.supervision.num_sequences + 1; - size_t den_graph_size = den_graph_states + 1; - size_t curr_sequence_size = den_graph_size * sup.supervision.num_sequences; - size_t curr_alpha_matrix_size = curr_frames_per_sequence * curr_sequence_size; - - if (curr_alpha_matrix_size > max_alpha_matrix_size) { - max_alpha_matrix_size = curr_alpha_matrix_size; - max_frames_per_sequence = curr_frames_per_sequence; - max_sequence_size = curr_sequence_size; - } - - size_t matrix_size = output_rows * output_cols; - if (matrix_size > (max_rows * max_cols)) { - max_rows = output_rows; - max_cols = output_cols; - } - } - - // the sequence of resizes is in a specific order (bigger to smaller) - // so that the cudaMalloc won't trash the memory it has already - // alloc'd in the previous iterations - alpha_.Resize(max_frames_per_sequence, - max_sequence_size, - kUndefined); - - - nnet_output_deriv_.Resize(max_rows, max_cols, kUndefined); - // note: the same block of memory can be used for xent_output_deriv_ as is - // used for exp_nnet_output_transposed_ in chain-training.cc. - xent_output_deriv_.Resize(max_rows, max_cols, - kUndefined, kStrideEqualNumCols); - - beta_.Resize(2, max_sequence_size, kUndefined); -} - void NnetChainTrainer::TrainInternal(const NnetChainExample &eg, const NnetComputation &computation) { const NnetTrainerOptions &nnet_config = opts_.nnet_config; @@ -176,20 +104,10 @@ void NnetChainTrainer::TrainInternal(const NnetChainExample &eg, NnetComputer computer(nnet_config.compute_config, computation, nnet_, delta_nnet_); - // reserve the memory needed in ProcessOutputs (before memory gets fragmented - // by the call to computer.Run(). - ChainTrainerMemoryHolder *memory_holder = - new ChainTrainerMemoryHolder(*nnet_, den_graph_.NumStates(), eg); - // give the inputs to the computer object. computer.AcceptInputs(*nnet_, eg.inputs); computer.Run(); - // 'this->ProcessOutputs()' is going to need the same sizes as are stored in - // 'memory_holder'. - delete memory_holder; - - // Probably could be merged in a single call PreallocateChainTrainerMemory(*nnet_, eg) ? this->ProcessOutputs(false, eg, &computer); computer.Run(); diff --git a/src/nnet3/nnet-compile-utils.cc b/src/nnet3/nnet-compile-utils.cc index 49012e08884..b1f9d0b0e2b 100644 --- a/src/nnet3/nnet-compile-utils.cc +++ b/src/nnet3/nnet-compile-utils.cc @@ -25,351 +25,164 @@ namespace kaldi { namespace nnet3 { -// this comparator will be used to sort pairs using first_element -// we declare it as a struct as it will also be used by std::lower_bound -// method which will supply elements of different types to the function -struct FirstElementComparator { - int first_element(int32 t) const { - return t; - } - - int first_element(std::pair t) const { - return t.first; - } - template< typename T1, typename T2> - bool operator()( T1 const & t1, T2 const & t2) const { - return first_element(t1) < first_element(t2); - } -}; - -// This comparator is used with std::find_if function to search for pairs -// whose first element is equal to the given pair -struct FirstElementIsEqualComparator : - public std::unary_function, bool> -{ - explicit FirstElementIsEqualComparator(const int32 element): - element_(element) {} - bool operator() (std::pair const &arg) - { return (arg.first == element_); } - int32 element_; -}; - -// This comparator is used with std::find_if function to search for pairs -// whose .first and .second elements are equal to the given pair -struct PairIsEqualComparator : - public std::unary_function, bool> -{ - explicit PairIsEqualComparator(const std::pair pair): - pair_(pair) {} - bool operator() (std::pair const &arg) - { - if (pair_.first == arg.first) - return pair_.second == arg.second; - return false; +/** + Gets counts of submatrices (the 1st members of pairs) in submat_lists. + Also outputs, to 'submats_with_large_counts', a list of submatrix indexes + that have counts over half of submat_lists.size(). (These will be separated + out into their own AddRows() commands). + */ +void GetSubmatCounts( + const std::vector > > &submat_lists, + std::unordered_map *submat_counts, + std::vector *submats_with_large_counts) { + auto iter = submat_lists.begin(), end = submat_lists.end(); + for (; iter != end; ++iter) { + std::vector >::const_iterator + iter2 = iter->begin(), end2 = iter->end(); + for (; iter2 != end2; ++iter2) { + int32 submat_index = iter2->first; + KALDI_ASSERT(submat_index >= 0); // We don't expect -1's in submat_lists. + std::unordered_map::iterator + iter = submat_counts->find(submat_index); + if (iter == submat_counts->end()) + (*submat_counts)[submat_index] = 1; + else + iter->second++; + } } - std::pair pair_; -}; - -// this comparator will be used to sort pairs initially by second element in -// descending order and then by first element in descending order. -// note, std::sort accepts an actual function as an alternative to a -// function object. -bool SecondElementComparator(const std::pair& first_pair, - const std::pair& second_pair) { - if (first_pair.second == second_pair.second) - return first_pair.first > second_pair.first; - return first_pair.second > second_pair.second; + auto counts_iter = submat_counts->begin(), + counts_end = submat_counts->end(); + size_t cutoff = submat_lists.size() / 2; + for (; counts_iter != counts_end; ++counts_iter) + if (counts_iter->second > cutoff) + submats_with_large_counts->push_back(counts_iter->first); } -// Function to sort the lists in a vector of lists of pairs, by the first -// element of the pair -void SortSubmatLists( - // vector of list of location pairs - const std::vector > > submat_lists, - // a copy of the input submat_lists where the lists are sorted - // (this will be used in the caller function for sort and find functions) - std::vector > > * sorted_submat_lists, - // maximum size of the submat_lists - int32* max_submat_list_size - ) -{ - *max_submat_list_size = 0; - sorted_submat_lists->reserve(submat_lists.size()); - KALDI_ASSERT(submat_lists.size() > 0); - for (int32 i = 0; i < submat_lists.size(); i++) { - if (submat_lists[i].size() > *max_submat_list_size) - *max_submat_list_size = submat_lists[i].size(); - sorted_submat_lists->push_back(submat_lists[i]); - std::sort((*sorted_submat_lists)[i].begin(), - (*sorted_submat_lists)[i].end(), - FirstElementComparator()); +/** + This function, used in SplitLocations(), is used to make separate + 'split lists' for certain high-count submatrix indexes, specified by + the user in 'submats_to_separate'. These split + lists will be lists of pairs that are all either (-1, 1) or (submatrix_index, x) + for a particular submatrix index (constant within the split list). + These high-count lists will be written to 'split_lists'; they + will eventually compile to AddRows() commands. We write the remaining + members of the lists in 'submat_lists' (the ones that did not make it + into 'split_lists') to 'reduced_submat_lists'. + */ +void SeparateSubmatsWithLargeCounts( + const std::vector &submats_to_separate, + const std::vector > > &submat_lists, + std::vector > > *reduced_submat_lists, + std::vector > > *split_lists) { + KALDI_ASSERT(split_lists->empty() && !submats_to_separate.empty()); + size_t num_to_separate = submats_to_separate.size(), + num_rows = submat_lists.size(); + std::unordered_map submat_to_index; + reduced_submat_lists->clear(); + reduced_submat_lists->resize(num_rows); + split_lists->resize(num_to_separate); + for (size_t i = 0; i < num_to_separate; i++) { + (*split_lists)[i].resize(num_rows, std::pair(-1, -1)); + int32 submat = submats_to_separate[i]; + submat_to_index[submat] = i; } -} - -// Function to compute a histogram of the submat_index, -// which is the first_element in the location pair, given vector of list of -// location pairs -void ComputeSubmatIndexHistogram( - // vector of list of pairs of location pairs where the lists are sorted - // by submat_indexes (.first element) - const std::vector > > - sorted_submat_lists, - // a histogram of submat_indexes where - // the keys are submat_indexes and values are a vector of frequencies - // of first occurrence, second occurrence, etc. of a submat_index - // in a submat_list - unordered_map >* submat_histogram - ) { - KALDI_ASSERT(sorted_submat_lists.size() > 0); - // computing the submat_histogram - // counting the occurrences of each element in the current submat_list; - // each new occurrence of the same element, in this list, is counted - // as a seperate symbol for frequency counts - for (int32 i = 0; i < sorted_submat_lists.size(); i++) { - int j = 0; - unordered_map >::iterator histogram_iterator - = submat_histogram->end(); - int32 repetition_count = 0; - while (j < sorted_submat_lists[i].size()) { - if ((histogram_iterator == submat_histogram->end()) || - (histogram_iterator->first != sorted_submat_lists[i][j].first)) { - histogram_iterator = - submat_histogram->find(sorted_submat_lists[i][j].first); - repetition_count = 0; - // if a histogram entry was not found for this submat_index, add one - if (histogram_iterator == submat_histogram->end()) { - (*submat_histogram)[sorted_submat_lists[i][j].first]; - histogram_iterator = submat_histogram->find( - sorted_submat_lists[i][j].first); - } + for (size_t row = 0; row < submat_lists.size(); row++) { + std::vector >::const_iterator + iter = submat_lists[row].begin(), end = submat_lists[row].end(); + std::vector > + &reduced_list = (*reduced_submat_lists)[row]; + // 'reduced_lists' will contain the pairs that don't make it into + // 'split_lists'. + for (; iter != end; ++iter) { + int32 submat_index = iter->first; + std::unordered_map::const_iterator map_iter = + submat_to_index.find(submat_index); + if (map_iter == submat_to_index.end()) { // not a large-count submatrix. + reduced_list.push_back(*iter); + continue; } - - if (repetition_count >= (histogram_iterator->second).size()) { - // this is the first time the submat_index repeated this many times - // so add an entry for this in the count vector - (histogram_iterator->second).push_back(1); - } else { - (histogram_iterator->second)[repetition_count]++; + size_t index = map_iter->second; + std::pair &p = (*split_lists)[index][row]; + if (p.first >= 0) { + // we'd only reach here if the same submat index repeated in the same + // row, which is possible but rare. + reduced_list.push_back(*iter); + continue; } - repetition_count++; - j++; - } - } -} - - -// Function to find the first occurrence of a submat_index in list of location -// pairs from a vector of list of locations pairs. -// The occurrences are returned as a list of vector iterators, -// pointing to the position of the pair in the list or to the -// end of the list (when the pair is not present) -void FindSubmatIndexInSubmatLists( - // submat_index to search in the submat_lists - int32 submat_index, - // sorted_submat_lists is a pointer as we want non-const iterators in the - // output - std::vector > > *sorted_submat_lists, - // a vector of iterators to store the location of the pairs - std::vector >::iterator> - *output_iterator_list, - // the max size of the submat_lists if the found pairs have been removed - int32 *max_remaining_submat_list_size) { - - output_iterator_list->reserve(sorted_submat_lists->size()); - *max_remaining_submat_list_size = 0; - for (int32 i = 0; i < sorted_submat_lists->size(); i++) { - std::vector< std::pair > & submat_list = - (*sorted_submat_lists)[i]; - output_iterator_list->push_back( - std::find_if(submat_list.begin(), submat_list.end(), - FirstElementIsEqualComparator(submat_index))); - int32 remaining_submat_list_size = submat_list.size(); - if (output_iterator_list->back() != submat_list.end()) { - // since the submat_index is present in this submat_list - // if submat_index was deleted from the list - // the remaining submat_list's size is reduced by 1 - remaining_submat_list_size--; - } - *max_remaining_submat_list_size = remaining_submat_list_size - > *max_remaining_submat_list_size ? remaining_submat_list_size : - *max_remaining_submat_list_size; - } -} - -// Function to extract the identified pairs (identified with an iterator) -// from a vector of list of pairs, "to extract" means to copy into -// a list and erase the original pair from the submat_lists -void ExtractGivenPairsFromSubmatLists( - std::vector >::iterator> - &input_iterator_list, - std::vector > > *sorted_submat_lists, - std::vector > *list_of_pairs) { - list_of_pairs->reserve(sorted_submat_lists->size()); - for (int32 i = 0; i < input_iterator_list.size(); i++) { - if (input_iterator_list[i] != (*sorted_submat_lists)[i].end()) { - // there was an element with the submat_index in the current list - list_of_pairs->push_back(*input_iterator_list[i]); - (*sorted_submat_lists)[i].erase(input_iterator_list[i]); - } else { - // insert a dummy element. Callers of this function expect the dummy - // element to be (-1, -1) - list_of_pairs->push_back(std::make_pair(-1, -1)); - } - } -} - -// Function to extract the last pairs from a vector of list of pairs -// a dummy is added when the list is empty -static void ExtractLastPairFromSubmatLists( - std::vector > > *sorted_submat_lists, - std::vector > *list_of_pairs) { - list_of_pairs->reserve(sorted_submat_lists->size()); - for (int32 i = 0; i < sorted_submat_lists->size(); i++) { - if ((*sorted_submat_lists)[i].size() == 0) { - // the value of the dummy has to be (-1, -1) as down stream code has - // expects -1 values for dummies - list_of_pairs->push_back(std::make_pair(-1, -1)); - continue; + p.first = submat_index; + int32 src_row_index = iter->second; + p.second = src_row_index; } - list_of_pairs->push_back((*sorted_submat_lists)[i].back()); - (*sorted_submat_lists)[i].pop_back(); } } -// Function which does the actual splitting of submat_lists. But it operates on -// sorted submat_lists and uses submat_histogram_vector. -// See SplitLocations, below for the algorithm -static void SplitLocationsUsingSubmatHistogram( - // maximum size of the lists in the sorted_submat_lists - int32 max_submat_list_size, - // a vector of list of pairs where each list is expected to be sorted - // this is a pointer as the lists will be modified - std::vector > > *sorted_submat_lists, - // a vector of pairs to represent a histogram - // this is a pointer as the vector will be sorted - std::vector > *submat_histogram_vector, - // a vector of lists of pairs with rearranged pairs - std::vector > > *split_lists) { - - // sort the submat_histogram_vector based on second element of pair - // in descending order then first element of pair in descending order - std::sort(submat_histogram_vector->begin(), - submat_histogram_vector->end(), SecondElementComparator); - - int32 prev_max_remaining_submat_list_size = max_submat_list_size; - std::vector >::iterator iter; - for (iter = submat_histogram_vector->begin(); - iter != submat_histogram_vector->end(); - ++iter) { - std::pair submat_index_and_count = *iter; - std::vector >::iterator> - output_iterator_list; - int32 max_remaining_submat_list_size = 0; - FindSubmatIndexInSubmatLists(submat_index_and_count.first, - sorted_submat_lists, - &output_iterator_list, - &max_remaining_submat_list_size); - if (max_remaining_submat_list_size - < prev_max_remaining_submat_list_size) { - // since we will have a smaller max_remaining_submat_list_size by - // splitting this submat_index into a seperate list, - // we will split it; - std::vector > list_of_pairs; - ExtractGivenPairsFromSubmatLists(output_iterator_list, - sorted_submat_lists, - &list_of_pairs); - split_lists->push_back(list_of_pairs); - prev_max_remaining_submat_list_size = max_remaining_submat_list_size; - } - } - - // rearrange the remaining pairs into lists where - // pairs with multiple first elements are allowed - // Note : we don't yet know if there is any advantage of having multiple - // calls to the same submat in kAddRowsMulti. If this is actually helpful - // then use the sorted_histogram_vector to first copy submat_indexes which - // did not make it to kAddRows calls - for (int32 i = 0; i < prev_max_remaining_submat_list_size; i++) { - std::vector > list_of_pairs; - ExtractLastPairFromSubmatLists(sorted_submat_lists, - &list_of_pairs); - split_lists->push_back(list_of_pairs); - } -} - -// Function rearranges the submat_lists (see nnet-compute-utils.h for -// description of submat_lists), into lists that can be used as inputs -// for kAddRows and kAddRowsMulti calls. -// kAddRows requires a list of pairs where all the first elements correspond to -// the same submat_index. -// kAddRowsMulti uses a list of pairs where the first elements can correspond to -// multiple submat_index locations. -// ------------------------ -// The maximum size of a list in submat_lists is the minimum number of -// kAddRowsMulti calls necessary. -// In the current implementation we replace kAddRowsMulti calls with -// kAddRows calls wherever possible, while not increasing the number of calls. -// -// Algorithm : -// The function computes a histogram of submat_indexes and spans through the -// submat_indexes in descending order of frequency. For each submat_index a -// decision is made to copy it using a kAddRows call or not. -// A kAddRow call is made for a submat_index if splitting it into a seperate -// list reduces the max_submat_list_size by one, i.e., reduces the number of -// remaining kAddRowsMulti calls. -// submat_indexes which cannot be assigned to kAddRow calls are rearranged into -// lists for kAddRowsMulti calls. -// -// Note : To decide splits we could have solved a combinatorial -// optimization problem where we find the best set of -// kAddRows + kAddRowsMulti calls; -// but given that both these calls have similar costs, -// and that the average number of elements in a submat_list is around 4, -// it does not make sense to -// choose a kAddRows call unless it is able to immediately reduce a -// kAddRowsMulti call. So we simplify the process and stay away from any -// complex search algorithms. We might implement a solution where a more -// elaborate search is done,if the length of the lists increases -// for newer NN architectures, as even minor savings in speed due to increased -// number of kAddRows calls can accumulate compensating for the additional calls - void SplitLocations( const std::vector > > &submat_lists, std::vector > > *split_lists) { + size_t num_rows = submat_lists.size(), + num_output_lists = 0; + auto iter = submat_lists.begin(), end = submat_lists.end(); + for (; iter != end; ++iter) + if (iter->size() > num_output_lists) + num_output_lists = iter->size(); + split_lists->clear(); + if (num_output_lists == 0) // Odd, but could happen, maybe + return; + else if (num_output_lists == 1) { + split_lists->resize(1); + std::vector > &list = (*split_lists)[0]; + list.resize(num_rows, std::pair(-1, -1)); + for (size_t i = 0; i < num_rows; i++) { + if (!submat_lists[i].empty()) + list[i] = submat_lists[i][0]; + } + return; + } - // a histogram of the submat_indexes in the submat_lists - // each occurence in a given submat_list is considered unique so we maintain - // a vector to count each occurrence separately. - // The i'th element in the vector corresponds to the count of - // the (i+1)'th occurrence of a submat_index in a submat_list - unordered_map > submat_histogram; - - int32 max_submat_list_size = 0; - - // initializing a vector of list of pairs to store the sorted submat_lists - std::vector > > - sorted_submat_lists; - SortSubmatLists(submat_lists, &sorted_submat_lists, &max_submat_list_size); - ComputeSubmatIndexHistogram(sorted_submat_lists, &submat_histogram); - // the vector has same information as the submat_histogram, but it is - // suitable for sorting according to frequency. The first elements of pairs - // can be repeated, these correspond to different occurrences in the same list - std::vector > submat_histogram_vector; - // copy the key, occurence_counts from submat_histogram to a vector - unordered_map >::iterator hist_iter; - for (hist_iter = submat_histogram.begin(); - hist_iter != submat_histogram.end(); - ++hist_iter) { - for (int32 i = 0; i < (hist_iter->second).size(); i++) { - submat_histogram_vector.push_back( - std::make_pair(hist_iter->first, (hist_iter->second)[i])); + // counts for each submatrix index, of how many times it occurs. + std::unordered_map submat_counts; + std::vector submats_with_large_counts; + GetSubmatCounts(submat_lists, &submat_counts, &submats_with_large_counts); + if (!submats_with_large_counts.empty()) { + // There are submatrices with counts over half the num-rows. We assign these + // their own output lists. + + std::vector > > reduced_submat_lists; + SeparateSubmatsWithLargeCounts(submats_with_large_counts, + submat_lists, + &reduced_submat_lists, + split_lists); + // 'reduced_split_lists' is the result of recursing with input 'reduced_submat_lists'; + // we'll append its result to 'split_lists'. + std::vector > > reduced_split_lists; + SplitLocations(reduced_submat_lists, &reduced_split_lists); + size_t cur_num_lists = split_lists->size(), + num_extra_lists = reduced_split_lists.size(), + new_num_lists = cur_num_lists + num_extra_lists; + split_lists->resize(new_num_lists); + for (size_t i = 0; i < num_extra_lists; i++) + (*split_lists)[cur_num_lists + i].swap(reduced_split_lists[i]); + return; + // and we're done. + } else { + // All the counts of submatrix indexes seem to be small so we are resigned to + // only using AddRowsMulti commands. + split_lists->resize(num_output_lists); + for (size_t i = 0; i < num_output_lists; i++) + (*split_lists)[i].resize(num_rows, std::pair(-1, -1)); + for (size_t row = 0; row < num_rows; row++) { + const std::vector > &this_list = + submat_lists[row]; + size_t this_list_size = submat_lists[row].size(); + for (size_t i = 0; i < this_list_size; i++) { + (*split_lists)[i][row] = this_list[i]; + } } } - SplitLocationsUsingSubmatHistogram(max_submat_list_size, &sorted_submat_lists, - &submat_histogram_vector, split_lists); } + /* If it is the case for some i >= 0 that all the .first elements of "location_vector" are either i or -1, then output i to first_value and the .second elements into "second_values", and return true. Otherwise return diff --git a/src/nnet3/nnet-compile-utils.h b/src/nnet3/nnet-compile-utils.h index 124f40f3421..e21f81aecdd 100644 --- a/src/nnet3/nnet-compile-utils.h +++ b/src/nnet3/nnet-compile-utils.h @@ -32,11 +32,15 @@ namespace nnet3 { /** - The input to this function is a vector of lists of pairs, and this function - splits it up into a list of vectors of pairs. In order to make the lists all - the same length it may have to insert "dummy" pairs with value (-1, -1). - In addition, this function implement certain heuristics to break up the - list into pairs in a particular desirable way, which we will describe below. + The input to this function is a vector (indexed by matrix-row-index) of lists + of pairs (submat_index, row_index), and this function splits it up into a + list of vectors of pairs, where those vectors are indexed by + matrix-row-index. + + In order to make the lists all the same length it may have to insert "dummy" + pairs with value (-1, -1). In addition, this function implement certain + heuristics to break up the list into pairs in a particular desirable way, + which we will describe below. Let the input be `submat_lists`, and let `num_rows = submat_lists.size()`. The value -1 is not expected to appear as either the .first or .second @@ -74,7 +78,6 @@ namespace nnet3 { See documentation here: \ref dnn3_compile_compiler_split_locations */ - void SplitLocations( const std::vector > > &submat_lists, std::vector > > *split_lists); @@ -179,4 +182,3 @@ void GetNxList(const std::vector &indexes, #endif - diff --git a/src/nnet3/nnet-descriptor-test.cc b/src/nnet3/nnet-descriptor-test.cc index de6fe5247bd..94f9bc99f12 100644 --- a/src/nnet3/nnet-descriptor-test.cc +++ b/src/nnet3/nnet-descriptor-test.cc @@ -205,6 +205,9 @@ void UnitTestGeneralDescriptorSpecial() { names.push_back("d"); KALDI_ASSERT(NormalizeTextDescriptor(names, "a") == "a"); KALDI_ASSERT(NormalizeTextDescriptor(names, "Scale(-1.0, a)") == "Scale(-1, a)"); + KALDI_ASSERT(NormalizeTextDescriptor(names, "Scale(-1.0, Scale(-2.0, a))") == "Scale(2, a)"); + KALDI_ASSERT(NormalizeTextDescriptor(names, "Scale(2.0, Sum(Scale(2.0, a), b, c))") == + "Sum(Scale(4, a), Sum(Scale(2, b), Scale(2, c)))"); KALDI_ASSERT(NormalizeTextDescriptor(names, "Const(1.0, 512)") == "Const(1, 512)"); KALDI_ASSERT(NormalizeTextDescriptor(names, "Sum(Const(1.0, 512), Scale(-1.0, a))") == "Sum(Const(1, 512), Scale(-1, a))"); diff --git a/src/nnet3/nnet-descriptor.cc b/src/nnet3/nnet-descriptor.cc index 78fedc3b00c..a218d945d65 100644 --- a/src/nnet3/nnet-descriptor.cc +++ b/src/nnet3/nnet-descriptor.cc @@ -928,6 +928,29 @@ bool GeneralDescriptor::Normalize(GeneralDescriptor *desc) { std::swap(desc->value1_, child->value1_); std::swap(desc->value2_, child->value2_); changed = true; + } else if (child->descriptor_type_ == kSum) { + // Push the Scale() inside the sum expression. + desc->descriptors_.clear(); + for (size_t i = 0; i < child->descriptors_.size(); i++) { + GeneralDescriptor *new_child = + new GeneralDescriptor(kScale, -1, -1, desc->alpha_); + new_child->descriptors_.push_back(child->descriptors_[i]); + desc->descriptors_.push_back(new_child); + } + desc->descriptor_type_ = kSum; + desc->alpha_ = 0.0; + child->descriptors_.clear(); // prevent them being freed. + delete child; + changed = true; + } else if (child->descriptor_type_ == kScale) { + // Combine the 'scale' expressions. + KALDI_ASSERT(child->descriptors_.size() == 1); + GeneralDescriptor *grandchild = child->descriptors_[0]; + desc->alpha_ *= child->alpha_; + desc->descriptors_[0] = grandchild; + child->descriptors_.clear(); // prevent them being freed. + delete child; + changed = true; } else if (child->descriptor_type_ != kNodeName) { KALDI_ERR << "Unhandled case encountered when normalizing Descriptor; " "you can work around this by pushing Scale() inside " diff --git a/src/nnet3bin/Makefile b/src/nnet3bin/Makefile index d187a7b61aa..67d15d3c38a 100644 --- a/src/nnet3bin/Makefile +++ b/src/nnet3bin/Makefile @@ -19,7 +19,7 @@ BINFILES = nnet3-init nnet3-info nnet3-get-egs nnet3-copy-egs nnet3-subset-egs \ nnet3-discriminative-subset-egs nnet3-get-egs-simple \ nnet3-discriminative-compute-from-egs nnet3-latgen-faster-looped \ nnet3-egs-augment-image nnet3-xvector-get-egs nnet3-xvector-compute \ - nnet3-latgen-grammar + nnet3-latgen-grammar nnet3-compute-batch nnet3-latgen-faster-batch OBJFILES = @@ -32,7 +32,7 @@ ADDLIBS = ../nnet3/kaldi-nnet3.a ../chain/kaldi-chain.a \ ../cudamatrix/kaldi-cudamatrix.a ../decoder/kaldi-decoder.a \ ../lat/kaldi-lat.a ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/nnet3bin/nnet3-compute-batch.cc b/src/nnet3bin/nnet3-compute-batch.cc new file mode 100644 index 00000000000..b0001c96f57 --- /dev/null +++ b/src/nnet3bin/nnet3-compute-batch.cc @@ -0,0 +1,204 @@ +// nnet3bin/nnet3-compute-batch.cc + +// Copyright 2012-2018 Johns Hopkins University (author: Daniel Povey) +// 2018 Hang Lyu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "nnet3/nnet-batch-compute.h" +#include "base/timer.h" +#include "nnet3/nnet-utils.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace kaldi::nnet3; + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "Propagate the features through raw neural network model " + "and write the output. This version is optimized for GPU use. " + "If --apply-exp=true, apply the Exp() function to the output " + "before writing it out.\n" + "\n" + "Usage: nnet3-compute-batch [options] " + "\n" + " e.g.: nnet3-compute-batch final.raw scp:feats.scp " + "ark:nnet_prediction.ark\n"; + + ParseOptions po(usage); + Timer timer; + + NnetBatchComputerOptions opts; + opts.acoustic_scale = 1.0; // by default do no scaling + + bool apply_exp = false, use_priors = false; + std::string use_gpu = "yes"; + + std::string word_syms_filename; + std::string ivector_rspecifier, + online_ivector_rspecifier, + utt2spk_rspecifier; + int32 online_ivector_period = 0; + opts.Register(&po); + + po.Register("ivectors", &ivector_rspecifier, "Rspecifier for " + "iVectors as vectors (i.e. not estimated online); per " + "utterance by default, or per speaker if you provide the " + "--utt2spk option."); + po.Register("utt2spk", &utt2spk_rspecifier, "Rspecifier for " + "utt2spk option used to get ivectors per speaker"); + po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier for " + "iVectors estimated online, as matrices. If you supply this," + " you must set the --online-ivector-period option."); + po.Register("online-ivector-period", &online_ivector_period, "Number of " + "frames between iVectors in matrices supplied to the " + "--online-ivectors option"); + po.Register("apply-exp", &apply_exp, "If true, apply exp function to " + "output"); + po.Register("use-gpu", &use_gpu, + "yes|no|optional|wait, only has effect if compiled with CUDA"); + po.Register("use-priors", &use_priors, "If true, subtract the logs of the " + "priors stored with the model (in this case, " + "a .mdl file is expected as input)."); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + +#if HAVE_CUDA==1 + CuDevice::Instantiate().AllowMultithreading(); + CuDevice::Instantiate().SelectGpuId(use_gpu); +#endif + + std::string nnet_rxfilename = po.GetArg(1), + feature_rspecifier = po.GetArg(2), + matrix_wspecifier = po.GetArg(3); + + Nnet raw_nnet; + AmNnetSimple am_nnet; + if (use_priors) { + bool binary; + TransitionModel trans_model; + Input ki(nnet_rxfilename, &binary); + trans_model.Read(ki.Stream(), binary); + am_nnet.Read(ki.Stream(), binary); + } else { + ReadKaldiObject(nnet_rxfilename, &raw_nnet); + } + Nnet &nnet = (use_priors ? am_nnet.GetNnet() : raw_nnet); + SetBatchnormTestMode(true, &nnet); + SetDropoutTestMode(true, &nnet); + CollapseModel(CollapseModelConfig(), &nnet); + + Vector priors; + if (use_priors) + priors = am_nnet.Priors(); + + RandomAccessBaseFloatMatrixReader online_ivector_reader( + online_ivector_rspecifier); + RandomAccessBaseFloatVectorReaderMapped ivector_reader( + ivector_rspecifier, utt2spk_rspecifier); + + BaseFloatMatrixWriter matrix_writer(matrix_wspecifier); + + int32 num_success = 0, num_fail = 0; + std::string output_uttid; + Matrix output_matrix; + + + NnetBatchInference inference(opts, nnet, priors); + + SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); + + for (; !feature_reader.Done(); feature_reader.Next()) { + std::string utt = feature_reader.Key(); + const Matrix &features = feature_reader.Value(); + if (features.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_fail++; + continue; + } + const Matrix *online_ivectors = NULL; + const Vector *ivector = NULL; + if (!ivector_rspecifier.empty()) { + if (!ivector_reader.HasKey(utt)) { + KALDI_WARN << "No iVector available for utterance " << utt; + num_fail++; + continue; + } else { + ivector = new Vector(ivector_reader.Value(utt)); + } + } + if (!online_ivector_rspecifier.empty()) { + if (!online_ivector_reader.HasKey(utt)) { + KALDI_WARN << "No online iVector available for utterance " << utt; + num_fail++; + continue; + } else { + online_ivectors = new Matrix( + online_ivector_reader.Value(utt)); + } + } + + inference.AcceptInput(utt, features, ivector, online_ivectors, + online_ivector_period); + + std::string output_key; + Matrix output; + while (inference.GetOutput(&output_key, &output)) { + if (apply_exp) + output.ApplyExp(); + matrix_writer.Write(output_key, output); + num_success++; + } + } + + inference.Finished(); + std::string output_key; + Matrix output; + while (inference.GetOutput(&output_key, &output)) { + if (apply_exp) + output.ApplyExp(); + matrix_writer.Write(output_key, output); + num_success++; + } +#if HAVE_CUDA==1 + CuDevice::Instantiate().PrintProfile(); +#endif + double elapsed = timer.Elapsed(); + KALDI_LOG << "Time taken "<< elapsed << "s"; + KALDI_LOG << "Done " << num_success << " utterances, failed for " + << num_fail; + + if (num_success != 0) { + return 0; + } else { + return 1; + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/nnet3bin/nnet3-compute.cc b/src/nnet3bin/nnet3-compute.cc index f67167bc819..45fde99a4f5 100644 --- a/src/nnet3bin/nnet3-compute.cc +++ b/src/nnet3bin/nnet3-compute.cc @@ -48,7 +48,7 @@ int main(int argc, char *argv[]) { Timer timer; NnetSimpleComputationOptions opts; - opts.acoustic_scale = 1.0; // by default do no scaling in this recipe. + opts.acoustic_scale = 1.0; // by default do no scaling. bool apply_exp = false, use_priors = false; std::string use_gpu = "yes"; diff --git a/src/nnet3bin/nnet3-latgen-faster-batch.cc b/src/nnet3bin/nnet3-latgen-faster-batch.cc new file mode 100644 index 00000000000..fad2d5ed356 --- /dev/null +++ b/src/nnet3bin/nnet3-latgen-faster-batch.cc @@ -0,0 +1,227 @@ +// nnet3bin/nnet3-latgen-faster-parallel.cc + +// Copyright 2012-2016 Johns Hopkins University (author: Daniel Povey) +// 2014 Guoguo Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/timer.h" +#include "base/kaldi-common.h" +#include "decoder/decoder-wrappers.h" +#include "fstext/fstext-lib.h" +#include "hmm/transition-model.h" +#include "nnet3/nnet-batch-compute.h" +#include "nnet3/nnet-utils.h" +#include "util/kaldi-thread.h" +#include "tree/context-dep.h" +#include "util/common-utils.h" + +namespace kaldi { + +void HandleOutput(bool determinize, + const fst::SymbolTable *word_syms, + nnet3::NnetBatchDecoder *decoder, + CompactLatticeWriter *clat_writer, + LatticeWriter *lat_writer) { + // Write out any lattices that are ready. + std::string output_utterance_id, sentence; + if (determinize) { + CompactLattice clat; + while (decoder->GetOutput(&output_utterance_id, &clat, &sentence)) { + if (word_syms != NULL) + std::cerr << output_utterance_id << ' ' << sentence << '\n'; + clat_writer->Write(output_utterance_id, clat); + } + } else { + Lattice lat; + while (decoder->GetOutput(&output_utterance_id, &lat, &sentence)) { + if (word_syms != NULL) + std::cerr << output_utterance_id << ' ' << sentence << '\n'; + lat_writer->Write(output_utterance_id, lat); + } + } +} + +} // namespace kaldi + +int main(int argc, char *argv[]) { + // note: making this program work with GPUs is as simple as initializing the + // device, but it probably won't make a huge difference in speed for typical + // setups. + try { + using namespace kaldi; + using namespace kaldi::nnet3; + typedef kaldi::int32 int32; + using fst::SymbolTable; + using fst::Fst; + using fst::StdArc; + + const char *usage = + "Generate lattices using nnet3 neural net model. This version is optimized\n" + "for GPU-based inference.\n" + "Usage: nnet3-latgen-faster-parallel [options] " + " \n"; + ParseOptions po(usage); + + bool allow_partial = false; + LatticeFasterDecoderConfig decoder_opts; + NnetBatchComputerOptions compute_opts; + std::string use_gpu = "yes"; + + std::string word_syms_filename; + std::string ivector_rspecifier, + online_ivector_rspecifier, + utt2spk_rspecifier; + int32 online_ivector_period = 0, num_threads = 1; + decoder_opts.Register(&po); + compute_opts.Register(&po); + po.Register("word-symbol-table", &word_syms_filename, + "Symbol table for words [for debug output]"); + po.Register("allow-partial", &allow_partial, + "If true, produce output even if end state was not reached."); + po.Register("ivectors", &ivector_rspecifier, "Rspecifier for " + "iVectors as vectors (i.e. not estimated online); per utterance " + "by default, or per speaker if you provide the --utt2spk option."); + po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier for " + "iVectors estimated online, as matrices. If you supply this," + " you must set the --online-ivector-period option."); + po.Register("online-ivector-period", &online_ivector_period, "Number of frames " + "between iVectors in matrices supplied to the --online-ivectors " + "option"); + po.Register("num-threads", &num_threads, "Number of decoder (i.e. " + "graph-search) threads. The number of model-evaluation threads " + "is always 1; this is optimized for use with the GPU."); + po.Register("use-gpu", &use_gpu, + "yes|no|optional|wait, only has effect if compiled with CUDA"); + + po.Read(argc, argv); + + if (po.NumArgs() != 4) { + po.PrintUsage(); + exit(1); + } + +#if HAVE_CUDA==1 + CuDevice::Instantiate().AllowMultithreading(); + CuDevice::Instantiate().SelectGpuId(use_gpu); +#endif + + std::string model_in_rxfilename = po.GetArg(1), + fst_in_rxfilename = po.GetArg(2), + feature_rspecifier = po.GetArg(3), + lattice_wspecifier = po.GetArg(4); + + TransitionModel trans_model; + AmNnetSimple am_nnet; + { + bool binary; + Input ki(model_in_rxfilename, &binary); + trans_model.Read(ki.Stream(), binary); + am_nnet.Read(ki.Stream(), binary); + SetBatchnormTestMode(true, &(am_nnet.GetNnet())); + SetDropoutTestMode(true, &(am_nnet.GetNnet())); + CollapseModel(CollapseModelConfig(), &(am_nnet.GetNnet())); + } + + bool determinize = decoder_opts.determinize_lattice; + CompactLatticeWriter compact_lattice_writer; + LatticeWriter lattice_writer; + if (! (determinize ? compact_lattice_writer.Open(lattice_wspecifier) + : lattice_writer.Open(lattice_wspecifier))) + KALDI_ERR << "Could not open table for writing lattices: " + << lattice_wspecifier; + + RandomAccessBaseFloatMatrixReader online_ivector_reader( + online_ivector_rspecifier); + RandomAccessBaseFloatVectorReaderMapped ivector_reader( + ivector_rspecifier, utt2spk_rspecifier); + + fst::SymbolTable *word_syms = NULL; + if (word_syms_filename != "") + if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename))) + KALDI_ERR << "Could not read symbol table from file " + << word_syms_filename; + + + SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); + + Fst *decode_fst = fst::ReadFstKaldiGeneric(fst_in_rxfilename); + + int32 num_success; + { + NnetBatchComputer computer(compute_opts, am_nnet.GetNnet(), + am_nnet.Priors()); + NnetBatchDecoder decoder(*decode_fst, decoder_opts, + trans_model, word_syms, allow_partial, + num_threads, &computer); + + for (; !feature_reader.Done(); feature_reader.Next()) { + std::string utt = feature_reader.Key(); + const Matrix &features (feature_reader.Value()); + + if (features.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + decoder.UtteranceFailed(); + continue; + } + const Matrix *online_ivectors = NULL; + const Vector *ivector = NULL; + if (!ivector_rspecifier.empty()) { + if (!ivector_reader.HasKey(utt)) { + KALDI_WARN << "No iVector available for utterance " << utt; + decoder.UtteranceFailed(); + continue; + } else { + ivector = &ivector_reader.Value(utt); + } + } + if (!online_ivector_rspecifier.empty()) { + if (!online_ivector_reader.HasKey(utt)) { + KALDI_WARN << "No online iVector available for utterance " << utt; + decoder.UtteranceFailed(); + continue; + } else { + online_ivectors = &online_ivector_reader.Value(utt); + } + } + + decoder.AcceptInput(utt, features, ivector, online_ivectors, + online_ivector_period); + + HandleOutput(decoder_opts.determinize_lattice, word_syms, &decoder, + &compact_lattice_writer, &lattice_writer); + } + num_success = decoder.Finished(); + HandleOutput(decoder_opts.determinize_lattice, word_syms, &decoder, + &compact_lattice_writer, &lattice_writer); + + // At this point the decoder and batch-computer objects will print + // diagnostics from their destructors (they are going out of scope). + } + delete decode_fst; + delete word_syms; + +#if HAVE_CUDA==1 + CuDevice::Instantiate().PrintProfile(); +#endif + + return (num_success != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/nnet3bin/nnet3-latgen-faster-parallel.cc b/src/nnet3bin/nnet3-latgen-faster-parallel.cc index 4858a9fcb14..e3d02410368 100644 --- a/src/nnet3bin/nnet3-latgen-faster-parallel.cc +++ b/src/nnet3bin/nnet3-latgen-faster-parallel.cc @@ -45,9 +45,11 @@ int main(int argc, char *argv[]) { using fst::StdArc; const char *usage = - "Generate lattices using nnet3 neural net model.\n" + "Generate lattices using nnet3 neural net model. This version supports\n" + "multiple decoding threads (using a shared decoding graph.)\n" "Usage: nnet3-latgen-faster-parallel [options] " - " [ [] ]\n"; + " [ [] ]\n" + "See also: nnet3-latgen-faster-batch (which supports GPUs)\n"; ParseOptions po(usage); Timer timer; diff --git a/src/nnet3bin/nnet3-latgen-faster.cc b/src/nnet3bin/nnet3-latgen-faster.cc index cb26745d808..42cd843cf15 100644 --- a/src/nnet3bin/nnet3-latgen-faster.cc +++ b/src/nnet3bin/nnet3-latgen-faster.cc @@ -33,7 +33,7 @@ int main(int argc, char *argv[]) { // note: making this program work with GPUs is as simple as initializing the // device, but it probably won't make a huge difference in speed for typical - // setups. + // setups. You should use nnet3-latgen-faster-batch if you want to use a GPU. try { using namespace kaldi; using namespace kaldi::nnet3; @@ -45,7 +45,8 @@ int main(int argc, char *argv[]) { const char *usage = "Generate lattices using nnet3 neural net model.\n" "Usage: nnet3-latgen-faster [options] " - " [ [] ]\n"; + " [ [] ]\n" + "See also: nnet3-latgen-faster-parallel, nnet3-latgen-faster-batch\n"; ParseOptions po(usage); Timer timer; bool allow_partial = false; diff --git a/src/nnetbin/Makefile b/src/nnetbin/Makefile index 49a174ec36e..86d59ae503e 100644 --- a/src/nnetbin/Makefile +++ b/src/nnetbin/Makefile @@ -25,7 +25,6 @@ TESTFILES = ADDLIBS = ../nnet/kaldi-nnet.a ../cudamatrix/kaldi-cudamatrix.a \ ../lat/kaldi-lat.a ../hmm/kaldi-hmm.a ../tree/kaldi-tree.a \ - ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/online/Makefile b/src/online/Makefile index 8f2fe238111..32c99500750 100644 --- a/src/online/Makefile +++ b/src/online/Makefile @@ -37,8 +37,7 @@ LIBNAME = kaldi-online ADDLIBS = ../decoder/kaldi-decoder.a ../lat/kaldi-lat.a ../hmm/kaldi-hmm.a \ ../feat/kaldi-feat.a ../transform/kaldi-transform.a \ ../gmm/kaldi-gmm.a ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a \ - ../base/kaldi-base.a + ../matrix/kaldi-matrix.a ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/online2/Makefile b/src/online2/Makefile index 764fef3ab26..242c7be6da6 100644 --- a/src/online2/Makefile +++ b/src/online2/Makefile @@ -18,8 +18,8 @@ ADDLIBS = ../ivector/kaldi-ivector.a ../nnet3/kaldi-nnet3.a \ ../cudamatrix/kaldi-cudamatrix.a ../decoder/kaldi-decoder.a \ ../lat/kaldi-lat.a ../hmm/kaldi-hmm.a ../feat/kaldi-feat.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a diff --git a/src/online2/online-ivector-feature.cc b/src/online2/online-ivector-feature.cc index 4e64609d9ff..3356eb4b1c7 100644 --- a/src/online2/online-ivector-feature.cc +++ b/src/online2/online-ivector-feature.cc @@ -174,24 +174,54 @@ void OnlineIvectorFeature::UpdateFrameWeights( delta_weights_provided_ = true; } -void OnlineIvectorFeature::UpdateStatsForFrame(int32 t, - BaseFloat weight) { + +BaseFloat OnlineIvectorFeature::GetMinPost(BaseFloat weight) const { + BaseFloat min_post = info_.min_post; + BaseFloat abs_weight = fabs(weight); + // If we return 0.99, it will have the same effect as just picking the + // most probable Gaussian on that frame. + if (abs_weight == 0.0) + return 0.99; // I don't anticipate reaching here. + min_post /= abs_weight; + if (min_post > 0.99) + min_post = 0.99; + return min_post; +} + +void OnlineIvectorFeature::UpdateStatsForFrames( + const std::vector > &frame_weights) { + int32 num_frames = static_cast(frame_weights.size()); int32 feat_dim = lda_normalized_->Dim(); - Vector feat(feat_dim), // features given to iVector extractor - log_likes(info_.diag_ubm.NumGauss()); - lda_normalized_->GetFrame(t, &feat); - info_.diag_ubm.LogLikelihoods(feat, &log_likes); - // "posterior" stores the pruned posteriors for Gaussians in the UBM. - std::vector > posterior; - tot_ubm_loglike_ += weight * - VectorToPosteriorEntry(log_likes, info_.num_gselect, - info_.min_post, &posterior); - for (size_t i = 0; i < posterior.size(); i++) - posterior[i].second *= info_.posterior_scale * weight; - lda_->GetFrame(t, &feat); // get feature without CMN. - ivector_stats_.AccStats(info_.extractor, feat, posterior); + Matrix feats(num_frames, feat_dim, kUndefined), + log_likes; + + std::vector frames; + frames.reserve(frame_weights.size()); + for (int32 i = 0; i < num_frames; i++) + frames.push_back(frame_weights[i].first); + lda_normalized_->GetFrames(frames, &feats); + + info_.diag_ubm.LogLikelihoods(feats, &log_likes); + + // "posteriors" stores, for each frame index in the range of frames, the + // pruned posteriors for the Gaussians in the UBM. + std::vector > > posteriors(num_frames); + for (int32 i = 0; i < num_frames; i++) { + std::vector > &posterior = posteriors[i]; + BaseFloat weight = frame_weights[i].second; + if (weight != 0.0) { + tot_ubm_loglike_ += weight * + VectorToPosteriorEntry(log_likes.Row(i), info_.num_gselect, + GetMinPost(weight), &posterior); + for (size_t j = 0; j < posterior.size(); j++) + posterior[j].second *= info_.posterior_scale * weight; + } + } + lda_->GetFrames(frames, &feats); // get features without CMN. + ivector_stats_.AccStats(info_.extractor, feats, posteriors); } + void OnlineIvectorFeature::UpdateStatsUntilFrame(int32 frame) { KALDI_ASSERT(frame >= 0 && frame < this->NumFramesReady() && !delta_weights_provided_); @@ -200,11 +230,19 @@ void OnlineIvectorFeature::UpdateStatsUntilFrame(int32 frame) { int32 ivector_period = info_.ivector_period; int32 num_cg_iters = info_.num_cg_iters; + std::vector > frame_weights; + for (; num_frames_stats_ <= frame; num_frames_stats_++) { int32 t = num_frames_stats_; - UpdateStatsForFrame(t, 1.0); + BaseFloat frame_weight = 1.0; + frame_weights.push_back(std::pair(t, frame_weight)); if ((!info_.use_most_recent_ivector && t % ivector_period == 0) || (info_.use_most_recent_ivector && t == frame)) { + // The call below to UpdateStatsForFrames() is equivalent to doing, for + // all valid indexes i: + // UpdateStatsForFrame(cur_start_frame + i, frame_weights[i]) + UpdateStatsForFrames(frame_weights); + frame_weights.clear(); ivector_stats_.GetIvector(num_cg_iters, ¤t_ivector_); if (!info_.use_most_recent_ivector) { // need to cache iVectors. int32 ivec_index = t / ivector_period; @@ -213,6 +251,8 @@ void OnlineIvectorFeature::UpdateStatsUntilFrame(int32 frame) { } } } + if (!frame_weights.empty()) + UpdateStatsForFrames(frame_weights); } void OnlineIvectorFeature::UpdateStatsUntilFrameWeighted(int32 frame) { @@ -225,17 +265,19 @@ void OnlineIvectorFeature::UpdateStatsUntilFrameWeighted(int32 frame) { int32 ivector_period = info_.ivector_period; int32 num_cg_iters = info_.num_cg_iters; + std::vector > frame_weights; + frame_weights.reserve(delta_weights_.size()); + for (; num_frames_stats_ <= frame; num_frames_stats_++) { int32 t = num_frames_stats_; // Instead of just updating frame t, we update all frames that need updating - // with index <= 1, in case old frames were reclassified as silence/nonsilence. + // with index <= t, in case old frames were reclassified as silence/nonsilence. while (!delta_weights_.empty() && delta_weights_.top().first <= t) { - std::pair p = delta_weights_.top(); + int32 frame = delta_weights_.top().first; + BaseFloat weight = delta_weights_.top().second; + frame_weights.push_back(delta_weights_.top()); delta_weights_.pop(); - int32 frame = p.first; - BaseFloat weight = p.second; - UpdateStatsForFrame(frame, weight); if (debug_weights) { if (current_frame_weight_debug_.size() <= frame) current_frame_weight_debug_.resize(frame + 1, 0.0); @@ -244,6 +286,8 @@ void OnlineIvectorFeature::UpdateStatsUntilFrameWeighted(int32 frame) { } if ((!info_.use_most_recent_ivector && t % ivector_period == 0) || (info_.use_most_recent_ivector && t == frame)) { + UpdateStatsForFrames(frame_weights); + frame_weights.clear(); ivector_stats_.GetIvector(num_cg_iters, ¤t_ivector_); if (!info_.use_most_recent_ivector) { // need to cache iVectors. int32 ivec_index = t / ivector_period; @@ -252,6 +296,8 @@ void OnlineIvectorFeature::UpdateStatsUntilFrameWeighted(int32 frame) { } } } + if (!frame_weights.empty()) + UpdateStatsForFrames(frame_weights); } @@ -297,7 +343,7 @@ void OnlineIvectorFeature::PrintDiagnostics() const { Vector temp_ivector(current_ivector_); temp_ivector(0) -= info_.extractor.PriorOffset(); - KALDI_VLOG(3) << "By the end of the utterance, objf change/frame " + KALDI_VLOG(2) << "By the end of the utterance, objf change/frame " << "from estimating iVector (vs. default) was " << ivector_stats_.ObjfChange(current_ivector_) << " and iVector length was " @@ -308,12 +354,8 @@ void OnlineIvectorFeature::PrintDiagnostics() const { OnlineIvectorFeature::~OnlineIvectorFeature() { PrintDiagnostics(); // Delete objects owned here. - delete lda_normalized_; - delete splice_normalized_; - delete cmvn_; - delete lda_; - delete splice_; - // base_ is not owned here so don't delete it. + for (size_t i = 0; i < to_delete_.size(); i++) + delete to_delete_[i]; for (size_t i = 0; i < ivectors_history_.size(); i++) delete ivectors_history_[i]; } @@ -334,7 +376,8 @@ void OnlineIvectorFeature::GetAdaptationState( OnlineIvectorFeature::OnlineIvectorFeature( const OnlineIvectorExtractionInfo &info, OnlineFeatureInterface *base_feature): - info_(info), base_(base_feature), + info_(info), + base_(base_feature), ivector_stats_(info_.extractor.IvectorDim(), info_.extractor.PriorOffset(), info_.max_count), @@ -343,16 +386,33 @@ OnlineIvectorFeature::OnlineIvectorFeature( most_recent_frame_with_weight_(-1), tot_ubm_loglike_(0.0) { info.Check(); KALDI_ASSERT(base_feature != NULL); - splice_ = new OnlineSpliceFrames(info_.splice_opts, base_); - lda_ = new OnlineTransform(info.lda_mat, splice_); + OnlineFeatureInterface *splice_feature = new OnlineSpliceFrames(info_.splice_opts, base_feature); + to_delete_.push_back(splice_feature); + OnlineFeatureInterface *lda_feature = new OnlineTransform(info.lda_mat, splice_feature); + to_delete_.push_back(lda_feature); + OnlineFeatureInterface *lda_cache_feature = new OnlineCacheFeature(lda_feature); + lda_ = lda_cache_feature; + to_delete_.push_back(lda_cache_feature); + + OnlineCmvnState naive_cmvn_state(info.global_cmvn_stats); // Note: when you call this constructor the CMVN state knows nothing // about the speaker. If you want to inform this class about more specific // adaptation state, call this->SetAdaptationState(), most likely derived // from a call to GetAdaptationState() from a previous object of this type. - cmvn_ = new OnlineCmvn(info.cmvn_opts, naive_cmvn_state, base_); - splice_normalized_ = new OnlineSpliceFrames(info_.splice_opts, cmvn_); - lda_normalized_ = new OnlineTransform(info.lda_mat, splice_normalized_); + cmvn_ = new OnlineCmvn(info.cmvn_opts, naive_cmvn_state, base_feature); + to_delete_.push_back(cmvn_); + + OnlineFeatureInterface *splice_normalized = + new OnlineSpliceFrames(info_.splice_opts, cmvn_), + *lda_normalized = + new OnlineTransform(info.lda_mat, splice_normalized), + *cache_normalized = new OnlineCacheFeature(lda_normalized); + lda_normalized_ = cache_normalized; + + to_delete_.push_back(splice_normalized); + to_delete_.push_back(lda_normalized); + to_delete_.push_back(cache_normalized); // Set the iVector to its default value, [ prior_offset, 0, 0, ... ]. current_ivector_.Resize(info_.extractor.IvectorDim()); diff --git a/src/online2/online-ivector-feature.h b/src/online2/online-ivector-feature.h index d4a89fdc8d1..25e078f1a98 100644 --- a/src/online2/online-ivector-feature.h +++ b/src/online2/online-ivector-feature.h @@ -311,9 +311,19 @@ class OnlineIvectorFeature: public OnlineFeatureInterface { const std::vector > &delta_weights); private: - // this function adds "weight" to the stats for frame "frame". - void UpdateStatsForFrame(int32 frame, - BaseFloat weight); + + // This accumulates i-vector stats for a set of frames, specified as pairs + // (t, weight). The weights do not have to be positive. (In the online + // silence-weighting that we do, negative weights can occur if we change our + // minds about the assignment of a frame as silence vs. non-silence). + void UpdateStatsForFrames( + const std::vector > &frame_weights); + + // Returns a modified version of info_.min_post, which is opts_.min_post if + // weight is 1.0 or -1.0, but gets larger if fabs(weight) is small... but no + // larger than 0.99. (This is an efficiency thing, to not bother processing + // very small counts). + BaseFloat GetMinPost(BaseFloat weight) const; // This is the original UpdateStatsUntilFrame that is called when there is // no data-weighting involved. @@ -327,14 +337,16 @@ class OnlineIvectorFeature: public OnlineFeatureInterface { const OnlineIvectorExtractionInfo &info_; - // base_ is the base feature; it is not owned here. - OnlineFeatureInterface *base_; - // the following online-feature-extractor pointers are owned here: - OnlineSpliceFrames *splice_; // splice on top of raw features. - OnlineTransform *lda_; // LDA on top of raw+splice features. - OnlineCmvn *cmvn_; - OnlineSpliceFrames *splice_normalized_; // splice on top of CMVN feats. - OnlineTransform *lda_normalized_; // LDA on top of CMVN+splice + OnlineFeatureInterface *base_; // The feature this is built on top of + // (e.g. MFCC); not owned here + + OnlineFeatureInterface *lda_; // LDA on top of raw+splice features. + OnlineCmvn *cmvn_; // the CMVN that we give to the lda_normalized_. + OnlineFeatureInterface *lda_normalized_; // LDA on top of CMVN+splice + + // the following is the pointers to OnlineFeatureInterface objects that are + // owned here and which we need to delete. + std::vector to_delete_; /// the iVector estimation stats OnlineIvectorEstimationStats ivector_stats_; diff --git a/src/online2bin/Makefile b/src/online2bin/Makefile index 2731fbfae1d..8792cc5b11a 100644 --- a/src/online2bin/Makefile +++ b/src/online2bin/Makefile @@ -23,6 +23,5 @@ ADDLIBS = ../online2/kaldi-online2.a ../ivector/kaldi-ivector.a \ ../lat/kaldi-lat.a ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a \ ../feat/kaldi-feat.a ../transform/kaldi-transform.a \ ../gmm/kaldi-gmm.a ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a \ - ../base/kaldi-base.a + ../matrix/kaldi-matrix.a ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/online2bin/ivector-extract-online2.cc b/src/online2bin/ivector-extract-online2.cc index 33aa990d1c3..e30d78620ad 100644 --- a/src/online2bin/ivector-extract-online2.cc +++ b/src/online2bin/ivector-extract-online2.cc @@ -23,6 +23,7 @@ #include "gmm/am-diag-gmm.h" #include "online2/online-ivector-feature.h" #include "util/kaldi-thread.h" +#include "base/timer.h" int main(int argc, char *argv[]) { using namespace kaldi; @@ -47,9 +48,9 @@ int main(int argc, char *argv[]) { "e.g.: \n" " ivector-extract-online2 --config=exp/nnet2_online/nnet_online/conf/ivector_extractor.conf \\\n" " ark:data/train/spk2utt scp:data/train/feats.scp ark,t:ivectors.1.ark\n"; - + ParseOptions po(usage); - + OnlineIvectorExtractionConfig ivector_config; ivector_config.Register(&po); @@ -57,7 +58,7 @@ int main(int argc, char *argv[]) { bool repeat = false; int32 length_tolerance = 0; std::string frame_weights_rspecifier; - + po.Register("num-threads", &g_num_threads, "Number of threads to use for computing derived variables " "of iVector extractor, at process start-up."); @@ -71,29 +72,28 @@ int main(int argc, char *argv[]) { "for feats and frame weights"); po.Read(argc, argv); - + if (po.NumArgs() != 3) { po.PrintUsage(); exit(1); } - + std::string spk2utt_rspecifier = po.GetArg(1), feature_rspecifier = po.GetArg(2), ivectors_wspecifier = po.GetArg(3); - + double tot_ubm_loglike = 0.0, tot_objf_impr = 0.0, tot_t = 0.0, tot_length = 0.0, tot_length_utt_end = 0.0; int32 num_done = 0, num_err = 0; - + ivector_config.use_most_recent_ivector = false; OnlineIvectorExtractionInfo ivector_info(ivector_config); - + SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier); RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier); RandomAccessBaseFloatVectorReader frame_weights_reader(frame_weights_rspecifier); BaseFloatMatrixWriter ivector_writer(ivectors_wspecifier); - - + for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) { std::string spk = spk2utt_reader.Key(); const std::vector &uttlist = spk2utt_reader.Value(); @@ -107,12 +107,12 @@ int main(int argc, char *argv[]) { continue; } const Matrix &feats = feature_reader.Value(utt); - + OnlineMatrixFeature matrix_feature(feats); OnlineIvectorFeature ivector_feature(ivector_info, &matrix_feature); - + ivector_feature.SetAdaptationState(adaptation_state); if (!frame_weights_rspecifier.empty()) { @@ -143,10 +143,10 @@ int main(int argc, char *argv[]) { int32 T = feats.NumRows(), n = (repeat ? 1 : ivector_config.ivector_period), num_ivectors = (T + n - 1) / n; - + Matrix ivectors(num_ivectors, ivector_feature.Dim()); - + for (int32 i = 0; i < num_ivectors; i++) { int32 t = i * n; SubVector ivector(ivectors, i); diff --git a/src/onlinebin/Makefile b/src/onlinebin/Makefile index 0999f4e7792..7c0550d0848 100644 --- a/src/onlinebin/Makefile +++ b/src/onlinebin/Makefile @@ -39,7 +39,7 @@ TESTFILES = ADDLIBS = ../online/kaldi-online.a ../decoder/kaldi-decoder.a \ ../lat/kaldi-lat.a ../hmm/kaldi-hmm.a ../feat/kaldi-feat.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/rnnlm/Makefile b/src/rnnlm/Makefile index 6ee52bbb1d7..d4b3f3ce0a8 100644 --- a/src/rnnlm/Makefile +++ b/src/rnnlm/Makefile @@ -15,7 +15,7 @@ OBJFILES = sampler.o rnnlm-example.o rnnlm-example-utils.o \ LIBNAME = kaldi-rnnlm ADDLIBS = ../nnet3/kaldi-nnet3.a ../cudamatrix/kaldi-cudamatrix.a \ - ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a \ - ../lm/kaldi-lm.a ../hmm/kaldi-hmm.a + ../lm/kaldi-lm.a ../hmm/kaldi-hmm.a ../util/kaldi-util.a \ + ../matrix/kaldi-matrix.a ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/rnnlm/rnnlm-example-utils.cc b/src/rnnlm/rnnlm-example-utils.cc index fd7cca5eadb..5aa2465d24d 100644 --- a/src/rnnlm/rnnlm-example-utils.cc +++ b/src/rnnlm/rnnlm-example-utils.cc @@ -284,11 +284,15 @@ static void ProcessRnnlmOutputNoSampling( CuMatrix word_probs(nnet_output.NumRows(), num_words - 1, kUndefined); word_probs.CopyFromMat(word_logprobs.ColRange(1, num_words - 1)); - word_probs.ApplyExp(); + word_probs.ApplyExpLimited(-80.0, 80.0); CuVector row_sums(nnet_output.NumRows()); row_sums.AddColSumMat(1.0, word_probs, 0.0); row_sums.ApplyLog(); - *objf_den_exact = -VecVec(row_sums, minibatch.output_weights); + BaseFloat ans = -VecVec(row_sums, minibatch.output_weights); + *objf_den_exact = ans; + if (fabs(ans) > 100) { + KALDI_WARN << "Big den objf " << ans; + } } // In preparation for computing the denominator objf, change 'word_logprobs' diff --git a/src/rnnlm/rnnlm-example.cc b/src/rnnlm/rnnlm-example.cc index 0be4d4ecb47..8dd36689fd6 100644 --- a/src/rnnlm/rnnlm-example.cc +++ b/src/rnnlm/rnnlm-example.cc @@ -346,7 +346,7 @@ RnnlmExampleCreator::~RnnlmExampleCreator() { num_minibatches_written_; KALDI_LOG << "Combined " << num_sequences_processed_ << "/" << num_chunks_processed_ - << " chunks/sequences into " << num_minibatches_written_ + << " sequences/chunks into " << num_minibatches_written_ << " minibatches (" << chunks_.size() << " chunks left over)"; KALDI_LOG << "Overall there were " diff --git a/src/rnnlm/rnnlm-example.h b/src/rnnlm/rnnlm-example.h index 1f3bcb957a9..3ac92701e36 100644 --- a/src/rnnlm/rnnlm-example.h +++ b/src/rnnlm/rnnlm-example.h @@ -401,7 +401,7 @@ class RnnlmExampleCreator { TableWriter > *writer): config_(config), minibatch_sampler_(NULL), sampling_sequencer_(TaskSequencerConfig()), - writer_(writer), + writer_(writer), num_sequences_processed_(0), num_chunks_processed_(0), num_words_processed_(0), num_minibatches_written_(0) { Check(); } diff --git a/src/rnnlmbin/Makefile b/src/rnnlmbin/Makefile index 4c4231c02c8..23a8eba6145 100644 --- a/src/rnnlmbin/Makefile +++ b/src/rnnlmbin/Makefile @@ -16,11 +16,11 @@ cuda-compiled.o: ../kaldi.mk TESTFILES = -ADDLIBS = ../rnnlm/kaldi-rnnlm.a ../lm/kaldi-lm.a ../nnet3/kaldi-nnet3.a \ +ADDLIBS = ../rnnlm/kaldi-rnnlm.a ../nnet3/kaldi-nnet3.a \ ../cudamatrix/kaldi-cudamatrix.a ../decoder/kaldi-decoder.a \ - ../lat/kaldi-lat.a ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a \ - ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../lat/kaldi-lat.a ../lm/kaldi-lm.a ../fstext/kaldi-fstext.a \ + ../hmm/kaldi-hmm.a ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/sgmm2/Makefile b/src/sgmm2/Makefile index d538c14c1a9..35a8d3a1f40 100644 --- a/src/sgmm2/Makefile +++ b/src/sgmm2/Makefile @@ -13,7 +13,7 @@ OBJFILES = am-sgmm2.o estimate-am-sgmm2.o estimate-am-sgmm2-ebw.o fmllr-sgmm2.o LIBNAME = kaldi-sgmm2 ADDLIBS = ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ - ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/sgmm2/decodable-am-sgmm2.h b/src/sgmm2/decodable-am-sgmm2.h index 75144650568..18498bf5b24 100644 --- a/src/sgmm2/decodable-am-sgmm2.h +++ b/src/sgmm2/decodable-am-sgmm2.h @@ -59,15 +59,15 @@ class DecodableAmSgmm2 : public DecodableInterface { sgmm_cache_(sgmm.NumGroups(), sgmm.NumPdfs()), delete_vars_(true) { KALDI_ASSERT(gselect->size() == static_cast(feats->NumRows())); } - + // Note, frames are numbered from zero, but transition indices are 1-based! // This is for compatibility with OpenFST. virtual BaseFloat LogLikelihood(int32 frame, int32 tid) { - return LogLikelihoodForPdf(frame, trans_model_.TransitionIdToPdf(tid)); + return LogLikelihoodForPdf(frame, trans_model_.TransitionIdToPdfFast(tid)); } int32 NumFramesReady() const { return feature_matrix_->NumRows(); } virtual int32 NumIndices() const { return trans_model_.NumTransitionIds(); } - + virtual bool IsLastFrame(int32 frame) const { KALDI_ASSERT(frame < NumFramesReady()); return (frame == NumFramesReady() - 1); @@ -81,17 +81,17 @@ class DecodableAmSgmm2 : public DecodableInterface { Sgmm2PerSpkDerivedVars *spk_; const TransitionModel &trans_model_; ///< for tid to pdf mapping const Matrix *feature_matrix_; - const std::vector > *gselect_; - + const std::vector > *gselect_; + BaseFloat log_prune_; - + int32 cur_frame_; Sgmm2PerFrameDerivedVars per_frame_vars_; Sgmm2LikelihoodCache sgmm_cache_; bool delete_vars_; // If true, we will delete feature_matrix_, gselect_, and // spk_ in the destructor. - + private: KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableAmSgmm2); }; @@ -121,10 +121,10 @@ class DecodableAmSgmm2Scaled : public DecodableAmSgmm2 { : DecodableAmSgmm2(sgmm, tm, feats, gselect, spk, log_prune), scale_(scale) {} - + // Note, frames are numbered from zero but transition-ids from one. virtual BaseFloat LogLikelihood(int32 frame, int32 tid) { - return LogLikelihoodForPdf(frame, trans_model_.TransitionIdToPdf(tid)) + return LogLikelihoodForPdf(frame, trans_model_.TransitionIdToPdfFast(tid)) * scale_; } private: diff --git a/src/sgmm2bin/Makefile b/src/sgmm2bin/Makefile index 34407a4f5ad..e973061ed8a 100644 --- a/src/sgmm2bin/Makefile +++ b/src/sgmm2bin/Makefile @@ -21,7 +21,6 @@ ADDLIBS = ../decoder/kaldi-decoder.a ../lat/kaldi-lat.a \ ../fstext/kaldi-fstext.a ../sgmm2/kaldi-sgmm2.a ../hmm/kaldi-hmm.a \ ../feat/kaldi-feat.a ../transform/kaldi-transform.a \ ../gmm/kaldi-gmm.a ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a \ - ../base/kaldi-base.a + ../matrix/kaldi-matrix.a ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/tfrnnlm/Makefile b/src/tfrnnlm/Makefile index 12e6c9494c9..db2b840b959 100644 --- a/src/tfrnnlm/Makefile +++ b/src/tfrnnlm/Makefile @@ -28,9 +28,8 @@ TESTFILES = LIBNAME = kaldi-tensorflow-rnnlm -ADDLIBS = ../lm/kaldi-lm.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a \ - +ADDLIBS = ../lm/kaldi-lm.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ + ../base/kaldi-base.a LDLIBS += -lz -ldl -fPIC -lrt LDLIBS += -L$(TENSORFLOW)/bazel-bin/tensorflow -ltensorflow_cc -ltensorflow_framework diff --git a/src/tfrnnlmbin/Makefile b/src/tfrnnlmbin/Makefile index f2a353c918c..4beeeb0d594 100644 --- a/src/tfrnnlmbin/Makefile +++ b/src/tfrnnlmbin/Makefile @@ -29,8 +29,8 @@ TESTFILES = ADDLIBS = ../lat/kaldi-lat.a ../lm/kaldi-lm.a ../fstext/kaldi-fstext.a \ ../hmm/kaldi-hmm.a ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a \ - ../base/kaldi-base.a ../tfrnnlm/kaldi-tensorflow-rnnlm.a + ../matrix/kaldi-matrix.a ../base/kaldi-base.a \ + ../tfrnnlm/kaldi-tensorflow-rnnlm.a LDLIBS += -lz -ldl -fPIC -lrt LDLIBS += -L$(TENSORFLOW)/bazel-bin/tensorflow -ltensorflow_cc -ltensorflow_framework diff --git a/src/transform/Makefile b/src/transform/Makefile index 02f5d0ec396..a265db6ac37 100644 --- a/src/transform/Makefile +++ b/src/transform/Makefile @@ -14,8 +14,7 @@ OBJFILES = regression-tree.o regtree-mllr-diag-gmm.o lda-estimate.o \ LIBNAME = kaldi-transform -ADDLIBS = ../gmm/kaldi-gmm.a ../tree/kaldi-tree.a \ - ../util/kaldi-util.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a +ADDLIBS = ../gmm/kaldi-gmm.a ../tree/kaldi-tree.a ../util/kaldi-util.a \ + ../matrix/kaldi-matrix.a ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/transform/cmvn.cc b/src/transform/cmvn.cc index 8dfe016227a..76f6652eecd 100644 --- a/src/transform/cmvn.cc +++ b/src/transform/cmvn.cc @@ -74,41 +74,43 @@ void ApplyCmvn(const MatrixBase &stats, if (stats.NumRows() == 1 && var_norm) KALDI_ERR << "You requested variance normalization but no variance stats " << "are supplied."; - + double count = stats(0, dim); // Do not change the threshold of 1.0 here: in the balanced-cmvn code, when // computing an offset and representing it as stats, we use a count of one. if (count < 1.0) KALDI_ERR << "Insufficient stats for cepstral mean and variance normalization: " << "count = " << count; - - Matrix norm(2, dim); // norm(0, d) = mean offset + + if (!var_norm) { + Vector offset(dim); + SubVector mean_stats(stats.RowData(0), dim); + offset.AddVec(-1.0 / count, mean_stats); + feats->AddVecToRows(1.0, offset); + return; + } + // norm(0, d) = mean offset; // norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d). + Matrix norm(2, dim); for (int32 d = 0; d < dim; d++) { double mean, offset, scale; mean = stats(0, d)/count; - if (!var_norm) { - scale = 1.0; - offset = -mean; - } else { - double var = (stats(1, d)/count) - mean*mean, - floor = 1.0e-20; - if (var < floor) { - KALDI_WARN << "Flooring cepstral variance from " << var << " to " - << floor; - var = floor; - } - scale = 1.0 / sqrt(var); - if (scale != scale || 1/scale == 0.0) - KALDI_ERR << "NaN or infinity in cepstral mean/variance computation"; - offset = -(mean*scale); + double var = (stats(1, d)/count) - mean*mean, + floor = 1.0e-20; + if (var < floor) { + KALDI_WARN << "Flooring cepstral variance from " << var << " to " + << floor; + var = floor; } + scale = 1.0 / sqrt(var); + if (scale != scale || 1/scale == 0.0) + KALDI_ERR << "NaN or infinity in cepstral mean/variance computation"; + offset = -(mean*scale); norm(0, d) = offset; norm(1, d) = scale; } // Apply the normalization. - if (var_norm) - feats->MulColsVec(norm.Row(1)); + feats->MulColsVec(norm.Row(1)); feats->AddVecToRows(1.0, norm.Row(0)); } @@ -125,14 +127,14 @@ void ApplyCmvnReverse(const MatrixBase &stats, if (stats.NumRows() == 1 && var_norm) KALDI_ERR << "You requested variance normalization but no variance stats " << "are supplied."; - + double count = stats(0, dim); // Do not change the threshold of 1.0 here: in the balanced-cmvn code, when // computing an offset and representing it as stats, we use a count of one. if (count < 1.0) KALDI_ERR << "Insufficient stats for cepstral mean and variance normalization: " << "count = " << count; - + Matrix norm(2, dim); // norm(0, d) = mean offset // norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d). for (int32 d = 0; d < dim; d++) { diff --git a/src/transform/decodable-am-diag-gmm-regtree.h b/src/transform/decodable-am-diag-gmm-regtree.h index 9da4b7f1591..b6e7888ffdc 100644 --- a/src/transform/decodable-am-diag-gmm-regtree.h +++ b/src/transform/decodable-am-diag-gmm-regtree.h @@ -51,7 +51,7 @@ class DecodableAmDiagGmmRegtreeFmllr: public DecodableAmDiagGmmUnmapped { // Note, frames are numbered from zero but transition-ids (tid) from one. virtual BaseFloat LogLikelihood(int32 frame, int32 tid) { return scale_*LogLikelihoodZeroBased(frame, - trans_model_.TransitionIdToPdf(tid)); + trans_model_.TransitionIdToPdfFast(tid)); } virtual int32 NumFramesReady() const { return feature_matrix_.NumRows(); } @@ -94,7 +94,7 @@ class DecodableAmDiagGmmRegtreeMllr: public DecodableAmDiagGmmUnmapped { // Note, frames are numbered from zero but transition-ids (tid) from one. virtual BaseFloat LogLikelihood(int32 frame, int32 tid) { return scale_*LogLikelihoodZeroBased(frame, - trans_model_.TransitionIdToPdf(tid)); + trans_model_.TransitionIdToPdfFast(tid)); } virtual int32 NumFramesReady() const { return feature_matrix_.NumRows(); } diff --git a/src/util/Makefile b/src/util/Makefile index 80c57fd7435..acfab8b8de1 100644 --- a/src/util/Makefile +++ b/src/util/Makefile @@ -15,6 +15,6 @@ OBJFILES = text-utils.o kaldi-io.o kaldi-holder.o kaldi-table.o \ LIBNAME = kaldi-util -ADDLIBS = ../matrix/kaldi-matrix.a ../base/kaldi-base.a +ADDLIBS = ../matrix/kaldi-matrix.a ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/tools/extras/install_irstlm.sh b/tools/extras/install_irstlm.sh index b27f0f89897..7c88377ad87 100755 --- a/tools/extras/install_irstlm.sh +++ b/tools/extras/install_irstlm.sh @@ -11,6 +11,13 @@ errcho() { echo "$@" 1>&2; } errcho "****() Installing IRSTLM" +if [ ! -d ./extras ]; then + errcho "****** You are trying to install IRSTLM from the wrong directory. You should" + errcho "****** go to tools/ and type extras/install_irstlm.sh." + exit 1 +fi + + if [ ! -d ./irstlm ] ; then svn=`which git` if [ $? != 0 ] ; then diff --git a/tools/extras/install_mmseg.sh b/tools/extras/install_mmseg.sh index 586740b5cbc..3c50ddc7dad 100755 --- a/tools/extras/install_mmseg.sh +++ b/tools/extras/install_mmseg.sh @@ -18,7 +18,7 @@ if ! $(python -c "import distutils.sysconfig" &> /dev/null); then echo "Proceeding with installation." >&2 else # get include path for this python version - INCLUDE_PY=$(python -c "from distutils import sysconfig as s; print s.get_config_vars()['INCLUDEPY']") + INCLUDE_PY=$(python -c "from distutils import sysconfig as s; print(s.get_python_inc())") if [ ! -f "${INCLUDE_PY}/Python.h" ]; then echo "$0 : ERROR: python-devel/python-dev not installed" >&2 if which yum >&/dev/null; then diff --git a/tools/extras/install_phonetisaurus.sh b/tools/extras/install_phonetisaurus.sh index 617aa341b32..e7594233d52 100755 --- a/tools/extras/install_phonetisaurus.sh +++ b/tools/extras/install_phonetisaurus.sh @@ -21,7 +21,7 @@ if ! $(python -c "import distutils.sysconfig" &> /dev/null); then echo "Proceeding with installation." >&2 else # get include path for this python version - INCLUDE_PY=$(python -c "from distutils import sysconfig as s; print s.get_config_vars()['INCLUDEPY']") + INCLUDE_PY=$(python -c "from distutils import sysconfig as s; print(s.get_python_inc())") if [ ! -f "${INCLUDE_PY}/Python.h" ]; then echo "$0 : ERROR: python-devel/python-dev not installed" >&2 if which yum >&/dev/null; then diff --git a/tools/extras/install_sequitur.sh b/tools/extras/install_sequitur.sh index 6ee4d9f4336..4250006651b 100755 --- a/tools/extras/install_sequitur.sh +++ b/tools/extras/install_sequitur.sh @@ -18,7 +18,7 @@ if ! $(python -c "import distutils.sysconfig" &> /dev/null); then echo "Proceeding with installation." >&2 else # get include path for this python version - INCLUDE_PY=$(python -c "from distutils import sysconfig as s; print s.get_config_vars()['INCLUDEPY']") + INCLUDE_PY=$(python -c "from distutils import sysconfig as s; print(s.get_python_inc())") if [ ! -f "${INCLUDE_PY}/Python.h" ]; then echo "$0 : ERROR: python-devel/python-dev not installed" >&2 if which yum >&/dev/null; then