diff --git a/egs/wsj/s5/steps/nnet3/chain/e2e/prepare_e2e.sh b/egs/wsj/s5/steps/nnet3/chain/e2e/prepare_e2e.sh index a060f0f3b36..c211381bf8b 100755 --- a/egs/wsj/s5/steps/nnet3/chain/e2e/prepare_e2e.sh +++ b/egs/wsj/s5/steps/nnet3/chain/e2e/prepare_e2e.sh @@ -18,6 +18,7 @@ treedir= # if specified, the tree and model will be copied from the # note that it may not be flat start anymore. type=mono # can be either mono or biphone -- either way # the resulting tree is full (i.e. it doesn't do any tying) +ci_silence=false # if true, silence phones will be treated as context independent scale_opts="--transition-scale=0.0 --self-loop-scale=0.0" # End configuration section. @@ -63,12 +64,17 @@ if $shared_phones; then shared_phones_opt="--shared-phones=$lang/phones/sets.int" fi +ciphonelist=`cat $lang/phones/context_indep.csl` || exit 1; +if $ci_silence; then + ci_opt="--ci-phones=$ciphonelist" +fi + if [ $stage -le 0 ]; then if [ -z $treedir ]; then echo "$0: Initializing $type system." # feat dim does not matter here. Just set it to 10 $cmd $dir/log/init_${type}_mdl_tree.log \ - gmm-init-$type $shared_phones_opt $lang/topo 10 \ + gmm-init-$type $ci_opt $shared_phones_opt $lang/topo 10 \ $dir/0.mdl $dir/tree || exit 1; else echo "$0: Copied tree/mdl from $treedir." >$dir/log/init_mdl_tree.log diff --git a/src/gmmbin/gmm-init-biphone.cc b/src/gmmbin/gmm-init-biphone.cc index d1c789a620e..e5cc182f94c 100644 --- a/src/gmmbin/gmm-init-biphone.cc +++ b/src/gmmbin/gmm-init-biphone.cc @@ -51,10 +51,12 @@ void ReadSharedPhonesList(std::string rxfilename, std::vector EventMap *GetFullBiphoneStubMap(const std::vector > &phone_sets, const std::vector &phone2num_pdf_classes, - const std::vector &share_roots) { + const std::vector &share_roots, + const std::vector &ci_phones_list) { { // Check the inputs - KALDI_ASSERT(!phone_sets.empty() && share_roots.size() == phone_sets.size()); + KALDI_ASSERT(!phone_sets.empty() && + share_roots.size() == phone_sets.size()); std::set all_phones; for (size_t i = 0; i < phone_sets.size(); i++) { KALDI_ASSERT(IsSortedAndUniq(phone_sets[i])); @@ -66,9 +68,18 @@ EventMap } } + int32 numpdfs_per_phone = phone2num_pdf_classes[1]; int32 current_pdfid = 0; std::map level1_map; // key is 1 + + for (size_t i = 0; i < ci_phones_list.size(); i++) { + std::map level2_map; + level2_map[0] = current_pdfid++; + if (numpdfs_per_phone == 2) level2_map[1] = current_pdfid++; + level1_map[ci_phones_list[i]] = new TableEventMap(kPdfClass, level2_map); + } + for (size_t i = 0; i < phone_sets.size(); i++) { if (numpdfs_per_phone == 1) { @@ -99,9 +110,11 @@ EventMap level3_map[0] = current_pdfid++; level3_map[1] = current_pdfid++; level2_map[0] = new TableEventMap(kPdfClass, level3_map); // no-left-context case + for (size_t i = 0; i < ci_phones_list.size(); i++) // ci-phone left-context cases + level2_map[ci_phones_list[i]] = new TableEventMap(kPdfClass, level3_map); } for (size_t j = 0; j < phone_sets.size(); j++) { - std::map level3_map; // key is -1 + std::map level3_map; // key is kPdfClass level3_map[0] = current_pdfid++; level3_map[1] = current_pdfid++; @@ -121,17 +134,35 @@ EventMap return new TableEventMap(1, level1_map); } + ContextDependency* -BiphoneContextDependencyFull(const std::vector > phone_sets, - const std::vector phone2num_pdf_classes) { - std::vector share_roots(phone_sets.size(), false); // Don't share roots +BiphoneContextDependencyFull(std::vector > phone_sets, + const std::vector phone2num_pdf_classes, + const std::vector &ci_phones_list) { + // Remove all the CI phones from the phone sets + std::set ci_phones; + for (size_t i = 0; i < ci_phones_list.size(); i++) + ci_phones.insert(ci_phones_list[i]); + for (int32 i = phone_sets.size() - 1; i >= 0; i--) { + for (int32 j = phone_sets[i].size() - 1; j >= 0; j--) { + if (ci_phones.find(phone_sets[i][j]) != ci_phones.end()) { // Delete it + phone_sets[i].erase(phone_sets[i].begin() + j); + if (phone_sets[i].empty()) // If empty, delete the whole entry + phone_sets.erase(phone_sets.begin() + i); + } + } + } + + std::vector share_roots(phone_sets.size(), false); // Don't share roots // N is context size, P = position of central phone (must be 0). int32 P = 1, N = 2; EventMap *pdf_map = GetFullBiphoneStubMap(phone_sets, - phone2num_pdf_classes, share_roots); + phone2num_pdf_classes, + share_roots, ci_phones_list); return new ContextDependency(N, P, pdf_map); } + } // end namespace kaldi int main(int argc, char *argv[]) { @@ -148,11 +179,17 @@ int main(int argc, char *argv[]) { bool binary = true; std::string shared_phones_rxfilename; + std::string ci_phones_str; + std::vector ci_phones; // Sorted, uniqe vector of + // context-independent phones. + ParseOptions po(usage); po.Register("binary", &binary, "Write output in binary mode"); po.Register("shared-phones", &shared_phones_rxfilename, "rxfilename containing, on each line, a list of phones " "whose pdfs should be shared."); + po.Register("ci-phones", &ci_phones_str, "Colon-separated list of " + "integer indices of context-independent phones."); po.Read(argc, argv); if (po.NumArgs() != 4) { @@ -169,6 +206,14 @@ int main(int argc, char *argv[]) { std::string model_filename = po.GetArg(3); std::string tree_filename = po.GetArg(4); + if (!ci_phones_str.empty()) { + SplitStringToIntegers(ci_phones_str, ":", false, &ci_phones); + std::sort(ci_phones.begin(), ci_phones.end()); + if (!IsSortedAndUniq(ci_phones) || ci_phones.empty() || ci_phones[0] == 0) + KALDI_ERR << "Invalid --ci-phones option: " << ci_phones_str; + } + + Vector glob_inv_var(dim); glob_inv_var.Set(1.0); Vector glob_mean(dim); @@ -200,7 +245,8 @@ int main(int argc, char *argv[]) { ReadSharedPhonesList(shared_phones_rxfilename, &shared_phones); // ReadSharedPhonesList crashes on error. } - ctx_dep = BiphoneContextDependencyFull(shared_phones, phone2num_pdf_classes); + ctx_dep = BiphoneContextDependencyFull(shared_phones, phone2num_pdf_classes, + ci_phones); int32 num_pdfs = ctx_dep->NumPdfs();