diff --git a/src/bin/Makefile b/src/bin/Makefile index 7cb01b50120..02c95ff4804 100644 --- a/src/bin/Makefile +++ b/src/bin/Makefile @@ -22,7 +22,7 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \ matrix-sum build-pfile-from-ali get-post-on-ali tree-info am-info \ vector-sum matrix-sum-rows est-pca sum-lda-accs sum-mllt-accs \ transform-vec align-text matrix-dim post-to-smat compile-graph \ - compare-int-vector + compare-int-vector latgen-incremental-mapped OBJFILES = diff --git a/src/bin/latgen-incremental-mapped.cc b/src/bin/latgen-incremental-mapped.cc new file mode 100644 index 00000000000..80c65bfb535 --- /dev/null +++ b/src/bin/latgen-incremental-mapped.cc @@ -0,0 +1,183 @@ +// bin/latgen-incremental-mapped.cc + +// Copyright 2019 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "tree/context-dep.h" +#include "hmm/transition-model.h" +#include "fstext/fstext-lib.h" +#include "decoder/decoder-wrappers.h" +#include "decoder/decodable-matrix.h" +#include "base/timer.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + using fst::SymbolTable; + using fst::Fst; + using fst::StdArc; + + const char *usage = + "Generate lattices, reading log-likelihoods as matrices\n" + " (model is needed only for the integer mappings in its transition-model)\n" + "The lattice determinization algorithm here can operate\n" + "incrementally.\n" + "Usage: latgen-incremental-mapped [options] trans-model-in " + "(fst-in|fsts-rspecifier) loglikes-rspecifier" + " lattice-wspecifier [ words-wspecifier [alignments-wspecifier] ]\n"; + ParseOptions po(usage); + Timer timer; + bool allow_partial = false; + BaseFloat acoustic_scale = 0.1; + LatticeIncrementalDecoderConfig config; + + std::string word_syms_filename; + config.Register(&po); + po.Register("acoustic-scale", &acoustic_scale, + "Scaling factor for acoustic likelihoods"); + + po.Register("word-symbol-table", &word_syms_filename, + "Symbol table for words [for debug output]"); + po.Register("allow-partial", &allow_partial, + "If true, produce output even if end state was not reached."); + + po.Read(argc, argv); + + if (po.NumArgs() < 4 || po.NumArgs() > 6) { + po.PrintUsage(); + exit(1); + } + + std::string model_in_filename = po.GetArg(1), fst_in_str = po.GetArg(2), + feature_rspecifier = po.GetArg(3), lattice_wspecifier = po.GetArg(4), + words_wspecifier = po.GetOptArg(5), + alignment_wspecifier = po.GetOptArg(6); + + TransitionModel trans_model; + ReadKaldiObject(model_in_filename, &trans_model); + + bool determinize = true; + CompactLatticeWriter compact_lattice_writer; + LatticeWriter lattice_writer; + if (!(determinize ? compact_lattice_writer.Open(lattice_wspecifier) + : lattice_writer.Open(lattice_wspecifier))) + KALDI_ERR << "Could not open table for writing lattices: " + << lattice_wspecifier; + + Int32VectorWriter words_writer(words_wspecifier); + + Int32VectorWriter alignment_writer(alignment_wspecifier); + + fst::SymbolTable *word_syms = NULL; + if (word_syms_filename != "") + if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename))) + KALDI_ERR << "Could not read symbol table from file " << word_syms_filename; + + double tot_like = 0.0; + kaldi::int64 frame_count = 0; + int num_success = 0, num_fail = 0; + + if (ClassifyRspecifier(fst_in_str, NULL, NULL) == kNoRspecifier) { + SequentialBaseFloatMatrixReader loglike_reader(feature_rspecifier); + // Input FST is just one FST, not a table of FSTs. + Fst *decode_fst = fst::ReadFstKaldiGeneric(fst_in_str); + timer.Reset(); + + { + LatticeIncrementalDecoder decoder(*decode_fst, trans_model, config); + + for (; !loglike_reader.Done(); loglike_reader.Next()) { + std::string utt = loglike_reader.Key(); + Matrix loglikes(loglike_reader.Value()); + loglike_reader.FreeCurrent(); + if (loglikes.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_fail++; + continue; + } + + DecodableMatrixScaledMapped decodable(trans_model, loglikes, + acoustic_scale); + + double like; + if (DecodeUtteranceLatticeIncremental( + decoder, decodable, trans_model, word_syms, utt, acoustic_scale, + determinize, allow_partial, &alignment_writer, &words_writer, + &compact_lattice_writer, &lattice_writer, &like)) { + tot_like += like; + frame_count += loglikes.NumRows(); + num_success++; + } else { + num_fail++; + } + } + } + delete decode_fst; // delete this only after decoder goes out of scope. + } else { // We have different FSTs for different utterances. + SequentialTableReader fst_reader(fst_in_str); + RandomAccessBaseFloatMatrixReader loglike_reader(feature_rspecifier); + for (; !fst_reader.Done(); fst_reader.Next()) { + std::string utt = fst_reader.Key(); + if (!loglike_reader.HasKey(utt)) { + KALDI_WARN << "Not decoding utterance " << utt + << " because no loglikes available."; + num_fail++; + continue; + } + const Matrix &loglikes = loglike_reader.Value(utt); + if (loglikes.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_fail++; + continue; + } + LatticeIncrementalDecoder decoder(fst_reader.Value(), trans_model, config); + DecodableMatrixScaledMapped decodable(trans_model, loglikes, acoustic_scale); + double like; + if (DecodeUtteranceLatticeIncremental( + decoder, decodable, trans_model, word_syms, utt, acoustic_scale, + determinize, allow_partial, &alignment_writer, &words_writer, + &compact_lattice_writer, &lattice_writer, &like)) { + tot_like += like; + frame_count += loglikes.NumRows(); + num_success++; + } else { + num_fail++; + } + } + } + + double elapsed = timer.Elapsed(); + KALDI_LOG << "Time taken " << elapsed + << "s: real-time factor assuming 100 frames/sec is " + << (elapsed * 100.0 / frame_count); + KALDI_LOG << "Done " << num_success << " utterances, failed for " << num_fail; + KALDI_LOG << "Overall log-likelihood per frame is " << (tot_like / frame_count) + << " over " << frame_count << " frames."; + + delete word_syms; + if (num_success != 0) + return 0; + else + return 1; + } catch (const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/decoder/Makefile b/src/decoder/Makefile index 020fe358fe9..61e9670adba 100644 --- a/src/decoder/Makefile +++ b/src/decoder/Makefile @@ -7,7 +7,8 @@ TESTFILES = OBJFILES = training-graph-compiler.o lattice-simple-decoder.o lattice-faster-decoder.o \ lattice-faster-online-decoder.o simple-decoder.o faster-decoder.o \ - decoder-wrappers.o grammar-fst.o decodable-matrix.o + decoder-wrappers.o grammar-fst.o decodable-matrix.o \ + lattice-incremental-decoder.o lattice-incremental-online-decoder.o LIBNAME = kaldi-decoder diff --git a/src/decoder/decoder-wrappers.cc b/src/decoder/decoder-wrappers.cc index ff573c74d15..af476f2322f 100644 --- a/src/decoder/decoder-wrappers.cc +++ b/src/decoder/decoder-wrappers.cc @@ -68,7 +68,7 @@ void DecodeUtteranceLatticeFasterClass::operator () () { success_ = true; using fst::VectorFst; if (!decoder_->Decode(decodable_)) { - KALDI_WARN << "Failed to decode file " << utt_; + KALDI_WARN << "Failed to decode utterance with id " << utt_; success_ = false; } if (!decoder_->ReachedFinal()) { @@ -195,6 +195,87 @@ DecodeUtteranceLatticeFasterClass::~DecodeUtteranceLatticeFasterClass() { delete decodable_; } +template +bool DecodeUtteranceLatticeIncremental( + LatticeIncrementalDecoderTpl &decoder, // not const but is really an input. + DecodableInterface &decodable, // not const but is really an input. + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr) { // puts utterance's like in like_ptr on success. + using fst::VectorFst; + if (!decoder.Decode(&decodable)) { + KALDI_WARN << "Failed to decode utterance with id " << utt; + return false; + } + if (!decoder.ReachedFinal()) { + if (allow_partial) { + KALDI_WARN << "Outputting partial output for utterance " << utt + << " since no final-state reached\n"; + } else { + KALDI_WARN << "Not producing output for utterance " << utt + << " since no final-state reached and " + << "--allow-partial=false.\n"; + return false; + } + } + + double likelihood; + LatticeWeight weight; + int32 num_frames; + { // First do some stuff with word-level traceback... + VectorFst decoded; + if (!decoder.GetBestPath(&decoded)) + // Shouldn't really reach this point as already checked success. + KALDI_ERR << "Failed to get traceback for utterance " << utt; + + std::vector alignment; + std::vector words; + GetLinearSymbolSequence(decoded, &alignment, &words, &weight); + num_frames = alignment.size(); + KALDI_ASSERT(num_frames == decoder.NumFramesDecoded()); + if (words_writer->IsOpen()) + words_writer->Write(utt, words); + if (alignment_writer->IsOpen()) + alignment_writer->Write(utt, alignment); + if (word_syms != NULL) { + std::cerr << utt << ' '; + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms->Find(words[i]); + if (s == "") + KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; + std::cerr << s << ' '; + } + std::cerr << '\n'; + } + likelihood = -(weight.Value1() + weight.Value2()); + } + + // Get lattice, and do determinization if requested. + CompactLattice clat; + decoder.GetLattice(&clat); + if (clat.NumStates() == 0) + KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt; + // We'll write the lattice without acoustic scaling. + if (acoustic_scale != 0.0) + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat); + compact_lattice_writer->Write(utt, clat); + KALDI_LOG << "Log-like per frame for utterance " << utt << " is " + << (likelihood / num_frames) << " over " + << num_frames << " frames."; + KALDI_VLOG(2) << "Cost for utterance " << utt << " is " + << weight.Value1() << " + " << weight.Value2(); + *like_ptr = likelihood; + return true; +} + // Takes care of output. Returns true on success. template @@ -215,7 +296,7 @@ bool DecodeUtteranceLatticeFaster( using fst::VectorFst; if (!decoder.Decode(&decodable)) { - KALDI_WARN << "Failed to decode file " << utt; + KALDI_WARN << "Failed to decode utterance with id " << utt; return false; } if (!decoder.ReachedFinal()) { @@ -296,6 +377,37 @@ bool DecodeUtteranceLatticeFaster( } // Instantiate the template above for the two required FST types. +template bool DecodeUtteranceLatticeIncremental( + LatticeIncrementalDecoderTpl > &decoder, + DecodableInterface &decodable, + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr); + +template bool DecodeUtteranceLatticeIncremental( + LatticeIncrementalDecoderTpl &decoder, + DecodableInterface &decodable, + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr); + + template bool DecodeUtteranceLatticeFaster( LatticeFasterDecoderTpl > &decoder, DecodableInterface &decodable, @@ -345,7 +457,7 @@ bool DecodeUtteranceLatticeSimple( using fst::VectorFst; if (!decoder.Decode(&decodable)) { - KALDI_WARN << "Failed to decode file " << utt; + KALDI_WARN << "Failed to decode utterance with id " << utt; return false; } if (!decoder.ReachedFinal()) { diff --git a/src/decoder/decoder-wrappers.h b/src/decoder/decoder-wrappers.h index fc81137f356..61134412cfd 100644 --- a/src/decoder/decoder-wrappers.h +++ b/src/decoder/decoder-wrappers.h @@ -22,6 +22,7 @@ #include "itf/options-itf.h" #include "decoder/lattice-faster-decoder.h" +#include "decoder/lattice-incremental-decoder.h" #include "decoder/lattice-simple-decoder.h" // This header contains declarations from various convenience functions that are called @@ -88,6 +89,23 @@ void AlignUtteranceWrapper( void ModifyGraphForCarefulAlignment( fst::VectorFst *fst); +/// TODO +template +bool DecodeUtteranceLatticeIncremental( + LatticeIncrementalDecoderTpl &decoder, // not const but is really an input. + DecodableInterface &decodable, // not const but is really an input. + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignments_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr); // puts utterance's likelihood in like_ptr on success. + /// This function DecodeUtteranceLatticeFaster is used in several decoders, and /// we have moved it here. Note: this is really "binary-level" code as it diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc new file mode 100644 index 00000000000..f025ccd78ff --- /dev/null +++ b/src/decoder/lattice-incremental-decoder.cc @@ -0,0 +1,1631 @@ +// decoder/lattice-incremental-decoder.cc + +// Copyright 2019 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/lattice-incremental-decoder.h" +#include "lat/lattice-functions.h" +#include "base/timer.h" + +namespace kaldi { + +// instantiate this class once for each thing you have to decode. +template +LatticeIncrementalDecoderTpl::LatticeIncrementalDecoderTpl( + const FST &fst, const TransitionModel &trans_model, + const LatticeIncrementalDecoderConfig &config) + : fst_(&fst), + delete_fst_(false), + config_(config), + num_toks_(0), + determinizer_(config, trans_model) { + config.Check(); + toks_.SetSize(1000); // just so on the first frame we do something reasonable. +} + +template +LatticeIncrementalDecoderTpl::LatticeIncrementalDecoderTpl( + const LatticeIncrementalDecoderConfig &config, FST *fst, + const TransitionModel &trans_model) + : fst_(fst), + delete_fst_(true), + config_(config), + num_toks_(0), + determinizer_(config, trans_model) { + config.Check(); + toks_.SetSize(1000); // just so on the first frame we do something reasonable. +} + +template +LatticeIncrementalDecoderTpl::~LatticeIncrementalDecoderTpl() { + DeleteElems(toks_.Clear()); + ClearActiveTokens(); + if (delete_fst_) delete fst_; +} + +template +void LatticeIncrementalDecoderTpl::InitDecoding() { + // clean up from last time: + DeleteElems(toks_.Clear()); + cost_offsets_.clear(); + ClearActiveTokens(); + warned_ = false; + num_toks_ = 0; + decoding_finalized_ = false; + final_costs_.clear(); + StateId start_state = fst_->Start(); + KALDI_ASSERT(start_state != fst::kNoStateId); + active_toks_.resize(1); + Token *start_tok = new Token(0.0, 0.0, NULL, NULL, NULL); + active_toks_[0].toks = start_tok; + toks_.Insert(start_state, start_tok); + num_toks_++; + + last_get_lattice_frame_ = 0; + token_label_map_.clear(); + token_label_map_.reserve(std::min((int32)1e5, config_.max_active)); + token_label_available_idx_ = config_.max_word_id + 1; + token_label_final_cost_.clear(); + determinizer_.Init(); + + ProcessNonemitting(config_.beam); +} + +template +void LatticeIncrementalDecoderTpl::DeterminizeLattice() { + // We always incrementally determinize the lattice after lattice pruning in + // PruneActiveTokens() since we need extra_cost as the weights + // of final arcs to denote the "future" information of final states (Tokens) + // Moreover, the delay on GetLattice to do determinization + // make it process more skinny lattices which reduces the computation overheads. + int32 frame_det_most = NumFramesDecoded() - config_.determinize_delay; + // The minimum length of chunk is config_.determinize_period. + if (frame_det_most % config_.determinize_period == 0) { + int32 frame_det_least = last_get_lattice_frame_ + config_.determinize_period; + // Incremental determinization: + // To adaptively decide the length of chunk, we further compare the number of + // tokens in each frame and a pre-defined threshold. + // If the number of tokens in a certain frame is less than + // config_.determinize_max_active, the lattice can be determinized up to this + // frame. And we try to determinize as most frames as possible so we check + // numbers from frame_det_most to frame_det_least + for (int32 f = frame_det_most; f >= frame_det_least; f--) { + if (config_.determinize_max_active == std::numeric_limits::max() || + GetNumToksForFrame(f) < config_.determinize_max_active) { + KALDI_VLOG(2) << "Frame: " << NumFramesDecoded() + << " incremental determinization up to " << f; + GetLattice(false, f); + break; + } + } + } + return; +} +// Returns true if any kind of traceback is available (not necessarily from +// a final state). It should only very rarely return false; this indicates +// an unusual search error. +template +bool LatticeIncrementalDecoderTpl::Decode(DecodableInterface *decodable) { + InitDecoding(); + + // We use 1-based indexing for frames in this decoder (if you view it in + // terms of features), but note that the decodable object uses zero-based + // numbering, which we have to correct for when we call it. + + while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { + if (NumFramesDecoded() % config_.prune_interval == 0) { + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + } + + DeterminizeLattice(); + + BaseFloat cost_cutoff = ProcessEmitting(decodable); + ProcessNonemitting(cost_cutoff); + } + Timer timer; + FinalizeDecoding(); + GetLattice(true, NumFramesDecoded()); + KALDI_VLOG(2) << "Delay time during and after FinalizeDecoding()" + << "(secs): " << timer.Elapsed(); + + // Returns true if we have any kind of traceback available (not necessarily + // to the end state; query ReachedFinal() for that). + return !active_toks_.empty() && active_toks_.back().toks != NULL; +} + +// Outputs an FST corresponding to the single best path through the lattice. +template +bool LatticeIncrementalDecoderTpl::GetBestPath(Lattice *olat, + bool use_final_probs) { + CompactLattice lat, slat; + GetLattice(use_final_probs, NumFramesDecoded(), &lat); + ShortestPath(lat, &slat); + ConvertLattice(slat, olat); + return (olat->NumStates() != 0); +} + +template +void LatticeIncrementalDecoderTpl::PossiblyResizeHash(size_t num_toks) { + size_t new_sz = + static_cast(static_cast(num_toks) * config_.hash_ratio); + if (new_sz > toks_.Size()) { + toks_.SetSize(new_sz); + } +} + +/* + A note on the definition of extra_cost. + + extra_cost is used in pruning tokens, to save memory. + + Define the 'forward cost' of a token as zero for any token on the frame + we're currently decoding; and for other frames, as the shortest-path cost + between that token and a token on the frame we're currently decoding. + (by "currently decoding" I mean the most recently processed frame). + + Then define the extra_cost of a token (always >= 0) as the forward-cost of + the token minus the smallest forward-cost of any token on the same frame. + + We can use the extra_cost to accurately prune away tokens that we know will + never appear in the lattice. If the extra_cost is greater than the desired + lattice beam, the token would provably never appear in the lattice, so we can + prune away the token. + + The advantage of storing the extra_cost rather than the forward-cost, is that + it is less costly to keep the extra_cost up-to-date when we process new frames. + When we process a new frame, *all* the previous frames' forward-costs would change; + but in general the extra_cost will change only for a finite number of frames. + (Actually we don't update all the extra_costs every time we update a frame; we + only do it every 'config_.prune_interval' frames). + */ + +// FindOrAddToken either locates a token in hash of toks_, +// or if necessary inserts a new, empty token (i.e. with no forward links) +// for the current frame. [note: it's inserted if necessary into hash toks_ +// and also into the singly linked list of tokens active on this frame +// (whose head is at active_toks_[frame]). +template +inline Token *LatticeIncrementalDecoderTpl::FindOrAddToken( + StateId state, int32 frame_plus_one, BaseFloat tot_cost, Token *backpointer, + bool *changed) { + // Returns the Token pointer. Sets "changed" (if non-NULL) to true + // if the token was newly created or the cost changed. + KALDI_ASSERT(frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + Elem *e_found = toks_.Find(state); + if (e_found == NULL) { // no such token presently. + const BaseFloat extra_cost = 0.0; + // tokens on the currently final frame have zero extra_cost + // as any of them could end up + // on the winning path. + Token *new_tok = new Token(tot_cost, extra_cost, NULL, toks, backpointer); + // NULL: no forward links yet + toks = new_tok; + num_toks_++; + toks_.Insert(state, new_tok); + if (changed) *changed = true; + return new_tok; + } else { + Token *tok = e_found->val; // There is an existing Token for this state. + if (tok->tot_cost > tot_cost) { // replace old token + tok->tot_cost = tot_cost; + // SetBackpointer() just does tok->backpointer = backpointer in + // the case where Token == BackpointerToken, else nothing. + tok->SetBackpointer(backpointer); + // we don't allocate a new token, the old stays linked in active_toks_ + // we only replace the tot_cost + // in the current frame, there are no forward links (and no extra_cost) + // only in ProcessNonemitting we have to delete forward links + // in case we visit a state for the second time + // those forward links, that lead to this replaced token before: + // they remain and will hopefully be pruned later (PruneForwardLinks...) + if (changed) *changed = true; + } else { + if (changed) *changed = false; + } + return tok; + } +} + +// prunes outgoing links for all tokens in active_toks_[frame] +// it's called by PruneActiveTokens +// all links, that have link_extra_cost > lattice_beam are pruned +template +void LatticeIncrementalDecoderTpl::PruneForwardLinks( + int32 frame_plus_one, bool *extra_costs_changed, bool *links_pruned, + BaseFloat delta) { + // delta is the amount by which the extra_costs must change + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + + *extra_costs_changed = false; + *links_pruned = false; + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen. + if (!warned_) { + KALDI_WARN << "No tokens alive [doing pruning].. warning first " + "time only for each utterance\n"; + warned_ = true; + } + } + + // We have to iterate until there is no more change, because the links + // are not guaranteed to be in topological order. + bool changed = true; // difference new minus old extra cost >= delta ? + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; tok != NULL; + tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost for tok. + BaseFloat tok_extra_cost = std::numeric_limits::infinity(); + // tok_extra_cost is the best (min) of link_extra_cost of outgoing links + for (link = tok->links; link != NULL;) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = + next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) - + next_tok->tot_cost); // difference in brackets is >= 0 + // link_exta_cost is the difference in score between the best paths + // through link source state and through link destination state + KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) + prev_link->next = next_link; + else + tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + *links_pruned = true; + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) tok_extra_cost = link_extra_cost; + prev_link = link; // move to next link + link = link->next; + } + } // for all outgoing links + if (fabs(tok_extra_cost - tok->extra_cost) > delta) + changed = true; // difference new minus old is bigger than delta + tok->extra_cost = tok_extra_cost; + // will be +infinity or <= lattice_beam_. + // infinity indicates, that no forward link survived pruning + } // for all Token on active_toks_[frame] + if (changed) *extra_costs_changed = true; + + // Note: it's theoretically possible that aggressive compiler + // optimizations could cause an infinite loop here for small delta and + // high-dynamic-range scores. + } // while changed +} + +// PruneForwardLinksFinal is a version of PruneForwardLinks that we call +// on the final frame. If there are final tokens active, it uses +// the final-probs for pruning, otherwise it treats all tokens as final. +template +void LatticeIncrementalDecoderTpl::PruneForwardLinksFinal() { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame_plus_one = active_toks_.size() - 1; + + if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen. + KALDI_WARN << "No tokens alive at end of file"; + + typedef typename unordered_map::const_iterator IterType; + ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_); + decoding_finalized_ = true; + // We call DeleteElems() as a nicety, not because it's really necessary; + // otherwise there would be a time, after calling PruneTokensForFrame() on the + // final frame, when toks_.GetList() or toks_.Clear() would contain pointers + // to nonexistent tokens. + DeleteElems(toks_.Clear()); + + // Now go through tokens on this frame, pruning forward links... may have to + // iterate a few times until there is no more change, because the list is not + // in topological order. This is a modified version of the code in + // PruneForwardLinks, but here we also take account of the final-probs. + bool changed = true; + BaseFloat delta = 1.0e-05; + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; tok != NULL; + tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost. It has a term in it that corresponds + // to the "final-prob", so instead of initializing tok_extra_cost to infinity + // below we set it to the difference between the (score+final_prob) of this + // token, + // and the best such (score+final_prob). + BaseFloat final_cost; + if (final_costs_.empty()) { + final_cost = 0.0; + } else { + IterType iter = final_costs_.find(tok); + if (iter != final_costs_.end()) + final_cost = iter->second; + else + final_cost = std::numeric_limits::infinity(); + } + BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_; + // tok_extra_cost will be a "min" over either directly being final, or + // being indirectly final through other links, and the loop below may + // decrease its value: + for (link = tok->links; link != NULL;) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = + next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) - + next_tok->tot_cost); + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) + prev_link->next = next_link; + else + tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) tok_extra_cost = link_extra_cost; + prev_link = link; + link = link->next; + } + } + // prune away tokens worse than lattice_beam above best path. This step + // was not necessary in the non-final case because then, this case + // showed up as having no forward links. Here, the tok_extra_cost has + // an extra component relating to the final-prob. + if (tok_extra_cost > config_.lattice_beam) + tok_extra_cost = std::numeric_limits::infinity(); + // to be pruned in PruneTokensForFrame + + if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) changed = true; + tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. + } + } // while changed +} + +template +BaseFloat LatticeIncrementalDecoderTpl::FinalRelativeCost() const { + if (!decoding_finalized_) { + BaseFloat relative_cost; + ComputeFinalCosts(NULL, &relative_cost, NULL); + return relative_cost; + } else { + // we're not allowed to call that function if FinalizeDecoding() has + // been called; return a cached value. + return final_relative_cost_; + } +} + +// Prune away any tokens on this frame that have no forward links. +// [we don't do this in PruneForwardLinks because it would give us +// a problem with dangling pointers]. +// It's called by PruneActiveTokens if any forward links have been pruned +template +void LatticeIncrementalDecoderTpl::PruneTokensForFrame( + int32 frame_plus_one) { + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + if (toks == NULL) KALDI_WARN << "No tokens alive [doing pruning]"; + Token *tok, *next_tok, *prev_tok = NULL; + for (tok = toks; tok != NULL; tok = next_tok) { + next_tok = tok->next; + if (tok->extra_cost == std::numeric_limits::infinity()) { + // token is unreachable from end of graph; (no forward links survived) + // excise tok from list and delete tok. + if (prev_tok != NULL) + prev_tok->next = tok->next; + else + toks = tok->next; + delete tok; + num_toks_--; + } else { // fetch next Token + prev_tok = tok; + } + } +} + +// Go backwards through still-alive tokens, pruning them, starting not from +// the current frame (where we want to keep all tokens) but from the frame before +// that. We go backwards through the frames and stop when we reach a point +// where the delta-costs are not changing (and the delta controls when we consider +// a cost to have "not changed"). +template +void LatticeIncrementalDecoderTpl::PruneActiveTokens(BaseFloat delta) { + int32 cur_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // The index "f" below represents a "frame plus one", i.e. you'd have to subtract + // one to get the corresponding index for the decodable object. + for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) { + // Reason why we need to prune forward links in this situation: + // (1) we have never pruned them (new TokenList) + // (2) we have not yet pruned the forward links to the next f, + // after any of those tokens have changed their extra_cost. + if (active_toks_[f].must_prune_forward_links) { + bool extra_costs_changed = false, links_pruned = false; + PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta); + if (extra_costs_changed && f > 0) // any token has changed extra_cost + active_toks_[f - 1].must_prune_forward_links = true; + if (links_pruned) // any link was pruned + active_toks_[f].must_prune_tokens = true; + active_toks_[f].must_prune_forward_links = false; // job done + } + if (f + 1 < cur_frame_plus_one && // except for last f (no forward links) + active_toks_[f + 1].must_prune_tokens) { + PruneTokensForFrame(f + 1); + active_toks_[f + 1].must_prune_tokens = false; + } + } + KALDI_VLOG(4) << "PruneActiveTokens: pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +template +void LatticeIncrementalDecoderTpl::ComputeFinalCosts( + unordered_map *final_costs, BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const { + KALDI_ASSERT(!decoding_finalized_); + if (final_costs != NULL) final_costs->clear(); + const Elem *final_toks = toks_.GetList(); + BaseFloat infinity = std::numeric_limits::infinity(); + BaseFloat best_cost = infinity, best_cost_with_final = infinity; + + while (final_toks != NULL) { + StateId state = final_toks->key; + Token *tok = final_toks->val; + const Elem *next = final_toks->tail; + BaseFloat final_cost = fst_->Final(state).Value(); + BaseFloat cost = tok->tot_cost, cost_with_final = cost + final_cost; + best_cost = std::min(cost, best_cost); + best_cost_with_final = std::min(cost_with_final, best_cost_with_final); + if (final_costs != NULL && final_cost != infinity) + (*final_costs)[tok] = final_cost; + final_toks = next; + } + if (final_relative_cost != NULL) { + if (best_cost == infinity && best_cost_with_final == infinity) { + // Likely this will only happen if there are no tokens surviving. + // This seems the least bad way to handle it. + *final_relative_cost = infinity; + } else { + *final_relative_cost = best_cost_with_final - best_cost; + } + } + if (final_best_cost != NULL) { + if (best_cost_with_final != infinity) { // final-state exists. + *final_best_cost = best_cost_with_final; + } else { // no final-state exists. + *final_best_cost = best_cost; + } + } +} + +template +void LatticeIncrementalDecoderTpl::AdvanceDecoding( + DecodableInterface *decodable, int32 max_num_frames) { + if (std::is_same >::value) { + // if the type 'FST' is the FST base-class, then see if the FST type of fst_ + // is actually VectorFst or ConstFst. If so, call the AdvanceDecoding() + // function after casting *this to the more specific type. + if (fst_->Type() == "const") { + LatticeIncrementalDecoderTpl, Token> *this_cast = + reinterpret_cast< + LatticeIncrementalDecoderTpl, Token> *>( + this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } else if (fst_->Type() == "vector") { + LatticeIncrementalDecoderTpl, Token> *this_cast = + reinterpret_cast< + LatticeIncrementalDecoderTpl, Token> *>( + this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } + } + + KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ && + "You must call InitDecoding() before AdvanceDecoding"); + int32 num_frames_ready = decodable->NumFramesReady(); + // num_frames_ready must be >= num_frames_decoded, or else + // the number of frames ready must have decreased (which doesn't + // make sense) or the decodable object changed between calls + // (which isn't allowed). + KALDI_ASSERT(num_frames_ready >= NumFramesDecoded()); + int32 target_frames_decoded = num_frames_ready; + if (max_num_frames >= 0) + target_frames_decoded = + std::min(target_frames_decoded, NumFramesDecoded() + max_num_frames); + while (NumFramesDecoded() < target_frames_decoded) { + if (NumFramesDecoded() % config_.prune_interval == 0) { + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + } + + DeterminizeLattice(); + + BaseFloat cost_cutoff = ProcessEmitting(decodable); + ProcessNonemitting(cost_cutoff); + } +} + +// FinalizeDecoding() is a version of PruneActiveTokens that we call +// (optionally) on the final frame. Takes into account the final-prob of +// tokens. This function used to be called PruneActiveTokensFinal(). +template +void LatticeIncrementalDecoderTpl::FinalizeDecoding() { + int32 final_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // PruneForwardLinksFinal() prunes final frame (with final-probs), and + // sets decoding_finalized_. + PruneForwardLinksFinal(); + for (int32 f = final_frame_plus_one - 1; f >= 0; f--) { + bool b1, b2; // values not used. + BaseFloat dontcare = 0.0; // delta of zero means we must always update + PruneForwardLinks(f, &b1, &b2, dontcare); + PruneTokensForFrame(f + 1); + } + PruneTokensForFrame(0); + KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin << " to " << num_toks_; +} + +/// Gets the weight cutoff. Also counts the active tokens. +template +BaseFloat LatticeIncrementalDecoderTpl::GetCutoff( + Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, Elem **best_elem) { + BaseFloat best_weight = std::numeric_limits::infinity(); + // positive == high cost == bad. + size_t count = 0; + if (config_.max_active == std::numeric_limits::max() && + config_.min_active == 0) { + for (Elem *e = list_head; e != NULL; e = e->tail, count++) { + BaseFloat w = static_cast(e->val->tot_cost); + if (w < best_weight) { + best_weight = w; + if (best_elem) *best_elem = e; + } + } + if (tok_count != NULL) *tok_count = count; + if (adaptive_beam != NULL) *adaptive_beam = config_.beam; + return best_weight + config_.beam; + } else { + tmp_array_.clear(); + for (Elem *e = list_head; e != NULL; e = e->tail, count++) { + BaseFloat w = e->val->tot_cost; + tmp_array_.push_back(w); + if (w < best_weight) { + best_weight = w; + if (best_elem) *best_elem = e; + } + } + if (tok_count != NULL) *tok_count = count; + + BaseFloat beam_cutoff = best_weight + config_.beam, + min_active_cutoff = std::numeric_limits::infinity(), + max_active_cutoff = std::numeric_limits::infinity(); + + KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() + << " is " << tmp_array_.size(); + + if (tmp_array_.size() > static_cast(config_.max_active)) { + std::nth_element(tmp_array_.begin(), tmp_array_.begin() + config_.max_active, + tmp_array_.end()); + max_active_cutoff = tmp_array_[config_.max_active]; + } + if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam. + if (adaptive_beam) + *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta; + return max_active_cutoff; + } + if (tmp_array_.size() > static_cast(config_.min_active)) { + if (config_.min_active == 0) + min_active_cutoff = best_weight; + else { + std::nth_element(tmp_array_.begin(), tmp_array_.begin() + config_.min_active, + tmp_array_.size() > static_cast(config_.max_active) + ? tmp_array_.begin() + config_.max_active + : tmp_array_.end()); + min_active_cutoff = tmp_array_[config_.min_active]; + } + } + if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam. + if (adaptive_beam) + *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta; + return min_active_cutoff; + } else { + *adaptive_beam = config_.beam; + return beam_cutoff; + } + } +} + +template +BaseFloat LatticeIncrementalDecoderTpl::ProcessEmitting( + DecodableInterface *decodable) { + KALDI_ASSERT(active_toks_.size() > 0); + int32 frame = active_toks_.size() - 1; // frame is the frame-index + // (zero-based) used to get likelihoods + // from the decodable object. + active_toks_.resize(active_toks_.size() + 1); + + Elem *final_toks = toks_.Clear(); // analogous to swapping prev_toks_ / cur_toks_ + // in simple-decoder.h. Removes the Elems from + // being indexed in the hash in toks_. + Elem *best_elem = NULL; + BaseFloat adaptive_beam; + size_t tok_cnt; + BaseFloat cur_cutoff = GetCutoff(final_toks, &tok_cnt, &adaptive_beam, &best_elem); + KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is " + << adaptive_beam; + + PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough. + + BaseFloat next_cutoff = std::numeric_limits::infinity(); + // pruning "online" before having seen all tokens + + BaseFloat cost_offset = 0.0; // Used to keep probabilities in a good + // dynamic range. + + // First process the best token to get a hopefully + // reasonably tight bound on the next cutoff. The only + // products of the next block are "next_cutoff" and "cost_offset". + if (best_elem) { + StateId state = best_elem->key; + Token *tok = best_elem->val; + cost_offset = -tok->tot_cost; + for (fst::ArcIterator aiter(*fst_, state); !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + BaseFloat new_weight = arc.weight.Value() + cost_offset - + decodable->LogLikelihood(frame, arc.ilabel) + + tok->tot_cost; + if (new_weight + adaptive_beam < next_cutoff) + next_cutoff = new_weight + adaptive_beam; + } + } + } + + // Store the offset on the acoustic likelihoods that we're applying. + // Could just do cost_offsets_.push_back(cost_offset), but we + // do it this way as it's more robust to future code changes. + cost_offsets_.resize(frame + 1, 0.0); + cost_offsets_[frame] = cost_offset; + + // the tokens are now owned here, in final_toks, and the hash is empty. + // 'owned' is a complex thing here; the point is we need to call DeleteElem + // on each elem 'e' to let toks_ know we're done with them. + for (Elem *e = final_toks, *e_tail; e != NULL; e = e_tail) { + // loop this way because we delete "e" as we go. + StateId state = e->key; + Token *tok = e->val; + if (tok->tot_cost <= cur_cutoff) { + for (fst::ArcIterator aiter(*fst_, state); !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + BaseFloat ac_cost = + cost_offset - decodable->LogLikelihood(frame, arc.ilabel), + graph_cost = arc.weight.Value(), cur_cost = tok->tot_cost, + tot_cost = cur_cost + ac_cost + graph_cost; + if (tot_cost > next_cutoff) + continue; + else if (tot_cost + adaptive_beam < next_cutoff) + next_cutoff = tot_cost + adaptive_beam; // prune by best current token + // Note: the frame indexes into active_toks_ are one-based, + // hence the + 1. + Token *next_tok = + FindOrAddToken(arc.nextstate, frame + 1, tot_cost, tok, NULL); + // NULL: no change indicator needed + + // Add ForwardLink from tok to next_tok (put on head of list tok->links) + tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, graph_cost, + ac_cost, tok->links); + } + } // for all arcs + } + e_tail = e->tail; + toks_.Delete(e); // delete Elem + } + return next_cutoff; +} + +// static inline +template +void LatticeIncrementalDecoderTpl::DeleteForwardLinks(Token *tok) { + ForwardLinkT *l = tok->links, *m; + while (l != NULL) { + m = l->next; + delete l; + l = m; + } + tok->links = NULL; +} + +template +void LatticeIncrementalDecoderTpl::ProcessNonemitting(BaseFloat cutoff) { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame = static_cast(active_toks_.size()) - 2; + // Note: "frame" is the time-index we just processed, or -1 if + // we are processing the nonemitting transitions before the + // first frame (called from InitDecoding()). + + // Processes nonemitting arcs for one frame. Propagates within toks_. + // Note-- this queue structure is is not very optimal as + // it may cause us to process states unnecessarily (e.g. more than once), + // but in the baseline code, turning this vector into a set to fix this + // problem did not improve overall speed. + + KALDI_ASSERT(queue_.empty()); + + if (toks_.GetList() == NULL) { + if (!warned_) { + KALDI_WARN << "Error, no surviving tokens: frame is " << frame; + warned_ = true; + } + } + + for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) { + StateId state = e->key; + if (fst_->NumInputEpsilons(state) != 0) queue_.push_back(state); + } + + while (!queue_.empty()) { + StateId state = queue_.back(); + queue_.pop_back(); + + Token *tok = + toks_.Find(state) + ->val; // would segfault if state not in toks_ but this can't happen. + BaseFloat cur_cost = tok->tot_cost; + if (cur_cost > cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + // but since most states are emitting it's not a huge issue. + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel == 0) { // propagate nonemitting only... + BaseFloat graph_cost = arc.weight.Value(), tot_cost = cur_cost + graph_cost; + if (tot_cost < cutoff) { + bool changed; + + Token *new_tok = + FindOrAddToken(arc.nextstate, frame + 1, tot_cost, tok, &changed); + + tok->links = + new ForwardLinkT(new_tok, 0, arc.olabel, graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new [if so, add into queue]. + if (changed && fst_->NumInputEpsilons(arc.nextstate) != 0) + queue_.push_back(arc.nextstate); + } + } + } // for all arcs + } // while queue not empty +} + +template +void LatticeIncrementalDecoderTpl::DeleteElems(Elem *list) { + for (Elem *e = list, *e_tail; e != NULL; e = e_tail) { + e_tail = e->tail; + toks_.Delete(e); + } +} + +template +void LatticeIncrementalDecoderTpl< + FST, Token>::ClearActiveTokens() { // a cleanup routine, at utt end/begin + for (size_t i = 0; i < active_toks_.size(); i++) { + // Delete all tokens alive on this frame, and any forward + // links they may have. + for (Token *tok = active_toks_[i].toks; tok != NULL;) { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + active_toks_.clear(); + KALDI_ASSERT(num_toks_ == 0); +} + +// static +template +void LatticeIncrementalDecoderTpl::TopSortTokens( + Token *tok_list, std::vector *topsorted_list) { + unordered_map token2pos; + typedef typename unordered_map::iterator IterType; + int32 num_toks = 0; + for (Token *tok = tok_list; tok != NULL; tok = tok->next) num_toks++; + int32 cur_pos = 0; + // We assign the tokens numbers num_toks - 1, ... , 2, 1, 0. + // This is likely to be in closer to topological order than + // if we had given them ascending order, because of the way + // new tokens are put at the front of the list. + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + token2pos[tok] = num_toks - ++cur_pos; + + unordered_set reprocess; + + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) { + Token *tok = iter->first; + int32 pos = iter->second; + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + // We only need to consider epsilon links, since non-epsilon links + // transition between frames and this function only needs to sort a list + // of tokens from a single frame. + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { // another token on this frame, + // so must consider it. + int32 next_pos = following_iter->second; + if (next_pos < pos) { // reassign the position of the next Token. + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + // In case we had previously assigned this token to be reprocessed, we can + // erase it from that set because it's "happy now" (we just processed it). + reprocess.erase(tok); + } + + size_t max_loop = 1000000, loop_count; // max_loop is to detect epsilon cycles. + for (loop_count = 0; !reprocess.empty() && loop_count < max_loop; ++loop_count) { + std::vector reprocess_vec; + for (typename unordered_set::iterator iter = reprocess.begin(); + iter != reprocess.end(); ++iter) + reprocess_vec.push_back(*iter); + reprocess.clear(); + for (typename std::vector::iterator iter = reprocess_vec.begin(); + iter != reprocess_vec.end(); ++iter) { + Token *tok = *iter; + int32 pos = token2pos[tok]; + // Repeat the processing we did above (for comments, see above). + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { + int32 next_pos = following_iter->second; + if (next_pos < pos) { + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + } + } + KALDI_ASSERT(loop_count < max_loop && + "Epsilon loops exist in your decoding " + "graph (this is not allowed!)"); + + topsorted_list->clear(); + topsorted_list->resize(cur_pos, NULL); // create a list with NULLs in between. + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) + (*topsorted_list)[iter->second] = iter->first; +} + +template +bool LatticeIncrementalDecoderTpl::GetLattice(CompactLattice *olat) { + return GetLattice(true, NumFramesDecoded(), olat); +} + +template +bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, + int32 last_frame_of_chunk, + CompactLattice *olat) { + using namespace fst; + bool not_first_chunk = last_get_lattice_frame_ != 0; + bool ret = true; + + // last_get_lattice_frame_ is used to record the first frame of the chunk + // last time we obtain from calling this function. If it reaches + // last_frame_of_chunk + // we cannot generate any more chunk + if (last_get_lattice_frame_ < last_frame_of_chunk) { + Lattice raw_fst; + // step 1: Get lattice chunk with initial and final states + // In this function, we do not create the initial state in + // the first chunk, and we do not create the final state in the last chunk + if (!GetIncrementalRawLattice(&raw_fst, use_final_probs, last_get_lattice_frame_, + last_frame_of_chunk, not_first_chunk, + !decoding_finalized_)) + KALDI_ERR << "Unexpected problem when getting lattice"; + // step 2-3 + ret = determinizer_.ProcessChunk(raw_fst, last_get_lattice_frame_, + last_frame_of_chunk); + last_get_lattice_frame_ = last_frame_of_chunk; + } else if (last_get_lattice_frame_ > last_frame_of_chunk) { + KALDI_WARN << "Call GetLattice up to frame: " << last_frame_of_chunk + << " while the determinizer_ has already done up to frame: " + << last_get_lattice_frame_; + } + + // step 4 + if (decoding_finalized_) ret &= determinizer_.Finalize(); + if (olat) { + *olat = determinizer_.GetDeterminizedLattice(); + ret &= (olat->NumStates() > 0); + } + if (!ret) { + KALDI_WARN << "Last chunk processing failed." + << " We will retry from frame 0."; + // Reset determinizer_ and re-determinize from + // frame 0 to last_frame_of_chunk + last_get_lattice_frame_ = 0; + determinizer_.Init(); + } + + return ret; +} + +template +bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( + Lattice *ofst, bool use_final_probs, int32 frame_begin, int32 frame_end, + bool create_initial_state, bool create_final_state) { + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + + if (decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "GetIncrementalRawLattice() with use_final_probs == false"; + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (decoding_finalized_ ? final_costs_ : final_costs_local); + if (!decoding_finalized_ && use_final_probs) + ComputeFinalCosts(&final_costs_local, NULL, NULL); + + ofst->DeleteStates(); + unordered_multimap + token_label2last_state; // for GetInitialRawLattice + // initial arcs for the chunk + if (create_initial_state) + determinizer_.GetInitialRawLattice(ofst, &token_label2last_state, + token_label_final_cost_); + // num-frames plus one (since frames are one-based, and we have + // an extra frame for the start-state). + KALDI_ASSERT(frame_end > 0); + const int32 bucket_count = num_toks_ / 2 + 3; + unordered_map tok_map(bucket_count); + // First create all states. + std::vector token_list; + for (int32 f = frame_begin; f <= frame_end; f++) { + if (active_toks_[f].toks == NULL) { + KALDI_WARN << "GetIncrementalRawLattice: no tokens active on frame " << f + << ": not producing lattice.\n"; + return false; + } + TopSortTokens(active_toks_[f].toks, &token_list); + for (size_t i = 0; i < token_list.size(); i++) + if (token_list[i] != NULL) tok_map[token_list[i]] = ofst->AddState(); + } + // The next statement sets the start state of the output FST. + // No matter create_initial_state or not , state zero must be the start-state. + StateId start_state = 0; + ofst->SetStart(start_state); + + KALDI_VLOG(4) << "init:" << num_toks_ / 2 + 3 + << " buckets:" << tok_map.bucket_count() + << " load:" << tok_map.load_factor() + << " max:" << tok_map.max_load_factor(); + // step 1.1: create initial_arc for later appending with the previous chunk + if (create_initial_state) { + for (Token *tok = active_toks_[frame_begin].toks; tok != NULL; tok = tok->next) { + StateId cur_state = tok_map[tok]; + // token_label_map_ is construct during create_final_state + auto r = token_label_map_.find(tok); + KALDI_ASSERT(r != token_label_map_.end()); // it should exist + int32 token_label = r->second; + auto range = token_label2last_state.equal_range(token_label); + if (range.first == range.second) { + KALDI_WARN + << "The token in the first frame of this chunk does not " + "exist in the last frame of previous chunk. It should seldom" + " happen and would be caused by over-pruning in determinization," + "e.g. the lattice reaches --max-mem constrain."; + continue; + } + for (auto it = range.first; it != range.second; ++it) { + // the destination state of the last of the sequence of arcs w.r.t the token + // label + // here created by GetInitialRawLattice + auto state_last_initial = it->second; + // connect it to the state correponding to the token w.r.t the token label + // here + Arc arc(0, 0, Weight::One(), cur_state); + ofst->AddArc(state_last_initial, arc); + } + } + } + // step 1.2: create all arcs as GetRawLattice() of LatticeFasterDecoder + for (int32 f = frame_begin; f <= frame_end; f++) { + for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { + StateId cur_state = tok_map[tok]; + for (ForwardLinkT *l = tok->links; l != NULL; l = l->next) { + // for the arcs outgoing from the last frame Token in this chunk, we will + // create these arcs in the next chunk + if (f == frame_end && l->ilabel > 0) continue; + typename unordered_map::const_iterator iter = + tok_map.find(l->next_tok); + KALDI_ASSERT(iter != tok_map.end()); + StateId nextstate = iter->second; + BaseFloat cost_offset = 0.0; + if (l->ilabel != 0) { // emitting.. + KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); + cost_offset = cost_offsets_[f]; + } + Arc arc(l->ilabel, l->olabel, + Weight(l->graph_cost, l->acoustic_cost - cost_offset), nextstate); + ofst->AddArc(cur_state, arc); + } + // For the last frame in this chunk, we need to work out a + // proper final weight for the corresponding state. + // If use_final_probs == true, we will try to use the final cost we just + // calculated + // Otherwise, we use LatticeWeight::One(). We record these cost in the state + // Later in the code, if create_final_state == true, we will create + // a specific final state, and move the final costs to the cost of an arc + // connecting to the final state + if (f == frame_end) { + LatticeWeight weight = LatticeWeight::One(); + if (use_final_probs && !final_costs.empty()) { + typename unordered_map::const_iterator iter = + final_costs.find(tok); + if (iter != final_costs.end()) + weight = LatticeWeight(iter->second, 0); + else + weight = LatticeWeight::Zero(); + } + ofst->SetFinal(cur_state, weight); + } + } + } + // step 1.3 create final_arc for later appending with the next chunk + if (create_final_state) { + StateId end_state = ofst->AddState(); // final-state for the chunk + ofst->SetFinal(end_state, Weight::One()); + + token_label_map_.clear(); + token_label_map_.reserve(std::min((int32)1e5, config_.max_active)); + for (Token *tok = active_toks_[frame_end].toks; tok != NULL; tok = tok->next) { + StateId cur_state = tok_map[tok]; + // We assign an unique state label for each of the token in the last frame + // of this chunk + int32 id = token_label_available_idx_++; + token_label_map_[tok] = id; + // The final weight has been worked out in the previous for loop and + // store in the states + // Here, we create a specific final state, and move the final costs to + // the cost of an arc connecting to the final state + KALDI_ASSERT(ofst->Final(cur_state) != Weight::Zero()); + Weight final_weight = ofst->Final(cur_state); + // Use cost_offsets to guide DeterminizeLatticePruned() + // For now, we use extra_cost from the decoding stage , which has some + // "future information", as + // the final weights of this chunk + BaseFloat cost_offset = tok->extra_cost - tok->tot_cost; + // We record these cost_offset, and after we appending two chunks + // we will cancel them out + token_label_final_cost_[id] = cost_offset; + Arc arc(0, id, Times(final_weight, Weight(0, cost_offset)), end_state); + ofst->AddArc(cur_state, arc); + ofst->SetFinal(cur_state, Weight::Zero()); + } + } + TopSortLatticeIfNeeded(ofst); + return (ofst->NumStates() > 0); +} + +template +int32 LatticeIncrementalDecoderTpl::GetNumToksForFrame(int32 frame) { + int32 r = 0; + for (Token *tok = active_toks_[frame].toks; tok; tok = tok->next) r++; + return r; +} + +template +LatticeIncrementalDeterminizer::LatticeIncrementalDeterminizer( + const LatticeIncrementalDecoderConfig &config, const TransitionModel &trans_model) + : config_(config), trans_model_(trans_model) {} + +template +void LatticeIncrementalDeterminizer::Init() { + final_arc_list_.clear(); + final_arc_list_prev_.clear(); + lat_.DeleteStates(); + determinization_finalized_ = false; + forward_costs_.clear(); + state_last_initial_offset_ = 2 * config_.max_word_id; + redeterminized_states_.clear(); + processed_prefinal_states_.clear(); +} +template +bool LatticeIncrementalDeterminizer::AddRedeterminizedState( + Lattice::StateId nextstate, Lattice *olat, Lattice::StateId *nextstate_copy) { + using namespace fst; + bool modified = false; + StateId nextstate_insert = kNoStateId; + auto r = redeterminized_states_.insert({nextstate, nextstate_insert}); + if (r.second) { // didn't exist, successfully insert here + // create a new state w.r.t state + nextstate_insert = olat->AddState(); + // map from arc.nextstate to nextstate_insert + r.first->second = nextstate_insert; + modified = true; + } else { // else already exist + // get nextstate_insert + nextstate_insert = r.first->second; + KALDI_ASSERT(nextstate_insert != kNoStateId); + modified = false; + } + if (nextstate_copy) *nextstate_copy = nextstate_insert; + return modified; +} + +template +void LatticeIncrementalDeterminizer::GetRawLatticeForRedeterminizedStates( + StateId start_state, StateId state, + const unordered_map &token_label_final_cost, + unordered_multimap *token_label2last_state, + Lattice *olat) { + using namespace fst; + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + + auto r = redeterminized_states_.find(state); + KALDI_ASSERT(r != redeterminized_states_.end()); + auto state_copy = r->second; + KALDI_ASSERT(state_copy != kNoStateId); + ArcIterator aiter(lat_, state); + + // use state_label in initial arcs + int state_label = state + state_last_initial_offset_; + // Moreover, we need to use the forward coast (alpha) of this determinized and + // appended state to guide the determinization later + KALDI_ASSERT(state < forward_costs_.size()); + auto alpha_cost = forward_costs_[state]; + Arc arc_initial(0, state_label, LatticeWeight(0, alpha_cost), state_copy); + if (alpha_cost != std::numeric_limits::infinity()) + olat->AddArc(start_state, arc_initial); + + for (; !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + auto laststate_copy = kNoStateId; + bool proc_nextstate = false; + auto arc_weight = arc.weight; + + KALDI_ASSERT(arc.olabel == arc.ilabel); + auto arc_olabel = arc.olabel; + + // the destination of the arc is the final state + if (lat_.Final(arc.nextstate) != CompactLatticeWeight::Zero()) { + KALDI_ASSERT(arc_olabel > config_.max_word_id && + arc_olabel < state_last_initial_offset_); // token label + // create a initial arc + + // Get arc weight here + // We will include it in arc_last in the following + CompactLatticeWeight weight_offset; + // To cancel out the weight on the final arcs, which is (extra cost - forward + // cost). + // see token_label_final_cost for more details + const auto r = token_label_final_cost.find(arc_olabel); + KALDI_ASSERT(r != token_label_final_cost.end()); + auto cost_offset = r->second; + weight_offset.SetWeight(LatticeWeight(0, -cost_offset)); + // The arc weight is a combination of original arc weight, above cost_offset + // and the weights on the final state + arc_weight = Times(Times(arc_weight, lat_.Final(arc.nextstate)), weight_offset); + + // We create a respective destination state for each final arc + // later we will connect it to the state correponding to the token w.r.t + // arc_olabel + laststate_copy = olat->AddState(); + // the destination state of the last of the sequence of arcs will be recorded + // and connected to the state corresponding to token w.r.t arc_olabel + // Notably, we have multiple states for one token label after determinization, + // hence we use multiset here + token_label2last_state->insert( + std::pair(arc_olabel, laststate_copy)); + arc_olabel = 0; // remove token label + } else { + // the arc connects to a non-final state (redeterminized state) + KALDI_ASSERT(arc_olabel < config_.max_word_id); // no token label + KALDI_ASSERT(arc_olabel); + // get the nextstate_copy w.r.t arc.nextstate + StateId nextstate_copy = kNoStateId; + proc_nextstate = AddRedeterminizedState(arc.nextstate, olat, &nextstate_copy); + KALDI_ASSERT(nextstate_copy != kNoStateId); + laststate_copy = nextstate_copy; + } + auto &state_seqs = arc_weight.String(); + // create new arcs w.r.t arc + // the following is for a normal arc + // We generate a linear sequence of arcs sufficient to contain all the + // transition-ids on the string + auto prev_state = state_copy; // from state_copy + for (auto &j : state_seqs) { + auto cur_state = olat->AddState(); + Arc arc(j, 0, LatticeWeight::One(), cur_state); + olat->AddArc(prev_state, arc); + prev_state = cur_state; + } + + // connect previous sequence of arcs to the laststate_copy + // the weight on the previous arc is stored in the arc to laststate_copy here + Arc arc_last(0, arc_olabel, arc_weight.Weight(), laststate_copy); + olat->AddArc(prev_state, arc_last); + + // not final state && previously didn't process this state + if (proc_nextstate) + GetRawLatticeForRedeterminizedStates(start_state, arc.nextstate, + token_label_final_cost, + token_label2last_state, olat); + } +} +template +void LatticeIncrementalDeterminizer::GetRedeterminizedStates() { + using namespace fst; + processed_prefinal_states_.clear(); + // go over all prefinal state + KALDI_ASSERT(final_arc_list_prev_.size()); + unordered_set prefinal_states; + + for (auto &i : final_arc_list_prev_) { + auto prefinal_state = i.first; + ArcIterator aiter(lat_, prefinal_state); + KALDI_ASSERT(lat_.NumArcs(prefinal_state) > i.second); + aiter.Seek(i.second); + auto final_arc = aiter.Value(); + auto final_weight = lat_.Final(final_arc.nextstate); + KALDI_ASSERT(final_weight != CompactLatticeWeight::Zero()); + auto num_frames = Times(final_arc.weight, final_weight).String().size(); + // If the state is too far from the end of the current appended lattice, + // we leave the non-final arcs unchanged and only redeterminize the final + // arcs by the following procedure. + // We also do above things once we prepare to redeterminize the start state. + if (num_frames <= config_.redeterminize_max_frames && prefinal_state != 0) + processed_prefinal_states_[prefinal_state] = prefinal_state; + else { + KALDI_VLOG(7) << "Impose a limit of " << config_.redeterminize_max_frames + << " on how far back in time we will redeterminize states. " + << num_frames << " frames in this arc. "; + + auto new_prefinal_state = lat_.AddState(); + forward_costs_.resize(new_prefinal_state + 1); + forward_costs_[new_prefinal_state] = forward_costs_[prefinal_state]; + + std::vector arcs_remained; + for (aiter.Reset(); !aiter.Done(); aiter.Next()) { + auto arc = aiter.Value(); + bool remain_the_arc = true; // If we remain the arc, the state will not be + // re-determinized, vice versa. + if (arc.olabel > config_.max_word_id) { // final arc + KALDI_ASSERT(arc.olabel < state_last_initial_offset_); + KALDI_ASSERT(lat_.Final(arc.nextstate) != CompactLatticeWeight::Zero()); + remain_the_arc = false; + } else { + int num_frames_exclude_arc = num_frames - arc.weight.String().size(); + // destination-state of the arc is further than redeterminize_max_frames + // from the most recent frame we are determinizing + if (num_frames_exclude_arc > config_.redeterminize_max_frames) + remain_the_arc = true; + else { + // destination-state of the arc is no further than + // redeterminize_max_frames from the most recent frame we are + // determinizing + auto r = final_arc_list_prev_.find(arc.nextstate); + // destination-state of the arc is not prefinal state + if (r == final_arc_list_prev_.end()) remain_the_arc = true; + // destination-state of the arc is prefinal state + else + remain_the_arc = false; + } + } + + if (remain_the_arc) + arcs_remained.push_back(arc); + else + lat_.AddArc(new_prefinal_state, arc); + } + CompactLatticeArc arc_to_new(0, 0, CompactLatticeWeight::One(), + new_prefinal_state); + arcs_remained.push_back(arc_to_new); + + lat_.DeleteArcs(prefinal_state); + for (auto &i : arcs_remained) lat_.AddArc(prefinal_state, i); + processed_prefinal_states_[prefinal_state] = new_prefinal_state; + } + } + KALDI_VLOG(8) << "states of the lattice after GetRedeterminizedStates: " + << lat_.NumStates(); +} + +// This function is specifically designed to obtain the initial arcs for a chunk +// We have multiple states for one token label after determinization +template +void LatticeIncrementalDeterminizer::GetInitialRawLattice( + Lattice *olat, + unordered_multimap *token_label2last_state, + const unordered_map &token_label_final_cost) { + using namespace fst; + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + + GetRedeterminizedStates(); + + olat->DeleteStates(); + token_label2last_state->clear(); + + auto start_state = olat->AddState(); + olat->SetStart(start_state); + // go over all prefinal states after preprocessing + for (auto &i : processed_prefinal_states_) { + auto prefinal_state = i.second; + bool modified = AddRedeterminizedState(prefinal_state, olat); + if (modified) + GetRawLatticeForRedeterminizedStates(start_state, prefinal_state, + token_label_final_cost, + token_label2last_state, olat); + } +} + +template +bool LatticeIncrementalDeterminizer::ProcessChunk(Lattice &raw_fst, + int32 first_frame, + int32 last_frame) { + bool not_first_chunk = first_frame != 0; + bool ret = true; + // step 2: Determinize the chunk + CompactLattice clat; + // We do determinization with beam pruning here + // Only if we use a beam larger than (config_.beam+config_.lattice_beam) here, we + // can guarantee no final or initial arcs in clat are pruned by this function. + // These pruned final arcs can hurt oracle WER performance in the final lattice + // (also result in less lattice density) but they seldom hurt 1-best WER. + // Since pruning behaviors in DeterminizeLatticePhonePrunedWrapper and + // PruneActiveTokens are not the same, to get similar lattice density as + // LatticeFasterDecoder, we need to use a slightly larger beam here + // than the lattice_beam used PruneActiveTokens. Hence the beam we use is + // (0.1 + config_.lattice_beam) + ret &= DeterminizeLatticePhonePrunedWrapper( + trans_model_, &raw_fst, (config_.lattice_beam + 0.1), &clat, config_.det_opts); + + // step 3: Appending the new chunk in clat to the old one in lat_ + ret &= AppendLatticeChunks(clat, not_first_chunk); + + ret &= (lat_.NumStates() > 0); + KALDI_VLOG(2) << "Frame: ( " << first_frame << " , " << last_frame << " )" + << " states of the chunk: " << clat.NumStates() + << " states of the lattice: " << lat_.NumStates(); + return ret; +} + +template +bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice clat, + bool not_first_chunk) { + using namespace fst; + CompactLattice *olat = &lat_; + + // later we need to calculate forward_costs_ for clat + TopSortCompactLatticeIfNeeded(&clat); + + // step 3.1: Appending new chunk to the old one + int32 state_offset = olat->NumStates(); + if (not_first_chunk) { + state_offset--; // since we do not append initial state in the first chunk + // remove arcs from redeterminized_states_ + for (auto i : redeterminized_states_) { + olat->DeleteArcs(i.first); + olat->SetFinal(i.first, CompactLatticeWeight::Zero()); + } + redeterminized_states_.clear(); + } else { + forward_costs_.push_back(0); // for the first state + } + forward_costs_.resize(state_offset + clat.NumStates(), + std::numeric_limits::infinity()); + + // Here we construct a map from the original prefinal state to the prefinal states + // for later use + unordered_map invert_processed_prefinal_states; + invert_processed_prefinal_states.reserve(processed_prefinal_states_.size()); + for (auto i : processed_prefinal_states_) + invert_processed_prefinal_states[i.second] = i.first; + for (StateIterator siter(clat); !siter.Done(); siter.Next()) { + auto s = siter.Value(); + StateId state_appended = kNoStateId; + // We do not copy initial state, which exists except the first chunk + if (!not_first_chunk || s != 0) { + state_appended = s + state_offset; + auto r = olat->AddState(); + KALDI_ASSERT(state_appended == r); + olat->SetFinal(state_appended, clat.Final(s)); + } + + for (ArcIterator aiter(clat, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + + StateId source_state = kNoStateId; + // We do not copy initial arcs, which exists except the first chunk. + // These arcs will be taken care later in step 3.2 + CompactLatticeArc arc_appended(arc); + arc_appended.nextstate += state_offset; + // In the first chunk, there could be a final arc starting from state 0, and we + // process it here + // In the last chunk, there could be a initial arc ending in final state, and + // we process it in "process initial arcs" in the following + bool is_initial_state = (not_first_chunk && s == 0); + if (!is_initial_state) { + KALDI_ASSERT(state_appended != kNoStateId); + KALDI_ASSERT(arc.olabel < state_last_initial_offset_); + source_state = state_appended; + // process final arcs + if (arc.olabel > config_.max_word_id) { + // record final_arc in this chunk for the step 3.2 in the next call + KALDI_ASSERT(arc.olabel < state_last_initial_offset_); + KALDI_ASSERT(clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()); + // state_appended shouldn't be in invert_processed_prefinal_states + // So we do not need to map it + final_arc_list_.insert( + pair(state_appended, aiter.Position())); + } + olat->AddArc(source_state, arc_appended); + } else { // process initial arcs + // a special olabel in the arc that corresponds to the identity of the + // source-state of the last arc, we use its StateId and a offset here, called + // state_label + auto state_label = arc.olabel; + KALDI_ASSERT(state_label > config_.max_word_id); + KALDI_ASSERT(state_label >= state_last_initial_offset_); + source_state = state_label - state_last_initial_offset_; + arc_appended.olabel = 0; + arc_appended.ilabel = 0; + CompactLatticeWeight weight_offset; + // remove alpha in weight + weight_offset.SetWeight(LatticeWeight(0, -forward_costs_[source_state])); + arc_appended.weight = Times(arc_appended.weight, weight_offset); + + // if it is an extra prefinal state, we should use its original prefinal + // state + int arc_offset = 0; + auto r = invert_processed_prefinal_states.find(source_state); + if (r != invert_processed_prefinal_states.end() && r->second != r->first) { + source_state = r->second; + arc_offset = olat->NumArcs(source_state); + } + + if (clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()) { + // it should be the last chunk + olat->SetFinal(source_state, + Times(arc_appended.weight, clat.Final(arc.nextstate))); + } else { + // append lattice chunk and remove Epsilon together + for (ArcIterator aiter_postinitial(clat, arc.nextstate); + !aiter_postinitial.Done(); aiter_postinitial.Next()) { + auto arc_postinitial(aiter_postinitial.Value()); + arc_postinitial.weight = + Times(arc_appended.weight, arc_postinitial.weight); + arc_postinitial.nextstate += state_offset; + olat->AddArc(source_state, arc_postinitial); + if (arc_postinitial.olabel > config_.max_word_id) { + KALDI_ASSERT(arc_postinitial.olabel < state_last_initial_offset_); + final_arc_list_.insert(pair( + source_state, aiter_postinitial.Position() + arc_offset)); + } + } + } + } + // update forward_costs_ (alpha) + KALDI_ASSERT(arc_appended.nextstate < forward_costs_.size()); + auto &alpha_nextstate = forward_costs_[arc_appended.nextstate]; + auto &weight = arc_appended.weight.Weight(); + alpha_nextstate = + std::min(alpha_nextstate, + forward_costs_[source_state] + weight.Value1() + weight.Value2()); + } + } + KALDI_ASSERT(olat->NumStates() == clat.NumStates() + state_offset); + KALDI_VLOG(8) << "states of the lattice: " << olat->NumStates(); + + if (!not_first_chunk) { + olat->SetStart(0); // Initialize the first chunk for olat + } else { + // The extra prefinal states generated by + // GetRedeterminizedStates are removed here, while splicing + // the compact lattices together + for (auto &i : processed_prefinal_states_) { + auto prefinal_state = i.first; + auto new_prefinal_state = i.second; + // It is without an extra prefinal state, hence do not need to process + if (prefinal_state == new_prefinal_state) continue; + for (ArcIterator aiter(*olat, new_prefinal_state); + !aiter.Done(); aiter.Next()) + olat->AddArc(prefinal_state, aiter.Value()); + olat->DeleteArcs(new_prefinal_state); + olat->SetFinal(new_prefinal_state, CompactLatticeWeight::Zero()); + } + } + + final_arc_list_.swap(final_arc_list_prev_); + final_arc_list_.clear(); + + return true; +} + +template +bool LatticeIncrementalDeterminizer::Finalize() { + using namespace fst; + auto *olat = &lat_; + // The lattice determinization only needs to be finalized once + if (determinization_finalized_) return true; + // step 4: remove dead states + if (config_.final_prune_after_determinize) + PruneLattice(config_.lattice_beam, olat); + else + Connect(olat); // Remove unreachable states... there might be + + KALDI_VLOG(2) << "states of the lattice: " << olat->NumStates(); + determinization_finalized_ = true; + + return (olat->NumStates() > 0); +} + +// Instantiate the template for the combination of token types and FST types +// that we'll need. +template class LatticeIncrementalDecoderTpl, decoder::StdToken>; +template class LatticeIncrementalDecoderTpl, + decoder::StdToken>; +template class LatticeIncrementalDecoderTpl, + decoder::StdToken>; +template class LatticeIncrementalDecoderTpl; + +template class LatticeIncrementalDecoderTpl, + decoder::BackpointerToken>; +template class LatticeIncrementalDecoderTpl, + decoder::BackpointerToken>; +template class LatticeIncrementalDecoderTpl, + decoder::BackpointerToken>; +template class LatticeIncrementalDecoderTpl; + +} // end namespace kaldi. diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h new file mode 100644 index 00000000000..9f930b6610d --- /dev/null +++ b/src/decoder/lattice-incremental-decoder.h @@ -0,0 +1,656 @@ +// decoder/lattice-incremental-decoder.h + +// Copyright 2019 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_LATTICE_INCREMENTAL_DECODER_H_ +#define KALDI_DECODER_LATTICE_INCREMENTAL_DECODER_H_ + +#include "util/stl-utils.h" +#include "util/hash-list.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/grammar-fst.h" +#include "lattice-faster-decoder.h" + +namespace kaldi { + +struct LatticeIncrementalDecoderConfig { + BaseFloat beam; + int32 max_active; + int32 min_active; + BaseFloat lattice_beam; + int32 prune_interval; + int32 determinize_delay; + int32 determinize_period; + int32 determinize_max_active; + int32 redeterminize_max_frames; + bool final_prune_after_determinize; + BaseFloat beam_delta; // has nothing to do with beam_ratio + BaseFloat hash_ratio; + BaseFloat prune_scale; // Note: we don't make this configurable on the command line, + // it's not a very important parameter. It affects the + // algorithm that prunes the tokens as we go. + // Most of the options inside det_opts are not actually queried by the + // LatticeIncrementalDecoder class itself, but by the code that calls it, for + // example in the function DecodeUtteranceLatticeIncremental. + int32 max_word_id; // for GetLattice + fst::DeterminizeLatticePhonePrunedOptions det_opts; + + LatticeIncrementalDecoderConfig() + : beam(16.0), + max_active(std::numeric_limits::max()), + min_active(200), + lattice_beam(10.0), + prune_interval(25), + determinize_delay(25), + determinize_period(20), + determinize_max_active(std::numeric_limits::max()), + redeterminize_max_frames(std::numeric_limits::max()), + final_prune_after_determinize(true), + beam_delta(0.5), + hash_ratio(2.0), + prune_scale(0.1), + max_word_id(1e8) {} + void Register(OptionsItf *opts) { + det_opts.Register(opts); + opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); + opts->Register("max-active", &max_active, + "Decoder max active states. Larger->slower; " + "more accurate"); + opts->Register("min-active", &min_active, "Decoder minimum #active states."); + opts->Register("lattice-beam", &lattice_beam, + "Lattice generation beam. Larger->slower, " + "and deeper lattices"); + opts->Register("prune-interval", &prune_interval, + "Interval (in frames) at " + "which to prune tokens"); + opts->Register("determinize-delay", &determinize_delay, + "Delay (in frames) at which to incrementally determinize " + "lattices. A larger delay reduces the computational " + "overhead of incremental deteriminization while increasing" + "the length of the last chunk which may increase latency."); + opts->Register("determinize-period", &determinize_period, + "The size (in frames) of chunk to do incrementally " + "determinization. If working with --determinize-max-active," + "it will become a lower bound of the size of chunk."); + opts->Register("determinize-max-active", &determinize_max_active, + "This option is to adaptively decide the size of the chunk " + "to be determinized. " + "If the number of active tokens(in a certain frame) is less " + "than this number (typically 50), we will start to " + "incrementally determinize lattices from the last frame we " + "determinized up to this frame. It can work with " + "--determinize-delay to further reduce the computation " + "introduced by incremental determinization. "); + opts->Register("redeterminize-max-frames", &redeterminize_max_frames, + "To impose a limit on how far back in time we will " + "redeterminize states. This is mainly intended to avoid " + "pathological cases. Smaller value leads to less " + "deterministic but less likely to blow up the processing" + "time in bad cases. You could set it infinite to get a fully " + "determinized lattice."); + opts->Register("final-prune-after-determinize", &final_prune_after_determinize, + "prune lattice after determinization "); + opts->Register("beam-delta", &beam_delta, + "Increment used in decoding-- this " + "parameter is obscure and relates to a speedup in the way the " + "max-active constraint is applied. Larger is more accurate."); + opts->Register("hash-ratio", &hash_ratio, + "Setting used in decoder to " + "control hash behavior"); + } + void Check() const { + KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 && + min_active <= max_active && prune_interval > 0 && + determinize_delay >= 0 && determinize_max_active >= 0 && + determinize_period >= 0 && redeterminize_max_frames >= 0 && + beam_delta > 0.0 && hash_ratio >= 1.0 && prune_scale > 0.0 && + prune_scale < 1.0); + } +}; + +template +class LatticeIncrementalDeterminizer; + +/* This is an extention to the "normal" lattice-generating decoder. + See \ref lattices_generation \ref decoders_faster and \ref decoders_simple + for more information. + + The main difference is the incremental determinization which will be + discussed in the function GetLattice(). + + The decoder is templated on the FST type and the token type. The token type + will normally be StdToken, but also may be BackpointerToken which is to support + quick lookup of the current best path (see lattice-faster-online-decoder.h) + + The FST you invoke this decoder with is expected to be of type + Fst::Fst, a.k.a. StdFst, or GrammarFst. If you invoke it with + FST == StdFst and it notices that the actual FST type is + fst::VectorFst or fst::ConstFst, the decoder object + will internally cast itself to one that is templated on those more specific + types; this is an optimization for speed. + */ +template +class LatticeIncrementalDecoderTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using ForwardLinkT = decoder::ForwardLink; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeIncrementalDecoderTpl(const FST &fst, const TransitionModel &trans_model, + const LatticeIncrementalDecoderConfig &config); + + // This version of the constructor takes ownership of the fst, and will delete + // it when this object is destroyed. + LatticeIncrementalDecoderTpl(const LatticeIncrementalDecoderConfig &config, + FST *fst, const TransitionModel &trans_model); + + void SetOptions(const LatticeIncrementalDecoderConfig &config) { config_ = config; } + + const LatticeIncrementalDecoderConfig &GetOptions() const { return config_; } + + ~LatticeIncrementalDecoderTpl(); + + /// An example of how to do decoding together with incremental + /// determinization. It decodes until there are no more frames left in the + /// "decodable" object. Note, this may block waiting for input + /// if the "decodable" object blocks. + /// In this example, config_.determinize_delay, config_.determinize_period + /// and config_.determinize_max_active are used to determine the time to + /// call GetLattice(). + /// Users may do it in their own ways by calling + /// AdvanceDecoding() and GetLattice(). So the logic for deciding + /// when we get the lattice would be driven by the user. + /// The function returns true if any kind + /// of traceback is available (not necessarily from a final state). + bool Decode(DecodableInterface *decodable); + + /// says whether a final-state was active on the last frame. If it was not, the + /// lattice (or traceback) will end with states that are not final-states. + bool ReachedFinal() const { + return FinalRelativeCost() != std::numeric_limits::infinity(); + } + + /// Outputs an FST corresponding to the single best path through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + bool GetBestPath(Lattice *ofst, bool use_final_probs = true); + + /** + The following function is specifically designed for incremental + determinization. The function obtains a CompactLattice for + the part of this utterance up to the frame last_frame_of_chunk. + If you call this multiple times + (calling it on every frame would not make sense, but every, say, + 10 to 40 frames might make sense) it will spread out the work of + determinization over time, which might be useful for online applications. + config_.determinize_delay, config_.determinize_period + and config_.determinize_max_active can be used to determine the time to + call this function. We show an example in Decode(). + + The procedure of incremental determinization is as follow: + step 1: Get lattice chunk with initial and final states and arcs, called `raw + lattice`. + Here, we define a `final arc` as an arc to a final-state, and the source state + of it as a `pre-final state` + Similarly, we define a `initial arc` as an arc from a initial-state, and the + destination state of it as a `post-initial state` + The post-initial states are constructed corresponding to pre-final states + in the determinized and appended lattice before this chunk + The pre-final states are constructed correponding to tokens in the last frames + of this chunk. + Since the StateId can change during determinization, we need to give permanent + unique labels (as olabel) to these + raw-lattice states for latter appending. + We give each token an olabel id, called `token_label`, and each determinized and + appended state an olabel id, called `state_label`. Notably, in our + paper, we call both of them ``state labels'' for simplicity. + step 2: Determinize the chunk of above raw lattice using determinization + algorithm the same as LatticeFasterDecoder. Benefit from above `state_label` and + `token_label` in initial and final arcs, each pre-final state in the last chunk + w.r.t the initial arc of this chunk can be treated uniquely and each token in + the last frame of this chunk can also be treated uniquely. We call the + determinized new + chunk `compact lattice (clat)` + step 3: Appending the new chunk `clat` to the determinized lattice + before this chunk. First, for each StateId in clat except its + initial state, allocate a new StateId in the appended + compact lattice. Copy the arcs except whose incoming state is initial + state. Secondly, for each initial arcs, change its source state to the state + corresponding to its `state_label`, which is a determinized and appended state + Finally, we make the previous final arcs point to a "dead state" + step 4: We remove dead states in the very end. + + In our implementation, step 1 is done in GetIncrementalRawLattice(), + step 2-4 is taken care by the class + LatticeIncrementalDeterminizer + + @param [in] use_final_probs If true *and* at least one final-state in HCLG + was active on the final frame, include final-probs from + HCLG + in the lattice. Otherwise treat all final-costs of + states active + on the most recent frame as zero (i.e. Weight::One()). + @param [in] last_frame_of_chunk Pass the last frame of this chunk to + the function. We make it not always equal to + NumFramesDecoded() to have a delay on the + deteriminization + @param [out] olat The CompactLattice representing what has been decoded + so far. + If lat == NULL, the CompactLattice won't be outputed. + @return ret This function will returns true if the chunk is processed + successfully + */ + bool GetLattice(bool use_final_probs, int32 last_frame_of_chunk, + CompactLattice *olat = NULL); + /// Specifically design when decoding_finalized_==true + bool GetLattice(CompactLattice *olat); + + /// InitDecoding initializes the decoding, and should only be used if you + /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to + /// call this. You can also call InitDecoding if you have already decoded an + /// utterance and want to start with a new utterance. + void InitDecoding(); + + /// This will decode until there are no more frames ready in the decodable + /// object. You can keep calling it each time more frames become available. + /// If max_num_frames is specified, it specifies the maximum number of frames + /// the function will decode before returning. + void AdvanceDecoding(DecodableInterface *decodable, int32 max_num_frames = -1); + + /// This function may be optionally called after AdvanceDecoding(), when you + /// do not plan to decode any further. It does an extra pruning step that + /// will help to prune the lattices output by GetLattice more accurately, + /// particularly toward the end of the utterance. + /// It does this by using the final-probs in pruning (if any + /// final-state survived); it also does a final pruning step that visits all + /// states (the pruning that is done during decoding may fail to prune states + /// that are within kPruningScale = 0.1 outside of the beam). If you call + /// this, you cannot call AdvanceDecoding again (it will fail), and you + /// cannot call GetLattice() and related functions with use_final_probs = + /// false. + /// Used to be called PruneActiveTokensFinal(). + void FinalizeDecoding(); + + /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives + /// more information. It returns the difference between the best (final-cost + /// plus cost) of any token on the final frame, and the best cost of any token + /// on the final frame. If it is infinity it means no final-states were + /// present on the final frame. It will usually be nonnegative. If it not + /// too positive (e.g. < 5 is my first guess, but this is not tested) you can + /// take it as a good indication that we reached the final-state with + /// reasonable likelihood. + BaseFloat FinalRelativeCost() const; + + // Returns the number of frames decoded so far. The value returned changes + // whenever we call ProcessEmitting(). + inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } + + protected: + // we make things protected instead of private, as future code in + // LatticeIncrementalOnlineDecoderTpl, which inherits from this, also will + // use the internals. + + // Deletes the elements of the singly linked list tok->links. + inline static void DeleteForwardLinks(Token *tok); + + // head of per-frame list of Tokens (list is in topological order), + // and something saying whether we ever pruned it using PruneForwardLinks. + struct TokenList { + Token *toks; + bool must_prune_forward_links; + bool must_prune_tokens; + TokenList() + : toks(NULL), must_prune_forward_links(true), must_prune_tokens(true) {} + }; + + using Elem = typename HashList::Elem; + // Equivalent to: + // struct Elem { + // StateId key; + // Token *val; + // Elem *tail; + // }; + + void PossiblyResizeHash(size_t num_toks); + + // FindOrAddToken either locates a token in hash of toks_, or if necessary + // inserts a new, empty token (i.e. with no forward links) for the current + // frame. [note: it's inserted if necessary into hash toks_ and also into the + // singly linked list of tokens active on this frame (whose head is at + // active_toks_[frame]). The frame_plus_one argument is the acoustic frame + // index plus one, which is used to index into the active_toks_ array. + // Returns the Token pointer. Sets "changed" (if non-NULL) to true if the + // token was newly created or the cost changed. + // If Token == StdToken, the 'backpointer' argument has no purpose (and will + // hopefully be optimized out). + inline Token *FindOrAddToken(StateId state, int32 frame_plus_one, + BaseFloat tot_cost, Token *backpointer, bool *changed); + + // prunes outgoing links for all tokens in active_toks_[frame] + // it's called by PruneActiveTokens + // all links, that have link_extra_cost > lattice_beam are pruned + // delta is the amount by which the extra_costs must change + // before we set *extra_costs_changed = true. + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, BaseFloat delta); + + // This function computes the final-costs for tokens active on the final + // frame. It outputs to final-costs, if non-NULL, a map from the Token* + // pointer to the final-prob of the corresponding state, for all Tokens + // that correspond to states that have final-probs. This map will be + // empty if there were no final-probs. It outputs to + // final_relative_cost, if non-NULL, the difference between the best + // forward-cost including the final-prob cost, and the best forward-cost + // without including the final-prob cost (this will usually be positive), or + // infinity if there were no final-probs. [c.f. FinalRelativeCost(), which + // outputs this quanitity]. It outputs to final_best_cost, if + // non-NULL, the lowest for any token t active on the final frame, of + // forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in + // the graph of the state corresponding to token t, or the best of + // forward-cost[t] if there were no final-probs active on the final frame. + // You cannot call this after FinalizeDecoding() has been called; in that + // case you should get the answer from class-member variables. + void ComputeFinalCosts(unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const; + + // PruneForwardLinksFinal is a version of PruneForwardLinks that we call + // on the final frame. If there are final tokens active, it uses + // the final-probs for pruning, otherwise it treats all tokens as final. + void PruneForwardLinksFinal(); + + // Prune away any tokens on this frame that have no forward links. + // [we don't do this in PruneForwardLinks because it would give us + // a problem with dangling pointers]. + // It's called by PruneActiveTokens if any forward links have been pruned + void PruneTokensForFrame(int32 frame_plus_one); + + // Go backwards through still-alive tokens, pruning them if the + // forward+backward cost is more than lat_beam away from the best path. It's + // possible to prove that this is "correct" in the sense that we won't lose + // anything outside of lat_beam, regardless of what happens in the future. + // delta controls when it considers a cost to have changed enough to continue + // going backward and propagating the change. larger delta -> will recurse + // less far. + void PruneActiveTokens(BaseFloat delta); + + /// Gets the weight cutoff. Also counts the active tokens. + BaseFloat GetCutoff(Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, + Elem **best_elem); + + /// Processes emitting arcs for one frame. Propagates from prev_toks_ to + /// cur_toks_. Returns the cost cutoff for subsequent ProcessNonemitting() to + /// use. + BaseFloat ProcessEmitting(DecodableInterface *decodable); + + /// Processes nonemitting (epsilon) arcs for one frame. Called after + /// ProcessEmitting() on each frame. The cost cutoff is computed by the + /// preceding ProcessEmitting(). + void ProcessNonemitting(BaseFloat cost_cutoff); + + // HashList defined in ../util/hash-list.h. It actually allows us to maintain + // more than one list (e.g. for current and previous frames), but only one of + // them at a time can be indexed by StateId. It is indexed by frame-index + // plus one, where the frame-index is zero-based, as used in decodable object. + // That is, the emitting probs of frame t are accounted for in tokens at + // toks_[t+1]. The zeroth frame is for nonemitting transition at the start of + // the graph. + HashList toks_; + + std::vector active_toks_; // Lists of tokens, indexed by + // frame (members of TokenList are toks, must_prune_forward_links, + // must_prune_tokens). + std::vector queue_; // temp variable used in ProcessNonemitting, + std::vector tmp_array_; // used in GetCutoff. + + // fst_ is a pointer to the FST we are decoding from. + const FST *fst_; + // delete_fst_ is true if the pointer fst_ needs to be deleted when this + // object is destroyed. + bool delete_fst_; + + std::vector cost_offsets_; // This contains, for each + // frame, an offset that was added to the acoustic log-likelihoods on that + // frame in order to keep everything in a nice dynamic range i.e. close to + // zero, to reduce roundoff errors. + LatticeIncrementalDecoderConfig config_; + int32 num_toks_; // current total #toks allocated... + bool warned_; + + /// decoding_finalized_ is true if someone called FinalizeDecoding(). [note, + /// calling this is optional]. If true, it's forbidden to decode more. Also, + /// if this is set, then the output of ComputeFinalCosts() is in the next + /// three variables. The reason we need to do this is that after + /// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some + /// of the tokens on the last frame are freed, so we free the list from toks_ + /// to avoid having dangling pointers hanging around. + bool decoding_finalized_; + /// For the meaning of the next 3 variables, see the comment for + /// decoding_finalized_ above., and ComputeFinalCosts(). + unordered_map final_costs_; + BaseFloat final_relative_cost_; + BaseFloat final_best_cost_; + + // There are various cleanup tasks... the the toks_ structure contains + // singly linked lists of Token pointers, where Elem is the list type. + // It also indexes them in a hash, indexed by state (this hash is only + // maintained for the most recent frame). toks_.Clear() + // deletes them from the hash and returns the list of Elems. The + // function DeleteElems calls toks_.Delete(elem) for each elem in + // the list, which returns ownership of the Elem to the toks_ structure + // for reuse, but does not delete the Token pointer. The Token pointers + // are reference-counted and are ultimately deleted in PruneTokensForFrame, + // but are also linked together on each frame by their own linked-list, + // using the "next" pointer. We delete them manually. + void DeleteElems(Elem *list); + + // This function takes a singly linked list of tokens for a single frame, and + // outputs a list of them in topological order (it will crash if no such order + // can be found, which will typically be due to decoding graphs with epsilon + // cycles, which are not allowed). Note: the output list may contain NULLs, + // which the caller should pass over; it just happens to be more efficient for + // the algorithm to output a list that contains NULLs. + static void TopSortTokens(Token *tok_list, std::vector *topsorted_list); + + void ClearActiveTokens(); + + // The following part is specifically designed for incremental determinization + // This function is modified from LatticeFasterDecoderTpl::GetRawLattice() + // and specific design for step 1 of incremental determinization + // introduced before above GetLattice() + // It does the same thing as GetRawLattice in lattice-faster-decoder.cc except: + // + // i) it creates a initial state, and connect + // each token in the first frame of this chunk to the initial state + // by one or more arcs with a state_label correponding to the pre-final state w.r.t + // this token(the pre-final state is appended in the last chunk) as its olabel + // ii) it creates a final state, and connect + // all the tokens in the last frame of this chunk to the final state + // by an arc with a per-token token_label as its olabel + // `frame_begin` and `frame_end` are the first and last frame of this chunk + // if `create_initial_state` == false, we will not create initial state and + // the corresponding initial arcs. Similar for `create_final_state` + // In incremental GetLattice, we do not create the initial state in + // the first chunk, and we do not create the final state in the last chunk + bool GetIncrementalRawLattice(Lattice *ofst, bool use_final_probs, + int32 frame_begin, int32 frame_end, + bool create_initial_state, bool create_final_state); + // Get the number of tokens in each frame + // It is useful, e.g. in using config_.determinize_max_active + int32 GetNumToksForFrame(int32 frame); + void DeterminizeLattice(); + + // The incremental lattice determinizer to take care of determinization + // and appending the lattice. + LatticeIncrementalDeterminizer determinizer_; + int32 last_get_lattice_frame_; // the last time we call GetLattice + // a map from Token to its token_label + unordered_map token_label_map_; + // we allocate a unique id for each Token + int32 token_label_available_idx_; + // We keep cost_offset for each token_label (Token) in final arcs. We need them to + // guide determinization + // We cancel them after determinization + unordered_map token_label_final_cost_; + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDecoderTpl); +}; + +typedef LatticeIncrementalDecoderTpl + LatticeIncrementalDecoder; + +// This class is designed for part of generating raw lattices and determnization +// and appending the lattice. +template +class LatticeIncrementalDeterminizer { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + LatticeIncrementalDeterminizer(const LatticeIncrementalDecoderConfig &config, + const TransitionModel &trans_model); + // Reset the lattice determinization data for an utterance + void Init(); + // Output the resultant determinized lattice in the form of CompactLattice + const CompactLattice &GetDeterminizedLattice() const { return lat_; } + + // Part of step 1 of incremental determinization, + // where the post-initial states are constructed corresponding to + // redeterminized states (see the description in redeterminized_states_) in the + // determinized and appended lattice before this chunk. + // We give each determinized and appended state an olabel id, called `state_label` + // We maintain a map (`token_label2last_state`) from token label (obtained from + // final arcs) to the destination state of the last of the sequence of initial arcs + // w.r.t the token label here + // Notably, we have multiple states for one token label after determinization, + // hence we use multiset here + // We need `token_label_final_cost` to cancel out the cost offset used in guiding + // DeterminizeLatticePhonePrunedWrapper + void GetInitialRawLattice( + Lattice *olat, + unordered_multimap *token_label2last_state, + const unordered_map &token_label_final_cost); + // This function consumes raw_fst generated by step 1 of incremental + // determinization with specific initial and final arcs. + // It processes lattices and outputs the resultant CompactLattice if + // needed. Otherwise, it keeps the resultant lattice in lat_ + bool ProcessChunk(Lattice &raw_fst, int32 first_frame, int32 last_frame); + + // Step 3 of incremental determinization, + // which is to append the new chunk in clat to the old one in lat_ + // If not_first_chunk == false, we do not need to append and just copy + // clat into olat + // Otherwise, we need to connect states of the last frame of + // the last chunk to states of the first frame of this chunk. + // These post-initial and pre-final states are corresponding to the same Token, + // guaranteed by unique state labels. + bool AppendLatticeChunks(CompactLattice clat, bool not_first_chunk); + + // Step 4 of incremental determinization, + // which either re-determinize above lat_, or simply remove the dead + // states of lat_ + bool Finalize(); + std::vector &GetForwardCosts() { return forward_costs_; } + + private: + // This function either locates a redeterminized state w.r.t nextstate previously + // added, or if necessary inserts a new one. + // The new one is inserted in olat and kept by the map (redeterminized_states_) + // which is from the state in the appended compact lattice to the state_copy in the + // raw lattice. The function returns whether a new one is inserted + // The StateId of the redeterminized state will be outputed by nextstate_copy + bool AddRedeterminizedState(Lattice::StateId nextstate, Lattice *olat, + Lattice::StateId *nextstate_copy = NULL); + // Sub function of GetInitialRawLattice(). Refer to description there + void GetRawLatticeForRedeterminizedStates( + StateId start_state, StateId state, + const unordered_map &token_label_final_cost, + unordered_multimap *token_label2last_state, + Lattice *olat); + // This function is to preprocess the appended compact lattice before + // generating raw lattices for the next chunk. + // After identifying pre-final states, for any such state that is separated by + // more than config_.redeterminize_max_frames from the end of the current + // appended lattice, we create an extra state for it; we add an epsilon arc + // from that pre-final state to the extra state; we copy any final arcs from + // the pre-final state to its extra state and we remove those final arcs from + // the original pre-final state. + // We also copy arcs meet the following requirements: i) destination-state of the + // arc is prefinal state. ii) destination-state of the arc is no further than than + // redeterminize_max_frames from the most recent frame we are determinizing. + // Now this extra state is the pre-final state to + // redeterminize and the original pre-final state does not need to redeterminize + // The epsilon would be removed later on in AppendLatticeChunks, while + // splicing the compact lattices together + void GetRedeterminizedStates(); + + const LatticeIncrementalDecoderConfig config_; + const TransitionModel &trans_model_; // keep it for determinization + + // Record whether we have finished determinized the whole utterance + // (including re-determinize) + bool determinization_finalized_; + // A map from the prefinal state to its correponding first final arc (there could be + // multiple final arcs). We keep final arc information for GetRedeterminizedStates() + // later. It can also be used to identify whether a state is a prefinal state. + unordered_map final_arc_list_; + unordered_map final_arc_list_prev_; + // alpha of each state in lat_ + std::vector forward_costs_; + // we allocate a unique id for each source-state of the last arc of a series of + // initial arcs in GetInitialRawLattice + int32 state_last_initial_offset_; + // We define a state in the appended lattice as a 'redeterminized-state' (meaning: + // one that will be redeterminized), if it is: a pre-final state, or there + // exists an arc from a redeterminized state to this state. We keep reapplying + // this rule until there are no more redeterminized states. The final state + // is not included. These redeterminized states will be stored in this map + // which is a map from the state in the appended compact lattice to the + // state_copy in the newly-created raw lattice. + unordered_map redeterminized_states_; + // It is a map used in GetRedeterminizedStates (see the description there) + // A map from the original pre-final state to the pre-final states (i.e. the + // original pre-final state or an extra state generated by + // GetRedeterminizedStates) used for generating raw lattices of the next chunk. + unordered_map processed_prefinal_states_; + + // The compact lattice we obtain. It should be reseted before processing a + // new utterance + CompactLattice lat_; + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDeterminizer); +}; + +} // end namespace kaldi. + +#endif diff --git a/src/decoder/lattice-incremental-online-decoder.cc b/src/decoder/lattice-incremental-online-decoder.cc new file mode 100644 index 00000000000..85f902bde3d --- /dev/null +++ b/src/decoder/lattice-incremental-online-decoder.cc @@ -0,0 +1,150 @@ +// decoder/lattice-incremental-online-decoder.cc + +// Copyright 2019 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// see note at the top of lattice-faster-decoder.cc, about how to maintain this +// file in sync with lattice-faster-decoder.cc + +#include "decoder/lattice-incremental-decoder.h" +#include "decoder/lattice-incremental-online-decoder.h" +#include "lat/lattice-functions.h" +#include "base/timer.h" + +namespace kaldi { + +// Outputs an FST corresponding to the single best path through the lattice. +template +bool LatticeIncrementalOnlineDecoderTpl::GetBestPath(Lattice *olat, + bool use_final_probs) const { + olat->DeleteStates(); + BaseFloat final_graph_cost; + BestPathIterator iter = BestPathEnd(use_final_probs, &final_graph_cost); + if (iter.Done()) + return false; // would have printed warning. + StateId state = olat->AddState(); + olat->SetFinal(state, LatticeWeight(final_graph_cost, 0.0)); + while (!iter.Done()) { + LatticeArc arc; + iter = TraceBackBestPath(iter, &arc); + arc.nextstate = state; + StateId new_state = olat->AddState(); + olat->AddArc(new_state, arc); + state = new_state; + } + olat->SetStart(state); + return true; +} + +template +typename LatticeIncrementalOnlineDecoderTpl::BestPathIterator LatticeIncrementalOnlineDecoderTpl::BestPathEnd( + bool use_final_probs, + BaseFloat *final_cost_out) const { + if (this->decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "BestPathEnd() with use_final_probs == false"; + KALDI_ASSERT(this->NumFramesDecoded() > 0 && + "You cannot call BestPathEnd if no frames were decoded."); + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (this->decoding_finalized_ ? this->final_costs_ :final_costs_local); + if (!this->decoding_finalized_ && use_final_probs) + this->ComputeFinalCosts(&final_costs_local, NULL, NULL); + + // Singly linked list of tokens on last frame (access list through "next" + // pointer). + BaseFloat best_cost = std::numeric_limits::infinity(); + BaseFloat best_final_cost = 0; + Token *best_tok = NULL; + for (Token *tok = this->active_toks_.back().toks; + tok != NULL; tok = tok->next) { + BaseFloat cost = tok->tot_cost, final_cost = 0.0; + if (use_final_probs && !final_costs.empty()) { + // if we are instructed to use final-probs, and any final tokens were + // active on final frame, include the final-prob in the cost of the token. + typename unordered_map::const_iterator + iter = final_costs.find(tok); + if (iter != final_costs.end()) { + final_cost = iter->second; + cost += final_cost; + } else { + cost = std::numeric_limits::infinity(); + } + } + if (cost < best_cost) { + best_cost = cost; + best_tok = tok; + best_final_cost = final_cost; + } + } + if (best_tok == NULL) { // this should not happen, and is likely a code error or + // caused by infinities in likelihoods, but I'm not making + // it a fatal error for now. + KALDI_WARN << "No final token found."; + } + if (final_cost_out == NULL) + *final_cost_out = best_final_cost; + return BestPathIterator(best_tok, this->NumFramesDecoded() - 1); +} + + +template +typename LatticeIncrementalOnlineDecoderTpl::BestPathIterator LatticeIncrementalOnlineDecoderTpl::TraceBackBestPath( + BestPathIterator iter, LatticeArc *oarc) const { + KALDI_ASSERT(!iter.Done() && oarc != NULL); + Token *tok = static_cast(iter.tok); + int32 cur_t = iter.frame, ret_t = cur_t; + if (tok->backpointer != NULL) { + ForwardLinkT *link; + for (link = tok->backpointer->links; + link != NULL; link = link->next) { + if (link->next_tok == tok) { // this is the link to "tok" + oarc->ilabel = link->ilabel; + oarc->olabel = link->olabel; + BaseFloat graph_cost = link->graph_cost, + acoustic_cost = link->acoustic_cost; + if (link->ilabel != 0) { + KALDI_ASSERT(static_cast(cur_t) < this->cost_offsets_.size()); + acoustic_cost -= this->cost_offsets_[cur_t]; + ret_t--; + } + oarc->weight = LatticeWeight(graph_cost, acoustic_cost); + break; + } + } + if (link == NULL) { // Did not find correct link. + KALDI_ERR << "Error tracing best-path back (likely " + << "bug in token-pruning algorithm)"; + } + } else { + oarc->ilabel = 0; + oarc->olabel = 0; + oarc->weight = LatticeWeight::One(); // zero costs. + } + return BestPathIterator(tok->backpointer, ret_t); +} + +// Instantiate the template for the FST types that we'll need. +template class LatticeIncrementalOnlineDecoderTpl >; +template class LatticeIncrementalOnlineDecoderTpl >; +template class LatticeIncrementalOnlineDecoderTpl >; +template class LatticeIncrementalOnlineDecoderTpl; + + +} // end namespace kaldi. diff --git a/src/decoder/lattice-incremental-online-decoder.h b/src/decoder/lattice-incremental-online-decoder.h new file mode 100644 index 00000000000..8bd41c851ab --- /dev/null +++ b/src/decoder/lattice-incremental-online-decoder.h @@ -0,0 +1,132 @@ +// decoder/lattice-incremental-online-decoder.h + +// Copyright 2019 Zhehuai Chen +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// see note at the top of lattice-faster-decoder.h, about how to maintain this +// file in sync with lattice-faster-decoder.h + + +#ifndef KALDI_DECODER_LATTICE_INCREMENTAL_ONLINE_DECODER_H_ +#define KALDI_DECODER_LATTICE_INCREMENTAL_ONLINE_DECODER_H_ + +#include "util/stl-utils.h" +#include "util/hash-list.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/lattice-incremental-decoder.h" + +namespace kaldi { + + + +/** LatticeIncrementalOnlineDecoderTpl is as LatticeIncrementalDecoderTpl but also + supports an efficient way to get the best path (see the function + BestPathEnd()), which is useful in endpointing and in situations where you + might want to frequently access the best path. + + This is only templated on the FST type, since the Token type is required to + be BackpointerToken. Actually it only makes sense to instantiate + LatticeIncrementalDecoderTpl with Token == BackpointerToken if you do so indirectly via + this child class. + */ +template +class LatticeIncrementalOnlineDecoderTpl: + public LatticeIncrementalDecoderTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using Token = decoder::BackpointerToken; + using ForwardLinkT = decoder::ForwardLink; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeIncrementalOnlineDecoderTpl(const FST &fst, + const TransitionModel &trans_model, + const LatticeIncrementalDecoderConfig &config): + LatticeIncrementalDecoderTpl(fst, trans_model, config) { } + + // This version of the initializer takes ownership of 'fst', and will delete + // it when this object is destroyed. + LatticeIncrementalOnlineDecoderTpl(const LatticeIncrementalDecoderConfig &config, + FST *fst, + const TransitionModel &trans_model): + LatticeIncrementalDecoderTpl(config, fst, trans_model) { } + + + struct BestPathIterator { + void *tok; + int32 frame; + // note, "frame" is the frame-index of the frame you'll get the + // transition-id for next time, if you call TraceBackBestPath on this + // iterator (assuming it's not an epsilon transition). Note that this + // is one less than you might reasonably expect, e.g. it's -1 for + // the nonemitting transitions before the first frame. + BestPathIterator(void *t, int32 f): tok(t), frame(f) { } + bool Done() { return tok == NULL; } + }; + + + /// Outputs an FST corresponding to the single best path through the lattice. + /// This is quite efficient because it doesn't get the entire raw lattice and find + /// the best path through it; instead, it uses the BestPathEnd and BestPathIterator + /// so it basically traces it back through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + bool GetBestPath(Lattice *ofst, + bool use_final_probs = true) const; + + + + /// This function returns an iterator that can be used to trace back + /// the best path. If use_final_probs == true and at least one final state + /// survived till the end, it will use the final-probs in working out the best + /// final Token, and will output the final cost to *final_cost (if non-NULL), + /// else it will use only the forward likelihood, and will put zero in + /// *final_cost (if non-NULL). + /// Requires that NumFramesDecoded() > 0. + BestPathIterator BestPathEnd(bool use_final_probs, + BaseFloat *final_cost = NULL) const; + + + /// This function can be used in conjunction with BestPathEnd() to trace back + /// the best path one link at a time (e.g. this can be useful in endpoint + /// detection). By "link" we mean a link in the graph; not all links cross + /// frame boundaries, but each time you see a nonzero ilabel you can interpret + /// that as a frame. The return value is the updated iterator. It outputs + /// the ilabel and olabel, and the (graph and acoustic) weight to the "arc" pointer, + /// while leaving its "nextstate" variable unchanged. + BestPathIterator TraceBackBestPath( + BestPathIterator iter, LatticeArc *arc) const; + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalOnlineDecoderTpl); +}; + +typedef LatticeIncrementalOnlineDecoderTpl LatticeIncrementalOnlineDecoder; + + +} // end namespace kaldi. + +#endif diff --git a/src/online2/Makefile b/src/online2/Makefile index 242c7be6da6..bbc7ac07bb1 100644 --- a/src/online2/Makefile +++ b/src/online2/Makefile @@ -9,7 +9,7 @@ OBJFILES = online-gmm-decodable.o online-feature-pipeline.o online-ivector-featu online-nnet2-feature-pipeline.o online-gmm-decoding.o online-timing.o \ online-endpoint.o onlinebin-util.o online-speex-wrapper.o \ online-nnet2-decoding.o online-nnet2-decoding-threaded.o \ - online-nnet3-decoding.o + online-nnet3-decoding.o online-nnet3-incremental-decoding.o LIBNAME = kaldi-online2 diff --git a/src/online2/online-endpoint.cc b/src/online2/online-endpoint.cc index aa7752c4484..a3be0791f03 100644 --- a/src/online2/online-endpoint.cc +++ b/src/online2/online-endpoint.cc @@ -71,10 +71,10 @@ bool EndpointDetected(const OnlineEndpointConfig &config, return false; } -template +template int32 TrailingSilenceLength(const TransitionModel &tmodel, const std::string &silence_phones_str, - const LatticeFasterOnlineDecoderTpl &decoder) { + const DEC &decoder) { std::vector silence_phones; if (!SplitStringToIntegers(silence_phones_str, ":", false, &silence_phones)) KALDI_ERR << "Bad --silence-phones option in endpointing config: " @@ -87,7 +87,7 @@ int32 TrailingSilenceLength(const TransitionModel &tmodel, ConstIntegerSet silence_set(silence_phones); bool use_final_probs = false; - typename LatticeFasterOnlineDecoderTpl::BestPathIterator iter = + typename DEC::BestPathIterator iter = decoder.BestPathEnd(use_final_probs, NULL); int32 num_silence_frames = 0; while (!iter.Done()) { // we're going backwards in time from the most @@ -117,7 +117,7 @@ bool EndpointDetected( BaseFloat final_relative_cost = decoder.FinalRelativeCost(); int32 num_frames_decoded = decoder.NumFramesDecoded(), - trailing_silence_frames = TrailingSilenceLength(tmodel, + trailing_silence_frames = TrailingSilenceLength>(tmodel, config.silence_phones, decoder); @@ -125,6 +125,26 @@ bool EndpointDetected( frame_shift_in_seconds, final_relative_cost); } +template +bool EndpointDetected( + const OnlineEndpointConfig &config, + const TransitionModel &tmodel, + BaseFloat frame_shift_in_seconds, + const LatticeIncrementalOnlineDecoderTpl &decoder) { + if (decoder.NumFramesDecoded() == 0) return false; + + BaseFloat final_relative_cost = decoder.FinalRelativeCost(); + + int32 num_frames_decoded = decoder.NumFramesDecoded(), + trailing_silence_frames = TrailingSilenceLength>(tmodel, + config.silence_phones, + decoder); + + return EndpointDetected(config, num_frames_decoded, trailing_silence_frames, + frame_shift_in_seconds, final_relative_cost); +} + + // Instantiate EndpointDetected for the types we need. // It will require TrailingSilenceLength so we don't have to instantiate that. @@ -143,5 +163,21 @@ bool EndpointDetected( BaseFloat frame_shift_in_seconds, const LatticeFasterOnlineDecoderTpl &decoder); +template +bool EndpointDetected >( + const OnlineEndpointConfig &config, + const TransitionModel &tmodel, + BaseFloat frame_shift_in_seconds, + const LatticeIncrementalOnlineDecoderTpl > &decoder); + + +template +bool EndpointDetected( + const OnlineEndpointConfig &config, + const TransitionModel &tmodel, + BaseFloat frame_shift_in_seconds, + const LatticeIncrementalOnlineDecoderTpl &decoder); + + } // namespace kaldi diff --git a/src/online2/online-endpoint.h b/src/online2/online-endpoint.h index aaf9232db13..3171f0c532c 100644 --- a/src/online2/online-endpoint.h +++ b/src/online2/online-endpoint.h @@ -35,6 +35,7 @@ #include "lat/kaldi-lattice.h" #include "hmm/transition-model.h" #include "decoder/lattice-faster-online-decoder.h" +#include "decoder/lattice-incremental-online-decoder.h" namespace kaldi { /// @addtogroup onlinedecoding OnlineDecoding @@ -187,10 +188,10 @@ bool EndpointDetected(const OnlineEndpointConfig &config, /// integer id's of phones that we consider silence. We use the the /// BestPathEnd() and TraceBackOneLink() functions of LatticeFasterOnlineDecoder /// to do this efficiently. -template +template int32 TrailingSilenceLength(const TransitionModel &tmodel, const std::string &silence_phones, - const LatticeFasterOnlineDecoderTpl &decoder); + const DEC &decoder); /// This is a higher-level convenience function that works out the @@ -202,6 +203,15 @@ bool EndpointDetected( BaseFloat frame_shift_in_seconds, const LatticeFasterOnlineDecoderTpl &decoder); +/// This is a higher-level convenience function that works out the +/// arguments to the EndpointDetected function above, from the decoder. +template +bool EndpointDetected( + const OnlineEndpointConfig &config, + const TransitionModel &tmodel, + BaseFloat frame_shift_in_seconds, + const LatticeIncrementalOnlineDecoderTpl &decoder); + diff --git a/src/online2/online-ivector-feature.cc b/src/online2/online-ivector-feature.cc index 2042fbb8b80..fb1b7d9225d 100644 --- a/src/online2/online-ivector-feature.cc +++ b/src/online2/online-ivector-feature.cc @@ -510,6 +510,57 @@ void OnlineSilenceWeighting::ComputeCurrentTraceback( } } +template +void OnlineSilenceWeighting::ComputeCurrentTraceback( + const LatticeIncrementalOnlineDecoderTpl &decoder) { + int32 num_frames_decoded = decoder.NumFramesDecoded(), + num_frames_prev = frame_info_.size(); + // note, num_frames_prev is not the number of frames previously decoded, + // it's the generally-larger number of frames that we were requested to + // provide weights for. + if (num_frames_prev < num_frames_decoded) + frame_info_.resize(num_frames_decoded); + if (num_frames_prev > num_frames_decoded && + frame_info_[num_frames_decoded].transition_id != -1) + KALDI_ERR << "Number of frames decoded decreased"; // Likely bug + + if (num_frames_decoded == 0) + return; + int32 frame = num_frames_decoded - 1; + bool use_final_probs = false; + typename LatticeIncrementalOnlineDecoderTpl::BestPathIterator iter = + decoder.BestPathEnd(use_final_probs, NULL); + while (frame >= 0) { + LatticeArc arc; + arc.ilabel = 0; + while (arc.ilabel == 0) // the while loop skips over input-epsilons + iter = decoder.TraceBackBestPath(iter, &arc); + // note, the iter.frame values are slightly unintuitively defined, + // they are one less than you might expect. + KALDI_ASSERT(iter.frame == frame - 1); + + if (frame_info_[frame].token == iter.tok) { + // we know that the traceback from this point back will be identical, so + // no point tracing back further. Note: we are comparing memory addresses + // of tokens of the decoder; this guarantees it's the same exact token, + // because tokens, once allocated on a frame, are only deleted, never + // reallocated for that frame. + break; + } + + if (num_frames_output_and_correct_ > frame) + num_frames_output_and_correct_ = frame; + + frame_info_[frame].token = iter.tok; + frame_info_[frame].transition_id = arc.ilabel; + frame--; + // leave frame_info_.current_weight at zero for now (as set in the + // constructor), reflecting that we haven't already output a weight for that + // frame. + } +} + + // Instantiate the template OnlineSilenceWeighting::ComputeCurrentTraceback(). template void OnlineSilenceWeighting::ComputeCurrentTraceback >( @@ -517,6 +568,13 @@ void OnlineSilenceWeighting::ComputeCurrentTraceback >( template void OnlineSilenceWeighting::ComputeCurrentTraceback( const LatticeFasterOnlineDecoderTpl &decoder); +template +void OnlineSilenceWeighting::ComputeCurrentTraceback >( + const LatticeIncrementalOnlineDecoderTpl > &decoder); +template +void OnlineSilenceWeighting::ComputeCurrentTraceback( + const LatticeIncrementalOnlineDecoderTpl &decoder); + int32 OnlineSilenceWeighting::GetBeginFrame() { int32 max_duration = config_.max_state_duration; diff --git a/src/online2/online-ivector-feature.h b/src/online2/online-ivector-feature.h index 25e078f1a98..5e674e2b7f1 100644 --- a/src/online2/online-ivector-feature.h +++ b/src/online2/online-ivector-feature.h @@ -33,6 +33,7 @@ #include "feat/online-feature.h" #include "ivector/ivector-extractor.h" #include "decoder/lattice-faster-online-decoder.h" +#include "decoder/lattice-incremental-online-decoder.h" namespace kaldi { /// @addtogroup onlinefeat OnlineFeatureExtraction @@ -471,6 +472,8 @@ class OnlineSilenceWeighting { // It will be instantiated for FST == fst::Fst and fst::GrammarFst. template void ComputeCurrentTraceback(const LatticeFasterOnlineDecoderTpl &decoder); + template + void ComputeCurrentTraceback(const LatticeIncrementalOnlineDecoderTpl &decoder); // Calling this function gets the changes in weight that require us to modify // the stats... the output format is (frame-index, delta-weight). The diff --git a/src/online2/online-nnet3-incremental-decoding.cc b/src/online2/online-nnet3-incremental-decoding.cc new file mode 100644 index 00000000000..540a3a4f850 --- /dev/null +++ b/src/online2/online-nnet3-incremental-decoding.cc @@ -0,0 +1,93 @@ +// online2/online-nnet3-incremental-decoding.cc + +// Copyright 2019 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "online2/online-nnet3-incremental-decoding.h" +#include "lat/lattice-functions.h" +#include "lat/determinize-lattice-pruned.h" +#include "decoder/grammar-fst.h" + +namespace kaldi { + +template +SingleUtteranceNnet3IncrementalDecoderTpl::SingleUtteranceNnet3IncrementalDecoderTpl( + const LatticeIncrementalDecoderConfig &decoder_opts, + const TransitionModel &trans_model, + const nnet3::DecodableNnetSimpleLoopedInfo &info, + const FST &fst, + OnlineNnet2FeaturePipeline *features): + decoder_opts_(decoder_opts), + input_feature_frame_shift_in_seconds_(features->FrameShiftInSeconds()), + trans_model_(trans_model), + decodable_(trans_model_, info, + features->InputFeature(), features->IvectorFeature()), + decoder_(fst, trans_model, decoder_opts_) { + decoder_.InitDecoding(); +} + +template +void SingleUtteranceNnet3IncrementalDecoderTpl::InitDecoding(int32 frame_offset) { + decoder_.InitDecoding(); + decodable_.SetFrameOffset(frame_offset); +} + +template +void SingleUtteranceNnet3IncrementalDecoderTpl::AdvanceDecoding() { + decoder_.AdvanceDecoding(&decodable_); +} + +template +void SingleUtteranceNnet3IncrementalDecoderTpl::FinalizeDecoding() { + decoder_.FinalizeDecoding(); +} + +template +int32 SingleUtteranceNnet3IncrementalDecoderTpl::NumFramesDecoded() const { + return decoder_.NumFramesDecoded(); +} + +template +void SingleUtteranceNnet3IncrementalDecoderTpl::GetLattice(bool end_of_utterance, + CompactLattice *clat) { + if (NumFramesDecoded() == 0) + KALDI_ERR << "You cannot get a lattice if you decoded no frames."; + decoder_.GetLattice(end_of_utterance, decoder_.NumFramesDecoded(), clat); +} + +template +void SingleUtteranceNnet3IncrementalDecoderTpl::GetBestPath(bool end_of_utterance, + Lattice *best_path) const { + decoder_.GetBestPath(best_path, end_of_utterance); +} + +template +bool SingleUtteranceNnet3IncrementalDecoderTpl::EndpointDetected( + const OnlineEndpointConfig &config) { + BaseFloat output_frame_shift = + input_feature_frame_shift_in_seconds_ * + decodable_.FrameSubsamplingFactor(); + return kaldi::EndpointDetected(config, trans_model_, + output_frame_shift, decoder_); +} + + +// Instantiate the template for the types needed. +template class SingleUtteranceNnet3IncrementalDecoderTpl >; +template class SingleUtteranceNnet3IncrementalDecoderTpl; + +} // namespace kaldi diff --git a/src/online2/online-nnet3-incremental-decoding.h b/src/online2/online-nnet3-incremental-decoding.h new file mode 100644 index 00000000000..ddd9707bf54 --- /dev/null +++ b/src/online2/online-nnet3-incremental-decoding.h @@ -0,0 +1,128 @@ +// online2/online-nnet3-incremental-decoding.h + +// Copyright 2019 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_ONLINE2_ONLINE_NNET3_INCREMENTAL_DECODING_H_ +#define KALDI_ONLINE2_ONLINE_NNET3_INCREMENTAL_DECODING_H_ + +#include +#include +#include + +#include "nnet3/decodable-online-looped.h" +#include "matrix/matrix-lib.h" +#include "util/common-utils.h" +#include "base/kaldi-error.h" +#include "itf/online-feature-itf.h" +#include "online2/online-endpoint.h" +#include "online2/online-nnet2-feature-pipeline.h" +#include "decoder/lattice-incremental-online-decoder.h" +#include "hmm/transition-model.h" +#include "hmm/posterior.h" + +namespace kaldi { +/// @addtogroup onlinedecoding OnlineDecoding +/// @{ + + +/** + You will instantiate this class when you want to decode a single utterance + using the online-decoding setup for neural nets. The template will be + instantiated only for FST = fst::Fst and FST = fst::GrammarFst. +*/ + +template +class SingleUtteranceNnet3IncrementalDecoderTpl { + public: + + // Constructor. The pointer 'features' is not being given to this class to own + // and deallocate, it is owned externally. + SingleUtteranceNnet3IncrementalDecoderTpl(const LatticeIncrementalDecoderConfig &decoder_opts, + const TransitionModel &trans_model, + const nnet3::DecodableNnetSimpleLoopedInfo &info, + const FST &fst, + OnlineNnet2FeaturePipeline *features); + + /// Initializes the decoding and sets the frame offset of the underlying + /// decodable object. This method is called by the constructor. You can also + /// call this method when you want to reset the decoder state, but want to + /// keep using the same decodable object, e.g. in case of an endpoint. + void InitDecoding(int32 frame_offset = 0); + + /// Advances the decoding as far as we can. + void AdvanceDecoding(); + + /// Finalizes the decoding. Cleans up and prunes remaining tokens, so the + /// GetLattice() call will return faster. You must not call this before + /// calling (TerminateDecoding() or InputIsFinished()) and then Wait(). + void FinalizeDecoding(); + + int32 NumFramesDecoded() const; + + /// Gets the lattice. The output lattice has any acoustic scaling in it + /// (which will typically be desirable in an online-decoding context); if you + /// want an un-scaled lattice, scale it using ScaleLattice() with the inverse + /// of the acoustic weight. "end_of_utterance" will be true if you want the + /// final-probs to be included. + void GetLattice(bool end_of_utterance, + CompactLattice *clat); + + /// Outputs an FST corresponding to the single best path through the current + /// lattice. If "use_final_probs" is true AND we reached the final-state of + /// the graph then it will include those as final-probs, else it will treat + /// all final-probs as one. + void GetBestPath(bool end_of_utterance, + Lattice *best_path) const; + + + /// This function calls EndpointDetected from online-endpoint.h, + /// with the required arguments. + bool EndpointDetected(const OnlineEndpointConfig &config); + + const LatticeIncrementalOnlineDecoderTpl &Decoder() const { return decoder_; } + + ~SingleUtteranceNnet3IncrementalDecoderTpl() { } + private: + + const LatticeIncrementalDecoderConfig &decoder_opts_; + + // this is remembered from the constructor; it's ultimately + // derived from calling FrameShiftInSeconds() on the feature pipeline. + BaseFloat input_feature_frame_shift_in_seconds_; + + // we need to keep a reference to the transition model around only because + // it's needed by the endpointing code. + const TransitionModel &trans_model_; + + nnet3::DecodableAmNnetLoopedOnline decodable_; + + LatticeIncrementalOnlineDecoderTpl decoder_; + +}; + + +typedef SingleUtteranceNnet3IncrementalDecoderTpl > SingleUtteranceNnet3IncrementalDecoder; + +/// @} End of "addtogroup onlinedecoding" + +} // namespace kaldi + + + +#endif // KALDI_ONLINE2_ONLINE_NNET3_DECODING_H_ diff --git a/src/online2bin/Makefile b/src/online2bin/Makefile index 28c135eb950..2552e7148dc 100644 --- a/src/online2bin/Makefile +++ b/src/online2bin/Makefile @@ -12,7 +12,7 @@ BINFILES = online2-wav-gmm-latgen-faster apply-cmvn-online \ online2-wav-dump-features ivector-randomize \ online2-wav-nnet2-am-compute online2-wav-nnet2-latgen-threaded \ online2-wav-nnet3-latgen-faster online2-wav-nnet3-latgen-grammar \ - online2-tcp-nnet3-decode-faster + online2-tcp-nnet3-decode-faster online2-wav-nnet3-latgen-incremental OBJFILES = diff --git a/src/online2bin/online2-wav-nnet3-latgen-incremental.cc b/src/online2bin/online2-wav-nnet3-latgen-incremental.cc new file mode 100644 index 00000000000..b48337af5fb --- /dev/null +++ b/src/online2bin/online2-wav-nnet3-latgen-incremental.cc @@ -0,0 +1,304 @@ +// online2bin/online2-wav-nnet3-latgen-incremental.cc + +// Copyright 2019 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "feat/wave-reader.h" +#include "online2/online-nnet3-incremental-decoding.h" +#include "online2/online-nnet2-feature-pipeline.h" +#include "online2/onlinebin-util.h" +#include "online2/online-timing.h" +#include "online2/online-endpoint.h" +#include "fstext/fstext-lib.h" +#include "lat/lattice-functions.h" +#include "util/kaldi-thread.h" +#include "nnet3/nnet-utils.h" + +namespace kaldi { + +void GetDiagnosticsAndPrintOutput(const std::string &utt, + const fst::SymbolTable *word_syms, + const CompactLattice &clat, + int64 *tot_num_frames, + double *tot_like) { + if (clat.NumStates() == 0) { + KALDI_WARN << "Empty lattice."; + return; + } + CompactLattice best_path_clat; + CompactLatticeShortestPath(clat, &best_path_clat); + + Lattice best_path_lat; + ConvertLattice(best_path_clat, &best_path_lat); + + double likelihood; + LatticeWeight weight; + int32 num_frames; + std::vector alignment; + std::vector words; + GetLinearSymbolSequence(best_path_lat, &alignment, &words, &weight); + num_frames = alignment.size(); + likelihood = -(weight.Value1() + weight.Value2()); + *tot_num_frames += num_frames; + *tot_like += likelihood; + KALDI_VLOG(2) << "Likelihood per frame for utterance " << utt << " is " + << (likelihood / num_frames) << " over " << num_frames + << " frames."; + + if (word_syms != NULL) { + std::cerr << utt << ' '; + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms->Find(words[i]); + if (s == "") + KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; + std::cerr << s << ' '; + } + std::cerr << std::endl; + } +} + +} + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace fst; + + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "Reads in wav file(s) and simulates online decoding with neural nets\n" + "(nnet3 setup), with optional iVector-based speaker adaptation and\n" + "optional endpointing. Note: some configuration values and inputs are\n" + "set via config files whose filenames are passed as options\n" + "The lattice determinization algorithm here can operate\n" + "incrementally.\n" + "\n" + "Usage: online2-wav-nnet3-latgen-incremental [options] " + " \n" + "The spk2utt-rspecifier can just be if\n" + "you want to decode utterance by utterance.\n"; + + ParseOptions po(usage); + + std::string word_syms_rxfilename; + + // feature_opts includes configuration for the iVector adaptation, + // as well as the basic features. + OnlineNnet2FeaturePipelineConfig feature_opts; + nnet3::NnetSimpleLoopedComputationOptions decodable_opts; + LatticeIncrementalDecoderConfig decoder_opts; + OnlineEndpointConfig endpoint_opts; + + BaseFloat chunk_length_secs = 0.18; + bool do_endpointing = false; + bool online = true; + + po.Register("chunk-length", &chunk_length_secs, + "Length of chunk size in seconds, that we process. Set to <= 0 " + "to use all input in one chunk."); + po.Register("word-symbol-table", &word_syms_rxfilename, + "Symbol table for words [for debug output]"); + po.Register("do-endpointing", &do_endpointing, + "If true, apply endpoint detection"); + po.Register("online", &online, + "You can set this to false to disable online iVector estimation " + "and have all the data for each utterance used, even at " + "utterance start. This is useful where you just want the best " + "results and don't care about online operation. Setting this to " + "false has the same effect as setting " + "--use-most-recent-ivector=true and --greedy-ivector-extractor=true " + "in the file given to --ivector-extraction-config, and " + "--chunk-length=-1."); + po.Register("num-threads-startup", &g_num_threads, + "Number of threads used when initializing iVector extractor."); + + feature_opts.Register(&po); + decodable_opts.Register(&po); + decoder_opts.Register(&po); + endpoint_opts.Register(&po); + + + po.Read(argc, argv); + + if (po.NumArgs() != 5) { + po.PrintUsage(); + return 1; + } + + std::string nnet3_rxfilename = po.GetArg(1), + fst_rxfilename = po.GetArg(2), + spk2utt_rspecifier = po.GetArg(3), + wav_rspecifier = po.GetArg(4), + clat_wspecifier = po.GetArg(5); + + OnlineNnet2FeaturePipelineInfo feature_info(feature_opts); + + if (!online) { + feature_info.ivector_extractor_info.use_most_recent_ivector = true; + feature_info.ivector_extractor_info.greedy_ivector_extractor = true; + chunk_length_secs = -1.0; + } + + TransitionModel trans_model; + nnet3::AmNnetSimple am_nnet; + { + bool binary; + Input ki(nnet3_rxfilename, &binary); + trans_model.Read(ki.Stream(), binary); + am_nnet.Read(ki.Stream(), binary); + SetBatchnormTestMode(true, &(am_nnet.GetNnet())); + SetDropoutTestMode(true, &(am_nnet.GetNnet())); + nnet3::CollapseModel(nnet3::CollapseModelConfig(), &(am_nnet.GetNnet())); + } + + // this object contains precomputed stuff that is used by all decodable + // objects. It takes a pointer to am_nnet because if it has iVectors it has + // to modify the nnet to accept iVectors at intervals. + nnet3::DecodableNnetSimpleLoopedInfo decodable_info(decodable_opts, + &am_nnet); + + + fst::Fst *decode_fst = ReadFstKaldiGeneric(fst_rxfilename); + + fst::SymbolTable *word_syms = NULL; + if (word_syms_rxfilename != "") + if (!(word_syms = fst::SymbolTable::ReadText(word_syms_rxfilename))) + KALDI_ERR << "Could not read symbol table from file " + << word_syms_rxfilename; + + int32 num_done = 0, num_err = 0; + double tot_like = 0.0; + int64 num_frames = 0; + + SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier); + RandomAccessTableReader wav_reader(wav_rspecifier); + CompactLatticeWriter clat_writer(clat_wspecifier); + + OnlineTimingStats timing_stats; + + for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) { + std::string spk = spk2utt_reader.Key(); + const std::vector &uttlist = spk2utt_reader.Value(); + OnlineIvectorExtractorAdaptationState adaptation_state( + feature_info.ivector_extractor_info); + for (size_t i = 0; i < uttlist.size(); i++) { + std::string utt = uttlist[i]; + if (!wav_reader.HasKey(utt)) { + KALDI_WARN << "Did not find audio for utterance " << utt; + num_err++; + continue; + } + const WaveData &wave_data = wav_reader.Value(utt); + // get the data for channel zero (if the signal is not mono, we only + // take the first channel). + SubVector data(wave_data.Data(), 0); + + OnlineNnet2FeaturePipeline feature_pipeline(feature_info); + feature_pipeline.SetAdaptationState(adaptation_state); + + OnlineSilenceWeighting silence_weighting( + trans_model, + feature_info.silence_weighting_config, + decodable_opts.frame_subsampling_factor); + + SingleUtteranceNnet3IncrementalDecoder decoder(decoder_opts, trans_model, + decodable_info, + *decode_fst, &feature_pipeline); + OnlineTimer decoding_timer(utt); + + BaseFloat samp_freq = wave_data.SampFreq(); + int32 chunk_length; + if (chunk_length_secs > 0) { + chunk_length = int32(samp_freq * chunk_length_secs); + if (chunk_length == 0) chunk_length = 1; + } else { + chunk_length = std::numeric_limits::max(); + } + + int32 samp_offset = 0; + std::vector > delta_weights; + + while (samp_offset < data.Dim()) { + int32 samp_remaining = data.Dim() - samp_offset; + int32 num_samp = chunk_length < samp_remaining ? chunk_length + : samp_remaining; + + SubVector wave_part(data, samp_offset, num_samp); + feature_pipeline.AcceptWaveform(samp_freq, wave_part); + + samp_offset += num_samp; + decoding_timer.WaitUntil(samp_offset / samp_freq); + if (samp_offset == data.Dim()) { + // no more input. flush out last frames + feature_pipeline.InputFinished(); + } + + if (silence_weighting.Active() && + feature_pipeline.IvectorFeature() != NULL) { + silence_weighting.ComputeCurrentTraceback(decoder.Decoder()); + silence_weighting.GetDeltaWeights(feature_pipeline.NumFramesReady(), + &delta_weights); + feature_pipeline.IvectorFeature()->UpdateFrameWeights(delta_weights); + } + + decoder.AdvanceDecoding(); + + if (do_endpointing && decoder.EndpointDetected(endpoint_opts)) { + break; + } + } + decoder.FinalizeDecoding(); + + CompactLattice clat; + bool end_of_utterance = true; + decoder.GetLattice(end_of_utterance, &clat); + + GetDiagnosticsAndPrintOutput(utt, word_syms, clat, + &num_frames, &tot_like); + + decoding_timer.OutputStats(&timing_stats); + + // In an application you might avoid updating the adaptation state if + // you felt the utterance had low confidence. See lat/confidence.h + feature_pipeline.GetAdaptationState(&adaptation_state); + + // we want to output the lattice with un-scaled acoustics. + BaseFloat inv_acoustic_scale = + 1.0 / decodable_opts.acoustic_scale; + ScaleLattice(AcousticLatticeScale(inv_acoustic_scale), &clat); + + clat_writer.Write(utt, clat); + KALDI_LOG << "Decoded utterance " << utt; + num_done++; + } + } + timing_stats.Print(online); + + KALDI_LOG << "Decoded " << num_done << " utterances, " + << num_err << " with errors."; + KALDI_LOG << "Overall likelihood per frame was " << (tot_like / num_frames) + << " per frame over " << num_frames << " frames."; + delete decode_fst; + delete word_syms; // will delete if non-NULL. + return (num_done != 0 ? 0 : 1); + } catch(const std::exception& e) { + std::cerr << e.what(); + return -1; + } +} // main()