From d8ff7ee0cd9f5c670c36ad4525899e668d50abef Mon Sep 17 00:00:00 2001 From: chenzhehuai Date: Tue, 3 Apr 2018 00:53:59 -0400 Subject: [PATCH 01/93] make fst templates inline to eliminate linking errors in other places --- src/fstext/fstext-utils-inl.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/fstext/fstext-utils-inl.h b/src/fstext/fstext-utils-inl.h index 923c67c07e2..756e449fcaa 100644 --- a/src/fstext/fstext-utils-inl.h +++ b/src/fstext/fstext-utils-inl.h @@ -1132,7 +1132,7 @@ inline bool IsStochasticFst(const Fst &fst, // Will override this for LogArc where NaturalLess will not work. template -bool IsStochasticFst(const Fst &fst, +inline bool IsStochasticFst(const Fst &fst, float delta, typename Arc::Weight *min_sum, typename Arc::Weight *max_sum) { @@ -1168,7 +1168,7 @@ bool IsStochasticFst(const Fst &fst, // Overriding template for LogArc as NaturalLess does not work there. template<> -bool IsStochasticFst(const Fst &fst, +inline bool IsStochasticFst(const Fst &fst, float delta, LogArc::Weight *min_sum, LogArc::Weight *max_sum) { @@ -1208,7 +1208,7 @@ bool IsStochasticFst(const Fst &fst, // This function deals with the generic fst. // This version currently supports ConstFst or VectorFst. // Otherwise, it will be died with an error. -bool IsStochasticFstInLog(const Fst &fst, +inline bool IsStochasticFstInLog(const Fst &fst, float delta, StdArc::Weight *min_sum, StdArc::Weight *max_sum) { From fa53cc6a9866821b8606fa4cbce0e506948fb375 Mon Sep 17 00:00:00 2001 From: chenzhehuai Date: Fri, 6 Apr 2018 09:01:10 -0400 Subject: [PATCH 02/93] tmp --- src/bin/latgen-biglm-faster-mapped.cc | 278 ++++++++++++++++++++++++++ 1 file changed, 278 insertions(+) create mode 100644 src/bin/latgen-biglm-faster-mapped.cc diff --git a/src/bin/latgen-biglm-faster-mapped.cc b/src/bin/latgen-biglm-faster-mapped.cc new file mode 100644 index 00000000000..18a3336540b --- /dev/null +++ b/src/bin/latgen-biglm-faster-mapped.cc @@ -0,0 +1,278 @@ +// bin/latgen-biglm-faster-mapped .cc + +// Copyright 2018 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "tree/context-dep.h" +#include "hmm/transition-model.h" +#include "fstext/fstext-lib.h" +#include "decoder/decoder-wrappers.h" +#include "decoder/decodable-matrix.h" +#include "base/timer.h" + + +namespace kaldi { +// Takes care of output. Returns true on success. +bool DecodeUtterance(LatticeBiglmFasterDecoder &decoder, // not const but is really an input. + DecodableInterface &decodable, // not const but is really an input. + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr) { // puts utterance's like in like_ptr on success. + using fst::VectorFst; + + if (!decoder.Decode(&decodable)) { + KALDI_WARN << "Failed to decode file " << utt; + return false; + } + if (!decoder.ReachedFinal()) { + if (allow_partial) { + KALDI_WARN << "Outputting partial output for utterance " << utt + << " since no final-state reached\n"; + } else { + KALDI_WARN << "Not producing output for utterance " << utt + << " since no final-state reached and " + << "--allow-partial=false.\n"; + return false; + } + } + + double likelihood; + LatticeWeight weight; + int32 num_frames; + { // First do some stuff with word-level traceback... + VectorFst decoded; + decoder.GetBestPath(&decoded); + if (decoded.NumStates() == 0) + // Shouldn't really reach this point as already checked success. + KALDI_ERR << "Failed to get traceback for utterance " << utt; + + std::vector alignment; + std::vector words; + GetLinearSymbolSequence(decoded, &alignment, &words, &weight); + num_frames = alignment.size(); + if (words_writer->IsOpen()) + words_writer->Write(utt, words); + if (alignment_writer->IsOpen()) + alignment_writer->Write(utt, alignment); + if (word_syms != NULL) { + std::cerr << utt << ' '; + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms->Find(words[i]); + if (s == "") + KALDI_ERR << "Word-id " << words[i] <<" not in symbol table."; + std::cerr << s << ' '; + } + std::cerr << '\n'; + } + likelihood = -(weight.Value1() + weight.Value2()); + } + + // Get lattice, and do determinization if requested. + Lattice lat; + decoder.GetRawLattice(&lat); + if (lat.NumStates() == 0) + KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt; + fst::Connect(&lat); + if (determinize) { + CompactLattice clat; + if (!DeterminizeLatticePhonePrunedWrapper( + trans_model, + &lat, + decoder.GetOptions().lattice_beam, + &clat, + decoder.GetOptions().det_opts)) + KALDI_WARN << "Determinization finished earlier than the beam for " + << "utterance " << utt; + // We'll write the lattice without acoustic scaling. + if (acoustic_scale != 0.0) + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat); + compact_lattice_writer->Write(utt, clat); + } else { + Lattice fst; + decoder.GetRawLattice(&fst); + if (fst.NumStates() == 0) + KALDI_ERR << "Unexpected problem getting lattice for utterance " + << utt; + fst::Connect(&fst); // Will get rid of this later... shouldn't have any + // disconnected states there, but we seem to. + if (acoustic_scale != 0.0) // We'll write the lattice without acoustic scaling + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &fst); + lattice_writer->Write(utt, fst); + } + KALDI_LOG << "Log-like per frame for utterance " << utt << " is " + << (likelihood / num_frames) << " over " + << num_frames << " frames."; + KALDI_VLOG(2) << "Cost for utterance " << utt << " is " + << weight.Value1() << " + " << weight.Value2(); + *like_ptr = likelihood; + return true; +} + +} + + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + using fst::SymbolTable; + using fst::VectorFst; + using fst::Fst; + using fst::StdArc; + using fst::ReadFstKaldi; + + const char *usage = + "Generate lattices using on-the-fly composition.\n" + "User supplies LM used to generate decoding graph, and desired LM;\n" + "this decoder applies the difference during decoding\n" + "Usage: latgen-biglm-faster-mapped [options] model-in (fst-in|fsts-rspecifier) " + "oldlm-fst-in newlm-fst-in features-rspecifier" + " lattice-wspecifier [ words-wspecifier [alignments-wspecifier] ]\n"; + ParseOptions po(usage); + Timer timer; + bool allow_partial = false; + BaseFloat acoustic_scale = 0.1; + LatticeBiglmFasterDecoderConfig config; + + std::string word_syms_filename; + config.Register(&po); + po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic likelihoods"); + + po.Register("word-symbol-table", &word_syms_filename, "Symbol table for words [for debug output]"); + po.Register("allow-partial", &allow_partial, "If true, produce output even if end state was not reached."); + + po.Read(argc, argv); + + if (po.NumArgs() < 6 || po.NumArgs() > 8) { + po.PrintUsage(); + exit(1); + } + + std::string model_in_filename = po.GetArg(1), + fst_in_str = po.GetArg(2), + old_lm_fst_rxfilename = po.GetArg(3), + new_lm_fst_rxfilename = po.GetArg(4), + feature_rspecifier = po.GetArg(5), + lattice_wspecifier = po.GetArg(6), + words_wspecifier = po.GetOptArg(7), + alignment_wspecifier = po.GetOptArg(8); + + TransitionModel trans_model; + ReadKaldiObject(model_in_filename, &trans_model); + + VectorFst *old_lm_fst = fst::CastOrConvertToVectorFst( + fst::ReadFstKaldiGeneric(old_lm_fst_rxfilename)); + ApplyProbabilityScale(-1.0, old_lm_fst); // Negate old LM probs... + + VectorFst *new_lm_fst = fst::CastOrConvertToVectorFst( + fst::ReadFstKaldiGeneric(new_lm_fst_rxfilename)); + + fst::BackoffDeterministicOnDemandFst old_lm_dfst(*old_lm_fst); + fst::BackoffDeterministicOnDemandFst new_lm_dfst(*new_lm_fst); + fst::ComposeDeterministicOnDemandFst compose_dfst(&old_lm_dfst, + &new_lm_dfst); + fst::CacheDeterministicOnDemandFst cache_dfst(&compose_dfst); + + bool determinize = config.determinize_lattice; + CompactLatticeWriter compact_lattice_writer; + LatticeWriter lattice_writer; + if (! (determinize ? compact_lattice_writer.Open(lattice_wspecifier) + : lattice_writer.Open(lattice_wspecifier))) + KALDI_ERR << "Could not open table for writing lattices: " + << lattice_wspecifier; + + Int32VectorWriter words_writer(words_wspecifier); + + Int32VectorWriter alignment_writer(alignment_wspecifier); + + fst::SymbolTable *word_syms = NULL; + if (word_syms_filename != "") + if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename))) + KALDI_ERR << "Could not read symbol table from file " + << word_syms_filename; + + double tot_like = 0.0; + kaldi::int64 frame_count = 0; + int num_success = 0, num_fail = 0; + + + if (ClassifyRspecifier(fst_in_str, NULL, NULL) == kNoRspecifier) { + SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); + // Input FST is just one FST, not a table of FSTs. + Fst *decode_fst = fst::ReadFstKaldiGeneric(fst_in_str); + + { + LatticeBiglmFasterDecoder decoder(*decode_fst, config, &cache_dfst); + + for (; !feature_reader.Done(); feature_reader.Next()) { + std::string utt = feature_reader.Key(); + Matrix features (feature_reader.Value()); + feature_reader.FreeCurrent(); + if (features.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_fail++; + continue; + } + + DecodableMatrixScaledMapped decodable(trans_model, loglikes, acoustic_scale); + + double like; + if (DecodeUtterance(decoder, decodable, trans_model, word_syms, + utt, acoustic_scale, determinize, allow_partial, + &alignment_writer, &words_writer, + &compact_lattice_writer, &lattice_writer, + &like)) { + tot_like += like; + frame_count += features.NumRows(); + num_success++; + } else num_fail++; + } + } + delete decode_fst; // delete this only after decoder goes out of scope. + } else { // We have different FSTs for different utterances. + assert(0); + } + + double elapsed = timer.Elapsed(); + KALDI_LOG << "Time taken "<< elapsed + << "s: real-time factor assuming 100 frames/sec is " + << (elapsed*100.0/frame_count); + KALDI_LOG << "Done " << num_success << " utterances, failed for " + << num_fail; + KALDI_LOG << "Overall log-likelihood per frame is " << (tot_like/frame_count) << " over " + << frame_count<<" frames."; + + delete word_syms; + if (num_success != 0) return 0; + else return 1; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} From efccda41d514de1452b59bf1f890007b88d7876e Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Fri, 6 Apr 2018 12:26:29 -0400 Subject: [PATCH 03/93] zchen@c05:/export/a12/zchen/works/decoder/egs/mini_librispeech/s5_otf$ bash run.biglm.sh --- src/bin/Makefile | 1 + src/bin/latgen-biglm-faster-mapped.cc | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/bin/Makefile b/src/bin/Makefile index 627c4f8a131..165eac6bb26 100644 --- a/src/bin/Makefile +++ b/src/bin/Makefile @@ -23,6 +23,7 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \ vector-sum matrix-sum-rows est-pca sum-lda-accs sum-mllt-accs \ transform-vec align-text matrix-dim post-to-smat +BINFILES += latgen-biglm-faster-mapped OBJFILES = diff --git a/src/bin/latgen-biglm-faster-mapped.cc b/src/bin/latgen-biglm-faster-mapped.cc index 18a3336540b..10265551a89 100644 --- a/src/bin/latgen-biglm-faster-mapped.cc +++ b/src/bin/latgen-biglm-faster-mapped.cc @@ -26,6 +26,7 @@ #include "decoder/decoder-wrappers.h" #include "decoder/decodable-matrix.h" #include "base/timer.h" +#include "decoder/lattice-biglm-faster-decoder.h" namespace kaldi { @@ -240,7 +241,7 @@ int main(int argc, char *argv[]) { continue; } - DecodableMatrixScaledMapped decodable(trans_model, loglikes, acoustic_scale); + DecodableMatrixScaledMapped decodable(trans_model, features, acoustic_scale); double like; if (DecodeUtterance(decoder, decodable, trans_model, word_syms, From ca4fb58c1dd3362b241ceb53f3634e9ce07f7c6f Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Fri, 6 Apr 2018 14:18:51 -0700 Subject: [PATCH 04/93] log --- src/decoder/lattice-biglm-faster-decoder.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/decoder/lattice-biglm-faster-decoder.h b/src/decoder/lattice-biglm-faster-decoder.h index 6276c25a83d..ff337f81083 100644 --- a/src/decoder/lattice-biglm-faster-decoder.h +++ b/src/decoder/lattice-biglm-faster-decoder.h @@ -640,6 +640,8 @@ class LatticeBiglmFasterDecoder { } } if (tok_count != NULL) *tok_count = count; + KALDI_VLOG(6) << "Number of tokens active on frame " << active_toks_.size() - 1 + << " is " << tmp_array_.size(); if (tmp_array_.size() <= static_cast(config_.max_active)) { if (adaptive_beam) *adaptive_beam = config_.beam; return best_weight + config_.beam; From a5da4f9f42783553c5c18e1e1d467ba3efd5f8eb Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Fri, 6 Apr 2018 16:30:29 -0700 Subject: [PATCH 05/93] prune --- src/decoder/lattice-biglm-faster-decoder.h | 69 ++++++++++++++-------- 1 file changed, 44 insertions(+), 25 deletions(-) diff --git a/src/decoder/lattice-biglm-faster-decoder.h b/src/decoder/lattice-biglm-faster-decoder.h index ff337f81083..39c0658a830 100644 --- a/src/decoder/lattice-biglm-faster-decoder.h +++ b/src/decoder/lattice-biglm-faster-decoder.h @@ -630,36 +630,55 @@ class LatticeBiglmFasterDecoder { if (adaptive_beam != NULL) *adaptive_beam = config_.beam; return best_weight + config_.beam; } else { - tmp_array_.clear(); - for (Elem *e = list_head; e != NULL; e = e->tail, count++) { - BaseFloat w = e->val->tot_cost; - tmp_array_.push_back(w); - if (w < best_weight) { - best_weight = w; - if (best_elem) *best_elem = e; - } + tmp_array_.clear(); + for (Elem *e = list_head; e != NULL; e = e->tail, count++) { + BaseFloat w = e->val->tot_cost; + tmp_array_.push_back(w); + if (w < best_weight) { + best_weight = w; + if (best_elem) *best_elem = e; } - if (tok_count != NULL) *tok_count = count; - KALDI_VLOG(6) << "Number of tokens active on frame " << active_toks_.size() - 1 - << " is " << tmp_array_.size(); - if (tmp_array_.size() <= static_cast(config_.max_active)) { - if (adaptive_beam) *adaptive_beam = config_.beam; - return best_weight + config_.beam; - } else { - // the lowest elements (lowest costs, highest likes) - // will be put in the left part of tmp_array. + } + if (tok_count != NULL) *tok_count = count; + + BaseFloat beam_cutoff = best_weight + config_.beam, + min_active_cutoff = std::numeric_limits::infinity(), + max_active_cutoff = std::numeric_limits::infinity(); + + KALDI_VLOG(6) << "Number of tokens active on frame " << active_toks_.size() + << " is " << tmp_array_.size(); + + if (tmp_array_.size() > static_cast(config_.max_active)) { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.max_active, + tmp_array_.end()); + max_active_cutoff = tmp_array_[config_.max_active]; + } + if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam. + if (adaptive_beam) + *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta; + return max_active_cutoff; + } + if (tmp_array_.size() > static_cast(config_.min_active)) { + if (config_.min_active == 0) min_active_cutoff = best_weight; + else { std::nth_element(tmp_array_.begin(), - tmp_array_.begin()+config_.max_active, + tmp_array_.begin() + config_.min_active, + tmp_array_.size() > static_cast(config_.max_active) ? + tmp_array_.begin() + config_.max_active : tmp_array_.end()); - // return the tighter of the two beams. - BaseFloat ans = std::min(best_weight + config_.beam, - *(tmp_array_.begin()+config_.max_active)); - if (adaptive_beam) - *adaptive_beam = std::min(config_.beam, - ans - best_weight + config_.beam_delta); - return ans; + min_active_cutoff = tmp_array_[config_.min_active]; } } + if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam. + if (adaptive_beam) + *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta; + return min_active_cutoff; + } else { + *adaptive_beam = config_.beam; + return beam_cutoff; + } + } } inline StateId PropagateLm(StateId lm_state, From dd1a532e96a31f44fa22ee15df56c55c6bcd51b2 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Sat, 7 Apr 2018 05:00:55 -0700 Subject: [PATCH 06/93] add profile --- src/bin/latgen-biglm-faster-mapped.cc | 1 + src/decoder/lattice-biglm-faster-decoder.h | 36 ++++++++++++---------- src/decoder/lattice-faster-decoder.cc | 4 +-- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/bin/latgen-biglm-faster-mapped.cc b/src/bin/latgen-biglm-faster-mapped.cc index 10265551a89..e8bc461afe4 100644 --- a/src/bin/latgen-biglm-faster-mapped.cc +++ b/src/bin/latgen-biglm-faster-mapped.cc @@ -230,6 +230,7 @@ int main(int argc, char *argv[]) { { LatticeBiglmFasterDecoder decoder(*decode_fst, config, &cache_dfst); + timer.Reset(); for (; !feature_reader.Done(); feature_reader.Next()) { std::string utt = feature_reader.Key(); diff --git a/src/decoder/lattice-biglm-faster-decoder.h b/src/decoder/lattice-biglm-faster-decoder.h index 39c0658a830..b13236c2970 100644 --- a/src/decoder/lattice-biglm-faster-decoder.h +++ b/src/decoder/lattice-biglm-faster-decoder.h @@ -615,21 +615,22 @@ class LatticeBiglmFasterDecoder { /// Gets the weight cutoff. Also counts the active tokens. BaseFloat GetCutoff(Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, Elem **best_elem) { - BaseFloat best_weight = std::numeric_limits::infinity(); - // positive == high cost == bad. - size_t count = 0; - if (config_.max_active == std::numeric_limits::max()) { - for (Elem *e = list_head; e != NULL; e = e->tail, count++) { - BaseFloat w = static_cast(e->val->tot_cost); - if (w < best_weight) { - best_weight = w; - if (best_elem) *best_elem = e; - } + BaseFloat best_weight = std::numeric_limits::infinity(); + // positive == high cost == bad. + size_t count = 0; + if (config_.max_active == std::numeric_limits::max() && + config_.min_active == 0) { + for (Elem *e = list_head; e != NULL; e = e->tail, count++) { + BaseFloat w = static_cast(e->val->tot_cost); + if (w < best_weight) { + best_weight = w; + if (best_elem) *best_elem = e; } - if (tok_count != NULL) *tok_count = count; - if (adaptive_beam != NULL) *adaptive_beam = config_.beam; - return best_weight + config_.beam; - } else { + } + if (tok_count != NULL) *tok_count = count; + if (adaptive_beam != NULL) *adaptive_beam = config_.beam; + return best_weight + config_.beam; + } else { tmp_array_.clear(); for (Elem *e = list_head; e != NULL; e = e->tail, count++) { BaseFloat w = e->val->tot_cost; @@ -678,7 +679,7 @@ class LatticeBiglmFasterDecoder { *adaptive_beam = config_.beam; return beam_cutoff; } - } + } } inline StateId PropagateLm(StateId lm_state, @@ -713,7 +714,10 @@ class LatticeBiglmFasterDecoder { size_t tok_cnt; BaseFloat cur_cutoff = GetCutoff(last_toks, &tok_cnt, &adaptive_beam, &best_elem); PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough. - + KALDI_VLOG(6) << "Adaptive beam on frame " << frame << "\t" << active_toks_.size() << " is " + << adaptive_beam << "\t" << cur_cutoff; + + BaseFloat next_cutoff = std::numeric_limits::infinity(); // pruning "online" before having seen all tokens diff --git a/src/decoder/lattice-faster-decoder.cc b/src/decoder/lattice-faster-decoder.cc index 963430a63f1..161f9bf228a 100644 --- a/src/decoder/lattice-faster-decoder.cc +++ b/src/decoder/lattice-faster-decoder.cc @@ -699,8 +699,8 @@ BaseFloat LatticeFasterDecoder::ProcessEmitting(DecodableInterface *decodable) { BaseFloat adaptive_beam; size_t tok_cnt; BaseFloat cur_cutoff = GetCutoff(final_toks, &tok_cnt, &adaptive_beam, &best_elem); - KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is " - << adaptive_beam; + KALDI_VLOG(6) << "Adaptive beam on frame " << frame << "\t" << NumFramesDecoded() << " is " + << adaptive_beam << "\t" << cur_cutoff; PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough. From ebf91405b2fca3e4687f41daaac69196aca09279 Mon Sep 17 00:00:00 2001 From: chenzhehuai Date: Sat, 7 Apr 2018 19:41:52 -0400 Subject: [PATCH 07/93] tmp --- src/bin/latgen-constlm-faster-mapped.cc | 278 ++++++++++++++++++++++++ 1 file changed, 278 insertions(+) create mode 100644 src/bin/latgen-constlm-faster-mapped.cc diff --git a/src/bin/latgen-constlm-faster-mapped.cc b/src/bin/latgen-constlm-faster-mapped.cc new file mode 100644 index 00000000000..e986814628b --- /dev/null +++ b/src/bin/latgen-constlm-faster-mapped.cc @@ -0,0 +1,278 @@ +// bin/latgen-constlm-faster-mapped .cc + +// Copyright 2018 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "tree/context-dep.h" +#include "hmm/transition-model.h" +#include "fstext/fstext-lib.h" +#include "decoder/decoder-wrappers.h" +#include "decoder/decodable-matrix.h" +#include "base/timer.h" +#include "decoder/lattice-constlm-faster-decoder.h" + + +namespace kaldi { +// Takes care of output. Returns true on success. +bool DecodeUtterance(LatticeConstlmFasterDecoder &decoder, // not const but is really an input. + DecodableInterface &decodable, // not const but is really an input. + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr) { // puts utterance's like in like_ptr on success. + using fst::VectorFst; + + if (!decoder.Decode(&decodable)) { + KALDI_WARN << "Failed to decode file " << utt; + return false; + } + if (!decoder.ReachedFinal()) { + if (allow_partial) { + KALDI_WARN << "Outputting partial output for utterance " << utt + << " since no final-state reached\n"; + } else { + KALDI_WARN << "Not producing output for utterance " << utt + << " since no final-state reached and " + << "--allow-partial=false.\n"; + return false; + } + } + + double likelihood; + LatticeWeight weight; + int32 num_frames; + { // First do some stuff with word-level traceback... + VectorFst decoded; + decoder.GetBestPath(&decoded); + if (decoded.NumStates() == 0) + // Shouldn't really reach this point as already checked success. + KALDI_ERR << "Failed to get traceback for utterance " << utt; + + std::vector alignment; + std::vector words; + GetLinearSymbolSequence(decoded, &alignment, &words, &weight); + num_frames = alignment.size(); + if (words_writer->IsOpen()) + words_writer->Write(utt, words); + if (alignment_writer->IsOpen()) + alignment_writer->Write(utt, alignment); + if (word_syms != NULL) { + std::cerr << utt << ' '; + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms->Find(words[i]); + if (s == "") + KALDI_ERR << "Word-id " << words[i] <<" not in symbol table."; + std::cerr << s << ' '; + } + std::cerr << '\n'; + } + likelihood = -(weight.Value1() + weight.Value2()); + } + + // Get lattice, and do determinization if requested. + Lattice lat; + decoder.GetRawLattice(&lat); + if (lat.NumStates() == 0) + KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt; + fst::Connect(&lat); + if (determinize) { + CompactLattice clat; + if (!DeterminizeLatticePhonePrunedWrapper( + trans_model, + &lat, + decoder.GetOptions().lattice_beam, + &clat, + decoder.GetOptions().det_opts)) + KALDI_WARN << "Determinization finished earlier than the beam for " + << "utterance " << utt; + // We'll write the lattice without acoustic scaling. + if (acoustic_scale != 0.0) + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat); + compact_lattice_writer->Write(utt, clat); + } else { + Lattice fst; + decoder.GetRawLattice(&fst); + if (fst.NumStates() == 0) + KALDI_ERR << "Unexpected problem getting lattice for utterance " + << utt; + fst::Connect(&fst); // Will get rid of this later... shouldn't have any + // disconnected states there, but we seem to. + if (acoustic_scale != 0.0) // We'll write the lattice without acoustic scaling + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &fst); + lattice_writer->Write(utt, fst); + } + KALDI_LOG << "Log-like per frame for utterance " << utt << " is " + << (likelihood / num_frames) << " over " + << num_frames << " frames."; + KALDI_VLOG(2) << "Cost for utterance " << utt << " is " + << weight.Value1() << " + " << weight.Value2(); + *like_ptr = likelihood; + return true; +} + +} + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + using fst::SymbolTable; + using fst::VectorFst; + using fst::Fst; + using fst::StdArc; + using fst::ReadFstKaldi; + + const char *usage = + "Generate lattices using on-the-fly composition.\n" + "User supplies LM used to generate decoding graph, and desired LM;\n" + "this decoder applies the difference during decoding\n" + "Usage: latgen-biglm-faster-mapped [options] model-in (fst-in|fsts-rspecifier) " + "oldlm-fst-in newlm-fst-in features-rspecifier" + " lattice-wspecifier [ words-wspecifier [alignments-wspecifier] ]\n"; + ParseOptions po(usage); + Timer timer; + bool allow_partial = false; + BaseFloat acoustic_scale = 0.1; + LatticeConstlmFasterDecoderConfig config; + + std::string word_syms_filename; + config.Register(&po); + po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic likelihoods"); + + po.Register("word-symbol-table", &word_syms_filename, "Symbol table for words [for debug output]"); + po.Register("allow-partial", &allow_partial, "If true, produce output even if end state was not reached."); + + po.Read(argc, argv); + + if (po.NumArgs() < 6 || po.NumArgs() > 8) { + po.PrintUsage(); + exit(1); + } + + std::string model_in_filename = po.GetArg(1), + fst_in_str = po.GetArg(2), + old_lm_fst_rxfilename = po.GetArg(3), + new_lm_fst_rxfilename = po.GetArg(4), + feature_rspecifier = po.GetArg(5), + lattice_wspecifier = po.GetArg(6), + words_wspecifier = po.GetOptArg(7), + alignment_wspecifier = po.GetOptArg(8); + + TransitionModel trans_model; + ReadKaldiObject(model_in_filename, &trans_model); + + ConstArpaLm old_lm; + ReadKaldiObject(old_lm_fst_rxfilename, &old_lm); + ConstArpaLmDeterministicFst old_lm_dfst(old_lm); + ApplyProbabilityScale(-1.0, old_lm_dfst); // Negate old LM probs... + + ConstArpaLm new_lm; + ReadKaldiObject(new_lm_fst_rxfilename, &new_lm); + ConstArpaLmDeterministicFst new_lm_dfst(new_lm); + + fst::ComposeDeterministicOnDemandFst compose_dfst(&old_lm_dfst, + &new_lm_dfst); + fst::CacheDeterministicOnDemandFst cache_dfst(&compose_dfst); + + bool determinize = config.determinize_lattice; + CompactLatticeWriter compact_lattice_writer; + LatticeWriter lattice_writer; + if (! (determinize ? compact_lattice_writer.Open(lattice_wspecifier) + : lattice_writer.Open(lattice_wspecifier))) + KALDI_ERR << "Could not open table for writing lattices: " + << lattice_wspecifier; + + Int32VectorWriter words_writer(words_wspecifier); + + Int32VectorWriter alignment_writer(alignment_wspecifier); + + fst::SymbolTable *word_syms = NULL; + if (word_syms_filename != "") + if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename))) + KALDI_ERR << "Could not read symbol table from file " + << word_syms_filename; + + double tot_like = 0.0; + kaldi::int64 frame_count = 0; + int num_success = 0, num_fail = 0; + + + if (ClassifyRspecifier(fst_in_str, NULL, NULL) == kNoRspecifier) { + SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); + // Input FST is just one FST, not a table of FSTs. + Fst *decode_fst = fst::ReadFstKaldiGeneric(fst_in_str); + + { + LatticeConstlmFasterDecoder decoder(*decode_fst, config, &cache_dfst); + timer.Reset(); + + for (; !feature_reader.Done(); feature_reader.Next()) { + std::string utt = feature_reader.Key(); + Matrix features (feature_reader.Value()); + feature_reader.FreeCurrent(); + if (features.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_fail++; + continue; + } + + DecodableMatrixScaledMapped decodable(trans_model, features, acoustic_scale); + + double like; + if (DecodeUtterance(decoder, decodable, trans_model, word_syms, + utt, acoustic_scale, determinize, allow_partial, + &alignment_writer, &words_writer, + &compact_lattice_writer, &lattice_writer, + &like)) { + tot_like += like; + frame_count += features.NumRows(); + num_success++; + } else num_fail++; + } + } + delete decode_fst; // delete this only after decoder goes out of scope. + } else { // We have different FSTs for different utterances. + assert(0); + } + + double elapsed = timer.Elapsed(); + KALDI_LOG << "Time taken "<< elapsed + << "s: real-time factor assuming 100 frames/sec is " + << (elapsed*100.0/frame_count); + KALDI_LOG << "Done " << num_success << " utterances, failed for " + << num_fail; + KALDI_LOG << "Overall log-likelihood per frame is " << (tot_like/frame_count) << " over " + << frame_count<<" frames."; + + delete word_syms; + if (num_success != 0) return 0; + else return 1; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} From 4dd846ed82f1c70b09aa957f15db7c3b163ed3c2 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Sat, 7 Apr 2018 18:36:55 -0700 Subject: [PATCH 08/93] single det --- src/bin/Makefile | 2 +- src/bin/latgen-biglm-faster-mapped.cc | 2 +- src/bin/latgen-constlm-faster-mapped.cc | 15 ++++++++------- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/bin/Makefile b/src/bin/Makefile index 165eac6bb26..5cf8ed63032 100644 --- a/src/bin/Makefile +++ b/src/bin/Makefile @@ -23,7 +23,7 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \ vector-sum matrix-sum-rows est-pca sum-lda-accs sum-mllt-accs \ transform-vec align-text matrix-dim post-to-smat -BINFILES += latgen-biglm-faster-mapped +BINFILES += latgen-biglm-faster-mapped latgen-constlm-faster-mapped OBJFILES = diff --git a/src/bin/latgen-biglm-faster-mapped.cc b/src/bin/latgen-biglm-faster-mapped.cc index e8bc461afe4..548dff7533d 100644 --- a/src/bin/latgen-biglm-faster-mapped.cc +++ b/src/bin/latgen-biglm-faster-mapped.cc @@ -198,7 +198,7 @@ int main(int argc, char *argv[]) { fst::BackoffDeterministicOnDemandFst new_lm_dfst(*new_lm_fst); fst::ComposeDeterministicOnDemandFst compose_dfst(&old_lm_dfst, &new_lm_dfst); - fst::CacheDeterministicOnDemandFst cache_dfst(&compose_dfst); + fst::CacheDeterministicOnDemandFst cache_dfst(&new_lm_dfst); bool determinize = config.determinize_lattice; CompactLatticeWriter compact_lattice_writer; diff --git a/src/bin/latgen-constlm-faster-mapped.cc b/src/bin/latgen-constlm-faster-mapped.cc index e986814628b..caf8dbc5004 100644 --- a/src/bin/latgen-constlm-faster-mapped.cc +++ b/src/bin/latgen-constlm-faster-mapped.cc @@ -26,12 +26,13 @@ #include "decoder/decoder-wrappers.h" #include "decoder/decodable-matrix.h" #include "base/timer.h" -#include "decoder/lattice-constlm-faster-decoder.h" +#include "lm/const-arpa-lm.h" +#include "decoder/lattice-biglm-faster-decoder.h" namespace kaldi { // Takes care of output. Returns true on success. -bool DecodeUtterance(LatticeConstlmFasterDecoder &decoder, // not const but is really an input. +bool DecodeUtterance(LatticeBiglmFasterDecoder &decoder, // not const but is really an input. DecodableInterface &decodable, // not const but is really an input. const TransitionModel &trans_model, const fst::SymbolTable *word_syms, @@ -157,7 +158,7 @@ int main(int argc, char *argv[]) { Timer timer; bool allow_partial = false; BaseFloat acoustic_scale = 0.1; - LatticeConstlmFasterDecoderConfig config; + LatticeBiglmFasterDecoderConfig config; std::string word_syms_filename; config.Register(&po); @@ -185,18 +186,18 @@ int main(int argc, char *argv[]) { TransitionModel trans_model; ReadKaldiObject(model_in_filename, &trans_model); + /* ConstArpaLm old_lm; ReadKaldiObject(old_lm_fst_rxfilename, &old_lm); ConstArpaLmDeterministicFst old_lm_dfst(old_lm); ApplyProbabilityScale(-1.0, old_lm_dfst); // Negate old LM probs... + */ ConstArpaLm new_lm; ReadKaldiObject(new_lm_fst_rxfilename, &new_lm); ConstArpaLmDeterministicFst new_lm_dfst(new_lm); - fst::ComposeDeterministicOnDemandFst compose_dfst(&old_lm_dfst, - &new_lm_dfst); - fst::CacheDeterministicOnDemandFst cache_dfst(&compose_dfst); + fst::CacheDeterministicOnDemandFst cache_dfst(&new_lm_dfst); bool determinize = config.determinize_lattice; CompactLatticeWriter compact_lattice_writer; @@ -227,7 +228,7 @@ int main(int argc, char *argv[]) { Fst *decode_fst = fst::ReadFstKaldiGeneric(fst_in_str); { - LatticeConstlmFasterDecoder decoder(*decode_fst, config, &cache_dfst); + LatticeBiglmFasterDecoder decoder(*decode_fst, config, &cache_dfst); timer.Reset(); for (; !feature_reader.Done(); feature_reader.Next()) { From fc767523651ba2bf0acdf464a1bf5b7126ed18c6 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Sun, 8 Apr 2018 14:06:30 -0400 Subject: [PATCH 09/93] tmp --- src/bin/latgen-biglm-faster-mapped.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bin/latgen-biglm-faster-mapped.cc b/src/bin/latgen-biglm-faster-mapped.cc index 548dff7533d..e8bc461afe4 100644 --- a/src/bin/latgen-biglm-faster-mapped.cc +++ b/src/bin/latgen-biglm-faster-mapped.cc @@ -198,7 +198,7 @@ int main(int argc, char *argv[]) { fst::BackoffDeterministicOnDemandFst new_lm_dfst(*new_lm_fst); fst::ComposeDeterministicOnDemandFst compose_dfst(&old_lm_dfst, &new_lm_dfst); - fst::CacheDeterministicOnDemandFst cache_dfst(&new_lm_dfst); + fst::CacheDeterministicOnDemandFst cache_dfst(&compose_dfst); bool determinize = config.determinize_lattice; CompactLatticeWriter compact_lattice_writer; From 451c471bc2ec9dcc189bac6bda7484afda71f43f Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Sun, 8 Apr 2018 12:37:31 -0700 Subject: [PATCH 10/93] otf-res ntok=1 --- src/fstext/deterministic-fst-inl.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fstext/deterministic-fst-inl.h b/src/fstext/deterministic-fst-inl.h index c6f99697e00..3dc49d04ff6 100644 --- a/src/fstext/deterministic-fst-inl.h +++ b/src/fstext/deterministic-fst-inl.h @@ -190,7 +190,7 @@ bool ComposeDeterministicOnDemandFst::GetArc(StateId s, Label ilabel, Arc arc2; if (!fst2_->GetArc(pr.second, arc1.olabel, &arc2)) return false; std::pair, StateId> new_value( - std::pair(arc1.nextstate, arc2.nextstate), + std::pair(arc1.nextstate, arc1.nextstate), next_state_); std::pair result = state_map_.insert(new_value); @@ -199,7 +199,7 @@ bool ComposeDeterministicOnDemandFst::GetArc(StateId s, Label ilabel, oarc->nextstate = result.first->second; oarc->weight = Times(arc1.weight, arc2.weight); if (result.second == true) { // was inserted - next_state_++; + //next_state_++; const std::pair &new_pair (new_value.first); state_vec_.push_back(new_pair); } From f66fc65336ab5bb9f144f24937b45c99ccaa23e8 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Sun, 8 Apr 2018 17:41:02 -0700 Subject: [PATCH 11/93] ntok=1 --- src/decoder/lattice-biglm-faster-decoder.h | 45 ++++++++++++++++++++-- src/fstext/deterministic-fst-inl.h | 4 +- 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/src/decoder/lattice-biglm-faster-decoder.h b/src/decoder/lattice-biglm-faster-decoder.h index b13236c2970..33bc551d76c 100644 --- a/src/decoder/lattice-biglm-faster-decoder.h +++ b/src/decoder/lattice-biglm-faster-decoder.h @@ -64,11 +64,12 @@ class LatticeBiglmFasterDecoder { KALDI_ASSERT(fst.Start() != fst::kNoStateId && lm_diff_fst->Start() != fst::kNoStateId); toks_.SetSize(1000); // just so on the first frame we do something reasonable. + toks_g1.SetSize(1000); // just so on the first frame we do something reasonable. } void SetOptions(const LatticeBiglmFasterDecoderConfig &config) { config_ = config; } LatticeBiglmFasterDecoderConfig GetOptions() { return config_; } ~LatticeBiglmFasterDecoder() { - DeleteElems(toks_.Clear()); + DeleteElems(toks_.Clear()); ClearActiveTokens(); } @@ -87,6 +88,7 @@ class LatticeBiglmFasterDecoder { Token *start_tok = new Token(0.0, 0.0, NULL, NULL); active_toks_[0].toks = start_tok; toks_.Insert(start_pair, start_tok); + toks_g1.Insert(PairToState(start_pair), start_pair); num_toks_++; ProcessNonemitting(0); @@ -298,6 +300,7 @@ class LatticeBiglmFasterDecoder { }; typedef HashList::Elem Elem; + typedef HashList::Elem Elem_g1; void PossiblyResizeHash(size_t num_toks) { size_t new_sz = static_cast(static_cast(num_toks) @@ -305,6 +308,9 @@ class LatticeBiglmFasterDecoder { if (new_sz > toks_.Size()) { toks_.SetSize(new_sz); } + if (new_sz > toks_g1.Size()) { + toks_g1.SetSize(new_sz); + } } // FindOrAddToken either locates a token in hash of toks_, @@ -312,7 +318,7 @@ class LatticeBiglmFasterDecoder { // for the current frame. [note: it's inserted if necessary into hash toks_ // and also into the singly linked list of tokens active on this frame // (whose head is at active_toks_[frame]). - inline Token *FindOrAddToken(PairId state_pair, int32 frame, BaseFloat tot_cost, + inline Token *FindOrAddToken_2(PairId state_pair, int32 frame, BaseFloat tot_cost, bool emitting, bool *changed) { // Returns the Token pointer. Sets "changed" (if non-NULL) to true // if the token was newly created or the cost changed. @@ -349,7 +355,29 @@ class LatticeBiglmFasterDecoder { return tok; } } - + inline Token *FindOrAddToken(PairId state_pair, int32 frame, BaseFloat tot_cost, + bool emitting, bool *changed) { + // Returns the Token pointer. Sets "changed" (if non-NULL) to true + // if the token was newly created or the cost changed. + KALDI_ASSERT(frame < active_toks_.size()); + Elem_g1 *e_found = toks_g1.Find(PairToState(state_pair)); + if (e_found == NULL) { // no such token presently. + toks_g1.Insert(PairToState(state_pair), state_pair); + return FindOrAddToken_2(state_pair, frame, tot_cost, emitting, changed); + } else { + Elem* e_f = toks_.Find(e_found->val); + assert(e_f); + Token *tok = e_f->val; // There is an existing Token for this state. + if (tok->tot_cost > tot_cost) { // replace old token + e_found->val = state_pair; + tok = FindOrAddToken_2(state_pair, frame, tot_cost, emitting, changed); + } else { + if (changed) *changed = false; + } + return tok; + } + } + // prunes outgoing links for all tokens in active_toks_[frame] // it's called by PruneActiveTokens // all links, that have link_extra_cost > lattice_beam are pruned @@ -441,6 +469,7 @@ class LatticeBiglmFasterDecoder { best_cost_nofinal = infinity; unordered_map tok_to_final_cost; Elem *cur_toks = toks_.Clear(); // swapping prev_toks_ / cur_toks_ + DeleteElems_1(toks_g1.Clear()); for (Elem *e = cur_toks, *e_tail; e != NULL; e = e_tail) { PairId state_pair = e->key; StateId state = PairToState(state_pair), @@ -709,6 +738,7 @@ class LatticeBiglmFasterDecoder { void ProcessEmitting(DecodableInterface *decodable, int32 frame) { // Processes emitting arcs for one frame. Propagates from prev_toks_ to cur_toks_. Elem *last_toks = toks_.Clear(); // swapping prev_toks_ / cur_toks_ + DeleteElems_1(toks_g1.Clear()); Elem *best_elem = NULL; BaseFloat adaptive_beam; size_t tok_cnt; @@ -857,6 +887,7 @@ class LatticeBiglmFasterDecoder { // more than one list (e.g. for current and previous frames), but only one of // them at a time can be indexed by StateId. HashList toks_; + HashList toks_g1; std::vector active_toks_; // Lists of tokens, indexed by // frame (members of TokenList are toks, must_prune_forward_links, // must_prune_tokens). @@ -886,6 +917,14 @@ class LatticeBiglmFasterDecoder { toks_.Delete(e); } toks_.Clear(); + DeleteElems_1(toks_g1.Clear()); + } + void DeleteElems_1(Elem_g1 *list) { + for (Elem_g1 *e = list, *e_tail; e != NULL; e = e_tail) { + e_tail = e->tail; + toks_g1.Delete(e); + } + toks_g1.Clear(); } void ClearActiveTokens() { // a cleanup routine, at utt end/begin diff --git a/src/fstext/deterministic-fst-inl.h b/src/fstext/deterministic-fst-inl.h index 3dc49d04ff6..c6f99697e00 100644 --- a/src/fstext/deterministic-fst-inl.h +++ b/src/fstext/deterministic-fst-inl.h @@ -190,7 +190,7 @@ bool ComposeDeterministicOnDemandFst::GetArc(StateId s, Label ilabel, Arc arc2; if (!fst2_->GetArc(pr.second, arc1.olabel, &arc2)) return false; std::pair, StateId> new_value( - std::pair(arc1.nextstate, arc1.nextstate), + std::pair(arc1.nextstate, arc2.nextstate), next_state_); std::pair result = state_map_.insert(new_value); @@ -199,7 +199,7 @@ bool ComposeDeterministicOnDemandFst::GetArc(StateId s, Label ilabel, oarc->nextstate = result.first->second; oarc->weight = Times(arc1.weight, arc2.weight); if (result.second == true) { // was inserted - //next_state_++; + next_state_++; const std::pair &new_pair (new_value.first); state_vec_.push_back(new_pair); } From a480cceaae1259ed2132f9dfeb5acdd14947b9be Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Sun, 8 Apr 2018 18:47:33 -0700 Subject: [PATCH 12/93] add beam in g1_map --- src/decoder/lattice-biglm-faster-decoder.h | 53 ++++++++++++---------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/src/decoder/lattice-biglm-faster-decoder.h b/src/decoder/lattice-biglm-faster-decoder.h index 33bc551d76c..eb640d99937 100644 --- a/src/decoder/lattice-biglm-faster-decoder.h +++ b/src/decoder/lattice-biglm-faster-decoder.h @@ -300,7 +300,7 @@ class LatticeBiglmFasterDecoder { }; typedef HashList::Elem Elem; - typedef HashList::Elem Elem_g1; + typedef HashList::Elem Elem_g1; void PossiblyResizeHash(size_t num_toks) { size_t new_sz = static_cast(static_cast(num_toks) @@ -355,28 +355,27 @@ class LatticeBiglmFasterDecoder { return tok; } } - inline Token *FindOrAddToken(PairId state_pair, int32 frame, BaseFloat tot_cost, - bool emitting, bool *changed) { +#define res_beam 1 + inline bool FindOrAddToken(StateId state_id, int32 frame, BaseFloat tot_cost, + bool emitting, bool *changed, bool pp) { // Returns the Token pointer. Sets "changed" (if non-NULL) to true // if the token was newly created or the cost changed. KALDI_ASSERT(frame < active_toks_.size()); - Elem_g1 *e_found = toks_g1.Find(PairToState(state_pair)); + Elem_g1 *e_found = toks_g1.Find(state_id); if (e_found == NULL) { // no such token presently. - toks_g1.Insert(PairToState(state_pair), state_pair); - return FindOrAddToken_2(state_pair, frame, tot_cost, emitting, changed); + toks_g1.Insert(state_id, tot_cost); + return true; } else { - Elem* e_f = toks_.Find(e_found->val); - assert(e_f); - Token *tok = e_f->val; // There is an existing Token for this state. - if (tok->tot_cost > tot_cost) { // replace old token - e_found->val = state_pair; - tok = FindOrAddToken_2(state_pair, frame, tot_cost, emitting, changed); - } else { - if (changed) *changed = false; + if (tot_cost < e_found->val + res_beam) {// There is an existing Token for this state. + if (tot_cost < e_found->val) + e_found->val = tot_cost; + return true; + } + else { + return false; } - return tok; } - } + } // prunes outgoing links for all tokens in active_toks_[frame] // it's called by PruneActiveTokens @@ -712,10 +711,12 @@ class LatticeBiglmFasterDecoder { } inline StateId PropagateLm(StateId lm_state, - Arc *arc) { // returns new LM state. + Arc *arc, bool *pp=NULL) { // returns new LM state. if (arc->olabel == 0) { + if (pp) *pp=false; return lm_state; // no change in LM state if no word crossed. } else { // Propagate in the LM-diff FST. + if (pp) *pp=false; Arc lm_arc; bool ans = lm_diff_fst_->GetArc(lm_state, arc->olabel, &lm_arc); if (!ans) { // this case is unexpected for statistical LMs. @@ -790,16 +791,18 @@ class LatticeBiglmFasterDecoder { const Arc &arc_ref = aiter.Value(); if (arc_ref.ilabel != 0) { // propagate.. Arc arc(arc_ref); - StateId next_lm_state = PropagateLm(lm_state, &arc); - BaseFloat ac_cost = -decodable->LogLikelihood(frame-1, arc.ilabel), - graph_cost = arc.weight.Value(), + bool pp; + BaseFloat ac_cost = -decodable->LogLikelihood(frame-1, arc.ilabel); + if (!FindOrAddToken(arc.nextstate, frame, tok->tot_cost + ac_cost+ arc.weight.Value(), true, NULL, pp)) continue; + StateId next_lm_state = PropagateLm(lm_state, &arc, &pp); + BaseFloat graph_cost = arc.weight.Value(), cur_cost = tok->tot_cost, tot_cost = cur_cost + ac_cost + graph_cost; if (tot_cost > next_cutoff) continue; else if (tot_cost + config_.beam < next_cutoff) next_cutoff = tot_cost + config_.beam; // prune by best current token PairId next_pair = ConstructPair(arc.nextstate, next_lm_state); - Token *next_tok = FindOrAddToken(next_pair, frame, tot_cost, true, NULL); + Token *next_tok = FindOrAddToken_2(next_pair, frame, tot_cost, true, NULL); // true: emitting, NULL: no change indicator needed // Add ForwardLink from tok to next_tok (put on head of list tok->links) @@ -861,13 +864,15 @@ class LatticeBiglmFasterDecoder { const Arc &arc_ref = aiter.Value(); if (arc_ref.ilabel == 0) { // propagate nonemitting only... Arc arc(arc_ref); - StateId next_lm_state = PropagateLm(lm_state, &arc); + bool pp; + if (!FindOrAddToken(arc.nextstate, frame, tok->tot_cost + arc.weight.Value(), true, NULL, pp)) continue; + StateId next_lm_state = PropagateLm(lm_state, &arc, &pp); BaseFloat graph_cost = arc.weight.Value(), tot_cost = cur_cost + graph_cost; if (tot_cost < cutoff) { bool changed; PairId next_pair = ConstructPair(arc.nextstate, next_lm_state); - Token *new_tok = FindOrAddToken(next_pair, frame, tot_cost, + Token *new_tok = FindOrAddToken_2(next_pair, frame, tot_cost, false, &changed); // false: non-emit tok->links = new ForwardLink(new_tok, 0, arc.olabel, @@ -887,7 +892,7 @@ class LatticeBiglmFasterDecoder { // more than one list (e.g. for current and previous frames), but only one of // them at a time can be indexed by StateId. HashList toks_; - HashList toks_g1; + HashList toks_g1; std::vector active_toks_; // Lists of tokens, indexed by // frame (members of TokenList are toks, must_prune_forward_links, // must_prune_tokens). From c0bf8268d33f1a563fd114c65ce8fc2455f57c71 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Sun, 8 Apr 2018 19:09:19 -0700 Subject: [PATCH 13/93] tiny --- src/decoder/lattice-biglm-faster-decoder.h | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/decoder/lattice-biglm-faster-decoder.h b/src/decoder/lattice-biglm-faster-decoder.h index eb640d99937..429a16b2574 100644 --- a/src/decoder/lattice-biglm-faster-decoder.h +++ b/src/decoder/lattice-biglm-faster-decoder.h @@ -371,9 +371,12 @@ class LatticeBiglmFasterDecoder { e_found->val = tot_cost; return true; } - else { + else if (pp) { return false; } + else { + return true; + } } } @@ -791,7 +794,7 @@ class LatticeBiglmFasterDecoder { const Arc &arc_ref = aiter.Value(); if (arc_ref.ilabel != 0) { // propagate.. Arc arc(arc_ref); - bool pp; + bool pp=arc.olabel>0; BaseFloat ac_cost = -decodable->LogLikelihood(frame-1, arc.ilabel); if (!FindOrAddToken(arc.nextstate, frame, tok->tot_cost + ac_cost+ arc.weight.Value(), true, NULL, pp)) continue; StateId next_lm_state = PropagateLm(lm_state, &arc, &pp); @@ -864,7 +867,7 @@ class LatticeBiglmFasterDecoder { const Arc &arc_ref = aiter.Value(); if (arc_ref.ilabel == 0) { // propagate nonemitting only... Arc arc(arc_ref); - bool pp; + bool pp=arc.olabel>0; if (!FindOrAddToken(arc.nextstate, frame, tok->tot_cost + arc.weight.Value(), true, NULL, pp)) continue; StateId next_lm_state = PropagateLm(lm_state, &arc, &pp); BaseFloat graph_cost = arc.weight.Value(), From d54d45a285ba1c16a6dba95d5127d638494f5e82 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Mon, 9 Apr 2018 12:31:49 -0700 Subject: [PATCH 14/93] tiny --- src/bin/latgen-biglm-faster-mapped.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bin/latgen-biglm-faster-mapped.cc b/src/bin/latgen-biglm-faster-mapped.cc index e8bc461afe4..1f87572a4f3 100644 --- a/src/bin/latgen-biglm-faster-mapped.cc +++ b/src/bin/latgen-biglm-faster-mapped.cc @@ -198,7 +198,7 @@ int main(int argc, char *argv[]) { fst::BackoffDeterministicOnDemandFst new_lm_dfst(*new_lm_fst); fst::ComposeDeterministicOnDemandFst compose_dfst(&old_lm_dfst, &new_lm_dfst); - fst::CacheDeterministicOnDemandFst cache_dfst(&compose_dfst); + fst::CacheDeterministicOnDemandFst cache_dfst(&compose_dfst, 1e7); bool determinize = config.determinize_lattice; CompactLatticeWriter compact_lattice_writer; From e258ec183cd2dd12778762635719998bebaeea12 Mon Sep 17 00:00:00 2001 From: chenzhehuai Date: Thu, 12 Apr 2018 22:28:37 -0400 Subject: [PATCH 15/93] tmp --- src/bin/latgen-fasterlm-faster-mapped.cc | 286 ++++++++++++++++++++ src/lm/faster-arpa-lm.cc | 36 +++ src/lm/faster-arpa-lm.h | 324 +++++++++++++++++++++++ 3 files changed, 646 insertions(+) create mode 100644 src/bin/latgen-fasterlm-faster-mapped.cc create mode 100644 src/lm/faster-arpa-lm.cc create mode 100644 src/lm/faster-arpa-lm.h diff --git a/src/bin/latgen-fasterlm-faster-mapped.cc b/src/bin/latgen-fasterlm-faster-mapped.cc new file mode 100644 index 00000000000..fe6ff62f6eb --- /dev/null +++ b/src/bin/latgen-fasterlm-faster-mapped.cc @@ -0,0 +1,286 @@ +// bin/latgen-fasterlm-faster-mapped .cc + +// Copyright 2018 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "tree/context-dep.h" +#include "hmm/transition-model.h" +#include "fstext/fstext-lib.h" +#include "decoder/decoder-wrappers.h" +#include "decoder/decodable-matrix.h" +#include "base/timer.h" +#include "lm/faster-arpa-lm.h" +#include "decoder/lattice-biglm-faster-decoder.h" + + +namespace kaldi { +// Takes care of output. Returns true on success. +bool DecodeUtterance(LatticeBiglmFasterDecoder &decoder, // not const but is really an input. + DecodableInterface &decodable, // not const but is really an input. + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr) { // puts utterance's like in like_ptr on success. + using fst::VectorFst; + + if (!decoder.Decode(&decodable)) { + KALDI_WARN << "Failed to decode file " << utt; + return false; + } + if (!decoder.ReachedFinal()) { + if (allow_partial) { + KALDI_WARN << "Outputting partial output for utterance " << utt + << " since no final-state reached\n"; + } else { + KALDI_WARN << "Not producing output for utterance " << utt + << " since no final-state reached and " + << "--allow-partial=false.\n"; + return false; + } + } + + double likelihood; + LatticeWeight weight; + int32 num_frames; + { // First do some stuff with word-level traceback... + VectorFst decoded; + decoder.GetBestPath(&decoded); + if (decoded.NumStates() == 0) + // Shouldn't really reach this point as already checked success. + KALDI_ERR << "Failed to get traceback for utterance " << utt; + + std::vector alignment; + std::vector words; + GetLinearSymbolSequence(decoded, &alignment, &words, &weight); + num_frames = alignment.size(); + if (words_writer->IsOpen()) + words_writer->Write(utt, words); + if (alignment_writer->IsOpen()) + alignment_writer->Write(utt, alignment); + if (word_syms != NULL) { + std::cerr << utt << ' '; + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms->Find(words[i]); + if (s == "") + KALDI_ERR << "Word-id " << words[i] <<" not in symbol table."; + std::cerr << s << ' '; + } + std::cerr << '\n'; + } + likelihood = -(weight.Value1() + weight.Value2()); + } + + // Get lattice, and do determinization if requested. + Lattice lat; + decoder.GetRawLattice(&lat); + if (lat.NumStates() == 0) + KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt; + fst::Connect(&lat); + if (determinize) { + CompactLattice clat; + if (!DeterminizeLatticePhonePrunedWrapper( + trans_model, + &lat, + decoder.GetOptions().lattice_beam, + &clat, + decoder.GetOptions().det_opts)) + KALDI_WARN << "Determinization finished earlier than the beam for " + << "utterance " << utt; + // We'll write the lattice without acoustic scaling. + if (acoustic_scale != 0.0) + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat); + compact_lattice_writer->Write(utt, clat); + } else { + Lattice fst; + decoder.GetRawLattice(&fst); + if (fst.NumStates() == 0) + KALDI_ERR << "Unexpected problem getting lattice for utterance " + << utt; + fst::Connect(&fst); // Will get rid of this later... shouldn't have any + // disconnected states there, but we seem to. + if (acoustic_scale != 0.0) // We'll write the lattice without acoustic scaling + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &fst); + lattice_writer->Write(utt, fst); + } + KALDI_LOG << "Log-like per frame for utterance " << utt << " is " + << (likelihood / num_frames) << " over " + << num_frames << " frames."; + KALDI_VLOG(2) << "Cost for utterance " << utt << " is " + << weight.Value1() << " + " << weight.Value2(); + *like_ptr = likelihood; + return true; +} + +} + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + using fst::SymbolTable; + using fst::VectorFst; + using fst::Fst; + using fst::StdArc; + using fst::ReadFstKaldi; + + const char *usage = + "Generate lattices using on-the-fly composition.\n" + "User supplies LM used to generate decoding graph, and desired LM;\n" + "this decoder applies the difference during decoding\n" + "Usage: latgen-biglm-faster-mapped [options] model-in (fst-in|fsts-rspecifier) " + "oldlm-fst-in newlm-fst-in features-rspecifier" + " lattice-wspecifier [ words-wspecifier [alignments-wspecifier] ]\n"; + ParseOptions po(usage); + Timer timer; + bool allow_partial = false; + BaseFloat acoustic_scale = 0.1; + LatticeBiglmFasterDecoderConfig config; + config.Register(&po); + + ArpaParseOptions arpa_options; + options.Register(&po); + + std::string word_syms_filename; + po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic likelihoods"); + + po.Register("word-symbol-table", &word_syms_filename, "Symbol table for words [for debug output]"); + po.Register("allow-partial", &allow_partial, "If true, produce output even if end state was not reached."); + + po.Read(argc, argv); + + if (po.NumArgs() < 6 || po.NumArgs() > 8) { + po.PrintUsage(); + exit(1); + } + + std::string model_in_filename = po.GetArg(1), + fst_in_str = po.GetArg(2), + old_lm_fst_rxfilename = po.GetArg(3), + new_lm_fst_rxfilename = po.GetArg(4), + feature_rspecifier = po.GetArg(5), + lattice_wspecifier = po.GetArg(6), + words_wspecifier = po.GetOptArg(7), + alignment_wspecifier = po.GetOptArg(8); + + TransitionModel trans_model; + ReadKaldiObject(model_in_filename, &trans_model); + + /* + FasterArpaLm old_lm; + ReadKaldiObject(old_lm_fst_rxfilename, &old_lm); + FasterArpaLmDeterministicFst old_lm_dfst(old_lm); + ApplyProbabilityScale(-1.0, old_lm_dfst); // Negate old LM probs... + */ + + FasterArpaLm old_lm(arpa_options, old_lm_fst_rxfilename, -1); + FasterArpaLmDeterministicFst new_lm_dfst(old_lm); + + FasterArpaLm new_lm(arpa_options, new_lm_fst_rxfilename); + FasterArpaLmDeterministicFst new_lm_dfst(new_lm); + + fst::ComposeDeterministicOnDemandFst compose_dfst(&old_lm_dfst, + &new_lm_dfst); + fst::CacheDeterministicOnDemandFst cache_dfst(&compose_dfst, 1e7); + + bool determinize = config.determinize_lattice; + CompactLatticeWriter compact_lattice_writer; + LatticeWriter lattice_writer; + if (! (determinize ? compact_lattice_writer.Open(lattice_wspecifier) + : lattice_writer.Open(lattice_wspecifier))) + KALDI_ERR << "Could not open table for writing lattices: " + << lattice_wspecifier; + + Int32VectorWriter words_writer(words_wspecifier); + + Int32VectorWriter alignment_writer(alignment_wspecifier); + + fst::SymbolTable *word_syms = NULL; + if (word_syms_filename != "") + if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename))) + KALDI_ERR << "Could not read symbol table from file " + << word_syms_filename; + + double tot_like = 0.0; + kaldi::int64 frame_count = 0; + int num_success = 0, num_fail = 0; + + + if (ClassifyRspecifier(fst_in_str, NULL, NULL) == kNoRspecifier) { + SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); + // Input FST is just one FST, not a table of FSTs. + Fst *decode_fst = fst::ReadFstKaldiGeneric(fst_in_str); + + { + LatticeBiglmFasterDecoder decoder(*decode_fst, config, &cache_dfst); + timer.Reset(); + + for (; !feature_reader.Done(); feature_reader.Next()) { + std::string utt = feature_reader.Key(); + Matrix features (feature_reader.Value()); + feature_reader.FreeCurrent(); + if (features.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_fail++; + continue; + } + + DecodableMatrixScaledMapped decodable(trans_model, features, acoustic_scale); + + double like; + if (DecodeUtterance(decoder, decodable, trans_model, word_syms, + utt, acoustic_scale, determinize, allow_partial, + &alignment_writer, &words_writer, + &compact_lattice_writer, &lattice_writer, + &like)) { + tot_like += like; + frame_count += features.NumRows(); + num_success++; + } else num_fail++; + } + } + delete decode_fst; // delete this only after decoder goes out of scope. + } else { // We have different FSTs for different utterances. + assert(0); + } + + double elapsed = timer.Elapsed(); + KALDI_LOG << "Time taken "<< elapsed + << "s: real-time factor assuming 100 frames/sec is " + << (elapsed*100.0/frame_count); + KALDI_LOG << "Done " << num_success << " utterances, failed for " + << num_fail; + KALDI_LOG << "Overall log-likelihood per frame is " << (tot_like/frame_count) << " over " + << frame_count<<" frames."; + + delete word_syms; + if (num_success != 0) return 0; + else return 1; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/lm/faster-arpa-lm.cc b/src/lm/faster-arpa-lm.cc new file mode 100644 index 00000000000..81d0322ed5b --- /dev/null +++ b/src/lm/faster-arpa-lm.cc @@ -0,0 +1,36 @@ +// lm/const-arpa-lm.cc + +// Copyright 2018 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "base/kaldi-math.h" +#include "lm/arpa-file-parser.h" +#include "lm/faster-arpa-lm.h" +#include "util/stl-utils.h" +#include "util/text-utils.h" + + +namespace kaldi { + + + +} // namespace kaldi diff --git a/src/lm/faster-arpa-lm.h b/src/lm/faster-arpa-lm.h new file mode 100644 index 00000000000..a9e6f06ac20 --- /dev/null +++ b/src/lm/faster-arpa-lm.h @@ -0,0 +1,324 @@ +// lm/const-arpa-lm.h + +// Copyright 2018 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_LM_FASTER_ARPA_LM_H_ +#define KALDI_LM_FASTER_ARPA_LM_H_ + +#include +#include + +#include "base/kaldi-common.h" +#include "fstext/deterministic-fst.h" +#include "lm/arpa-file-parser.h" +#include "util/common-utils.h" + +namespace kaldi { + +#define MAX_NGRAM 5+1 + +class FasterArpaLm { + public: + + // LmState in FasterArpaLm: the basic storage unit + class LmState { + public: + LmState() logprob_(0) { } + Allocate(NGram* ngram, float lm_scale=1): + logprob_(ngram->logprob_*lm_scale), + backoff_logprob_(ngram->backoff_logprob_*lm_scale) { + /* + std::vector &word_ids = ngram->words; + int32 ngram_order = word_ids.size(); + int32 sz= sizeof(int32)*(ngram_order); + */ + } + bool IsExist() { return logprob_!=0; }; + ~LmState() { } + + // for current query + float logprob_; + // for next query; can be optional + float backoff_logprob_; + }; + + // Class to build FasterArpaLm from Arpa format language model. It relies on the + // auxiliary class LmState above. + class FasterArpaLmBuilder : public ArpaFileParser { + public: + FasterArpaLmBuilder(ArpaParseOptions &options, FasterArpaLm *lm, + float lm_scale = 1): + lm_(lm), lm_scale_(lm_scale) { ArpaFileParser(options, NULL); } + ~FasterArpaLmBuilder() { } + + protected: + // ArpaFileParser overrides. + virtual void HeaderAvailable() { + lm_->Allocate(NgramCounts(), Symbols()); + } + virtual void ConsumeNGram(const NGram& ngram) { + LmState *lmstate = lm_->GetHashedState(ngram.words); + lmstate->Allocate(&ngram, lm_scale_); + } + + virtual void ReadComplete() { } + + private: + FasterArpaLm *lm_; + float lm_scale_; + }; + + FasterArpaLm(ArpaParseOptions &options, const std::string& arpa_rxfilename, + float lm_scale=0) { + is_built_ = false; + ngram_order_ = 0; + num_words_ = 0; + lm_states_size_ = 0; + ngrams_ = NULL; + randint_per_word_gram_ = NULL; + options_ = options; + + BuildFasterArpaLm(arpa_rxfilename, lm_scale); + } + + ~FasterArpaLm() { + if (is_built_) free(); + } + + inline LmState* GetHashedState(int32* word_ids, + int query_ngram_order) { + assert(query_ngram_order > 0 && query_ngram_order <= ngram_order_); + int32 ngram_order = query_ngram_order; + if (ngram_order == 1) { + return &ngrams_[ngram_order-1][word_ids[ngram_order-1]]; + } else { + int32 hashed_idx=randint_per_word_gram_[0][word_ids[0]]; + for (int i=1; i &word_ids, + int query_ngram_order = 0) { + int32 ngram_order = query_ngram_order==0? word_ids.size(): query_ngram_order; + int32 word_ids_arr[MAX_NGRAM]; + for (int i=0; i& o_word_ids) { + float prob; + assert(ngram_order > 0); + if (ngram_order > ngram_order_) { + //while (wseq.size() >= lm_.NgramOrder()) { + // History state has at most lm_.NgramOrder() -1 words in the state. + // wseq.erase(wseq.begin(), wseq.begin() + 1); + //} + // we don't need to do above things as we do in reverse fashion: + // memcpy(n_wids+1, wids, len(wids)); n_wids[0] = cur_wrd; + ngram_order = ngram_order_; + } + + LmState *lm_state = GetHashedState(word_ids, ngram_order); + assert(lm_state); + if (lm_state->IsExist()) { + prob = lm_state->logprob_; + o_word_ids.resize(ngram_order); + for (int i=0; ibackoff_logprob_ + + GetNgramLogprob(word_ids, ngram_order - 1, o_word_ids); + } + return prob; + } + + bool BuildFasterArpaLm(const std::string& arpa_rxfilename, float lm_scale) { + FasterArpaLmBuilder lm_builder(options_, this, lm_scale); + KALDI_VLOG(1) << "Reading " << arpa_rxfilename; + Input ki(arpa_rxfilename); + lm_builder.Read(ki.Stream()); + return true; + } + + private: + void Allocate(const std::vector& ngram_count, + const fst::SymbolTable* symbols) { + ngram_order_ = ngram_count.size(); + uint64 max_rand = -1; + kaldi::RandomState rstate; + rstate.seed = 27437; + ngrams_ = malloc(ngram_order_ * sizeof(void*)); + randint_per_word_gram_ = malloc(ngram_order_ * sizeof(void*)); + ngrams_hashed_size_ = malloc(ngram_order_ * sizeof(int32)); + for (int i=0; i< ngram_order_; i++) { + if (i == 0) ngrams_hashed_size_[i] = ngram_count[i]; // uni-gram + else { + ngrams_hashed_size_[i] = (1<<(int)ceil(log(ngram_count[i]) / + M_LN2 + 0.3)); + } + KALDI_VLOG(2) << "ngram: "<< i <<" hashed_size/size = "<< + ngrams_hashed_size_[i] / ngram_count[i]; + ngrams_[i] = new LmState[ngrams_hashed_size_[i]]; + randint_per_word_gram_[i] = new int32[symbols->NumSymbols()]; + for (int j=0; jNumSymbols(); j++) { + randint_per_word_gram_[i][j] = kaldi::RandInt(0, max_rand, &rstate); + } + } + is_built_ = true; + } + void free() { + for (int i=0; i< ngram_order_; i++) { + delete ngrams_[i]; + delete randint_per_word_gram_[i]; + } + delete ngrams_; + delete randint_per_word_gram_; + } + + private: + // configurations + + // Indicating if FasterArpaLm has been built or not. + bool is_built_; + // N-gram order of language model. This can be figured out from "/data/" + // section in Arpa format language model. + int32 ngram_order_; + // Index of largest word-id plus one. It defines the end of + // array. + int32 num_words_; + // Size of the array, which will be needed by I/O. + int64 lm_states_size_; + // Hash table from word sequences to LmStates. + unordered_map, + LmState*, VectorHasher > seq_to_state_; + ArpaParseOptions &options; + + // data + + // Memory blcok for storing N-gram; ngrams_[ngram_order][hashed_idx] + LmState** ngrams_; + // used to obtain hash value; randint_per_word_gram_[ngram_order][word_id] + uint64** randint_per_word_gram_; + int32* ngrams_hashed_size_; +}; + + +/** + This class wraps a FasterArpaLm format language model with the interface defined + in DeterministicOnDemandFst. + */ +class FasterArpaLmDeterministicFst + : public fst::DeterministicOnDemandFst { + public: + typedef fst::StdArc::Weight Weight; + typedef fst::StdArc::StateId StateId; + typedef fst::StdArc::Label Label; + typedef FasterArpaLm::LmState LmState; + + explicit FasterArpaLmDeterministicFst(const FasterArpaLm& lm): + lm_(lm), start_state_(0) { + // Creates a history state for . + std::vector. + int32 eos_symbol_; + // Integer corresponds to unknown-word. -1 if no unknown-word symbol is + // provided. + int32 unk_symbol_; // N-gram order of language model. This can be figured out from "/data/" // section in Arpa format language model. int32 ngram_order_; + int32 symbol_size_; // Index of largest word-id plus one. It defines the end of // array. int32 num_words_; @@ -215,14 +237,14 @@ class FasterArpaLm { // Hash table from word sequences to LmStates. unordered_map, LmState*, VectorHasher > seq_to_state_; - ArpaParseOptions &options; + ArpaParseOptions &options_; // data // Memory blcok for storing N-gram; ngrams_[ngram_order][hashed_idx] LmState** ngrams_; // used to obtain hash value; randint_per_word_gram_[ngram_order][word_id] - uint64** randint_per_word_gram_; + RAND_TYPE** randint_per_word_gram_; int32* ngrams_hashed_size_; }; @@ -240,7 +262,7 @@ class FasterArpaLmDeterministicFst typedef FasterArpaLm::LmState LmState; explicit FasterArpaLmDeterministicFst(const FasterArpaLm& lm): - lm_(lm), start_state_(0) { + start_state_(0), lm_(lm) { // Creates a history state for . std::vector. You must set this to your actual " + "EOS integer."); + + po.Read(argc, argv); + + { + std::string g_lm_fst_rxfilename = po.GetArg(1); + VectorFst *old_lm_fst = fst::CastOrConvertToVectorFst( + fst::ReadFstKaldiGeneric(g_lm_fst_rxfilename)); + fst::BackoffDeterministicOnDemandFst old_lm_dfst(*old_lm_fst); + fst::CacheDeterministicOnDemandFst cache_dfst(&old_lm_dfst, 1e7); + get_score(&cache_dfst, word_ids, state_ids, scores, TEST_SIZE); + } + { + std::string g_lm_fst_rxfilename = po.GetArg(2); + ConstArpaLm new_lm; + ReadKaldiObject(g_lm_fst_rxfilename, &new_lm); + ConstArpaLmDeterministicFst new_lm_dfst(new_lm); + fst::CacheDeterministicOnDemandFst cache_dfst(&new_lm_dfst, 1e7); + get_score(&cache_dfst, word_ids, state_ids, scores2, TEST_SIZE); + } + { + std::string g_lm_fst_rxfilename = po.GetArg(3); + FasterArpaLm new_lm(arpa_options, g_lm_fst_rxfilename, symbol_size); + FasterArpaLmDeterministicFst new_lm_dfst(new_lm); + fst::CacheDeterministicOnDemandFst cache_dfst(&new_lm_dfst, 1e7); + get_score(&cache_dfst, word_ids, state_ids, scores3, TEST_SIZE); + } + for (int i=0;i Date: Sat, 14 Apr 2018 11:10:43 -0700 Subject: [PATCH 23/93] found out it's h_value problem; add a hack to reduce colid --- src/bin/latgen-fasterlm-faster-mapped.cc | 2 +- src/lm/faster-arpa-lm-test.cc | 6 ++++-- src/lm/faster-arpa-lm.h | 15 ++++++++++++++- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/bin/latgen-fasterlm-faster-mapped.cc b/src/bin/latgen-fasterlm-faster-mapped.cc index 637df4896fa..42b310f3100 100644 --- a/src/bin/latgen-fasterlm-faster-mapped.cc +++ b/src/bin/latgen-fasterlm-faster-mapped.cc @@ -208,7 +208,7 @@ int main(int argc, char *argv[]) { FasterArpaLmDeterministicFst old_lm_dfst(old_lm); ApplyProbabilityScale(-1.0, old_lm_dfst); // Negate old LM probs... */ -#if 0 +#if 1 FasterArpaLm old_lm(arpa_options, old_lm_fst_rxfilename, symbol_size, -1); FasterArpaLmDeterministicFst old_lm_dfst(old_lm); #else diff --git a/src/lm/faster-arpa-lm-test.cc b/src/lm/faster-arpa-lm-test.cc index e8d327fc975..44fd25e8d95 100644 --- a/src/lm/faster-arpa-lm-test.cc +++ b/src/lm/faster-arpa-lm-test.cc @@ -67,12 +67,14 @@ int main(int argc, char *argv[]) { #define Arc fst::StdArc using fst::ReadFstKaldi; -#define TEST_SIZE 25 +#define TEST_SIZE 28 +//#define TEST_SIZE 25 ParseOptions po(""); float scores[TEST_SIZE]; float scores2[TEST_SIZE]; float scores3[TEST_SIZE]; - int32 word_ids[]={14207, 198712, 7589, 175861, 171937, 124782, 36528, 175861, 104488, 150861, 139719, 78075, 14268, 124782, 61783, 196158, 4, 20681, 194454, 137421, 158810, 161569, 4, 37434, 50498}; + //int32 word_ids[]={14207, 198712, 7589, 175861, 171937, 124782, 36528, 175861, 104488, 150861, 139719, 78075, 14268, 124782, 61783, 196158, 4, 20681, 194454, 137421, 158810, 161569, 4, 37434, 50498}; + int32 word_ids[] = {14207, 198712, 7589, 4, 171935, 87918, 124782, 36528, 175861, 104488, 150861, 139719, 78075, 14268, 124782, 61783, 196158, 4, 20681, 194454, 138359, 155516, 2379, 160908, 2811, 4, 37434, 50498}; int32 state_ids[TEST_SIZE]={0}; ArpaParseOptions arpa_options; diff --git a/src/lm/faster-arpa-lm.h b/src/lm/faster-arpa-lm.h index f81320f1c2e..fb69360a1fe 100644 --- a/src/lm/faster-arpa-lm.h +++ b/src/lm/faster-arpa-lm.h @@ -123,7 +123,8 @@ class FasterArpaLm { } else { hashed_idx=randint_per_word_gram_[0][word_ids[0]]; for (int i=1; iIsExist()); //assert(ngram_order==1 || GetHashedState(word_ids, ngram_order-1)->IsExist()); prob = lm_state->logprob_; + /* + for (int i=0; i0); + o_word_ids.resize(ngram_order); for (int i=0; i 1); // thus we can do backoff const LmState *lm_state_bo = GetHashedState(word_ids + 1, ngram_order-1); From 907d914db55417f7dc731d2509bac923253c974f Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Sat, 14 Apr 2018 12:10:05 -0700 Subject: [PATCH 24/93] separate ngrams_map; use uint64 rand() --- src/bin/latgen-fasterlm-faster-mapped.cc | 1 - src/lm/faster-arpa-lm-test.cc | 10 +++-- src/lm/faster-arpa-lm.h | 48 +++++++++++++++--------- 3 files changed, 37 insertions(+), 22 deletions(-) diff --git a/src/bin/latgen-fasterlm-faster-mapped.cc b/src/bin/latgen-fasterlm-faster-mapped.cc index 42b310f3100..be2fce3e01d 100644 --- a/src/bin/latgen-fasterlm-faster-mapped.cc +++ b/src/bin/latgen-fasterlm-faster-mapped.cc @@ -189,7 +189,6 @@ int main(int argc, char *argv[]) { exit(1); } - KALDI_LOG << RAND_MAX; std::string model_in_filename = po.GetArg(1), fst_in_str = po.GetArg(2), old_lm_fst_rxfilename = po.GetArg(3), diff --git a/src/lm/faster-arpa-lm-test.cc b/src/lm/faster-arpa-lm-test.cc index 44fd25e8d95..b542c662988 100644 --- a/src/lm/faster-arpa-lm-test.cc +++ b/src/lm/faster-arpa-lm-test.cc @@ -67,14 +67,16 @@ int main(int argc, char *argv[]) { #define Arc fst::StdArc using fst::ReadFstKaldi; -#define TEST_SIZE 28 +#define TEST_SIZE 26 +//#define TEST_SIZE 28 //#define TEST_SIZE 25 ParseOptions po(""); float scores[TEST_SIZE]; float scores2[TEST_SIZE]; float scores3[TEST_SIZE]; //int32 word_ids[]={14207, 198712, 7589, 175861, 171937, 124782, 36528, 175861, 104488, 150861, 139719, 78075, 14268, 124782, 61783, 196158, 4, 20681, 194454, 137421, 158810, 161569, 4, 37434, 50498}; - int32 word_ids[] = {14207, 198712, 7589, 4, 171935, 87918, 124782, 36528, 175861, 104488, 150861, 139719, 78075, 14268, 124782, 61783, 196158, 4, 20681, 194454, 138359, 155516, 2379, 160908, 2811, 4, 37434, 50498}; + //int32 word_ids[] = {14207, 198712, 7589, 4, 171935, 87918, 124782, 36528, 175861, 104488, 150861, 139719, 78075, 14268, 124782, 61783, 196158, 4, 20681, 194454, 138359, 155516, 2379, 160908, 2811, 4, 37434, 50498}; + int32 word_ids[] = {14207, 198712, 7589, 175861, 171937, 124782, 36528, 175861, 104488, 150861, 139719, 78075, 14268, 124782, 61783, 196158, 124782, 19206, 53865, 137753, 2279, 32505, 153074, 4, 37434, 50498}; int32 state_ids[TEST_SIZE]={0}; ArpaParseOptions arpa_options; @@ -117,8 +119,8 @@ int main(int argc, char *argv[]) { get_score(&cache_dfst, word_ids, state_ids, scores3, TEST_SIZE); } for (int i=0;i1e-4) KALDI_LOG<1e-4) KALDI_LOG<IsExist()); //assert(ngram_order==1 || GetHashedState(word_ids, ngram_order-1)->IsExist()); prob = lm_state->logprob_; - /* + +/* for (int i=0; i ngrams_map_; // hash to ngrams_ index // used to obtain hash value; randint_per_word_gram_[ngram_order][word_id] RAND_TYPE** randint_per_word_gram_; - int32* ngrams_hashed_size_; + int32* ngrams_hashed_size_; //after init, it's an accumulate value int32 hash_size_except_uni_; int32 max_collision_; }; From 7f272fdc948641c78ecc6bb539b3b7a3a90e6a39 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Sat, 14 Apr 2018 12:30:19 -0700 Subject: [PATCH 25/93] match performance in exp_dec/constlm.1a/dec.log; but still larger toks --- src/lm/faster-arpa-lm-test.cc | 4 ++-- src/lm/faster-arpa-lm.h | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/lm/faster-arpa-lm-test.cc b/src/lm/faster-arpa-lm-test.cc index b542c662988..9a72289be23 100644 --- a/src/lm/faster-arpa-lm-test.cc +++ b/src/lm/faster-arpa-lm-test.cc @@ -67,7 +67,7 @@ int main(int argc, char *argv[]) { #define Arc fst::StdArc using fst::ReadFstKaldi; -#define TEST_SIZE 26 +#define TEST_SIZE 39 //#define TEST_SIZE 28 //#define TEST_SIZE 25 ParseOptions po(""); @@ -76,7 +76,7 @@ int main(int argc, char *argv[]) { float scores3[TEST_SIZE]; //int32 word_ids[]={14207, 198712, 7589, 175861, 171937, 124782, 36528, 175861, 104488, 150861, 139719, 78075, 14268, 124782, 61783, 196158, 4, 20681, 194454, 137421, 158810, 161569, 4, 37434, 50498}; //int32 word_ids[] = {14207, 198712, 7589, 4, 171935, 87918, 124782, 36528, 175861, 104488, 150861, 139719, 78075, 14268, 124782, 61783, 196158, 4, 20681, 194454, 138359, 155516, 2379, 160908, 2811, 4, 37434, 50498}; - int32 word_ids[] = {14207, 198712, 7589, 175861, 171937, 124782, 36528, 175861, 104488, 150861, 139719, 78075, 14268, 124782, 61783, 196158, 124782, 19206, 53865, 137753, 2279, 32505, 153074, 4, 37434, 50498}; + int32 word_ids[] = {78521, 148206, 178313, 175861, 144826, 28459, 25372, 62655, 138328, 175861, 72352, 76155, 152997, 4, 102911, 177031, 193231, 127711, 71590, 47932, 151710, 40606, 5411, 82074, 86219, 81505, 77097, 4, 155384, 194419, 193822, 71589, 76098, 163928, 124918, 177084, 9376, 81505, 78840}; int32 state_ids[TEST_SIZE]={0}; ArpaParseOptions arpa_options; diff --git a/src/lm/faster-arpa-lm.h b/src/lm/faster-arpa-lm.h index 13d7d0e1b53..fab511c0ea5 100644 --- a/src/lm/faster-arpa-lm.h +++ b/src/lm/faster-arpa-lm.h @@ -67,7 +67,7 @@ class FasterArpaLm { float logprob_; // for next query; can be optional float backoff_logprob_; - int32 h_value; + RAND_TYPE h_value; LmState* next; // for colid }; @@ -122,10 +122,10 @@ class FasterArpaLm { int32 NgramOrder() const { return ngram_order_; } inline int32 GetHashedIdx(const int32* word_ids, - int query_ngram_order, int32 *h_value=NULL) const { + int query_ngram_order, RAND_TYPE *h_value=NULL) const { assert(query_ngram_order > 0 && query_ngram_order <= ngram_order_); int32 ngram_order = query_ngram_order; - int32 hashed_idx; + RAND_TYPE hashed_idx; if (ngram_order == 1) { hashed_idx = word_ids[ngram_order-1]; } else { @@ -134,7 +134,7 @@ class FasterArpaLm { int word_id=word_ids[i]; hashed_idx ^= randint_per_word_gram_[i][word_id]; } - if (h_value) *h_value = hashed_idx; // to check colid + if (h_value) *h_value = hashed_idx; // to check colid, h_value should be precise int i = ngram_order-1; hashed_idx &= (ngrams_hashed_size_[i]-ngrams_hashed_size_[i-1] - 1); @@ -158,7 +158,7 @@ class FasterArpaLm { } inline void SaveHashedState(const int32* word_ids, int query_ngram_order, LmState &lm_state_pattern) { - int32 h_value=0; + RAND_TYPE h_value=0; int32 hashed_idx = GetHashedIdx(word_ids, query_ngram_order, &h_value); lm_state_pattern.h_value = h_value; int32 ngram_order = query_ngram_order; @@ -184,7 +184,7 @@ class FasterArpaLm { inline const LmState* GetHashedState(const int32* word_ids, int query_ngram_order) const { - int32 h_value; + RAND_TYPE h_value; int32 hashed_idx = GetHashedIdx(word_ids, query_ngram_order, &h_value); int32 ngram_order = query_ngram_order; if (ngram_order == 1) { From 7c65ea5ca28ae9de32362108efde4b01f9476595 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Sat, 14 Apr 2018 13:00:27 -0700 Subject: [PATCH 26/93] add otfres; --- src/bin/Makefile | 2 +- .../latgen-otfres-fasterlm-faster-mapped.cc | 304 ++++++ .../lattice-otfres-biglm-faster-decoder.h | 957 ++++++++++++++++++ 3 files changed, 1262 insertions(+), 1 deletion(-) create mode 100644 src/bin/latgen-otfres-fasterlm-faster-mapped.cc create mode 100644 src/decoder/lattice-otfres-biglm-faster-decoder.h diff --git a/src/bin/Makefile b/src/bin/Makefile index 9ce73123612..439353b06eb 100644 --- a/src/bin/Makefile +++ b/src/bin/Makefile @@ -23,7 +23,7 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \ vector-sum matrix-sum-rows est-pca sum-lda-accs sum-mllt-accs \ transform-vec align-text matrix-dim post-to-smat -BINFILES += latgen-biglm-faster-mapped latgen-constlm-faster-mapped latgen-fasterlm-faster-mapped +BINFILES += latgen-biglm-faster-mapped latgen-constlm-faster-mapped latgen-fasterlm-faster-mapped latgen-otfres-fasterlm-faster-mapped OBJFILES = diff --git a/src/bin/latgen-otfres-fasterlm-faster-mapped.cc b/src/bin/latgen-otfres-fasterlm-faster-mapped.cc new file mode 100644 index 00000000000..ad475f9405f --- /dev/null +++ b/src/bin/latgen-otfres-fasterlm-faster-mapped.cc @@ -0,0 +1,304 @@ +// bin/latgen-otfres-fasterlm-faster-mapped .cc + +// Copyright 2018 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "tree/context-dep.h" +#include "hmm/transition-model.h" +#include "fstext/fstext-lib.h" +#include "decoder/decoder-wrappers.h" +#include "decoder/decodable-matrix.h" +#include "base/timer.h" +#include "lm/faster-arpa-lm.h" +#include "decoder/lattice-otfres-biglm-faster-decoder.h" + + +namespace kaldi { +// Takes care of output. Returns true on success. +bool DecodeUtterance(LatticeBiglmFasterDecoder &decoder, // not const but is really an input. + DecodableInterface &decodable, // not const but is really an input. + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr) { // puts utterance's like in like_ptr on success. + using fst::VectorFst; + + if (!decoder.Decode(&decodable)) { + KALDI_WARN << "Failed to decode file " << utt; + return false; + } + if (!decoder.ReachedFinal()) { + if (allow_partial) { + KALDI_WARN << "Outputting partial output for utterance " << utt + << " since no final-state reached\n"; + } else { + KALDI_WARN << "Not producing output for utterance " << utt + << " since no final-state reached and " + << "--allow-partial=false.\n"; + return false; + } + } + + double likelihood; + LatticeWeight weight; + int32 num_frames; + { // First do some stuff with word-level traceback... + VectorFst decoded; + decoder.GetBestPath(&decoded); + if (decoded.NumStates() == 0) + // Shouldn't really reach this point as already checked success. + KALDI_ERR << "Failed to get traceback for utterance " << utt; + + std::vector alignment; + std::vector words; + GetLinearSymbolSequence(decoded, &alignment, &words, &weight); + num_frames = alignment.size(); + if (words_writer->IsOpen()) + words_writer->Write(utt, words); + if (alignment_writer->IsOpen()) + alignment_writer->Write(utt, alignment); + if (word_syms != NULL) { + std::cerr << utt << ' '; + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms->Find(words[i]); + if (s == "") + KALDI_ERR << "Word-id " << words[i] <<" not in symbol table."; + std::cerr << s << ' '; + } + std::cerr << '\n'; + } + likelihood = -(weight.Value1() + weight.Value2()); + } + + // Get lattice, and do determinization if requested. + Lattice lat; + decoder.GetRawLattice(&lat); + if (lat.NumStates() == 0) + KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt; + fst::Connect(&lat); + if (determinize) { + CompactLattice clat; + if (!DeterminizeLatticePhonePrunedWrapper( + trans_model, + &lat, + decoder.GetOptions().lattice_beam, + &clat, + decoder.GetOptions().det_opts)) + KALDI_WARN << "Determinization finished earlier than the beam for " + << "utterance " << utt; + // We'll write the lattice without acoustic scaling. + if (acoustic_scale != 0.0) + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat); + compact_lattice_writer->Write(utt, clat); + } else { + Lattice fst; + decoder.GetRawLattice(&fst); + if (fst.NumStates() == 0) + KALDI_ERR << "Unexpected problem getting lattice for utterance " + << utt; + fst::Connect(&fst); // Will get rid of this later... shouldn't have any + // disconnected states there, but we seem to. + if (acoustic_scale != 0.0) // We'll write the lattice without acoustic scaling + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &fst); + lattice_writer->Write(utt, fst); + } + KALDI_LOG << "Log-like per frame for utterance " << utt << " is " + << (likelihood / num_frames) << " over " + << num_frames << " frames."; + KALDI_VLOG(2) << "Cost for utterance " << utt << " is " + << weight.Value1() << " + " << weight.Value2(); + *like_ptr = likelihood; + return true; +} + +} + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + using fst::SymbolTable; + using fst::VectorFst; + using fst::Fst; + using fst::StdArc; + using fst::ReadFstKaldi; + + const char *usage = + "Generate lattices using on-the-fly composition.\n" + "User supplies LM used to generate decoding graph, and desired LM;\n" + "this decoder applies the difference during decoding\n" + "Usage: latgen-biglm-faster-mapped [options] model-in (fst-in|fsts-rspecifier) " + "oldlm-fst-in newlm-fst-in features-rspecifier" + " lattice-wspecifier [ words-wspecifier [alignments-wspecifier] ]\n"; + ParseOptions po(usage); + Timer timer; + bool allow_partial = false; + BaseFloat acoustic_scale = 0.1; + int32 symbol_size = 0; + LatticeBiglmFasterDecoderConfig config; + config.Register(&po); + + ArpaParseOptions arpa_options; + arpa_options.Register(&po); + po.Register("symbol-size", &symbol_size, "symbol table size"); + po.Register("unk-symbol", &arpa_options.unk_symbol, + "Integer corresponds to unknown-word in language model. -1 if " + "no such word is provided."); + po.Register("bos-symbol", &arpa_options.bos_symbol, + "Integer corresponds to . You must set this to your actual " + "BOS integer."); + po.Register("eos-symbol", &arpa_options.eos_symbol, + "Integer corresponds to . You must set this to your actual " + "EOS integer."); + + + std::string word_syms_filename; + po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic likelihoods"); + + po.Register("word-symbol-table", &word_syms_filename, "Symbol table for words [for debug output]"); + po.Register("allow-partial", &allow_partial, "If true, produce output even if end state was not reached."); + + po.Read(argc, argv); + + if (po.NumArgs() < 6 || po.NumArgs() > 8) { + po.PrintUsage(); + exit(1); + } + + std::string model_in_filename = po.GetArg(1), + fst_in_str = po.GetArg(2), + old_lm_fst_rxfilename = po.GetArg(3), + new_lm_fst_rxfilename = po.GetArg(4), + feature_rspecifier = po.GetArg(5), + lattice_wspecifier = po.GetArg(6), + words_wspecifier = po.GetOptArg(7), + alignment_wspecifier = po.GetOptArg(8); + + TransitionModel trans_model; + ReadKaldiObject(model_in_filename, &trans_model); + + /* + FasterArpaLm old_lm; + ReadKaldiObject(old_lm_fst_rxfilename, &old_lm); + FasterArpaLmDeterministicFst old_lm_dfst(old_lm); + ApplyProbabilityScale(-1.0, old_lm_dfst); // Negate old LM probs... + */ +#if 1 + FasterArpaLm old_lm(arpa_options, old_lm_fst_rxfilename, symbol_size, -1); + FasterArpaLmDeterministicFst old_lm_dfst(old_lm); +#else + VectorFst *old_lm_fst = fst::CastOrConvertToVectorFst( + fst::ReadFstKaldiGeneric(old_lm_fst_rxfilename)); + ApplyProbabilityScale(-1.0, old_lm_fst); // Negate old LM probs... + fst::BackoffDeterministicOnDemandFst old_lm_dfst(*old_lm_fst); +#endif + + FasterArpaLm new_lm(arpa_options, new_lm_fst_rxfilename, symbol_size); + FasterArpaLmDeterministicFst new_lm_dfst(new_lm); + + fst::ComposeDeterministicOnDemandFst compose_dfst(&old_lm_dfst, + &new_lm_dfst); + fst::CacheDeterministicOnDemandFst cache_dfst(&compose_dfst, 1e7); + + bool determinize = config.determinize_lattice; + CompactLatticeWriter compact_lattice_writer; + LatticeWriter lattice_writer; + if (! (determinize ? compact_lattice_writer.Open(lattice_wspecifier) + : lattice_writer.Open(lattice_wspecifier))) + KALDI_ERR << "Could not open table for writing lattices: " + << lattice_wspecifier; + + Int32VectorWriter words_writer(words_wspecifier); + + Int32VectorWriter alignment_writer(alignment_wspecifier); + + fst::SymbolTable *word_syms = NULL; + if (word_syms_filename != "") + if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename))) + KALDI_ERR << "Could not read symbol table from file " + << word_syms_filename; + + double tot_like = 0.0; + kaldi::int64 frame_count = 0; + int num_success = 0, num_fail = 0; + + + if (ClassifyRspecifier(fst_in_str, NULL, NULL) == kNoRspecifier) { + SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); + // Input FST is just one FST, not a table of FSTs. + Fst *decode_fst = fst::ReadFstKaldiGeneric(fst_in_str); + + { + LatticeBiglmFasterDecoder decoder(*decode_fst, config, &cache_dfst); + timer.Reset(); + + for (; !feature_reader.Done(); feature_reader.Next()) { + std::string utt = feature_reader.Key(); + Matrix features (feature_reader.Value()); + feature_reader.FreeCurrent(); + if (features.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_fail++; + continue; + } + + DecodableMatrixScaledMapped decodable(trans_model, features, acoustic_scale); + + double like; + if (DecodeUtterance(decoder, decodable, trans_model, word_syms, + utt, acoustic_scale, determinize, allow_partial, + &alignment_writer, &words_writer, + &compact_lattice_writer, &lattice_writer, + &like)) { + tot_like += like; + frame_count += features.NumRows(); + num_success++; + } else num_fail++; + } + } + delete decode_fst; // delete this only after decoder goes out of scope. + } else { // We have different FSTs for different utterances. + assert(0); + } + + double elapsed = timer.Elapsed(); + KALDI_LOG << "Time taken "<< elapsed + << "s: real-time factor assuming 100 frames/sec is " + << (elapsed*100.0/frame_count); + KALDI_LOG << "Done " << num_success << " utterances, failed for " + << num_fail; + KALDI_LOG << "Overall log-likelihood per frame is " << (tot_like/frame_count) << " over " + << frame_count<<" frames."; + + delete word_syms; + if (num_success != 0) return 0; + else return 1; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/decoder/lattice-otfres-biglm-faster-decoder.h b/src/decoder/lattice-otfres-biglm-faster-decoder.h new file mode 100644 index 00000000000..841547b9cca --- /dev/null +++ b/src/decoder/lattice-otfres-biglm-faster-decoder.h @@ -0,0 +1,957 @@ +// decoder/lattice-otfres-biglm-faster-decoder.h + +// Copyright 2009-2011 Microsoft Corporation, Mirko Hannemann, +// Gilles Boulianne + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_LATTICE_BIGLM_FASTER_DECODER_H_ +#define KALDI_DECODER_LATTICE_BIGLM_FASTER_DECODER_H_ + + +#include "util/stl-utils.h" +#include "util/hash-list.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/kaldi-lattice.h" +#include "decoder/lattice-faster-decoder.h" // for options. + + +namespace kaldi { + +// The options are the same as for lattice-faster-decoder.h for now. +typedef LatticeFasterDecoderConfig LatticeBiglmFasterDecoderConfig; + +/** This is as LatticeFasterDecoder, but does online composition between + HCLG and the "difference language model", which is a deterministic + FST that represents the difference between the language model you want + and the language model you compiled HCLG with. The class + DeterministicOnDemandFst follows through the epsilons in G for you + (assuming G is a standard backoff language model) and makes it look + like a determinized FST. +*/ + +class LatticeBiglmFasterDecoder { + public: + typedef fst::StdArc Arc; + typedef Arc::Label Label; + typedef Arc::StateId StateId; + // A PairId will be constructed as: (StateId in fst) + (StateId in lm_diff_fst) << 32; + typedef uint64 PairId; + typedef Arc::Weight Weight; + // instantiate this class once for each thing you have to decode. + LatticeBiglmFasterDecoder( + const fst::Fst &fst, + const LatticeBiglmFasterDecoderConfig &config, + fst::DeterministicOnDemandFst *lm_diff_fst): + fst_(fst), lm_diff_fst_(lm_diff_fst), config_(config), + warned_noarc_(false), num_toks_(0) { + config.Check(); + KALDI_ASSERT(fst.Start() != fst::kNoStateId && + lm_diff_fst->Start() != fst::kNoStateId); + toks_.SetSize(1000); // just so on the first frame we do something reasonable. + toks_g1.SetSize(1000); // just so on the first frame we do something reasonable. + } + void SetOptions(const LatticeBiglmFasterDecoderConfig &config) { config_ = config; } + LatticeBiglmFasterDecoderConfig GetOptions() { return config_; } + ~LatticeBiglmFasterDecoder() { + DeleteElems(toks_.Clear()); + ClearActiveTokens(); + } + + // Returns true if any kind of traceback is available (not necessarily from + // a final state). + bool Decode(DecodableInterface *decodable) { + // clean up from last time: + DeleteElems(toks_.Clear()); + ClearActiveTokens(); + warned_ = false; + final_active_ = false; + final_costs_.clear(); + num_toks_ = 0; + PairId start_pair = ConstructPair(fst_.Start(), lm_diff_fst_->Start()); + active_toks_.resize(1); + Token *start_tok = new Token(0.0, 0.0, NULL, NULL); + active_toks_[0].toks = start_tok; + toks_.Insert(start_pair, start_tok); + toks_g1.Insert(PairToState(start_pair), start_pair); + num_toks_++; + ProcessNonemitting(0); + + // We use 1-based indexing for frames in this decoder (if you view it in + // terms of features), but note that the decodable object uses zero-based + // numbering, which we have to correct for when we call it. + for (int32 frame = 1; !decodable->IsLastFrame(frame-2); frame++) { + active_toks_.resize(frame+1); // new column + + ProcessEmitting(decodable, frame); + + ProcessNonemitting(frame); + + if (decodable->IsLastFrame(frame-1)) + PruneActiveTokensFinal(frame); + else if (frame % config_.prune_interval == 0) + PruneActiveTokens(frame, config_.lattice_beam * 0.1); // use larger delta. + } + // Returns true if we have any kind of traceback available (not necessarily + // to the end state; query ReachedFinal() for that). + return !final_costs_.empty(); + } + + /// says whether a final-state was active on the last frame. If it was not, the + /// lattice (or traceback) will end with states that are not final-states. + bool ReachedFinal() const { return final_active_; } + + + // Outputs an FST corresponding to the single best path + // through the lattice. + bool GetBestPath(fst::MutableFst *ofst, + bool use_final_probs = true) const { + fst::VectorFst fst; + if (!GetRawLattice(&fst, use_final_probs)) return false; + // std::cout << "Raw lattice is:\n"; + // fst::FstPrinter fstprinter(fst, NULL, NULL, NULL, false, true); + // fstprinter.Print(&std::cout, "standard output"); + ShortestPath(fst, ofst); + return true; + } + + // Outputs an FST corresponding to the raw, state-level + // tracebacks. + bool GetRawLattice(fst::MutableFst *ofst, + bool use_final_probs = true) const { + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + // A PairId will be constructed as: (StateId in fst) + (StateId in lm_diff_fst) << 32; + typedef uint64 PairId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + ofst->DeleteStates(); + // num-frames plus one (since frames are one-based, and we have + // an extra frame for the start-state). + int32 num_frames = active_toks_.size() - 1; + KALDI_ASSERT(num_frames > 0); + unordered_map tok_map(num_toks_/2 + 3); // bucket count + // First create all states. + for (int32 f = 0; f <= num_frames; f++) { + if (active_toks_[f].toks == NULL) { + KALDI_WARN << "GetRawLattice: no tokens active on frame " << f + << ": not producing lattice.\n"; + return false; + } + for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) + tok_map[tok] = ofst->AddState(); + // The next statement sets the start state of the output FST. + // Because we always add new states to the head of the list + // active_toks_[f].toks, and the start state was the first one + // added, it will be the last one added to ofst. + if (f == 0 && ofst->NumStates() > 0) + ofst->SetStart(ofst->NumStates()-1); + } + KALDI_VLOG(3) << "init:" << num_toks_/2 + 3 << " buckets:" + << tok_map.bucket_count() << " load:" << tok_map.load_factor() + << " max:" << tok_map.max_load_factor(); + // Now create all arcs. + StateId cur_state = 0; // we rely on the fact that we numbered these + // consecutively (AddState() returns the numbers in order..) + for (int32 f = 0; f <= num_frames; f++) { + for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next, + cur_state++) { + for (ForwardLink *l = tok->links; + l != NULL; + l = l->next) { + unordered_map::const_iterator iter = + tok_map.find(l->next_tok); + StateId nextstate = iter->second; + KALDI_ASSERT(iter != tok_map.end()); + Arc arc(l->ilabel, l->olabel, + Weight(l->graph_cost, l->acoustic_cost), + nextstate); + ofst->AddArc(cur_state, arc); + } + if (f == num_frames) { + if (use_final_probs && !final_costs_.empty()) { + std::map::const_iterator iter = + final_costs_.find(tok); + if (iter != final_costs_.end()) + ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); + } else { + ofst->SetFinal(cur_state, LatticeWeight::One()); + } + } + } + } + KALDI_ASSERT(cur_state == ofst->NumStates()); + return (cur_state != 0); + } + + // This function is now deprecated, since now we do determinization from + // outside the LatticeBiglmFasterDecoder class. + // Outputs an FST corresponding to the lattice-determinized + // lattice (one path per word sequence). + bool GetLattice(fst::MutableFst *ofst, + bool use_final_probs = true) const { + Lattice raw_fst; + if (!GetRawLattice(&raw_fst, use_final_probs)) return false; + Invert(&raw_fst); // make it so word labels are on the input. + if (!TopSort(&raw_fst)) // topological sort makes lattice-determinization more efficient + KALDI_WARN << "Topological sorting of state-level lattice failed " + "(probably your lexicon has empty words or your LM has epsilon cycles; this " + " is a bad idea.)"; + // (in phase where we get backward-costs). + fst::ILabelCompare ilabel_comp; + ArcSort(&raw_fst, ilabel_comp); // sort on ilabel; makes + // lattice-determinization more efficient. + + fst::DeterminizeLatticePrunedOptions lat_opts; + lat_opts.max_mem = config_.det_opts.max_mem; + + DeterminizeLatticePruned(raw_fst, config_.lattice_beam, ofst, lat_opts); + raw_fst.DeleteStates(); // Free memory-- raw_fst no longer needed. + Connect(ofst); // Remove unreachable states... there might be + // a small number of these, in some cases. + return true; + } + + private: + inline PairId ConstructPair(StateId fst_state, StateId lm_state) { + return static_cast(fst_state) + (static_cast(lm_state) << 32); + } + + static inline StateId PairToState(PairId state_pair) { + return static_cast(static_cast(state_pair)); + } + static inline StateId PairToLmState(PairId state_pair) { + return static_cast(static_cast(state_pair >> 32)); + } + + struct Token; + // ForwardLinks are the links from a token to a token on the next frame. + // or sometimes on the current frame (for input-epsilon links). + struct ForwardLink { + Token *next_tok; // the next token [or NULL if represents final-state] + Label ilabel; // ilabel on link. + Label olabel; // olabel on link. + BaseFloat graph_cost; // graph cost of traversing link (contains LM, etc.) + BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing link + ForwardLink *next; // next in singly-linked list of forward links from a + // token. + inline ForwardLink(Token *next_tok, Label ilabel, Label olabel, + BaseFloat graph_cost, BaseFloat acoustic_cost, + ForwardLink *next): + next_tok(next_tok), ilabel(ilabel), olabel(olabel), + graph_cost(graph_cost), acoustic_cost(acoustic_cost), + next(next) { } + }; + + // Token is what's resident in a particular state at a particular time. + // In this decoder a Token actually contains *forward* links. + // When first created, a Token just has the (total) cost. We add forward + // links to it when we process the next frame. + struct Token { + BaseFloat tot_cost; // would equal weight.Value()... cost up to this point. + BaseFloat extra_cost; // >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + + ForwardLink *links; // Head of singly linked list of ForwardLinks + + Token *next; // Next in list of tokens for this frame. + + inline Token(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLink *links, + Token *next): tot_cost(tot_cost), extra_cost(extra_cost), + links(links), next(next) { } + inline void DeleteForwardLinks() { + ForwardLink *l = links, *m; + while (l != NULL) { + m = l->next; + delete l; + l = m; + } + links = NULL; + } + }; + + // head and tail of per-frame list of Tokens (list is in topological order), + // and something saying whether we ever pruned it using PruneForwardLinks. + struct TokenList { + Token *toks; + bool must_prune_forward_links; + bool must_prune_tokens; + TokenList(): toks(NULL), must_prune_forward_links(true), + must_prune_tokens(true) { } + }; + + typedef HashList::Elem Elem; + typedef HashList::Elem Elem_g1; + + void PossiblyResizeHash(size_t num_toks) { + size_t new_sz = static_cast(static_cast(num_toks) + * config_.hash_ratio); + if (new_sz > toks_.Size()) { + toks_.SetSize(new_sz); + } + if (new_sz > toks_g1.Size()) { + toks_g1.SetSize(new_sz); + } + } + + // FindOrAddToken either locates a token in hash of toks_, + // or if necessary inserts a new, empty token (i.e. with no forward links) + // for the current frame. [note: it's inserted if necessary into hash toks_ + // and also into the singly linked list of tokens active on this frame + // (whose head is at active_toks_[frame]). + inline Token *FindOrAddToken_2(PairId state_pair, int32 frame, BaseFloat tot_cost, + bool emitting, bool *changed) { + // Returns the Token pointer. Sets "changed" (if non-NULL) to true + // if the token was newly created or the cost changed. + KALDI_ASSERT(frame < active_toks_.size()); + Token *&toks = active_toks_[frame].toks; + Elem *e_found = toks_.Find(state_pair); + if (e_found == NULL) { // no such token presently. + const BaseFloat extra_cost = 0.0; + // tokens on the currently final frame have zero extra_cost + // as any of them could end up + // on the winning path. + Token *new_tok = new Token (tot_cost, extra_cost, NULL, toks); + // NULL: no forward links yet + toks = new_tok; + num_toks_++; + toks_.Insert(state_pair, new_tok); + if (changed) *changed = true; + return new_tok; + } else { + Token *tok = e_found->val; // There is an existing Token for this state. + if (tok->tot_cost > tot_cost) { // replace old token + tok->tot_cost = tot_cost; + // we don't allocate a new token, the old stays linked in active_toks_ + // we only replace the tot_cost + // in the current frame, there are no forward links (and no extra_cost) + // only in ProcessNonemitting we have to delete forward links + // in case we visit a state for the second time + // those forward links, that lead to this replaced token before: + // they remain and will hopefully be pruned later (PruneForwardLinks...) + if (changed) *changed = true; + } else { + if (changed) *changed = false; + } + return tok; + } + } +#define res_beam 1 + inline bool FindOrAddToken(StateId state_id, int32 frame, BaseFloat tot_cost, + bool emitting, bool *changed, bool pp) { + // Returns the Token pointer. Sets "changed" (if non-NULL) to true + // if the token was newly created or the cost changed. + KALDI_ASSERT(frame < active_toks_.size()); + Elem_g1 *e_found = toks_g1.Find(state_id); + if (e_found == NULL) { // no such token presently. + toks_g1.Insert(state_id, tot_cost); + return true; + } else { + if (tot_cost < e_found->val + res_beam) {// There is an existing Token for this state. + if (tot_cost < e_found->val) + e_found->val = tot_cost; + return true; + } + else if (pp) { + return false; + } + else { + return true; + } + } + } + + // prunes outgoing links for all tokens in active_toks_[frame] + // it's called by PruneActiveTokens + // all links, that have link_extra_cost > lattice_beam are pruned + void PruneForwardLinks(int32 frame, bool *extra_costs_changed, + bool *links_pruned, + BaseFloat delta) { + // delta is the amount by which the extra_costs must change + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + + *extra_costs_changed = false; + *links_pruned = false; + KALDI_ASSERT(frame >= 0 && frame < active_toks_.size()); + if (active_toks_[frame].toks == NULL ) { // empty list; should not happen. + if (!warned_) { + KALDI_WARN << "No tokens alive [doing pruning].. warning first " + "time only for each utterance\n"; + warned_ = true; + } + } + + // We have to iterate until there is no more change, because the links + // are not guaranteed to be in topological order. + bool changed = true; // difference new minus old extra cost >= delta ? + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + ForwardLink *link, *prev_link=NULL; + // will recompute tok_extra_cost for tok. + BaseFloat tok_extra_cost = std::numeric_limits::infinity(); + // tok_extra_cost is the best (min) of link_extra_cost of outgoing links + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); // difference in brackets is >= 0 + // link_exta_cost is the difference in score between the best paths + // through link source state and through link destination state + KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLink *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + *links_pruned = true; + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; // move to next link + link = link->next; + } + } // for all outgoing links + if (fabs(tok_extra_cost - tok->extra_cost) > delta) + changed = true; // difference new minus old is bigger than delta + tok->extra_cost = tok_extra_cost; + // will be +infinity or <= lattice_beam_. + // infinity indicates, that no forward link survived pruning + } // for all Token on active_toks_[frame] + if (changed) *extra_costs_changed = true; + + // Note: it's theoretically possible that aggressive compiler + // optimizations could cause an infinite loop here for small delta and + // high-dynamic-range scores. + } // while changed + } + + // PruneForwardLinksFinal is a version of PruneForwardLinks that we call + // on the final frame. If there are final tokens active, it uses + // the final-probs for pruning, otherwise it treats all tokens as final. + void PruneForwardLinksFinal(int32 frame) { + KALDI_ASSERT(static_cast(frame+1) == active_toks_.size()); + if (active_toks_[frame].toks == NULL ) // empty list; should not happen. + KALDI_WARN << "No tokens alive at end of file\n"; + + // First go through, working out the best token (do it in parallel + // including final-probs and not including final-probs; we'll take + // the one with final-probs if it's valid). + const BaseFloat infinity = std::numeric_limits::infinity(); + BaseFloat best_cost_final = infinity, + best_cost_nofinal = infinity; + unordered_map tok_to_final_cost; + Elem *cur_toks = toks_.Clear(); // swapping prev_toks_ / cur_toks_ + DeleteElems_1(toks_g1.Clear()); + for (Elem *e = cur_toks, *e_tail; e != NULL; e = e_tail) { + PairId state_pair = e->key; + StateId state = PairToState(state_pair), + lm_state = PairToLmState(state_pair); + Token *tok = e->val; + BaseFloat final_cost = fst_.Final(state).Value() + + lm_diff_fst_->Final(lm_state).Value(); + tok_to_final_cost[tok] = final_cost; + best_cost_final = std::min(best_cost_final, tok->tot_cost + final_cost); + best_cost_nofinal = std::min(best_cost_nofinal, tok->tot_cost); + e_tail = e->tail; + toks_.Delete(e); + } + final_active_ = (best_cost_final != infinity); + + // Now go through tokens on this frame, pruning forward links... may have + // to iterate a few times until there is no more change, because the list is + // not in topological order. + + bool changed = true; + BaseFloat delta = 1.0e-05; + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + ForwardLink *link, *prev_link=NULL; + // will recompute tok_extra_cost. It has a term in it that corresponds + // to the "final-prob", so instead of initializing tok_extra_cost to infinity + // below we set it to the difference between the (score+final_prob) of this token, + // and the best such (score+final_prob). + BaseFloat tok_extra_cost; + if (final_active_) { + BaseFloat final_cost = tok_to_final_cost[tok]; + tok_extra_cost = (tok->tot_cost + final_cost) - best_cost_final; + } else + tok_extra_cost = tok->tot_cost - best_cost_nofinal; + + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLink *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; + link = link->next; + } + } + // prune away tokens worse than lattice_beam above best path. This step + // was not necessary in the non-final case because then, this case + // showed up as having no forward links. Here, the tok_extra_cost has + // an extra component relating to the final-prob. + if (tok_extra_cost > config_.lattice_beam) + tok_extra_cost = infinity; + // to be pruned in PruneTokensForFrame + + if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) + changed = true; + tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. + } + } // while changed + + // Now put surviving Tokens in the final_costs_ hash, which is a class + // member (unlike tok_to_final_costs). + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + if (tok->extra_cost != infinity) { + // If the token was not pruned away, + if (final_active_) { + BaseFloat final_cost = tok_to_final_cost[tok]; + if (final_cost != infinity) + final_costs_[tok] = final_cost; + } else { + final_costs_[tok] = 0; + } + } + } + } + + // Prune away any tokens on this frame that have no forward links. + // [we don't do this in PruneForwardLinks because it would give us + // a problem with dangling pointers]. + // It's called by PruneActiveTokens if any forward links have been pruned + void PruneTokensForFrame(int32 frame) { + KALDI_ASSERT(frame >= 0 && frame < active_toks_.size()); + Token *&toks = active_toks_[frame].toks; + if (toks == NULL) + KALDI_WARN << "No tokens alive [doing pruning]\n"; + Token *tok, *next_tok, *prev_tok = NULL; + for (tok = toks; tok != NULL; tok = next_tok) { + next_tok = tok->next; + if (tok->extra_cost == std::numeric_limits::infinity()) { + // token is unreachable from end of graph; (no forward links survived) + // excise tok from list and delete tok. + if (prev_tok != NULL) prev_tok->next = tok->next; + else toks = tok->next; + delete tok; + num_toks_--; + } else { // fetch next Token + prev_tok = tok; + } + } + } + + // Go backwards through still-alive tokens, pruning them. note: cur_frame is + // where hash toks_ are (so we do not want to mess with it because these tokens + // don't yet have forward pointers), but we do all previous frames, unless we + // know that we can safely ignore them because the frame after them was unchanged. + // delta controls when it considers a cost to have changed enough to continue + // going backward and propagating the change. + // for a larger delta, we will recurse less far back + void PruneActiveTokens(int32 cur_frame, BaseFloat delta) { + int32 num_toks_begin = num_toks_; + for (int32 frame = cur_frame-1; frame >= 0; frame--) { + // Reason why we need to prune forward links in this situation: + // (1) we have never pruned them (new TokenList) + // (2) we have not yet pruned the forward links to the next frame, + // after any of those tokens have changed their extra_cost. + if (active_toks_[frame].must_prune_forward_links) { + bool extra_costs_changed = false, links_pruned = false; + PruneForwardLinks(frame, &extra_costs_changed, &links_pruned, delta); + if (extra_costs_changed && frame > 0) // any token has changed extra_cost + active_toks_[frame-1].must_prune_forward_links = true; + if (links_pruned) // any link was pruned + active_toks_[frame].must_prune_tokens = true; + active_toks_[frame].must_prune_forward_links = false; // job done + } + if (frame+1 < cur_frame && // except for last frame (no forward links) + active_toks_[frame+1].must_prune_tokens) { + PruneTokensForFrame(frame+1); + active_toks_[frame+1].must_prune_tokens = false; + } + } + KALDI_VLOG(3) << "PruneActiveTokens: pruned tokens from " << num_toks_begin + << " to " << num_toks_; + } + + // Version of PruneActiveTokens that we call on the final frame. + // Takes into account the final-prob of tokens. + void PruneActiveTokensFinal(int32 cur_frame) { + // returns true if there were final states active + // else returns false and treats all states as final while doing the pruning + // (this can be useful if you want partial lattice output, + // although it can be dangerous, depending what you want the lattices for). + // final_active_ and final_probs_ (a hash) are set internally + // by PruneForwardLinksFinal + int32 num_toks_begin = num_toks_; + PruneForwardLinksFinal(cur_frame); // prune final frame (with final-probs) + // sets final_active_ and final_probs_ + for (int32 frame = cur_frame-1; frame >= 0; frame--) { + bool b1, b2; // values not used. + BaseFloat dontcare = 0.0; // delta of zero means we must always update + PruneForwardLinks(frame, &b1, &b2, dontcare); + PruneTokensForFrame(frame+1); + } + PruneTokensForFrame(0); + KALDI_VLOG(3) << "PruneActiveTokensFinal: pruned tokens from " << num_toks_begin + << " to " << num_toks_; + } + + /// Gets the weight cutoff. Also counts the active tokens. + BaseFloat GetCutoff(Elem *list_head, size_t *tok_count, + BaseFloat *adaptive_beam, Elem **best_elem) { + BaseFloat best_weight = std::numeric_limits::infinity(); + // positive == high cost == bad. + size_t count = 0; + if (config_.max_active == std::numeric_limits::max() && + config_.min_active == 0) { + for (Elem *e = list_head; e != NULL; e = e->tail, count++) { + BaseFloat w = static_cast(e->val->tot_cost); + if (w < best_weight) { + best_weight = w; + if (best_elem) *best_elem = e; + } + } + if (tok_count != NULL) *tok_count = count; + if (adaptive_beam != NULL) *adaptive_beam = config_.beam; + return best_weight + config_.beam; + } else { + tmp_array_.clear(); + for (Elem *e = list_head; e != NULL; e = e->tail, count++) { + BaseFloat w = e->val->tot_cost; + tmp_array_.push_back(w); + if (w < best_weight) { + best_weight = w; + if (best_elem) *best_elem = e; + } + } + if (tok_count != NULL) *tok_count = count; + + BaseFloat beam_cutoff = best_weight + config_.beam, + min_active_cutoff = std::numeric_limits::infinity(), + max_active_cutoff = std::numeric_limits::infinity(); + + KALDI_VLOG(6) << "Number of tokens active on frame " << active_toks_.size() + << " is " << tmp_array_.size(); + + if (tmp_array_.size() > static_cast(config_.max_active)) { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.max_active, + tmp_array_.end()); + max_active_cutoff = tmp_array_[config_.max_active]; + } + if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam. + if (adaptive_beam) + *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta; + return max_active_cutoff; + } + if (tmp_array_.size() > static_cast(config_.min_active)) { + if (config_.min_active == 0) min_active_cutoff = best_weight; + else { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.min_active, + tmp_array_.size() > static_cast(config_.max_active) ? + tmp_array_.begin() + config_.max_active : + tmp_array_.end()); + min_active_cutoff = tmp_array_[config_.min_active]; + } + } + if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam. + if (adaptive_beam) + *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta; + return min_active_cutoff; + } else { + *adaptive_beam = config_.beam; + return beam_cutoff; + } + } + } + + inline StateId PropagateLm(StateId lm_state, + Arc *arc, bool *pp=NULL) { // returns new LM state. + if (arc->olabel == 0) { + if (pp) *pp=false; + return lm_state; // no change in LM state if no word crossed. + } else { // Propagate in the LM-diff FST. + if (pp) *pp=false; + Arc lm_arc; + bool ans = lm_diff_fst_->GetArc(lm_state, arc->olabel, &lm_arc); + if (!ans) { // this case is unexpected for statistical LMs. + if (!warned_noarc_) { + warned_noarc_ = true; + KALDI_WARN << "No arc available in LM (unlikely to be correct " + "if a statistical language model); will not warn again"; + } + arc->weight = Weight::Zero(); + return lm_state; // doesn't really matter what we return here; will + // be pruned. + } else { + arc->weight = Times(arc->weight, lm_arc.weight); + arc->olabel = lm_arc.olabel; // probably will be the same. + return lm_arc.nextstate; // return the new LM state. + } + } + } + + void ProcessEmitting(DecodableInterface *decodable, int32 frame) { + // Processes emitting arcs for one frame. Propagates from prev_toks_ to cur_toks_. + Elem *last_toks = toks_.Clear(); // swapping prev_toks_ / cur_toks_ + DeleteElems_1(toks_g1.Clear()); + Elem *best_elem = NULL; + BaseFloat adaptive_beam; + size_t tok_cnt; + BaseFloat cur_cutoff = GetCutoff(last_toks, &tok_cnt, &adaptive_beam, &best_elem); + PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough. + KALDI_VLOG(6) << "Adaptive beam on frame " << frame << "\t" << active_toks_.size() << " is " + << adaptive_beam << "\t" << cur_cutoff; + + + BaseFloat next_cutoff = std::numeric_limits::infinity(); + // pruning "online" before having seen all tokens + + // First process the best token to get a hopefully + // reasonably tight bound on the next cutoff. + if (best_elem) { + PairId state_pair = best_elem->key; + StateId state = PairToState(state_pair), // state in "fst" + lm_state = PairToLmState(state_pair); + Token *tok = best_elem->val; + for (fst::ArcIterator > aiter(fst_, state); + !aiter.Done(); + aiter.Next()) { + Arc arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + PropagateLm(lm_state, &arc); // may affect "arc.weight". + // We don't need the return value (the new LM state). + arc.weight = Times(arc.weight, + Weight(-decodable->LogLikelihood(frame-1, arc.ilabel))); + BaseFloat new_weight = arc.weight.Value() + tok->tot_cost; + if (new_weight + adaptive_beam < next_cutoff) + next_cutoff = new_weight + adaptive_beam; + } + } + } + + // the tokens are now owned here, in last_toks, and the hash is empty. + // 'owned' is a complex thing here; the point is we need to call DeleteElem + // on each elem 'e' to let toks_ know we're done with them. + for (Elem *e = last_toks, *e_tail; e != NULL; e = e_tail) { + // loop this way because we delete "e" as we go. + PairId state_pair = e->key; + StateId state = PairToState(state_pair), + lm_state = PairToLmState(state_pair); + Token *tok = e->val; + if (tok->tot_cost <= cur_cutoff) { + for (fst::ArcIterator > aiter(fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc_ref = aiter.Value(); + if (arc_ref.ilabel != 0) { // propagate.. + Arc arc(arc_ref); + bool pp=arc.olabel>0; + BaseFloat ac_cost = -decodable->LogLikelihood(frame-1, arc.ilabel); + if (!FindOrAddToken(arc.nextstate, frame, tok->tot_cost + ac_cost+ arc.weight.Value(), true, NULL, pp)) continue; + StateId next_lm_state = PropagateLm(lm_state, &arc, &pp); + BaseFloat graph_cost = arc.weight.Value(), + cur_cost = tok->tot_cost, + tot_cost = cur_cost + ac_cost + graph_cost; + if (tot_cost > next_cutoff) continue; + else if (tot_cost + config_.beam < next_cutoff) + next_cutoff = tot_cost + config_.beam; // prune by best current token + PairId next_pair = ConstructPair(arc.nextstate, next_lm_state); + Token *next_tok = FindOrAddToken_2(next_pair, frame, tot_cost, true, NULL); + // true: emitting, NULL: no change indicator needed + + // Add ForwardLink from tok to next_tok (put on head of list tok->links) + tok->links = new ForwardLink(next_tok, arc.ilabel, arc.olabel, + graph_cost, ac_cost, tok->links); + } + } // for all arcs + } + e_tail = e->tail; + toks_.Delete(e); // delete Elem + } + } + + void ProcessNonemitting(int32 frame) { + // note: "frame" is the same as emitting states just processed. + + // Processes nonemitting arcs for one frame. Propagates within toks_. + // Note-- this queue structure is is not very optimal as + // it may cause us to process states unnecessarily (e.g. more than once), + // but in the baseline code, turning this vector into a set to fix this + // problem did not improve overall speed. + + KALDI_ASSERT(queue_.empty()); + BaseFloat best_cost = std::numeric_limits::infinity(); + for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) { + queue_.push_back(e->key); + // for pruning with current best token + best_cost = std::min(best_cost, static_cast(e->val->tot_cost)); + } + if (queue_.empty()) { + if (!warned_) { + KALDI_ERR << "Error in ProcessEmitting: no surviving tokens: frame is " + << frame; + warned_ = true; + } + } + BaseFloat cutoff = best_cost + config_.beam; + + while (!queue_.empty()) { + PairId state_pair = queue_.back(); + queue_.pop_back(); + + Token *tok = toks_.Find(state_pair)->val; // would segfault if state not in + // toks_ but this can't happen. + BaseFloat cur_cost = tok->tot_cost; + if (cur_cost > cutoff) // Don't bother processing successors. + continue; + StateId state = PairToState(state_pair), + lm_state = PairToLmState(state_pair); + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + // but since most states are emitting it's not a huge issue. + tok->DeleteForwardLinks(); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator > aiter(fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc_ref = aiter.Value(); + if (arc_ref.ilabel == 0) { // propagate nonemitting only... + Arc arc(arc_ref); + bool pp=arc.olabel>0; + if (!FindOrAddToken(arc.nextstate, frame, tok->tot_cost + arc.weight.Value(), true, NULL, pp)) continue; + StateId next_lm_state = PropagateLm(lm_state, &arc, &pp); + BaseFloat graph_cost = arc.weight.Value(), + tot_cost = cur_cost + graph_cost; + if (tot_cost < cutoff) { + bool changed; + PairId next_pair = ConstructPair(arc.nextstate, next_lm_state); + Token *new_tok = FindOrAddToken_2(next_pair, frame, tot_cost, + false, &changed); // false: non-emit + + tok->links = new ForwardLink(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new [if so, add into queue]. + if (changed) queue_.push_back(next_pair); + } + } + } // for all arcs + } // while queue not empty + } + + + // HashList defined in ../util/hash-list.h. It actually allows us to maintain + // more than one list (e.g. for current and previous frames), but only one of + // them at a time can be indexed by StateId. + HashList toks_; + HashList toks_g1; + std::vector active_toks_; // Lists of tokens, indexed by + // frame (members of TokenList are toks, must_prune_forward_links, + // must_prune_tokens). + std::vector queue_; // temp variable used in ProcessNonemitting, + std::vector tmp_array_; // used in GetCutoff. + // make it class member to avoid internal new/delete. + const fst::Fst &fst_; + fst::DeterministicOnDemandFst *lm_diff_fst_; + LatticeBiglmFasterDecoderConfig config_; + bool warned_noarc_; + int32 num_toks_; // current total #toks allocated... + bool warned_; + bool final_active_; // use this to say whether we found active final tokens + // on the last frame. + std::map final_costs_; // A cache of final-costs + // of tokens on the last frame-- it's just convenient to store it this way. + + // It might seem unclear why we call DeleteElems(toks_.Clear()). + // There are two separate cleanup tasks we need to do at when we start a new file. + // one is to delete the Token objects in the list; the other is to delete + // the Elem objects. toks_.Clear() just clears them from the hash and gives ownership + // to the caller, who then has to call toks_.Delete(e) for each one. It was designed + // this way for convenience in propagating tokens from one frame to the next. + void DeleteElems(Elem *list) { + for (Elem *e = list, *e_tail; e != NULL; e = e_tail) { + e_tail = e->tail; + toks_.Delete(e); + } + toks_.Clear(); + DeleteElems_1(toks_g1.Clear()); + } + void DeleteElems_1(Elem_g1 *list) { + for (Elem_g1 *e = list, *e_tail; e != NULL; e = e_tail) { + e_tail = e->tail; + toks_g1.Delete(e); + } + toks_g1.Clear(); + } + + void ClearActiveTokens() { // a cleanup routine, at utt end/begin + for (size_t i = 0; i < active_toks_.size(); i++) { + // Delete all tokens alive on this frame, and any forward + // links they may have. + for (Token *tok = active_toks_[i].toks; tok != NULL; ) { + tok->DeleteForwardLinks(); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + active_toks_.clear(); + KALDI_ASSERT(num_toks_ == 0); + } +}; + +} // end namespace kaldi. + +#endif From 0c94f2eaa55487267cb6c91ac00a3b7154283bf8 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Sat, 14 Apr 2018 17:25:49 -0700 Subject: [PATCH 27/93] remove map in LM, comes to 1.30; exp_dec/fasterlm.1b2/dec.log --- .../lattice-otfres-biglm-faster-decoder.h | 2 +- src/lm/faster-arpa-lm.h | 92 +++++++++---------- 2 files changed, 44 insertions(+), 50 deletions(-) diff --git a/src/decoder/lattice-otfres-biglm-faster-decoder.h b/src/decoder/lattice-otfres-biglm-faster-decoder.h index 841547b9cca..ac682024ccb 100644 --- a/src/decoder/lattice-otfres-biglm-faster-decoder.h +++ b/src/decoder/lattice-otfres-biglm-faster-decoder.h @@ -355,7 +355,7 @@ class LatticeBiglmFasterDecoder { return tok; } } -#define res_beam 1 +#define res_beam 0.5 inline bool FindOrAddToken(StateId state_id, int32 frame, BaseFloat tot_cost, bool emitting, bool *changed, bool pp) { // Returns the Token pointer. Sets "changed" (if non-NULL) to true diff --git a/src/lm/faster-arpa-lm.h b/src/lm/faster-arpa-lm.h index fab511c0ea5..808d56008b8 100644 --- a/src/lm/faster-arpa-lm.h +++ b/src/lm/faster-arpa-lm.h @@ -47,7 +47,7 @@ class FasterArpaLm { // LmState in FasterArpaLm: the basic storage unit class LmState { public: - LmState(): logprob_(0), h_value(0), next(NULL) { } + LmState(): logprob_(0), h_value(0), word_ids_(NULL), next(NULL) { } LmState(float logprob, float backoff_logprob): logprob_(logprob), backoff_logprob_(backoff_logprob), h_value(0), next(NULL) { } @@ -60,14 +60,21 @@ class FasterArpaLm { int32 sz= sizeof(int32)*(ngram_order); */ } + void SaveWordIds(const int32 *word_ids, const int32 ngram_order) { + word_ids_ = (int32 *)malloc(sizeof(int32)*ngram_order); + for (int i=0; ih_value == h_value) { - return lm_state; + ret_lm_state = lm_state; + break; } lm_state = lm_state->next; } } + if (ret_lm_state && lm_state_idx) *lm_state_idx = ret_lm_state - ngrams_; // not found, can be bug or really not found the corresponding ngram - return NULL; + return ret_lm_state; } inline const LmState* GetHashedState(const std::vector &word_ids, bool reverse = false, int query_ngram_order = 0) const { @@ -215,9 +227,15 @@ class FasterArpaLm { // if exist, get logprob_, else get backoff_logprob_ // memcpy(n_wids+1, wids, len(wids)); n_wids[0] = cur_wrd; + inline void GetWordIdsByLmStateIdx(int32 **word_ids, + int32 *word_ngram_order, int32 lm_state_idx) const { + *word_ids = ngrams_[lm_state_idx].word_ids_; + *word_ngram_order = ngrams_[lm_state_idx].ngram_order_; + } + inline float GetNgramLogprob(const int32 *word_ids, const int32 word_ngram_order, - std::vector& o_word_ids) const { + int32 *lm_state_idx) const { float prob; int32 ngram_order = word_ngram_order; assert(ngram_order > 0); @@ -245,14 +263,8 @@ class FasterArpaLm { */ // below code is to make sure the LmState exist, so un-exist states can be recombined to a same state ngram_order = std::min(ngram_order,ngram_order_-1); - while(!GetHashedState(word_ids, ngram_order)) ngram_order--; + while(!GetHashedState(word_ids, ngram_order, lm_state_idx)) ngram_order--; assert(ngram_order>0); - - o_word_ids.resize(ngram_order); - for (int i=0; i 1); // thus we can do backoff const LmState *lm_state_bo = GetHashedState(word_ids + 1, ngram_order-1); @@ -260,7 +272,7 @@ class FasterArpaLm { //assert(lm_state_bo && lm_state_bo->IsExist()); // TODO: assert will fail because some place has false-exist? 84746 4447 8537 without 4447 8537 in LM prob = lm_state_bo? lm_state_bo->backoff_logprob_:0; - prob += GetNgramLogprob(word_ids, ngram_order - 1, o_word_ids); + prob += GetNgramLogprob(word_ids, ngram_order - 1, lm_state_idx); } return prob; } @@ -343,8 +355,6 @@ class FasterArpaLm { // Size of the array, which will be needed by I/O. int64 lm_states_size_; // Hash table from word sequences to LmStates. - unordered_map, - LmState*, VectorHasher > seq_to_state_; ArpaParseOptions &options_; // data @@ -376,10 +386,10 @@ class FasterArpaLmDeterministicFst explicit FasterArpaLmDeterministicFst(const FasterArpaLm& lm): start_state_(0), lm_(lm) { + // TODO // Creates a history state for . - std::vector