diff --git a/egs/rm/s5/RESULTS b/egs/rm/s5/RESULTS index 65a9840df71..ecafb588cfe 100644 --- a/egs/rm/s5/RESULTS +++ b/egs/rm/s5/RESULTS @@ -230,8 +230,9 @@ for x in exp/nnet2_online_wsj/nnet_ms_a_smbr_0.00005/1/decode_*; do grep WER $x/ %WER 7.36 [ 923 / 12533, 85 ins, 148 del, 690 sub ] exp/nnet2_online_wsj/nnet_ms_a_smbr_0.00005/1/decode_ug_epoch4/wer_13 ### chain results ### -# current best chain result with TDNN (check local/chain/run_tdnn_5f.sh) -%WER 2.94 [ 369 / 12533, 51 ins, 71 del, 247 sub ] exp/chain/tdnn_5f/decode/wer_3_0.5 +# current best chain result with TDNN (check local/chain/run_tdnn_5g.sh) +%WER 2.86 [ 358 / 12533, 46 ins, 61 del, 251 sub ] exp/chain/tdnn_5g/decode/wer_5_0.0 +%WER 2.71 [ 340 / 12533, 58 ins, 59 del, 223 sub ] exp/chain/tdnn_5n/decode/wer_4_0.0 ### nnet1 results ### diff --git a/egs/rm/s5/local/chain/run_tdnn_5g.sh b/egs/rm/s5/local/chain/run_tdnn_5g.sh new file mode 100755 index 00000000000..f6fbe070763 --- /dev/null +++ b/egs/rm/s5/local/chain/run_tdnn_5g.sh @@ -0,0 +1,155 @@ +#!/bin/bash + +# This is modified from run_tdnn_5f.sh, to use the old topology, as a baseline +# to test the modified transition-model code (by which we hope to be able to +# create more compact decoding graphs for chain models). + +set -e + +# configs for 'chain' +stage=0 +train_stage=-10 +get_egs_stage=-10 +dir=exp/chain/tdnn_5g + +# training options +num_epochs=12 +initial_effective_lrate=0.005 +final_effective_lrate=0.0005 +leftmost_questions_truncate=-1 +max_param_change=2.0 +final_layer_normalize_target=0.5 +num_jobs_initial=2 +num_jobs_final=4 +minibatch_size=128 +frames_per_eg=150 +remove_egs=false + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo +fi + +if [ $stage -le 6 ]; then + # Build a tree using our new topology. + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ + --leftmost-questions-truncate $leftmost_questions_truncate \ + --cmd "$train_cmd" 1200 data/train $lang $ali_dir $treedir +fi + +if [ $stage -le 7 ]; then + mkdir -p $dir + + echo "$0: creating neural net configs"; + + steps/nnet3/tdnn/make_configs.py \ + --self-repair-scale-nonlinearity 0.00001 \ + --feat-dir data/train \ + --ivector-dir exp/nnet2_online/ivectors \ + --tree-dir $treedir \ + --relu-dim 450 \ + --splice-indexes "-1,0,1 -2,-1,0,1 -3,0,3 -6,-3,0 0" \ + --use-presoftmax-prior-scale false \ + --xent-regularize 0.1 \ + --xent-separate-forward-affine true \ + --include-log-softmax false \ + --final-layer-normalize-target 1.0 \ + $dir/configs || exit 1; +fi + +if [ $stage -le 8 ]; then + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir exp/nnet2_online/ivectors \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize 0.1 \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=200" \ + --egs.dir "$common_egs_dir" \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width $frames_per_eg \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter 1000000 \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.num-jobs-initial $num_jobs_initial \ + --trainer.optimization.num-jobs-final $num_jobs_final \ + --trainer.optimization.initial-effective-lrate $initial_effective_lrate \ + --trainer.optimization.final-effective-lrate $final_effective_lrate \ + --trainer.max-param-change $max_param_change \ + --cleanup.remove-egs true \ + --feat-dir data/train \ + --tree-dir $treedir \ + --lat-dir exp/tri3b_lats \ + --dir $dir +fi + +if [ $stage -le 9 ]; then + steps/online/nnet2/extract_ivectors_online.sh --cmd "$train_cmd" --nj 4 \ + data/test exp/nnet2_online/extractor exp/nnet2_online/ivectors_test || exit 1; +fi + +if [ $stage -le 10 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 data/lang $dir $dir/graph + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --scoring-opts "--min-lmwt 1" \ + --nj 20 --cmd "$decode_cmd" \ + --online-ivector-dir exp/nnet2_online/ivectors_test \ + $dir/graph data/test $dir/decode || exit 1; +fi + +if [ $stage -le 11 ]; then + utils/mkgraph.sh --self-loop-scale 1.0 data/lang_ug $dir $dir/graph_ug + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj 20 --cmd "$decode_cmd" \ + --online-ivector-dir exp/nnet2_online/ivectors_test \ + $dir/graph_ug data/test $dir/decode_ug || exit 1; +fi +wait; +exit 0; diff --git a/egs/rm/s5/local/chain/run_tdnn_5f.sh b/egs/rm/s5/local/chain/run_tdnn_5n.sh old mode 100644 new mode 100755 similarity index 62% rename from egs/rm/s5/local/chain/run_tdnn_5f.sh rename to egs/rm/s5/local/chain/run_tdnn_5n.sh index 0379d16fe13..7fd7b82aa1d --- a/egs/rm/s5/local/chain/run_tdnn_5f.sh +++ b/egs/rm/s5/local/chain/run_tdnn_5n.sh @@ -1,6 +1,9 @@ #!/bin/bash -# this script is a modified version of swbd/run_tdnn_5f.sh +# this script is a modified version of run_tdnn_5g.sh. It uses +# the new transition model and the python version of training scripts. + + set -e @@ -8,7 +11,7 @@ set -e stage=0 train_stage=-10 get_egs_stage=-10 -dir=exp/chain/tdnn_5f +dir=exp/chain/tdnn_5n # training options num_epochs=12 @@ -43,13 +46,13 @@ fi # run those things. ali_dir=exp/tri3b_ali -treedir=exp/chain/tri4_2y_tree -lang=data/lang_chain_2y +treedir=exp/chain/tri4_5n_tree +lang=data/lang_chain_5n local/online/run_nnet2_common.sh --stage $stage || exit 1; if [ $stage -le 4 ]; then - # Get the alignments as lattices (gives the CTC training more freedom). + # Get the alignments as lattices (gives the chain training more freedom). # use the same num-jobs as the alignments nj=$(cat exp/tri3b_ali/num_jobs) || exit 1; steps/align_fmllr_lats.sh --nj $nj --cmd "$train_cmd" data/train \ @@ -78,51 +81,73 @@ if [ $stage -le 6 ]; then fi if [ $stage -le 7 ]; then - steps/nnet3/chain/train_tdnn.sh --stage $train_stage \ + mkdir -p $dir + + echo "$0: creating neural net configs"; + + steps/nnet3/tdnn/make_configs.py \ + --self-repair-scale-nonlinearity 0.00001 \ + --feat-dir data/train \ + --ivector-dir exp/nnet2_online/ivectors \ + --tree-dir $treedir \ + --relu-dim 450 \ + --splice-indexes "-1,0,1 -2,-1,0,1 -3,0,3 -6,-3,0 0" \ + --use-presoftmax-prior-scale false \ --xent-regularize 0.1 \ - --leaky-hmm-coefficient 0.1 \ - --l2-regularize 0.00005 \ - --jesus-opts "--jesus-forward-input-dim 200 --jesus-forward-output-dim 500 --jesus-hidden-dim 2000 --jesus-stddev-scale 0.2 --final-layer-learning-rate-factor 0.25" \ - --splice-indexes "-1,0,1 -2,-1,0,1 -3,0,3 -6,-3,0" \ - --apply-deriv-weights false \ - --frames-per-iter 1000000 \ - --lm-opts "--num-extra-lm-states=200" \ - --get-egs-stage $get_egs_stage \ - --minibatch-size $minibatch_size \ - --egs-opts "--frames-overlap-per-eg 0" \ - --frames-per-eg $frames_per_eg \ - --num-epochs $num_epochs --num-jobs-initial $num_jobs_initial --num-jobs-final $num_jobs_final \ - --feat-type raw \ - --online-ivector-dir exp/nnet2_online/ivectors \ - --cmvn-opts "--norm-means=false --norm-vars=false" \ - --initial-effective-lrate $initial_effective_lrate --final-effective-lrate $final_effective_lrate \ - --max-param-change $max_param_change \ - --cmd "$decode_cmd" \ - --remove-egs $remove_egs \ - data/train $treedir exp/tri3b_lats $dir || exit 1; + --xent-separate-forward-affine true \ + --include-log-softmax false \ + --final-layer-normalize-target 1.0 \ + $dir/configs || exit 1; fi if [ $stage -le 8 ]; then + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.online-ivector-dir exp/nnet2_online/ivectors \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize 0.1 \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=200" \ + --egs.dir "$common_egs_dir" \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width $frames_per_eg \ + --trainer.num-chunk-per-minibatch $minibatch_size \ + --trainer.frames-per-iter 1000000 \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.num-jobs-initial $num_jobs_initial \ + --trainer.optimization.num-jobs-final $num_jobs_final \ + --trainer.optimization.initial-effective-lrate $initial_effective_lrate \ + --trainer.optimization.final-effective-lrate $final_effective_lrate \ + --trainer.max-param-change $max_param_change \ + --cleanup.remove-egs true \ + --feat-dir data/train \ + --tree-dir $treedir \ + --lat-dir exp/tri3b_lats \ + --dir $dir +fi + +if [ $stage -le 9 ]; then steps/online/nnet2/extract_ivectors_online.sh --cmd "$train_cmd" --nj 4 \ data/test exp/nnet2_online/extractor exp/nnet2_online/ivectors_test || exit 1; fi -if [ $stage -le 9 ]; then +if [ $stage -le 10 ]; then # Note: it might appear that this $lang directory is mismatched, and it is as # far as the 'topo' is concerned, but this script doesn't read the 'topo' from # the lang directory. utils/mkgraph.sh --self-loop-scale 1.0 data/lang $dir $dir/graph steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ - --extra-left-context 20 --scoring-opts "--min-lmwt 1" \ + --scoring-opts "--min-lmwt 1" \ --nj 20 --cmd "$decode_cmd" \ --online-ivector-dir exp/nnet2_online/ivectors_test \ $dir/graph data/test $dir/decode || exit 1; fi -if [ $stage -le 10 ]; then +if [ $stage -le 11 ]; then utils/mkgraph.sh --self-loop-scale 1.0 data/lang_ug $dir $dir/graph_ug steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ - --extra-left-context 20 \ --nj 20 --cmd "$decode_cmd" \ --online-ivector-dir exp/nnet2_online/ivectors_test \ $dir/graph_ug data/test $dir/decode_ug || exit 1; diff --git a/egs/wsj/s5/steps/nnet3/chain/gen_topo.py b/egs/wsj/s5/steps/nnet3/chain/gen_topo.py index fdd7a02fd88..b27cd9eff1c 100755 --- a/egs/wsj/s5/steps/nnet3/chain/gen_topo.py +++ b/egs/wsj/s5/steps/nnet3/chain/gen_topo.py @@ -2,6 +2,9 @@ # Copyright 2012 Johns Hopkins University (author: Daniel Povey) +# This script was modified around 11.11.2016, when the code was extended to +# support having a different pdf-class on the self loop. + # Generate a topology file. This allows control of the number of states in the # non-silence HMMs, and in the silence HMMs. This is a modified version of # 'utils/gen_topo.pl' that generates a different type of topology, one that we @@ -41,9 +44,8 @@ # We make the transition-probs 0.5 so they normalize, to keep the code happy. # In fact, we always set the transition probability scale to 0.0 in the 'chain' # code, so they are never used. -print(" 0 0 1 0.5 2 0.5 ") -print(" 1 1 1 0.5 2 0.5 ") -print(" 2 ") +print(" 0 0 1 0 0.5 1 0.5 ") +print(" 1 ") print("") print("") diff --git a/egs/wsj/s5/steps/nnet3/chain/gen_topo_orig.py b/egs/wsj/s5/steps/nnet3/chain/gen_topo_orig.py new file mode 100755 index 00000000000..01a715a9a23 --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/chain/gen_topo_orig.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python + +# Copyright 2012 Johns Hopkins University (author: Daniel Povey) + +# This file is as ./gen_topo.py used to be (before we extended the transition-model +# code to support having a different self-loop pdf-class). It is included +# here for baseline and testing purposes. + + +# Generate a topology file. This allows control of the number of states in the +# non-silence HMMs, and in the silence HMMs. This is a modified version of +# 'utils/gen_topo.pl' that generates a different type of topology, one that we +# believe should be useful in the 'chain' model. Note: right now it doesn't +# have any real options, and it treats silence and nonsilence the same. The +# intention is that you write different versions of this script, or add options, +# if you experiment with it. + +from __future__ import print_function +import argparse + + +parser = argparse.ArgumentParser(description="Usage: steps/nnet3/chain/gen_topo.py " + " " + "e.g.: steps/nnet3/chain/gen_topo.pl 4:5:6:7:8:9:10 1:2:3\n", + epilog="See egs/swbd/s5c/local/chain/train_tdnn_a.sh for example of usage."); +parser.add_argument("nonsilence_phones", type=str, + help="List of non-silence phones as integers, separated by colons, e.g. 4:5:6:7:8:9"); +parser.add_argument("silence_phones", type=str, + help="List of silence phones as integers, separated by colons, e.g. 1:2:3"); + +args = parser.parse_args() + +silence_phones = [ int(x) for x in args.silence_phones.split(":") ] +nonsilence_phones = [ int(x) for x in args.nonsilence_phones.split(":") ] +all_phones = silence_phones + nonsilence_phones + +print("") +print("") +print("") +print(" ".join([str(x) for x in all_phones])) +print("") +# The next two lines may look like a bug, but they are as intended. State 0 has +# no self-loop, it happens exactly once. And it can go either to state 1 (with +# a self-loop) or to state 2, so we can have zero or more instances of state 1 +# following state 0. +# We make the transition-probs 0.5 so they normalize, to keep the code happy. +# In fact, we always set the transition probability scale to 0.0 in the 'chain' +# code, so they are never used. +print(" 0 0 1 0.5 2 0.5 ") +print(" 1 1 1 0.5 2 0.5 ") +print(" 2 ") +print("") +print("") diff --git a/src/bin/acc-tree-stats.cc b/src/bin/acc-tree-stats.cc index 90432c2e58a..8b9ce9065b4 100644 --- a/src/bin/acc-tree-stats.cc +++ b/src/bin/acc-tree-stats.cc @@ -128,5 +128,3 @@ int main(int argc, char *argv[]) { return -1; } } - - diff --git a/src/chain/chain-den-graph.cc b/src/chain/chain-den-graph.cc index b092b3de4d7..6f494a0c562 100644 --- a/src/chain/chain-den-graph.cc +++ b/src/chain/chain-den-graph.cc @@ -220,11 +220,6 @@ static void SortOnTransitionCount(fst::StdVectorFst *fst) { void DenGraphMinimizeWrapper(fst::StdVectorFst *fst) { for (int32 i = 1; i <= 3; i++) { - fst::PushSpecial(fst, fst::kDelta * 0.01); - MinimizeAcceptorNoPush(fst); - KALDI_LOG << "Number of states and arcs in transition-id FST after regular " - << "minimization is " << fst->NumStates() << " and " - << NumArcs(*fst) << " (pass " << i << ")"; fst::StdVectorFst fst_reversed; fst::Reverse(*fst, &fst_reversed); fst::PushSpecial(&fst_reversed, fst::kDelta * 0.01); @@ -233,6 +228,11 @@ void DenGraphMinimizeWrapper(fst::StdVectorFst *fst) { KALDI_LOG << "Number of states and arcs in transition-id FST after reversed " << "minimization is " << fst->NumStates() << " and " << NumArcs(*fst) << " (pass " << i << ")"; + fst::PushSpecial(fst, fst::kDelta * 0.01); + MinimizeAcceptorNoPush(fst); + KALDI_LOG << "Number of states and arcs in transition-id FST after regular " + << "minimization is " << fst->NumStates() << " and " + << NumArcs(*fst) << " (pass " << i << ")"; } fst::RmEpsilon(fst); KALDI_LOG << "Number of states and arcs in transition-id FST after " @@ -347,7 +347,7 @@ void CreateDenominatorFst(const ContextDependency &ctx_dep, BaseFloat self_loop_scale = 1.0; // We have to be careful to use the same // value in test time. - bool reorder = false; + bool reorder = true; // add self-loops to the FST with transition-ids as its labels. AddSelfLoops(trans_model, disambig_syms_h, self_loop_scale, reorder, &transition_id_fst); diff --git a/src/doc/hmm.dox b/src/doc/hmm.dox index 30873cfa9b0..a051caffd76 100644 --- a/src/doc/hmm.dox +++ b/src/doc/hmm.dox @@ -80,7 +80,7 @@ that the transition probabilities in the HmmTopology object are generally not used after initializing the TransitionModel object. There is an exception to this, however; for nonemitting states that are non-final (i.e. those that have transitions out of them but no \ entry), Kaldi does not train the -transition probabilities and instead it uses the probabilities given in the +7transition probabilities and instead it uses the probabilities given in the HmmTopology object. The decision not to support trainable transition probabilities for non-emitting states simplifies our training mechanisms, and since it is not normal to have non-emitting states with transitions, we felt that this was no @@ -92,7 +92,17 @@ great loss. The pdf-class is a concept that relates to the HmmTopology object. The HmmTopology object specifies a prototype HMM for each phone. Each numbered state of a -"prototype HMM" has a variable "pdf_class". If two states have the same +"prototype HMM" has two variables "forward_pdf_class" and "self_loop_pdf_class". +The "self_loop_pdf_class" is a kind of pdf-class that is associated +with self-loop transition. It is by default identical to "forward_pdf_class", +but it can be used to define less-convectional HMM topologies +where the pdfs on the self-loop and forward transitions are different. +The decision to allow the pdf-class on just the self-loop to be different, +while not embracing a fully "arc-based" representation where the pdfs on +all transitions in the HMM are potentially independent, was made as a compromise, +to allow for compatibility with previous versions of Kaldi while supporting the topology +used in our "chain models" AKA lattice-free MMI. +If two states have the same pdf_class variable, then they will always share the same probability distribution function (p.d.f.) if they are in the same phonetic context. This is because the decision-tree code does not get to "see" the HMM-state directly, @@ -121,11 +131,14 @@ object to get the pdf-ids associated with particular phonetic contexts). The decision that underlies a lot of the transition-modeling code is as follows: we have decided to make the transition probability of a -context dependent HMM state depend on the following four things (you could view -them as a 4-tuple): +context dependent HMM state depend on the following five things (you could view +them as a 5-tuple): - The phone (whose HMM we are in) - The source HMM-state (as interpreted by the HmmTopology object, i.e. normally 0, 1 or 2) - - The \ref pdf_id "pdf-id" (i.e. the index of the pdf associated with the state) + - The \ref pdf_id "forward-pdf-id" + (i.e. the index of the forward transition pdfs associated with the state) + - The \ref pdf_id "self-loop-pdf-id" + (i.e. the index of the self-loop pdfs associated with the state) - The index of the transition in the HmmTopology object. The last of these four items could be viewed as encoding the destination @@ -198,7 +211,7 @@ prototype HMM (as given in the HmmTopology object). from (transition-state, transition-index) to transition-id, and vice versa. There are also in the transition-modeling code reference to the following concepts: - - A triple means a triple (phone, hmm-state, pdf) which is mappable to and from a transition-state. + - A tuple means a 4-tuple (phone, hmm-state, forward pdf, self-loop pdf) which is mappable to and from a transition-state. - A pair means a pair (transition-state, transition-index) which is mappable to and from a transition-id. \section hmm_transition_training Training the transition model diff --git a/src/hmm/hmm-test-utils.cc b/src/hmm/hmm-test-utils.cc index 4cfebcd0d51..ceca116c828 100644 --- a/src/hmm/hmm-test-utils.cc +++ b/src/hmm/hmm-test-utils.cc @@ -203,7 +203,7 @@ void GeneratePathThroughHmm(const HmmTopology &topology, const HmmTopology::HmmState &cur_hmm_state = this_entry[cur_state]; int32 num_transitions = cur_hmm_state.transitions.size(), transition_index = RandInt(0, num_transitions - 1); - if (cur_hmm_state.pdf_class != -1) { + if (cur_hmm_state.forward_pdf_class != -1) { std::pair pr(cur_state, transition_index); if (!reorder) { path->push_back(pr); @@ -257,12 +257,15 @@ void GenerateRandomAlignment(const ContextDependencyInterface &ctx_dep, trans_model.GetTopo().TopologyForPhone(phone); int32 hmm_state = path[k].first, transition_index = path[k].second, - pdf_class = entry[hmm_state].pdf_class, - pdf_id; - bool ans = ctx_dep.Compute(context_window, pdf_class, &pdf_id); + forward_pdf_class = entry[hmm_state].forward_pdf_class, + self_loop_pdf_class = entry[hmm_state].self_loop_pdf_class, + forward_pdf_id, self_loop_pdf_id; + bool ans = ctx_dep.Compute(context_window, forward_pdf_class, &forward_pdf_id); KALDI_ASSERT(ans && "context-dependency computation failed."); - int32 transition_state = trans_model.TripleToTransitionState( - phone, hmm_state, pdf_id), + ans = ctx_dep.Compute(context_window, self_loop_pdf_class, &self_loop_pdf_id); + KALDI_ASSERT(ans && "context-dependency computation failed."); + int32 transition_state = trans_model.TupleToTransitionState( + phone, hmm_state, forward_pdf_id, self_loop_pdf_id), transition_id = trans_model.PairToTransitionId(transition_state, transition_index); alignment->push_back(transition_id); diff --git a/src/hmm/hmm-topology-test.cc b/src/hmm/hmm-topology-test.cc index 61cf13e17bc..14081d2355d 100644 --- a/src/hmm/hmm-topology-test.cc +++ b/src/hmm/hmm-topology-test.cc @@ -58,6 +58,17 @@ void TestHmmTopology() { " \n" " \n"; + std::string chain_input_str = "\n" + "\n" + " 1 2 3 4 5 6 7 8 9 \n" + " 0 0 1\n" + " 0 0.5\n" + " 1 0.5\n" + " \n" + " 1 \n" + "\n" + "\n"; + HmmTopology topo; if (RandInt(0, 1) == 0) { @@ -84,6 +95,13 @@ void TestHmmTopology() { KALDI_ASSERT(oss1.str() == oss2.str()); } + { // test chain topology + HmmTopology chain_topo; + std::istringstream chain_iss(chain_input_str); + chain_topo.Read(chain_iss, false); + KALDI_ASSERT(chain_topo.MinLength(3) == 1); + } + { // make sure GetDefaultTopology does not crash. std::vector phones; phones.push_back(1); diff --git a/src/hmm/hmm-topology.cc b/src/hmm/hmm-topology.cc index 54144326766..cf134065dbf 100644 --- a/src/hmm/hmm-topology.cc +++ b/src/hmm/hmm-topology.cc @@ -76,12 +76,24 @@ void HmmTopology::Read(std::istream &is, bool binary) { KALDI_ERR << "States are expected to be in order from zero, expected " << this_entry.size() << ", got " << state; ReadToken(is, binary, &token); - int32 pdf_class = kNoPdf; // -1 by default, means no pdf. + int32 forward_pdf_class = kNoPdf; // -1 by default, means no pdf. if (token == "") { - ReadBasicType(is, binary, &pdf_class); + ReadBasicType(is, binary, &forward_pdf_class); + this_entry.push_back(HmmState(forward_pdf_class)); ReadToken(is, binary, &token); - } - this_entry.push_back(HmmState(pdf_class)); + if (token == "") + KALDI_ERR << "pdf classes should be defined using " + << "or / pair"; + } else if (token == "") { + int32 self_loop_pdf_class = kNoPdf; + ReadBasicType(is, binary, &forward_pdf_class); + ReadToken(is, binary, &token); + KALDI_ASSERT(token == ""); + ReadBasicType(is, binary, &self_loop_pdf_class); + this_entry.push_back(HmmState(forward_pdf_class, self_loop_pdf_class)); + ReadToken(is, binary, &token); + } else + this_entry.push_back(HmmState(forward_pdf_class)); while (token == "") { int32 dst_state; BaseFloat trans_prob; @@ -118,13 +130,22 @@ void HmmTopology::Read(std::istream &is, bool binary) { ReadIntegerVector(is, binary, &phone2idx_); int32 sz; ReadBasicType(is, binary, &sz); + bool is_hmm = true; + if (sz == -1) { + is_hmm = false; + ReadBasicType(is, binary, &sz); + } entries_.resize(sz); for (int32 i = 0; i < sz; i++) { int32 thist_sz; ReadBasicType(is, binary, &thist_sz); entries_[i].resize(thist_sz); for (int32 j = 0 ; j < thist_sz; j++) { - ReadBasicType(is, binary, &(entries_[i][j].pdf_class)); + ReadBasicType(is, binary, &(entries_[i][j].forward_pdf_class)); + if (is_hmm) + entries_[i][j].self_loop_pdf_class = entries_[i][j].forward_pdf_class; + else + ReadBasicType(is, binary, &(entries_[i][j].self_loop_pdf_class)); int32 thiss_sz; ReadBasicType(is, binary, &thiss_sz); entries_[i][j].transitions.resize(thiss_sz); @@ -141,6 +162,7 @@ void HmmTopology::Read(std::istream &is, bool binary) { void HmmTopology::Write(std::ostream &os, bool binary) const { + bool is_hmm = IsHmm(); WriteToken(os, binary, ""); if (!binary) { // Text-mode write. os << "\n"; @@ -159,9 +181,17 @@ void HmmTopology::Write(std::ostream &os, bool binary) const { for (size_t j = 0; j < entries_[i].size(); j++) { WriteToken(os, binary, ""); WriteBasicType(os, binary, static_cast(j)); - if (entries_[i][j].pdf_class != kNoPdf) { - WriteToken(os, binary, ""); - WriteBasicType(os, binary, entries_[i][j].pdf_class); + if (entries_[i][j].forward_pdf_class != kNoPdf) { + if (is_hmm) { + WriteToken(os, binary, ""); + WriteBasicType(os, binary, entries_[i][j].forward_pdf_class); + } else { + WriteToken(os, binary, ""); + WriteBasicType(os, binary, entries_[i][j].forward_pdf_class); + KALDI_ASSERT(entries_[i][j].self_loop_pdf_class != kNoPdf); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, entries_[i][j].self_loop_pdf_class); + } } for (size_t k = 0; k < entries_[i][j].transitions.size(); k++) { WriteToken(os, binary, ""); @@ -177,11 +207,15 @@ void HmmTopology::Write(std::ostream &os, bool binary) const { } else { WriteIntegerVector(os, binary, phones_); WriteIntegerVector(os, binary, phone2idx_); + // -1 is put here as a signal that the object has the new, + // extended format with SelfLoopPdfClass + if (!is_hmm) WriteBasicType(os, binary, static_cast(-1)); WriteBasicType(os, binary, static_cast(entries_.size())); for (size_t i = 0; i < entries_.size(); i++) { WriteBasicType(os, binary, static_cast(entries_[i].size())); for (size_t j = 0; j < entries_[i].size(); j++) { - WriteBasicType(os, binary, entries_[i][j].pdf_class); + WriteBasicType(os, binary, entries_[i][j].forward_pdf_class); + if (!is_hmm) WriteBasicType(os, binary, entries_[i][j].self_loop_pdf_class); WriteBasicType(os, binary, static_cast(entries_[i][j].transitions.size())); for (size_t k = 0; k < entries_[i][j].transitions.size(); k++) { WriteBasicType(os, binary, entries_[i][j].transitions[k].first); @@ -215,7 +249,7 @@ void HmmTopology::Check() { if (!entries_[i][num_states-1].transitions.empty()) KALDI_ERR << "HmmTopology::Check(), last state must have no transitions."; // not sure how necessary this next stipulation is. - if (entries_[i][num_states-1].pdf_class != kNoPdf) + if (entries_[i][num_states-1].forward_pdf_class != kNoPdf) KALDI_ERR << "HmmTopology::Check(), last state must not be emitting."; std::vector has_trans_in(num_states, false); @@ -223,8 +257,10 @@ void HmmTopology::Check() { for (int32 j = 0; j < num_states; j++) { // j is the state-id. BaseFloat tot_prob = 0.0; - if (entries_[i][j].pdf_class != kNoPdf) - seen_pdf_classes.push_back(entries_[i][j].pdf_class); + if (entries_[i][j].forward_pdf_class != kNoPdf) { + seen_pdf_classes.push_back(entries_[i][j].forward_pdf_class); + seen_pdf_classes.push_back(entries_[i][j].self_loop_pdf_class); + } std::set seen_transition; for (int32 k = 0; static_cast(k) < entries_[i][j].transitions.size(); @@ -238,7 +274,7 @@ void HmmTopology::Check() { // that are being built, which enable the creation of phone-level lattices // and rescoring these with a different lexicon and LM. if (dst_state == num_states-1 // && j != 0 - && entries_[i][j].pdf_class == kNoPdf) + && entries_[i][j].forward_pdf_class == kNoPdf) KALDI_ERR << "We do not allow any state to be " "nonemitting and have a transition to the final-state (this would " "stop the SplitToPhones function from identifying the last state " @@ -248,7 +284,8 @@ void HmmTopology::Check() { if (seen_transition.count(dst_state) != 0) KALDI_ERR << "HmmTopology::Check(), duplicate transition found."; if (dst_state == k) { // self_loop... - KALDI_ASSERT(entries_[i][j].pdf_class != kNoPdf && "Nonemitting states cannot have self-loops."); + KALDI_ASSERT(entries_[i][j].self_loop_pdf_class != kNoPdf && + "Nonemitting states cannot have self-loops."); } seen_transition.insert(dst_state); has_trans_in[dst_state] = true; @@ -275,6 +312,22 @@ void HmmTopology::Check() { } } +bool HmmTopology::IsHmm() const { + const std::vector &phones = GetPhones(); + KALDI_ASSERT(!phones.empty()); + for (size_t i = 0; i < phones.size(); i++) { + int32 phone = phones[i]; + const TopologyEntry &entry = TopologyForPhone(phone); + for (int32 j = 0; j < static_cast(entry.size()); j++) { // for each state... + int32 forward_pdf_class = entry[j].forward_pdf_class, + self_loop_pdf_class = entry[j].self_loop_pdf_class; + if (forward_pdf_class != self_loop_pdf_class) + return false; + } + } + return true; +} + const HmmTopology::TopologyEntry& HmmTopology::TopologyForPhone(int32 phone) const { // Will throw if phone not covered. if (static_cast(phone) >= phone2idx_.size() || phone2idx_[phone] == -1) { KALDI_ERR << "TopologyForPhone(), phone "<<(phone)<<" not covered."; @@ -286,8 +339,10 @@ int32 HmmTopology::NumPdfClasses(int32 phone) const { // will throw if phone not covered. const TopologyEntry &entry = TopologyForPhone(phone); int32 max_pdf_class = 0; - for (size_t i = 0; i < entry.size(); i++) - max_pdf_class = std::max(max_pdf_class, entry[i].pdf_class); + for (size_t i = 0; i < entry.size(); i++) { + max_pdf_class = std::max(max_pdf_class, entry[i].forward_pdf_class); + max_pdf_class = std::max(max_pdf_class, entry[i].self_loop_pdf_class); + } return max_pdf_class+1; } @@ -299,7 +354,7 @@ int32 HmmTopology::MinLength(int32 phone) const { std::numeric_limits::max()); KALDI_ASSERT(!entry.empty()); - min_length[0] = (entry[0].pdf_class == -1 ? 0 : 1); + min_length[0] = (entry[0].forward_pdf_class == -1 ? 0 : 1); int32 num_states = min_length.size(); bool changed = true; while (changed) { @@ -313,7 +368,7 @@ int32 HmmTopology::MinLength(int32 phone) const { int32 next_state = iter->first; KALDI_ASSERT(next_state < num_states); int32 next_state_min_length = min_length[s] + - (entry[next_state].pdf_class == -1 ? 0 : 1); + (entry[next_state].forward_pdf_class == -1 ? 0 : 1); if (next_state_min_length < min_length[next_state]) { min_length[next_state] = next_state_min_length; if (next_state < s) diff --git a/src/hmm/hmm-topology.h b/src/hmm/hmm-topology.h index 79b535e7d6b..edea02998c0 100644 --- a/src/hmm/hmm-topology.h +++ b/src/hmm/hmm-topology.h @@ -95,23 +95,38 @@ class HmmTopology { public: /// A structure defined inside HmmTopology to represent a HMM state. struct HmmState { - /// The \ref pdf_class pdf-class, typically 0, 1 or 2 (the same as the HMM-state index), + /// The \ref pdf_class forward-pdf-class, typically 0, 1 or 2 (the same as the HMM-state index), /// but may be different to enable us to hardwire sharing of state, and may be /// equal to \ref kNoPdf == -1 in order to specify nonemitting states (unusual). - int32 pdf_class; + int32 forward_pdf_class; + + /// The \ref pdf_class self-loop pdf-class, similar to \ref pdf_class forward-pdf-class. + /// They will either both be \ref kNoPdf, or neither be \ref kNoPdf. + int32 self_loop_pdf_class; /// A list of transitions, indexed by what we call a 'transition-index'. /// The first member of each pair is the index of the next HmmState, and the /// second is the default transition probability (before training). std::vector > transitions; - explicit HmmState(int32 p): pdf_class(p) { } + explicit HmmState(int32 pdf_class) { + this->forward_pdf_class = pdf_class; + this->self_loop_pdf_class = pdf_class; + } + explicit HmmState(int32 forward_pdf_class, int32 self_loop_pdf_class) { + KALDI_ASSERT((forward_pdf_class != kNoPdf && self_loop_pdf_class != kNoPdf) || + (forward_pdf_class == kNoPdf && self_loop_pdf_class == kNoPdf)); + this->forward_pdf_class = forward_pdf_class; + this->self_loop_pdf_class = self_loop_pdf_class; + } bool operator == (const HmmState &other) const { - return (pdf_class == other.pdf_class && transitions == other.transitions); + return (forward_pdf_class == other.forward_pdf_class && + self_loop_pdf_class == other.self_loop_pdf_class && + transitions == other.transitions); } - HmmState(): pdf_class(-1) { } + HmmState(): forward_pdf_class(-1), self_loop_pdf_class(-1) { } }; /// TopologyEntry is a typedef that represents the topology of @@ -124,6 +139,15 @@ class HmmTopology { // Checks that the object is valid, and throw exception otherwise. void Check(); + /// Returns true if this HmmTopology is really 'hmm-like', i.e. the pdf-class on + /// the self-loops and forward transitions of all states are identical. [note: in HMMs, + /// the densities are associated with the states.] We have extended this to + /// support 'non-hmm-like' topologies (where those pdf-classes are different), + /// in order to make for more compact decoding graphs in our so-called 'chain models' + /// (AKA lattice-free MMI), where we use 1-state topologies that have different pdf-classes + /// for the self-loop and the forward transition. Note that we always use the 'reorder=true' + /// option so the 'forward transition' actually comes before the self-loop. + bool IsHmm() const; /// Returns the topology entry (i.e. vector of HmmState) for this phone; /// will throw exception if phone not covered by the topology. diff --git a/src/hmm/hmm-utils.cc b/src/hmm/hmm-utils.cc index 04ec09d14b7..ab0b133f708 100644 --- a/src/hmm/hmm-utils.cc +++ b/src/hmm/hmm-utils.cc @@ -93,11 +93,16 @@ fst::VectorFst *GetHmmAsFst( for (int32 hmm_state = 0; hmm_state < static_cast(entry.size()); hmm_state++) { - int32 pdf_class = entry[hmm_state].pdf_class, pdf; - if (pdf_class == kNoPdf) pdf = kNoPdf; // nonemitting state. - else { - KALDI_ASSERT(pdf_class < static_cast(pdfs.size())); - pdf = pdfs[pdf_class]; + int32 forward_pdf_class = entry[hmm_state].forward_pdf_class, forward_pdf; + int32 self_loop_pdf_class = entry[hmm_state].self_loop_pdf_class, self_loop_pdf; + if (forward_pdf_class == kNoPdf) { // nonemitting state. + forward_pdf = kNoPdf; + self_loop_pdf = kNoPdf; + } else { + KALDI_ASSERT(forward_pdf_class < static_cast(pdfs.size())); + KALDI_ASSERT(self_loop_pdf_class < static_cast(pdfs.size())); + forward_pdf = pdfs[forward_pdf_class]; + self_loop_pdf = pdfs[self_loop_pdf_class]; } int32 trans_idx; for (trans_idx = 0; @@ -110,7 +115,7 @@ fst::VectorFst *GetHmmAsFst( if (is_self_loop) continue; // We will add self-loops in at a later stage of processing, // not in this function. - if (pdf_class == kNoPdf) { + if (forward_pdf_class == kNoPdf) { // no pdf, hence non-estimated probability. // [would not happen with normal topology] . There is no transition-state // involved in this case. @@ -118,7 +123,7 @@ fst::VectorFst *GetHmmAsFst( label = 0; } else { // normal probability. int32 trans_state = - trans_model.TripleToTransitionState(phone, hmm_state, pdf); + trans_model.TupleToTransitionState(phone, hmm_state, forward_pdf, self_loop_pdf); int32 trans_id = trans_model.PairToTransitionId(trans_state, trans_idx); log_prob = trans_model.GetTransitionLogProbIgnoringSelfLoops(trans_id); @@ -183,10 +188,15 @@ GetHmmAsFstSimple(std::vector phone_window, for (int32 hmm_state = 0; hmm_state < static_cast(entry.size()); hmm_state++) { - int32 pdf_class = entry[hmm_state].pdf_class, pdf; - if (pdf_class == kNoPdf) pdf = kNoPdf; // nonemitting state; not generally used. - else { - bool ans = ctx_dep.Compute(phone_window, pdf_class, &pdf); + int32 forward_pdf_class = entry[hmm_state].forward_pdf_class, forward_pdf; + int32 self_loop_pdf_class = entry[hmm_state].self_loop_pdf_class, self_loop_pdf; + if (forward_pdf_class == kNoPdf) { // nonemitting state; not generally used. + forward_pdf = kNoPdf; + self_loop_pdf = kNoPdf; + } else { + bool ans = ctx_dep.Compute(phone_window, forward_pdf_class, &forward_pdf); + KALDI_ASSERT(ans && "Context-dependency computation failed."); + ans = ctx_dep.Compute(phone_window, self_loop_pdf_class, &self_loop_pdf); KALDI_ASSERT(ans && "Context-dependency computation failed."); } int32 trans_idx; @@ -196,7 +206,7 @@ GetHmmAsFstSimple(std::vector phone_window, BaseFloat log_prob; Label label; int32 dest_state = entry[hmm_state].transitions[trans_idx].first; - if (pdf_class == kNoPdf) { + if (forward_pdf_class == kNoPdf) { // no pdf, hence non-estimated probability. very unusual case. [would // not happen with normal topology] . There is no transition-state // involved in this case. @@ -205,7 +215,7 @@ GetHmmAsFstSimple(std::vector phone_window, label = 0; } else { // normal probability. int32 trans_state = - trans_model.TripleToTransitionState(phone, hmm_state, pdf); + trans_model.TupleToTransitionState(phone, hmm_state, forward_pdf, self_loop_pdf); int32 trans_id = trans_model.PairToTransitionId(trans_state, trans_idx); log_prob = prob_scale * trans_model.GetTransitionLogProb(trans_id); @@ -652,8 +662,8 @@ static bool SplitToPhonesInternal(const TransitionModel &trans_model, int32 trans_state = trans_model.TransitionIdToTransitionState(alignment[cur_point]); int32 phone = trans_model.TransitionStateToPhone(trans_state); - int32 pdf_class = trans_model.GetTopo().TopologyForPhone(phone)[0].pdf_class; - if (pdf_class != kNoPdf) // initial-state of the current phone is emitting + int32 forward_pdf_class = trans_model.GetTopo().TopologyForPhone(phone)[0].forward_pdf_class; + if (forward_pdf_class != kNoPdf) // initial-state of the current phone is emitting if (trans_model.TransitionStateToHmmState(trans_state) != 0) was_ok = false; for (size_t j = cur_point; j < end_points[i]; j++) @@ -739,14 +749,19 @@ static inline void ConvertAlignmentForPhone( // the topologies and lengths match -> we can directly transfer // the alignment. for (int32 j = 0; j < alignment_size; j++) { - int32 old_tid = old_phone_alignment[j]; - int32 pdf_class = old_trans_model.TransitionIdToPdfClass(old_tid); + int32 old_tid = old_phone_alignment[j], + old_tstate = old_trans_model.TransitionIdToTransitionState(old_tid); + int32 forward_pdf_class = + old_trans_model.TransitionStateToForwardPdfClass(old_tstate), + self_loop_pdf_class = + old_trans_model.TransitionStateToSelfLoopPdfClass(old_tstate); int32 hmm_state = old_trans_model.TransitionIdToHmmState(old_tid); int32 trans_idx = old_trans_model.TransitionIdToTransitionIndex(old_tid); - int32 new_pdf = pdf_ids[pdf_class]; + int32 new_forward_pdf = pdf_ids[forward_pdf_class]; + int32 new_self_loop_pdf = pdf_ids[self_loop_pdf_class]; int32 new_trans_state = - new_trans_model.TripleToTransitionState(new_central_phone, hmm_state, - new_pdf); + new_trans_model.TupleToTransitionState(new_central_phone, hmm_state, + new_forward_pdf, new_self_loop_pdf); int32 new_tid = new_trans_model.PairToTransitionId(new_trans_state, trans_idx); (*new_phone_alignment)[j] = new_tid; diff --git a/src/hmm/transition-model.cc b/src/hmm/transition-model.cc index df22169cd25..83edbaf5805 100644 --- a/src/hmm/transition-model.cc +++ b/src/hmm/transition-model.cc @@ -24,13 +24,26 @@ namespace kaldi { -void TransitionModel::ComputeTriples(const ContextDependencyInterface &ctx_dep) { +void TransitionModel::ComputeTuples(const ContextDependencyInterface &ctx_dep) { + if (IsHmm()) + ComputeTuplesIsHmm(ctx_dep); + else + ComputeTuplesNotHmm(ctx_dep); + + // now tuples_ is populated with all possible tuples of (phone, hmm_state, pdf, self_loop_pdf). + std::sort(tuples_.begin(), tuples_.end()); // sort to enable reverse lookup. + // this sorting defines the transition-ids. +} + +void TransitionModel::ComputeTuplesIsHmm(const ContextDependencyInterface &ctx_dep) { const std::vector &phones = topo_.GetPhones(); - std::vector > > pdf_info; KALDI_ASSERT(!phones.empty()); + + // this is the case for normal models. but not fot chain models + std::vector > > pdf_info; std::vector num_pdf_classes( 1 + *std::max_element(phones.begin(), phones.end()), -1); for (size_t i = 0; i < phones.size(); i++) - num_pdf_classes[phones[i]] = topo_.NumPdfClasses(phones[i]); + num_pdf_classes[phones[i]] = topo_.NumPdfClasses(phones[i]); ctx_dep.GetPdfInfo(phones, num_pdf_classes, &pdf_info); // pdf_info is list indexed by pdf of which (phone, pdf_class) it // can correspond to. @@ -43,47 +56,108 @@ void TransitionModel::ComputeTriples(const ContextDependencyInterface &ctx_dep) int32 phone = phones[i]; const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(phone); for (int32 j = 0; j < static_cast(entry.size()); j++) { // for each state... - int32 pdf_class = entry[j].pdf_class; + int32 pdf_class = entry[j].forward_pdf_class; if (pdf_class != kNoPdf) { to_hmm_state_list[std::make_pair(phone, pdf_class)].push_back(j); } } } + for (int32 pdf = 0; pdf < static_cast(pdf_info.size()); pdf++) { for (size_t j = 0; j < pdf_info[pdf].size(); j++) { int32 phone = pdf_info[pdf][j].first, - pdf_class = pdf_info[pdf][j].second; + pdf_class = pdf_info[pdf][j].second; const std::vector &state_vec = to_hmm_state_list[std::make_pair(phone, pdf_class)]; KALDI_ASSERT(!state_vec.empty()); // state_vec is a list of the possible HMM-states that emit this // pdf_class. for (size_t k = 0; k < state_vec.size(); k++) { int32 hmm_state = state_vec[k]; - triples_.push_back(Triple(phone, hmm_state, pdf)); + tuples_.push_back(Tuple(phone, hmm_state, pdf, pdf)); } } } +} - // now triples_ is populated with all possible triples of (phone, hmm_state, pdf). - std::sort(triples_.begin(), triples_.end()); // sort to enable reverse lookup. - // this sorting defines the transition-ids. +void TransitionModel::ComputeTuplesNotHmm(const ContextDependencyInterface &ctx_dep) { + const std::vector &phones = topo_.GetPhones(); + KALDI_ASSERT(!phones.empty()); + + // pdf_info is a set of lists indexed by phone. Each list is indexed by + // (pdf-class, self-loop pdf-class) of each state of that phone, and the element + // is a list of possible (pdf, self-loop pdf) pairs that that (pdf-class, self-loop pdf-class) + // pair generates. + std::vector > > > pdf_info; + // pdf_class_pairs is a set of lists indexed by phone. Each list stores + // (pdf-class, self-loop pdf-class) of each state of that phone. + std::vector > > pdf_class_pairs; + pdf_class_pairs.resize(1 + *std::max_element(phones.begin(), phones.end())); + for (size_t i = 0; i < phones.size(); i++) { + int32 phone = phones[i]; + const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(phone); + for (int32 j = 0; j < static_cast(entry.size()); j++) { // for each state... + int32 forward_pdf_class = entry[j].forward_pdf_class, self_loop_pdf_class = entry[j].self_loop_pdf_class; + if (forward_pdf_class != kNoPdf) + pdf_class_pairs[phone].push_back(std::make_pair(forward_pdf_class, self_loop_pdf_class)); + } + } + ctx_dep.GetPdfInfo(phones, pdf_class_pairs, &pdf_info); + + std::vector, std::vector > > to_hmm_state_list; + to_hmm_state_list.resize(1 + *std::max_element(phones.begin(), phones.end())); + // to_hmm_state_list is a phone-indexed set of maps from (pdf-class, self-loop pdf_class) to the list + // of hmm-states in the HMM for that phone that that (pdf-class, self-loop pdf-class) + // can correspond to. + for (size_t i = 0; i < phones.size(); i++) { // setting up to_hmm_state_list. + int32 phone = phones[i]; + const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(phone); + std::map, std::vector > phone_to_hmm_state_list; + for (int32 j = 0; j < static_cast(entry.size()); j++) { // for each state... + int32 forward_pdf_class = entry[j].forward_pdf_class, self_loop_pdf_class = entry[j].self_loop_pdf_class; + if (forward_pdf_class != kNoPdf) { + phone_to_hmm_state_list[std::make_pair(forward_pdf_class, self_loop_pdf_class)].push_back(j); + } + } + to_hmm_state_list[phone] = phone_to_hmm_state_list; + } + + for (int32 i = 0; i < phones.size(); i++) { + int32 phone = phones[i]; + for (int32 j = 0; j < static_cast(pdf_info[phone].size()); j++) { + int32 pdf_class = pdf_class_pairs[phone][j].first, + self_loop_pdf_class = pdf_class_pairs[phone][j].second; + const std::vector &state_vec = + to_hmm_state_list[phone][std::make_pair(pdf_class, self_loop_pdf_class)]; + KALDI_ASSERT(!state_vec.empty()); + for (size_t k = 0; k < state_vec.size(); k++) { + int32 hmm_state = state_vec[k]; + for (size_t m = 0; m < pdf_info[phone][j].size(); m++) { + int32 pdf = pdf_info[phone][j][m].first, + self_loop_pdf = pdf_info[phone][j][m].second; + tuples_.push_back(Tuple(phone, hmm_state, pdf, self_loop_pdf)); + } + } + } + } } void TransitionModel::ComputeDerived() { - state2id_.resize(triples_.size()+2); // indexed by transition-state, which + state2id_.resize(tuples_.size()+2); // indexed by transition-state, which // is one based, but also an entry for one past end of list. int32 cur_transition_id = 1; num_pdfs_ = 0; for (int32 tstate = 1; - tstate <= static_cast(triples_.size()+1); // not a typo. + tstate <= static_cast(tuples_.size()+1); // not a typo. tstate++) { state2id_[tstate] = cur_transition_id; - if (static_cast(tstate) <= triples_.size()) { - int32 phone = triples_[tstate-1].phone, - hmm_state = triples_[tstate-1].hmm_state, - pdf = triples_[tstate-1].pdf; - num_pdfs_ = std::max(num_pdfs_, 1+pdf); + if (static_cast(tstate) <= tuples_.size()) { + int32 phone = tuples_[tstate-1].phone, + hmm_state = tuples_[tstate-1].hmm_state, + forward_pdf = tuples_[tstate-1].forward_pdf, + self_loop_pdf = tuples_[tstate-1].self_loop_pdf; + num_pdfs_ = std::max(num_pdfs_, 1 + forward_pdf); + num_pdfs_ = std::max(num_pdfs_, 1 + self_loop_pdf); const HmmTopology::HmmState &state = topo_.TopologyForPhone(phone)[hmm_state]; int32 my_num_ids = static_cast(state.transitions.size()); cur_transition_id += my_num_ids; // # trans out of this state. @@ -91,20 +165,26 @@ void TransitionModel::ComputeDerived() { } id2state_.resize(cur_transition_id); // cur_transition_id is #transition-ids+1. - for (int32 tstate = 1; tstate <= static_cast(triples_.size()); tstate++) - for (int32 tid = state2id_[tstate]; tid < state2id_[tstate+1]; tid++) + id2pdf_id_.resize(cur_transition_id); + for (int32 tstate = 1; tstate <= static_cast(tuples_.size()); tstate++) + for (int32 tid = state2id_[tstate]; tid < state2id_[tstate+1]; tid++) { id2state_[tid] = tstate; - + if (IsSelfLoop(tid)) + id2pdf_id_[tid] = tuples_[tstate-1].self_loop_pdf; + else + id2pdf_id_[tid] = tuples_[tstate-1].forward_pdf; + } } + void TransitionModel::InitializeProbs() { log_probs_.Resize(NumTransitionIds()+1); // one-based array, zeroth element empty. for (int32 trans_id = 1; trans_id <= NumTransitionIds(); trans_id++) { int32 trans_state = id2state_[trans_id]; int32 trans_index = trans_id - state2id_[trans_state]; - const Triple &triple = triples_[trans_state-1]; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(triple.phone); - KALDI_ASSERT(static_cast(triple.hmm_state) < entry.size()); - BaseFloat prob = entry[triple.hmm_state].transitions[trans_index].second; + const Tuple &tuple = tuples_[trans_state-1]; + const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(tuple.phone); + KALDI_ASSERT(static_cast(tuple.hmm_state) < entry.size()); + BaseFloat prob = entry[tuple.hmm_state].transitions[trans_index].second; if (prob <= 0.0) KALDI_ERR << "TransitionModel::InitializeProbs, zero " "probability [should remove that entry in the topology]"; @@ -129,40 +209,55 @@ void TransitionModel::Check() const { KALDI_ASSERT(tid == PairToTransitionId(tstate, index)); int32 phone = TransitionStateToPhone(tstate), hmm_state = TransitionStateToHmmState(tstate), - pdf = TransitionStateToPdf(tstate); - KALDI_ASSERT(tstate == TripleToTransitionState(phone, hmm_state, pdf)); + forward_pdf = TransitionStateToForwardPdf(tstate), + self_loop_pdf = TransitionStateToSelfLoopPdf(tstate); + KALDI_ASSERT(tstate == TupleToTransitionState(phone, hmm_state, forward_pdf, self_loop_pdf)); KALDI_ASSERT(log_probs_(tid) <= 0.0 && log_probs_(tid) - log_probs_(tid) == 0.0); // checking finite and non-positive (and not out-of-bounds). } } +bool TransitionModel::IsHmm() const { + const std::vector &phones = topo_.GetPhones(); + KALDI_ASSERT(!phones.empty()); + for (size_t i = 0; i < phones.size(); i++) { + int32 phone = phones[i]; + const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(phone); + for (int32 j = 0; j < static_cast(entry.size()); j++) { // for each state... + if (entry[j].forward_pdf_class != entry[j].self_loop_pdf_class) + return false; + } + } + return true; +} + TransitionModel::TransitionModel(const ContextDependencyInterface &ctx_dep, const HmmTopology &hmm_topo): topo_(hmm_topo) { - // First thing is to get all possible triples. - ComputeTriples(ctx_dep); + // First thing is to get all possible tuples. + ComputeTuples(ctx_dep); ComputeDerived(); InitializeProbs(); Check(); } -int32 TransitionModel::TripleToTransitionState(int32 phone, int32 hmm_state, int32 pdf) const { - Triple triple(phone, hmm_state, pdf); +int32 TransitionModel::TupleToTransitionState(int32 phone, int32 hmm_state, int32 pdf, int32 self_loop_pdf) const { + Tuple tuple(phone, hmm_state, pdf, self_loop_pdf); // Note: if this ever gets too expensive, which is unlikely, we can refactor // this code to sort first on pdf, and then index on pdf, so those // that have the same pdf are in a contiguous range. - std::vector::const_iterator iter = - std::lower_bound(triples_.begin(), triples_.end(), triple); - if (iter == triples_.end() || !(*iter == triple)) { - KALDI_ERR << "TransitionModel::TripleToTransitionState, triple not found." + std::vector::const_iterator iter = + std::lower_bound(tuples_.begin(), tuples_.end(), tuple); + if (iter == tuples_.end() || !(*iter == tuple)) { + KALDI_ERR << "TransitionModel::TupleToTransitionState, tuple not found." << " (incompatible tree and model?)"; } - // triples_ is indexed by transition_state-1, so add one. - return static_cast((iter - triples_.begin())) + 1; + // tuples_ is indexed by transition_state-1, so add one. + return static_cast((iter - tuples_.begin())) + 1; } int32 TransitionModel::NumTransitionIndices(int32 trans_state) const { - KALDI_ASSERT(static_cast(trans_state) <= triples_.size()); + KALDI_ASSERT(static_cast(trans_state) <= tuples_.size()); return static_cast(state2id_[trans_state+1]-state2id_[trans_state]); } @@ -177,32 +272,57 @@ int32 TransitionModel::TransitionIdToTransitionIndex(int32 trans_id) const { } int32 TransitionModel::TransitionStateToPhone(int32 trans_state) const { - KALDI_ASSERT(static_cast(trans_state) <= triples_.size()); - return triples_[trans_state-1].phone; + KALDI_ASSERT(static_cast(trans_state) <= tuples_.size()); + return tuples_[trans_state-1].phone; } -int32 TransitionModel::TransitionStateToPdf(int32 trans_state) const { - KALDI_ASSERT(static_cast(trans_state) <= triples_.size()); - return triples_[trans_state-1].pdf; +int32 TransitionModel::TransitionStateToForwardPdf(int32 trans_state) const { + KALDI_ASSERT(static_cast(trans_state) <= tuples_.size()); + return tuples_[trans_state-1].forward_pdf; +} + +int32 TransitionModel::TransitionStateToForwardPdfClass( + int32 trans_state) const { + KALDI_ASSERT(static_cast(trans_state) <= tuples_.size()); + const Tuple &t = tuples_[trans_state-1]; + const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(t.phone); + KALDI_ASSERT(static_cast(t.hmm_state) < entry.size()); + return entry[t.hmm_state].forward_pdf_class; +} + + +int32 TransitionModel::TransitionStateToSelfLoopPdfClass( + int32 trans_state) const { + KALDI_ASSERT(static_cast(trans_state) <= tuples_.size()); + const Tuple &t = tuples_[trans_state-1]; + const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(t.phone); + KALDI_ASSERT(static_cast(t.hmm_state) < entry.size()); + return entry[t.hmm_state].self_loop_pdf_class; +} + + +int32 TransitionModel::TransitionStateToSelfLoopPdf(int32 trans_state) const { + KALDI_ASSERT(static_cast(trans_state) <= tuples_.size()); + return tuples_[trans_state-1].self_loop_pdf; } int32 TransitionModel::TransitionStateToHmmState(int32 trans_state) const { - KALDI_ASSERT(static_cast(trans_state) <= triples_.size()); - return triples_[trans_state-1].hmm_state; + KALDI_ASSERT(static_cast(trans_state) <= tuples_.size()); + return tuples_[trans_state-1].hmm_state; } int32 TransitionModel::PairToTransitionId(int32 trans_state, int32 trans_index) const { - KALDI_ASSERT(static_cast(trans_state) <= triples_.size()); + KALDI_ASSERT(static_cast(trans_state) <= tuples_.size()); KALDI_ASSERT(trans_index < state2id_[trans_state+1] - state2id_[trans_state]); return state2id_[trans_state] + trans_index; } int32 TransitionModel::NumPhones() const { - int32 num_trans_state = triples_.size(); + int32 num_trans_state = tuples_.size(); int32 max_phone_id = 0; for (int32 i = 0; i < num_trans_state; ++i) { - if (triples_[i].phone > max_phone_id) - max_phone_id = triples_[i].phone; + if (tuples_[i].phone > max_phone_id) + max_phone_id = tuples_[i].phone; } return max_phone_id; } @@ -212,36 +332,25 @@ bool TransitionModel::IsFinal(int32 trans_id) const { KALDI_ASSERT(static_cast(trans_id) < id2state_.size()); int32 trans_state = id2state_[trans_id]; int32 trans_index = trans_id - state2id_[trans_state]; - const Triple &triple = triples_[trans_state-1]; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(triple.phone); - KALDI_ASSERT(static_cast(triple.hmm_state) < entry.size()); - KALDI_ASSERT(static_cast(triple.hmm_state) < entry.size()); + const Tuple &tuple = tuples_[trans_state-1]; + const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(tuple.phone); + KALDI_ASSERT(static_cast(tuple.hmm_state) < entry.size()); + KALDI_ASSERT(static_cast(tuple.hmm_state) < entry.size()); KALDI_ASSERT(static_cast(trans_index) < - entry[triple.hmm_state].transitions.size()); + entry[tuple.hmm_state].transitions.size()); // return true if the transition goes to the final state of the // topology entry. - return (entry[triple.hmm_state].transitions[trans_index].first + 1 == + return (entry[tuple.hmm_state].transitions[trans_index].first + 1 == static_cast(entry.size())); } -bool TransitionModel::IsSelfLoop(int32 trans_id) const { - KALDI_ASSERT(static_cast(trans_id) < id2state_.size()); - int32 trans_state = id2state_[trans_id]; - int32 trans_index = trans_id - state2id_[trans_state]; - const Triple &triple = triples_[trans_state-1]; - int32 phone = triple.phone, hmm_state = triple.hmm_state; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(phone); - KALDI_ASSERT(static_cast(hmm_state) < entry.size()); - return (static_cast(trans_index) < entry[hmm_state].transitions.size() - && entry[hmm_state].transitions[trans_index].first == hmm_state); -} int32 TransitionModel::SelfLoopOf(int32 trans_state) const { // returns the self-loop transition-id, - KALDI_ASSERT(static_cast(trans_state-1) < triples_.size()); - const Triple &triple = triples_[trans_state-1]; + KALDI_ASSERT(static_cast(trans_state-1) < tuples_.size()); + const Tuple &tuple = tuples_[trans_state-1]; // or zero if does not exist. - int32 phone = triple.phone, hmm_state = triple.hmm_state; + int32 phone = tuple.phone, hmm_state = tuple.hmm_state; const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(phone); KALDI_ASSERT(static_cast(hmm_state) < entry.size()); for (int32 trans_index = 0; @@ -274,16 +383,22 @@ void TransitionModel::ComputeDerivedOfProbs() { void TransitionModel::Read(std::istream &is, bool binary) { ExpectToken(is, binary, ""); topo_.Read(is, binary); - ExpectToken(is, binary, ""); + std::string token; + ReadToken(is, binary, &token); int32 size; ReadBasicType(is, binary, &size); - triples_.resize(size); + tuples_.resize(size); for (int32 i = 0; i < size; i++) { - ReadBasicType(is, binary, &(triples_[i].phone)); - ReadBasicType(is, binary, &(triples_[i].hmm_state)); - ReadBasicType(is, binary, &(triples_[i].pdf)); + ReadBasicType(is, binary, &(tuples_[i].phone)); + ReadBasicType(is, binary, &(tuples_[i].hmm_state)); + ReadBasicType(is, binary, &(tuples_[i].forward_pdf)); + if (token == "") + ReadBasicType(is, binary, &(tuples_[i].self_loop_pdf)); + else if (token == "") + tuples_[i].self_loop_pdf = tuples_[i].forward_pdf; } - ExpectToken(is, binary, ""); + ReadToken(is, binary, &token); + KALDI_ASSERT(token == "" || token == ""); ComputeDerived(); ExpectToken(is, binary, ""); log_probs_.Read(is, binary); @@ -294,19 +409,28 @@ void TransitionModel::Read(std::istream &is, bool binary) { } void TransitionModel::Write(std::ostream &os, bool binary) const { + bool is_hmm = IsHmm(); WriteToken(os, binary, ""); if (!binary) os << "\n"; topo_.Write(os, binary); - WriteToken(os, binary, ""); - WriteBasicType(os, binary, static_cast(triples_.size())); + if (is_hmm) + WriteToken(os, binary, ""); + else + WriteToken(os, binary, ""); + WriteBasicType(os, binary, static_cast(tuples_.size())); if (!binary) os << "\n"; - for (int32 i = 0; i < static_cast (triples_.size()); i++) { - WriteBasicType(os, binary, triples_[i].phone); - WriteBasicType(os, binary, triples_[i].hmm_state); - WriteBasicType(os, binary, triples_[i].pdf); + for (int32 i = 0; i < static_cast (tuples_.size()); i++) { + WriteBasicType(os, binary, tuples_[i].phone); + WriteBasicType(os, binary, tuples_[i].hmm_state); + WriteBasicType(os, binary, tuples_[i].forward_pdf); + if (!is_hmm) + WriteBasicType(os, binary, tuples_[i].self_loop_pdf); if (!binary) os << "\n"; } - WriteToken(os, binary, ""); + if (is_hmm) + WriteToken(os, binary, ""); + else + WriteToken(os, binary, ""); if (!binary) os << "\n"; WriteToken(os, binary, ""); if (!binary) os << "\n"; @@ -473,8 +597,12 @@ void TransitionModel::MleUpdateShared(const Vector &stats, std::map > pdf_to_tstate; for (int32 tstate = 1; tstate <= NumTransitionStates(); tstate++) { - int32 pdf = TransitionStateToPdf(tstate); + int32 pdf = TransitionStateToForwardPdf(tstate); pdf_to_tstate[pdf].insert(tstate); + if (!IsHmm()) { + pdf = TransitionStateToSelfLoopPdf(tstate); + pdf_to_tstate[pdf].insert(tstate); + } } std::map >::iterator map_iter; for (map_iter = pdf_to_tstate.begin(); @@ -567,8 +695,12 @@ void TransitionModel::MapUpdateShared(const Vector &stats, std::map > pdf_to_tstate; for (int32 tstate = 1; tstate <= NumTransitionStates(); tstate++) { - int32 pdf = TransitionStateToPdf(tstate); + int32 pdf = TransitionStateToForwardPdf(tstate); pdf_to_tstate[pdf].insert(tstate); + if (!IsHmm()) { + pdf = TransitionStateToSelfLoopPdf(tstate); + pdf_to_tstate[pdf].insert(tstate); + } } std::map >::iterator map_iter; for (map_iter = pdf_to_tstate.begin(); @@ -642,24 +774,27 @@ void TransitionModel::MapUpdateShared(const Vector &stats, int32 TransitionModel::TransitionIdToPhone(int32 trans_id) const { KALDI_ASSERT(trans_id != 0 && static_cast(trans_id) < id2state_.size()); int32 trans_state = id2state_[trans_id]; - return triples_[trans_state-1].phone; + return tuples_[trans_state-1].phone; } int32 TransitionModel::TransitionIdToPdfClass(int32 trans_id) const { KALDI_ASSERT(trans_id != 0 && static_cast(trans_id) < id2state_.size()); int32 trans_state = id2state_[trans_id]; - const Triple &t = triples_[trans_state-1]; + const Tuple &t = tuples_[trans_state-1]; const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(t.phone); KALDI_ASSERT(static_cast(t.hmm_state) < entry.size()); - return entry[t.hmm_state].pdf_class; + if (IsSelfLoop(trans_id)) + return entry[t.hmm_state].self_loop_pdf_class; + else + return entry[t.hmm_state].forward_pdf_class; } int32 TransitionModel::TransitionIdToHmmState(int32 trans_id) const { KALDI_ASSERT(trans_id != 0 && static_cast(trans_id) < id2state_.size()); int32 trans_state = id2state_[trans_id]; - const Triple &t = triples_[trans_state-1]; + const Tuple &t = tuples_[trans_state-1]; return t.hmm_state; } @@ -668,23 +803,34 @@ void TransitionModel::Print(std::ostream &os, const Vector *occs) { if (occs != NULL) KALDI_ASSERT(occs->Dim() == NumPdfs()); + bool is_hmm = IsHmm(); for (int32 tstate = 1; tstate <= NumTransitionStates(); tstate++) { - const Triple &triple = triples_[tstate-1]; - KALDI_ASSERT(static_cast(triple.phone) < phone_names.size()); - std::string phone_name = phone_names[triple.phone]; + const Tuple &tuple = tuples_[tstate-1]; + KALDI_ASSERT(static_cast(tuple.phone) < phone_names.size()); + std::string phone_name = phone_names[tuple.phone]; os << "Transition-state " << tstate << ": phone = " << phone_name - << " hmm-state = " << triple.hmm_state << " pdf = " << triple.pdf << '\n'; + << " hmm-state = " << tuple.hmm_state; + if (is_hmm) + os << " pdf = " << tuple.forward_pdf << '\n'; + else + os << " forward-pdf = " << tuple.forward_pdf << " self-loop-pdf = " + << tuple.self_loop_pdf << '\n'; for (int32 tidx = 0; tidx < NumTransitionIndices(tstate); tidx++) { int32 tid = PairToTransitionId(tstate, tidx); BaseFloat p = GetTransitionProb(tid); os << " Transition-id = " << tid << " p = " << p; - if (occs != NULL) os << " count of pdf = " << (*occs)(triple.pdf); + if (occs != NULL) { + if (IsSelfLoop(tid)) + os << " count of pdf = " << (*occs)(tuple.self_loop_pdf); + else + os << " count of pdf = " << (*occs)(tuple.forward_pdf); + } // now describe what it's a transition to. if (IsSelfLoop(tid)) os << " [self-loop]\n"; else { - int32 hmm_state = triple.hmm_state; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(triple.phone); + int32 hmm_state = tuple.hmm_state; + const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(tuple.phone); KALDI_ASSERT(static_cast(hmm_state) < entry.size()); int32 next_hmm_state = entry[hmm_state].transitions[tidx].first; KALDI_ASSERT(next_hmm_state != hmm_state); @@ -702,14 +848,18 @@ bool GetPdfsForPhones(const TransitionModel &trans_model, pdfs->clear(); for (int32 tstate = 1; tstate <= trans_model.NumTransitionStates(); tstate++) { if (std::binary_search(phones.begin(), phones.end(), - trans_model.TransitionStateToPhone(tstate))) - pdfs->push_back(trans_model.TransitionStateToPdf(tstate)); + trans_model.TransitionStateToPhone(tstate))) { + pdfs->push_back(trans_model.TransitionStateToForwardPdf(tstate)); + pdfs->push_back(trans_model.TransitionStateToSelfLoopPdf(tstate)); + } } SortAndUniq(pdfs); for (int32 tstate = 1; tstate <= trans_model.NumTransitionStates(); tstate++) - if (std::binary_search(pdfs->begin(), pdfs->end(), - trans_model.TransitionStateToPdf(tstate)) + if ((std::binary_search(pdfs->begin(), pdfs->end(), + trans_model.TransitionStateToForwardPdf(tstate)) || + std::binary_search(pdfs->begin(), pdfs->end(), + trans_model.TransitionStateToSelfLoopPdf(tstate))) && !std::binary_search(phones.begin(), phones.end(), trans_model.TransitionStateToPhone(tstate))) return false; @@ -724,7 +874,9 @@ bool GetPhonesForPdfs(const TransitionModel &trans_model, phones->clear(); for (int32 tstate = 1; tstate <= trans_model.NumTransitionStates(); tstate++) { if (std::binary_search(pdfs.begin(), pdfs.end(), - trans_model.TransitionStateToPdf(tstate))) + trans_model.TransitionStateToForwardPdf(tstate)) || + std::binary_search(pdfs.begin(), pdfs.end(), + trans_model.TransitionStateToSelfLoopPdf(tstate))) phones->push_back(trans_model.TransitionStateToPhone(tstate)); } SortAndUniq(phones); @@ -732,16 +884,30 @@ bool GetPhonesForPdfs(const TransitionModel &trans_model, for (int32 tstate = 1; tstate <= trans_model.NumTransitionStates(); tstate++) if (std::binary_search(phones->begin(), phones->end(), trans_model.TransitionStateToPhone(tstate)) - && !std::binary_search(pdfs.begin(), pdfs.end(), - trans_model.TransitionStateToPdf(tstate))) + && !(std::binary_search(pdfs.begin(), pdfs.end(), + trans_model.TransitionStateToForwardPdf(tstate)) && + std::binary_search(pdfs.begin(), pdfs.end(), + trans_model.TransitionStateToSelfLoopPdf(tstate))) ) return false; return true; } bool TransitionModel::Compatible(const TransitionModel &other) const { - return (topo_ == other.topo_ && triples_ == other.triples_ && + return (topo_ == other.topo_ && tuples_ == other.tuples_ && state2id_ == other.state2id_ && id2state_ == other.id2state_ && num_pdfs_ == other.num_pdfs_); } +bool TransitionModel::IsSelfLoop(int32 trans_id) const { + KALDI_ASSERT(static_cast(trans_id) < id2state_.size()); + int32 trans_state = id2state_[trans_id]; + int32 trans_index = trans_id - state2id_[trans_state]; + const Tuple &tuple = tuples_[trans_state-1]; + int32 phone = tuple.phone, hmm_state = tuple.hmm_state; + const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(phone); + KALDI_ASSERT(static_cast(hmm_state) < entry.size()); + return (static_cast(trans_index) < entry[hmm_state].transitions.size() + && entry[hmm_state].transitions[trans_index].first == hmm_state); +} + } // End namespace kaldi diff --git a/src/hmm/transition-model.h b/src/hmm/transition-model.h index ff236e6de9e..33a0d55443e 100644 --- a/src/hmm/transition-model.h +++ b/src/hmm/transition-model.h @@ -53,7 +53,7 @@ namespace kaldi { // phone: a phone index (1, 2, 3 ...) // HMM-state: a number (0, 1, 2...) that indexes TopologyEntry (see hmm-topology.h) // pdf-id: a number output by the Compute function of ContextDependency (it -// indexes pdf's). Zero-based. +// indexes pdf's, either forward or self-loop). Zero-based. // transition-state: the states for which we estimate transition probabilities for transitions // out of them. In some topologies, will map one-to-one with pdf-ids. // One-based, since it appears on FSTs. @@ -66,14 +66,15 @@ namespace kaldi { // One-based, since it appears on FSTs. // // List of the possible mappings TransitionModel can do: -// (phone, HMM-state, pdf-id) -> transition-state -// (transition-state, transition-index) -> transition-id +// (phone, HMM-state, forward-pdf-id, self-loop-pdf-id) -> transition-state +// (transition-state, transition-index) -> transition-id // Reverse mappings: // transition-id -> transition-state // transition-id -> transition-index // transition-state -> phone // transition-state -> HMM-state -// transition-state -> pdf-id +// transition-state -> forward-pdf-id +// transition-state -> self-loop-pdf-id // // The main things the TransitionModel object can do are: // Get initialized (need ContextDependency and HmmTopology objects). @@ -141,13 +142,16 @@ class TransitionModel { /// \name Integer mapping functions /// @{ - int32 TripleToTransitionState(int32 phone, int32 hmm_state, int32 pdf) const; + int32 TupleToTransitionState(int32 phone, int32 hmm_state, int32 pdf, int32 self_loop_pdf) const; int32 PairToTransitionId(int32 trans_state, int32 trans_index) const; int32 TransitionIdToTransitionState(int32 trans_id) const; int32 TransitionIdToTransitionIndex(int32 trans_id) const; int32 TransitionStateToPhone(int32 trans_state) const; int32 TransitionStateToHmmState(int32 trans_state) const; - int32 TransitionStateToPdf(int32 trans_state) const; + int32 TransitionStateToForwardPdfClass(int32 trans_state) const; + int32 TransitionStateToSelfLoopPdfClass(int32 trans_state) const; + int32 TransitionStateToForwardPdf(int32 trans_state) const; + int32 TransitionStateToSelfLoopPdf(int32 trans_state) const; int32 SelfLoopOf(int32 trans_state) const; // returns the self-loop transition-id, or zero if // this state doesn't have a self-loop. @@ -172,7 +176,7 @@ class TransitionModel { int32 NumTransitionIndices(int32 trans_state) const; /// Returns the total number of transition-states (note, these are one-based). - int32 NumTransitionStates() const { return triples_.size(); } + int32 NumTransitionStates() const { return tuples_.size(); } // NumPdfs() actually returns the highest-numbered pdf we ever saw, plus one. // In normal cases this should equal the number of pdfs in the system, but if you @@ -249,30 +253,36 @@ class TransitionModel { void MapUpdateShared(const Vector &stats, const MapTransitionUpdateConfig &cfg, BaseFloat *objf_impr_out, BaseFloat *count_out); - void ComputeTriples(const ContextDependencyInterface &ctx_dep); // called from constructor. initializes triples_. + void ComputeTuples(const ContextDependencyInterface &ctx_dep); // called from constructor. initializes tuples_. + void ComputeTuplesIsHmm(const ContextDependencyInterface &ctx_dep); + void ComputeTuplesNotHmm(const ContextDependencyInterface &ctx_dep); void ComputeDerived(); // called from constructor and Read function: computes state2id_ and id2state_. void ComputeDerivedOfProbs(); // computes quantities derived from log-probs (currently just // non_self_loop_log_probs_; called whenever log-probs change. void InitializeProbs(); // called from constructor. void Check() const; + bool IsHmm() const; - struct Triple { + struct Tuple { int32 phone; int32 hmm_state; - int32 pdf; - Triple() { } - Triple(int32 phone, int32 hmm_state, int32 pdf): - phone(phone), hmm_state(hmm_state), pdf(pdf) { } - bool operator < (const Triple &other) const { + int32 forward_pdf; + int32 self_loop_pdf; + Tuple() { } + Tuple(int32 phone, int32 hmm_state, int32 forward_pdf, int32 self_loop_pdf): + phone(phone), hmm_state(hmm_state), forward_pdf(forward_pdf), self_loop_pdf(self_loop_pdf) { } + bool operator < (const Tuple &other) const { if (phone < other.phone) return true; else if (phone > other.phone) return false; else if (hmm_state < other.hmm_state) return true; else if (hmm_state > other.hmm_state) return false; - else return pdf < other.pdf; + else if (forward_pdf < other.forward_pdf) return true; + else if (forward_pdf > other.forward_pdf) return false; + else return (self_loop_pdf < other.self_loop_pdf); } - bool operator == (const Triple &other) const { + bool operator == (const Tuple &other) const { return (phone == other.phone && hmm_state == other.hmm_state - && pdf == other.pdf); + && forward_pdf == other.forward_pdf && self_loop_pdf == other.self_loop_pdf); } }; @@ -281,7 +291,7 @@ class TransitionModel { /// Triples indexed by transition state minus one; /// the triples are in sorted order which allows us to do the reverse mapping from /// triple to transition state - std::vector triples_; + std::vector tuples_; /// Gives the first transition_id of each transition-state; indexed by /// the transition-state. Array indexed 1..num-transition-states+1 (the last one @@ -292,6 +302,8 @@ class TransitionModel { /// state (indexed by transition-id). std::vector id2state_; + std::vector id2pdf_id_; + /// For each transition-id, the corresponding log-prob. Indexed by transition-id. Vector log_probs_; @@ -310,12 +322,9 @@ class TransitionModel { }; inline int32 TransitionModel::TransitionIdToPdf(int32 trans_id) const { - // If a lot of time is spent here we may create an extra array - // to handle this. - KALDI_ASSERT(static_cast(trans_id) < id2state_.size() && + KALDI_ASSERT(static_cast(trans_id) < id2pdf_id_.size() && "Likely graph/model mismatch (graph built from wrong model?)"); - int32 trans_state = id2state_[trans_id]; - return triples_[trans_state-1].pdf; + return id2pdf_id_[trans_id]; } /// Works out which pdfs might correspond to the given phones. Will return true diff --git a/src/itf/context-dep-itf.h b/src/itf/context-dep-itf.h index b989dd900ea..40681bb5ccd 100644 --- a/src/itf/context-dep-itf.h +++ b/src/itf/context-dep-itf.h @@ -63,9 +63,36 @@ class ContextDependencyInterface { /// GetPdfInfo returns a vector indexed by pdf-id, saying for each pdf which /// pairs of (phone, pdf-class) it can correspond to. (Usually just one). /// c.f. hmm/hmm-topology.h for meaning of pdf-class. - virtual void GetPdfInfo(const std::vector &phones, // list of phones - const std::vector &num_pdf_classes, // indexed by phone, - std::vector > > *pdf_info) + /// This is the old, simpler interface of GetPdfInfo(), and that this one can + /// only be called if the HmmTopology object's IsHmm() function call returns + /// true. + virtual void GetPdfInfo( + const std::vector &phones, // list of phones + const std::vector &num_pdf_classes, // indexed by phone, + std::vector > > *pdf_info) + const = 0; + + /// This function outputs information about what possible pdf-ids can + /// be generated for HMM-states; it covers the general case where + /// the self-loop pdf-class may be different from the forward-transition + /// pdf-class, so we are asking not about the set of possible pdf-ids + /// for a given (phone, pdf-class), but the set of possible ordered pairs + /// (forward-transition-pdf, self-loop-pdf) for a given (phone, + /// forward-transition-pdf-class, self-loop-pdf-class). + /// Note: 'phones' is a list of integer ids of phones, and + /// 'pdf-class-pairs', indexed by phone, is a list of pairs + /// (forward-transition-pdf-class, self-loop-pdf-class) that we can have for + /// that phone. + /// The output 'pdf_info' is indexed first by phone and then by the + /// same index that indexes each element of 'pdf_class_pairs', + /// and tells us for each pair in 'pdf_class_pairs', what is the + /// list of possible (forward-transition-pdf-id, self-loop-pdf-id) that + /// we can have. + /// This is less efficient than the other version of GetPdfInfo(). + virtual void GetPdfInfo( + const std::vector &phones, + const std::vector > > &pdf_class_pairs, + std::vector > > > *pdf_info) const = 0; diff --git a/src/sgmm2/Makefile b/src/sgmm2/Makefile index 41a4175aa3b..f0da85e48de 100644 --- a/src/sgmm2/Makefile +++ b/src/sgmm2/Makefile @@ -14,6 +14,6 @@ LIBNAME = kaldi-sgmm2 ADDLIBS = ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ ../tree/kaldi-tree.a ../util/kaldi-util.a ../thread/kaldi-thread.a \ - ../matrix/kaldi-matrix.a ../base/kaldi-base.a + ../matrix/kaldi-matrix.a ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/transform/Makefile b/src/transform/Makefile index 3ae8b1fa3a4..4df681f1ade 100644 --- a/src/transform/Makefile +++ b/src/transform/Makefile @@ -14,8 +14,8 @@ OBJFILES = regression-tree.o regtree-mllr-diag-gmm.o lda-estimate.o \ LIBNAME = kaldi-transform -ADDLIBS = ../gmm/kaldi-gmm.a ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../thread/kaldi-thread.a ../matrix/kaldi-matrix.a \ - ../base/kaldi-base.a +ADDLIBS = ../gmm/kaldi-gmm.a ../tree/kaldi-tree.a \ + ../util/kaldi-util.a ../thread/kaldi-thread.a \ + ../matrix/kaldi-matrix.a ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/tree/context-dep.cc b/src/tree/context-dep.cc index 81eee5bb4ee..03afe547ee4 100644 --- a/src/tree/context-dep.cc +++ b/src/tree/context-dep.cc @@ -178,9 +178,107 @@ void ContextDependency::Read (std::istream &is, bool binary) { to_pdf_ = to_pdf; } -void ContextDependency::GetPdfInfo(const std::vector &phones, - const std::vector &num_pdf_classes, // indexed by phone, - std::vector > > *pdf_info) const { +void ContextDependency::EnumeratePairs( + const std::vector &phones, + int32 self_loop_pdf_class, int32 forward_pdf_class, + const std::vector &phone_window, + unordered_set, PairHasher > *pairs) const { + std::vector new_phone_window(phone_window); + EventType vec; + + std::vector forward_pdfs, self_loop_pdfs; + + // get list of possible forward pdfs + vec.clear(); + for (size_t i = 0; i < N_; i++) + if (phone_window[i] >= 0) + vec.push_back(std::make_pair(static_cast(i), + static_cast(phone_window[i]))); + vec.push_back(std::make_pair(kPdfClass, static_cast(forward_pdf_class))); + std::sort(vec.begin(), vec.end()); + to_pdf_->MultiMap(vec, &forward_pdfs); + SortAndUniq(&forward_pdfs); + + // get list of possible self-loop pdfs + vec.clear(); + for (size_t i = 0; i < N_; i++) + if (phone_window[i] >= 0) + vec.push_back(std::make_pair(static_cast(i), + static_cast(phone_window[i]))); + vec.push_back(std::make_pair(kPdfClass, static_cast(self_loop_pdf_class))); + std::sort(vec.begin(), vec.end()); + to_pdf_->MultiMap(vec, &self_loop_pdfs); + SortAndUniq(&self_loop_pdfs); + + if (forward_pdfs.size() == 1 || self_loop_pdfs.size() == 1) { + for (size_t m = 0; m < forward_pdfs.size(); m++) + for (size_t n = 0; n < self_loop_pdfs.size(); n++) + pairs->insert(std::make_pair(forward_pdfs[m], self_loop_pdfs[n])); + } else { + // Choose 'position' as a phone position in 'context' that's currently + // -1, and that is as close as possible to the central position P. + int32 position = 0; + int32 min_dist = N_ - 1; + for (int32 i = 0; i < N_; i++) { + int32 dist = (P_ - i > 0) ? (P_ - i) : (i - P_); + if (phone_window[i] == -1 && dist < min_dist) { + position = i; + min_dist = dist; + } + } + KALDI_ASSERT(min_dist < N_); + KALDI_ASSERT(position != P_); + + new_phone_window[position] = 0; + EnumeratePairs(phones, self_loop_pdf_class, forward_pdf_class, + new_phone_window, pairs); + + for (size_t i = 0 ; i < phones.size(); i++) { + new_phone_window[position] = phones[i]; + EnumeratePairs(phones, self_loop_pdf_class, forward_pdf_class, + new_phone_window, pairs); + } + } +} + +void ContextDependency::GetPdfInfo( + const std::vector &phones, + const std::vector > > &pdf_class_pairs, + std::vector > > > *pdf_info) const { + + KALDI_ASSERT(pdf_info != NULL); + pdf_info->resize(1 + *std::max_element(phones.begin(), phones.end())); + std::vector phone_window(N_); + EventType vec; + for (size_t i = 0 ; i < phones.size(); i++) { + // loop over phones + int32 phone = phones[i]; + (*pdf_info)[phone].resize(pdf_class_pairs[phone].size()); + for (size_t j = 0; j < pdf_class_pairs[phone].size(); j++) { + // loop over pdf_class pairs + int32 pdf_class = pdf_class_pairs[phone][j].first, + self_loop_pdf_class = pdf_class_pairs[phone][j].second; + for (size_t win_start = 0; win_start < phone_window.size(); win_start++) { + if (win_start != P_) + phone_window[win_start] = -1; + else + phone_window[win_start] = phone; + } + unordered_set, PairHasher > pairs; + EnumeratePairs(phones, self_loop_pdf_class, pdf_class, phone_window, &pairs); + unordered_set, PairHasher >::iterator iter = pairs.begin(), + end = pairs.end(); + for (; iter != end; ++iter) + (*pdf_info)[phone][j].push_back(*iter); + std::sort( ((*pdf_info)[phone][j]).begin(), ((*pdf_info)[phone][j]).end()); + } + } +} + +void ContextDependency::GetPdfInfo( + const std::vector &phones, + const std::vector &num_pdf_classes, // indexed by phone, + std::vector > > *pdf_info) const { EventType vec; KALDI_ASSERT(pdf_info != NULL); diff --git a/src/tree/context-dep.h b/src/tree/context-dep.h index 08dc974570d..6342d89667b 100644 --- a/src/tree/context-dep.h +++ b/src/tree/context-dep.h @@ -20,6 +20,7 @@ #ifndef KALDI_TREE_CONTEXT_DEP_H_ #define KALDI_TREE_CONTEXT_DEP_H_ +#include "util/stl-utils.h" #include "itf/context-dep-itf.h" #include "tree/event-map.h" #include "matrix/matrix-lib.h" @@ -99,9 +100,36 @@ class ContextDependency: public ContextDependencyInterface { /// GetPdfInfo returns a vector indexed by pdf-id, saying for each pdf which /// pairs of (phone, pdf-class) it can correspond to. (Usually just one). /// c.f. hmm/hmm-topology.h for meaning of pdf-class. - virtual void GetPdfInfo(const std::vector &phones, // list of phones - const std::vector &num_pdf_classes, // indexed by phone, - std::vector > > *pdf_info) + /// This is the old, simpler interface of GetPdfInfo(), and that this one can + /// only be called if the HmmTopology object's IsHmm() function call returns + /// true. + virtual void GetPdfInfo( + const std::vector &phones, // list of phones + const std::vector &num_pdf_classes, // indexed by phone, + std::vector > > *pdf_info) + const; + + /// This function outputs information about what possible pdf-ids can + /// be generated for HMM-states; it covers the general case where + /// the self-loop pdf-class may be different from the forward-transition + /// pdf-class, so we are asking not about the set of possible pdf-ids + /// for a given (phone, pdf-class), but the set of possible ordered pairs + /// (forward-transition-pdf, self-loop-pdf) for a given (phone, + /// forward-transition-pdf-class, self-loop-pdf-class). + /// Note: 'phones' is a list of integer ids of phones, and + /// 'pdf-class-pairs', indexed by phone, is a list of pairs + /// (forward-transition-pdf-class, self-loop-pdf-class) that we can have for + /// that phone. + /// The output 'pdf_info' is indexed first by phone and then by the + /// same index that indexes each element of 'pdf_class_pairs', + /// and tells us for each pair in 'pdf_class_pairs', what is the + /// list of possible (forward-transition-pdf-id, self-loop-pdf-id) that + /// we can have. + /// This is less efficient than the other version of GetPdfInfo(). + virtual void GetPdfInfo( + const std::vector &phones, + const std::vector > > &pdf_class_pairs, + std::vector > > > *pdf_info) const; private: @@ -109,6 +137,20 @@ class ContextDependency: public ContextDependencyInterface { int32 P_; EventMap *to_pdf_; // owned here. + // 'context' is the context-window of phones, of + // length N, with -1 for those positions where phones + // that are currently unknown, treated as wildcards; at least + // the central phone [position P] must be a real phone, i.e. + // not -1. + // This function inserts any allowed pairs (forward_pdf, self_loop_pdf) + // to the set "pairs". + void EnumeratePairs( + const std::vector &phones, + int32 self_loop_pdf_class, int32 forward_pdf_class, + const std::vector &context, + unordered_set, PairHasher > *pairs) + const; + KALDI_DISALLOW_COPY_AND_ASSIGN(ContextDependency); };