From d8ff7ee0cd9f5c670c36ad4525899e668d50abef Mon Sep 17 00:00:00 2001 From: chenzhehuai Date: Tue, 3 Apr 2018 00:53:59 -0400 Subject: [PATCH 01/60] make fst templates inline to eliminate linking errors in other places --- src/fstext/fstext-utils-inl.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/fstext/fstext-utils-inl.h b/src/fstext/fstext-utils-inl.h index 923c67c07e2..756e449fcaa 100644 --- a/src/fstext/fstext-utils-inl.h +++ b/src/fstext/fstext-utils-inl.h @@ -1132,7 +1132,7 @@ inline bool IsStochasticFst(const Fst &fst, // Will override this for LogArc where NaturalLess will not work. template -bool IsStochasticFst(const Fst &fst, +inline bool IsStochasticFst(const Fst &fst, float delta, typename Arc::Weight *min_sum, typename Arc::Weight *max_sum) { @@ -1168,7 +1168,7 @@ bool IsStochasticFst(const Fst &fst, // Overriding template for LogArc as NaturalLess does not work there. template<> -bool IsStochasticFst(const Fst &fst, +inline bool IsStochasticFst(const Fst &fst, float delta, LogArc::Weight *min_sum, LogArc::Weight *max_sum) { @@ -1208,7 +1208,7 @@ bool IsStochasticFst(const Fst &fst, // This function deals with the generic fst. // This version currently supports ConstFst or VectorFst. // Otherwise, it will be died with an error. -bool IsStochasticFstInLog(const Fst &fst, +inline bool IsStochasticFstInLog(const Fst &fst, float delta, StdArc::Weight *min_sum, StdArc::Weight *max_sum) { From 1f50f0644a8f9712daf43bc8820b528d425d0078 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Wed, 27 Mar 2019 07:04:55 +0800 Subject: [PATCH 02/60] WIP --- src/bin/Makefile | 1 + src/bin/latgen-incremental-mapped.cc | 179 +++ src/decoder/Makefile | 1 + src/decoder/decoder-wrappers.cc | 129 +++ src/decoder/decoder-wrappers.h | 18 + src/decoder/lattice-incremental-decoder.cc | 1186 ++++++++++++++++++++ src/decoder/lattice-incremental-decoder.h | 443 ++++++++ 7 files changed, 1957 insertions(+) create mode 100644 src/bin/latgen-incremental-mapped.cc create mode 100644 src/decoder/lattice-incremental-decoder.cc create mode 100644 src/decoder/lattice-incremental-decoder.h diff --git a/src/bin/Makefile b/src/bin/Makefile index 7cb01b50120..8046f6c9ab2 100644 --- a/src/bin/Makefile +++ b/src/bin/Makefile @@ -17,6 +17,7 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \ post-to-weights sum-tree-stats weight-post post-to-tacc copy-matrix \ copy-vector copy-int-vector sum-post sum-matrices draw-tree \ align-mapped align-compiled-mapped latgen-faster-mapped latgen-faster-mapped-parallel \ + latgen-incremental-mapped \ hmm-info analyze-counts post-to-phone-post \ post-to-pdf-post logprob-to-post prob-to-post copy-post \ matrix-sum build-pfile-from-ali get-post-on-ali tree-info am-info \ diff --git a/src/bin/latgen-incremental-mapped.cc b/src/bin/latgen-incremental-mapped.cc new file mode 100644 index 00000000000..164a513f2d6 --- /dev/null +++ b/src/bin/latgen-incremental-mapped.cc @@ -0,0 +1,179 @@ +// bin/latgen-incremental-mapped.cc + +// Copyright 2009-2012 Microsoft Corporation, Karel Vesely +// 2013 Johns Hopkins University (author: Daniel Povey) +// 2014 Guoguo Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "tree/context-dep.h" +#include "hmm/transition-model.h" +#include "fstext/fstext-lib.h" +#include "decoder/decoder-wrappers.h" +#include "decoder/decodable-matrix.h" +#include "base/timer.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + using fst::SymbolTable; + using fst::Fst; + using fst::StdArc; + + const char *usage = + "Generate lattices, reading log-likelihoods as matrices\n" + " (model is needed only for the integer mappings in its transition-model)\n" + "Usage: latgen-incremental-mapped [options] trans-model-in (fst-in|fsts-rspecifier) loglikes-rspecifier" + " lattice-wspecifier [ words-wspecifier [alignments-wspecifier] ]\n"; + ParseOptions po(usage); + Timer timer; + bool allow_partial = false; + BaseFloat acoustic_scale = 0.1; + LatticeIncrementalDecoderConfig config; + + std::string word_syms_filename; + config.Register(&po); + po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic likelihoods"); + + po.Register("word-symbol-table", &word_syms_filename, "Symbol table for words [for debug output]"); + po.Register("allow-partial", &allow_partial, "If true, produce output even if end state was not reached."); + + po.Read(argc, argv); + + if (po.NumArgs() < 4 || po.NumArgs() > 6) { + po.PrintUsage(); + exit(1); + } + + std::string model_in_filename = po.GetArg(1), + fst_in_str = po.GetArg(2), + feature_rspecifier = po.GetArg(3), + lattice_wspecifier = po.GetArg(4), + words_wspecifier = po.GetOptArg(5), + alignment_wspecifier = po.GetOptArg(6); + + TransitionModel trans_model; + ReadKaldiObject(model_in_filename, &trans_model); + + bool determinize = config.determinize_lattice; + CompactLatticeWriter compact_lattice_writer; + LatticeWriter lattice_writer; + if (! (determinize ? compact_lattice_writer.Open(lattice_wspecifier) + : lattice_writer.Open(lattice_wspecifier))) + KALDI_ERR << "Could not open table for writing lattices: " + << lattice_wspecifier; + + Int32VectorWriter words_writer(words_wspecifier); + + Int32VectorWriter alignment_writer(alignment_wspecifier); + + fst::SymbolTable *word_syms = NULL; + if (word_syms_filename != "") + if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename))) + KALDI_ERR << "Could not read symbol table from file " + << word_syms_filename; + + double tot_like = 0.0; + kaldi::int64 frame_count = 0; + int num_success = 0, num_fail = 0; + + if (ClassifyRspecifier(fst_in_str, NULL, NULL) == kNoRspecifier) { + SequentialBaseFloatMatrixReader loglike_reader(feature_rspecifier); + // Input FST is just one FST, not a table of FSTs. + Fst *decode_fst = fst::ReadFstKaldiGeneric(fst_in_str); + timer.Reset(); + + { + LatticeIncrementalDecoder decoder(*decode_fst, trans_model, config); + + for (; !loglike_reader.Done(); loglike_reader.Next()) { + std::string utt = loglike_reader.Key(); + Matrix loglikes (loglike_reader.Value()); + loglike_reader.FreeCurrent(); + if (loglikes.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_fail++; + continue; + } + + DecodableMatrixScaledMapped decodable(trans_model, loglikes, acoustic_scale); + + double like; + if (DecodeUtteranceLatticeIncremental( + decoder, decodable, trans_model, word_syms, utt, + acoustic_scale, determinize, allow_partial, &alignment_writer, + &words_writer, &compact_lattice_writer, &lattice_writer, + &like)) { + tot_like += like; + frame_count += loglikes.NumRows(); + num_success++; + } else num_fail++; + } + } + delete decode_fst; // delete this only after decoder goes out of scope. + } else { // We have different FSTs for different utterances. + SequentialTableReader fst_reader(fst_in_str); + RandomAccessBaseFloatMatrixReader loglike_reader(feature_rspecifier); + for (; !fst_reader.Done(); fst_reader.Next()) { + std::string utt = fst_reader.Key(); + if (!loglike_reader.HasKey(utt)) { + KALDI_WARN << "Not decoding utterance " << utt + << " because no loglikes available."; + num_fail++; + continue; + } + const Matrix &loglikes = loglike_reader.Value(utt); + if (loglikes.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_fail++; + continue; + } + LatticeIncrementalDecoder decoder(fst_reader.Value(), trans_model, config); + DecodableMatrixScaledMapped decodable(trans_model, loglikes, acoustic_scale); + double like; + if (DecodeUtteranceLatticeIncremental( + decoder, decodable, trans_model, word_syms, utt, acoustic_scale, + determinize, allow_partial, &alignment_writer, &words_writer, + &compact_lattice_writer, &lattice_writer, &like)) { + tot_like += like; + frame_count += loglikes.NumRows(); + num_success++; + } else num_fail++; + } + } + + double elapsed = timer.Elapsed(); + KALDI_LOG << "Time taken "<< elapsed + << "s: real-time factor assuming 100 frames/sec is " + << (elapsed*100.0/frame_count); + KALDI_LOG << "Done " << num_success << " utterances, failed for " + << num_fail; + KALDI_LOG << "Overall log-likelihood per frame is " << (tot_like/frame_count) << " over " + << frame_count<<" frames."; + + delete word_syms; + if (num_success != 0) return 0; + else return 1; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/decoder/Makefile b/src/decoder/Makefile index 020fe358fe9..849947e493f 100644 --- a/src/decoder/Makefile +++ b/src/decoder/Makefile @@ -7,6 +7,7 @@ TESTFILES = OBJFILES = training-graph-compiler.o lattice-simple-decoder.o lattice-faster-decoder.o \ lattice-faster-online-decoder.o simple-decoder.o faster-decoder.o \ + lattice-incremental-decoder.o \ decoder-wrappers.o grammar-fst.o decodable-matrix.o LIBNAME = kaldi-decoder diff --git a/src/decoder/decoder-wrappers.cc b/src/decoder/decoder-wrappers.cc index ff573c74d15..7f05bf274e4 100644 --- a/src/decoder/decoder-wrappers.cc +++ b/src/decoder/decoder-wrappers.cc @@ -195,6 +195,104 @@ DecodeUtteranceLatticeFasterClass::~DecodeUtteranceLatticeFasterClass() { delete decodable_; } +template +bool DecodeUtteranceLatticeIncremental( + LatticeIncrementalDecoderTpl &decoder, // not const but is really an input. + DecodableInterface &decodable, // not const but is really an input. + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr) { // puts utterance's like in like_ptr on success. + using fst::VectorFst; + + if (!decoder.Decode(&decodable)) { + KALDI_WARN << "Failed to decode file " << utt; + return false; + } + if (!decoder.ReachedFinal()) { + if (allow_partial) { + KALDI_WARN << "Outputting partial output for utterance " << utt + << " since no final-state reached\n"; + } else { + KALDI_WARN << "Not producing output for utterance " << utt + << " since no final-state reached and " + << "--allow-partial=false.\n"; + return false; + } + } + + double likelihood; + LatticeWeight weight; + int32 num_frames; + { // First do some stuff with word-level traceback... + VectorFst decoded; + if (!decoder.GetBestPath(&decoded)) + // Shouldn't really reach this point as already checked success. + KALDI_ERR << "Failed to get traceback for utterance " << utt; + + std::vector alignment; + std::vector words; + GetLinearSymbolSequence(decoded, &alignment, &words, &weight); + num_frames = alignment.size(); + if (words_writer->IsOpen()) + words_writer->Write(utt, words); + if (alignment_writer->IsOpen()) + alignment_writer->Write(utt, alignment); + if (word_syms != NULL) { + std::cerr << utt << ' '; + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms->Find(words[i]); + if (s == "") + KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; + std::cerr << s << ' '; + } + std::cerr << '\n'; + } + likelihood = -(weight.Value1() + weight.Value2()); + } + + // Get lattice, and do determinization if requested. + Lattice lat; + decoder.GetRawLattice(&lat); + if (lat.NumStates() == 0) + KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt; + fst::Connect(&lat); + if (determinize) { + CompactLattice clat; + if (!DeterminizeLatticePhonePrunedWrapper( + trans_model, + &lat, + decoder.GetOptions().lattice_beam, + &clat, + decoder.GetOptions().det_opts)) + KALDI_WARN << "Determinization finished earlier than the beam for " + << "utterance " << utt; + // We'll write the lattice without acoustic scaling. + if (acoustic_scale != 0.0) + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat); + compact_lattice_writer->Write(utt, clat); + } else { + // We'll write the lattice without acoustic scaling. + if (acoustic_scale != 0.0) + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &lat); + lattice_writer->Write(utt, lat); + } + KALDI_LOG << "Log-like per frame for utterance " << utt << " is " + << (likelihood / num_frames) << " over " + << num_frames << " frames."; + KALDI_VLOG(2) << "Cost for utterance " << utt << " is " + << weight.Value1() << " + " << weight.Value2(); + *like_ptr = likelihood; + return true; +} + // Takes care of output. Returns true on success. template @@ -296,6 +394,37 @@ bool DecodeUtteranceLatticeFaster( } // Instantiate the template above for the two required FST types. +template bool DecodeUtteranceLatticeIncremental( + LatticeIncrementalDecoderTpl > &decoder, + DecodableInterface &decodable, + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr); + +template bool DecodeUtteranceLatticeIncremental( + LatticeIncrementalDecoderTpl &decoder, + DecodableInterface &decodable, + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr); + + template bool DecodeUtteranceLatticeFaster( LatticeFasterDecoderTpl > &decoder, DecodableInterface &decodable, diff --git a/src/decoder/decoder-wrappers.h b/src/decoder/decoder-wrappers.h index fc81137f356..61134412cfd 100644 --- a/src/decoder/decoder-wrappers.h +++ b/src/decoder/decoder-wrappers.h @@ -22,6 +22,7 @@ #include "itf/options-itf.h" #include "decoder/lattice-faster-decoder.h" +#include "decoder/lattice-incremental-decoder.h" #include "decoder/lattice-simple-decoder.h" // This header contains declarations from various convenience functions that are called @@ -88,6 +89,23 @@ void AlignUtteranceWrapper( void ModifyGraphForCarefulAlignment( fst::VectorFst *fst); +/// TODO +template +bool DecodeUtteranceLatticeIncremental( + LatticeIncrementalDecoderTpl &decoder, // not const but is really an input. + DecodableInterface &decodable, // not const but is really an input. + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignments_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr); // puts utterance's likelihood in like_ptr on success. + /// This function DecodeUtteranceLatticeFaster is used in several decoders, and /// we have moved it here. Note: this is really "binary-level" code as it diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc new file mode 100644 index 00000000000..2432f92ddd4 --- /dev/null +++ b/src/decoder/lattice-incremental-decoder.cc @@ -0,0 +1,1186 @@ +// decoder/lattice-incremental-decoder.cc + +// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann +// 2013-2018 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/lattice-incremental-decoder.h" +#include "lat/lattice-functions.h" + +namespace kaldi { + +// instantiate this class once for each thing you have to decode. +template +LatticeIncrementalDecoderTpl::LatticeIncrementalDecoderTpl( + const FST &fst, + const TransitionModel& trans_model, + const LatticeIncrementalDecoderConfig &config): + fst_(&fst), delete_fst_(false), config_(config), num_toks_(0), trans_model_(trans_model) { + config.Check(); + toks_.SetSize(1000); // just so on the first frame we do something reasonable. +} + + +template +LatticeIncrementalDecoderTpl::LatticeIncrementalDecoderTpl( + const LatticeIncrementalDecoderConfig &config, FST *fst, + const TransitionModel& trans_model): + fst_(fst), delete_fst_(true), config_(config), num_toks_(0), +trans_model_(trans_model) { + config.Check(); + toks_.SetSize(1000); // just so on the first frame we do something reasonable. +} + + +template +LatticeIncrementalDecoderTpl::~LatticeIncrementalDecoderTpl() { + DeleteElems(toks_.Clear()); + ClearActiveTokens(); + if (delete_fst_) delete fst_; +} + +template +void LatticeIncrementalDecoderTpl::InitDecoding() { + // clean up from last time: + DeleteElems(toks_.Clear()); + cost_offsets_.clear(); + ClearActiveTokens(); + warned_ = false; + num_toks_ = 0; + decoding_finalized_ = false; + final_costs_.clear(); + StateId start_state = fst_->Start(); + KALDI_ASSERT(start_state != fst::kNoStateId); + active_toks_.resize(1); + Token *start_tok = new Token(0.0, 0.0, NULL, NULL, NULL); + active_toks_[0].toks = start_tok; + toks_.Insert(start_state, start_tok); + num_toks_++; + + lat_.DeleteStates(); + last_get_lattice_frame_ = 0; + state_label_map_.clear(); + state_label_map_.reserve(std::min((int32)1e5, config_.max_active)); + state_label_avilable_idx_ = config_.max_word_id+1; + final_arc_list_.clear(); + final_arc_list_prev_.clear(); + state_label_forward_prob_.clear(); + + ProcessNonemitting(config_.beam); +} + +// Returns true if any kind of traceback is available (not necessarily from +// a final state). It should only very rarely return false; this indicates +// an unusual search error. +template +bool LatticeIncrementalDecoderTpl::Decode(DecodableInterface *decodable) { + InitDecoding(); + + // We use 1-based indexing for frames in this decoder (if you view it in + // terms of features), but note that the decodable object uses zero-based + // numbering, which we have to correct for when we call it. + + while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { + if (NumFramesDecoded() % config_.prune_interval == 0) { + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + // TODO: have a delay in GetLattice + GetLattice(false, false, &lat_); + } + BaseFloat cost_cutoff = ProcessEmitting(decodable); + ProcessNonemitting(cost_cutoff); + } + FinalizeDecoding(); + GetLattice(true, false, &lat_); + // GetLattice(true, true, &lat_); // TODO + + // Returns true if we have any kind of traceback available (not necessarily + // to the end state; query ReachedFinal() for that). + return !active_toks_.empty() && active_toks_.back().toks != NULL; +} + + +// Outputs an FST corresponding to the single best path through the lattice. +template +bool LatticeIncrementalDecoderTpl::GetBestPath(Lattice *olat, + bool use_final_probs) const { + Lattice raw_lat; + GetRawLattice(&raw_lat); + ShortestPath(raw_lat, olat); + return (olat->NumStates() != 0); +} + +// Outputs an FST corresponding to the raw, state-level lattice +template +bool LatticeIncrementalDecoderTpl::GetRawLattice( + Lattice *ofst, + bool use_final_probs) const { + ConvertLattice(lat_, ofst); + return true; +} + +template +bool LatticeIncrementalDecoderTpl::GetRawLattice( + Lattice *ofst, + bool use_final_probs, + int32 frame_begin, + int32 frame_end, + bool create_initial_state, + bool create_final_state) { + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + + // Note: you can't use the old interface (Decode()) if you want to + // get the lattice with use_final_probs = false. You'd have to do + // InitDecoding() and then AdvanceDecoding(). + if (decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "GetRawLattice() with use_final_probs == false"; + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (decoding_finalized_ ? final_costs_ : final_costs_local); + if (!decoding_finalized_ && use_final_probs) + ComputeFinalCosts(&final_costs_local, NULL, NULL); + + ofst->DeleteStates(); + if (frame_begin != 0) ofst->AddState(); // initial-state for the chunk + // num-frames plus one (since frames are one-based, and we have + // an extra frame for the start-state). + KALDI_ASSERT(frame_end > 0); + const int32 bucket_count = num_toks_/2 + 3; + unordered_map tok_map(bucket_count); + // First create all states. + std::vector token_list; + for (int32 f = frame_begin; f <= frame_end; f++) { + if (active_toks_[f].toks == NULL) { + KALDI_WARN << "GetRawLattice: no tokens active on frame " << f + << ": not producing lattice.\n"; + return false; + } + TopSortTokens(active_toks_[f].toks, &token_list); + for (size_t i = 0; i < token_list.size(); i++) + if (token_list[i] != NULL) + tok_map[token_list[i]] = ofst->AddState(); + } + // The next statement sets the start state of the output FST. Because we + // topologically sorted the tokens, state zero must be the start-state. + StateId begin_state = 0; + StateId end_state = ofst->AddState(); // final-state for the chunk + ofst->SetStart(begin_state); + ofst->SetFinal(end_state, Weight::One()); + + KALDI_VLOG(4) << "init:" << num_toks_/2 + 3 << " buckets:" + << tok_map.bucket_count() << " load:" << tok_map.load_factor() + << " max:" << tok_map.max_load_factor(); + // Create initial_arc for later appending with the previous chunk + if (create_initial_state) { + for (Token *tok = active_toks_[frame_begin].toks; tok != NULL; tok = tok->next) { + StateId cur_state = tok_map[tok]; + int32 id = state_label_map_.find(tok)->second; // it should exist + // TODO: calculate alpha but not use tot_cost or extra_cost + BaseFloat cost_offset = tok->tot_cost; + state_label_forward_prob_[id] = tok->tot_cost; + Arc arc(0, id, + Weight(0, cost_offset), + cur_state); + ofst->AddArc(begin_state, arc); + } + } + // Now create all arcs. + for (int32 f = frame_begin; f <= frame_end; f++) { + for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { + StateId cur_state = tok_map[tok]; + for (ForwardLinkT *l = tok->links; + l != NULL; + l = l->next) { + typename unordered_map::const_iterator + iter = tok_map.find(l->next_tok); + StateId nextstate = iter->second; + KALDI_ASSERT(iter != tok_map.end()); + BaseFloat cost_offset = 0.0; + if (l->ilabel != 0) { // emitting.. + KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); + cost_offset = cost_offsets_[f]; + } + Arc arc(l->ilabel, l->olabel, + Weight(l->graph_cost, l->acoustic_cost - cost_offset), + nextstate); + ofst->AddArc(cur_state, arc); + } + if (f == frame_end) { + if (use_final_probs && !final_costs.empty()) { + typename unordered_map::const_iterator + iter = final_costs.find(tok); + if (iter != final_costs.end()) + ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); + } else { + ofst->SetFinal(cur_state, LatticeWeight::One()); + } + } + } + } + // Create final_arc for later appending with the next chunk + if (create_final_state) { + state_label_map_.clear(); + state_label_map_.reserve(std::min((int32)1e5, config_.max_active)); + for (Token *tok = active_toks_[frame_end].toks; tok != NULL; tok = tok->next) { + StateId cur_state = tok_map[tok]; + int32 id = state_label_avilable_idx_++; + state_label_map_[tok] = id; + Weight final_weight = (!decoding_finalized_ && ofst->Final(cur_state) == Weight::Zero())? Weight::One(): ofst->Final(cur_state); + + Arc arc(0, id, + Weight(0, final_weight.Value1()+final_weight.Value2()), + end_state); + ofst->AddArc(cur_state, arc); + ofst->SetFinal(cur_state, Weight::Zero()); + } + } + return (ofst->NumStates() > 0); +} + +template +void LatticeIncrementalDecoderTpl::PossiblyResizeHash(size_t num_toks) { + size_t new_sz = static_cast(static_cast(num_toks) + * config_.hash_ratio); + if (new_sz > toks_.Size()) { + toks_.SetSize(new_sz); + } +} + +/* + A note on the definition of extra_cost. + + extra_cost is used in pruning tokens, to save memory. + + Define the 'forward cost' of a token as zero for any token on the frame + we're currently decoding; and for other frames, as the shortest-path cost + between that token and a token on the frame we're currently decoding. + (by "currently decoding" I mean the most recently processed frame). + + Then define the extra_cost of a token (always >= 0) as the forward-cost of + the token minus the smallest forward-cost of any token on the same frame. + + We can use the extra_cost to accurately prune away tokens that we know will + never appear in the lattice. If the extra_cost is greater than the desired + lattice beam, the token would provably never appear in the lattice, so we can + prune away the token. + + The advantage of storing the extra_cost rather than the forward-cost, is that + it is less costly to keep the extra_cost up-to-date when we process new frames. + When we process a new frame, *all* the previous frames' forward-costs would change; + but in general the extra_cost will change only for a finite number of frames. + (Actually we don't update all the extra_costs every time we update a frame; we + only do it every 'config_.prune_interval' frames). + */ + +// FindOrAddToken either locates a token in hash of toks_, +// or if necessary inserts a new, empty token (i.e. with no forward links) +// for the current frame. [note: it's inserted if necessary into hash toks_ +// and also into the singly linked list of tokens active on this frame +// (whose head is at active_toks_[frame]). +template +inline Token* LatticeIncrementalDecoderTpl::FindOrAddToken( + StateId state, int32 frame_plus_one, BaseFloat tot_cost, + Token *backpointer, bool *changed) { + // Returns the Token pointer. Sets "changed" (if non-NULL) to true + // if the token was newly created or the cost changed. + KALDI_ASSERT(frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + Elem *e_found = toks_.Find(state); + if (e_found == NULL) { // no such token presently. + const BaseFloat extra_cost = 0.0; + // tokens on the currently final frame have zero extra_cost + // as any of them could end up + // on the winning path. + Token *new_tok = new Token (tot_cost, extra_cost, NULL, toks, backpointer); + // NULL: no forward links yet + toks = new_tok; + num_toks_++; + toks_.Insert(state, new_tok); + if (changed) *changed = true; + return new_tok; + } else { + Token *tok = e_found->val; // There is an existing Token for this state. + if (tok->tot_cost > tot_cost) { // replace old token + tok->tot_cost = tot_cost; + // SetBackpointer() just does tok->backpointer = backpointer in + // the case where Token == BackpointerToken, else nothing. + tok->SetBackpointer(backpointer); + // we don't allocate a new token, the old stays linked in active_toks_ + // we only replace the tot_cost + // in the current frame, there are no forward links (and no extra_cost) + // only in ProcessNonemitting we have to delete forward links + // in case we visit a state for the second time + // those forward links, that lead to this replaced token before: + // they remain and will hopefully be pruned later (PruneForwardLinks...) + if (changed) *changed = true; + } else { + if (changed) *changed = false; + } + return tok; + } +} + +// prunes outgoing links for all tokens in active_toks_[frame] +// it's called by PruneActiveTokens +// all links, that have link_extra_cost > lattice_beam are pruned +template +void LatticeIncrementalDecoderTpl::PruneForwardLinks( + int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, BaseFloat delta) { + // delta is the amount by which the extra_costs must change + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + + *extra_costs_changed = false; + *links_pruned = false; + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen. + if (!warned_) { + KALDI_WARN << "No tokens alive [doing pruning].. warning first " + "time only for each utterance\n"; + warned_ = true; + } + } + + // We have to iterate until there is no more change, because the links + // are not guaranteed to be in topological order. + bool changed = true; // difference new minus old extra cost >= delta ? + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost for tok. + BaseFloat tok_extra_cost = std::numeric_limits::infinity(); + // tok_extra_cost is the best (min) of link_extra_cost of outgoing links + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); // difference in brackets is >= 0 + // link_exta_cost is the difference in score between the best paths + // through link source state and through link destination state + KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + *links_pruned = true; + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; // move to next link + link = link->next; + } + } // for all outgoing links + if (fabs(tok_extra_cost - tok->extra_cost) > delta) + changed = true; // difference new minus old is bigger than delta + tok->extra_cost = tok_extra_cost; + // will be +infinity or <= lattice_beam_. + // infinity indicates, that no forward link survived pruning + } // for all Token on active_toks_[frame] + if (changed) *extra_costs_changed = true; + + // Note: it's theoretically possible that aggressive compiler + // optimizations could cause an infinite loop here for small delta and + // high-dynamic-range scores. + } // while changed +} + +// PruneForwardLinksFinal is a version of PruneForwardLinks that we call +// on the final frame. If there are final tokens active, it uses +// the final-probs for pruning, otherwise it treats all tokens as final. +template +void LatticeIncrementalDecoderTpl::PruneForwardLinksFinal() { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame_plus_one = active_toks_.size() - 1; + + if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen. + KALDI_WARN << "No tokens alive at end of file"; + + typedef typename unordered_map::const_iterator IterType; + ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_); + decoding_finalized_ = true; + // We call DeleteElems() as a nicety, not because it's really necessary; + // otherwise there would be a time, after calling PruneTokensForFrame() on the + // final frame, when toks_.GetList() or toks_.Clear() would contain pointers + // to nonexistent tokens. + DeleteElems(toks_.Clear()); + + // Now go through tokens on this frame, pruning forward links... may have to + // iterate a few times until there is no more change, because the list is not + // in topological order. This is a modified version of the code in + // PruneForwardLinks, but here we also take account of the final-probs. + bool changed = true; + BaseFloat delta = 1.0e-05; + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost. It has a term in it that corresponds + // to the "final-prob", so instead of initializing tok_extra_cost to infinity + // below we set it to the difference between the (score+final_prob) of this token, + // and the best such (score+final_prob). + BaseFloat final_cost; + if (final_costs_.empty()) { + final_cost = 0.0; + } else { + IterType iter = final_costs_.find(tok); + if (iter != final_costs_.end()) + final_cost = iter->second; + else + final_cost = std::numeric_limits::infinity(); + } + BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_; + // tok_extra_cost will be a "min" over either directly being final, or + // being indirectly final through other links, and the loop below may + // decrease its value: + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; + link = link->next; + } + } + // prune away tokens worse than lattice_beam above best path. This step + // was not necessary in the non-final case because then, this case + // showed up as having no forward links. Here, the tok_extra_cost has + // an extra component relating to the final-prob. + if (tok_extra_cost > config_.lattice_beam) + tok_extra_cost = std::numeric_limits::infinity(); + // to be pruned in PruneTokensForFrame + + if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) + changed = true; + tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. + } + } // while changed +} + +template +BaseFloat LatticeIncrementalDecoderTpl::FinalRelativeCost() const { + if (!decoding_finalized_) { + BaseFloat relative_cost; + ComputeFinalCosts(NULL, &relative_cost, NULL); + return relative_cost; + } else { + // we're not allowed to call that function if FinalizeDecoding() has + // been called; return a cached value. + return final_relative_cost_; + } +} + + +// Prune away any tokens on this frame that have no forward links. +// [we don't do this in PruneForwardLinks because it would give us +// a problem with dangling pointers]. +// It's called by PruneActiveTokens if any forward links have been pruned +template +void LatticeIncrementalDecoderTpl::PruneTokensForFrame(int32 frame_plus_one) { + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + if (toks == NULL) + KALDI_WARN << "No tokens alive [doing pruning]"; + Token *tok, *next_tok, *prev_tok = NULL; + for (tok = toks; tok != NULL; tok = next_tok) { + next_tok = tok->next; + if (tok->extra_cost == std::numeric_limits::infinity()) { + // token is unreachable from end of graph; (no forward links survived) + // excise tok from list and delete tok. + if (prev_tok != NULL) prev_tok->next = tok->next; + else toks = tok->next; + delete tok; + num_toks_--; + } else { // fetch next Token + prev_tok = tok; + } + } +} + +// Go backwards through still-alive tokens, pruning them, starting not from +// the current frame (where we want to keep all tokens) but from the frame before +// that. We go backwards through the frames and stop when we reach a point +// where the delta-costs are not changing (and the delta controls when we consider +// a cost to have "not changed"). +template +void LatticeIncrementalDecoderTpl::PruneActiveTokens(BaseFloat delta) { + int32 cur_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // The index "f" below represents a "frame plus one", i.e. you'd have to subtract + // one to get the corresponding index for the decodable object. + for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) { + // Reason why we need to prune forward links in this situation: + // (1) we have never pruned them (new TokenList) + // (2) we have not yet pruned the forward links to the next f, + // after any of those tokens have changed their extra_cost. + if (active_toks_[f].must_prune_forward_links) { + bool extra_costs_changed = false, links_pruned = false; + PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta); + if (extra_costs_changed && f > 0) // any token has changed extra_cost + active_toks_[f-1].must_prune_forward_links = true; + if (links_pruned) // any link was pruned + active_toks_[f].must_prune_tokens = true; + active_toks_[f].must_prune_forward_links = false; // job done + } + if (f+1 < cur_frame_plus_one && // except for last f (no forward links) + active_toks_[f+1].must_prune_tokens) { + PruneTokensForFrame(f+1); + active_toks_[f+1].must_prune_tokens = false; + } + } + KALDI_VLOG(4) << "PruneActiveTokens: pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +template +void LatticeIncrementalDecoderTpl::ComputeFinalCosts( + unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const { + KALDI_ASSERT(!decoding_finalized_); + if (final_costs != NULL) + final_costs->clear(); + const Elem *final_toks = toks_.GetList(); + BaseFloat infinity = std::numeric_limits::infinity(); + BaseFloat best_cost = infinity, + best_cost_with_final = infinity; + + while (final_toks != NULL) { + StateId state = final_toks->key; + Token *tok = final_toks->val; + const Elem *next = final_toks->tail; + BaseFloat final_cost = fst_->Final(state).Value(); + BaseFloat cost = tok->tot_cost, + cost_with_final = cost + final_cost; + best_cost = std::min(cost, best_cost); + best_cost_with_final = std::min(cost_with_final, best_cost_with_final); + if (final_costs != NULL && final_cost != infinity) + (*final_costs)[tok] = final_cost; + final_toks = next; + } + if (final_relative_cost != NULL) { + if (best_cost == infinity && best_cost_with_final == infinity) { + // Likely this will only happen if there are no tokens surviving. + // This seems the least bad way to handle it. + *final_relative_cost = infinity; + } else { + *final_relative_cost = best_cost_with_final - best_cost; + } + } + if (final_best_cost != NULL) { + if (best_cost_with_final != infinity) { // final-state exists. + *final_best_cost = best_cost_with_final; + } else { // no final-state exists. + *final_best_cost = best_cost; + } + } +} + +template +void LatticeIncrementalDecoderTpl::AdvanceDecoding(DecodableInterface *decodable, + int32 max_num_frames) { + if (std::is_same >::value) { + // if the type 'FST' is the FST base-class, then see if the FST type of fst_ + // is actually VectorFst or ConstFst. If so, call the AdvanceDecoding() + // function after casting *this to the more specific type. + if (fst_->Type() == "const") { + LatticeIncrementalDecoderTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } else if (fst_->Type() == "vector") { + LatticeIncrementalDecoderTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } + } + + + KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ && + "You must call InitDecoding() before AdvanceDecoding"); + int32 num_frames_ready = decodable->NumFramesReady(); + // num_frames_ready must be >= num_frames_decoded, or else + // the number of frames ready must have decreased (which doesn't + // make sense) or the decodable object changed between calls + // (which isn't allowed). + KALDI_ASSERT(num_frames_ready >= NumFramesDecoded()); + int32 target_frames_decoded = num_frames_ready; + if (max_num_frames >= 0) + target_frames_decoded = std::min(target_frames_decoded, + NumFramesDecoded() + max_num_frames); + while (NumFramesDecoded() < target_frames_decoded) { + if (NumFramesDecoded() % config_.prune_interval == 0) { + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + } + BaseFloat cost_cutoff = ProcessEmitting(decodable); + ProcessNonemitting(cost_cutoff); + } +} + +// FinalizeDecoding() is a version of PruneActiveTokens that we call +// (optionally) on the final frame. Takes into account the final-prob of +// tokens. This function used to be called PruneActiveTokensFinal(). +template +void LatticeIncrementalDecoderTpl::FinalizeDecoding() { + int32 final_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // PruneForwardLinksFinal() prunes final frame (with final-probs), and + // sets decoding_finalized_. + PruneForwardLinksFinal(); + for (int32 f = final_frame_plus_one - 1; f >= 0; f--) { + bool b1, b2; // values not used. + BaseFloat dontcare = 0.0; // delta of zero means we must always update + PruneForwardLinks(f, &b1, &b2, dontcare); + PruneTokensForFrame(f + 1); + } + PruneTokensForFrame(0); + KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +/// Gets the weight cutoff. Also counts the active tokens. +template +BaseFloat LatticeIncrementalDecoderTpl::GetCutoff(Elem *list_head, size_t *tok_count, + BaseFloat *adaptive_beam, Elem **best_elem) { + BaseFloat best_weight = std::numeric_limits::infinity(); + // positive == high cost == bad. + size_t count = 0; + if (config_.max_active == std::numeric_limits::max() && + config_.min_active == 0) { + for (Elem *e = list_head; e != NULL; e = e->tail, count++) { + BaseFloat w = static_cast(e->val->tot_cost); + if (w < best_weight) { + best_weight = w; + if (best_elem) *best_elem = e; + } + } + if (tok_count != NULL) *tok_count = count; + if (adaptive_beam != NULL) *adaptive_beam = config_.beam; + return best_weight + config_.beam; + } else { + tmp_array_.clear(); + for (Elem *e = list_head; e != NULL; e = e->tail, count++) { + BaseFloat w = e->val->tot_cost; + tmp_array_.push_back(w); + if (w < best_weight) { + best_weight = w; + if (best_elem) *best_elem = e; + } + } + if (tok_count != NULL) *tok_count = count; + + BaseFloat beam_cutoff = best_weight + config_.beam, + min_active_cutoff = std::numeric_limits::infinity(), + max_active_cutoff = std::numeric_limits::infinity(); + + KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() + << " is " << tmp_array_.size(); + + if (tmp_array_.size() > static_cast(config_.max_active)) { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.max_active, + tmp_array_.end()); + max_active_cutoff = tmp_array_[config_.max_active]; + } + if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam. + if (adaptive_beam) + *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta; + return max_active_cutoff; + } + if (tmp_array_.size() > static_cast(config_.min_active)) { + if (config_.min_active == 0) min_active_cutoff = best_weight; + else { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.min_active, + tmp_array_.size() > static_cast(config_.max_active) ? + tmp_array_.begin() + config_.max_active : + tmp_array_.end()); + min_active_cutoff = tmp_array_[config_.min_active]; + } + } + if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam. + if (adaptive_beam) + *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta; + return min_active_cutoff; + } else { + *adaptive_beam = config_.beam; + return beam_cutoff; + } + } +} + +template +BaseFloat LatticeIncrementalDecoderTpl::ProcessEmitting( + DecodableInterface *decodable) { + KALDI_ASSERT(active_toks_.size() > 0); + int32 frame = active_toks_.size() - 1; // frame is the frame-index + // (zero-based) used to get likelihoods + // from the decodable object. + active_toks_.resize(active_toks_.size() + 1); + + Elem *final_toks = toks_.Clear(); // analogous to swapping prev_toks_ / cur_toks_ + // in simple-decoder.h. Removes the Elems from + // being indexed in the hash in toks_. + Elem *best_elem = NULL; + BaseFloat adaptive_beam; + size_t tok_cnt; + BaseFloat cur_cutoff = GetCutoff(final_toks, &tok_cnt, &adaptive_beam, &best_elem); + KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is " + << adaptive_beam; + + PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough. + + BaseFloat next_cutoff = std::numeric_limits::infinity(); + // pruning "online" before having seen all tokens + + BaseFloat cost_offset = 0.0; // Used to keep probabilities in a good + // dynamic range. + + + // First process the best token to get a hopefully + // reasonably tight bound on the next cutoff. The only + // products of the next block are "next_cutoff" and "cost_offset". + if (best_elem) { + StateId state = best_elem->key; + Token *tok = best_elem->val; + cost_offset = - tok->tot_cost; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + BaseFloat new_weight = arc.weight.Value() + cost_offset - + decodable->LogLikelihood(frame, arc.ilabel) + tok->tot_cost; + if (new_weight + adaptive_beam < next_cutoff) + next_cutoff = new_weight + adaptive_beam; + } + } + } + + // Store the offset on the acoustic likelihoods that we're applying. + // Could just do cost_offsets_.push_back(cost_offset), but we + // do it this way as it's more robust to future code changes. + cost_offsets_.resize(frame + 1, 0.0); + cost_offsets_[frame] = cost_offset; + + // the tokens are now owned here, in final_toks, and the hash is empty. + // 'owned' is a complex thing here; the point is we need to call DeleteElem + // on each elem 'e' to let toks_ know we're done with them. + for (Elem *e = final_toks, *e_tail; e != NULL; e = e_tail) { + // loop this way because we delete "e" as we go. + StateId state = e->key; + Token *tok = e->val; + if (tok->tot_cost <= cur_cutoff) { + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + BaseFloat ac_cost = cost_offset - + decodable->LogLikelihood(frame, arc.ilabel), + graph_cost = arc.weight.Value(), + cur_cost = tok->tot_cost, + tot_cost = cur_cost + ac_cost + graph_cost; + if (tot_cost > next_cutoff) continue; + else if (tot_cost + adaptive_beam < next_cutoff) + next_cutoff = tot_cost + adaptive_beam; // prune by best current token + // Note: the frame indexes into active_toks_ are one-based, + // hence the + 1. + Token *next_tok = FindOrAddToken(arc.nextstate, + frame + 1, tot_cost, tok, NULL); + // NULL: no change indicator needed + + // Add ForwardLink from tok to next_tok (put on head of list tok->links) + tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, + graph_cost, ac_cost, tok->links); + } + } // for all arcs + } + e_tail = e->tail; + toks_.Delete(e); // delete Elem + } + return next_cutoff; +} + +// static inline +template +void LatticeIncrementalDecoderTpl::DeleteForwardLinks(Token *tok) { + ForwardLinkT *l = tok->links, *m; + while (l != NULL) { + m = l->next; + delete l; + l = m; + } + tok->links = NULL; +} + + +template +void LatticeIncrementalDecoderTpl::ProcessNonemitting(BaseFloat cutoff) { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame = static_cast(active_toks_.size()) - 2; + // Note: "frame" is the time-index we just processed, or -1 if + // we are processing the nonemitting transitions before the + // first frame (called from InitDecoding()). + + // Processes nonemitting arcs for one frame. Propagates within toks_. + // Note-- this queue structure is is not very optimal as + // it may cause us to process states unnecessarily (e.g. more than once), + // but in the baseline code, turning this vector into a set to fix this + // problem did not improve overall speed. + + KALDI_ASSERT(queue_.empty()); + + if (toks_.GetList() == NULL) { + if (!warned_) { + KALDI_WARN << "Error, no surviving tokens: frame is " << frame; + warned_ = true; + } + } + + for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) { + StateId state = e->key; + if (fst_->NumInputEpsilons(state) != 0) + queue_.push_back(state); + } + + while (!queue_.empty()) { + StateId state = queue_.back(); + queue_.pop_back(); + + Token *tok = toks_.Find(state)->val; // would segfault if state not in toks_ but this can't happen. + BaseFloat cur_cost = tok->tot_cost; + if (cur_cost > cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + // but since most states are emitting it's not a huge issue. + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel == 0) { // propagate nonemitting only... + BaseFloat graph_cost = arc.weight.Value(), + tot_cost = cur_cost + graph_cost; + if (tot_cost < cutoff) { + bool changed; + + Token *new_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, + tok, &changed); + + tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new [if so, add into queue]. + if (changed && fst_->NumInputEpsilons(arc.nextstate) != 0) + queue_.push_back(arc.nextstate); + } + } + } // for all arcs + } // while queue not empty +} + + +template +void LatticeIncrementalDecoderTpl::DeleteElems(Elem *list) { + for (Elem *e = list, *e_tail; e != NULL; e = e_tail) { + e_tail = e->tail; + toks_.Delete(e); + } +} + +template +void LatticeIncrementalDecoderTpl::ClearActiveTokens() { // a cleanup routine, at utt end/begin + for (size_t i = 0; i < active_toks_.size(); i++) { + // Delete all tokens alive on this frame, and any forward + // links they may have. + for (Token *tok = active_toks_[i].toks; tok != NULL; ) { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + active_toks_.clear(); + KALDI_ASSERT(num_toks_ == 0); +} + +// static +template +void LatticeIncrementalDecoderTpl::TopSortTokens( + Token *tok_list, std::vector *topsorted_list) { + unordered_map token2pos; + typedef typename unordered_map::iterator IterType; + int32 num_toks = 0; + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + num_toks++; + int32 cur_pos = 0; + // We assign the tokens numbers num_toks - 1, ... , 2, 1, 0. + // This is likely to be in closer to topological order than + // if we had given them ascending order, because of the way + // new tokens are put at the front of the list. + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + token2pos[tok] = num_toks - ++cur_pos; + + unordered_set reprocess; + + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) { + Token *tok = iter->first; + int32 pos = iter->second; + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + // We only need to consider epsilon links, since non-epsilon links + // transition between frames and this function only needs to sort a list + // of tokens from a single frame. + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { // another token on this frame, + // so must consider it. + int32 next_pos = following_iter->second; + if (next_pos < pos) { // reassign the position of the next Token. + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + // In case we had previously assigned this token to be reprocessed, we can + // erase it from that set because it's "happy now" (we just processed it). + reprocess.erase(tok); + } + + size_t max_loop = 1000000, loop_count; // max_loop is to detect epsilon cycles. + for (loop_count = 0; + !reprocess.empty() && loop_count < max_loop; ++loop_count) { + std::vector reprocess_vec; + for (typename unordered_set::iterator iter = reprocess.begin(); + iter != reprocess.end(); ++iter) + reprocess_vec.push_back(*iter); + reprocess.clear(); + for (typename std::vector::iterator iter = reprocess_vec.begin(); + iter != reprocess_vec.end(); ++iter) { + Token *tok = *iter; + int32 pos = token2pos[tok]; + // Repeat the processing we did above (for comments, see above). + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { + int32 next_pos = following_iter->second; + if (next_pos < pos) { + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + } + } + KALDI_ASSERT(loop_count < max_loop && "Epsilon loops exist in your decoding " + "graph (this is not allowed!)"); + + topsorted_list->clear(); + topsorted_list->resize(cur_pos, NULL); // create a list with NULLs in between. + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) + (*topsorted_list)[iter->second] = iter->first; +} + +template +bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, + bool redeterminize, CompactLattice *olat) { + using namespace fst; + + if (last_get_lattice_frame_ < NumFramesDecoded()) { + // Get lattice chunk with initial state + Lattice raw_fst; + KALDI_ASSERT(GetRawLattice(&raw_fst, use_final_probs, last_get_lattice_frame_, NumFramesDecoded(), last_get_lattice_frame_ != 0, !decoding_finalized_)); + // Determinize the chunk + CompactLattice clat; + if (!DeterminizeLatticePhonePrunedWrapper( + trans_model_, + &raw_fst, + config_.lattice_beam, + &clat, + config_.det_opts)) + KALDI_WARN << "Determinization finished earlier than the beam"; + + final_arc_list_.swap(final_arc_list_prev_); + final_arc_list_.clear(); + + // Appending new chunk to the old one + int32 state_offset=olat->NumStates(); + unordered_map initial_arc_map; // the previous states of these arcs are initial states + initial_arc_map.reserve(std::min((int32)1e5, config_.max_active)); + for (StateIterator siter(clat); !siter.Done(); siter.Next()) { + auto s = siter.Value(); + StateId state_append = -1; + if (last_get_lattice_frame_ == 0) { // do not need to copy initial state + state_append = s+state_offset; + KALDI_ASSERT(state_append == olat->AddState()); + olat->SetFinal(state_append, clat.Final(s)); + } else if (s != 0) { + state_append = s+state_offset-1; // do not include the first state + KALDI_ASSERT(state_append == olat->AddState()); + olat->SetFinal(state_append, clat.Final(s)); + } + + for (ArcIterator aiter(clat, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + // construct a copy of the state & arcs + if (last_get_lattice_frame_ == 0 || s != 0) { // do not need to copy initial state + CompactLatticeArc arc_append(arc); + arc_append.nextstate += state_offset; + olat->AddArc(state_append, arc_append); + } + if (arc.olabel > config_.max_word_id) { + if (s==0) { // initial_arc + initial_arc_map[arc.olabel]=aiter.Position(); + } else { // final_arc + final_arc_list_.push_back(pair(state_append, aiter.Position())); + } + } + } + } + // connect the states between two chunks + if (last_get_lattice_frame_ != 0) { + KALDI_ASSERT(final_arc_list_prev_.size()); + StateId prev_final_state = -1; + for (auto&i:final_arc_list_prev_) { + MutableArcIterator aiter(olat, i.first); + aiter.Seek(i.second); + auto &arc_append = aiter.Value(); + auto r = initial_arc_map.find(arc_append.olabel); + if (r!=initial_arc_map.end()) { + ArcIterator aiter_chunk(clat, 0); // initial state + aiter_chunk.Seek(r->second); + const auto &arc_chunk = aiter_chunk.Value(); + KALDI_ASSERT(arc_chunk.olabel == arc_append.olabel); + StateId state_append = arc_chunk.nextstate+state_offset; + if (prev_final_state == -1) prev_final_state=arc_append.nextstate; + else KALDI_ASSERT(arc_append.nextstate == prev_final_state); + CompactLatticeArc arc_append_mod(arc_append); + arc_append_mod.nextstate = state_append; + + CompactLatticeWeight weight_offset; + weight_offset.SetWeight(LatticeWeight(0, -state_label_forward_prob_[arc_append.olabel])); + vector weights = {arc_append_mod.weight, arc_chunk.weight, olat->Final(prev_final_state), weight_offset}; + BaseFloat v1=0, v2=0; + for (auto& i:weights) + v1+=i.Weight().Value1(); + for (auto& i:weights) + v2+=i.Weight().Value2(); + vector s; + for (auto& i:weights) + s.insert(s.end(), i.String().begin(), i.String().end()); + + arc_append_mod.weight = CompactLatticeWeight(LatticeWeight(v1,v2), s); + arc_append_mod.olabel = 0; + aiter.SetValue(arc_append_mod); + } // otherwise, it has been pruned + state_label_forward_prob_.erase(arc_append.olabel); + } + // making all unmodified remaining arcs of final_arc_list_prev_ are connected to a dead state + olat->SetFinal(prev_final_state, CompactLatticeWeight::Zero()); + } + KALDI_VLOG(2) << "Frame: " <NumStates(); + } // TODO: check in the case the last frame is det twice + last_get_lattice_frame_ = NumFramesDecoded(); + // Determinize the final lattice + if (redeterminize) { + DeterminizeLatticePrunedOptions det_opts; + det_opts.delta = config_.det_opts.delta; + det_opts.max_mem = config_.det_opts.max_mem; + Lattice lat; + ConvertLattice(*olat, &lat); + Invert(&lat); + if (lat.Properties(fst::kTopSorted, true) == 0) { + if (!TopSort(&lat)) { + // Cannot topologically sort the lattice -- determinization will fail. + KALDI_ERR << "Topological sorting of state-level lattice failed (probably" + << " your lexicon has empty words or your LM has epsilon cycles" + << ")."; + } + } + if (!DeterminizeLatticePruned( + lat, + config_.lattice_beam, + olat, + det_opts)) + KALDI_WARN << "Determinization finished earlier than the beam"; + Connect(olat); // Remove unreachable states... there might be + } + + // a small number of these, in some cases. + // Note: if something went wrong and the raw lattice was empty, + // we should still get to this point in the code without warnings or failures. + return (olat->NumStates() != 0); +} + +// Instantiate the template for the combination of token types and FST types +// that we'll need. +template class LatticeIncrementalDecoderTpl, decoder::StdToken>; +template class LatticeIncrementalDecoderTpl, decoder::StdToken >; +template class LatticeIncrementalDecoderTpl, decoder::StdToken >; +template class LatticeIncrementalDecoderTpl; + +template class LatticeIncrementalDecoderTpl , decoder::BackpointerToken>; +template class LatticeIncrementalDecoderTpl, decoder::BackpointerToken >; +template class LatticeIncrementalDecoderTpl, decoder::BackpointerToken >; +template class LatticeIncrementalDecoderTpl; + + +} // end namespace kaldi. diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h new file mode 100644 index 00000000000..ae033d865ae --- /dev/null +++ b/src/decoder/lattice-incremental-decoder.h @@ -0,0 +1,443 @@ +// decoder/lattice-incremental-decoder.h + +// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann; +// 2013-2014 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_LATTICE_INCREMENTAL_DECODER_H_ +#define KALDI_DECODER_LATTICE_INCREMENTAL_DECODER_H_ + + +#include "util/stl-utils.h" +#include "util/hash-list.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/grammar-fst.h" +#include "lattice-faster-decoder.h" + +namespace kaldi { + +struct LatticeIncrementalDecoderConfig { + BaseFloat beam; + int32 max_active; + int32 min_active; + BaseFloat lattice_beam; + int32 prune_interval; + bool determinize_lattice; // not inspected by this class... used in + // command-line program. + BaseFloat beam_delta; // has nothing to do with beam_ratio + BaseFloat hash_ratio; + BaseFloat prune_scale; // Note: we don't make this configurable on the command line, + // it's not a very important parameter. It affects the + // algorithm that prunes the tokens as we go. + // Most of the options inside det_opts are not actually queried by the + // LatticeIncrementalDecoder class itself, but by the code that calls it, for + // example in the function DecodeUtteranceLatticeIncremental. + int32 max_word_id; // for GetLattice + fst::DeterminizeLatticePhonePrunedOptions det_opts; + + LatticeIncrementalDecoderConfig(): beam(16.0), + max_active(std::numeric_limits::max()), + min_active(200), + lattice_beam(10.0), + prune_interval(25), + determinize_lattice(true), + beam_delta(0.5), + hash_ratio(2.0), + prune_scale(0.1), + max_word_id(1e7) { } + void Register(OptionsItf *opts) { + det_opts.Register(opts); + opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); + opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; " + "more accurate"); + opts->Register("min-active", &min_active, "Decoder minimum #active states."); + opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, " + "and deeper lattices"); + opts->Register("prune-interval", &prune_interval, "Interval (in frames) at " + "which to prune tokens"); + opts->Register("determinize-lattice", &determinize_lattice, "If true, " + "determinize the lattice (lattice-determinization, keeping only " + "best pdf-sequence for each word-sequence)."); + opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " + "parameter is obscure and relates to a speedup in the way the " + "max-active constraint is applied. Larger is more accurate."); + opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to " + "control hash behavior"); + } + void Check() const { + KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 + && min_active <= max_active + && prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0 + && prune_scale > 0.0 && prune_scale < 1.0); + } +}; + +/** This is the "normal" lattice-generating decoder. + See \ref lattices_generation \ref decoders_faster and \ref decoders_simple + for more information. + + The decoder is templated on the FST type and the token type. The token type + will normally be StdToken, but also may be BackpointerToken which is to support + quick lookup of the current best path (see lattice-faster-online-decoder.h) + + The FST you invoke this decoder with is expected to equal + Fst::Fst, a.k.a. StdFst, or GrammarFst. If you invoke it with + FST == StdFst and it notices that the actual FST type is + fst::VectorFst or fst::ConstFst, the decoder object + will internally cast itself to one that is templated on those more specific + types; this is an optimization for speed. + */ +template +class LatticeIncrementalDecoderTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using ForwardLinkT = decoder::ForwardLink; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeIncrementalDecoderTpl(const FST &fst, const TransitionModel& trans_model, + const LatticeIncrementalDecoderConfig &config); + + // This version of the constructor takes ownership of the fst, and will delete + // it when this object is destroyed. + LatticeIncrementalDecoderTpl(const LatticeIncrementalDecoderConfig &config, + FST *fst, const TransitionModel& trans_model); + + void SetOptions(const LatticeIncrementalDecoderConfig &config) { + config_ = config; + } + + const LatticeIncrementalDecoderConfig &GetOptions() const { + return config_; + } + + ~LatticeIncrementalDecoderTpl(); + + /// Decodes until there are no more frames left in the "decodable" object.. + /// note, this may block waiting for input if the "decodable" object blocks. + /// Returns true if any kind of traceback is available (not necessarily from a + /// final state). + bool Decode(DecodableInterface *decodable); + + + /// says whether a final-state was active on the last frame. If it was not, the + /// lattice (or traceback) will end with states that are not final-states. + bool ReachedFinal() const { + return FinalRelativeCost() != std::numeric_limits::infinity(); + } + + /// Outputs an FST corresponding to the single best path through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. Note: this just calls GetRawLattice() + /// and figures out the shortest path. + bool GetBestPath(Lattice *ofst, + bool use_final_probs = true) const; + + /// Outputs an FST corresponding to the raw, state-level + /// tracebacks. Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state + /// of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + /// The raw lattice will be topologically sorted. + /// + /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, + /// which also supports a pruning beam, in case for some reason + /// you want it pruned tighter than the regular lattice beam. + /// We could put that here in future needed. + bool GetRawLattice(Lattice *ofst, bool use_final_probs = true) const; + + bool GetRawLattice(Lattice *ofst, bool use_final_probs, + int32 frame_begin, + int32 frame_end, + bool create_initial_state, + bool create_final_state); + + + + /// InitDecoding initializes the decoding, and should only be used if you + /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to + /// call this. You can also call InitDecoding if you have already decoded an + /// utterance and want to start with a new utterance. + void InitDecoding(); + + /// This will decode until there are no more frames ready in the decodable + /// object. You can keep calling it each time more frames become available. + /// If max_num_frames is specified, it specifies the maximum number of frames + /// the function will decode before returning. + void AdvanceDecoding(DecodableInterface *decodable, + int32 max_num_frames = -1); + + /// This function may be optionally called after AdvanceDecoding(), when you + /// do not plan to decode any further. It does an extra pruning step that + /// will help to prune the lattices output by GetLattice and (particularly) + /// GetRawLattice more accurately, particularly toward the end of the + /// utterance. It does this by using the final-probs in pruning (if any + /// final-state survived); it also does a final pruning step that visits all + /// states (the pruning that is done during decoding may fail to prune states + /// that are within kPruningScale = 0.1 outside of the beam). If you call + /// this, you cannot call AdvanceDecoding again (it will fail), and you + /// cannot call GetLattice() and related functions with use_final_probs = + /// false. + /// Used to be called PruneActiveTokensFinal(). + void FinalizeDecoding(); + + /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives + /// more information. It returns the difference between the best (final-cost + /// plus cost) of any token on the final frame, and the best cost of any token + /// on the final frame. If it is infinity it means no final-states were + /// present on the final frame. It will usually be nonnegative. If it not + /// too positive (e.g. < 5 is my first guess, but this is not tested) you can + /// take it as a good indication that we reached the final-state with + /// reasonable likelihood. + BaseFloat FinalRelativeCost() const; + + + // Returns the number of frames decoded so far. The value returned changes + // whenever we call ProcessEmitting(). + inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } + + protected: + // we make things protected instead of private, as code in + // LatticeIncrementalOnlineDecoderTpl, which inherits from this, also uses the + // internals. + + // Deletes the elements of the singly linked list tok->links. + inline static void DeleteForwardLinks(Token *tok); + + // head of per-frame list of Tokens (list is in topological order), + // and something saying whether we ever pruned it using PruneForwardLinks. + struct TokenList { + Token *toks; + bool must_prune_forward_links; + bool must_prune_tokens; + TokenList(): toks(NULL), must_prune_forward_links(true), + must_prune_tokens(true) { } + }; + + using Elem = typename HashList::Elem; + // Equivalent to: + // struct Elem { + // StateId key; + // Token *val; + // Elem *tail; + // }; + + void PossiblyResizeHash(size_t num_toks); + + // FindOrAddToken either locates a token in hash of toks_, or if necessary + // inserts a new, empty token (i.e. with no forward links) for the current + // frame. [note: it's inserted if necessary into hash toks_ and also into the + // singly linked list of tokens active on this frame (whose head is at + // active_toks_[frame]). The frame_plus_one argument is the acoustic frame + // index plus one, which is used to index into the active_toks_ array. + // Returns the Token pointer. Sets "changed" (if non-NULL) to true if the + // token was newly created or the cost changed. + // If Token == StdToken, the 'backpointer' argument has no purpose (and will + // hopefully be optimized out). + inline Token *FindOrAddToken(StateId state, int32 frame_plus_one, + BaseFloat tot_cost, Token *backpointer, + bool *changed); + + // prunes outgoing links for all tokens in active_toks_[frame] + // it's called by PruneActiveTokens + // all links, that have link_extra_cost > lattice_beam are pruned + // delta is the amount by which the extra_costs must change + // before we set *extra_costs_changed = true. + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, + BaseFloat delta); + + // This function computes the final-costs for tokens active on the final + // frame. It outputs to final-costs, if non-NULL, a map from the Token* + // pointer to the final-prob of the corresponding state, for all Tokens + // that correspond to states that have final-probs. This map will be + // empty if there were no final-probs. It outputs to + // final_relative_cost, if non-NULL, the difference between the best + // forward-cost including the final-prob cost, and the best forward-cost + // without including the final-prob cost (this will usually be positive), or + // infinity if there were no final-probs. [c.f. FinalRelativeCost(), which + // outputs this quanitity]. It outputs to final_best_cost, if + // non-NULL, the lowest for any token t active on the final frame, of + // forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in + // the graph of the state corresponding to token t, or the best of + // forward-cost[t] if there were no final-probs active on the final frame. + // You cannot call this after FinalizeDecoding() has been called; in that + // case you should get the answer from class-member variables. + void ComputeFinalCosts(unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const; + + // PruneForwardLinksFinal is a version of PruneForwardLinks that we call + // on the final frame. If there are final tokens active, it uses + // the final-probs for pruning, otherwise it treats all tokens as final. + void PruneForwardLinksFinal(); + + // Prune away any tokens on this frame that have no forward links. + // [we don't do this in PruneForwardLinks because it would give us + // a problem with dangling pointers]. + // It's called by PruneActiveTokens if any forward links have been pruned + void PruneTokensForFrame(int32 frame_plus_one); + + + // Go backwards through still-alive tokens, pruning them if the + // forward+backward cost is more than lat_beam away from the best path. It's + // possible to prove that this is "correct" in the sense that we won't lose + // anything outside of lat_beam, regardless of what happens in the future. + // delta controls when it considers a cost to have changed enough to continue + // going backward and propagating the change. larger delta -> will recurse + // less far. + void PruneActiveTokens(BaseFloat delta); + + /// Gets the weight cutoff. Also counts the active tokens. + BaseFloat GetCutoff(Elem *list_head, size_t *tok_count, + BaseFloat *adaptive_beam, Elem **best_elem); + + /// Processes emitting arcs for one frame. Propagates from prev_toks_ to + /// cur_toks_. Returns the cost cutoff for subsequent ProcessNonemitting() to + /// use. + BaseFloat ProcessEmitting(DecodableInterface *decodable); + + /// Processes nonemitting (epsilon) arcs for one frame. Called after + /// ProcessEmitting() on each frame. The cost cutoff is computed by the + /// preceding ProcessEmitting(). + void ProcessNonemitting(BaseFloat cost_cutoff); + + // HashList defined in ../util/hash-list.h. It actually allows us to maintain + // more than one list (e.g. for current and previous frames), but only one of + // them at a time can be indexed by StateId. It is indexed by frame-index + // plus one, where the frame-index is zero-based, as used in decodable object. + // That is, the emitting probs of frame t are accounted for in tokens at + // toks_[t+1]. The zeroth frame is for nonemitting transition at the start of + // the graph. + HashList toks_; + + std::vector active_toks_; // Lists of tokens, indexed by + // frame (members of TokenList are toks, must_prune_forward_links, + // must_prune_tokens). + std::vector queue_; // temp variable used in ProcessNonemitting, + std::vector tmp_array_; // used in GetCutoff. + + // fst_ is a pointer to the FST we are decoding from. + const FST *fst_; + // delete_fst_ is true if the pointer fst_ needs to be deleted when this + // object is destroyed. + bool delete_fst_; + + std::vector cost_offsets_; // This contains, for each + // frame, an offset that was added to the acoustic log-likelihoods on that + // frame in order to keep everything in a nice dynamic range i.e. close to + // zero, to reduce roundoff errors. + LatticeIncrementalDecoderConfig config_; + int32 num_toks_; // current total #toks allocated... + bool warned_; + + /// decoding_finalized_ is true if someone called FinalizeDecoding(). [note, + /// calling this is optional]. If true, it's forbidden to decode more. Also, + /// if this is set, then the output of ComputeFinalCosts() is in the next + /// three variables. The reason we need to do this is that after + /// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some + /// of the tokens on the last frame are freed, so we free the list from toks_ + /// to avoid having dangling pointers hanging around. + bool decoding_finalized_; + /// For the meaning of the next 3 variables, see the comment for + /// decoding_finalized_ above., and ComputeFinalCosts(). + unordered_map final_costs_; + BaseFloat final_relative_cost_; + BaseFloat final_best_cost_; + + // There are various cleanup tasks... the the toks_ structure contains + // singly linked lists of Token pointers, where Elem is the list type. + // It also indexes them in a hash, indexed by state (this hash is only + // maintained for the most recent frame). toks_.Clear() + // deletes them from the hash and returns the list of Elems. The + // function DeleteElems calls toks_.Delete(elem) for each elem in + // the list, which returns ownership of the Elem to the toks_ structure + // for reuse, but does not delete the Token pointer. The Token pointers + // are reference-counted and are ultimately deleted in PruneTokensForFrame, + // but are also linked together on each frame by their own linked-list, + // using the "next" pointer. We delete them manually. + void DeleteElems(Elem *list); + + // This function takes a singly linked list of tokens for a single frame, and + // outputs a list of them in topological order (it will crash if no such order + // can be found, which will typically be due to decoding graphs with epsilon + // cycles, which are not allowed). Note: the output list may contain NULLs, + // which the caller should pass over; it just happens to be more efficient for + // the algorithm to output a list that contains NULLs. + static void TopSortTokens(Token *tok_list, + std::vector *topsorted_list); + + void ClearActiveTokens(); + + /// Obtains a CompactLattice for the part of this utterance that has been + /// decoded so far. If you call this multiple times (calling it on every frame would not make + /// sense, but every, say, 10, to 40 frames might make sense) it will spread out the + /// work of determinization over time,which might be useful for online applications. + /// + /// @param [in] use_final_probs If true *and* at least one final-state in HCLG + /// was active on the final frame, include final-probs from HCLG + /// in the lattice. Otherwise treat all final-costs of states active + /// on the most recent frame as zero (i.e. Weight::One()). + /// @param [in] redeterminize If true, re-determinize the CompactLattice + /// after appending the most recently decoded chunk to it, to + /// ensure that the output is fully deterministic. + /// This does extra work, but not nearly as much as determinizing + /// a RawLattice from scratch. + /// @param [out] lat The CompactLattice representing what has been decoded + /// so far. + /// @return reached_final This function will returns true if a state that was final in + /// HCLG was active on the most recent frame, and false otherwise. + /// CAUTION: this is not the same meaning as the return value of + /// LatticeFasterDecoder::GetLattice(). + bool GetLattice(bool use_final_probs, + bool redeterminize, CompactLattice *olat); + CompactLattice lat_; + int32 last_get_lattice_frame_; + unordered_map state_label_map_; + int32 state_label_avilable_idx_; + const TransitionModel& trans_model_; + std::vector> final_arc_list_; + std::vector> final_arc_list_prev_; + // TODO use 2 vector since state_label is continuous in each frame, and we need 2 frames + unordered_map state_label_forward_prob_; // alpha for each state_label (Token) + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDecoderTpl); +}; + +typedef LatticeIncrementalDecoderTpl LatticeIncrementalDecoder; + + + +} // end namespace kaldi. + +#endif From be6fba21d163f71ca1671d9f05b14eaefca8d8e0 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Wed, 27 Mar 2019 09:14:49 +0800 Subject: [PATCH 03/60] worse wer & ower --- src/decoder/decoder-wrappers.cc | 19 +- src/decoder/lattice-incremental-decoder.cc | 287 +++++++++++---------- src/decoder/lattice-incremental-decoder.h | 15 +- 3 files changed, 165 insertions(+), 156 deletions(-) diff --git a/src/decoder/decoder-wrappers.cc b/src/decoder/decoder-wrappers.cc index 7f05bf274e4..22655878caa 100644 --- a/src/decoder/decoder-wrappers.cc +++ b/src/decoder/decoder-wrappers.cc @@ -241,6 +241,7 @@ bool DecodeUtteranceLatticeIncremental( std::vector words; GetLinearSymbolSequence(decoded, &alignment, &words, &weight); num_frames = alignment.size(); + KALDI_ASSERT(num_frames == decoder.NumFramesDecoded()); if (words_writer->IsOpen()) words_writer->Write(utt, words); if (alignment_writer->IsOpen()) @@ -259,26 +260,18 @@ bool DecodeUtteranceLatticeIncremental( } // Get lattice, and do determinization if requested. - Lattice lat; - decoder.GetRawLattice(&lat); - if (lat.NumStates() == 0) + CompactLattice clat; + decoder.GetCompactLattice(&clat); + if (clat.NumStates() == 0) KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt; - fst::Connect(&lat); if (determinize) { - CompactLattice clat; - if (!DeterminizeLatticePhonePrunedWrapper( - trans_model, - &lat, - decoder.GetOptions().lattice_beam, - &clat, - decoder.GetOptions().det_opts)) - KALDI_WARN << "Determinization finished earlier than the beam for " - << "utterance " << utt; // We'll write the lattice without acoustic scaling. if (acoustic_scale != 0.0) fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat); compact_lattice_writer->Write(utt, clat); } else { + Lattice lat; + decoder.GetRawLattice(&lat); // We'll write the lattice without acoustic scaling. if (acoustic_scale != 0.0) fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &lat); diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 2432f92ddd4..baf0cbd7b5c 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -106,8 +106,7 @@ bool LatticeIncrementalDecoderTpl::Decode(DecodableInterface *decoda ProcessNonemitting(cost_cutoff); } FinalizeDecoding(); - GetLattice(true, false, &lat_); - // GetLattice(true, true, &lat_); // TODO + GetLattice(true, true, &lat_); // Returns true if we have any kind of traceback available (not necessarily // to the end state; query ReachedFinal() for that). @@ -119,9 +118,9 @@ bool LatticeIncrementalDecoderTpl::Decode(DecodableInterface *decoda template bool LatticeIncrementalDecoderTpl::GetBestPath(Lattice *olat, bool use_final_probs) const { - Lattice raw_lat; - GetRawLattice(&raw_lat); - ShortestPath(raw_lat, olat); + CompactLattice lat; + ShortestPath(lat_, &lat); + ConvertLattice(lat, olat); return (olat->NumStates() != 0); } @@ -131,131 +130,15 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( Lattice *ofst, bool use_final_probs) const { ConvertLattice(lat_, ofst); - return true; + Connect(ofst); + return (ofst->NumStates() != 0); } template -bool LatticeIncrementalDecoderTpl::GetRawLattice( - Lattice *ofst, - bool use_final_probs, - int32 frame_begin, - int32 frame_end, - bool create_initial_state, - bool create_final_state) { - typedef LatticeArc Arc; - typedef Arc::StateId StateId; - typedef Arc::Weight Weight; - typedef Arc::Label Label; - - // Note: you can't use the old interface (Decode()) if you want to - // get the lattice with use_final_probs = false. You'd have to do - // InitDecoding() and then AdvanceDecoding(). - if (decoding_finalized_ && !use_final_probs) - KALDI_ERR << "You cannot call FinalizeDecoding() and then call " - << "GetRawLattice() with use_final_probs == false"; - - unordered_map final_costs_local; - - const unordered_map &final_costs = - (decoding_finalized_ ? final_costs_ : final_costs_local); - if (!decoding_finalized_ && use_final_probs) - ComputeFinalCosts(&final_costs_local, NULL, NULL); - - ofst->DeleteStates(); - if (frame_begin != 0) ofst->AddState(); // initial-state for the chunk - // num-frames plus one (since frames are one-based, and we have - // an extra frame for the start-state). - KALDI_ASSERT(frame_end > 0); - const int32 bucket_count = num_toks_/2 + 3; - unordered_map tok_map(bucket_count); - // First create all states. - std::vector token_list; - for (int32 f = frame_begin; f <= frame_end; f++) { - if (active_toks_[f].toks == NULL) { - KALDI_WARN << "GetRawLattice: no tokens active on frame " << f - << ": not producing lattice.\n"; - return false; - } - TopSortTokens(active_toks_[f].toks, &token_list); - for (size_t i = 0; i < token_list.size(); i++) - if (token_list[i] != NULL) - tok_map[token_list[i]] = ofst->AddState(); - } - // The next statement sets the start state of the output FST. Because we - // topologically sorted the tokens, state zero must be the start-state. - StateId begin_state = 0; - StateId end_state = ofst->AddState(); // final-state for the chunk - ofst->SetStart(begin_state); - ofst->SetFinal(end_state, Weight::One()); - - KALDI_VLOG(4) << "init:" << num_toks_/2 + 3 << " buckets:" - << tok_map.bucket_count() << " load:" << tok_map.load_factor() - << " max:" << tok_map.max_load_factor(); - // Create initial_arc for later appending with the previous chunk - if (create_initial_state) { - for (Token *tok = active_toks_[frame_begin].toks; tok != NULL; tok = tok->next) { - StateId cur_state = tok_map[tok]; - int32 id = state_label_map_.find(tok)->second; // it should exist - // TODO: calculate alpha but not use tot_cost or extra_cost - BaseFloat cost_offset = tok->tot_cost; - state_label_forward_prob_[id] = tok->tot_cost; - Arc arc(0, id, - Weight(0, cost_offset), - cur_state); - ofst->AddArc(begin_state, arc); - } - } - // Now create all arcs. - for (int32 f = frame_begin; f <= frame_end; f++) { - for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { - StateId cur_state = tok_map[tok]; - for (ForwardLinkT *l = tok->links; - l != NULL; - l = l->next) { - typename unordered_map::const_iterator - iter = tok_map.find(l->next_tok); - StateId nextstate = iter->second; - KALDI_ASSERT(iter != tok_map.end()); - BaseFloat cost_offset = 0.0; - if (l->ilabel != 0) { // emitting.. - KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); - cost_offset = cost_offsets_[f]; - } - Arc arc(l->ilabel, l->olabel, - Weight(l->graph_cost, l->acoustic_cost - cost_offset), - nextstate); - ofst->AddArc(cur_state, arc); - } - if (f == frame_end) { - if (use_final_probs && !final_costs.empty()) { - typename unordered_map::const_iterator - iter = final_costs.find(tok); - if (iter != final_costs.end()) - ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); - } else { - ofst->SetFinal(cur_state, LatticeWeight::One()); - } - } - } - } - // Create final_arc for later appending with the next chunk - if (create_final_state) { - state_label_map_.clear(); - state_label_map_.reserve(std::min((int32)1e5, config_.max_active)); - for (Token *tok = active_toks_[frame_end].toks; tok != NULL; tok = tok->next) { - StateId cur_state = tok_map[tok]; - int32 id = state_label_avilable_idx_++; - state_label_map_[tok] = id; - Weight final_weight = (!decoding_finalized_ && ofst->Final(cur_state) == Weight::Zero())? Weight::One(): ofst->Final(cur_state); - - Arc arc(0, id, - Weight(0, final_weight.Value1()+final_weight.Value2()), - end_state); - ofst->AddArc(cur_state, arc); - ofst->SetFinal(cur_state, Weight::Zero()); - } - } - return (ofst->NumStates() > 0); +bool LatticeIncrementalDecoderTpl::GetCompactLattice( + CompactLattice *ofst) const { + *ofst = lat_; + return (ofst->NumStates() != 0); } template @@ -1044,6 +927,7 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, bool redeterminize, CompactLattice *olat) { using namespace fst; + CompactLatticeWriter lattice_writer("ark,t:/tmp/lat.1"); // TODO if (last_get_lattice_frame_ < NumFramesDecoded()) { // Get lattice chunk with initial state Lattice raw_fst; @@ -1061,22 +945,20 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, final_arc_list_.swap(final_arc_list_prev_); final_arc_list_.clear(); + if (redeterminize) lattice_writer.Write("TODO1", *olat); // TODO // Appending new chunk to the old one int32 state_offset=olat->NumStates(); + if (last_get_lattice_frame_ != 0) state_offset--; // since we do not append initial state unordered_map initial_arc_map; // the previous states of these arcs are initial states initial_arc_map.reserve(std::min((int32)1e5, config_.max_active)); for (StateIterator siter(clat); !siter.Done(); siter.Next()) { auto s = siter.Value(); StateId state_append = -1; - if (last_get_lattice_frame_ == 0) { // do not need to copy initial state + if (last_get_lattice_frame_ == 0 || s != 0) { // do not need to copy initial state state_append = s+state_offset; KALDI_ASSERT(state_append == olat->AddState()); olat->SetFinal(state_append, clat.Final(s)); - } else if (s != 0) { - state_append = s+state_offset-1; // do not include the first state - KALDI_ASSERT(state_append == olat->AddState()); - olat->SetFinal(state_append, clat.Final(s)); - } + } for (ArcIterator aiter(clat, s); !aiter.Done(); aiter.Next()) { const auto &arc = aiter.Value(); @@ -1090,6 +972,7 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, if (s==0) { // initial_arc initial_arc_map[arc.olabel]=aiter.Position(); } else { // final_arc + KALDI_ASSERT(clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()); final_arc_list_.push_back(pair(state_append, aiter.Position())); } } @@ -1129,18 +1012,21 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, arc_append_mod.weight = CompactLatticeWeight(LatticeWeight(v1,v2), s); arc_append_mod.olabel = 0; + arc_append_mod.ilabel = 0; aiter.SetValue(arc_append_mod); } // otherwise, it has been pruned state_label_forward_prob_.erase(arc_append.olabel); } + KALDI_ASSERT(prev_final_state != -1); // at least one arc should be appended // making all unmodified remaining arcs of final_arc_list_prev_ are connected to a dead state olat->SetFinal(prev_final_state, CompactLatticeWeight::Zero()); - } + } else olat->SetStart(0); KALDI_VLOG(2) << "Frame: " <NumStates(); } // TODO: check in the case the last frame is det twice last_get_lattice_frame_ = NumFramesDecoded(); // Determinize the final lattice if (redeterminize) { + lattice_writer.Write("TODO2", *olat); // TODO DeterminizeLatticePrunedOptions det_opts; det_opts.delta = config_.det_opts.delta; det_opts.max_mem = config_.det_opts.max_mem; @@ -1162,14 +1048,143 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, det_opts)) KALDI_WARN << "Determinization finished earlier than the beam"; Connect(olat); // Remove unreachable states... there might be - } - + lattice_writer.Write("TODO3", *olat); + KALDI_VLOG(2) << "states of the lattice: "<NumStates(); + } // a small number of these, in some cases. // Note: if something went wrong and the raw lattice was empty, // we should still get to this point in the code without warnings or failures. return (olat->NumStates() != 0); } +template +bool LatticeIncrementalDecoderTpl::GetRawLattice( + Lattice *ofst, + bool use_final_probs, + int32 frame_begin, + int32 frame_end, + bool create_initial_state, + bool create_final_state) { + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + + // Note: you can't use the old interface (Decode()) if you want to + // get the lattice with use_final_probs = false. You'd have to do + // InitDecoding() and then AdvanceDecoding(). + if (decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "GetRawLattice() with use_final_probs == false"; + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (decoding_finalized_ ? final_costs_ : final_costs_local); + if (!decoding_finalized_ && use_final_probs) + ComputeFinalCosts(&final_costs_local, NULL, NULL); + + ofst->DeleteStates(); + if (frame_begin != 0) ofst->AddState(); // initial-state for the chunk + // num-frames plus one (since frames are one-based, and we have + // an extra frame for the start-state). + KALDI_ASSERT(frame_end > 0); + const int32 bucket_count = num_toks_/2 + 3; + unordered_map tok_map(bucket_count); + // First create all states. + std::vector token_list; + for (int32 f = frame_begin; f <= frame_end; f++) { + if (active_toks_[f].toks == NULL) { + KALDI_WARN << "GetRawLattice: no tokens active on frame " << f + << ": not producing lattice.\n"; + return false; + } + TopSortTokens(active_toks_[f].toks, &token_list); + for (size_t i = 0; i < token_list.size(); i++) + if (token_list[i] != NULL) + tok_map[token_list[i]] = ofst->AddState(); + } + // The next statement sets the start state of the output FST. Because we + // topologically sorted the tokens, state zero must be the start-state. + StateId begin_state = 0; + ofst->SetStart(begin_state); + + KALDI_VLOG(4) << "init:" << num_toks_/2 + 3 << " buckets:" + << tok_map.bucket_count() << " load:" << tok_map.load_factor() + << " max:" << tok_map.max_load_factor(); + // Create initial_arc for later appending with the previous chunk + if (create_initial_state) { + for (Token *tok = active_toks_[frame_begin].toks; tok != NULL; tok = tok->next) { + StateId cur_state = tok_map[tok]; + int32 id = state_label_map_.find(tok)->second; // it should exist + // TODO: calculate alpha but not use tot_cost or extra_cost + BaseFloat cost_offset = tok->tot_cost; + state_label_forward_prob_[id] = tok->tot_cost; + Arc arc(0, id, + Weight(0, cost_offset), + cur_state); + ofst->AddArc(begin_state, arc); + } + } + // Now create all arcs. + for (int32 f = frame_begin; f <= frame_end; f++) { + for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { + StateId cur_state = tok_map[tok]; + for (ForwardLinkT *l = tok->links; + l != NULL; + l = l->next) { + if (f==frame_begin && create_initial_state && l->ilabel==0) continue; // has existed in the last chunk + if (f==frame_end && create_final_state && l->ilabel!=0) continue; // will exist in the next chunk + typename unordered_map::const_iterator + iter = tok_map.find(l->next_tok); + StateId nextstate = iter->second; + KALDI_ASSERT(iter != tok_map.end()); + BaseFloat cost_offset = 0.0; + if (l->ilabel != 0) { // emitting.. + KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); + cost_offset = cost_offsets_[f]; + } + Arc arc(l->ilabel, l->olabel, + Weight(l->graph_cost, l->acoustic_cost - cost_offset), + nextstate); + ofst->AddArc(cur_state, arc); + } + if (f == frame_end) { + if (use_final_probs && !final_costs.empty()) { + typename unordered_map::const_iterator + iter = final_costs.find(tok); + if (iter != final_costs.end()) + ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); + } else { + ofst->SetFinal(cur_state, LatticeWeight::One()); + } + } + } + } + // Create final_arc for later appending with the next chunk + if (create_final_state) { + StateId end_state = ofst->AddState(); // final-state for the chunk + ofst->SetFinal(end_state, Weight::One()); + + state_label_map_.clear(); + state_label_map_.reserve(std::min((int32)1e5, config_.max_active)); + for (Token *tok = active_toks_[frame_end].toks; tok != NULL; tok = tok->next) { + StateId cur_state = tok_map[tok]; + int32 id = state_label_avilable_idx_++; + state_label_map_[tok] = id; + Weight final_weight = (!decoding_finalized_ && ofst->Final(cur_state) == Weight::Zero())? Weight::One(): ofst->Final(cur_state); + + Arc arc(0, id, + final_weight, + end_state); + ofst->AddArc(cur_state, arc); + ofst->SetFinal(cur_state, Weight::Zero()); + } + } + return (ofst->NumStates() > 0); +} + + // Instantiate the template for the combination of token types and FST types // that we'll need. template class LatticeIncrementalDecoderTpl, decoder::StdToken>; diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index ae033d865ae..c7436002506 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -171,13 +171,7 @@ class LatticeIncrementalDecoderTpl { /// you want it pruned tighter than the regular lattice beam. /// We could put that here in future needed. bool GetRawLattice(Lattice *ofst, bool use_final_probs = true) const; - - bool GetRawLattice(Lattice *ofst, bool use_final_probs, - int32 frame_begin, - int32 frame_end, - bool create_initial_state, - bool create_final_state); - + bool GetCompactLattice(CompactLattice *ofst) const; /// InitDecoding initializes the decoding, and should only be used if you @@ -430,6 +424,13 @@ class LatticeIncrementalDecoderTpl { std::vector> final_arc_list_prev_; // TODO use 2 vector since state_label is continuous in each frame, and we need 2 frames unordered_map state_label_forward_prob_; // alpha for each state_label (Token) + bool GetRawLattice(Lattice *ofst, bool use_final_probs, + int32 frame_begin, + int32 frame_end, + bool create_initial_state, + bool create_final_state); + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDecoderTpl); }; From 6f92369636f7a2b02a524b66aa2819ee83213048 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Wed, 27 Mar 2019 21:01:09 +0800 Subject: [PATCH 04/60] clean code --- src/decoder/lattice-incremental-decoder.cc | 544 ++++++++++----------- src/decoder/lattice-incremental-decoder.h | 170 +++---- 2 files changed, 358 insertions(+), 356 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index baf0cbd7b5c..103e7fbeee7 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -28,26 +28,30 @@ namespace kaldi { // instantiate this class once for each thing you have to decode. template LatticeIncrementalDecoderTpl::LatticeIncrementalDecoderTpl( - const FST &fst, - const TransitionModel& trans_model, - const LatticeIncrementalDecoderConfig &config): - fst_(&fst), delete_fst_(false), config_(config), num_toks_(0), trans_model_(trans_model) { + const FST &fst, const TransitionModel &trans_model, + const LatticeIncrementalDecoderConfig &config) + : fst_(&fst), + delete_fst_(false), + config_(config), + num_toks_(0), + trans_model_(trans_model) { config.Check(); - toks_.SetSize(1000); // just so on the first frame we do something reasonable. + toks_.SetSize(1000); // just so on the first frame we do something reasonable. } - template LatticeIncrementalDecoderTpl::LatticeIncrementalDecoderTpl( const LatticeIncrementalDecoderConfig &config, FST *fst, - const TransitionModel& trans_model): - fst_(fst), delete_fst_(true), config_(config), num_toks_(0), -trans_model_(trans_model) { + const TransitionModel &trans_model) + : fst_(fst), + delete_fst_(true), + config_(config), + num_toks_(0), + trans_model_(trans_model) { config.Check(); - toks_.SetSize(1000); // just so on the first frame we do something reasonable. + toks_.SetSize(1000); // just so on the first frame we do something reasonable. } - template LatticeIncrementalDecoderTpl::~LatticeIncrementalDecoderTpl() { DeleteElems(toks_.Clear()); @@ -77,7 +81,7 @@ void LatticeIncrementalDecoderTpl::InitDecoding() { last_get_lattice_frame_ = 0; state_label_map_.clear(); state_label_map_.reserve(std::min((int32)1e5, config_.max_active)); - state_label_avilable_idx_ = config_.max_word_id+1; + state_label_avilable_idx_ = config_.max_word_id + 1; final_arc_list_.clear(); final_arc_list_prev_.clear(); state_label_forward_prob_.clear(); @@ -89,7 +93,8 @@ void LatticeIncrementalDecoderTpl::InitDecoding() { // a final state). It should only very rarely return false; this indicates // an unusual search error. template -bool LatticeIncrementalDecoderTpl::Decode(DecodableInterface *decodable) { +bool LatticeIncrementalDecoderTpl::Decode( + DecodableInterface *decodable) { InitDecoding(); // We use 1-based indexing for frames in this decoder (if you view it in @@ -113,11 +118,10 @@ bool LatticeIncrementalDecoderTpl::Decode(DecodableInterface *decoda return !active_toks_.empty() && active_toks_.back().toks != NULL; } - // Outputs an FST corresponding to the single best path through the lattice. template -bool LatticeIncrementalDecoderTpl::GetBestPath(Lattice *olat, - bool use_final_probs) const { +bool LatticeIncrementalDecoderTpl::GetBestPath( + Lattice *olat, bool use_final_probs) const { CompactLattice lat; ShortestPath(lat_, &lat); ConvertLattice(lat, olat); @@ -127,8 +131,7 @@ bool LatticeIncrementalDecoderTpl::GetBestPath(Lattice *olat, // Outputs an FST corresponding to the raw, state-level lattice template bool LatticeIncrementalDecoderTpl::GetRawLattice( - Lattice *ofst, - bool use_final_probs) const { + Lattice *ofst, bool use_final_probs) const { ConvertLattice(lat_, ofst); Connect(ofst); return (ofst->NumStates() != 0); @@ -143,8 +146,8 @@ bool LatticeIncrementalDecoderTpl::GetCompactLattice( template void LatticeIncrementalDecoderTpl::PossiblyResizeHash(size_t num_toks) { - size_t new_sz = static_cast(static_cast(num_toks) - * config_.hash_ratio); + size_t new_sz = + static_cast(static_cast(num_toks) * config_.hash_ratio); if (new_sz > toks_.Size()) { toks_.SetSize(new_sz); } @@ -182,20 +185,20 @@ void LatticeIncrementalDecoderTpl::PossiblyResizeHash(size_t num_tok // and also into the singly linked list of tokens active on this frame // (whose head is at active_toks_[frame]). template -inline Token* LatticeIncrementalDecoderTpl::FindOrAddToken( - StateId state, int32 frame_plus_one, BaseFloat tot_cost, - Token *backpointer, bool *changed) { +inline Token *LatticeIncrementalDecoderTpl::FindOrAddToken( + StateId state, int32 frame_plus_one, BaseFloat tot_cost, Token *backpointer, + bool *changed) { // Returns the Token pointer. Sets "changed" (if non-NULL) to true // if the token was newly created or the cost changed. KALDI_ASSERT(frame_plus_one < active_toks_.size()); Token *&toks = active_toks_[frame_plus_one].toks; Elem *e_found = toks_.Find(state); - if (e_found == NULL) { // no such token presently. + if (e_found == NULL) { // no such token presently. const BaseFloat extra_cost = 0.0; // tokens on the currently final frame have zero extra_cost // as any of them could end up // on the winning path. - Token *new_tok = new Token (tot_cost, extra_cost, NULL, toks, backpointer); + Token *new_tok = new Token(tot_cost, extra_cost, NULL, toks, backpointer); // NULL: no forward links yet toks = new_tok; num_toks_++; @@ -203,8 +206,8 @@ inline Token* LatticeIncrementalDecoderTpl::FindOrAddToken( if (changed) *changed = true; return new_tok; } else { - Token *tok = e_found->val; // There is an existing Token for this state. - if (tok->tot_cost > tot_cost) { // replace old token + Token *tok = e_found->val; // There is an existing Token for this state. + if (tok->tot_cost > tot_cost) { // replace old token tok->tot_cost = tot_cost; // SetBackpointer() just does tok->backpointer = backpointer in // the case where Token == BackpointerToken, else nothing. @@ -229,8 +232,8 @@ inline Token* LatticeIncrementalDecoderTpl::FindOrAddToken( // all links, that have link_extra_cost > lattice_beam are pruned template void LatticeIncrementalDecoderTpl::PruneForwardLinks( - int32 frame_plus_one, bool *extra_costs_changed, - bool *links_pruned, BaseFloat delta) { + int32 frame_plus_one, bool *extra_costs_changed, bool *links_pruned, + BaseFloat delta) { // delta is the amount by which the extra_costs must change // If delta is larger, we'll tend to go back less far // toward the beginning of the file. @@ -240,59 +243,61 @@ void LatticeIncrementalDecoderTpl::PruneForwardLinks( *extra_costs_changed = false; *links_pruned = false; KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); - if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen. + if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen. if (!warned_) { KALDI_WARN << "No tokens alive [doing pruning].. warning first " - "time only for each utterance\n"; + "time only for each utterance\n"; warned_ = true; } } // We have to iterate until there is no more change, because the links // are not guaranteed to be in topological order. - bool changed = true; // difference new minus old extra cost >= delta ? + bool changed = true; // difference new minus old extra cost >= delta ? while (changed) { changed = false; - for (Token *tok = active_toks_[frame_plus_one].toks; - tok != NULL; tok = tok->next) { + for (Token *tok = active_toks_[frame_plus_one].toks; tok != NULL; + tok = tok->next) { ForwardLinkT *link, *prev_link = NULL; // will recompute tok_extra_cost for tok. BaseFloat tok_extra_cost = std::numeric_limits::infinity(); // tok_extra_cost is the best (min) of link_extra_cost of outgoing links - for (link = tok->links; link != NULL; ) { + for (link = tok->links; link != NULL;) { // See if we need to excise this link... Token *next_tok = link->next_tok; - BaseFloat link_extra_cost = next_tok->extra_cost + - ((tok->tot_cost + link->acoustic_cost + link->graph_cost) - - next_tok->tot_cost); // difference in brackets is >= 0 + BaseFloat link_extra_cost = + next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) - + next_tok->tot_cost); // difference in brackets is >= 0 // link_exta_cost is the difference in score between the best paths // through link source state and through link destination state - KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN - if (link_extra_cost > config_.lattice_beam) { // excise link + KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN + if (link_extra_cost > config_.lattice_beam) { // excise link ForwardLinkT *next_link = link->next; - if (prev_link != NULL) prev_link->next = next_link; - else tok->links = next_link; + if (prev_link != NULL) + prev_link->next = next_link; + else + tok->links = next_link; delete link; - link = next_link; // advance link but leave prev_link the same. + link = next_link; // advance link but leave prev_link the same. *links_pruned = true; - } else { // keep the link and update the tok_extra_cost if needed. - if (link_extra_cost < 0.0) { // this is just a precaution. + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. if (link_extra_cost < -0.01) KALDI_WARN << "Negative extra_cost: " << link_extra_cost; link_extra_cost = 0.0; } - if (link_extra_cost < tok_extra_cost) - tok_extra_cost = link_extra_cost; - prev_link = link; // move to next link + if (link_extra_cost < tok_extra_cost) tok_extra_cost = link_extra_cost; + prev_link = link; // move to next link link = link->next; } - } // for all outgoing links + } // for all outgoing links if (fabs(tok_extra_cost - tok->extra_cost) > delta) - changed = true; // difference new minus old is bigger than delta + changed = true; // difference new minus old is bigger than delta tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. // infinity indicates, that no forward link survived pruning - } // for all Token on active_toks_[frame] + } // for all Token on active_toks_[frame] if (changed) *extra_costs_changed = true; // Note: it's theoretically possible that aggressive compiler @@ -309,10 +314,10 @@ void LatticeIncrementalDecoderTpl::PruneForwardLinksFinal() { KALDI_ASSERT(!active_toks_.empty()); int32 frame_plus_one = active_toks_.size() - 1; - if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen. + if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen. KALDI_WARN << "No tokens alive at end of file"; - typedef typename unordered_map::const_iterator IterType; + typedef typename unordered_map::const_iterator IterType; ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_); decoding_finalized_ = true; // We call DeleteElems() as a nicety, not because it's really necessary; @@ -329,12 +334,13 @@ void LatticeIncrementalDecoderTpl::PruneForwardLinksFinal() { BaseFloat delta = 1.0e-05; while (changed) { changed = false; - for (Token *tok = active_toks_[frame_plus_one].toks; - tok != NULL; tok = tok->next) { + for (Token *tok = active_toks_[frame_plus_one].toks; tok != NULL; + tok = tok->next) { ForwardLinkT *link, *prev_link = NULL; // will recompute tok_extra_cost. It has a term in it that corresponds // to the "final-prob", so instead of initializing tok_extra_cost to infinity - // below we set it to the difference between the (score+final_prob) of this token, + // below we set it to the difference between the (score+final_prob) of this + // token, // and the best such (score+final_prob). BaseFloat final_cost; if (final_costs_.empty()) { @@ -350,26 +356,28 @@ void LatticeIncrementalDecoderTpl::PruneForwardLinksFinal() { // tok_extra_cost will be a "min" over either directly being final, or // being indirectly final through other links, and the loop below may // decrease its value: - for (link = tok->links; link != NULL; ) { + for (link = tok->links; link != NULL;) { // See if we need to excise this link... Token *next_tok = link->next_tok; - BaseFloat link_extra_cost = next_tok->extra_cost + - ((tok->tot_cost + link->acoustic_cost + link->graph_cost) - - next_tok->tot_cost); - if (link_extra_cost > config_.lattice_beam) { // excise link + BaseFloat link_extra_cost = + next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) - + next_tok->tot_cost); + if (link_extra_cost > config_.lattice_beam) { // excise link ForwardLinkT *next_link = link->next; - if (prev_link != NULL) prev_link->next = next_link; - else tok->links = next_link; + if (prev_link != NULL) + prev_link->next = next_link; + else + tok->links = next_link; delete link; link = next_link; // advance link but leave prev_link the same. - } else { // keep the link and update the tok_extra_cost if needed. + } else { // keep the link and update the tok_extra_cost if needed. if (link_extra_cost < 0.0) { // this is just a precaution. if (link_extra_cost < -0.01) KALDI_WARN << "Negative extra_cost: " << link_extra_cost; link_extra_cost = 0.0; } - if (link_extra_cost < tok_extra_cost) - tok_extra_cost = link_extra_cost; + if (link_extra_cost < tok_extra_cost) tok_extra_cost = link_extra_cost; prev_link = link; link = link->next; } @@ -382,8 +390,7 @@ void LatticeIncrementalDecoderTpl::PruneForwardLinksFinal() { tok_extra_cost = std::numeric_limits::infinity(); // to be pruned in PruneTokensForFrame - if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) - changed = true; + if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) changed = true; tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. } } // while changed @@ -402,28 +409,29 @@ BaseFloat LatticeIncrementalDecoderTpl::FinalRelativeCost() const { } } - // Prune away any tokens on this frame that have no forward links. // [we don't do this in PruneForwardLinks because it would give us // a problem with dangling pointers]. // It's called by PruneActiveTokens if any forward links have been pruned template -void LatticeIncrementalDecoderTpl::PruneTokensForFrame(int32 frame_plus_one) { +void LatticeIncrementalDecoderTpl::PruneTokensForFrame( + int32 frame_plus_one) { KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); Token *&toks = active_toks_[frame_plus_one].toks; - if (toks == NULL) - KALDI_WARN << "No tokens alive [doing pruning]"; + if (toks == NULL) KALDI_WARN << "No tokens alive [doing pruning]"; Token *tok, *next_tok, *prev_tok = NULL; for (tok = toks; tok != NULL; tok = next_tok) { next_tok = tok->next; if (tok->extra_cost == std::numeric_limits::infinity()) { // token is unreachable from end of graph; (no forward links survived) // excise tok from list and delete tok. - if (prev_tok != NULL) prev_tok->next = tok->next; - else toks = tok->next; + if (prev_tok != NULL) + prev_tok->next = tok->next; + else + toks = tok->next; delete tok; num_toks_--; - } else { // fetch next Token + } else { // fetch next Token prev_tok = tok; } } @@ -449,15 +457,15 @@ void LatticeIncrementalDecoderTpl::PruneActiveTokens(BaseFloat delta bool extra_costs_changed = false, links_pruned = false; PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta); if (extra_costs_changed && f > 0) // any token has changed extra_cost - active_toks_[f-1].must_prune_forward_links = true; + active_toks_[f - 1].must_prune_forward_links = true; if (links_pruned) // any link was pruned active_toks_[f].must_prune_tokens = true; active_toks_[f].must_prune_forward_links = false; // job done } - if (f+1 < cur_frame_plus_one && // except for last f (no forward links) - active_toks_[f+1].must_prune_tokens) { - PruneTokensForFrame(f+1); - active_toks_[f+1].must_prune_tokens = false; + if (f + 1 < cur_frame_plus_one && // except for last f (no forward links) + active_toks_[f + 1].must_prune_tokens) { + PruneTokensForFrame(f + 1); + active_toks_[f + 1].must_prune_tokens = false; } } KALDI_VLOG(4) << "PruneActiveTokens: pruned tokens from " << num_toks_begin @@ -466,24 +474,20 @@ void LatticeIncrementalDecoderTpl::PruneActiveTokens(BaseFloat delta template void LatticeIncrementalDecoderTpl::ComputeFinalCosts( - unordered_map *final_costs, - BaseFloat *final_relative_cost, + unordered_map *final_costs, BaseFloat *final_relative_cost, BaseFloat *final_best_cost) const { KALDI_ASSERT(!decoding_finalized_); - if (final_costs != NULL) - final_costs->clear(); + if (final_costs != NULL) final_costs->clear(); const Elem *final_toks = toks_.GetList(); BaseFloat infinity = std::numeric_limits::infinity(); - BaseFloat best_cost = infinity, - best_cost_with_final = infinity; + BaseFloat best_cost = infinity, best_cost_with_final = infinity; while (final_toks != NULL) { StateId state = final_toks->key; Token *tok = final_toks->val; const Elem *next = final_toks->tail; BaseFloat final_cost = fst_->Final(state).Value(); - BaseFloat cost = tok->tot_cost, - cost_with_final = cost + final_cost; + BaseFloat cost = tok->tot_cost, cost_with_final = cost + final_cost; best_cost = std::min(cost, best_cost); best_cost_with_final = std::min(cost_with_final, best_cost_with_final); if (final_costs != NULL && final_cost != infinity) @@ -509,26 +513,29 @@ void LatticeIncrementalDecoderTpl::ComputeFinalCosts( } template -void LatticeIncrementalDecoderTpl::AdvanceDecoding(DecodableInterface *decodable, - int32 max_num_frames) { +void LatticeIncrementalDecoderTpl::AdvanceDecoding( + DecodableInterface *decodable, int32 max_num_frames) { if (std::is_same >::value) { // if the type 'FST' is the FST base-class, then see if the FST type of fst_ // is actually VectorFst or ConstFst. If so, call the AdvanceDecoding() // function after casting *this to the more specific type. if (fst_->Type() == "const") { LatticeIncrementalDecoderTpl, Token> *this_cast = - reinterpret_cast, Token>* >(this); + reinterpret_cast< + LatticeIncrementalDecoderTpl, Token> *>( + this); this_cast->AdvanceDecoding(decodable, max_num_frames); return; } else if (fst_->Type() == "vector") { LatticeIncrementalDecoderTpl, Token> *this_cast = - reinterpret_cast, Token>* >(this); + reinterpret_cast< + LatticeIncrementalDecoderTpl, Token> *>( + this); this_cast->AdvanceDecoding(decodable, max_num_frames); return; } } - KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ && "You must call InitDecoding() before AdvanceDecoding"); int32 num_frames_ready = decodable->NumFramesReady(); @@ -539,8 +546,8 @@ void LatticeIncrementalDecoderTpl::AdvanceDecoding(DecodableInterfac KALDI_ASSERT(num_frames_ready >= NumFramesDecoded()); int32 target_frames_decoded = num_frames_ready; if (max_num_frames >= 0) - target_frames_decoded = std::min(target_frames_decoded, - NumFramesDecoded() + max_num_frames); + target_frames_decoded = + std::min(target_frames_decoded, NumFramesDecoded() + max_num_frames); while (NumFramesDecoded() < target_frames_decoded) { if (NumFramesDecoded() % config_.prune_interval == 0) { PruneActiveTokens(config_.lattice_beam * config_.prune_scale); @@ -561,20 +568,19 @@ void LatticeIncrementalDecoderTpl::FinalizeDecoding() { // sets decoding_finalized_. PruneForwardLinksFinal(); for (int32 f = final_frame_plus_one - 1; f >= 0; f--) { - bool b1, b2; // values not used. + bool b1, b2; // values not used. BaseFloat dontcare = 0.0; // delta of zero means we must always update PruneForwardLinks(f, &b1, &b2, dontcare); PruneTokensForFrame(f + 1); } PruneTokensForFrame(0); - KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin - << " to " << num_toks_; + KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin << " to " << num_toks_; } /// Gets the weight cutoff. Also counts the active tokens. template -BaseFloat LatticeIncrementalDecoderTpl::GetCutoff(Elem *list_head, size_t *tok_count, - BaseFloat *adaptive_beam, Elem **best_elem) { +BaseFloat LatticeIncrementalDecoderTpl::GetCutoff( + Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, Elem **best_elem) { BaseFloat best_weight = std::numeric_limits::infinity(); // positive == high cost == bad. size_t count = 0; @@ -603,15 +609,14 @@ BaseFloat LatticeIncrementalDecoderTpl::GetCutoff(Elem *list_head, s if (tok_count != NULL) *tok_count = count; BaseFloat beam_cutoff = best_weight + config_.beam, - min_active_cutoff = std::numeric_limits::infinity(), - max_active_cutoff = std::numeric_limits::infinity(); + min_active_cutoff = std::numeric_limits::infinity(), + max_active_cutoff = std::numeric_limits::infinity(); KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() << " is " << tmp_array_.size(); if (tmp_array_.size() > static_cast(config_.max_active)) { - std::nth_element(tmp_array_.begin(), - tmp_array_.begin() + config_.max_active, + std::nth_element(tmp_array_.begin(), tmp_array_.begin() + config_.max_active, tmp_array_.end()); max_active_cutoff = tmp_array_[config_.max_active]; } @@ -621,13 +626,13 @@ BaseFloat LatticeIncrementalDecoderTpl::GetCutoff(Elem *list_head, s return max_active_cutoff; } if (tmp_array_.size() > static_cast(config_.min_active)) { - if (config_.min_active == 0) min_active_cutoff = best_weight; + if (config_.min_active == 0) + min_active_cutoff = best_weight; else { - std::nth_element(tmp_array_.begin(), - tmp_array_.begin() + config_.min_active, - tmp_array_.size() > static_cast(config_.max_active) ? - tmp_array_.begin() + config_.max_active : - tmp_array_.end()); + std::nth_element(tmp_array_.begin(), tmp_array_.begin() + config_.min_active, + tmp_array_.size() > static_cast(config_.max_active) + ? tmp_array_.begin() + config_.max_active + : tmp_array_.end()); min_active_cutoff = tmp_array_[config_.min_active]; } } @@ -652,8 +657,8 @@ BaseFloat LatticeIncrementalDecoderTpl::ProcessEmitting( active_toks_.resize(active_toks_.size() + 1); Elem *final_toks = toks_.Clear(); // analogous to swapping prev_toks_ / cur_toks_ - // in simple-decoder.h. Removes the Elems from - // being indexed in the hash in toks_. + // in simple-decoder.h. Removes the Elems from + // being indexed in the hash in toks_. Elem *best_elem = NULL; BaseFloat adaptive_beam; size_t tok_cnt; @@ -661,7 +666,7 @@ BaseFloat LatticeIncrementalDecoderTpl::ProcessEmitting( KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is " << adaptive_beam; - PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough. + PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough. BaseFloat next_cutoff = std::numeric_limits::infinity(); // pruning "online" before having seen all tokens @@ -669,21 +674,19 @@ BaseFloat LatticeIncrementalDecoderTpl::ProcessEmitting( BaseFloat cost_offset = 0.0; // Used to keep probabilities in a good // dynamic range. - // First process the best token to get a hopefully // reasonably tight bound on the next cutoff. The only // products of the next block are "next_cutoff" and "cost_offset". if (best_elem) { StateId state = best_elem->key; Token *tok = best_elem->val; - cost_offset = - tok->tot_cost; - for (fst::ArcIterator aiter(*fst_, state); - !aiter.Done(); - aiter.Next()) { + cost_offset = -tok->tot_cost; + for (fst::ArcIterator aiter(*fst_, state); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); - if (arc.ilabel != 0) { // propagate.. + if (arc.ilabel != 0) { // propagate.. BaseFloat new_weight = arc.weight.Value() + cost_offset - - decodable->LogLikelihood(frame, arc.ilabel) + tok->tot_cost; + decodable->LogLikelihood(frame, arc.ilabel) + + tok->tot_cost; if (new_weight + adaptive_beam < next_cutoff) next_cutoff = new_weight + adaptive_beam; } @@ -704,28 +707,26 @@ BaseFloat LatticeIncrementalDecoderTpl::ProcessEmitting( StateId state = e->key; Token *tok = e->val; if (tok->tot_cost <= cur_cutoff) { - for (fst::ArcIterator aiter(*fst_, state); - !aiter.Done(); - aiter.Next()) { + for (fst::ArcIterator aiter(*fst_, state); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); - if (arc.ilabel != 0) { // propagate.. - BaseFloat ac_cost = cost_offset - - decodable->LogLikelihood(frame, arc.ilabel), - graph_cost = arc.weight.Value(), - cur_cost = tok->tot_cost, - tot_cost = cur_cost + ac_cost + graph_cost; - if (tot_cost > next_cutoff) continue; + if (arc.ilabel != 0) { // propagate.. + BaseFloat ac_cost = + cost_offset - decodable->LogLikelihood(frame, arc.ilabel), + graph_cost = arc.weight.Value(), cur_cost = tok->tot_cost, + tot_cost = cur_cost + ac_cost + graph_cost; + if (tot_cost > next_cutoff) + continue; else if (tot_cost + adaptive_beam < next_cutoff) next_cutoff = tot_cost + adaptive_beam; // prune by best current token // Note: the frame indexes into active_toks_ are one-based, // hence the + 1. - Token *next_tok = FindOrAddToken(arc.nextstate, - frame + 1, tot_cost, tok, NULL); + Token *next_tok = + FindOrAddToken(arc.nextstate, frame + 1, tot_cost, tok, NULL); // NULL: no change indicator needed // Add ForwardLink from tok to next_tok (put on head of list tok->links) - tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, - graph_cost, ac_cost, tok->links); + tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, graph_cost, + ac_cost, tok->links); } } // for all arcs } @@ -747,7 +748,6 @@ void LatticeIncrementalDecoderTpl::DeleteForwardLinks(Token *tok) { tok->links = NULL; } - template void LatticeIncrementalDecoderTpl::ProcessNonemitting(BaseFloat cutoff) { KALDI_ASSERT(!active_toks_.empty()); @@ -771,17 +771,18 @@ void LatticeIncrementalDecoderTpl::ProcessNonemitting(BaseFloat cuto } } - for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) { + for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) { StateId state = e->key; - if (fst_->NumInputEpsilons(state) != 0) - queue_.push_back(state); + if (fst_->NumInputEpsilons(state) != 0) queue_.push_back(state); } while (!queue_.empty()) { StateId state = queue_.back(); queue_.pop_back(); - Token *tok = toks_.Find(state)->val; // would segfault if state not in toks_ but this can't happen. + Token *tok = + toks_.Find(state) + ->val; // would segfault if state not in toks_ but this can't happen. BaseFloat cur_cost = tok->tot_cost; if (cur_cost > cutoff) // Don't bother processing successors. continue; @@ -791,21 +792,18 @@ void LatticeIncrementalDecoderTpl::ProcessNonemitting(BaseFloat cuto // but since most states are emitting it's not a huge issue. DeleteForwardLinks(tok); // necessary when re-visiting tok->links = NULL; - for (fst::ArcIterator aiter(*fst_, state); - !aiter.Done(); - aiter.Next()) { + for (fst::ArcIterator aiter(*fst_, state); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); - if (arc.ilabel == 0) { // propagate nonemitting only... - BaseFloat graph_cost = arc.weight.Value(), - tot_cost = cur_cost + graph_cost; + if (arc.ilabel == 0) { // propagate nonemitting only... + BaseFloat graph_cost = arc.weight.Value(), tot_cost = cur_cost + graph_cost; if (tot_cost < cutoff) { bool changed; - Token *new_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, - tok, &changed); + Token *new_tok = + FindOrAddToken(arc.nextstate, frame + 1, tot_cost, tok, &changed); - tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, - graph_cost, 0, tok->links); + tok->links = + new ForwardLinkT(new_tok, 0, arc.olabel, graph_cost, 0, tok->links); // "changed" tells us whether the new token has a different // cost from before, or is new [if so, add into queue]. @@ -814,10 +812,9 @@ void LatticeIncrementalDecoderTpl::ProcessNonemitting(BaseFloat cuto } } } // for all arcs - } // while queue not empty + } // while queue not empty } - template void LatticeIncrementalDecoderTpl::DeleteElems(Elem *list) { for (Elem *e = list, *e_tail; e != NULL; e = e_tail) { @@ -827,11 +824,12 @@ void LatticeIncrementalDecoderTpl::DeleteElems(Elem *list) { } template -void LatticeIncrementalDecoderTpl::ClearActiveTokens() { // a cleanup routine, at utt end/begin +void LatticeIncrementalDecoderTpl< + FST, Token>::ClearActiveTokens() { // a cleanup routine, at utt end/begin for (size_t i = 0; i < active_toks_.size(); i++) { // Delete all tokens alive on this frame, and any forward // links they may have. - for (Token *tok = active_toks_[i].toks; tok != NULL; ) { + for (Token *tok = active_toks_[i].toks; tok != NULL;) { DeleteForwardLinks(tok); Token *next_tok = tok->next; delete tok; @@ -846,12 +844,11 @@ void LatticeIncrementalDecoderTpl::ClearActiveTokens() { // a cleanu // static template void LatticeIncrementalDecoderTpl::TopSortTokens( - Token *tok_list, std::vector *topsorted_list) { - unordered_map token2pos; - typedef typename unordered_map::iterator IterType; + Token *tok_list, std::vector *topsorted_list) { + unordered_map token2pos; + typedef typename unordered_map::iterator IterType; int32 num_toks = 0; - for (Token *tok = tok_list; tok != NULL; tok = tok->next) - num_toks++; + for (Token *tok = tok_list; tok != NULL; tok = tok->next) num_toks++; int32 cur_pos = 0; // We assign the tokens numbers num_toks - 1, ... , 2, 1, 0. // This is likely to be in closer to topological order than @@ -860,7 +857,7 @@ void LatticeIncrementalDecoderTpl::TopSortTokens( for (Token *tok = tok_list; tok != NULL; tok = tok->next) token2pos[tok] = num_toks - ++cur_pos; - unordered_set reprocess; + unordered_set reprocess; for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) { Token *tok = iter->first; @@ -887,14 +884,13 @@ void LatticeIncrementalDecoderTpl::TopSortTokens( } size_t max_loop = 1000000, loop_count; // max_loop is to detect epsilon cycles. - for (loop_count = 0; - !reprocess.empty() && loop_count < max_loop; ++loop_count) { - std::vector reprocess_vec; - for (typename unordered_set::iterator iter = reprocess.begin(); + for (loop_count = 0; !reprocess.empty() && loop_count < max_loop; ++loop_count) { + std::vector reprocess_vec; + for (typename unordered_set::iterator iter = reprocess.begin(); iter != reprocess.end(); ++iter) reprocess_vec.push_back(*iter); reprocess.clear(); - for (typename std::vector::iterator iter = reprocess_vec.begin(); + for (typename std::vector::iterator iter = reprocess_vec.begin(); iter != reprocess_vec.end(); ++iter) { Token *tok = *iter; int32 pos = token2pos[tok]; @@ -913,104 +909,111 @@ void LatticeIncrementalDecoderTpl::TopSortTokens( } } } - KALDI_ASSERT(loop_count < max_loop && "Epsilon loops exist in your decoding " + KALDI_ASSERT(loop_count < max_loop && + "Epsilon loops exist in your decoding " "graph (this is not allowed!)"); topsorted_list->clear(); - topsorted_list->resize(cur_pos, NULL); // create a list with NULLs in between. + topsorted_list->resize(cur_pos, NULL); // create a list with NULLs in between. for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) (*topsorted_list)[iter->second] = iter->first; } template bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, - bool redeterminize, CompactLattice *olat) { + bool redeterminize, + CompactLattice *olat) { using namespace fst; - CompactLatticeWriter lattice_writer("ark,t:/tmp/lat.1"); // TODO if (last_get_lattice_frame_ < NumFramesDecoded()) { - // Get lattice chunk with initial state Lattice raw_fst; - KALDI_ASSERT(GetRawLattice(&raw_fst, use_final_probs, last_get_lattice_frame_, NumFramesDecoded(), last_get_lattice_frame_ != 0, !decoding_finalized_)); - // Determinize the chunk + // step 1: Get lattice chunk with initial state + KALDI_ASSERT(GetRawLattice(&raw_fst, use_final_probs, last_get_lattice_frame_, + NumFramesDecoded(), last_get_lattice_frame_ != 0, + !decoding_finalized_)); + // step 2: Determinize the chunk CompactLattice clat; if (!DeterminizeLatticePhonePrunedWrapper( - trans_model_, - &raw_fst, - config_.lattice_beam, - &clat, - config_.det_opts)) - KALDI_WARN << "Determinization finished earlier than the beam"; - + trans_model_, &raw_fst, config_.lattice_beam, &clat, config_.det_opts)) + KALDI_WARN << "Determinization finished earlier than the beam"; + final_arc_list_.swap(final_arc_list_prev_); final_arc_list_.clear(); - if (redeterminize) lattice_writer.Write("TODO1", *olat); // TODO - // Appending new chunk to the old one - int32 state_offset=olat->NumStates(); - if (last_get_lattice_frame_ != 0) state_offset--; // since we do not append initial state - unordered_map initial_arc_map; // the previous states of these arcs are initial states + // step 3.1: Appending new chunk to the old one + int32 state_offset = olat->NumStates(); + if (last_get_lattice_frame_ != 0) + state_offset--; // since we do not append initial state + unordered_map + initial_arc_map; // the incoming states of these arcs are initial states initial_arc_map.reserve(std::min((int32)1e5, config_.max_active)); for (StateIterator siter(clat); !siter.Done(); siter.Next()) { auto s = siter.Value(); StateId state_append = -1; - if (last_get_lattice_frame_ == 0 || s != 0) { // do not need to copy initial state - state_append = s+state_offset; + if (last_get_lattice_frame_ == 0 || + s != 0) { // do not need to copy initial state + state_append = s + state_offset; KALDI_ASSERT(state_append == olat->AddState()); olat->SetFinal(state_append, clat.Final(s)); - } + } for (ArcIterator aiter(clat, s); !aiter.Done(); aiter.Next()) { const auto &arc = aiter.Value(); // construct a copy of the state & arcs - if (last_get_lattice_frame_ == 0 || s != 0) { // do not need to copy initial state + if (last_get_lattice_frame_ == 0 || + s != 0) { // do not need to copy initial arc CompactLatticeArc arc_append(arc); arc_append.nextstate += state_offset; olat->AddArc(state_append, arc_append); } if (arc.olabel > config_.max_word_id) { - if (s==0) { // initial_arc - initial_arc_map[arc.olabel]=aiter.Position(); + if (s == 0) { // initial_arc + initial_arc_map[arc.olabel] = aiter.Position(); } else { // final_arc KALDI_ASSERT(clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()); - final_arc_list_.push_back(pair(state_append, aiter.Position())); + final_arc_list_.push_back( + pair(state_append, aiter.Position())); } - } + } } } - // connect the states between two chunks + + // step 3.2: connect the states between two chunks if (last_get_lattice_frame_ != 0) { KALDI_ASSERT(final_arc_list_prev_.size()); StateId prev_final_state = -1; - for (auto&i:final_arc_list_prev_) { + for (auto &i : final_arc_list_prev_) { MutableArcIterator aiter(olat, i.first); aiter.Seek(i.second); auto &arc_append = aiter.Value(); auto r = initial_arc_map.find(arc_append.olabel); - if (r!=initial_arc_map.end()) { + if (r != initial_arc_map.end()) { ArcIterator aiter_chunk(clat, 0); // initial state aiter_chunk.Seek(r->second); const auto &arc_chunk = aiter_chunk.Value(); KALDI_ASSERT(arc_chunk.olabel == arc_append.olabel); - StateId state_append = arc_chunk.nextstate+state_offset; - if (prev_final_state == -1) prev_final_state=arc_append.nextstate; - else KALDI_ASSERT(arc_append.nextstate == prev_final_state); + StateId state_append = arc_chunk.nextstate + state_offset; + if (prev_final_state == -1) + prev_final_state = arc_append.nextstate; + else + KALDI_ASSERT(arc_append.nextstate == prev_final_state); CompactLatticeArc arc_append_mod(arc_append); arc_append_mod.nextstate = state_append; CompactLatticeWeight weight_offset; - weight_offset.SetWeight(LatticeWeight(0, -state_label_forward_prob_[arc_append.olabel])); - vector weights = {arc_append_mod.weight, arc_chunk.weight, olat->Final(prev_final_state), weight_offset}; - BaseFloat v1=0, v2=0; - for (auto& i:weights) - v1+=i.Weight().Value1(); - for (auto& i:weights) - v2+=i.Weight().Value2(); + weight_offset.SetWeight( + LatticeWeight(0, -state_label_forward_prob_[arc_append.olabel])); + vector weights = { + arc_append_mod.weight, arc_chunk.weight, olat->Final(prev_final_state), + weight_offset}; + BaseFloat v1 = 0, v2 = 0; + for (auto &i : weights) v1 += i.Weight().Value1(); + for (auto &i : weights) v2 += i.Weight().Value2(); vector s; - for (auto& i:weights) + for (auto &i : weights) s.insert(s.end(), i.String().begin(), i.String().end()); - arc_append_mod.weight = CompactLatticeWeight(LatticeWeight(v1,v2), s); + arc_append_mod.weight = CompactLatticeWeight(LatticeWeight(v1, v2), s); arc_append_mod.olabel = 0; arc_append_mod.ilabel = 0; aiter.SetValue(arc_append_mod); @@ -1018,15 +1021,19 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, state_label_forward_prob_.erase(arc_append.olabel); } KALDI_ASSERT(prev_final_state != -1); // at least one arc should be appended - // making all unmodified remaining arcs of final_arc_list_prev_ are connected to a dead state + // making all unmodified remaining arcs of final_arc_list_prev_ be connected to + // a dead state olat->SetFinal(prev_final_state, CompactLatticeWeight::Zero()); - } else olat->SetStart(0); - KALDI_VLOG(2) << "Frame: " <NumStates(); + } else + olat->SetStart(0); + KALDI_VLOG(2) << "Frame: " << NumFramesDecoded() + << " states of chunk: " << clat.NumStates() + << " states of the lattice: " << olat->NumStates(); } // TODO: check in the case the last frame is det twice + last_get_lattice_frame_ = NumFramesDecoded(); - // Determinize the final lattice + // step 4: re-determinize the final lattice if (redeterminize) { - lattice_writer.Write("TODO2", *olat); // TODO DeterminizeLatticePrunedOptions det_opts; det_opts.delta = config_.det_opts.delta; det_opts.max_mem = config_.det_opts.max_mem; @@ -1041,58 +1048,45 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, << ")."; } } - if (!DeterminizeLatticePruned( - lat, - config_.lattice_beam, - olat, - det_opts)) - KALDI_WARN << "Determinization finished earlier than the beam"; - Connect(olat); // Remove unreachable states... there might be - lattice_writer.Write("TODO3", *olat); - KALDI_VLOG(2) << "states of the lattice: "<NumStates(); - } - // a small number of these, in some cases. - // Note: if something went wrong and the raw lattice was empty, - // we should still get to this point in the code without warnings or failures. + if (!DeterminizeLatticePruned(lat, config_.lattice_beam, olat, det_opts)) + KALDI_WARN << "Determinization finished earlier than the beam"; + Connect(olat); // Remove unreachable states... there might be + KALDI_VLOG(2) << "states of the lattice: " << olat->NumStates(); + } + return (olat->NumStates() != 0); } - + +// It is modified from LatticeFasterDecoderTpl::GetRawLattice() template bool LatticeIncrementalDecoderTpl::GetRawLattice( - Lattice *ofst, - bool use_final_probs, - int32 frame_begin, - int32 frame_end, - bool create_initial_state, - bool create_final_state) { + Lattice *ofst, bool use_final_probs, int32 frame_begin, int32 frame_end, + bool create_initial_state, bool create_final_state) { typedef LatticeArc Arc; typedef Arc::StateId StateId; typedef Arc::Weight Weight; typedef Arc::Label Label; - // Note: you can't use the old interface (Decode()) if you want to - // get the lattice with use_final_probs = false. You'd have to do - // InitDecoding() and then AdvanceDecoding(). if (decoding_finalized_ && !use_final_probs) KALDI_ERR << "You cannot call FinalizeDecoding() and then call " << "GetRawLattice() with use_final_probs == false"; - unordered_map final_costs_local; + unordered_map final_costs_local; - const unordered_map &final_costs = + const unordered_map &final_costs = (decoding_finalized_ ? final_costs_ : final_costs_local); if (!decoding_finalized_ && use_final_probs) ComputeFinalCosts(&final_costs_local, NULL, NULL); ofst->DeleteStates(); - if (frame_begin != 0) ofst->AddState(); // initial-state for the chunk + if (create_initial_state) ofst->AddState(); // initial-state for the chunk // num-frames plus one (since frames are one-based, and we have // an extra frame for the start-state). KALDI_ASSERT(frame_end > 0); - const int32 bucket_count = num_toks_/2 + 3; - unordered_map tok_map(bucket_count); + const int32 bucket_count = num_toks_ / 2 + 3; + unordered_map tok_map(bucket_count); // First create all states. - std::vector token_list; + std::vector token_list; for (int32 f = frame_begin; f <= frame_end; f++) { if (active_toks_[f].toks == NULL) { KALDI_WARN << "GetRawLattice: no tokens active on frame " << f @@ -1101,28 +1095,27 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( } TopSortTokens(active_toks_[f].toks, &token_list); for (size_t i = 0; i < token_list.size(); i++) - if (token_list[i] != NULL) - tok_map[token_list[i]] = ofst->AddState(); + if (token_list[i] != NULL) tok_map[token_list[i]] = ofst->AddState(); } // The next statement sets the start state of the output FST. Because we // topologically sorted the tokens, state zero must be the start-state. StateId begin_state = 0; ofst->SetStart(begin_state); - KALDI_VLOG(4) << "init:" << num_toks_/2 + 3 << " buckets:" - << tok_map.bucket_count() << " load:" << tok_map.load_factor() + KALDI_VLOG(4) << "init:" << num_toks_ / 2 + 3 + << " buckets:" << tok_map.bucket_count() + << " load:" << tok_map.load_factor() << " max:" << tok_map.max_load_factor(); // Create initial_arc for later appending with the previous chunk if (create_initial_state) { for (Token *tok = active_toks_[frame_begin].toks; tok != NULL; tok = tok->next) { StateId cur_state = tok_map[tok]; + // state_label_map_ is construct during create_final_state int32 id = state_label_map_.find(tok)->second; // it should exist - // TODO: calculate alpha but not use tot_cost or extra_cost + // TODO: calculate alpha but not use tot_cost BaseFloat cost_offset = tok->tot_cost; state_label_forward_prob_[id] = tok->tot_cost; - Arc arc(0, id, - Weight(0, cost_offset), - cur_state); + Arc arc(0, id, Weight(0, cost_offset), cur_state); ofst->AddArc(begin_state, arc); } } @@ -1130,29 +1123,28 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( for (int32 f = frame_begin; f <= frame_end; f++) { for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { StateId cur_state = tok_map[tok]; - for (ForwardLinkT *l = tok->links; - l != NULL; - l = l->next) { - if (f==frame_begin && create_initial_state && l->ilabel==0) continue; // has existed in the last chunk - if (f==frame_end && create_final_state && l->ilabel!=0) continue; // will exist in the next chunk - typename unordered_map::const_iterator - iter = tok_map.find(l->next_tok); + for (ForwardLinkT *l = tok->links; l != NULL; l = l->next) { + if (f == frame_begin && create_initial_state && l->ilabel == 0) + continue; // has existed in the last chunk + if (f == frame_end && create_final_state && l->ilabel != 0) + continue; // will exist in the next chunk + typename unordered_map::const_iterator iter = + tok_map.find(l->next_tok); StateId nextstate = iter->second; KALDI_ASSERT(iter != tok_map.end()); BaseFloat cost_offset = 0.0; - if (l->ilabel != 0) { // emitting.. + if (l->ilabel != 0) { // emitting.. KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); cost_offset = cost_offsets_[f]; } Arc arc(l->ilabel, l->olabel, - Weight(l->graph_cost, l->acoustic_cost - cost_offset), - nextstate); + Weight(l->graph_cost, l->acoustic_cost - cost_offset), nextstate); ofst->AddArc(cur_state, arc); } if (f == frame_end) { if (use_final_probs && !final_costs.empty()) { - typename unordered_map::const_iterator - iter = final_costs.find(tok); + typename unordered_map::const_iterator iter = + final_costs.find(tok); if (iter != final_costs.end()) ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); } else { @@ -1172,11 +1164,12 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( StateId cur_state = tok_map[tok]; int32 id = state_label_avilable_idx_++; state_label_map_[tok] = id; - Weight final_weight = (!decoding_finalized_ && ofst->Final(cur_state) == Weight::Zero())? Weight::One(): ofst->Final(cur_state); + Weight final_weight = + (!decoding_finalized_ && ofst->Final(cur_state) == Weight::Zero()) + ? Weight::One() + : ofst->Final(cur_state); - Arc arc(0, id, - final_weight, - end_state); + Arc arc(0, id, final_weight, end_state); ofst->AddArc(cur_state, arc); ofst->SetFinal(cur_state, Weight::Zero()); } @@ -1184,18 +1177,23 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( return (ofst->NumStates() > 0); } - // Instantiate the template for the combination of token types and FST types // that we'll need. -template class LatticeIncrementalDecoderTpl, decoder::StdToken>; -template class LatticeIncrementalDecoderTpl, decoder::StdToken >; -template class LatticeIncrementalDecoderTpl, decoder::StdToken >; +template class LatticeIncrementalDecoderTpl, + decoder::StdToken>; +template class LatticeIncrementalDecoderTpl, + decoder::StdToken>; +template class LatticeIncrementalDecoderTpl, + decoder::StdToken>; template class LatticeIncrementalDecoderTpl; -template class LatticeIncrementalDecoderTpl , decoder::BackpointerToken>; -template class LatticeIncrementalDecoderTpl, decoder::BackpointerToken >; -template class LatticeIncrementalDecoderTpl, decoder::BackpointerToken >; -template class LatticeIncrementalDecoderTpl; - +template class LatticeIncrementalDecoderTpl, + decoder::BackpointerToken>; +template class LatticeIncrementalDecoderTpl, + decoder::BackpointerToken>; +template class LatticeIncrementalDecoderTpl, + decoder::BackpointerToken>; +template class LatticeIncrementalDecoderTpl; } // end namespace kaldi. diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index c7436002506..c289820e822 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -23,7 +23,6 @@ #ifndef KALDI_DECODER_LATTICE_INCREMENTAL_DECODER_H_ #define KALDI_DECODER_LATTICE_INCREMENTAL_DECODER_H_ - #include "util/stl-utils.h" #include "util/hash-list.h" #include "fst/fstlib.h" @@ -44,51 +43,59 @@ struct LatticeIncrementalDecoderConfig { int32 prune_interval; bool determinize_lattice; // not inspected by this class... used in // command-line program. - BaseFloat beam_delta; // has nothing to do with beam_ratio + BaseFloat beam_delta; // has nothing to do with beam_ratio BaseFloat hash_ratio; - BaseFloat prune_scale; // Note: we don't make this configurable on the command line, - // it's not a very important parameter. It affects the - // algorithm that prunes the tokens as we go. + BaseFloat + prune_scale; // Note: we don't make this configurable on the command line, + // it's not a very important parameter. It affects the + // algorithm that prunes the tokens as we go. // Most of the options inside det_opts are not actually queried by the // LatticeIncrementalDecoder class itself, but by the code that calls it, for // example in the function DecodeUtteranceLatticeIncremental. int32 max_word_id; // for GetLattice fst::DeterminizeLatticePhonePrunedOptions det_opts; - LatticeIncrementalDecoderConfig(): beam(16.0), - max_active(std::numeric_limits::max()), - min_active(200), - lattice_beam(10.0), - prune_interval(25), - determinize_lattice(true), - beam_delta(0.5), - hash_ratio(2.0), - prune_scale(0.1), - max_word_id(1e7) { } + LatticeIncrementalDecoderConfig() + : beam(16.0), + max_active(std::numeric_limits::max()), + min_active(200), + lattice_beam(10.0), + prune_interval(25), + determinize_lattice(true), + beam_delta(0.5), + hash_ratio(2.0), + prune_scale(0.1), + max_word_id(1e7) {} void Register(OptionsItf *opts) { det_opts.Register(opts); opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); - opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; " + opts->Register("max-active", &max_active, + "Decoder max active states. Larger->slower; " "more accurate"); opts->Register("min-active", &min_active, "Decoder minimum #active states."); - opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, " + opts->Register("lattice-beam", &lattice_beam, + "Lattice generation beam. Larger->slower, " "and deeper lattices"); - opts->Register("prune-interval", &prune_interval, "Interval (in frames) at " + opts->Register("prune-interval", &prune_interval, + "Interval (in frames) at " "which to prune tokens"); - opts->Register("determinize-lattice", &determinize_lattice, "If true, " + opts->Register("determinize-lattice", &determinize_lattice, + "If true, " "determinize the lattice (lattice-determinization, keeping only " "best pdf-sequence for each word-sequence)."); - opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " + opts->Register("beam-delta", &beam_delta, + "Increment used in decoding-- this " "parameter is obscure and relates to a speedup in the way the " "max-active constraint is applied. Larger is more accurate."); - opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to " + opts->Register("hash-ratio", &hash_ratio, + "Setting used in decoder to " "control hash behavior"); } void Check() const { - KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 - && min_active <= max_active - && prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0 - && prune_scale > 0.0 && prune_scale < 1.0); + KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 && + min_active <= max_active && prune_interval > 0 && + beam_delta > 0.0 && hash_ratio >= 1.0 && prune_scale > 0.0 && + prune_scale < 1.0); } }; @@ -119,21 +126,19 @@ class LatticeIncrementalDecoderTpl { // Instantiate this class once for each thing you have to decode. // This version of the constructor does not take ownership of // 'fst'. - LatticeIncrementalDecoderTpl(const FST &fst, const TransitionModel& trans_model, - const LatticeIncrementalDecoderConfig &config); + LatticeIncrementalDecoderTpl(const FST &fst, const TransitionModel &trans_model, + const LatticeIncrementalDecoderConfig &config); // This version of the constructor takes ownership of the fst, and will delete // it when this object is destroyed. LatticeIncrementalDecoderTpl(const LatticeIncrementalDecoderConfig &config, - FST *fst, const TransitionModel& trans_model); + FST *fst, const TransitionModel &trans_model); void SetOptions(const LatticeIncrementalDecoderConfig &config) { config_ = config; } - const LatticeIncrementalDecoderConfig &GetOptions() const { - return config_; - } + const LatticeIncrementalDecoderConfig &GetOptions() const { return config_; } ~LatticeIncrementalDecoderTpl(); @@ -143,7 +148,6 @@ class LatticeIncrementalDecoderTpl { /// final state). bool Decode(DecodableInterface *decodable); - /// says whether a final-state was active on the last frame. If it was not, the /// lattice (or traceback) will end with states that are not final-states. bool ReachedFinal() const { @@ -156,8 +160,7 @@ class LatticeIncrementalDecoderTpl { /// final-state of the graph then it will include those as final-probs, else /// it will treat all final-probs as one. Note: this just calls GetRawLattice() /// and figures out the shortest path. - bool GetBestPath(Lattice *ofst, - bool use_final_probs = true) const; + bool GetBestPath(Lattice *ofst, bool use_final_probs = true) const; /// Outputs an FST corresponding to the raw, state-level /// tracebacks. Returns true if result is nonempty. @@ -173,7 +176,6 @@ class LatticeIncrementalDecoderTpl { bool GetRawLattice(Lattice *ofst, bool use_final_probs = true) const; bool GetCompactLattice(CompactLattice *ofst) const; - /// InitDecoding initializes the decoding, and should only be used if you /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to /// call this. You can also call InitDecoding if you have already decoded an @@ -184,8 +186,7 @@ class LatticeIncrementalDecoderTpl { /// object. You can keep calling it each time more frames become available. /// If max_num_frames is specified, it specifies the maximum number of frames /// the function will decode before returning. - void AdvanceDecoding(DecodableInterface *decodable, - int32 max_num_frames = -1); + void AdvanceDecoding(DecodableInterface *decodable, int32 max_num_frames = -1); /// This function may be optionally called after AdvanceDecoding(), when you /// do not plan to decode any further. It does an extra pruning step that @@ -211,7 +212,6 @@ class LatticeIncrementalDecoderTpl { /// reasonable likelihood. BaseFloat FinalRelativeCost() const; - // Returns the number of frames decoded so far. The value returned changes // whenever we call ProcessEmitting(). inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } @@ -230,11 +230,11 @@ class LatticeIncrementalDecoderTpl { Token *toks; bool must_prune_forward_links; bool must_prune_tokens; - TokenList(): toks(NULL), must_prune_forward_links(true), - must_prune_tokens(true) { } + TokenList() + : toks(NULL), must_prune_forward_links(true), must_prune_tokens(true) {} }; - using Elem = typename HashList::Elem; + using Elem = typename HashList::Elem; // Equivalent to: // struct Elem { // StateId key; @@ -268,8 +268,7 @@ class LatticeIncrementalDecoderTpl { // extra_costs_changed is set to true if extra_cost was changed for any token // links_pruned is set to true if any link in any token was pruned void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed, - bool *links_pruned, - BaseFloat delta); + bool *links_pruned, BaseFloat delta); // This function computes the final-costs for tokens active on the final // frame. It outputs to final-costs, if non-NULL, a map from the Token* @@ -287,7 +286,7 @@ class LatticeIncrementalDecoderTpl { // forward-cost[t] if there were no final-probs active on the final frame. // You cannot call this after FinalizeDecoding() has been called; in that // case you should get the answer from class-member variables. - void ComputeFinalCosts(unordered_map *final_costs, + void ComputeFinalCosts(unordered_map *final_costs, BaseFloat *final_relative_cost, BaseFloat *final_best_cost) const; @@ -302,7 +301,6 @@ class LatticeIncrementalDecoderTpl { // It's called by PruneActiveTokens if any forward links have been pruned void PruneTokensForFrame(int32 frame_plus_one); - // Go backwards through still-alive tokens, pruning them if the // forward+backward cost is more than lat_beam away from the best path. It's // possible to prove that this is "correct" in the sense that we won't lose @@ -313,8 +311,8 @@ class LatticeIncrementalDecoderTpl { void PruneActiveTokens(BaseFloat delta); /// Gets the weight cutoff. Also counts the active tokens. - BaseFloat GetCutoff(Elem *list_head, size_t *tok_count, - BaseFloat *adaptive_beam, Elem **best_elem); + BaseFloat GetCutoff(Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, + Elem **best_elem); /// Processes emitting arcs for one frame. Propagates from prev_toks_ to /// cur_toks_. Returns the cost cutoff for subsequent ProcessNonemitting() to @@ -333,13 +331,13 @@ class LatticeIncrementalDecoderTpl { // That is, the emitting probs of frame t are accounted for in tokens at // toks_[t+1]. The zeroth frame is for nonemitting transition at the start of // the graph. - HashList toks_; + HashList toks_; std::vector active_toks_; // Lists of tokens, indexed by // frame (members of TokenList are toks, must_prune_forward_links, // must_prune_tokens). - std::vector queue_; // temp variable used in ProcessNonemitting, - std::vector tmp_array_; // used in GetCutoff. + std::vector queue_; // temp variable used in ProcessNonemitting, + std::vector tmp_array_; // used in GetCutoff. // fst_ is a pointer to the FST we are decoding from. const FST *fst_; @@ -365,7 +363,7 @@ class LatticeIncrementalDecoderTpl { bool decoding_finalized_; /// For the meaning of the next 3 variables, see the comment for /// decoding_finalized_ above., and ComputeFinalCosts(). - unordered_map final_costs_; + unordered_map final_costs_; BaseFloat final_relative_cost_; BaseFloat final_best_cost_; @@ -388,56 +386,62 @@ class LatticeIncrementalDecoderTpl { // cycles, which are not allowed). Note: the output list may contain NULLs, // which the caller should pass over; it just happens to be more efficient for // the algorithm to output a list that contains NULLs. - static void TopSortTokens(Token *tok_list, - std::vector *topsorted_list); + static void TopSortTokens(Token *tok_list, std::vector *topsorted_list); void ClearActiveTokens(); /// Obtains a CompactLattice for the part of this utterance that has been - /// decoded so far. If you call this multiple times (calling it on every frame would not make - /// sense, but every, say, 10, to 40 frames might make sense) it will spread out the - /// work of determinization over time,which might be useful for online applications. + /// decoded so far. If you call this multiple times (calling it on every frame + /// would not make + /// sense, but every, say, 10, to 40 frames might make sense) it will spread out + /// the + /// work of determinization over time,which might be useful for online + /// applications. /// /// @param [in] use_final_probs If true *and* at least one final-state in HCLG - /// was active on the final frame, include final-probs from HCLG - /// in the lattice. Otherwise treat all final-costs of states active + /// was active on the final frame, include final-probs from + /// HCLG + /// in the lattice. Otherwise treat all final-costs of + /// states active /// on the most recent frame as zero (i.e. Weight::One()). /// @param [in] redeterminize If true, re-determinize the CompactLattice - /// after appending the most recently decoded chunk to it, to + /// after appending the most recently decoded chunk to it, + /// to /// ensure that the output is fully deterministic. - /// This does extra work, but not nearly as much as determinizing + /// This does extra work, but not nearly as much as + /// determinizing /// a RawLattice from scratch. /// @param [out] lat The CompactLattice representing what has been decoded /// so far. - /// @return reached_final This function will returns true if a state that was final in - /// HCLG was active on the most recent frame, and false otherwise. - /// CAUTION: this is not the same meaning as the return value of + /// @return reached_final This function will returns true if a state that was + /// final in + /// HCLG was active on the most recent frame, and false + /// otherwise. + /// CAUTION: this is not the same meaning as the return + /// value of /// LatticeFasterDecoder::GetLattice(). - bool GetLattice(bool use_final_probs, - bool redeterminize, CompactLattice *olat); - CompactLattice lat_; - int32 last_get_lattice_frame_; - unordered_map state_label_map_; - int32 state_label_avilable_idx_; - const TransitionModel& trans_model_; - std::vector> final_arc_list_; + bool GetLattice(bool use_final_probs, bool redeterminize, CompactLattice *olat); + CompactLattice lat_; // the compact lattice we obtain + int32 last_get_lattice_frame_; // the last time we call GetLattice + unordered_map state_label_map_; // between Token and state_label + int32 state_label_avilable_idx_; // we allocate a unique id for each Token + const TransitionModel &trans_model_; // keep it for determinization + std::vector> final_arc_list_; // keep final_arc std::vector> final_arc_list_prev_; - // TODO use 2 vector since state_label is continuous in each frame, and we need 2 frames - unordered_map state_label_forward_prob_; // alpha for each state_label (Token) - bool GetRawLattice(Lattice *ofst, bool use_final_probs, - int32 frame_begin, - int32 frame_end, - bool create_initial_state, - bool create_final_state); - - + // TODO use 2 vector to replace this map, since state_label is continuous + // in each frame, and we need 2 frames of them + unordered_map + state_label_forward_prob_; // alpha for each state_label (Token) + // specific design for incremental GetLattice + bool GetRawLattice(Lattice *ofst, bool use_final_probs, int32 frame_begin, + int32 frame_end, bool create_initial_state, + bool create_final_state); KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDecoderTpl); }; -typedef LatticeIncrementalDecoderTpl LatticeIncrementalDecoder; - - +typedef LatticeIncrementalDecoderTpl + LatticeIncrementalDecoder; } // end namespace kaldi. From 8080697f64d7e2dab6a3b86aaede773051183867 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Thu, 28 Mar 2019 12:23:29 +0800 Subject: [PATCH 05/60] this commit is for sanity check --- src/decoder/lattice-incremental-decoder.cc | 47 ++++++++++++++++++---- src/decoder/lattice-incremental-decoder.h | 2 + 2 files changed, 41 insertions(+), 8 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 103e7fbeee7..f2d0575f05b 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -933,9 +933,14 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, !decoding_finalized_)); // step 2: Determinize the chunk CompactLattice clat; +#if 0 if (!DeterminizeLatticePhonePrunedWrapper( trans_model_, &raw_fst, config_.lattice_beam, &clat, config_.det_opts)) KALDI_WARN << "Determinization finished earlier than the beam"; +#else + ConvertLattice(raw_fst, &clat); + Connect(&clat); +#endif final_arc_list_.swap(final_arc_list_prev_); final_arc_list_.clear(); @@ -969,6 +974,7 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, if (arc.olabel > config_.max_word_id) { if (s == 0) { // initial_arc initial_arc_map[arc.olabel] = aiter.Position(); + if (last_get_lattice_frame_ != 0) initial_state_in_chunk_.erase(arc.olabel); } else { // final_arc KALDI_ASSERT(clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()); final_arc_list_.push_back( @@ -977,6 +983,8 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, } } } + // sanity check + KALDI_ASSERT(initial_state_in_chunk_.size() == 0); // step 3.2: connect the states between two chunks if (last_get_lattice_frame_ != 0) { @@ -1029,6 +1037,20 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, KALDI_VLOG(2) << "Frame: " << NumFramesDecoded() << " states of chunk: " << clat.NumStates() << " states of the lattice: " << olat->NumStates(); + // sanity check + CompactLattice cdecoded; + Lattice decoded; + ShortestPath(*olat, &cdecoded); + ConvertLattice(cdecoded, &decoded); + LatticeWeight weight; + std::vector alignment; + std::vector words; + GetLinearSymbolSequence(decoded, &alignment, &words, &weight); + BaseFloat offset_sum=0; + for (auto& i:cost_offsets_) offset_sum+=i; + KALDI_ASSERT(alignment.size() == NumFramesDecoded()); + KALDI_ASSERT(ApproxEqual(best_cost_in_chunk_, weight.Value1() + weight.Value2()+offset_sum, 1e-2)); + } // TODO: check in the case the last frame is det twice last_get_lattice_frame_ = NumFramesDecoded(); @@ -1108,26 +1130,33 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( << " max:" << tok_map.max_load_factor(); // Create initial_arc for later appending with the previous chunk if (create_initial_state) { + initial_state_in_chunk_.clear(); for (Token *tok = active_toks_[frame_begin].toks; tok != NULL; tok = tok->next) { StateId cur_state = tok_map[tok]; // state_label_map_ is construct during create_final_state int32 id = state_label_map_.find(tok)->second; // it should exist + initial_state_in_chunk_.insert(id); // TODO: calculate alpha but not use tot_cost BaseFloat cost_offset = tok->tot_cost; state_label_forward_prob_[id] = tok->tot_cost; Arc arc(0, id, Weight(0, cost_offset), cur_state); ofst->AddArc(begin_state, arc); } + KALDI_VLOG(6) << initial_state_in_chunk_.size(); } // Now create all arcs. + best_cost_in_chunk_ = std::numeric_limits::infinity(); for (int32 f = frame_begin; f <= frame_end; f++) { for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { StateId cur_state = tok_map[tok]; for (ForwardLinkT *l = tok->links; l != NULL; l = l->next) { + // TODO + /* if (f == frame_begin && create_initial_state && l->ilabel == 0) continue; // has existed in the last chunk if (f == frame_end && create_final_state && l->ilabel != 0) continue; // will exist in the next chunk + */ typename unordered_map::const_iterator iter = tok_map.find(l->next_tok); StateId nextstate = iter->second; @@ -1142,14 +1171,18 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( ofst->AddArc(cur_state, arc); } if (f == frame_end) { + LatticeWeight weight = LatticeWeight::One(); if (use_final_probs && !final_costs.empty()) { typename unordered_map::const_iterator iter = final_costs.find(tok); if (iter != final_costs.end()) - ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); - } else { - ofst->SetFinal(cur_state, LatticeWeight::One()); - } + weight = LatticeWeight(iter->second, 0); + else + weight = LatticeWeight::Zero(); + } + ofst->SetFinal(cur_state, weight); + // for sanity check + best_cost_in_chunk_ = std::min(best_cost_in_chunk_, tok->tot_cost + weight.Value1() + weight.Value2()); } } } @@ -1164,10 +1197,8 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( StateId cur_state = tok_map[tok]; int32 id = state_label_avilable_idx_++; state_label_map_[tok] = id; - Weight final_weight = - (!decoding_finalized_ && ofst->Final(cur_state) == Weight::Zero()) - ? Weight::One() - : ofst->Final(cur_state); + KALDI_ASSERT(ofst->Final(cur_state) != Weight::Zero()); + Weight final_weight = ofst->Final(cur_state); Arc arc(0, id, final_weight, end_state); ofst->AddArc(cur_state, arc); diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index c289820e822..5ac47a1706f 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -436,6 +436,8 @@ class LatticeIncrementalDecoderTpl { bool GetRawLattice(Lattice *ofst, bool use_final_probs, int32 frame_begin, int32 frame_end, bool create_initial_state, bool create_final_state); + BaseFloat best_cost_in_chunk_; + unordered_set initial_state_in_chunk_; KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDecoderTpl); }; From b302f126219e76cf2e7400c68b9b96056610ef7e Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Thu, 28 Mar 2019 16:58:59 +0800 Subject: [PATCH 06/60] code clean --- src/decoder/lattice-incremental-decoder.cc | 282 ++++++++++++--------- src/decoder/lattice-incremental-decoder.h | 50 +++- 2 files changed, 202 insertions(+), 130 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index f2d0575f05b..6291b0a0a20 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -81,10 +81,10 @@ void LatticeIncrementalDecoderTpl::InitDecoding() { last_get_lattice_frame_ = 0; state_label_map_.clear(); state_label_map_.reserve(std::min((int32)1e5, config_.max_active)); - state_label_avilable_idx_ = config_.max_word_id + 1; + state_label_available_idx_ = config_.max_word_id + 1; final_arc_list_.clear(); final_arc_list_prev_.clear(); - state_label_forward_prob_.clear(); + state_label_forward_cost_.clear(); ProcessNonemitting(config_.beam); } @@ -924,20 +924,27 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, bool redeterminize, CompactLattice *olat) { using namespace fst; + bool not_first_chunk = last_get_lattice_frame_ != 0; + // last_get_lattice_frame_ is used to record the first frame of the chunk + // last time we obtain from calling this function. If it reaches NumFramesDecoded() + // we cannot generate any more chunk if (last_get_lattice_frame_ < NumFramesDecoded()) { Lattice raw_fst; // step 1: Get lattice chunk with initial state + // In this function, we do not create the initial state in + // the first chunk, and we do not create the final state in the last chunk KALDI_ASSERT(GetRawLattice(&raw_fst, use_final_probs, last_get_lattice_frame_, - NumFramesDecoded(), last_get_lattice_frame_ != 0, + NumFramesDecoded(), not_first_chunk, !decoding_finalized_)); // step 2: Determinize the chunk CompactLattice clat; -#if 0 +#if 1 if (!DeterminizeLatticePhonePrunedWrapper( trans_model_, &raw_fst, config_.lattice_beam, &clat, config_.det_opts)) KALDI_WARN << "Determinization finished earlier than the beam"; #else + // sanity check, remove it later ConvertLattice(raw_fst, &clat); Connect(&clat); #endif @@ -945,99 +952,10 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, final_arc_list_.swap(final_arc_list_prev_); final_arc_list_.clear(); - // step 3.1: Appending new chunk to the old one - int32 state_offset = olat->NumStates(); - if (last_get_lattice_frame_ != 0) - state_offset--; // since we do not append initial state - unordered_map - initial_arc_map; // the incoming states of these arcs are initial states - initial_arc_map.reserve(std::min((int32)1e5, config_.max_active)); - for (StateIterator siter(clat); !siter.Done(); siter.Next()) { - auto s = siter.Value(); - StateId state_append = -1; - if (last_get_lattice_frame_ == 0 || - s != 0) { // do not need to copy initial state - state_append = s + state_offset; - KALDI_ASSERT(state_append == olat->AddState()); - olat->SetFinal(state_append, clat.Final(s)); - } + // step 3: Appending the new chunk in clat to the old one in olat + AppendLatticeChunks(clat, not_first_chunk, olat); - for (ArcIterator aiter(clat, s); !aiter.Done(); aiter.Next()) { - const auto &arc = aiter.Value(); - // construct a copy of the state & arcs - if (last_get_lattice_frame_ == 0 || - s != 0) { // do not need to copy initial arc - CompactLatticeArc arc_append(arc); - arc_append.nextstate += state_offset; - olat->AddArc(state_append, arc_append); - } - if (arc.olabel > config_.max_word_id) { - if (s == 0) { // initial_arc - initial_arc_map[arc.olabel] = aiter.Position(); - if (last_get_lattice_frame_ != 0) initial_state_in_chunk_.erase(arc.olabel); - } else { // final_arc - KALDI_ASSERT(clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()); - final_arc_list_.push_back( - pair(state_append, aiter.Position())); - } - } - } - } - // sanity check - KALDI_ASSERT(initial_state_in_chunk_.size() == 0); - - // step 3.2: connect the states between two chunks - if (last_get_lattice_frame_ != 0) { - KALDI_ASSERT(final_arc_list_prev_.size()); - StateId prev_final_state = -1; - for (auto &i : final_arc_list_prev_) { - MutableArcIterator aiter(olat, i.first); - aiter.Seek(i.second); - auto &arc_append = aiter.Value(); - auto r = initial_arc_map.find(arc_append.olabel); - if (r != initial_arc_map.end()) { - ArcIterator aiter_chunk(clat, 0); // initial state - aiter_chunk.Seek(r->second); - const auto &arc_chunk = aiter_chunk.Value(); - KALDI_ASSERT(arc_chunk.olabel == arc_append.olabel); - StateId state_append = arc_chunk.nextstate + state_offset; - if (prev_final_state == -1) - prev_final_state = arc_append.nextstate; - else - KALDI_ASSERT(arc_append.nextstate == prev_final_state); - CompactLatticeArc arc_append_mod(arc_append); - arc_append_mod.nextstate = state_append; - - CompactLatticeWeight weight_offset; - weight_offset.SetWeight( - LatticeWeight(0, -state_label_forward_prob_[arc_append.olabel])); - vector weights = { - arc_append_mod.weight, arc_chunk.weight, olat->Final(prev_final_state), - weight_offset}; - BaseFloat v1 = 0, v2 = 0; - for (auto &i : weights) v1 += i.Weight().Value1(); - for (auto &i : weights) v2 += i.Weight().Value2(); - vector s; - for (auto &i : weights) - s.insert(s.end(), i.String().begin(), i.String().end()); - - arc_append_mod.weight = CompactLatticeWeight(LatticeWeight(v1, v2), s); - arc_append_mod.olabel = 0; - arc_append_mod.ilabel = 0; - aiter.SetValue(arc_append_mod); - } // otherwise, it has been pruned - state_label_forward_prob_.erase(arc_append.olabel); - } - KALDI_ASSERT(prev_final_state != -1); // at least one arc should be appended - // making all unmodified remaining arcs of final_arc_list_prev_ be connected to - // a dead state - olat->SetFinal(prev_final_state, CompactLatticeWeight::Zero()); - } else - olat->SetStart(0); - KALDI_VLOG(2) << "Frame: " << NumFramesDecoded() - << " states of chunk: " << clat.NumStates() - << " states of the lattice: " << olat->NumStates(); - // sanity check + // sanity check, remove them later CompactLattice cdecoded; Lattice decoded; ShortestPath(*olat, &cdecoded); @@ -1046,12 +964,13 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, std::vector alignment; std::vector words; GetLinearSymbolSequence(decoded, &alignment, &words, &weight); - BaseFloat offset_sum=0; - for (auto& i:cost_offsets_) offset_sum+=i; + BaseFloat offset_sum = 0; + for (auto &i : cost_offsets_) offset_sum += i; KALDI_ASSERT(alignment.size() == NumFramesDecoded()); - KALDI_ASSERT(ApproxEqual(best_cost_in_chunk_, weight.Value1() + weight.Value2()+offset_sum, 1e-2)); - - } // TODO: check in the case the last frame is det twice + // TODO: the following KALDI_ASSERT will fail some time, which is unexpected + // KALDI_ASSERT(ApproxEqual(best_cost_in_chunk_, weight.Value1() + + // weight.Value2()+offset_sum, 1e-2)); + } last_get_lattice_frame_ = NumFramesDecoded(); // step 4: re-determinize the final lattice @@ -1079,7 +998,6 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, return (olat->NumStates() != 0); } -// It is modified from LatticeFasterDecoderTpl::GetRawLattice() template bool LatticeIncrementalDecoderTpl::GetRawLattice( Lattice *ofst, bool use_final_probs, int32 frame_begin, int32 frame_end, @@ -1119,8 +1037,8 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( for (size_t i = 0; i < token_list.size(); i++) if (token_list[i] != NULL) tok_map[token_list[i]] = ofst->AddState(); } - // The next statement sets the start state of the output FST. Because we - // topologically sorted the tokens, state zero must be the start-state. + // The next statement sets the start state of the output FST. + // No matter create_initial_state or not , state zero must be the start-state. StateId begin_state = 0; ofst->SetStart(begin_state); @@ -1128,39 +1046,42 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( << " buckets:" << tok_map.bucket_count() << " load:" << tok_map.load_factor() << " max:" << tok_map.max_load_factor(); - // Create initial_arc for later appending with the previous chunk + // step 1.1: create initial_arc for later appending with the previous chunk if (create_initial_state) { initial_state_in_chunk_.clear(); for (Token *tok = active_toks_[frame_begin].toks; tok != NULL; tok = tok->next) { StateId cur_state = tok_map[tok]; // state_label_map_ is construct during create_final_state - int32 id = state_label_map_.find(tok)->second; // it should exist + auto r = state_label_map_.find(tok); + KALDI_ASSERT(r != state_label_map_.end()); // it should exist + int32 id = r->second; initial_state_in_chunk_.insert(id); + // Use cost_offsets to guide DeterminizeLatticePruned() + // later in GetLattice() + // For now, we use alpha (tot_cost) from the decoding stage as + // the initial weights of arcs connecting to the states in the begin + // of this chunk // TODO: calculate alpha but not use tot_cost BaseFloat cost_offset = tok->tot_cost; - state_label_forward_prob_[id] = tok->tot_cost; + // We record these cost_offset, and after we appending two chunks + // we will cancel them out + state_label_forward_cost_[id] = cost_offset; Arc arc(0, id, Weight(0, cost_offset), cur_state); ofst->AddArc(begin_state, arc); } KALDI_VLOG(6) << initial_state_in_chunk_.size(); } - // Now create all arcs. + // for sanity check best_cost_in_chunk_ = std::numeric_limits::infinity(); + // step 1.2: create all arcs as GetRawLattice() for (int32 f = frame_begin; f <= frame_end; f++) { for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { StateId cur_state = tok_map[tok]; for (ForwardLinkT *l = tok->links; l != NULL; l = l->next) { - // TODO - /* - if (f == frame_begin && create_initial_state && l->ilabel == 0) - continue; // has existed in the last chunk - if (f == frame_end && create_final_state && l->ilabel != 0) - continue; // will exist in the next chunk - */ typename unordered_map::const_iterator iter = tok_map.find(l->next_tok); - StateId nextstate = iter->second; KALDI_ASSERT(iter != tok_map.end()); + StateId nextstate = iter->second; BaseFloat cost_offset = 0.0; if (l->ilabel != 0) { // emitting.. KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); @@ -1170,6 +1091,14 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( Weight(l->graph_cost, l->acoustic_cost - cost_offset), nextstate); ofst->AddArc(cur_state, arc); } + // For the last frame in this chunk, we need to work out a + // proper final weight for the corresponding state. + // If use_final_probs == true, we will try to use the final cost we just + // calculated + // Otherwise, we use LatticeWeight::One(). We record these cost in the state + // Later in the code, if create_final_state == true, we will create + // a specific final state, and move the final costs to the cost of an arc + // connecting to the final state if (f == frame_end) { LatticeWeight weight = LatticeWeight::One(); if (use_final_probs && !final_costs.empty()) { @@ -1179,14 +1108,15 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( weight = LatticeWeight(iter->second, 0); else weight = LatticeWeight::Zero(); - } + } ofst->SetFinal(cur_state, weight); // for sanity check - best_cost_in_chunk_ = std::min(best_cost_in_chunk_, tok->tot_cost + weight.Value1() + weight.Value2()); + best_cost_in_chunk_ = std::min( + best_cost_in_chunk_, tok->tot_cost + weight.Value1() + weight.Value2()); } } } - // Create final_arc for later appending with the next chunk + // step 1.3 create final_arc for later appending with the next chunk (in GetLattice) if (create_final_state) { StateId end_state = ofst->AddState(); // final-state for the chunk ofst->SetFinal(end_state, Weight::One()); @@ -1195,8 +1125,14 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( state_label_map_.reserve(std::min((int32)1e5, config_.max_active)); for (Token *tok = active_toks_[frame_end].toks; tok != NULL; tok = tok->next) { StateId cur_state = tok_map[tok]; - int32 id = state_label_avilable_idx_++; + // We assign an unique state label for each of the token in the last frame + // of this chunk + int32 id = state_label_available_idx_++; state_label_map_[tok] = id; + // The final weight has been worked out in the previous for loop and + // store in the states + // Here, we create a specific final state, and move the final costs to + // the cost of an arc connecting to the final state KALDI_ASSERT(ofst->Final(cur_state) != Weight::Zero()); Weight final_weight = ofst->Final(cur_state); @@ -1208,6 +1144,112 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( return (ofst->NumStates() > 0); } +template +void LatticeIncrementalDecoderTpl::AppendLatticeChunks( + CompactLattice clat, bool not_first_chunk, CompactLattice *olat) { + using namespace fst; + // step 3.1: Appending new chunk to the old one + int32 state_offset = olat->NumStates(); + if (not_first_chunk) + state_offset--; // since we do not append initial state in the first chunk + + // A map between state label and the arc position (index) + // the incoming states of these arcs are initial states of the chunk + // and the olabel of these arcs are the key of this map (state label) + // The arc position are obtained from ArcIterator corresponding to the state + unordered_map initial_arc_map; + initial_arc_map.reserve(std::min((int32)1e5, config_.max_active)); + for (StateIterator siter(clat); !siter.Done(); siter.Next()) { + auto s = siter.Value(); + StateId state_appended = -1; + // We do not copy initial state, which exists except the first chunk + if (!not_first_chunk || s != 0) { + state_appended = s + state_offset; + KALDI_ASSERT(state_appended == olat->AddState()); + olat->SetFinal(state_appended, clat.Final(s)); + } + + for (ArcIterator aiter(clat, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + // We do not copy initial arcs, which exists except the first chunk. + // These arcs will be taken care later in step 3.2 + if (!not_first_chunk || s != 0) { + CompactLatticeArc arc_appended(arc); + arc_appended.nextstate += state_offset; + olat->AddArc(state_appended, arc_appended); + } + // Process state labels, which will be used in step 3.2 + if (arc.olabel > config_.max_word_id) { // initial_arc + if (s == 0) { // record initial_arc in this chunk, we will use it right now + initial_arc_map[arc.olabel] = aiter.Position(); + if (last_get_lattice_frame_ != 0) // Erase since we are not interested + initial_state_in_chunk_.erase(arc.olabel); + } else { // final_arc + // record final_arc in this chunk for the step 3.2 in the next call + KALDI_ASSERT(clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()); + final_arc_list_.push_back( + pair(state_appended, aiter.Position())); + } + } + } + } + // sanity check, remove it later + // KALDI_ASSERT(initial_state_in_chunk_.size() == 0); + + // step 3.2: connect the states between two chunks + if (last_get_lattice_frame_ != 0) { + KALDI_ASSERT(final_arc_list_prev_.size()); + StateId prev_final_state = -1; + for (auto &i : final_arc_list_prev_) { + MutableArcIterator aiter(olat, i.first); + aiter.Seek(i.second); + // Obtain the appended final arcs in the last chunk + auto &arc_appended = aiter.Value(); + // Find out whether its corresponding Token still exists in the begin + // of this chunk. If not, it is pruned by PruneActiveTokens() + auto r = initial_arc_map.find(arc_appended.olabel); + if (r != initial_arc_map.end()) { + ArcIterator aiter_unappended(clat, 0); // initial state + aiter_unappended.Seek(r->second); + const auto &arc_unappended = aiter_unappended.Value(); + KALDI_ASSERT(arc_unappended.olabel == arc_appended.olabel); + StateId state_appended = arc_unappended.nextstate + state_offset; + if (prev_final_state == -1) + prev_final_state = arc_appended.nextstate; + else + KALDI_ASSERT(arc_appended.nextstate == prev_final_state); + // For the later code in this loop, we try to modify the arc_appended + // to connect the last frame state of last chunk to the first frame + // state of this chunk. These begin and final states are + // corresponding to the same Token, guaranteed by unique state labels. + CompactLatticeArc arc_appended_mod(arc_appended); + arc_appended_mod.nextstate = state_appended; + + CompactLatticeWeight weight_offset; + weight_offset.SetWeight( + LatticeWeight(0, -state_label_forward_cost_[arc_appended.olabel])); + arc_appended_mod.weight = + Times(arc_appended_mod.weight, arc_unappended.weight); + arc_appended_mod.weight = + Times(arc_appended_mod.weight, olat->Final(prev_final_state)); + arc_appended_mod.weight = Times(arc_appended_mod.weight, weight_offset); + arc_appended_mod.olabel = 0; + arc_appended_mod.ilabel = 0; + aiter.SetValue(arc_appended_mod); + } // otherwise, it has been pruned + state_label_forward_cost_.erase(arc_appended.olabel); + } + KALDI_ASSERT(prev_final_state != -1); // at least one arc should be appended + // making all unmodified remaining arcs of final_arc_list_prev_ be connected to + // a dead state + olat->SetFinal(prev_final_state, CompactLatticeWeight::Zero()); + } else + olat->SetStart(0); + KALDI_VLOG(2) << "Frame: " << NumFramesDecoded() + << " states of chunk: " << clat.NumStates() + << " states of the lattice: " << olat->NumStates(); +} + // Instantiate the template for the combination of token types and FST types // that we'll need. template class LatticeIncrementalDecoderTpl, diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 5ac47a1706f..19fc7cc0288 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -390,10 +390,12 @@ class LatticeIncrementalDecoderTpl { void ClearActiveTokens(); - /// Obtains a CompactLattice for the part of this utterance that has been - /// decoded so far. If you call this multiple times (calling it on every frame - /// would not make - /// sense, but every, say, 10, to 40 frames might make sense) it will spread out + /// The following part is specifically designed for incremental determinization + /// + /// The function obtains a CompactLattice for the part of this utterance that has + /// been decoded so far. If you call this multiple times (calling it on + /// every frame would not make sense, + /// but every, say, 10, to 40 frames might make sense) it will spread out /// the /// work of determinization over time,which might be useful for online /// applications. @@ -424,21 +426,49 @@ class LatticeIncrementalDecoderTpl { CompactLattice lat_; // the compact lattice we obtain int32 last_get_lattice_frame_; // the last time we call GetLattice unordered_map state_label_map_; // between Token and state_label - int32 state_label_avilable_idx_; // we allocate a unique id for each Token + int32 state_label_available_idx_; // we allocate a unique id for each Token const TransitionModel &trans_model_; // keep it for determinization std::vector> final_arc_list_; // keep final_arc std::vector> final_arc_list_prev_; + // We keep alpha for each state_label (Token). We need them before determinization + // We cancel them after determinization // TODO use 2 vector to replace this map, since state_label is continuous // in each frame, and we need 2 frames of them - unordered_map - state_label_forward_prob_; // alpha for each state_label (Token) - // specific design for incremental GetLattice + unordered_map state_label_forward_cost_; + + /// This function is modified from LatticeFasterDecoderTpl::GetRawLattice() + /// and specific design for incremental GetLattice + /// It does the same thing as GetRawLattice in lattice-faster-decoder.cc except: + /// + /// i) it creates a initial state, and connect + /// all the tokens in the first frame of this chunk to the initial state + /// by an arc with a per-token state-label as its olabel + /// ii) it creates a final state, and connect + /// all the tokens in the last frame of this chunk to the final state + /// by an arc with a per-token state-label as its olabel + /// the state-label for a token in both i) and ii) should be the same + /// frame_begin and frame_end are the first and last frame of this chunk + /// if create_initial_state == false, we will not create initial state and + /// the corresponding state-label arcs. Similar for create_final_state + /// In incremental GetLattice, we do not create the initial state in + /// the first chunk, and we do not create the final state in the last chunk bool GetRawLattice(Lattice *ofst, bool use_final_probs, int32 frame_begin, int32 frame_end, bool create_initial_state, bool create_final_state); - BaseFloat best_cost_in_chunk_; - unordered_set initial_state_in_chunk_; + // Take care of the step 3 in GetLattice, which is to + // appending the new chunk in clat to the old one in olat + // If not_first_chunk == false, we do not need to append and just copy + // clat into olat + // Otherwise, we need to connect the last frame state of + // last chunk to the first frame state of this chunk. + // These begin and final states are corresponding to the same Token, + // guaranteed by unique state labels. + void AppendLatticeChunks(CompactLattice clat, bool not_first_chunk, + CompactLattice *olat); + + BaseFloat best_cost_in_chunk_; // for sanity check + unordered_set initial_state_in_chunk_; // for sanity check KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDecoderTpl); }; From 7c0f7d7c44f8d39b8683b57be7740278f7c49355 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Fri, 29 Mar 2019 12:13:23 +0800 Subject: [PATCH 07/60] each time we determinize the piece of lattice, instead of going all the way to the currently-decoded frame, we go up to, say, t-10 (unless this is the end of the utterance), and the same way that we put in temporary initial-probs, we also put in temporary final-probs which reflect the on the states at frame t-10. (we remove them later on, of course). --- src/decoder/lattice-incremental-decoder.cc | 61 +++++++++++++++------- src/decoder/lattice-incremental-decoder.h | 20 ++++--- 2 files changed, 56 insertions(+), 25 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 6291b0a0a20..bb790e69f34 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -84,7 +84,8 @@ void LatticeIncrementalDecoderTpl::InitDecoding() { state_label_available_idx_ = config_.max_word_id + 1; final_arc_list_.clear(); final_arc_list_prev_.clear(); - state_label_forward_cost_.clear(); + state_label_initial_cost_.clear(); + state_label_final_cost_.clear(); ProcessNonemitting(config_.beam); } @@ -104,14 +105,17 @@ bool LatticeIncrementalDecoderTpl::Decode( while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { if (NumFramesDecoded() % config_.prune_interval == 0) { PruneActiveTokens(config_.lattice_beam * config_.prune_scale); - // TODO: have a delay in GetLattice - GetLattice(false, false, &lat_); + // The chunk length of determinization is equal to prune_interval + // We have a delay on GetLattice to do determinization on more skinny lattices + if (NumFramesDecoded() - config_.determinize_delay > 0) + GetLattice(false, false, NumFramesDecoded() - config_.determinize_delay, + &lat_); } BaseFloat cost_cutoff = ProcessEmitting(decodable); ProcessNonemitting(cost_cutoff); } FinalizeDecoding(); - GetLattice(true, true, &lat_); + GetLattice(true, true, NumFramesDecoded(), &lat_); // Returns true if we have any kind of traceback available (not necessarily // to the end state; query ReachedFinal() for that). @@ -922,6 +926,7 @@ void LatticeIncrementalDecoderTpl::TopSortTokens( template bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, bool redeterminize, + int32 last_frame_of_chunk, CompactLattice *olat) { using namespace fst; bool not_first_chunk = last_get_lattice_frame_ != 0; @@ -929,13 +934,13 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, // last_get_lattice_frame_ is used to record the first frame of the chunk // last time we obtain from calling this function. If it reaches NumFramesDecoded() // we cannot generate any more chunk - if (last_get_lattice_frame_ < NumFramesDecoded()) { + if (last_get_lattice_frame_ < last_frame_of_chunk) { Lattice raw_fst; // step 1: Get lattice chunk with initial state // In this function, we do not create the initial state in // the first chunk, and we do not create the final state in the last chunk KALDI_ASSERT(GetRawLattice(&raw_fst, use_final_probs, last_get_lattice_frame_, - NumFramesDecoded(), not_first_chunk, + last_frame_of_chunk, not_first_chunk, !decoding_finalized_)); // step 2: Determinize the chunk CompactLattice clat; @@ -953,7 +958,7 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, final_arc_list_.clear(); // step 3: Appending the new chunk in clat to the old one in olat - AppendLatticeChunks(clat, not_first_chunk, olat); + AppendLatticeChunks(clat, not_first_chunk, last_frame_of_chunk, olat); // sanity check, remove them later CompactLattice cdecoded; @@ -966,13 +971,13 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, GetLinearSymbolSequence(decoded, &alignment, &words, &weight); BaseFloat offset_sum = 0; for (auto &i : cost_offsets_) offset_sum += i; - KALDI_ASSERT(alignment.size() == NumFramesDecoded()); + KALDI_ASSERT(alignment.size() == last_frame_of_chunk); // TODO: the following KALDI_ASSERT will fail some time, which is unexpected // KALDI_ASSERT(ApproxEqual(best_cost_in_chunk_, weight.Value1() + // weight.Value2()+offset_sum, 1e-2)); } - last_get_lattice_frame_ = NumFramesDecoded(); + last_get_lattice_frame_ = last_frame_of_chunk; // step 4: re-determinize the final lattice if (redeterminize) { DeterminizeLatticePrunedOptions det_opts; @@ -1065,7 +1070,7 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( BaseFloat cost_offset = tok->tot_cost; // We record these cost_offset, and after we appending two chunks // we will cancel them out - state_label_forward_cost_[id] = cost_offset; + state_label_initial_cost_[id] = cost_offset; Arc arc(0, id, Weight(0, cost_offset), cur_state); ofst->AddArc(begin_state, arc); } @@ -1078,6 +1083,9 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { StateId cur_state = tok_map[tok]; for (ForwardLinkT *l = tok->links; l != NULL; l = l->next) { + // for the arcs outgoing from the last frame Token in this chunk, we will + // create these arcs in the next chunk + if (f == frame_end && l->ilabel > 0) continue; typename unordered_map::const_iterator iter = tok_map.find(l->next_tok); KALDI_ASSERT(iter != tok_map.end()); @@ -1116,7 +1124,8 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( } } } - // step 1.3 create final_arc for later appending with the next chunk (in GetLattice) + // step 1.3 create final_arc for later appending with the next chunk (in + // GetLattice) if (create_final_state) { StateId end_state = ofst->AddState(); // final-state for the chunk ofst->SetFinal(end_state, Weight::One()); @@ -1135,8 +1144,16 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( // the cost of an arc connecting to the final state KALDI_ASSERT(ofst->Final(cur_state) != Weight::Zero()); Weight final_weight = ofst->Final(cur_state); - - Arc arc(0, id, final_weight, end_state); + // Use cost_offsets to guide DeterminizeLatticePruned() + // later in GetLattice() + // For now, we use extra_cost from the decoding stage , which has some + // "future information", as + // the final weights of this chunk + BaseFloat cost_offset = tok->extra_cost; + // We record these cost_offset, and after we appending two chunks + // we will cancel them out + state_label_final_cost_[id] = cost_offset; + Arc arc(0, id, Times(final_weight, Weight(0, cost_offset)), end_state); ofst->AddArc(cur_state, arc); ofst->SetFinal(cur_state, Weight::Zero()); } @@ -1146,9 +1163,10 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( template void LatticeIncrementalDecoderTpl::AppendLatticeChunks( - CompactLattice clat, bool not_first_chunk, CompactLattice *olat) { + CompactLattice clat, bool not_first_chunk, int32 last_frame_of_chunk, + CompactLattice *olat) { using namespace fst; - // step 3.1: Appending new chunk to the old one + // step 3.1: Appending new chunk to the old one int32 state_offset = olat->NumStates(); if (not_first_chunk) state_offset--; // since we do not append initial state in the first chunk @@ -1225,19 +1243,24 @@ void LatticeIncrementalDecoderTpl::AppendLatticeChunks( CompactLatticeArc arc_appended_mod(arc_appended); arc_appended_mod.nextstate = state_appended; - CompactLatticeWeight weight_offset; + CompactLatticeWeight weight_offset, weight_offset_final; weight_offset.SetWeight( - LatticeWeight(0, -state_label_forward_cost_[arc_appended.olabel])); + LatticeWeight(0, -state_label_initial_cost_[arc_appended.olabel])); + weight_offset_final.SetWeight( + LatticeWeight(0, -state_label_final_cost_[arc_appended.olabel])); arc_appended_mod.weight = Times(arc_appended_mod.weight, arc_unappended.weight); arc_appended_mod.weight = Times(arc_appended_mod.weight, olat->Final(prev_final_state)); arc_appended_mod.weight = Times(arc_appended_mod.weight, weight_offset); + arc_appended_mod.weight = + Times(arc_appended_mod.weight, weight_offset_final); arc_appended_mod.olabel = 0; arc_appended_mod.ilabel = 0; aiter.SetValue(arc_appended_mod); } // otherwise, it has been pruned - state_label_forward_cost_.erase(arc_appended.olabel); + state_label_initial_cost_.erase(arc_appended.olabel); + state_label_final_cost_.erase(arc_appended.olabel); } KALDI_ASSERT(prev_final_state != -1); // at least one arc should be appended // making all unmodified remaining arcs of final_arc_list_prev_ be connected to @@ -1245,7 +1268,7 @@ void LatticeIncrementalDecoderTpl::AppendLatticeChunks( olat->SetFinal(prev_final_state, CompactLatticeWeight::Zero()); } else olat->SetStart(0); - KALDI_VLOG(2) << "Frame: " << NumFramesDecoded() + KALDI_VLOG(2) << "Frame: " << last_frame_of_chunk << " states of chunk: " << clat.NumStates() << " states of the lattice: " << olat->NumStates(); } diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 19fc7cc0288..408e358f135 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -41,6 +41,7 @@ struct LatticeIncrementalDecoderConfig { int32 min_active; BaseFloat lattice_beam; int32 prune_interval; + int32 determinize_delay; bool determinize_lattice; // not inspected by this class... used in // command-line program. BaseFloat beam_delta; // has nothing to do with beam_ratio @@ -61,6 +62,7 @@ struct LatticeIncrementalDecoderConfig { min_active(200), lattice_beam(10.0), prune_interval(25), + determinize_delay(25), determinize_lattice(true), beam_delta(0.5), hash_ratio(2.0), @@ -79,6 +81,9 @@ struct LatticeIncrementalDecoderConfig { opts->Register("prune-interval", &prune_interval, "Interval (in frames) at " "which to prune tokens"); + opts->Register("determinize-delay", &determinize_delay, + "delay (in frames) at " + "which to incrementally determinize lattices"); opts->Register("determinize-lattice", &determinize_lattice, "If true, " "determinize the lattice (lattice-determinization, keeping only " @@ -393,8 +398,8 @@ class LatticeIncrementalDecoderTpl { /// The following part is specifically designed for incremental determinization /// /// The function obtains a CompactLattice for the part of this utterance that has - /// been decoded so far. If you call this multiple times (calling it on - /// every frame would not make sense, + /// been decoded so far. If you call this multiple times (calling it on + /// every frame would not make sense, /// but every, say, 10, to 40 frames might make sense) it will spread out /// the /// work of determinization over time,which might be useful for online @@ -422,7 +427,8 @@ class LatticeIncrementalDecoderTpl { /// CAUTION: this is not the same meaning as the return /// value of /// LatticeFasterDecoder::GetLattice(). - bool GetLattice(bool use_final_probs, bool redeterminize, CompactLattice *olat); + bool GetLattice(bool use_final_probs, bool redeterminize, + int32 last_frame_of_chunk, CompactLattice *olat); CompactLattice lat_; // the compact lattice we obtain int32 last_get_lattice_frame_; // the last time we call GetLattice unordered_map state_label_map_; // between Token and state_label @@ -430,11 +436,13 @@ class LatticeIncrementalDecoderTpl { const TransitionModel &trans_model_; // keep it for determinization std::vector> final_arc_list_; // keep final_arc std::vector> final_arc_list_prev_; - // We keep alpha for each state_label (Token). We need them before determinization + // We keep tot_cost or extra_cost for each state_label (Token) in final and + // initial arcs. We need them before determinization // We cancel them after determinization // TODO use 2 vector to replace this map, since state_label is continuous // in each frame, and we need 2 frames of them - unordered_map state_label_forward_cost_; + unordered_map state_label_initial_cost_; + unordered_map state_label_final_cost_; /// This function is modified from LatticeFasterDecoderTpl::GetRawLattice() /// and specific design for incremental GetLattice @@ -465,7 +473,7 @@ class LatticeIncrementalDecoderTpl { // These begin and final states are corresponding to the same Token, // guaranteed by unique state labels. void AppendLatticeChunks(CompactLattice clat, bool not_first_chunk, - CompactLattice *olat); + int32 last_frame_of_chunk, CompactLattice *olat); BaseFloat best_cost_in_chunk_; // for sanity check unordered_set initial_state_in_chunk_; // for sanity check From bb4e68f56cda7e6ed2dce56545075323db40f563 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Sat, 30 Mar 2019 08:11:47 +0800 Subject: [PATCH 08/60] bug fix: 1. in the determinized lattice, there could be multiple final arcs with the same state label. I need to change the logic here. 2. for the first chunk, there could be some final arcs starting from state 0, while for the last chunk, there could be some initial arcs ending in final state. Hence, I found that we cannot distinguish final and initial arcs by simply "if (s==0)" or "if (clat.Final(arc_appended.nextstate)!=CompactLatticeWeight::Zero()" --- src/decoder/lattice-incremental-decoder.cc | 23 +++++++++++++--------- src/decoder/lattice-incremental-decoder.h | 1 + 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index bb790e69f34..306664f8e26 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -970,11 +970,11 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, std::vector words; GetLinearSymbolSequence(decoded, &alignment, &words, &weight); BaseFloat offset_sum = 0; - for (auto &i : cost_offsets_) offset_sum += i; + for (int32 i = 1; i < last_frame_of_chunk; i++) offset_sum += cost_offsets_[i]; KALDI_ASSERT(alignment.size() == last_frame_of_chunk); - // TODO: the following KALDI_ASSERT will fail some time, which is unexpected - // KALDI_ASSERT(ApproxEqual(best_cost_in_chunk_, weight.Value1() + - // weight.Value2()+offset_sum, 1e-2)); + // the following KALDI_ASSERT will fail some time, which is unexpected + KALDI_ASSERT(ApproxEqual(best_cost_in_chunk_, weight.Value1() + + weight.Value2()+offset_sum, 1e-2)); } last_get_lattice_frame_ = last_frame_of_chunk; @@ -1066,7 +1066,6 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( // For now, we use alpha (tot_cost) from the decoding stage as // the initial weights of arcs connecting to the states in the begin // of this chunk - // TODO: calculate alpha but not use tot_cost BaseFloat cost_offset = tok->tot_cost; // We record these cost_offset, and after we appending two chunks // we will cancel them out @@ -1119,14 +1118,16 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( } ofst->SetFinal(cur_state, weight); // for sanity check + // we will use extra_cost in step 1.3 (see the following code) best_cost_in_chunk_ = std::min( - best_cost_in_chunk_, tok->tot_cost + weight.Value1() + weight.Value2()); + best_cost_in_chunk_, tok->tot_cost + tok->extra_cost + weight.Value1() + weight.Value2()); } } } // step 1.3 create final_arc for later appending with the next chunk (in // GetLattice) if (create_final_state) { + final_state_in_chunk_.clear(); StateId end_state = ofst->AddState(); // final-state for the chunk ofst->SetFinal(end_state, Weight::One()); @@ -1137,6 +1138,7 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( // We assign an unique state label for each of the token in the last frame // of this chunk int32 id = state_label_available_idx_++; + final_state_in_chunk_.insert(id); state_label_map_[tok] = id; // The final weight has been worked out in the previous for loop and // store in the states @@ -1157,6 +1159,7 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( ofst->AddArc(cur_state, arc); ofst->SetFinal(cur_state, Weight::Zero()); } + KALDI_VLOG(6) << final_state_in_chunk_.size(); } return (ofst->NumStates() > 0); } @@ -1198,13 +1201,16 @@ void LatticeIncrementalDecoderTpl::AppendLatticeChunks( } // Process state labels, which will be used in step 3.2 if (arc.olabel > config_.max_word_id) { // initial_arc - if (s == 0) { // record initial_arc in this chunk, we will use it right now + // In first chunk, there could be a final arc starting from state 0 + // In the last chunk, there could be a initial arc ending in final state + if (not_first_chunk && s == 0) { // record initial_arc in this chunk, we will use it right now initial_arc_map[arc.olabel] = aiter.Position(); if (last_get_lattice_frame_ != 0) // Erase since we are not interested initial_state_in_chunk_.erase(arc.olabel); } else { // final_arc // record final_arc in this chunk for the step 3.2 in the next call KALDI_ASSERT(clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()); + final_state_in_chunk_.erase(arc.olabel); final_arc_list_.push_back( pair(state_appended, aiter.Position())); } @@ -1213,6 +1219,7 @@ void LatticeIncrementalDecoderTpl::AppendLatticeChunks( } // sanity check, remove it later // KALDI_ASSERT(initial_state_in_chunk_.size() == 0); + // KALDI_ASSERT(final_state_in_chunk_.size() == 0); // step 3.2: connect the states between two chunks if (last_get_lattice_frame_ != 0) { @@ -1259,8 +1266,6 @@ void LatticeIncrementalDecoderTpl::AppendLatticeChunks( arc_appended_mod.ilabel = 0; aiter.SetValue(arc_appended_mod); } // otherwise, it has been pruned - state_label_initial_cost_.erase(arc_appended.olabel); - state_label_final_cost_.erase(arc_appended.olabel); } KALDI_ASSERT(prev_final_state != -1); // at least one arc should be appended // making all unmodified remaining arcs of final_arc_list_prev_ be connected to diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 408e358f135..be16789c560 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -477,6 +477,7 @@ class LatticeIncrementalDecoderTpl { BaseFloat best_cost_in_chunk_; // for sanity check unordered_set initial_state_in_chunk_; // for sanity check + unordered_set final_state_in_chunk_; // for sanity check KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDecoderTpl); }; From 86595f95b7e2f673fbb31168eafe01f7d794f901 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Sat, 30 Mar 2019 19:54:26 +0800 Subject: [PATCH 09/60] test in libri speech + grep -H Overall exp_dec/incre.fl.1f/base/ora.base exp_dec/incre.fl.1f/base/ora.den.base exp_dec/incre.fl.1f/incre/ora.base exp_dec/incre.fl.1f/incre/ora.den.base exp_dec/incre.fl.1f/base/ora.base:LOG (lattice-oracle[5.5.276~4-6f366]:main():lattice-oracle.cc:383) Overall %WER 1.70591 [ 342 / 20048, 109 insertions, 22 deletions, 211 substitutions ] exp_dec/incre.fl.1f/base/ora.den.base:LOG (lattice-depth[5.5.276~4-6f366]:main():lattice-depth.cc:79) Overall density is 25.1613 over 244027 frames. exp_dec/incre.fl.1f/incre/ora.base:LOG (lattice-oracle[5.5.276~4-6f366]:main():lattice-oracle.cc:383) Overall %WER 1.80567 [ 362 / 20048, 108 insertions, 25 deletions, 229 substitutions ] exp_dec/incre.fl.1f/incre/ora.den.base:LOG (lattice-depth[5.5.276~4-6f366]:main():lattice-depth.cc:79) Overall density is 28.1682 over 244027 frames. + grep -H WER exp_dec/incre.fl.1f/base/wer exp_dec/incre.fl.1f/incre/wer exp_dec/incre.fl.1f/base/wer:%WER 12.57 [ 2532 / 20138, 305 ins, 287 del, 1940 sub ] exp_dec/incre.fl.1f/incre/wer:%WER 12.57 [ 2532 / 20138, 305 ins, 287 del, 1940 sub ] + grep real exp_dec/incre.fl.1f/base/log/decode.1.log exp_dec/incre.fl.1f/incre/log/decode.1.log exp_dec/incre.fl.1f/base/log/decode.1.log:LOG (latgen-faster-mapped[5.5.276~4-6f366]:main():latgen-faster-mapped.cc:164) Time taken 48.4324s: real-time factor assuming 100 frames/sec is 0.912442 exp_dec/incre.fl.1f/incre/log/decode.1.log:LOG (latgen-incremental-mapped[5.5.276~4-6f366]:main():latgen-incremental-mapped.cc:164) Time taken 54.6669s: real-time factor assuming 100 frames/sec is 1.0299 --- src/decoder/lattice-incremental-decoder.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 306664f8e26..e98567edda3 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -946,7 +946,7 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, CompactLattice clat; #if 1 if (!DeterminizeLatticePhonePrunedWrapper( - trans_model_, &raw_fst, config_.lattice_beam, &clat, config_.det_opts)) + trans_model_, &raw_fst, config_.beam, &clat, config_.det_opts)) KALDI_WARN << "Determinization finished earlier than the beam"; #else // sanity check, remove it later @@ -972,9 +972,9 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, BaseFloat offset_sum = 0; for (int32 i = 1; i < last_frame_of_chunk; i++) offset_sum += cost_offsets_[i]; KALDI_ASSERT(alignment.size() == last_frame_of_chunk); - // the following KALDI_ASSERT will fail some time, which is unexpected - KALDI_ASSERT(ApproxEqual(best_cost_in_chunk_, weight.Value1() + - weight.Value2()+offset_sum, 1e-2)); + // TODO: the following KALDI_ASSERT will fail some time, which is unexpected + //KALDI_ASSERT(ApproxEqual(best_cost_in_chunk_, weight.Value1() + + // weight.Value2()+offset_sum, 1e-2)); } last_get_lattice_frame_ = last_frame_of_chunk; From 080e5b4c90701f8fb7f6e8f9f1f7e183fe859188 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Sat, 30 Mar 2019 21:57:22 +0800 Subject: [PATCH 10/60] clean; without class --- src/decoder/lattice-incremental-decoder.cc | 147 +++++++++------------ src/decoder/lattice-incremental-decoder.h | 36 ++--- 2 files changed, 79 insertions(+), 104 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index e98567edda3..5b361a959f3 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -939,42 +939,28 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, // step 1: Get lattice chunk with initial state // In this function, we do not create the initial state in // the first chunk, and we do not create the final state in the last chunk - KALDI_ASSERT(GetRawLattice(&raw_fst, use_final_probs, last_get_lattice_frame_, - last_frame_of_chunk, not_first_chunk, - !decoding_finalized_)); + if (!GetRawLattice(&raw_fst, use_final_probs, last_get_lattice_frame_, + last_frame_of_chunk, not_first_chunk, !decoding_finalized_)) + KALDI_ERR << "Unexpected problem when getting lattice"; // step 2: Determinize the chunk CompactLattice clat; -#if 1 - if (!DeterminizeLatticePhonePrunedWrapper( - trans_model_, &raw_fst, config_.beam, &clat, config_.det_opts)) + // We do determinization with beam pruning here + // Only if we use a beam larger than (config_.beam+config_.lattice_beam) here, we + // can guarantee no final or initial arcs in clat are pruned by this function. + // These pruned final arcs can hurt oracle WER performance in the final lattice + // (also result in less lattice density) but they seldom hurt 1-best WER. + // Moreover, if we use (config_.beam) as the beam here, the oracle WER + // performance can be similar to that in lattice-faster-decoder (also similar + // lattice density). Hence we decide to use config_.beam here. + if (!DeterminizeLatticePhonePrunedWrapper(trans_model_, &raw_fst, config_.beam, + &clat, config_.det_opts)) KALDI_WARN << "Determinization finished earlier than the beam"; -#else - // sanity check, remove it later - ConvertLattice(raw_fst, &clat); - Connect(&clat); -#endif final_arc_list_.swap(final_arc_list_prev_); final_arc_list_.clear(); // step 3: Appending the new chunk in clat to the old one in olat AppendLatticeChunks(clat, not_first_chunk, last_frame_of_chunk, olat); - - // sanity check, remove them later - CompactLattice cdecoded; - Lattice decoded; - ShortestPath(*olat, &cdecoded); - ConvertLattice(cdecoded, &decoded); - LatticeWeight weight; - std::vector alignment; - std::vector words; - GetLinearSymbolSequence(decoded, &alignment, &words, &weight); - BaseFloat offset_sum = 0; - for (int32 i = 1; i < last_frame_of_chunk; i++) offset_sum += cost_offsets_[i]; - KALDI_ASSERT(alignment.size() == last_frame_of_chunk); - // TODO: the following KALDI_ASSERT will fail some time, which is unexpected - //KALDI_ASSERT(ApproxEqual(best_cost_in_chunk_, weight.Value1() + - // weight.Value2()+offset_sum, 1e-2)); } last_get_lattice_frame_ = last_frame_of_chunk; @@ -1053,14 +1039,12 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( << " max:" << tok_map.max_load_factor(); // step 1.1: create initial_arc for later appending with the previous chunk if (create_initial_state) { - initial_state_in_chunk_.clear(); for (Token *tok = active_toks_[frame_begin].toks; tok != NULL; tok = tok->next) { StateId cur_state = tok_map[tok]; // state_label_map_ is construct during create_final_state auto r = state_label_map_.find(tok); KALDI_ASSERT(r != state_label_map_.end()); // it should exist int32 id = r->second; - initial_state_in_chunk_.insert(id); // Use cost_offsets to guide DeterminizeLatticePruned() // later in GetLattice() // For now, we use alpha (tot_cost) from the decoding stage as @@ -1073,10 +1057,7 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( Arc arc(0, id, Weight(0, cost_offset), cur_state); ofst->AddArc(begin_state, arc); } - KALDI_VLOG(6) << initial_state_in_chunk_.size(); } - // for sanity check - best_cost_in_chunk_ = std::numeric_limits::infinity(); // step 1.2: create all arcs as GetRawLattice() for (int32 f = frame_begin; f <= frame_end; f++) { for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { @@ -1117,17 +1098,12 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( weight = LatticeWeight::Zero(); } ofst->SetFinal(cur_state, weight); - // for sanity check - // we will use extra_cost in step 1.3 (see the following code) - best_cost_in_chunk_ = std::min( - best_cost_in_chunk_, tok->tot_cost + tok->extra_cost + weight.Value1() + weight.Value2()); } } } // step 1.3 create final_arc for later appending with the next chunk (in // GetLattice) if (create_final_state) { - final_state_in_chunk_.clear(); StateId end_state = ofst->AddState(); // final-state for the chunk ofst->SetFinal(end_state, Weight::One()); @@ -1138,7 +1114,6 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( // We assign an unique state label for each of the token in the last frame // of this chunk int32 id = state_label_available_idx_++; - final_state_in_chunk_.insert(id); state_label_map_[tok] = id; // The final weight has been worked out in the previous for loop and // store in the states @@ -1159,7 +1134,6 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( ofst->AddArc(cur_state, arc); ofst->SetFinal(cur_state, Weight::Zero()); } - KALDI_VLOG(6) << final_state_in_chunk_.size(); } return (ofst->NumStates() > 0); } @@ -1174,7 +1148,7 @@ void LatticeIncrementalDecoderTpl::AppendLatticeChunks( if (not_first_chunk) state_offset--; // since we do not append initial state in the first chunk - // A map between state label and the arc position (index) + // A map from state label to the arc position (index) // the incoming states of these arcs are initial states of the chunk // and the olabel of these arcs are the key of this map (state label) // The arc position are obtained from ArcIterator corresponding to the state @@ -1203,76 +1177,77 @@ void LatticeIncrementalDecoderTpl::AppendLatticeChunks( if (arc.olabel > config_.max_word_id) { // initial_arc // In first chunk, there could be a final arc starting from state 0 // In the last chunk, there could be a initial arc ending in final state - if (not_first_chunk && s == 0) { // record initial_arc in this chunk, we will use it right now + if (not_first_chunk && + s == 0) { // record initial_arc in this chunk, we will use it right now initial_arc_map[arc.olabel] = aiter.Position(); if (last_get_lattice_frame_ != 0) // Erase since we are not interested - initial_state_in_chunk_.erase(arc.olabel); } else { // final_arc // record final_arc in this chunk for the step 3.2 in the next call KALDI_ASSERT(clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()); - final_state_in_chunk_.erase(arc.olabel); final_arc_list_.push_back( pair(state_appended, aiter.Position())); } } } } - // sanity check, remove it later - // KALDI_ASSERT(initial_state_in_chunk_.size() == 0); - // KALDI_ASSERT(final_state_in_chunk_.size() == 0); - // step 3.2: connect the states between two chunks + // step 3.2: connect the states between two chunks, i.e. chunk1 in olat and chunk2 + // in clat in the following + // Notably, most states and arcs of clat has been copied to olat in step 3.1 + // This step is mainly to process the boundary of these two chunks if (last_get_lattice_frame_ != 0) { KALDI_ASSERT(final_arc_list_prev_.size()); - StateId prev_final_state = -1; + vector prev_final_states; for (auto &i : final_arc_list_prev_) { - MutableArcIterator aiter(olat, i.first); - aiter.Seek(i.second); - // Obtain the appended final arcs in the last chunk - auto &arc_appended = aiter.Value(); + MutableArcIterator aiter_chunk1(olat, i.first); + aiter_chunk1.Seek(i.second); + // Obtain the appended final arcs in the previous chunk + auto &arc_chunk1 = aiter_chunk1.Value(); // Find out whether its corresponding Token still exists in the begin // of this chunk. If not, it is pruned by PruneActiveTokens() - auto r = initial_arc_map.find(arc_appended.olabel); + auto r = initial_arc_map.find(arc_chunk1.olabel); if (r != initial_arc_map.end()) { - ArcIterator aiter_unappended(clat, 0); // initial state - aiter_unappended.Seek(r->second); - const auto &arc_unappended = aiter_unappended.Value(); - KALDI_ASSERT(arc_unappended.olabel == arc_appended.olabel); - StateId state_appended = arc_unappended.nextstate + state_offset; - if (prev_final_state == -1) - prev_final_state = arc_appended.nextstate; - else - KALDI_ASSERT(arc_appended.nextstate == prev_final_state); - // For the later code in this loop, we try to modify the arc_appended + ArcIterator aiter_chunk2(clat, 0); // initial state + aiter_chunk2.Seek(r->second); + const auto &arc_chunk2 = aiter_chunk2.Value(); + KALDI_ASSERT(arc_chunk2.olabel == arc_chunk1.olabel); + StateId state_chunk1 = arc_chunk2.nextstate + state_offset; + StateId prev_final_state = arc_chunk1.nextstate; + prev_final_states.push_back(prev_final_state); + // For the later code in this loop, we try to modify the arc_chunk1 // to connect the last frame state of last chunk to the first frame // state of this chunk. These begin and final states are // corresponding to the same Token, guaranteed by unique state labels. - CompactLatticeArc arc_appended_mod(arc_appended); - arc_appended_mod.nextstate = state_appended; - - CompactLatticeWeight weight_offset, weight_offset_final; - weight_offset.SetWeight( - LatticeWeight(0, -state_label_initial_cost_[arc_appended.olabel])); - weight_offset_final.SetWeight( - LatticeWeight(0, -state_label_final_cost_[arc_appended.olabel])); - arc_appended_mod.weight = - Times(arc_appended_mod.weight, arc_unappended.weight); - arc_appended_mod.weight = - Times(arc_appended_mod.weight, olat->Final(prev_final_state)); - arc_appended_mod.weight = Times(arc_appended_mod.weight, weight_offset); - arc_appended_mod.weight = - Times(arc_appended_mod.weight, weight_offset_final); - arc_appended_mod.olabel = 0; - arc_appended_mod.ilabel = 0; - aiter.SetValue(arc_appended_mod); + CompactLatticeArc arc_chunk1_mod(arc_chunk1); + arc_chunk1_mod.nextstate = state_chunk1; + { // Update arc weight in this section + CompactLatticeWeight weight_offset, weight_offset_final; + auto r1 = state_label_initial_cost_.find(arc_chunk1.olabel); + KALDI_ASSERT(r1 != state_label_initial_cost_.end()); + weight_offset.SetWeight(LatticeWeight(0, -r->second)); + auto r2 = state_label_final_cost_.find(arc_chunk1.olabel); + KALDI_ASSERT(r2 != state_label_final_cost_.end()); + weight_offset_final.SetWeight(LatticeWeight(0, -r->second)); + arc_chunk1_mod.weight = Times( + Times(Times(Times(arc_chunk2.weight, olat->Final(prev_final_state)), + weight_offset), + weight_offset_final), + arc_chunk1_mod.weight); + } + // After appending, state labels are of no use and we remove them + arc_chunk1_mod.olabel = 0; + arc_chunk1_mod.ilabel = 0; + aiter_chunk1.SetValue(arc_chunk1_mod); } // otherwise, it has been pruned } - KALDI_ASSERT(prev_final_state != -1); // at least one arc should be appended - // making all unmodified remaining arcs of final_arc_list_prev_ be connected to - // a dead state - olat->SetFinal(prev_final_state, CompactLatticeWeight::Zero()); + KALDI_ASSERT(prev_final_states.size()); // at least one arc should be appended + // Making all unmodified remaining arcs of final_arc_list_prev_ be connected to + // a dead state. The following prev_final_states can be the same or different states + for (auto i:prev_final_states) + olat->SetFinal(i, CompactLatticeWeight::Zero()); } else - olat->SetStart(0); + olat->SetStart(0); // Initialize the first chunk for olat + KALDI_VLOG(2) << "Frame: " << last_frame_of_chunk << " states of chunk: " << clat.NumStates() << " states of the lattice: " << olat->NumStates(); diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index be16789c560..ad6796dd5de 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -418,6 +418,10 @@ class LatticeIncrementalDecoderTpl { /// This does extra work, but not nearly as much as /// determinizing /// a RawLattice from scratch. + /// @param [in] last_frame_of_chunk Pass the last frame of this chunk to + /// the function. We make it not always equal to + /// NumFramesDecoded() to have a delay on the + /// deteriminization /// @param [out] lat The CompactLattice representing what has been decoded /// so far. /// @return reached_final This function will returns true if a state that was @@ -429,21 +433,6 @@ class LatticeIncrementalDecoderTpl { /// LatticeFasterDecoder::GetLattice(). bool GetLattice(bool use_final_probs, bool redeterminize, int32 last_frame_of_chunk, CompactLattice *olat); - CompactLattice lat_; // the compact lattice we obtain - int32 last_get_lattice_frame_; // the last time we call GetLattice - unordered_map state_label_map_; // between Token and state_label - int32 state_label_available_idx_; // we allocate a unique id for each Token - const TransitionModel &trans_model_; // keep it for determinization - std::vector> final_arc_list_; // keep final_arc - std::vector> final_arc_list_prev_; - // We keep tot_cost or extra_cost for each state_label (Token) in final and - // initial arcs. We need them before determinization - // We cancel them after determinization - // TODO use 2 vector to replace this map, since state_label is continuous - // in each frame, and we need 2 frames of them - unordered_map state_label_initial_cost_; - unordered_map state_label_final_cost_; - /// This function is modified from LatticeFasterDecoderTpl::GetRawLattice() /// and specific design for incremental GetLattice /// It does the same thing as GetRawLattice in lattice-faster-decoder.cc except: @@ -464,6 +453,20 @@ class LatticeIncrementalDecoderTpl { int32 frame_end, bool create_initial_state, bool create_final_state); + CompactLattice lat_; // the compact lattice we obtain + int32 last_get_lattice_frame_; // the last time we call GetLattice + unordered_map state_label_map_; // between Token and state_label + int32 state_label_available_idx_; // we allocate a unique id for each Token + // We keep tot_cost or extra_cost for each state_label (Token) in final and + // initial arcs. We need them before determinization + // We cancel them after determinization + unordered_map state_label_initial_cost_; + unordered_map state_label_final_cost_; + + const TransitionModel &trans_model_; // keep it for determinization + std::vector> final_arc_list_; // keep final_arc + std::vector> final_arc_list_prev_; + // Take care of the step 3 in GetLattice, which is to // appending the new chunk in clat to the old one in olat // If not_first_chunk == false, we do not need to append and just copy @@ -475,9 +478,6 @@ class LatticeIncrementalDecoderTpl { void AppendLatticeChunks(CompactLattice clat, bool not_first_chunk, int32 last_frame_of_chunk, CompactLattice *olat); - BaseFloat best_cost_in_chunk_; // for sanity check - unordered_set initial_state_in_chunk_; // for sanity check - unordered_set final_state_in_chunk_; // for sanity check KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDecoderTpl); }; From d8907a4790671ffc5b916152d77584f84806b7ba Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Sun, 31 Mar 2019 09:07:59 +0800 Subject: [PATCH 11/60] add class LatticeIncrementalDeterminizer --- src/decoder/decoder-wrappers.cc | 2 +- src/decoder/lattice-incremental-decoder.cc | 223 ++++++++++++-------- src/decoder/lattice-incremental-decoder.h | 227 ++++++++++++++------- 3 files changed, 286 insertions(+), 166 deletions(-) diff --git a/src/decoder/decoder-wrappers.cc b/src/decoder/decoder-wrappers.cc index 22655878caa..294a2f69117 100644 --- a/src/decoder/decoder-wrappers.cc +++ b/src/decoder/decoder-wrappers.cc @@ -261,7 +261,7 @@ bool DecodeUtteranceLatticeIncremental( // Get lattice, and do determinization if requested. CompactLattice clat; - decoder.GetCompactLattice(&clat); + decoder.GetLattice(&clat); if (clat.NumStates() == 0) KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt; if (determinize) { diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 5b361a959f3..639aaf9a3d2 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -34,7 +34,7 @@ LatticeIncrementalDecoderTpl::LatticeIncrementalDecoderTpl( delete_fst_(false), config_(config), num_toks_(0), - trans_model_(trans_model) { + determinizer_(config, trans_model) { config.Check(); toks_.SetSize(1000); // just so on the first frame we do something reasonable. } @@ -47,7 +47,7 @@ LatticeIncrementalDecoderTpl::LatticeIncrementalDecoderTpl( delete_fst_(true), config_(config), num_toks_(0), - trans_model_(trans_model) { + determinizer_(config, trans_model) { config.Check(); toks_.SetSize(1000); // just so on the first frame we do something reasonable. } @@ -77,15 +77,13 @@ void LatticeIncrementalDecoderTpl::InitDecoding() { toks_.Insert(start_state, start_tok); num_toks_++; - lat_.DeleteStates(); last_get_lattice_frame_ = 0; state_label_map_.clear(); state_label_map_.reserve(std::min((int32)1e5, config_.max_active)); state_label_available_idx_ = config_.max_word_id + 1; - final_arc_list_.clear(); - final_arc_list_prev_.clear(); state_label_initial_cost_.clear(); state_label_final_cost_.clear(); + determinizer_.Init(); ProcessNonemitting(config_.beam); } @@ -108,14 +106,13 @@ bool LatticeIncrementalDecoderTpl::Decode( // The chunk length of determinization is equal to prune_interval // We have a delay on GetLattice to do determinization on more skinny lattices if (NumFramesDecoded() - config_.determinize_delay > 0) - GetLattice(false, false, NumFramesDecoded() - config_.determinize_delay, - &lat_); + GetLattice(false, false, NumFramesDecoded() - config_.determinize_delay); } BaseFloat cost_cutoff = ProcessEmitting(decodable); ProcessNonemitting(cost_cutoff); } FinalizeDecoding(); - GetLattice(true, true, NumFramesDecoded(), &lat_); + GetLattice(true, config_.redeterminize, NumFramesDecoded()); // Returns true if we have any kind of traceback available (not necessarily // to the end state; query ReachedFinal() for that). @@ -124,30 +121,26 @@ bool LatticeIncrementalDecoderTpl::Decode( // Outputs an FST corresponding to the single best path through the lattice. template -bool LatticeIncrementalDecoderTpl::GetBestPath( - Lattice *olat, bool use_final_probs) const { - CompactLattice lat; - ShortestPath(lat_, &lat); - ConvertLattice(lat, olat); +bool LatticeIncrementalDecoderTpl::GetBestPath(Lattice *olat, + bool use_final_probs) { + CompactLattice lat, slat; + GetLattice(use_final_probs, config_.redeterminize, NumFramesDecoded(), &lat); + ShortestPath(lat, &slat); + ConvertLattice(slat, olat); return (olat->NumStates() != 0); } // Outputs an FST corresponding to the raw, state-level lattice template -bool LatticeIncrementalDecoderTpl::GetRawLattice( - Lattice *ofst, bool use_final_probs) const { - ConvertLattice(lat_, ofst); +bool LatticeIncrementalDecoderTpl::GetRawLattice(Lattice *ofst, + bool use_final_probs) { + CompactLattice lat; + GetLattice(use_final_probs, config_.redeterminize, NumFramesDecoded(), &lat); + ConvertLattice(lat, ofst); Connect(ofst); return (ofst->NumStates() != 0); } -template -bool LatticeIncrementalDecoderTpl::GetCompactLattice( - CompactLattice *ofst) const { - *ofst = lat_; - return (ofst->NumStates() != 0); -} - template void LatticeIncrementalDecoderTpl::PossiblyResizeHash(size_t num_toks) { size_t new_sz = @@ -923,6 +916,11 @@ void LatticeIncrementalDecoderTpl::TopSortTokens( (*topsorted_list)[iter->second] = iter->first; } +template +bool LatticeIncrementalDecoderTpl::GetLattice(CompactLattice *olat) { + return GetLattice(true, config_.redeterminize, NumFramesDecoded(), olat); +} + template bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, bool redeterminize, @@ -930,63 +928,36 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, CompactLattice *olat) { using namespace fst; bool not_first_chunk = last_get_lattice_frame_ != 0; + bool ret = true; // last_get_lattice_frame_ is used to record the first frame of the chunk - // last time we obtain from calling this function. If it reaches NumFramesDecoded() + // last time we obtain from calling this function. If it reaches + // last_frame_of_chunk // we cannot generate any more chunk if (last_get_lattice_frame_ < last_frame_of_chunk) { Lattice raw_fst; - // step 1: Get lattice chunk with initial state + // step 1: Get lattice chunk with initial and final states // In this function, we do not create the initial state in // the first chunk, and we do not create the final state in the last chunk if (!GetRawLattice(&raw_fst, use_final_probs, last_get_lattice_frame_, last_frame_of_chunk, not_first_chunk, !decoding_finalized_)) KALDI_ERR << "Unexpected problem when getting lattice"; - // step 2: Determinize the chunk - CompactLattice clat; - // We do determinization with beam pruning here - // Only if we use a beam larger than (config_.beam+config_.lattice_beam) here, we - // can guarantee no final or initial arcs in clat are pruned by this function. - // These pruned final arcs can hurt oracle WER performance in the final lattice - // (also result in less lattice density) but they seldom hurt 1-best WER. - // Moreover, if we use (config_.beam) as the beam here, the oracle WER - // performance can be similar to that in lattice-faster-decoder (also similar - // lattice density). Hence we decide to use config_.beam here. - if (!DeterminizeLatticePhonePrunedWrapper(trans_model_, &raw_fst, config_.beam, - &clat, config_.det_opts)) - KALDI_WARN << "Determinization finished earlier than the beam"; - - final_arc_list_.swap(final_arc_list_prev_); - final_arc_list_.clear(); - - // step 3: Appending the new chunk in clat to the old one in olat - AppendLatticeChunks(clat, not_first_chunk, last_frame_of_chunk, olat); - } - - last_get_lattice_frame_ = last_frame_of_chunk; - // step 4: re-determinize the final lattice - if (redeterminize) { - DeterminizeLatticePrunedOptions det_opts; - det_opts.delta = config_.det_opts.delta; - det_opts.max_mem = config_.det_opts.max_mem; - Lattice lat; - ConvertLattice(*olat, &lat); - Invert(&lat); - if (lat.Properties(fst::kTopSorted, true) == 0) { - if (!TopSort(&lat)) { - // Cannot topologically sort the lattice -- determinization will fail. - KALDI_ERR << "Topological sorting of state-level lattice failed (probably" - << " your lexicon has empty words or your LM has epsilon cycles" - << ")."; - } - } - if (!DeterminizeLatticePruned(lat, config_.lattice_beam, olat, det_opts)) - KALDI_WARN << "Determinization finished earlier than the beam"; - Connect(olat); // Remove unreachable states... there might be - KALDI_VLOG(2) << "states of the lattice: " << olat->NumStates(); + ret = determinizer_.ProcessChunk(raw_fst, last_get_lattice_frame_, + last_frame_of_chunk, state_label_initial_cost_, + state_label_final_cost_); + last_get_lattice_frame_ = last_frame_of_chunk; + } else if (last_get_lattice_frame_ > last_frame_of_chunk) + KALDI_WARN << "Call GetLattice up to frame: " << last_frame_of_chunk + << " while the determinizer_ has already done up to frame: " + << last_get_lattice_frame_; + + if (decoding_finalized_) ret &= determinizer_.Finalize(redeterminize); + if (olat) { + *olat = determinizer_.GetDeterminizedLattice(); + ret &= (olat->NumStates() > 0); } - return (olat->NumStates() != 0); + return ret; } template @@ -1046,7 +1017,7 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( KALDI_ASSERT(r != state_label_map_.end()); // it should exist int32 id = r->second; // Use cost_offsets to guide DeterminizeLatticePruned() - // later in GetLattice() + // later // For now, we use alpha (tot_cost) from the decoding stage as // the initial weights of arcs connecting to the states in the begin // of this chunk @@ -1101,8 +1072,7 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( } } } - // step 1.3 create final_arc for later appending with the next chunk (in - // GetLattice) + // step 1.3 create final_arc for later appending with the next chunk if (create_final_state) { StateId end_state = ofst->AddState(); // final-state for the chunk ofst->SetFinal(end_state, Weight::One()); @@ -1122,7 +1092,6 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( KALDI_ASSERT(ofst->Final(cur_state) != Weight::Zero()); Weight final_weight = ofst->Final(cur_state); // Use cost_offsets to guide DeterminizeLatticePruned() - // later in GetLattice() // For now, we use extra_cost from the decoding stage , which has some // "future information", as // the final weights of this chunk @@ -1138,11 +1107,57 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( return (ofst->NumStates() > 0); } -template -void LatticeIncrementalDecoderTpl::AppendLatticeChunks( - CompactLattice clat, bool not_first_chunk, int32 last_frame_of_chunk, - CompactLattice *olat) { +template +LatticeIncrementalDeterminizer::LatticeIncrementalDeterminizer( + const LatticeIncrementalDecoderConfig &config, + const TransitionModel &trans_model) + : config_(config), trans_model_(trans_model) {} + +template +void LatticeIncrementalDeterminizer::Init() { + final_arc_list_.clear(); + final_arc_list_prev_.clear(); + lat_.DeleteStates(); + determinization_finalized_ = false; +} + +template +bool LatticeIncrementalDeterminizer::ProcessChunk( + Lattice &raw_fst, int32 first_frame, int32 last_frame, + const unordered_map &state_label_initial_cost, + const unordered_map &state_label_final_cost) { + bool not_first_chunk = first_frame != 0; + // step 2: Determinize the chunk + CompactLattice clat; + // We do determinization with beam pruning here + // Only if we use a beam larger than (config_.beam+config_.lattice_beam) here, we + // can guarantee no final or initial arcs in clat are pruned by this function. + // These pruned final arcs can hurt oracle WER performance in the final lattice + // (also result in less lattice density) but they seldom hurt 1-best WER. + if (!DeterminizeLatticePhonePrunedWrapper(trans_model_, &raw_fst, config_.beam, + &clat, config_.det_opts)) + KALDI_WARN << "Determinization finished earlier than the beam"; + + final_arc_list_.swap(final_arc_list_prev_); + final_arc_list_.clear(); + + // step 3: Appending the new chunk in clat to the old one in lat_ + AppendLatticeChunks(clat, not_first_chunk, state_label_initial_cost, + state_label_final_cost); + KALDI_VLOG(2) << "Frame: ( " << first_frame << " , " << last_frame << " )" + << " states of the chunk: " << clat.NumStates() + << " states of the lattice: " << lat_.NumStates(); + + return (lat_.NumStates() > 0); +} + +template +void LatticeIncrementalDeterminizer::AppendLatticeChunks( + CompactLattice clat, bool not_first_chunk, + const unordered_map &state_label_initial_cost, + const unordered_map &state_label_final_cost) { using namespace fst; + CompactLattice *olat = &lat_; // step 3.1: Appending new chunk to the old one int32 state_offset = olat->NumStates(); if (not_first_chunk) @@ -1180,7 +1195,6 @@ void LatticeIncrementalDecoderTpl::AppendLatticeChunks( if (not_first_chunk && s == 0) { // record initial_arc in this chunk, we will use it right now initial_arc_map[arc.olabel] = aiter.Position(); - if (last_get_lattice_frame_ != 0) // Erase since we are not interested } else { // final_arc // record final_arc in this chunk for the step 3.2 in the next call KALDI_ASSERT(clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()); @@ -1195,7 +1209,7 @@ void LatticeIncrementalDecoderTpl::AppendLatticeChunks( // in clat in the following // Notably, most states and arcs of clat has been copied to olat in step 3.1 // This step is mainly to process the boundary of these two chunks - if (last_get_lattice_frame_ != 0) { + if (not_first_chunk) { KALDI_ASSERT(final_arc_list_prev_.size()); vector prev_final_states; for (auto &i : final_arc_list_prev_) { @@ -1222,12 +1236,12 @@ void LatticeIncrementalDecoderTpl::AppendLatticeChunks( arc_chunk1_mod.nextstate = state_chunk1; { // Update arc weight in this section CompactLatticeWeight weight_offset, weight_offset_final; - auto r1 = state_label_initial_cost_.find(arc_chunk1.olabel); - KALDI_ASSERT(r1 != state_label_initial_cost_.end()); - weight_offset.SetWeight(LatticeWeight(0, -r->second)); - auto r2 = state_label_final_cost_.find(arc_chunk1.olabel); - KALDI_ASSERT(r2 != state_label_final_cost_.end()); - weight_offset_final.SetWeight(LatticeWeight(0, -r->second)); + const auto r1 = state_label_initial_cost.find(arc_chunk1.olabel); + KALDI_ASSERT(r1 != state_label_initial_cost.end()); + weight_offset.SetWeight(LatticeWeight(0, -r1->second)); + const auto r2 = state_label_final_cost.find(arc_chunk1.olabel); + KALDI_ASSERT(r2 != state_label_final_cost.end()); + weight_offset_final.SetWeight(LatticeWeight(0, -r2->second)); arc_chunk1_mod.weight = Times( Times(Times(Times(arc_chunk2.weight, olat->Final(prev_final_state)), weight_offset), @@ -1242,15 +1256,44 @@ void LatticeIncrementalDecoderTpl::AppendLatticeChunks( } KALDI_ASSERT(prev_final_states.size()); // at least one arc should be appended // Making all unmodified remaining arcs of final_arc_list_prev_ be connected to - // a dead state. The following prev_final_states can be the same or different states - for (auto i:prev_final_states) - olat->SetFinal(i, CompactLatticeWeight::Zero()); + // a dead state. The following prev_final_states can be the same or different + // states + for (auto i : prev_final_states) olat->SetFinal(i, CompactLatticeWeight::Zero()); } else olat->SetStart(0); // Initialize the first chunk for olat +} + +template +bool LatticeIncrementalDeterminizer::Finalize(bool redeterminize) { + using namespace fst; + auto *olat = &lat_; + // The lattice determinization only needs to be finalized once + if (determinization_finalized_) return true; + // step 4: re-determinize the final lattice + if (redeterminize) { + Connect(olat); // Remove unreachable states... there might be + DeterminizeLatticePrunedOptions det_opts; + det_opts.delta = config_.det_opts.delta; + det_opts.max_mem = config_.det_opts.max_mem; + Lattice lat; + ConvertLattice(*olat, &lat); + Invert(&lat); + if (lat.Properties(fst::kTopSorted, true) == 0) { + if (!TopSort(&lat)) { + // Cannot topologically sort the lattice -- determinization will fail. + KALDI_ERR << "Topological sorting of state-level lattice failed (probably" + << " your lexicon has empty words or your LM has epsilon cycles" + << ")."; + } + } + if (!DeterminizeLatticePruned(lat, config_.lattice_beam, olat, det_opts)) + KALDI_WARN << "Determinization finished earlier than the beam"; + } + Connect(olat); // Remove unreachable states... there might be + KALDI_VLOG(2) << "states of the lattice: " << olat->NumStates(); + determinization_finalized_ = true; - KALDI_VLOG(2) << "Frame: " << last_frame_of_chunk - << " states of chunk: " << clat.NumStates() - << " states of the lattice: " << olat->NumStates(); + return (olat->NumStates() > 0); } // Instantiate the template for the combination of token types and FST types diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index ad6796dd5de..f0e82275433 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -42,6 +42,7 @@ struct LatticeIncrementalDecoderConfig { BaseFloat lattice_beam; int32 prune_interval; int32 determinize_delay; + bool redeterminize; bool determinize_lattice; // not inspected by this class... used in // command-line program. BaseFloat beam_delta; // has nothing to do with beam_ratio @@ -63,6 +64,7 @@ struct LatticeIncrementalDecoderConfig { lattice_beam(10.0), prune_interval(25), determinize_delay(25), + redeterminize(true), determinize_lattice(true), beam_delta(0.5), hash_ratio(2.0), @@ -84,6 +86,9 @@ struct LatticeIncrementalDecoderConfig { opts->Register("determinize-delay", &determinize_delay, "delay (in frames) at " "which to incrementally determinize lattices"); + opts->Register("redeterminize", &redeterminize, + "whether to re-determinize the lattice after incremental " + "determinization."); opts->Register("determinize-lattice", &determinize_lattice, "If true, " "determinize the lattice (lattice-determinization, keeping only " @@ -104,6 +109,9 @@ struct LatticeIncrementalDecoderConfig { } }; +template +class LatticeIncrementalDeterminizer; + /** This is the "normal" lattice-generating decoder. See \ref lattices_generation \ref decoders_faster and \ref decoders_simple for more information. @@ -165,7 +173,7 @@ class LatticeIncrementalDecoderTpl { /// final-state of the graph then it will include those as final-probs, else /// it will treat all final-probs as one. Note: this just calls GetRawLattice() /// and figures out the shortest path. - bool GetBestPath(Lattice *ofst, bool use_final_probs = true) const; + bool GetBestPath(Lattice *ofst, bool use_final_probs = true); /// Outputs an FST corresponding to the raw, state-level /// tracebacks. Returns true if result is nonempty. @@ -173,13 +181,73 @@ class LatticeIncrementalDecoderTpl { /// of the graph then it will include those as final-probs, else /// it will treat all final-probs as one. /// The raw lattice will be topologically sorted. + /// Notably, the raw lattice from this incremental determinization decoder + /// has already been partially determinized /// /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, /// which also supports a pruning beam, in case for some reason /// you want it pruned tighter than the regular lattice beam. /// We could put that here in future needed. - bool GetRawLattice(Lattice *ofst, bool use_final_probs = true) const; - bool GetCompactLattice(CompactLattice *ofst) const; + bool GetRawLattice(Lattice *ofst, bool use_final_probs = true); + + /// The following function is specifically designed for incremental + /// determinization. The function obtains a CompactLattice for + /// the part of this utterance up to the frame last_frame_of_chunk. + /// If you call this multiple times + /// (calling it on every frame would not make sense, + /// but every, say, 10 to 40 frames might make sense) it will spread out the + /// work of determinization over time, which might be useful for online + /// applications. + /// + /// The procedure of incremental determinization is as follow: + /// step 1: Get lattice chunk with initial and final states corresponding + /// to tokens in the first and last frames of this chunk + /// We need to give permanent labels (called "state labels") to these + /// raw-lattice states (Tokens). + /// step 2: Determinize the chunk of above raw lattice using determinization + /// algorithm the same as LatticeFasterDecoder. We call the determinized new + /// chunk "clat" + /// step 3: Appending the new chunk "clat" to the determinized lattice + /// before this chunk. First, for each state-id in clat2 *except* its + /// initial state, allocate a new state-id in the appended + /// compact-lattice. Copy the arcs except whose incoming state is initial + /// state. Secondly, for each final arc in previous chunk, check whether + /// the corresponding initial arc exists in the newly appended chunk. If + /// not, we make the final arc point to a "dead state" + /// Otherwise, we modify this arc to connect to the corresponding next + /// state of the initial arc with a proper weight + /// step 4: We re-determinize the appended lattice if needed. + /// + /// In our implementation, step 1 is done in GetRawLattice(), + /// step 2-4 is taken care by the class + /// LatticeIncrementalDeterminizer + /// + /// @param [in] use_final_probs If true *and* at least one final-state in HCLG + /// was active on the final frame, include final-probs from + /// HCLG + /// in the lattice. Otherwise treat all final-costs of + /// states active + /// on the most recent frame as zero (i.e. Weight::One()). + /// @param [in] redeterminize If true, re-determinize the CompactLattice + /// after appending the most recently decoded chunk to it, + /// to + /// ensure that the output is fully deterministic. + /// This does extra work, but not nearly as much as + /// determinizing + /// a RawLattice from scratch. + /// @param [in] last_frame_of_chunk Pass the last frame of this chunk to + /// the function. We make it not always equal to + /// NumFramesDecoded() to have a delay on the + /// deteriminization + /// @param [out] olat The CompactLattice representing what has been decoded + /// so far. + // If lat == NULL, the CompactLattice won't be outputed. + /// @return ret This function will returns true if the chunk is processed + /// successfully + bool GetLattice(bool use_final_probs, bool redeterminize, + int32 last_frame_of_chunk, CompactLattice *olat = NULL); + /// Specifically design when decoding_finalized_==true + bool GetLattice(CompactLattice *olat); /// InitDecoding initializes the decoding, and should only be used if you /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to @@ -222,9 +290,9 @@ class LatticeIncrementalDecoderTpl { inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } protected: - // we make things protected instead of private, as code in - // LatticeIncrementalOnlineDecoderTpl, which inherits from this, also uses the - // internals. + // we make things protected instead of private, as future code in + // LatticeIncrementalOnlineDecoderTpl, which inherits from this, also will + // use the internals. // Deletes the elements of the singly linked list tok->links. inline static void DeleteForwardLinks(Token *tok); @@ -395,95 +463,104 @@ class LatticeIncrementalDecoderTpl { void ClearActiveTokens(); - /// The following part is specifically designed for incremental determinization - /// - /// The function obtains a CompactLattice for the part of this utterance that has - /// been decoded so far. If you call this multiple times (calling it on - /// every frame would not make sense, - /// but every, say, 10, to 40 frames might make sense) it will spread out - /// the - /// work of determinization over time,which might be useful for online - /// applications. - /// - /// @param [in] use_final_probs If true *and* at least one final-state in HCLG - /// was active on the final frame, include final-probs from - /// HCLG - /// in the lattice. Otherwise treat all final-costs of - /// states active - /// on the most recent frame as zero (i.e. Weight::One()). - /// @param [in] redeterminize If true, re-determinize the CompactLattice - /// after appending the most recently decoded chunk to it, - /// to - /// ensure that the output is fully deterministic. - /// This does extra work, but not nearly as much as - /// determinizing - /// a RawLattice from scratch. - /// @param [in] last_frame_of_chunk Pass the last frame of this chunk to - /// the function. We make it not always equal to - /// NumFramesDecoded() to have a delay on the - /// deteriminization - /// @param [out] lat The CompactLattice representing what has been decoded - /// so far. - /// @return reached_final This function will returns true if a state that was - /// final in - /// HCLG was active on the most recent frame, and false - /// otherwise. - /// CAUTION: this is not the same meaning as the return - /// value of - /// LatticeFasterDecoder::GetLattice(). - bool GetLattice(bool use_final_probs, bool redeterminize, - int32 last_frame_of_chunk, CompactLattice *olat); - /// This function is modified from LatticeFasterDecoderTpl::GetRawLattice() - /// and specific design for incremental GetLattice - /// It does the same thing as GetRawLattice in lattice-faster-decoder.cc except: - /// - /// i) it creates a initial state, and connect - /// all the tokens in the first frame of this chunk to the initial state - /// by an arc with a per-token state-label as its olabel - /// ii) it creates a final state, and connect - /// all the tokens in the last frame of this chunk to the final state - /// by an arc with a per-token state-label as its olabel - /// the state-label for a token in both i) and ii) should be the same - /// frame_begin and frame_end are the first and last frame of this chunk - /// if create_initial_state == false, we will not create initial state and - /// the corresponding state-label arcs. Similar for create_final_state - /// In incremental GetLattice, we do not create the initial state in - /// the first chunk, and we do not create the final state in the last chunk + // The following part is specifically designed for incremental + // This function is modified from LatticeFasterDecoderTpl::GetRawLattice() + // and specific design for step 1 of incremental determinization + // introduced before above GetLattice() + // It does the same thing as GetRawLattice in lattice-faster-decoder.cc except: + // + // i) it creates a initial state, and connect + // all the tokens in the first frame of this chunk to the initial state + // by an arc with a per-token state-label as its olabel + // ii) it creates a final state, and connect + // all the tokens in the last frame of this chunk to the final state + // by an arc with a per-token state-label as its olabel + // the state-label for a token in both i) and ii) should be the same + // frame_begin and frame_end are the first and last frame of this chunk + // if create_initial_state == false, we will not create initial state and + // the corresponding state-label arcs. Similar for create_final_state + // In incremental GetLattice, we do not create the initial state in + // the first chunk, and we do not create the final state in the last chunk bool GetRawLattice(Lattice *ofst, bool use_final_probs, int32 frame_begin, int32 frame_end, bool create_initial_state, bool create_final_state); - CompactLattice lat_; // the compact lattice we obtain - int32 last_get_lattice_frame_; // the last time we call GetLattice - unordered_map state_label_map_; // between Token and state_label - int32 state_label_available_idx_; // we allocate a unique id for each Token + LatticeIncrementalDeterminizer determinizer_; + int32 last_get_lattice_frame_; // the last time we call GetLattice + // a map from Token to its state_label + unordered_map state_label_map_; + // we allocate a unique id for each Token + int32 state_label_available_idx_; // We keep tot_cost or extra_cost for each state_label (Token) in final and // initial arcs. We need them before determinization // We cancel them after determinization unordered_map state_label_initial_cost_; unordered_map state_label_final_cost_; + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDecoderTpl); +}; - const TransitionModel &trans_model_; // keep it for determinization - std::vector> final_arc_list_; // keep final_arc - std::vector> final_arc_list_prev_; +typedef LatticeIncrementalDecoderTpl + LatticeIncrementalDecoder; + +// This class is designed for step 2-4 of incremental determinization +// introduced before above GetLattice() +template +class LatticeIncrementalDeterminizer { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; - // Take care of the step 3 in GetLattice, which is to - // appending the new chunk in clat to the old one in olat + LatticeIncrementalDeterminizer(const LatticeIncrementalDecoderConfig &config, + const TransitionModel &trans_model); + // Reset the lattice determinization data for an utterance + void Init(); + // Output the resultant determinized lattice in the form of CompactLattice + const CompactLattice &GetDeterminizedLattice() const { return lat_; } + + // This function consumes raw_fst generated by step 1 of incremental + // determinization with specific initial and final arcs. + // It does step 2-4 and outputs the resultant CompactLattice if + // needed. Otherwise, it keeps the resultant lattice in lat_ + bool ProcessChunk(Lattice &raw_fst, int32 first_frame, int32 last_frame, + const unordered_map &state_label_initial_cost, + const unordered_map &state_label_final_cost); + + // Step 3 of incremental determinization, + // which is to append the new chunk in clat to the old one in lat_ // If not_first_chunk == false, we do not need to append and just copy // clat into olat // Otherwise, we need to connect the last frame state of // last chunk to the first frame state of this chunk. // These begin and final states are corresponding to the same Token, // guaranteed by unique state labels. - void AppendLatticeChunks(CompactLattice clat, bool not_first_chunk, - int32 last_frame_of_chunk, CompactLattice *olat); + void AppendLatticeChunks( + CompactLattice clat, bool not_first_chunk, + const unordered_map &state_label_initial_cost, + const unordered_map &state_label_final_cost); + + // Step 4 of incremental determinization, + // which either re-determinize above lat_, or simply remove the dead + // states of lat_ + bool Finalize(bool redeterminize); + + private: + const LatticeIncrementalDecoderConfig config_; + const TransitionModel &trans_model_; // keep it for determinization - KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDecoderTpl); + // Record whether we have finished determinized the whole utterance + // (including re-determinize) + bool determinization_finalized_; + // keep final_arc for appending later + std::vector> final_arc_list_; + std::vector> final_arc_list_prev_; + // The compact lattice we obtain. It should be reseted before processing a + // new utterance + CompactLattice lat_; + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDeterminizer); }; -typedef LatticeIncrementalDecoderTpl - LatticeIncrementalDecoder; - } // end namespace kaldi. #endif From 228d8a27dd2be5e73a0db70341b804cd4cce4f2d Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Mon, 1 Apr 2019 17:15:19 +0800 Subject: [PATCH 12/60] add config_.determinize_max_active & redeterminize=false --- src/decoder/lattice-incremental-decoder.cc | 33 ++++++++++++++++++---- src/decoder/lattice-incremental-decoder.h | 15 ++++++++-- 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 639aaf9a3d2..650bd05c5fa 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -103,10 +103,26 @@ bool LatticeIncrementalDecoderTpl::Decode( while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { if (NumFramesDecoded() % config_.prune_interval == 0) { PruneActiveTokens(config_.lattice_beam * config_.prune_scale); - // The chunk length of determinization is equal to prune_interval + // We always incrementally determinize the lattice after lattice pruning in + // PruneActiveTokens() // We have a delay on GetLattice to do determinization on more skinny lattices - if (NumFramesDecoded() - config_.determinize_delay > 0) - GetLattice(false, false, NumFramesDecoded() - config_.determinize_delay); + int32 frame_det_most = NumFramesDecoded() - config_.determinize_delay; + int32 frame_det_least = config_.prune_interval + last_get_lattice_frame_; + if (frame_det_most > 0) { + // To adaptively decide the length of chunk, we further compare the number of + // tokens in each frame and a pre-defined threshold. + // If the number of tokens in a certain frame is less than + // config_.determinize_max_active, the lattice can be determinized up to this + // frame. And we try to determinize as most frames as possible so we check + // numbers from frame_det_most_ up to last_get_lattice_frame_ + for (int32 f = frame_det_most; f >= frame_det_least; f--) { + if (GetNumToksForFrame(f) < config_.determinize_max_active) { + KALDI_VLOG(2) << "Frame: " << NumFramesDecoded() + << " incremental determinization up to " << f; + GetLattice(false, false, f); + } + } + } } BaseFloat cost_cutoff = ProcessEmitting(decodable); ProcessNonemitting(cost_cutoff); @@ -1107,6 +1123,13 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( return (ofst->NumStates() > 0); } +template +int32 LatticeIncrementalDecoderTpl::GetNumToksForFrame(int32 frame) { + int32 r = 0; + for (Token *tok = active_toks_[frame].toks; tok; tok = tok->next) r++; + return r; +} + template LatticeIncrementalDeterminizer::LatticeIncrementalDeterminizer( const LatticeIncrementalDecoderConfig &config, @@ -1134,8 +1157,8 @@ bool LatticeIncrementalDeterminizer::ProcessChunk( // can guarantee no final or initial arcs in clat are pruned by this function. // These pruned final arcs can hurt oracle WER performance in the final lattice // (also result in less lattice density) but they seldom hurt 1-best WER. - if (!DeterminizeLatticePhonePrunedWrapper(trans_model_, &raw_fst, config_.beam, - &clat, config_.det_opts)) + if (!DeterminizeLatticePhonePrunedWrapper( + trans_model_, &raw_fst, config_.lattice_beam, &clat, config_.det_opts)) KALDI_WARN << "Determinization finished earlier than the beam"; final_arc_list_.swap(final_arc_list_prev_); diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index f0e82275433..7ab60ae3686 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -42,6 +42,7 @@ struct LatticeIncrementalDecoderConfig { BaseFloat lattice_beam; int32 prune_interval; int32 determinize_delay; + int32 determinize_max_active; bool redeterminize; bool determinize_lattice; // not inspected by this class... used in // command-line program. @@ -63,8 +64,9 @@ struct LatticeIncrementalDecoderConfig { min_active(200), lattice_beam(10.0), prune_interval(25), - determinize_delay(25), - redeterminize(true), + determinize_delay(0), + determinize_max_active(50), + redeterminize(false), determinize_lattice(true), beam_delta(0.5), hash_ratio(2.0), @@ -86,6 +88,11 @@ struct LatticeIncrementalDecoderConfig { opts->Register("determinize-delay", &determinize_delay, "delay (in frames) at " "which to incrementally determinize lattices"); + opts->Register("determinize-max-active", &determinize_max_active, + "This option is to adaptively decide --determinize-delay. " + "If the number of active tokens(in a certain frame) is less " + "than this number, we will start to incrementally " + "determinize lattices up to this frame."); opts->Register("redeterminize", &redeterminize, "whether to re-determinize the lattice after incremental " "determinization."); @@ -484,7 +491,9 @@ class LatticeIncrementalDecoderTpl { bool GetRawLattice(Lattice *ofst, bool use_final_probs, int32 frame_begin, int32 frame_end, bool create_initial_state, bool create_final_state); - + // Get the number of tokens in each frame + // It is useful, e.g. in using config_.determinize_max_active + int32 GetNumToksForFrame(int32 frame); LatticeIncrementalDeterminizer determinizer_; int32 last_get_lattice_frame_; // the last time we call GetLattice // a map from Token to its state_label From c350305cbef28bc80402202710b1da554f22a7ef Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Mon, 1 Apr 2019 20:36:13 +0800 Subject: [PATCH 13/60] update best config; add re-determinization from frame 0 if AppendLatticeChunks failed --- src/decoder/lattice-incremental-decoder.cc | 46 +++++++++++++++++----- src/decoder/lattice-incremental-decoder.h | 12 +++--- 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 650bd05c5fa..f4822ce5162 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -22,6 +22,7 @@ #include "decoder/lattice-incremental-decoder.h" #include "lat/lattice-functions.h" +#include "base/timer.h" namespace kaldi { @@ -107,7 +108,10 @@ bool LatticeIncrementalDecoderTpl::Decode( // PruneActiveTokens() // We have a delay on GetLattice to do determinization on more skinny lattices int32 frame_det_most = NumFramesDecoded() - config_.determinize_delay; - int32 frame_det_least = config_.prune_interval + last_get_lattice_frame_; + // The minimum length of chunk is config_.prune_interval. We make it + // identical to PruneActiveTokens since we need extra_cost as the weights + // of final arcs to denote the "future" information of final states (Tokens) + int32 frame_det_least = last_get_lattice_frame_ + config_.prune_interval; if (frame_det_most > 0) { // To adaptively decide the length of chunk, we further compare the number of // tokens in each frame and a pre-defined threshold. @@ -120,6 +124,7 @@ bool LatticeIncrementalDecoderTpl::Decode( KALDI_VLOG(2) << "Frame: " << NumFramesDecoded() << " incremental determinization up to " << f; GetLattice(false, false, f); + break; } } } @@ -127,8 +132,10 @@ bool LatticeIncrementalDecoderTpl::Decode( BaseFloat cost_cutoff = ProcessEmitting(decodable); ProcessNonemitting(cost_cutoff); } + Timer timer; FinalizeDecoding(); GetLattice(true, config_.redeterminize, NumFramesDecoded()); + KALDI_VLOG(2) << "Delay time after decoding finalized (secs): " << timer.Elapsed(); // Returns true if we have any kind of traceback available (not necessarily // to the end state; query ReachedFinal() for that). @@ -958,6 +965,7 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, if (!GetRawLattice(&raw_fst, use_final_probs, last_get_lattice_frame_, last_frame_of_chunk, not_first_chunk, !decoding_finalized_)) KALDI_ERR << "Unexpected problem when getting lattice"; + // step 2-3 ret = determinizer_.ProcessChunk(raw_fst, last_get_lattice_frame_, last_frame_of_chunk, state_label_initial_cost_, state_label_final_cost_); @@ -967,11 +975,21 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, << " while the determinizer_ has already done up to frame: " << last_get_lattice_frame_; + // step 4 if (decoding_finalized_) ret &= determinizer_.Finalize(redeterminize); if (olat) { *olat = determinizer_.GetDeterminizedLattice(); ret &= (olat->NumStates() > 0); } + if (!ret) { + KALDI_WARN << "Last chunk processing failed." + << " We will retry from frame 0."; + // Reset determinizer_ and re-determinize from + // frame 0 to last_frame_of_chunk + last_get_lattice_frame_ = 0; + determinizer_.Init(); + ret = GetLattice(use_final_probs, redeterminize, last_frame_of_chunk, olat); + } return ret; } @@ -1150,6 +1168,7 @@ bool LatticeIncrementalDeterminizer::ProcessChunk( const unordered_map &state_label_initial_cost, const unordered_map &state_label_final_cost) { bool not_first_chunk = first_frame != 0; + bool ret = true; // step 2: Determinize the chunk CompactLattice clat; // We do determinization with beam pruning here @@ -1157,25 +1176,25 @@ bool LatticeIncrementalDeterminizer::ProcessChunk( // can guarantee no final or initial arcs in clat are pruned by this function. // These pruned final arcs can hurt oracle WER performance in the final lattice // (also result in less lattice density) but they seldom hurt 1-best WER. - if (!DeterminizeLatticePhonePrunedWrapper( - trans_model_, &raw_fst, config_.lattice_beam, &clat, config_.det_opts)) - KALDI_WARN << "Determinization finished earlier than the beam"; + ret &= DeterminizeLatticePhonePrunedWrapper( + trans_model_, &raw_fst, config_.lattice_beam, &clat, config_.det_opts); final_arc_list_.swap(final_arc_list_prev_); final_arc_list_.clear(); // step 3: Appending the new chunk in clat to the old one in lat_ - AppendLatticeChunks(clat, not_first_chunk, state_label_initial_cost, - state_label_final_cost); + ret &= AppendLatticeChunks(clat, not_first_chunk, state_label_initial_cost, + state_label_final_cost); + + ret &= (lat_.NumStates() > 0); KALDI_VLOG(2) << "Frame: ( " << first_frame << " , " << last_frame << " )" << " states of the chunk: " << clat.NumStates() << " states of the lattice: " << lat_.NumStates(); - - return (lat_.NumStates() > 0); + return ret; } template -void LatticeIncrementalDeterminizer::AppendLatticeChunks( +bool LatticeIncrementalDeterminizer::AppendLatticeChunks( CompactLattice clat, bool not_first_chunk, const unordered_map &state_label_initial_cost, const unordered_map &state_label_final_cost) { @@ -1277,13 +1296,20 @@ void LatticeIncrementalDeterminizer::AppendLatticeChunks( aiter_chunk1.SetValue(arc_chunk1_mod); } // otherwise, it has been pruned } - KALDI_ASSERT(prev_final_states.size()); // at least one arc should be appended + // If at least one arc connects two chunks, the function will return true + // Otherwise, return false + // It is possible to fail in connecting two chunks since the old chunk is pruned + // by DeterminizeLatticePhonePrunedWrapper while the new chunk is pruned by + // PruneActiveTokens. Their pruning behaviors are not totally the same + // If returning false in this function, we will later re-determinize from frame 0 + if (!prev_final_states.size()) return false; // Making all unmodified remaining arcs of final_arc_list_prev_ be connected to // a dead state. The following prev_final_states can be the same or different // states for (auto i : prev_final_states) olat->SetFinal(i, CompactLatticeWeight::Zero()); } else olat->SetStart(0); // Initialize the first chunk for olat + return true; } template diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 7ab60ae3686..c5a5bfd79a6 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -1,4 +1,4 @@ -// decoder/lattice-incremental-decoder.h +// decoder/lattice-incremental-decoder.h // Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann; // 2013-2014 Johns Hopkins University (Author: Daniel Povey) @@ -246,11 +246,11 @@ class LatticeIncrementalDecoderTpl { /// the function. We make it not always equal to /// NumFramesDecoded() to have a delay on the /// deteriminization - /// @param [out] olat The CompactLattice representing what has been decoded + /// @param [out] olat The CompactLattice representing what has been decoded /// so far. - // If lat == NULL, the CompactLattice won't be outputed. - /// @return ret This function will returns true if the chunk is processed - /// successfully + // If lat == NULL, the CompactLattice won't be outputed. + /// @return ret This function will returns true if the chunk is processed + /// successfully bool GetLattice(bool use_final_probs, bool redeterminize, int32 last_frame_of_chunk, CompactLattice *olat = NULL); /// Specifically design when decoding_finalized_==true @@ -544,7 +544,7 @@ class LatticeIncrementalDeterminizer { // last chunk to the first frame state of this chunk. // These begin and final states are corresponding to the same Token, // guaranteed by unique state labels. - void AppendLatticeChunks( + bool AppendLatticeChunks( CompactLattice clat, bool not_first_chunk, const unordered_map &state_label_initial_cost, const unordered_map &state_label_final_cost); From 38443890cd70ebe874b7290859ea61c369ed8637 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Tue, 2 Apr 2019 01:35:31 -0400 Subject: [PATCH 14/60] 1. add time profiling for baseline lattice-faster-decoder for comparison (remove it later) 2. add determinize-beam-offset. By this way, the beam used in lattice determinization is (determinize_beam_offset + lattice_beam) --- src/decoder/decoder-wrappers.cc | 3 +++ src/decoder/lattice-faster-decoder.cc | 2 ++ src/decoder/lattice-incremental-decoder.cc | 3 ++- src/decoder/lattice-incremental-decoder.h | 5 +++++ 4 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/decoder/decoder-wrappers.cc b/src/decoder/decoder-wrappers.cc index 294a2f69117..abb326eb012 100644 --- a/src/decoder/decoder-wrappers.cc +++ b/src/decoder/decoder-wrappers.cc @@ -22,6 +22,7 @@ #include "decoder/lattice-faster-decoder.h" #include "decoder/grammar-fst.h" #include "lat/lattice-functions.h" +#include "base/timer.h" namespace kaldi { @@ -353,6 +354,7 @@ bool DecodeUtteranceLatticeFaster( // Get lattice, and do determinization if requested. Lattice lat; + Timer timer; decoder.GetRawLattice(&lat); if (lat.NumStates() == 0) KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt; @@ -377,6 +379,7 @@ bool DecodeUtteranceLatticeFaster( fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &lat); lattice_writer->Write(utt, lat); } + KALDI_VLOG(2) << "Delay time after decoding finalized (secs): " << timer.Elapsed(); KALDI_LOG << "Log-like per frame for utterance " << utt << " is " << (likelihood / num_frames) << " over " << num_frames << " frames."; diff --git a/src/decoder/lattice-faster-decoder.cc b/src/decoder/lattice-faster-decoder.cc index 2bc8c7cdef4..ed78ba5fddb 100644 --- a/src/decoder/lattice-faster-decoder.cc +++ b/src/decoder/lattice-faster-decoder.cc @@ -89,7 +89,9 @@ bool LatticeFasterDecoderTpl::Decode(DecodableInterface *decodable) BaseFloat cost_cutoff = ProcessEmitting(decodable); ProcessNonemitting(cost_cutoff); } + Timer timer; FinalizeDecoding(); + KALDI_VLOG(2) << "Delay0 time after decoding finalized (secs): " << timer.Elapsed(); // Returns true if we have any kind of traceback available (not necessarily // to the end state; query ReachedFinal() for that). diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index f4822ce5162..fc73b0efc9f 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -1177,7 +1177,8 @@ bool LatticeIncrementalDeterminizer::ProcessChunk( // These pruned final arcs can hurt oracle WER performance in the final lattice // (also result in less lattice density) but they seldom hurt 1-best WER. ret &= DeterminizeLatticePhonePrunedWrapper( - trans_model_, &raw_fst, config_.lattice_beam, &clat, config_.det_opts); + trans_model_, &raw_fst, config_.determinize_beam_offset + + config_.lattice_beam, &clat, config_.det_opts); final_arc_list_.swap(final_arc_list_prev_); final_arc_list_.clear(); diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index c5a5bfd79a6..0a6f92e8306 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -43,6 +43,7 @@ struct LatticeIncrementalDecoderConfig { int32 prune_interval; int32 determinize_delay; int32 determinize_max_active; + BaseFloat determinize_beam_offset; bool redeterminize; bool determinize_lattice; // not inspected by this class... used in // command-line program. @@ -66,6 +67,7 @@ struct LatticeIncrementalDecoderConfig { prune_interval(25), determinize_delay(0), determinize_max_active(50), + determinize_beam_offset(0), redeterminize(false), determinize_lattice(true), beam_delta(0.5), @@ -93,6 +95,9 @@ struct LatticeIncrementalDecoderConfig { "If the number of active tokens(in a certain frame) is less " "than this number, we will start to incrementally " "determinize lattices up to this frame."); + opts->Register("determinize-beam-offset", &determinize_beam_offset, + "the beam used in lattice determinization is " + "(determinize_beam_offset + lattice_beam) ."); opts->Register("redeterminize", &redeterminize, "whether to re-determinize the lattice after incremental " "determinization."); From 90f3ea72f0063259e5dc213196accfba20efc5c7 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Tue, 2 Apr 2019 04:18:21 -0400 Subject: [PATCH 15/60] code refine --- src/decoder/lattice-incremental-decoder.cc | 20 +++++++---- src/decoder/lattice-incremental-decoder.h | 39 ++++++++++------------ 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index fc73b0efc9f..3b28cc555bc 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -118,9 +118,10 @@ bool LatticeIncrementalDecoderTpl::Decode( // If the number of tokens in a certain frame is less than // config_.determinize_max_active, the lattice can be determinized up to this // frame. And we try to determinize as most frames as possible so we check - // numbers from frame_det_most_ up to last_get_lattice_frame_ + // numbers from frame_det_most to frame_det_least for (int32 f = frame_det_most; f >= frame_det_least; f--) { - if (GetNumToksForFrame(f) < config_.determinize_max_active) { + if (config_.determinize_max_active == std::numeric_limits::max() + || GetNumToksForFrame(f) < config_.determinize_max_active) { KALDI_VLOG(2) << "Frame: " << NumFramesDecoded() << " incremental determinization up to " << f; GetLattice(false, false, f); @@ -962,7 +963,7 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, // step 1: Get lattice chunk with initial and final states // In this function, we do not create the initial state in // the first chunk, and we do not create the final state in the last chunk - if (!GetRawLattice(&raw_fst, use_final_probs, last_get_lattice_frame_, + if (!GetIncrementalRawLattice(&raw_fst, use_final_probs, last_get_lattice_frame_, last_frame_of_chunk, not_first_chunk, !decoding_finalized_)) KALDI_ERR << "Unexpected problem when getting lattice"; // step 2-3 @@ -995,7 +996,7 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, } template -bool LatticeIncrementalDecoderTpl::GetRawLattice( +bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( Lattice *ofst, bool use_final_probs, int32 frame_begin, int32 frame_end, bool create_initial_state, bool create_final_state) { typedef LatticeArc Arc; @@ -1005,7 +1006,7 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( if (decoding_finalized_ && !use_final_probs) KALDI_ERR << "You cannot call FinalizeDecoding() and then call " - << "GetRawLattice() with use_final_probs == false"; + << "GetIncrementalRawLattice() with use_final_probs == false"; unordered_map final_costs_local; @@ -1025,7 +1026,7 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( std::vector token_list; for (int32 f = frame_begin; f <= frame_end; f++) { if (active_toks_[f].toks == NULL) { - KALDI_WARN << "GetRawLattice: no tokens active on frame " << f + KALDI_WARN << "GetIncrementalRawLattice: no tokens active on frame " << f << ": not producing lattice.\n"; return false; } @@ -1063,7 +1064,7 @@ bool LatticeIncrementalDecoderTpl::GetRawLattice( ofst->AddArc(begin_state, arc); } } - // step 1.2: create all arcs as GetRawLattice() + // step 1.2: create all arcs as GetRawLattice() of LatticeFasterDecoder for (int32 f = frame_begin; f <= frame_end; f++) { for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { StateId cur_state = tok_map[tok]; @@ -1176,6 +1177,11 @@ bool LatticeIncrementalDeterminizer::ProcessChunk( // can guarantee no final or initial arcs in clat are pruned by this function. // These pruned final arcs can hurt oracle WER performance in the final lattice // (also result in less lattice density) but they seldom hurt 1-best WER. + // Since pruning behaviors in DeterminizeLatticePhonePrunedWrapper and + // PruneActiveTokens are not the same, to get similar lattice density as + // LatticeFasterDecoder, we need to use a slightly larger beam here + // than the lattice_beam used PruneActiveTokens. Hence the beam we use is + // (config_.determinize_beam_offset + config_.lattice_beam) ret &= DeterminizeLatticePhonePrunedWrapper( trans_model_, &raw_fst, config_.determinize_beam_offset + config_.lattice_beam, &clat, config_.det_opts); diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 0a6f92e8306..9f4e0983cf2 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -65,9 +65,9 @@ struct LatticeIncrementalDecoderConfig { min_active(200), lattice_beam(10.0), prune_interval(25), - determinize_delay(0), - determinize_max_active(50), - determinize_beam_offset(0), + determinize_delay(25), + determinize_max_active(std::numeric_limits::max()), + determinize_beam_offset(1), redeterminize(false), determinize_lattice(true), beam_delta(0.5), @@ -187,21 +187,6 @@ class LatticeIncrementalDecoderTpl { /// and figures out the shortest path. bool GetBestPath(Lattice *ofst, bool use_final_probs = true); - /// Outputs an FST corresponding to the raw, state-level - /// tracebacks. Returns true if result is nonempty. - /// If "use_final_probs" is true AND we reached the final-state - /// of the graph then it will include those as final-probs, else - /// it will treat all final-probs as one. - /// The raw lattice will be topologically sorted. - /// Notably, the raw lattice from this incremental determinization decoder - /// has already been partially determinized - /// - /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, - /// which also supports a pruning beam, in case for some reason - /// you want it pruned tighter than the regular lattice beam. - /// We could put that here in future needed. - bool GetRawLattice(Lattice *ofst, bool use_final_probs = true); - /// The following function is specifically designed for incremental /// determinization. The function obtains a CompactLattice for /// the part of this utterance up to the frame last_frame_of_chunk. @@ -261,6 +246,16 @@ class LatticeIncrementalDecoderTpl { /// Specifically design when decoding_finalized_==true bool GetLattice(CompactLattice *olat); + /// This function is to keep forwards compatibility. + /// It outputs an FST corresponding to the raw, state-level + /// tracebacks. Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state + /// of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + /// Notably, the raw lattice from this incremental determinization decoder + /// has already been partially determinized + bool GetRawLattice(Lattice *ofst, bool use_final_probs = true); + /// InitDecoding initializes the decoding, and should only be used if you /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to /// call this. You can also call InitDecoding if you have already decoded an @@ -475,7 +470,7 @@ class LatticeIncrementalDecoderTpl { void ClearActiveTokens(); - // The following part is specifically designed for incremental + // The following part is specifically designed for incremental determinization // This function is modified from LatticeFasterDecoderTpl::GetRawLattice() // and specific design for step 1 of incremental determinization // introduced before above GetLattice() @@ -493,9 +488,9 @@ class LatticeIncrementalDecoderTpl { // the corresponding state-label arcs. Similar for create_final_state // In incremental GetLattice, we do not create the initial state in // the first chunk, and we do not create the final state in the last chunk - bool GetRawLattice(Lattice *ofst, bool use_final_probs, int32 frame_begin, - int32 frame_end, bool create_initial_state, - bool create_final_state); + bool GetIncrementalRawLattice(Lattice *ofst, bool use_final_probs, + int32 frame_begin, int32 frame_end, + bool create_initial_state, bool create_final_state); // Get the number of tokens in each frame // It is useful, e.g. in using config_.determinize_max_active int32 GetNumToksForFrame(int32 frame); From 9111173c465f452e212feb821532c74e90b16645 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Tue, 9 Apr 2019 21:54:14 +0800 Subject: [PATCH 16/60] update final weight by extra_cost-alpha, see sheet "ver 3" --- src/decoder/lattice-incremental-decoder.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 3b28cc555bc..284c439280e 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -989,7 +989,6 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, // frame 0 to last_frame_of_chunk last_get_lattice_frame_ = 0; determinizer_.Init(); - ret = GetLattice(use_final_probs, redeterminize, last_frame_of_chunk, olat); } return ret; @@ -1130,7 +1129,7 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( // For now, we use extra_cost from the decoding stage , which has some // "future information", as // the final weights of this chunk - BaseFloat cost_offset = tok->extra_cost; + BaseFloat cost_offset = tok->extra_cost-tok->tot_cost; // We record these cost_offset, and after we appending two chunks // we will cancel them out state_label_final_cost_[id] = cost_offset; @@ -1184,7 +1183,7 @@ bool LatticeIncrementalDeterminizer::ProcessChunk( // (config_.determinize_beam_offset + config_.lattice_beam) ret &= DeterminizeLatticePhonePrunedWrapper( trans_model_, &raw_fst, config_.determinize_beam_offset + - config_.lattice_beam, &clat, config_.det_opts); + config_.lattice_beam + 0.1, &clat, config_.det_opts); final_arc_list_.swap(final_arc_list_prev_); final_arc_list_.clear(); From 6af8f62d0e590c51694b7e1e704104f37d62de2e Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Wed, 10 Apr 2019 20:46:50 +0800 Subject: [PATCH 17/60] WIP --- src/decoder/lattice-incremental-decoder.cc | 236 ++++++++++++--------- src/decoder/lattice-incremental-decoder.h | 21 +- 2 files changed, 152 insertions(+), 105 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 284c439280e..1e93b9d968e 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -968,8 +968,7 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, KALDI_ERR << "Unexpected problem when getting lattice"; // step 2-3 ret = determinizer_.ProcessChunk(raw_fst, last_get_lattice_frame_, - last_frame_of_chunk, state_label_initial_cost_, - state_label_final_cost_); + last_frame_of_chunk, state_label_initial_cost_); last_get_lattice_frame_ = last_frame_of_chunk; } else if (last_get_lattice_frame_ > last_frame_of_chunk) KALDI_WARN << "Call GetLattice up to frame: " << last_frame_of_chunk @@ -994,6 +993,9 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, return ret; } +// sanity check +unordered_map g_last2first_state; + template bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( Lattice *ofst, bool use_final_probs, int32 frame_begin, int32 frame_end, @@ -1015,7 +1017,10 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( ComputeFinalCosts(&final_costs_local, NULL, NULL); ofst->DeleteStates(); - if (create_initial_state) ofst->AddState(); // initial-state for the chunk + unordered_multimap state_label2state_map; // for GetInitialRawLattice + // initial arcs for the chunk + if (create_initial_state) + determinizer_.GetInitialRawLattice(ofst, &state_label2state_map, state_label_final_cost_); // num-frames plus one (since frames are one-based, and we have // an extra frame for the start-state). KALDI_ASSERT(frame_end > 0); @@ -1035,8 +1040,8 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( } // The next statement sets the start state of the output FST. // No matter create_initial_state or not , state zero must be the start-state. - StateId begin_state = 0; - ofst->SetStart(begin_state); + StateId start_state = 0; + ofst->SetStart(start_state); KALDI_VLOG(4) << "init:" << num_toks_ / 2 + 3 << " buckets:" << tok_map.bucket_count() @@ -1050,17 +1055,27 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( auto r = state_label_map_.find(tok); KALDI_ASSERT(r != state_label_map_.end()); // it should exist int32 id = r->second; - // Use cost_offsets to guide DeterminizeLatticePruned() - // later - // For now, we use alpha (tot_cost) from the decoding stage as - // the initial weights of arcs connecting to the states in the begin - // of this chunk - BaseFloat cost_offset = tok->tot_cost; - // We record these cost_offset, and after we appending two chunks - // we will cancel them out - state_label_initial_cost_[id] = cost_offset; - Arc arc(0, id, Weight(0, cost_offset), cur_state); - ofst->AddArc(begin_state, arc); + auto r2 = state_label2state_map.find(id); + auto r3 = r2; //TODO + KALDI_ASSERT(r2!=state_label2state_map.end()); + // sanity check + auto& forward_costs = determinizer_.GetForwardCosts(); + BaseFloat best_cost = std::numeric_limits::infinity(); + std::vector tmp_vec; + for (; r2!=state_label2state_map.end(); ++r2) { + // the destination state of the last of the sequence of arcs w.r.t the id here (state label) created by GetInitialRawLattice + auto state_last_initial = r2->second; + // connect it to the state correponding to the token w.r.t the id here (state label) + Arc arc(0, 0, Weight::One(), cur_state); + ofst->AddArc(state_last_initial, arc); + // sanity check + auto r = g_last2first_state.find(state_last_initial); + KALDI_ASSERT(r!=g_last2first_state.end()); + tmp_vec.push_back(r->second); + best_cost = std::min(best_cost, r->second); + } + // sanity check + KALDI_ASSERT(abs(best_cost - tok->tot_cost) < 0.1); } } // step 1.2: create all arcs as GetRawLattice() of LatticeFasterDecoder @@ -1076,10 +1091,12 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( KALDI_ASSERT(iter != tok_map.end()); StateId nextstate = iter->second; BaseFloat cost_offset = 0.0; + /* // TODO if (l->ilabel != 0) { // emitting.. KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); cost_offset = cost_offsets_[f]; } + */ Arc arc(l->ilabel, l->olabel, Weight(l->graph_cost, l->acoustic_cost - cost_offset), nextstate); ofst->AddArc(cur_state, arc); @@ -1160,13 +1177,14 @@ void LatticeIncrementalDeterminizer::Init() { final_arc_list_prev_.clear(); lat_.DeleteStates(); determinization_finalized_ = false; + forward_costs_.clear(); + state_last_initial_offset_ = 2*config_.max_word_id; } template bool LatticeIncrementalDeterminizer::ProcessChunk( Lattice &raw_fst, int32 first_frame, int32 last_frame, - const unordered_map &state_label_initial_cost, - const unordered_map &state_label_final_cost) { + const unordered_map &state_label_initial_cost) { bool not_first_chunk = first_frame != 0; bool ret = true; // step 2: Determinize the chunk @@ -1181,16 +1199,15 @@ bool LatticeIncrementalDeterminizer::ProcessChunk( // LatticeFasterDecoder, we need to use a slightly larger beam here // than the lattice_beam used PruneActiveTokens. Hence the beam we use is // (config_.determinize_beam_offset + config_.lattice_beam) +#if 0 ret &= DeterminizeLatticePhonePrunedWrapper( trans_model_, &raw_fst, config_.determinize_beam_offset + config_.lattice_beam + 0.1, &clat, config_.det_opts); - - final_arc_list_.swap(final_arc_list_prev_); - final_arc_list_.clear(); - +#else + ConvertLattice(raw_fst, &clat); +#endif // step 3: Appending the new chunk in clat to the old one in lat_ - ret &= AppendLatticeChunks(clat, not_first_chunk, state_label_initial_cost, - state_label_final_cost); + ret &= AppendLatticeChunks(clat, not_first_chunk, state_label_initial_cost); ret &= (lat_.NumStates() > 0); KALDI_VLOG(2) << "Frame: ( " << first_frame << " , " << last_frame << " )" @@ -1198,123 +1215,144 @@ bool LatticeIncrementalDeterminizer::ProcessChunk( << " states of the lattice: " << lat_.NumStates(); return ret; } +// We have multiple states for one state label after determinization +template +void LatticeIncrementalDeterminizer::GetInitialRawLattice( + Lattice *olat, + unordered_multimap *state_label2state_map, + const unordered_map &state_label_final_cost) { + using namespace fst; + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + KALDI_ASSERT(final_arc_list_prev_.size()); + olat->DeleteStates(); + state_label2state_map->clear(); + + auto start_state = olat->AddState(); + olat->SetStart(start_state); + for (auto &i : final_arc_list_prev_) { + ArcIterator aiter_chunk1(lat_, i.first); + aiter_chunk1.Seek(i.second); + // Obtain the appended final arcs in the previous chunk + const auto &arc_chunk1 = aiter_chunk1.Value(); + KALDI_ASSERT(arc_chunk1.olabel > config_.max_word_id); + StateId prev_final_state = arc_chunk1.nextstate; + CompactLatticeWeight weight_offset; + KALDI_ASSERT(i.first < forward_costs_.size()); + // TODO: description + const auto r2 = state_label_final_cost.find(arc_chunk1.olabel); + KALDI_ASSERT(r2 != state_label_final_cost.end()); + weight_offset.SetWeight(LatticeWeight(0, forward_costs_[i.first]-r2->second)); + auto initial_weight = Times(Times(arc_chunk1.weight, lat_.Final(prev_final_state)), weight_offset); + + // create a state representing the i.first state (source state) in appended lattice + auto source_state = olat->AddState(); + // we need a special label in the arc that corresponds to the identity of the source-state of the last arc, we use its StateId and a offset here + int id = i.first + state_last_initial_offset_; + Arc arc(0, id, initial_weight.Weight(), source_state); + olat->AddArc(start_state, arc); + // We generate a linear sequence of arcs sufficient to contain all the transition-ids on the string + auto prev_state = source_state; + for (auto &j : initial_weight.String()) { + auto cur_state = olat->AddState(); + Arc arc(j, 0, LatticeWeight::One(), cur_state); + olat->AddArc(prev_state, arc); + prev_state = cur_state; + } + auto last_state = olat->NumStates()-1; + // the destination state of the last of the sequence of arcs will be recorded and connected to the state corresponding to token w.r.t arc_chunk1.olabel + state_label2state_map->insert( + std::pair(arc_chunk1.olabel, last_state)); + // sanity check + g_last2first_state[last_state]=initial_weight.Weight().Value1()+initial_weight.Weight().Value2(); + } +} template bool LatticeIncrementalDeterminizer::AppendLatticeChunks( CompactLattice clat, bool not_first_chunk, - const unordered_map &state_label_initial_cost, - const unordered_map &state_label_final_cost) { + const unordered_map &state_label_initial_cost) { using namespace fst; CompactLattice *olat = &lat_; + + // later we need to calculate forward_costs_ for clat + TopSortCompactLatticeIfNeeded(&clat); + // step 3.1: Appending new chunk to the old one int32 state_offset = olat->NumStates(); if (not_first_chunk) state_offset--; // since we do not append initial state in the first chunk + else + forward_costs_.push_back(0); // for the first state + // TODO assert whether it is similar to tot_cost + forward_costs_.resize(state_offset+clat.NumStates()); // we append all states except the initial state - // A map from state label to the arc position (index) - // the incoming states of these arcs are initial states of the chunk - // and the olabel of these arcs are the key of this map (state label) - // The arc position are obtained from ArcIterator corresponding to the state - unordered_map initial_arc_map; - initial_arc_map.reserve(std::min((int32)1e5, config_.max_active)); for (StateIterator siter(clat); !siter.Done(); siter.Next()) { auto s = siter.Value(); - StateId state_appended = -1; + StateId state_appended = kNoStateId; // We do not copy initial state, which exists except the first chunk if (!not_first_chunk || s != 0) { state_appended = s + state_offset; KALDI_ASSERT(state_appended == olat->AddState()); olat->SetFinal(state_appended, clat.Final(s)); } - for (ArcIterator aiter(clat, s); !aiter.Done(); aiter.Next()) { const auto &arc = aiter.Value(); + StateId source_state = kNoStateId; // We do not copy initial arcs, which exists except the first chunk. // These arcs will be taken care later in step 3.2 + CompactLatticeArc arc_appended(arc); + arc_appended.nextstate += state_offset; + // In the first chunk, there could be a final arc starting from state 0, and we process it here + // In the last chunk, there could be a initial arc ending in final state, and we process it in "process initial arcs" in the following if (!not_first_chunk || s != 0) { - CompactLatticeArc arc_appended(arc); - arc_appended.nextstate += state_offset; - olat->AddArc(state_appended, arc_appended); - } - // Process state labels, which will be used in step 3.2 - if (arc.olabel > config_.max_word_id) { // initial_arc - // In first chunk, there could be a final arc starting from state 0 - // In the last chunk, there could be a initial arc ending in final state - if (not_first_chunk && - s == 0) { // record initial_arc in this chunk, we will use it right now - initial_arc_map[arc.olabel] = aiter.Position(); - } else { // final_arc + KALDI_ASSERT(state_appended != kNoStateId); + source_state = state_appended; + if (arc.olabel > config_.max_word_id) { // record final_arc in this chunk for the step 3.2 in the next call + KALDI_ASSERT(arc.olabel < state_last_initial_offset_); KALDI_ASSERT(clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()); final_arc_list_.push_back( pair(state_appended, aiter.Position())); } + } else { // process initial arcs + KALDI_ASSERT(arc.olabel > config_.max_word_id); + KALDI_ASSERT(arc.olabel >= state_last_initial_offset_); + source_state = arc.olabel - state_last_initial_offset_; // TODO description + arc_appended.olabel = 0; + arc_appended.ilabel = 0; + // TODO: remove alpha in weight } + KALDI_ASSERT(source_state != kNoStateId); + olat->AddArc(source_state, arc_appended); + // update forward_costs_ + KALDI_ASSERT(arc_appended.nextstate < forward_costs_.size()); + auto& alpha_nextstate = forward_costs_[arc_appended.nextstate]; + auto& weight = arc_appended.weight.Weight(); + alpha_nextstate = std::min(alpha_nextstate, forward_costs_[source_state] + weight.Value1() + weight.Value2()); } } - // step 3.2: connect the states between two chunks, i.e. chunk1 in olat and chunk2 - // in clat in the following - // Notably, most states and arcs of clat has been copied to olat in step 3.1 - // This step is mainly to process the boundary of these two chunks + // Making all remaining arcs of final_arc_list_prev_ be connected to + // a dead state. + // final states are always the same state) if (not_first_chunk) { KALDI_ASSERT(final_arc_list_prev_.size()); - vector prev_final_states; for (auto &i : final_arc_list_prev_) { - MutableArcIterator aiter_chunk1(olat, i.first); + ArcIterator aiter_chunk1(*olat, i.first); aiter_chunk1.Seek(i.second); // Obtain the appended final arcs in the previous chunk auto &arc_chunk1 = aiter_chunk1.Value(); - // Find out whether its corresponding Token still exists in the begin - // of this chunk. If not, it is pruned by PruneActiveTokens() - auto r = initial_arc_map.find(arc_chunk1.olabel); - if (r != initial_arc_map.end()) { - ArcIterator aiter_chunk2(clat, 0); // initial state - aiter_chunk2.Seek(r->second); - const auto &arc_chunk2 = aiter_chunk2.Value(); - KALDI_ASSERT(arc_chunk2.olabel == arc_chunk1.olabel); - StateId state_chunk1 = arc_chunk2.nextstate + state_offset; - StateId prev_final_state = arc_chunk1.nextstate; - prev_final_states.push_back(prev_final_state); - // For the later code in this loop, we try to modify the arc_chunk1 - // to connect the last frame state of last chunk to the first frame - // state of this chunk. These begin and final states are - // corresponding to the same Token, guaranteed by unique state labels. - CompactLatticeArc arc_chunk1_mod(arc_chunk1); - arc_chunk1_mod.nextstate = state_chunk1; - { // Update arc weight in this section - CompactLatticeWeight weight_offset, weight_offset_final; - const auto r1 = state_label_initial_cost.find(arc_chunk1.olabel); - KALDI_ASSERT(r1 != state_label_initial_cost.end()); - weight_offset.SetWeight(LatticeWeight(0, -r1->second)); - const auto r2 = state_label_final_cost.find(arc_chunk1.olabel); - KALDI_ASSERT(r2 != state_label_final_cost.end()); - weight_offset_final.SetWeight(LatticeWeight(0, -r2->second)); - arc_chunk1_mod.weight = Times( - Times(Times(Times(arc_chunk2.weight, olat->Final(prev_final_state)), - weight_offset), - weight_offset_final), - arc_chunk1_mod.weight); - } - // After appending, state labels are of no use and we remove them - arc_chunk1_mod.olabel = 0; - arc_chunk1_mod.ilabel = 0; - aiter_chunk1.SetValue(arc_chunk1_mod); - } // otherwise, it has been pruned + olat->SetFinal(arc_chunk1.nextstate, CompactLatticeWeight::Zero()); } - // If at least one arc connects two chunks, the function will return true - // Otherwise, return false - // It is possible to fail in connecting two chunks since the old chunk is pruned - // by DeterminizeLatticePhonePrunedWrapper while the new chunk is pruned by - // PruneActiveTokens. Their pruning behaviors are not totally the same - // If returning false in this function, we will later re-determinize from frame 0 - if (!prev_final_states.size()) return false; - // Making all unmodified remaining arcs of final_arc_list_prev_ be connected to - // a dead state. The following prev_final_states can be the same or different - // states - for (auto i : prev_final_states) olat->SetFinal(i, CompactLatticeWeight::Zero()); } else olat->SetStart(0); // Initialize the first chunk for olat + + final_arc_list_.swap(final_arc_list_prev_); + final_arc_list_.clear(); + return true; } diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 9f4e0983cf2..5d156ee62e2 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -503,8 +503,8 @@ class LatticeIncrementalDecoderTpl { // We keep tot_cost or extra_cost for each state_label (Token) in final and // initial arcs. We need them before determinization // We cancel them after determinization - unordered_map state_label_initial_cost_; - unordered_map state_label_final_cost_; + unordered_map state_label_initial_cost_; // TODO remove it + unordered_map state_label_final_cost_; // TODO remove it KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDecoderTpl); }; @@ -533,8 +533,7 @@ class LatticeIncrementalDeterminizer { // It does step 2-4 and outputs the resultant CompactLattice if // needed. Otherwise, it keeps the resultant lattice in lat_ bool ProcessChunk(Lattice &raw_fst, int32 first_frame, int32 last_frame, - const unordered_map &state_label_initial_cost, - const unordered_map &state_label_final_cost); + const unordered_map &state_label_initial_cost); // Step 3 of incremental determinization, // which is to append the new chunk in clat to the old one in lat_ @@ -546,13 +545,18 @@ class LatticeIncrementalDeterminizer { // guaranteed by unique state labels. bool AppendLatticeChunks( CompactLattice clat, bool not_first_chunk, - const unordered_map &state_label_initial_cost, - const unordered_map &state_label_final_cost); + const unordered_map &state_label_initial_cost); // Step 4 of incremental determinization, // which either re-determinize above lat_, or simply remove the dead // states of lat_ bool Finalize(bool redeterminize); + std::vector& GetForwardCosts() { + return forward_costs_; + } + void GetInitialRawLattice(Lattice *olat, + unordered_multimap *state_label2state_map, + const unordered_map &state_label_final_cost); private: const LatticeIncrementalDecoderConfig config_; @@ -564,6 +568,11 @@ class LatticeIncrementalDeterminizer { // keep final_arc for appending later std::vector> final_arc_list_; std::vector> final_arc_list_prev_; + // alpha of each state in lat_ + std::vector forward_costs_; + // we allocate a unique id for each source-state of the last arc of a series of initial arcs in GetInitialRawLattice + int32 state_last_initial_offset_; + // The compact lattice we obtain. It should be reseted before processing a // new utterance CompactLattice lat_; From 40cf7ffe9202147a36aa0c05ddccd8f3bd445eaf Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Wed, 10 Apr 2019 21:55:21 +0800 Subject: [PATCH 18/60] fix bugs and add sanity check --- src/decoder/lattice-incremental-decoder.cc | 67 +++++++++++++++++++--- 1 file changed, 58 insertions(+), 9 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 1e93b9d968e..b9b38b686ec 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -944,6 +944,8 @@ template bool LatticeIncrementalDecoderTpl::GetLattice(CompactLattice *olat) { return GetLattice(true, config_.redeterminize, NumFramesDecoded(), olat); } +// sanity check +BaseFloat best_cost_in_chunk_; template bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, @@ -1055,16 +1057,15 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( auto r = state_label_map_.find(tok); KALDI_ASSERT(r != state_label_map_.end()); // it should exist int32 id = r->second; - auto r2 = state_label2state_map.find(id); - auto r3 = r2; //TODO - KALDI_ASSERT(r2!=state_label2state_map.end()); + auto range = state_label2state_map.equal_range(id); + KALDI_ASSERT(range.first != range.second); // sanity check auto& forward_costs = determinizer_.GetForwardCosts(); BaseFloat best_cost = std::numeric_limits::infinity(); std::vector tmp_vec; - for (; r2!=state_label2state_map.end(); ++r2) { + for (auto it = range.first; it!=range.second; ++it) { // the destination state of the last of the sequence of arcs w.r.t the id here (state label) created by GetInitialRawLattice - auto state_last_initial = r2->second; + auto state_last_initial = it->second; // connect it to the state correponding to the token w.r.t the id here (state label) Arc arc(0, 0, Weight::One(), cur_state); ofst->AddArc(state_last_initial, arc); @@ -1078,6 +1079,8 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( KALDI_ASSERT(abs(best_cost - tok->tot_cost) < 0.1); } } + // for sanity check + best_cost_in_chunk_ = std::numeric_limits::infinity(); // step 1.2: create all arcs as GetRawLattice() of LatticeFasterDecoder for (int32 f = frame_begin; f <= frame_end; f++) { for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { @@ -1120,6 +1123,10 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( weight = LatticeWeight::Zero(); } ofst->SetFinal(cur_state, weight); + // for sanity check + // we will use extra_cost in step 1.3 (see the following code) + best_cost_in_chunk_ = std::min( + best_cost_in_chunk_, tok->tot_cost + weight.Value1() + weight.Value2()); } } } @@ -1146,7 +1153,7 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( // For now, we use extra_cost from the decoding stage , which has some // "future information", as // the final weights of this chunk - BaseFloat cost_offset = tok->extra_cost-tok->tot_cost; + BaseFloat cost_offset = 0; // TODO tok->extra_cost-tok->tot_cost; // We record these cost_offset, and after we appending two chunks // we will cancel them out state_label_final_cost_[id] = cost_offset; @@ -1206,8 +1213,41 @@ bool LatticeIncrementalDeterminizer::ProcessChunk( #else ConvertLattice(raw_fst, &clat); #endif + { + // sanity check, remove them later + CompactLattice cdecoded; + Lattice decoded; + ShortestPath(clat, &cdecoded); + ConvertLattice(cdecoded, &decoded); + LatticeWeight weight; + std::vector alignment; + std::vector words; + GetLinearSymbolSequence(decoded, &alignment, &words, &weight); + KALDI_ASSERT(alignment.size() == last_frame-first_frame); + // TODO: the following KALDI_ASSERT will fail some time, which is unexpected + // for sanity check + KALDI_ASSERT(std::abs(best_cost_in_chunk_ - (weight.Value1() + + weight.Value2())) < 1e-1); + } + // step 3: Appending the new chunk in clat to the old one in lat_ ret &= AppendLatticeChunks(clat, not_first_chunk, state_label_initial_cost); + { + // sanity check, remove them later + CompactLattice cdecoded; + Lattice decoded; + ShortestPath(lat_, &cdecoded); + ConvertLattice(cdecoded, &decoded); + LatticeWeight weight; + std::vector alignment; + std::vector words; + GetLinearSymbolSequence(decoded, &alignment, &words, &weight); + KALDI_ASSERT(alignment.size() == last_frame); + // TODO: the following KALDI_ASSERT will fail some time, which is unexpected + // for sanity check + KALDI_ASSERT(std::abs(best_cost_in_chunk_ - (weight.Value1() + + weight.Value2())) < 1e-1); + } ret &= (lat_.NumStates() > 0); KALDI_VLOG(2) << "Frame: ( " << first_frame << " , " << last_frame << " )" @@ -1232,6 +1272,8 @@ void LatticeIncrementalDeterminizer::GetInitialRawLattice( auto start_state = olat->AddState(); olat->SetStart(start_state); + // sanity check + BaseFloat best_cost = std::numeric_limits::infinity(); for (auto &i : final_arc_list_prev_) { ArcIterator aiter_chunk1(lat_, i.first); aiter_chunk1.Seek(i.second); @@ -1240,10 +1282,10 @@ void LatticeIncrementalDeterminizer::GetInitialRawLattice( KALDI_ASSERT(arc_chunk1.olabel > config_.max_word_id); StateId prev_final_state = arc_chunk1.nextstate; CompactLatticeWeight weight_offset; - KALDI_ASSERT(i.first < forward_costs_.size()); // TODO: description const auto r2 = state_label_final_cost.find(arc_chunk1.olabel); KALDI_ASSERT(r2 != state_label_final_cost.end()); + KALDI_ASSERT(i.first < forward_costs_.size()); weight_offset.SetWeight(LatticeWeight(0, forward_costs_[i.first]-r2->second)); auto initial_weight = Times(Times(arc_chunk1.weight, lat_.Final(prev_final_state)), weight_offset); @@ -1267,7 +1309,11 @@ void LatticeIncrementalDeterminizer::GetInitialRawLattice( std::pair(arc_chunk1.olabel, last_state)); // sanity check g_last2first_state[last_state]=initial_weight.Weight().Value1()+initial_weight.Weight().Value2(); + best_cost=std::min(best_cost,initial_weight.Weight().Value1()+initial_weight.Weight().Value2()); } + // sanity check + KALDI_ASSERT(std::abs(best_cost_in_chunk_ - best_cost) < 1e-1); + } template @@ -1287,7 +1333,7 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks( else forward_costs_.push_back(0); // for the first state // TODO assert whether it is similar to tot_cost - forward_costs_.resize(state_offset+clat.NumStates()); // we append all states except the initial state + forward_costs_.resize(state_offset+clat.NumStates(), std::numeric_limits::infinity()); // we append all states except the initial state for (StateIterator siter(clat); !siter.Done(); siter.Next()) { auto s = siter.Value(); @@ -1323,7 +1369,10 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks( source_state = arc.olabel - state_last_initial_offset_; // TODO description arc_appended.olabel = 0; arc_appended.ilabel = 0; - // TODO: remove alpha in weight + CompactLatticeWeight weight_offset; + weight_offset.SetWeight(LatticeWeight(0, -forward_costs_[source_state])); + // remove alpha in weight + arc_appended.weight=Times(arc_appended.weight, weight_offset); } KALDI_ASSERT(source_state != kNoStateId); olat->AddArc(source_state, arc_appended); From c5f0a8e14f18458928fc6790c9c1e8e75380f07a Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Wed, 10 Apr 2019 21:56:38 +0800 Subject: [PATCH 19/60] enable det --- src/decoder/lattice-incremental-decoder.cc | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index b9b38b686ec..e346ee4dae2 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -1125,7 +1125,11 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( ofst->SetFinal(cur_state, weight); // for sanity check // we will use extra_cost in step 1.3 (see the following code) - best_cost_in_chunk_ = std::min( + if (create_final_state) + best_cost_in_chunk_ = std::min( + best_cost_in_chunk_, tok->tot_cost + (tok->extra_cost-tok->tot_cost) + weight.Value1() + weight.Value2()); + else + best_cost_in_chunk_ = std::min( best_cost_in_chunk_, tok->tot_cost + weight.Value1() + weight.Value2()); } } @@ -1153,7 +1157,7 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( // For now, we use extra_cost from the decoding stage , which has some // "future information", as // the final weights of this chunk - BaseFloat cost_offset = 0; // TODO tok->extra_cost-tok->tot_cost; + BaseFloat cost_offset = tok->extra_cost-tok->tot_cost; // We record these cost_offset, and after we appending two chunks // we will cancel them out state_label_final_cost_[id] = cost_offset; @@ -1206,7 +1210,7 @@ bool LatticeIncrementalDeterminizer::ProcessChunk( // LatticeFasterDecoder, we need to use a slightly larger beam here // than the lattice_beam used PruneActiveTokens. Hence the beam we use is // (config_.determinize_beam_offset + config_.lattice_beam) -#if 0 +#if 1 ret &= DeterminizeLatticePhonePrunedWrapper( trans_model_, &raw_fst, config_.determinize_beam_offset + config_.lattice_beam + 0.1, &clat, config_.det_opts); @@ -1223,7 +1227,8 @@ bool LatticeIncrementalDeterminizer::ProcessChunk( std::vector alignment; std::vector words; GetLinearSymbolSequence(decoded, &alignment, &words, &weight); - KALDI_ASSERT(alignment.size() == last_frame-first_frame); + // remove the following because the chunk starts from the last final arc, and we cannot decide which frame it is in + //KALDI_ASSERT(alignment.size() == last_frame-first_frame); // TODO: the following KALDI_ASSERT will fail some time, which is unexpected // for sanity check KALDI_ASSERT(std::abs(best_cost_in_chunk_ - (weight.Value1() + @@ -1272,8 +1277,6 @@ void LatticeIncrementalDeterminizer::GetInitialRawLattice( auto start_state = olat->AddState(); olat->SetStart(start_state); - // sanity check - BaseFloat best_cost = std::numeric_limits::infinity(); for (auto &i : final_arc_list_prev_) { ArcIterator aiter_chunk1(lat_, i.first); aiter_chunk1.Seek(i.second); @@ -1309,11 +1312,7 @@ void LatticeIncrementalDeterminizer::GetInitialRawLattice( std::pair(arc_chunk1.olabel, last_state)); // sanity check g_last2first_state[last_state]=initial_weight.Weight().Value1()+initial_weight.Weight().Value2(); - best_cost=std::min(best_cost,initial_weight.Weight().Value1()+initial_weight.Weight().Value2()); } - // sanity check - KALDI_ASSERT(std::abs(best_cost_in_chunk_ - best_cost) < 1e-1); - } template From 467abd8be6a5d5c96979303c1b52627842395b38 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Wed, 10 Apr 2019 23:40:44 +0800 Subject: [PATCH 20/60] clean code --- src/decoder/lattice-incremental-decoder.cc | 296 +++++++++------------ src/decoder/lattice-incremental-decoder.h | 198 +++++++------- 2 files changed, 238 insertions(+), 256 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index e346ee4dae2..fe702407eae 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -79,11 +79,10 @@ void LatticeIncrementalDecoderTpl::InitDecoding() { num_toks_++; last_get_lattice_frame_ = 0; - state_label_map_.clear(); - state_label_map_.reserve(std::min((int32)1e5, config_.max_active)); - state_label_available_idx_ = config_.max_word_id + 1; - state_label_initial_cost_.clear(); - state_label_final_cost_.clear(); + token_label_map_.clear(); + token_label_map_.reserve(std::min((int32)1e5, config_.max_active)); + token_label_available_idx_ = config_.max_word_id + 1; + token_label_final_cost_.clear(); determinizer_.Init(); ProcessNonemitting(config_.beam); @@ -112,16 +111,16 @@ bool LatticeIncrementalDecoderTpl::Decode( // identical to PruneActiveTokens since we need extra_cost as the weights // of final arcs to denote the "future" information of final states (Tokens) int32 frame_det_least = last_get_lattice_frame_ + config_.prune_interval; - if (frame_det_most > 0) { + if (config_.determinize_lattice && frame_det_most > 0) { // To adaptively decide the length of chunk, we further compare the number of // tokens in each frame and a pre-defined threshold. // If the number of tokens in a certain frame is less than // config_.determinize_max_active, the lattice can be determinized up to this // frame. And we try to determinize as most frames as possible so we check - // numbers from frame_det_most to frame_det_least + // numbers from frame_det_most to frame_det_least for (int32 f = frame_det_most; f >= frame_det_least; f--) { - if (config_.determinize_max_active == std::numeric_limits::max() - || GetNumToksForFrame(f) < config_.determinize_max_active) { + if (config_.determinize_max_active == std::numeric_limits::max() || + GetNumToksForFrame(f) < config_.determinize_max_active) { KALDI_VLOG(2) << "Frame: " << NumFramesDecoded() << " incremental determinization up to " << f; GetLattice(false, false, f); @@ -135,8 +134,10 @@ bool LatticeIncrementalDecoderTpl::Decode( } Timer timer; FinalizeDecoding(); - GetLattice(true, config_.redeterminize, NumFramesDecoded()); - KALDI_VLOG(2) << "Delay time after decoding finalized (secs): " << timer.Elapsed(); + if (config_.determinize_lattice) + GetLattice(true, config_.redeterminize, NumFramesDecoded()); + KALDI_VLOG(2) << "Delay time during and after decoding finalization (secs): " + << timer.Elapsed(); // Returns true if we have any kind of traceback available (not necessarily // to the end state; query ReachedFinal() for that). @@ -944,8 +945,6 @@ template bool LatticeIncrementalDecoderTpl::GetLattice(CompactLattice *olat) { return GetLattice(true, config_.redeterminize, NumFramesDecoded(), olat); } -// sanity check -BaseFloat best_cost_in_chunk_; template bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, @@ -966,11 +965,12 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, // In this function, we do not create the initial state in // the first chunk, and we do not create the final state in the last chunk if (!GetIncrementalRawLattice(&raw_fst, use_final_probs, last_get_lattice_frame_, - last_frame_of_chunk, not_first_chunk, !decoding_finalized_)) + last_frame_of_chunk, not_first_chunk, + !decoding_finalized_)) KALDI_ERR << "Unexpected problem when getting lattice"; // step 2-3 ret = determinizer_.ProcessChunk(raw_fst, last_get_lattice_frame_, - last_frame_of_chunk, state_label_initial_cost_); + last_frame_of_chunk); last_get_lattice_frame_ = last_frame_of_chunk; } else if (last_get_lattice_frame_ > last_frame_of_chunk) KALDI_WARN << "Call GetLattice up to frame: " << last_frame_of_chunk @@ -995,9 +995,6 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, return ret; } -// sanity check -unordered_map g_last2first_state; - template bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( Lattice *ofst, bool use_final_probs, int32 frame_begin, int32 frame_end, @@ -1019,10 +1016,12 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( ComputeFinalCosts(&final_costs_local, NULL, NULL); ofst->DeleteStates(); - unordered_multimap state_label2state_map; // for GetInitialRawLattice - // initial arcs for the chunk - if (create_initial_state) - determinizer_.GetInitialRawLattice(ofst, &state_label2state_map, state_label_final_cost_); + unordered_multimap + token_label2last_state_map; // for GetInitialRawLattice + // initial arcs for the chunk + if (create_initial_state) + determinizer_.GetInitialRawLattice(ofst, &token_label2last_state_map, + token_label_final_cost_); // num-frames plus one (since frames are one-based, and we have // an extra frame for the start-state). KALDI_ASSERT(frame_end > 0); @@ -1053,34 +1052,25 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( if (create_initial_state) { for (Token *tok = active_toks_[frame_begin].toks; tok != NULL; tok = tok->next) { StateId cur_state = tok_map[tok]; - // state_label_map_ is construct during create_final_state - auto r = state_label_map_.find(tok); - KALDI_ASSERT(r != state_label_map_.end()); // it should exist - int32 id = r->second; - auto range = state_label2state_map.equal_range(id); + // token_label_map_ is construct during create_final_state + auto r = token_label_map_.find(tok); + KALDI_ASSERT(r != token_label_map_.end()); // it should exist + int32 token_label = r->second; + auto range = token_label2last_state_map.equal_range(token_label); KALDI_ASSERT(range.first != range.second); - // sanity check - auto& forward_costs = determinizer_.GetForwardCosts(); - BaseFloat best_cost = std::numeric_limits::infinity(); std::vector tmp_vec; - for (auto it = range.first; it!=range.second; ++it) { - // the destination state of the last of the sequence of arcs w.r.t the id here (state label) created by GetInitialRawLattice + for (auto it = range.first; it != range.second; ++it) { + // the destination state of the last of the sequence of arcs w.r.t the token + // label + // here created by GetInitialRawLattice auto state_last_initial = it->second; - // connect it to the state correponding to the token w.r.t the id here (state label) + // connect it to the state correponding to the token w.r.t the token label + // here Arc arc(0, 0, Weight::One(), cur_state); ofst->AddArc(state_last_initial, arc); - // sanity check - auto r = g_last2first_state.find(state_last_initial); - KALDI_ASSERT(r!=g_last2first_state.end()); - tmp_vec.push_back(r->second); - best_cost = std::min(best_cost, r->second); } - // sanity check - KALDI_ASSERT(abs(best_cost - tok->tot_cost) < 0.1); } } - // for sanity check - best_cost_in_chunk_ = std::numeric_limits::infinity(); // step 1.2: create all arcs as GetRawLattice() of LatticeFasterDecoder for (int32 f = frame_begin; f <= frame_end; f++) { for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { @@ -1094,12 +1084,10 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( KALDI_ASSERT(iter != tok_map.end()); StateId nextstate = iter->second; BaseFloat cost_offset = 0.0; - /* // TODO if (l->ilabel != 0) { // emitting.. KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); cost_offset = cost_offsets_[f]; } - */ Arc arc(l->ilabel, l->olabel, Weight(l->graph_cost, l->acoustic_cost - cost_offset), nextstate); ofst->AddArc(cur_state, arc); @@ -1123,14 +1111,6 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( weight = LatticeWeight::Zero(); } ofst->SetFinal(cur_state, weight); - // for sanity check - // we will use extra_cost in step 1.3 (see the following code) - if (create_final_state) - best_cost_in_chunk_ = std::min( - best_cost_in_chunk_, tok->tot_cost + (tok->extra_cost-tok->tot_cost) + weight.Value1() + weight.Value2()); - else - best_cost_in_chunk_ = std::min( - best_cost_in_chunk_, tok->tot_cost + weight.Value1() + weight.Value2()); } } } @@ -1139,14 +1119,14 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( StateId end_state = ofst->AddState(); // final-state for the chunk ofst->SetFinal(end_state, Weight::One()); - state_label_map_.clear(); - state_label_map_.reserve(std::min((int32)1e5, config_.max_active)); + token_label_map_.clear(); + token_label_map_.reserve(std::min((int32)1e5, config_.max_active)); for (Token *tok = active_toks_[frame_end].toks; tok != NULL; tok = tok->next) { StateId cur_state = tok_map[tok]; // We assign an unique state label for each of the token in the last frame // of this chunk - int32 id = state_label_available_idx_++; - state_label_map_[tok] = id; + int32 id = token_label_available_idx_++; + token_label_map_[tok] = id; // The final weight has been worked out in the previous for loop and // store in the states // Here, we create a specific final state, and move the final costs to @@ -1157,10 +1137,10 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( // For now, we use extra_cost from the decoding stage , which has some // "future information", as // the final weights of this chunk - BaseFloat cost_offset = tok->extra_cost-tok->tot_cost; + BaseFloat cost_offset = tok->extra_cost - tok->tot_cost; // We record these cost_offset, and after we appending two chunks // we will cancel them out - state_label_final_cost_[id] = cost_offset; + token_label_final_cost_[id] = cost_offset; Arc arc(0, id, Times(final_weight, Weight(0, cost_offset)), end_state); ofst->AddArc(cur_state, arc); ofst->SetFinal(cur_state, Weight::Zero()); @@ -1189,83 +1169,16 @@ void LatticeIncrementalDeterminizer::Init() { lat_.DeleteStates(); determinization_finalized_ = false; forward_costs_.clear(); - state_last_initial_offset_ = 2*config_.max_word_id; + state_last_initial_offset_ = 2 * config_.max_word_id; } -template -bool LatticeIncrementalDeterminizer::ProcessChunk( - Lattice &raw_fst, int32 first_frame, int32 last_frame, - const unordered_map &state_label_initial_cost) { - bool not_first_chunk = first_frame != 0; - bool ret = true; - // step 2: Determinize the chunk - CompactLattice clat; - // We do determinization with beam pruning here - // Only if we use a beam larger than (config_.beam+config_.lattice_beam) here, we - // can guarantee no final or initial arcs in clat are pruned by this function. - // These pruned final arcs can hurt oracle WER performance in the final lattice - // (also result in less lattice density) but they seldom hurt 1-best WER. - // Since pruning behaviors in DeterminizeLatticePhonePrunedWrapper and - // PruneActiveTokens are not the same, to get similar lattice density as - // LatticeFasterDecoder, we need to use a slightly larger beam here - // than the lattice_beam used PruneActiveTokens. Hence the beam we use is - // (config_.determinize_beam_offset + config_.lattice_beam) -#if 1 - ret &= DeterminizeLatticePhonePrunedWrapper( - trans_model_, &raw_fst, config_.determinize_beam_offset + - config_.lattice_beam + 0.1, &clat, config_.det_opts); -#else - ConvertLattice(raw_fst, &clat); -#endif - { - // sanity check, remove them later - CompactLattice cdecoded; - Lattice decoded; - ShortestPath(clat, &cdecoded); - ConvertLattice(cdecoded, &decoded); - LatticeWeight weight; - std::vector alignment; - std::vector words; - GetLinearSymbolSequence(decoded, &alignment, &words, &weight); - // remove the following because the chunk starts from the last final arc, and we cannot decide which frame it is in - //KALDI_ASSERT(alignment.size() == last_frame-first_frame); - // TODO: the following KALDI_ASSERT will fail some time, which is unexpected - // for sanity check - KALDI_ASSERT(std::abs(best_cost_in_chunk_ - (weight.Value1() + - weight.Value2())) < 1e-1); - } - - // step 3: Appending the new chunk in clat to the old one in lat_ - ret &= AppendLatticeChunks(clat, not_first_chunk, state_label_initial_cost); - { - // sanity check, remove them later - CompactLattice cdecoded; - Lattice decoded; - ShortestPath(lat_, &cdecoded); - ConvertLattice(cdecoded, &decoded); - LatticeWeight weight; - std::vector alignment; - std::vector words; - GetLinearSymbolSequence(decoded, &alignment, &words, &weight); - KALDI_ASSERT(alignment.size() == last_frame); - // TODO: the following KALDI_ASSERT will fail some time, which is unexpected - // for sanity check - KALDI_ASSERT(std::abs(best_cost_in_chunk_ - (weight.Value1() + - weight.Value2())) < 1e-1); - } - - ret &= (lat_.NumStates() > 0); - KALDI_VLOG(2) << "Frame: ( " << first_frame << " , " << last_frame << " )" - << " states of the chunk: " << clat.NumStates() - << " states of the lattice: " << lat_.NumStates(); - return ret; -} -// We have multiple states for one state label after determinization +// This function is specifically designed to obtain the initial arcs for a chunk +// We have multiple states for one token label after determinization template void LatticeIncrementalDeterminizer::GetInitialRawLattice( Lattice *olat, - unordered_multimap *state_label2state_map, - const unordered_map &state_label_final_cost) { + unordered_multimap *token_label2last_state_map, + const unordered_map &token_label_final_cost) { using namespace fst; typedef LatticeArc Arc; typedef Arc::StateId StateId; @@ -1273,7 +1186,7 @@ void LatticeIncrementalDeterminizer::GetInitialRawLattice( typedef Arc::Label Label; KALDI_ASSERT(final_arc_list_prev_.size()); olat->DeleteStates(); - state_label2state_map->clear(); + token_label2last_state_map->clear(); auto start_state = olat->AddState(); olat->SetStart(start_state); @@ -1285,20 +1198,33 @@ void LatticeIncrementalDeterminizer::GetInitialRawLattice( KALDI_ASSERT(arc_chunk1.olabel > config_.max_word_id); StateId prev_final_state = arc_chunk1.nextstate; CompactLatticeWeight weight_offset; - // TODO: description - const auto r2 = state_label_final_cost.find(arc_chunk1.olabel); - KALDI_ASSERT(r2 != state_label_final_cost.end()); + // To cancel out the weight on the final arcs, which is (extra cost - forward + // cost). + // see token_label_final_cost for more details + const auto r = token_label_final_cost.find(arc_chunk1.olabel); + KALDI_ASSERT(r != token_label_final_cost.end()); + auto cost_offset = r->second; + // Moreover, we need to use the forward coast (alpha) of this determinized and + // appended state to guide the determinization later KALDI_ASSERT(i.first < forward_costs_.size()); - weight_offset.SetWeight(LatticeWeight(0, forward_costs_[i.first]-r2->second)); - auto initial_weight = Times(Times(arc_chunk1.weight, lat_.Final(prev_final_state)), weight_offset); - - // create a state representing the i.first state (source state) in appended lattice + auto alpha_cost = forward_costs_[i.first]; + weight_offset.SetWeight(LatticeWeight(0, alpha_cost - cost_offset)); + // The initial_weight is a combination of above cost_offset, alpha_cost and the + // weights on the previous final arc and the final state + auto initial_weight = + Times(Times(arc_chunk1.weight, lat_.Final(prev_final_state)), weight_offset); + + // create a state representing the i.first state (source state) in appended + // lattice auto source_state = olat->AddState(); - // we need a special label in the arc that corresponds to the identity of the source-state of the last arc, we use its StateId and a offset here - int id = i.first + state_last_initial_offset_; - Arc arc(0, id, initial_weight.Weight(), source_state); + // we need a special label in the arc that corresponds to the identity of the + // source-state of the last arc, we use its StateId and a offset here, called + // state_label + int state_label = i.first + state_last_initial_offset_; + Arc arc(0, state_label, initial_weight.Weight(), source_state); olat->AddArc(start_state, arc); - // We generate a linear sequence of arcs sufficient to contain all the transition-ids on the string + // We generate a linear sequence of arcs sufficient to contain all the + // transition-ids on the string auto prev_state = source_state; for (auto &j : initial_weight.String()) { auto cur_state = olat->AddState(); @@ -1306,19 +1232,50 @@ void LatticeIncrementalDeterminizer::GetInitialRawLattice( olat->AddArc(prev_state, arc); prev_state = cur_state; } - auto last_state = olat->NumStates()-1; - // the destination state of the last of the sequence of arcs will be recorded and connected to the state corresponding to token w.r.t arc_chunk1.olabel - state_label2state_map->insert( + // the destination state of the last of the sequence of arcs will be recorded and + // connected to the state corresponding to token w.r.t arc_chunk1.olabel + // Notably, we have multiple states for one token label after determinization, + // hence we use multiset here + auto last_state = olat->NumStates() - 1; + token_label2last_state_map->insert( std::pair(arc_chunk1.olabel, last_state)); - // sanity check - g_last2first_state[last_state]=initial_weight.Weight().Value1()+initial_weight.Weight().Value2(); } } template -bool LatticeIncrementalDeterminizer::AppendLatticeChunks( - CompactLattice clat, bool not_first_chunk, - const unordered_map &state_label_initial_cost) { +bool LatticeIncrementalDeterminizer::ProcessChunk(Lattice &raw_fst, + int32 first_frame, + int32 last_frame) { + bool not_first_chunk = first_frame != 0; + bool ret = true; + // step 2: Determinize the chunk + CompactLattice clat; + // We do determinization with beam pruning here + // Only if we use a beam larger than (config_.beam+config_.lattice_beam) here, we + // can guarantee no final or initial arcs in clat are pruned by this function. + // These pruned final arcs can hurt oracle WER performance in the final lattice + // (also result in less lattice density) but they seldom hurt 1-best WER. + // Since pruning behaviors in DeterminizeLatticePhonePrunedWrapper and + // PruneActiveTokens are not the same, to get similar lattice density as + // LatticeFasterDecoder, we need to use a slightly larger beam here + // than the lattice_beam used PruneActiveTokens. Hence the beam we use is + // (0.1 + config_.lattice_beam) + ret &= DeterminizeLatticePhonePrunedWrapper( + trans_model_, &raw_fst, (config_.lattice_beam + 0.1), &clat, config_.det_opts); + + // step 3: Appending the new chunk in clat to the old one in lat_ + ret &= AppendLatticeChunks(clat, not_first_chunk); + + ret &= (lat_.NumStates() > 0); + KALDI_VLOG(2) << "Frame: ( " << first_frame << " , " << last_frame << " )" + << " states of the chunk: " << clat.NumStates() + << " states of the lattice: " << lat_.NumStates(); + return ret; +} + +template +bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice clat, + bool not_first_chunk) { using namespace fst; CompactLattice *olat = &lat_; @@ -1331,9 +1288,8 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks( state_offset--; // since we do not append initial state in the first chunk else forward_costs_.push_back(0); // for the first state - // TODO assert whether it is similar to tot_cost - forward_costs_.resize(state_offset+clat.NumStates(), std::numeric_limits::infinity()); // we append all states except the initial state - + forward_costs_.resize(state_offset + clat.NumStates(), + std::numeric_limits::infinity()); for (StateIterator siter(clat); !siter.Done(); siter.Next()) { auto s = siter.Value(); StateId state_appended = kNoStateId; @@ -1350,8 +1306,10 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks( // These arcs will be taken care later in step 3.2 CompactLatticeArc arc_appended(arc); arc_appended.nextstate += state_offset; - // In the first chunk, there could be a final arc starting from state 0, and we process it here - // In the last chunk, there could be a initial arc ending in final state, and we process it in "process initial arcs" in the following + // In the first chunk, there could be a final arc starting from state 0, and we + // process it here + // In the last chunk, there could be a initial arc ending in final state, and + // we process it in "process initial arcs" in the following if (!not_first_chunk || s != 0) { KALDI_ASSERT(state_appended != kNoStateId); source_state = state_appended; @@ -1363,28 +1321,34 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks( pair(state_appended, aiter.Position())); } } else { // process initial arcs - KALDI_ASSERT(arc.olabel > config_.max_word_id); - KALDI_ASSERT(arc.olabel >= state_last_initial_offset_); - source_state = arc.olabel - state_last_initial_offset_; // TODO description + // a special olabel in the arc that corresponds to the identity of the + // source-state of the last arc, we use its StateId and a offset here, called + // state_label + auto state_label = arc.olabel; + KALDI_ASSERT(state_label > config_.max_word_id); + KALDI_ASSERT(state_label >= state_last_initial_offset_); + source_state = state_label - state_last_initial_offset_; arc_appended.olabel = 0; arc_appended.ilabel = 0; CompactLatticeWeight weight_offset; - weight_offset.SetWeight(LatticeWeight(0, -forward_costs_[source_state])); // remove alpha in weight - arc_appended.weight=Times(arc_appended.weight, weight_offset); + weight_offset.SetWeight(LatticeWeight(0, -forward_costs_[source_state])); + arc_appended.weight = Times(arc_appended.weight, weight_offset); } KALDI_ASSERT(source_state != kNoStateId); olat->AddArc(source_state, arc_appended); - // update forward_costs_ + // update forward_costs_ (alpha) KALDI_ASSERT(arc_appended.nextstate < forward_costs_.size()); - auto& alpha_nextstate = forward_costs_[arc_appended.nextstate]; - auto& weight = arc_appended.weight.Weight(); - alpha_nextstate = std::min(alpha_nextstate, forward_costs_[source_state] + weight.Value1() + weight.Value2()); + auto &alpha_nextstate = forward_costs_[arc_appended.nextstate]; + auto &weight = arc_appended.weight.Weight(); + alpha_nextstate = + std::min(alpha_nextstate, + forward_costs_[source_state] + weight.Value1() + weight.Value2()); } } // Making all remaining arcs of final_arc_list_prev_ be connected to - // a dead state. + // a dead state. // final states are always the same state) if (not_first_chunk) { KALDI_ASSERT(final_arc_list_prev_.size()); @@ -1393,7 +1357,7 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks( aiter_chunk1.Seek(i.second); // Obtain the appended final arcs in the previous chunk auto &arc_chunk1 = aiter_chunk1.Value(); - olat->SetFinal(arc_chunk1.nextstate, CompactLatticeWeight::Zero()); + olat->SetFinal(arc_chunk1.nextstate, CompactLatticeWeight::Zero()); } } else olat->SetStart(0); // Initialize the first chunk for olat diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 5d156ee62e2..6093e3b16f7 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -43,7 +43,6 @@ struct LatticeIncrementalDecoderConfig { int32 prune_interval; int32 determinize_delay; int32 determinize_max_active; - BaseFloat determinize_beam_offset; bool redeterminize; bool determinize_lattice; // not inspected by this class... used in // command-line program. @@ -67,7 +66,6 @@ struct LatticeIncrementalDecoderConfig { prune_interval(25), determinize_delay(25), determinize_max_active(std::numeric_limits::max()), - determinize_beam_offset(1), redeterminize(false), determinize_lattice(true), beam_delta(0.5), @@ -95,9 +93,6 @@ struct LatticeIncrementalDecoderConfig { "If the number of active tokens(in a certain frame) is less " "than this number, we will start to incrementally " "determinize lattices up to this frame."); - opts->Register("determinize-beam-offset", &determinize_beam_offset, - "the beam used in lattice determinization is " - "(determinize_beam_offset + lattice_beam) ."); opts->Register("redeterminize", &redeterminize, "whether to re-determinize the lattice after incremental " "determinization."); @@ -187,60 +182,73 @@ class LatticeIncrementalDecoderTpl { /// and figures out the shortest path. bool GetBestPath(Lattice *ofst, bool use_final_probs = true); - /// The following function is specifically designed for incremental - /// determinization. The function obtains a CompactLattice for - /// the part of this utterance up to the frame last_frame_of_chunk. - /// If you call this multiple times - /// (calling it on every frame would not make sense, - /// but every, say, 10 to 40 frames might make sense) it will spread out the - /// work of determinization over time, which might be useful for online - /// applications. - /// - /// The procedure of incremental determinization is as follow: - /// step 1: Get lattice chunk with initial and final states corresponding - /// to tokens in the first and last frames of this chunk - /// We need to give permanent labels (called "state labels") to these - /// raw-lattice states (Tokens). - /// step 2: Determinize the chunk of above raw lattice using determinization - /// algorithm the same as LatticeFasterDecoder. We call the determinized new - /// chunk "clat" - /// step 3: Appending the new chunk "clat" to the determinized lattice - /// before this chunk. First, for each state-id in clat2 *except* its - /// initial state, allocate a new state-id in the appended - /// compact-lattice. Copy the arcs except whose incoming state is initial - /// state. Secondly, for each final arc in previous chunk, check whether - /// the corresponding initial arc exists in the newly appended chunk. If - /// not, we make the final arc point to a "dead state" - /// Otherwise, we modify this arc to connect to the corresponding next - /// state of the initial arc with a proper weight - /// step 4: We re-determinize the appended lattice if needed. - /// - /// In our implementation, step 1 is done in GetRawLattice(), - /// step 2-4 is taken care by the class - /// LatticeIncrementalDeterminizer - /// - /// @param [in] use_final_probs If true *and* at least one final-state in HCLG - /// was active on the final frame, include final-probs from - /// HCLG - /// in the lattice. Otherwise treat all final-costs of - /// states active - /// on the most recent frame as zero (i.e. Weight::One()). - /// @param [in] redeterminize If true, re-determinize the CompactLattice - /// after appending the most recently decoded chunk to it, - /// to - /// ensure that the output is fully deterministic. - /// This does extra work, but not nearly as much as - /// determinizing - /// a RawLattice from scratch. - /// @param [in] last_frame_of_chunk Pass the last frame of this chunk to - /// the function. We make it not always equal to - /// NumFramesDecoded() to have a delay on the - /// deteriminization - /// @param [out] olat The CompactLattice representing what has been decoded - /// so far. + // The following function is specifically designed for incremental + // determinization. The function obtains a CompactLattice for + // the part of this utterance up to the frame last_frame_of_chunk. + // If you call this multiple times + // (calling it on every frame would not make sense, + // but every, say, 10 to 40 frames might make sense) it will spread out the + // work of determinization over time, which might be useful for online + // applications. + // + // The procedure of incremental determinization is as follow: + // step 1: Get lattice chunk with initial and final states and arcs, called `raw + // lattice`. + // Here, we define a `final arc` as an arc to a final-state, and the source state + // of it as a `pre-final state` + // Similarly, we define a `initial arc` as an arc from a initial-state, and the + // destination state of it as a `post-initial state` + // The initial states are constructed corresponding to pre-final states in the + // determinized and appended lattice before this chunk + // The final states are constructed correponding to tokens in the last frames of + // this chunk + // Since the StateId can change during determinization, we need to give permanent + // unique labels (as olabel) to these + // raw-lattice states for latter appending. + // We give each token an olabel id, called `token_label`, and each determinized and + // appended state an olabel id, called `state_label` + // step 2: Determinize the chunk of above raw lattice using determinization + // algorithm the same as LatticeFasterDecoder. Benefit from above `state_label` and + // `token_label` in initial and final arcs, each pre-final state in the last chunk + // w.r.t the initial arc of this chunk can be treated uniquely and each token in + // the last frame of this chunk can also be treated uniquely. We call the + // determinized new + // chunk `compact lattice (clat)` + // step 3: Appending the new chunk `clat` to the determinized lattice + // before this chunk. First, for each StateId in clat except its + // initial state, allocate a new StateId in the appended + // compact lattice. Copy the arcs except whose incoming state is initial + // state. Secondly, for each initial arcs, change its source state to the state + // corresponding to its `state_label`, which is a determinized and appended state + // Finally, we make the previous final arcs point to a "dead state" + // step 4 (optional): We re-determinize the appended lattice if needed. + // + // In our implementation, step 1 is done in GetIncrementalRawLattice(), + // step 2-4 is taken care by the class + // LatticeIncrementalDeterminizer + // + // @param [in] use_final_probs If true *and* at least one final-state in HCLG + // was active on the final frame, include final-probs from + // HCLG + // in the lattice. Otherwise treat all final-costs of + // states active + // on the most recent frame as zero (i.e. Weight::One()). + // @param [in] redeterminize If true, re-determinize the CompactLattice + // after appending the most recently decoded chunk to it, + // to + // ensure that the output is fully deterministic. + // This does extra work, but not nearly as much as + // determinizing + // a RawLattice from scratch. + // @param [in] last_frame_of_chunk Pass the last frame of this chunk to + // the function. We make it not always equal to + // NumFramesDecoded() to have a delay on the + // deteriminization + // @param [out] olat The CompactLattice representing what has been decoded + // so far. // If lat == NULL, the CompactLattice won't be outputed. - /// @return ret This function will returns true if the chunk is processed - /// successfully + // @return ret This function will returns true if the chunk is processed + // successfully bool GetLattice(bool use_final_probs, bool redeterminize, int32 last_frame_of_chunk, CompactLattice *olat = NULL); /// Specifically design when decoding_finalized_==true @@ -477,41 +485,43 @@ class LatticeIncrementalDecoderTpl { // It does the same thing as GetRawLattice in lattice-faster-decoder.cc except: // // i) it creates a initial state, and connect - // all the tokens in the first frame of this chunk to the initial state - // by an arc with a per-token state-label as its olabel + // each token in the first frame of this chunk to the initial state + // by one or more arcs with a state_label correponding to the pre-final state w.r.t + // this token(the pre-final state is appended in the last chunk) as its olabel // ii) it creates a final state, and connect // all the tokens in the last frame of this chunk to the final state - // by an arc with a per-token state-label as its olabel - // the state-label for a token in both i) and ii) should be the same - // frame_begin and frame_end are the first and last frame of this chunk - // if create_initial_state == false, we will not create initial state and - // the corresponding state-label arcs. Similar for create_final_state + // by an arc with a per-token token_label as its olabel + // `frame_begin` and `frame_end` are the first and last frame of this chunk + // if `create_initial_state` == false, we will not create initial state and + // the corresponding initial arcs. Similar for `create_final_state` // In incremental GetLattice, we do not create the initial state in // the first chunk, and we do not create the final state in the last chunk - bool GetIncrementalRawLattice(Lattice *ofst, bool use_final_probs, - int32 frame_begin, int32 frame_end, - bool create_initial_state, bool create_final_state); + bool GetIncrementalRawLattice(Lattice *ofst, bool use_final_probs, + int32 frame_begin, int32 frame_end, + bool create_initial_state, bool create_final_state); // Get the number of tokens in each frame // It is useful, e.g. in using config_.determinize_max_active int32 GetNumToksForFrame(int32 frame); + + // The incremental lattice determinizer to take care of step 2-4 LatticeIncrementalDeterminizer determinizer_; int32 last_get_lattice_frame_; // the last time we call GetLattice - // a map from Token to its state_label - unordered_map state_label_map_; + // a map from Token to its token_label + unordered_map token_label_map_; // we allocate a unique id for each Token - int32 state_label_available_idx_; - // We keep tot_cost or extra_cost for each state_label (Token) in final and - // initial arcs. We need them before determinization + int32 token_label_available_idx_; + // We keep cost_offset for each token_label (Token) in final arcs. We need them to + // guide determinization // We cancel them after determinization - unordered_map state_label_initial_cost_; // TODO remove it - unordered_map state_label_final_cost_; // TODO remove it + unordered_map token_label_final_cost_; KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDecoderTpl); }; typedef LatticeIncrementalDecoderTpl LatticeIncrementalDecoder; -// This class is designed for step 2-4 of incremental determinization +// This class is designed for step 2-4 and part of step 1 of incremental +// determinization // introduced before above GetLattice() template class LatticeIncrementalDeterminizer { @@ -528,35 +538,42 @@ class LatticeIncrementalDeterminizer { // Output the resultant determinized lattice in the form of CompactLattice const CompactLattice &GetDeterminizedLattice() const { return lat_; } + // Part of step 1 of incremental determinization, + // where the initial states are constructed corresponding to pre-final states in + // the determinized and appended lattice before this chunk + // We give each determinized and appended state an olabel id, called `state_label` + // We maintain a map (`token_label2last_state_map`) from token label (obtained from + // final arcs) to the destination state of the last of the sequence of initial arcs + // w.r.t the token label here + // Notably, we have multiple states for one token label after determinization, + // hence we use multiset here + // We need `token_label_final_cost` to cancel out the cost offset used in guiding + // DeterminizeLatticePhonePrunedWrapper + void GetInitialRawLattice( + Lattice *olat, + unordered_multimap *token_label2last_state_map, + const unordered_map &token_label_final_cost); // This function consumes raw_fst generated by step 1 of incremental // determinization with specific initial and final arcs. // It does step 2-4 and outputs the resultant CompactLattice if // needed. Otherwise, it keeps the resultant lattice in lat_ - bool ProcessChunk(Lattice &raw_fst, int32 first_frame, int32 last_frame, - const unordered_map &state_label_initial_cost); + bool ProcessChunk(Lattice &raw_fst, int32 first_frame, int32 last_frame); // Step 3 of incremental determinization, // which is to append the new chunk in clat to the old one in lat_ // If not_first_chunk == false, we do not need to append and just copy // clat into olat - // Otherwise, we need to connect the last frame state of - // last chunk to the first frame state of this chunk. + // Otherwise, we need to connect states of the last frame of + // the last chunk to states of the first frame of this chunk. // These begin and final states are corresponding to the same Token, // guaranteed by unique state labels. - bool AppendLatticeChunks( - CompactLattice clat, bool not_first_chunk, - const unordered_map &state_label_initial_cost); + bool AppendLatticeChunks(CompactLattice clat, bool not_first_chunk); // Step 4 of incremental determinization, // which either re-determinize above lat_, or simply remove the dead // states of lat_ bool Finalize(bool redeterminize); - std::vector& GetForwardCosts() { - return forward_costs_; - } - void GetInitialRawLattice(Lattice *olat, - unordered_multimap *state_label2state_map, - const unordered_map &state_label_final_cost); + std::vector &GetForwardCosts() { return forward_costs_; } private: const LatticeIncrementalDecoderConfig config_; @@ -570,7 +587,8 @@ class LatticeIncrementalDeterminizer { std::vector> final_arc_list_prev_; // alpha of each state in lat_ std::vector forward_costs_; - // we allocate a unique id for each source-state of the last arc of a series of initial arcs in GetInitialRawLattice + // we allocate a unique id for each source-state of the last arc of a series of + // initial arcs in GetInitialRawLattice int32 state_last_initial_offset_; // The compact lattice we obtain. It should be reseted before processing a From 612d398f6a44a536d57c767d25fe5edb07d89864 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Mon, 29 Apr 2019 21:48:44 +0800 Subject: [PATCH 21/60] [experimental] new det algorithm (#31) the new algorithm is to determinize "states in the appended lattice with final-arcs to also have non-final arcs leaving them" --- src/decoder/lattice-incremental-decoder.cc | 254 +++++++++++++++------ src/decoder/lattice-incremental-decoder.h | 29 ++- 2 files changed, 215 insertions(+), 68 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index fe702407eae..b34f6687ac9 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -1146,6 +1146,7 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( ofst->SetFinal(cur_state, Weight::Zero()); } } + TopSortLatticeIfNeeded(ofst); return (ofst->NumStates() > 0); } @@ -1169,9 +1170,137 @@ void LatticeIncrementalDeterminizer::Init() { lat_.DeleteStates(); determinization_finalized_ = false; forward_costs_.clear(); + forward_costs_for_redet_.clear(); state_last_initial_offset_ = 2 * config_.max_word_id; + redeterminized_states_.clear(); +} +template +bool LatticeIncrementalDeterminizer::AddRedeterminizedState( + Lattice::StateId nextstate, Lattice *olat, Lattice::StateId *nextstate_copy) { + using namespace fst; + bool modified = false; + StateId nextstate_insert = kNoStateId; + auto r = redeterminized_states_.insert({nextstate, nextstate_insert}); + if (r.second) { // didn't exist, successfully insert here + // create a new state w.r.t state + nextstate_insert = olat->AddState(); + // map from arc.nextstate to nextstate_insert + r.first->second = nextstate_insert; + modified = true; + } else { // else already exist + // get nextstate_insert + nextstate_insert = r.first->second; + KALDI_ASSERT(nextstate_insert != kNoStateId); + modified = false; + } + if (nextstate_copy) *nextstate_copy = nextstate_insert; + return modified; } +template +void LatticeIncrementalDeterminizer::GetRawLatticeForRedeterminizedStates( + StateId start_state, StateId state, + const unordered_map &token_label_final_cost, + unordered_multimap *token_label2last_state_map, + Lattice *olat) { + using namespace fst; + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + + auto r = redeterminized_states_.find(state); + KALDI_ASSERT(r != redeterminized_states_.end()); + auto state_copy = r->second; + KALDI_ASSERT(state_copy != kNoStateId); + ArcIterator aiter(lat_, state); + + // use state_label in initial arcs + int state_label = state + state_last_initial_offset_; + // Moreover, we need to use the forward coast (alpha) of this determinized and + // appended state to guide the determinization later + KALDI_ASSERT(state < forward_costs_for_redet_.size()); + auto alpha_cost = forward_costs_for_redet_[state]; + Arc arc_initial(0, state_label, LatticeWeight(0, alpha_cost), state_copy); + if (alpha_cost != std::numeric_limits::infinity()) + olat->AddArc(start_state, arc_initial); + + for (; !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + auto laststate_copy = kNoStateId; + bool proc_nextstate = false; + auto arc_weight = arc.weight; + + KALDI_ASSERT(arc.olabel == arc.ilabel); + auto arc_olabel = arc.olabel; + + // the destination of the arc is the final state + if (lat_.Final(arc.nextstate) != CompactLatticeWeight::Zero()) { + KALDI_ASSERT(arc_olabel > config_.max_word_id && + arc_olabel < state_last_initial_offset_); // token label + // create a initial arc + + // Get arc weight here + // We will include it in arc_last in the following + CompactLatticeWeight weight_offset; + // To cancel out the weight on the final arcs, which is (extra cost - forward + // cost). + // see token_label_final_cost for more details + const auto r = token_label_final_cost.find(arc_olabel); + KALDI_ASSERT(r != token_label_final_cost.end()); + auto cost_offset = r->second; + weight_offset.SetWeight(LatticeWeight(0, -cost_offset)); + // The arc weight is a combination of original arc weight, above cost_offset + // and the weights on the final state + arc_weight = + Times(Times(arc_weight, lat_.Final(arc.nextstate)), weight_offset); + + // We create a respective destination state for each final arc + // later we will connect it to the state correponding to the token w.r.t + // arc_olabel + laststate_copy = olat->AddState(); + // the destination state of the last of the sequence of arcs will be recorded + // and connected to the state corresponding to token w.r.t arc_olabel + // Notably, we have multiple states for one token label after determinization, + // hence we use multiset here + token_label2last_state_map->insert( + std::pair(arc_olabel, laststate_copy)); + arc_olabel = 0; // remove token label + } else { + // the arc connects to a non-final state (redeterminized state) + KALDI_ASSERT(arc_olabel < config_.max_word_id); // no token label + KALDI_ASSERT(arc_olabel); + // get the nextstate_copy w.r.t arc.nextstate + StateId nextstate_copy = kNoStateId; + proc_nextstate = AddRedeterminizedState(arc.nextstate, olat, &nextstate_copy); + KALDI_ASSERT(nextstate_copy != kNoStateId); + laststate_copy = nextstate_copy; + } + auto &state_seqs = arc_weight.String(); + // create new arcs w.r.t arc + // the following is for a normal arc + // We generate a linear sequence of arcs sufficient to contain all the + // transition-ids on the string + auto prev_state = state_copy; // from state_copy + for (auto &j : state_seqs) { + auto cur_state = olat->AddState(); + Arc arc(j, 0, LatticeWeight::One(), cur_state); + olat->AddArc(prev_state, arc); + prev_state = cur_state; + } + + // connect previous sequence of arcs to the laststate_copy + // the weight on the previous arc is stored in the arc to laststate_copy here + Arc arc_last(0, arc_olabel, arc_weight.Weight(), laststate_copy); + olat->AddArc(prev_state, arc_last); + + // not final state && previously didn't process this state + if (proc_nextstate) + GetRawLatticeForRedeterminizedStates(start_state, arc.nextstate, + token_label_final_cost, + token_label2last_state_map, olat); + } +} // This function is specifically designed to obtain the initial arcs for a chunk // We have multiple states for one token label after determinization template @@ -1190,55 +1319,14 @@ void LatticeIncrementalDeterminizer::GetInitialRawLattice( auto start_state = olat->AddState(); olat->SetStart(start_state); + // go over all prefinal state for (auto &i : final_arc_list_prev_) { - ArcIterator aiter_chunk1(lat_, i.first); - aiter_chunk1.Seek(i.second); - // Obtain the appended final arcs in the previous chunk - const auto &arc_chunk1 = aiter_chunk1.Value(); - KALDI_ASSERT(arc_chunk1.olabel > config_.max_word_id); - StateId prev_final_state = arc_chunk1.nextstate; - CompactLatticeWeight weight_offset; - // To cancel out the weight on the final arcs, which is (extra cost - forward - // cost). - // see token_label_final_cost for more details - const auto r = token_label_final_cost.find(arc_chunk1.olabel); - KALDI_ASSERT(r != token_label_final_cost.end()); - auto cost_offset = r->second; - // Moreover, we need to use the forward coast (alpha) of this determinized and - // appended state to guide the determinization later - KALDI_ASSERT(i.first < forward_costs_.size()); - auto alpha_cost = forward_costs_[i.first]; - weight_offset.SetWeight(LatticeWeight(0, alpha_cost - cost_offset)); - // The initial_weight is a combination of above cost_offset, alpha_cost and the - // weights on the previous final arc and the final state - auto initial_weight = - Times(Times(arc_chunk1.weight, lat_.Final(prev_final_state)), weight_offset); - - // create a state representing the i.first state (source state) in appended - // lattice - auto source_state = olat->AddState(); - // we need a special label in the arc that corresponds to the identity of the - // source-state of the last arc, we use its StateId and a offset here, called - // state_label - int state_label = i.first + state_last_initial_offset_; - Arc arc(0, state_label, initial_weight.Weight(), source_state); - olat->AddArc(start_state, arc); - // We generate a linear sequence of arcs sufficient to contain all the - // transition-ids on the string - auto prev_state = source_state; - for (auto &j : initial_weight.String()) { - auto cur_state = olat->AddState(); - Arc arc(j, 0, LatticeWeight::One(), cur_state); - olat->AddArc(prev_state, arc); - prev_state = cur_state; - } - // the destination state of the last of the sequence of arcs will be recorded and - // connected to the state corresponding to token w.r.t arc_chunk1.olabel - // Notably, we have multiple states for one token label after determinization, - // hence we use multiset here - auto last_state = olat->NumStates() - 1; - token_label2last_state_map->insert( - std::pair(arc_chunk1.olabel, last_state)); + auto prefinal_state = i.first; + bool modified = AddRedeterminizedState(prefinal_state, olat); + if (modified) + GetRawLatticeForRedeterminizedStates(start_state, prefinal_state, + token_label_final_cost, + token_label2last_state_map, olat); } } @@ -1284,12 +1372,35 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla // step 3.1: Appending new chunk to the old one int32 state_offset = olat->NumStates(); - if (not_first_chunk) + if (not_first_chunk) { state_offset--; // since we do not append initial state in the first chunk - else + // remove arcs from redeterminized_states_ + for (auto i : redeterminized_states_) { + olat->DeleteArcs(i.first); + olat->SetFinal(i.first, CompactLatticeWeight::Zero()); + } + redeterminized_states_.clear(); + } else { forward_costs_.push_back(0); // for the first state + forward_costs_for_redet_.push_back(0); + } forward_costs_.resize(state_offset + clat.NumStates(), std::numeric_limits::infinity()); + forward_costs_for_redet_.resize(state_offset + clat.NumStates(), + std::numeric_limits::infinity()); + + std::unordered_set prefinal_states; + for (StateIterator siter(clat); !siter.Done(); siter.Next()) { + auto s = siter.Value(); + for (ArcIterator aiter(clat, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + if (clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()) { + prefinal_states.insert(s); + break; + } + } + } + for (StateIterator siter(clat); !siter.Done(); siter.Next()) { auto s = siter.Value(); StateId state_appended = kNoStateId; @@ -1299,8 +1410,10 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla KALDI_ASSERT(state_appended == olat->AddState()); olat->SetFinal(state_appended, clat.Final(s)); } + for (ArcIterator aiter(clat, s); !aiter.Done(); aiter.Next()) { const auto &arc = aiter.Value(); + StateId source_state = kNoStateId; // We do not copy initial arcs, which exists except the first chunk. // These arcs will be taken care later in step 3.2 @@ -1310,9 +1423,12 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla // process it here // In the last chunk, there could be a initial arc ending in final state, and // we process it in "process initial arcs" in the following - if (!not_first_chunk || s != 0) { + bool is_initial_state = (not_first_chunk && s == 0); + if (!is_initial_state) { KALDI_ASSERT(state_appended != kNoStateId); + KALDI_ASSERT(arc.olabel < state_last_initial_offset_); source_state = state_appended; + // process final arcs if (arc.olabel > config_.max_word_id) { // record final_arc in this chunk for the step 3.2 in the next call KALDI_ASSERT(arc.olabel < state_last_initial_offset_); @@ -1332,7 +1448,8 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla arc_appended.ilabel = 0; CompactLatticeWeight weight_offset; // remove alpha in weight - weight_offset.SetWeight(LatticeWeight(0, -forward_costs_[source_state])); + weight_offset.SetWeight( + LatticeWeight(0, -forward_costs_for_redet_[source_state])); arc_appended.weight = Times(arc_appended.weight, weight_offset); } KALDI_ASSERT(source_state != kNoStateId); @@ -1344,23 +1461,28 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla alpha_nextstate = std::min(alpha_nextstate, forward_costs_[source_state] + weight.Value1() + weight.Value2()); + // The forward_costs_for_redet_ is a version of the alpha that only includes + // contributions from non-redeterminized states, and will later be used to set + // the costs on arcs from the initial-state of the raw lattice + // Hence if it is a prefinal state, we do not update the arcs out going from it + bool is_prefinal_state = (prefinal_states.find(s) != prefinal_states.end()); + if (!is_prefinal_state) { + auto &alpha_nextstate_for_redet = + forward_costs_for_redet_[arc_appended.nextstate]; + // If the state is a initial state, the source state is from the last chunk + // We use the forward_costs_ but not forward_costs_for_redet_ to include the + // alpha with the contributions from redeterminized states + auto &alpha_state = is_initial_state + ? forward_costs_[source_state] + : forward_costs_for_redet_[source_state]; + alpha_nextstate_for_redet = + std::min(alpha_nextstate_for_redet, + alpha_state + weight.Value1() + weight.Value2()); + } } } - // Making all remaining arcs of final_arc_list_prev_ be connected to - // a dead state. - // final states are always the same state) - if (not_first_chunk) { - KALDI_ASSERT(final_arc_list_prev_.size()); - for (auto &i : final_arc_list_prev_) { - ArcIterator aiter_chunk1(*olat, i.first); - aiter_chunk1.Seek(i.second); - // Obtain the appended final arcs in the previous chunk - auto &arc_chunk1 = aiter_chunk1.Value(); - olat->SetFinal(arc_chunk1.nextstate, CompactLatticeWeight::Zero()); - } - } else - olat->SetStart(0); // Initialize the first chunk for olat + if (!not_first_chunk) olat->SetStart(0); // Initialize the first chunk for olat final_arc_list_.swap(final_arc_list_prev_); final_arc_list_.clear(); diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 6093e3b16f7..9f18f3b21b7 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -539,8 +539,9 @@ class LatticeIncrementalDeterminizer { const CompactLattice &GetDeterminizedLattice() const { return lat_; } // Part of step 1 of incremental determinization, - // where the initial states are constructed corresponding to pre-final states in - // the determinized and appended lattice before this chunk + // where the initial states are constructed corresponding to redeterminized + // states (see the description in redeterminized_states_) in the + // determinized and appended lattice before this chunk. // We give each determinized and appended state an olabel id, called `state_label` // We maintain a map (`token_label2last_state_map`) from token label (obtained from // final arcs) to the destination state of the last of the sequence of initial arcs @@ -576,6 +577,21 @@ class LatticeIncrementalDeterminizer { std::vector &GetForwardCosts() { return forward_costs_; } private: + // This function either locates a redeterminized state w.r.t nextstate previously + // added, or if necessary inserts a new one. + // The new one is inserted in olat and kept by the map (redeterminized_states_) + // which is from the state in the appended compact lattice to the state_copy in the + // raw lattice. The function returns whether a new one is inserted + // The StateId of the redeterminized state will be outputed by nextstate_copy + bool AddRedeterminizedState(Lattice::StateId nextstate, Lattice *olat, + Lattice::StateId *nextstate_copy = NULL); + // Sub function of GetInitialRawLattice(). Refer to description there + void GetRawLatticeForRedeterminizedStates( + StateId start_state, StateId state, + const unordered_map &token_label_final_cost, + unordered_multimap *token_label2last_state_map, + Lattice *olat); + const LatticeIncrementalDecoderConfig config_; const TransitionModel &trans_model_; // keep it for determinization @@ -587,9 +603,18 @@ class LatticeIncrementalDeterminizer { std::vector> final_arc_list_prev_; // alpha of each state in lat_ std::vector forward_costs_; + std::vector forward_costs_for_redet_; // we allocate a unique id for each source-state of the last arc of a series of // initial arcs in GetInitialRawLattice int32 state_last_initial_offset_; + // We define a state in the appended lattice as a 'redeterminized-state' (meaning: + // one that will be redeterminized), if it is: prefinal, or there exists an arc + // from a redeterminized state to this state. We keep reapplying this rule until + // there are no more redeterminized states. The final state is not included. + // These redeterminized states will be stored in this map + // which is a map from the state in the appended compact lattice to the + // state_copy in the newly-created raw lattice. + std::unordered_map redeterminized_states_; // The compact lattice we obtain. It should be reseted before processing a // new utterance From 92ce13c029c03bd62abf6c4f6e767923dea4ad06 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Wed, 1 May 2019 11:18:22 +0800 Subject: [PATCH 22/60] adding redet frames --- src/decoder/lattice-incremental-decoder.cc | 119 +++++++++++++-------- src/decoder/lattice-incremental-decoder.h | 31 +++++- 2 files changed, 103 insertions(+), 47 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index b34f6687ac9..a34288a43f2 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -1170,7 +1170,6 @@ void LatticeIncrementalDeterminizer::Init() { lat_.DeleteStates(); determinization_finalized_ = false; forward_costs_.clear(); - forward_costs_for_redet_.clear(); state_last_initial_offset_ = 2 * config_.max_word_id; redeterminized_states_.clear(); } @@ -1219,8 +1218,8 @@ void LatticeIncrementalDeterminizer::GetRawLatticeForRedeterminizedStates( int state_label = state + state_last_initial_offset_; // Moreover, we need to use the forward coast (alpha) of this determinized and // appended state to guide the determinization later - KALDI_ASSERT(state < forward_costs_for_redet_.size()); - auto alpha_cost = forward_costs_for_redet_[state]; + KALDI_ASSERT(state < forward_costs_.size()); + auto alpha_cost = forward_costs_[state]; Arc arc_initial(0, state_label, LatticeWeight(0, alpha_cost), state_copy); if (alpha_cost != std::numeric_limits::infinity()) olat->AddArc(start_state, arc_initial); @@ -1301,6 +1300,55 @@ void LatticeIncrementalDeterminizer::GetRawLatticeForRedeterminizedStates( token_label2last_state_map, olat); } } +template +void LatticeIncrementalDeterminizer::GetRedeterminizedStates() { + using namespace fst; + processed_prefinal_states_.clear(); + // go over all prefinal state + KALDI_ASSERT(final_arc_list_prev_.size()); + for (auto &i : final_arc_list_prev_) { + auto prefinal_state = i.first; + if (processed_prefinal_states_.find(prefinal_state) != + processed_prefinal_states_.end()) + continue; + ArcIterator aiter(lat_, prefinal_state); + aiter.Seek(i.second); + auto final_arc = aiter.Value(); + auto final_weight = lat_.Final(final_arc.nextstate); + KALDI_ASSERT(final_weight != CompactLatticeWeight::Zero()); + auto num_frames = Times(final_arc.weight, final_weight).String().size(); + if (num_frames <= config_.redeterminize_max_frames) + processed_prefinal_states_[prefinal_state] = prefinal_state; + else { + KALDI_VLOG(7) << "Impose a limit of " << config_.redeterminize_max_frames + << " on how far back in time we will redeterminize states. " + << num_frames << " frames in this arc. "; + + auto new_prefinal_state = lat_.AddState(); + forward_costs_.resize(new_prefinal_state + 1); + forward_costs_[new_prefinal_state] = forward_costs_[prefinal_state]; + + std::vector arcs_remained; + for (aiter.Reset(); !aiter.Done(); aiter.Next()) { + auto arc = aiter.Value(); + if (arc.olabel > config_.max_word_id) { // final arc + KALDI_ASSERT(arc.olabel < state_last_initial_offset_); + KALDI_ASSERT(lat_.Final(arc.nextstate) != CompactLatticeWeight::Zero()); + lat_.AddArc(new_prefinal_state, arc); + } else + arcs_remained.push_back(arc); + } + CompactLatticeArc arc_to_new; + arc_to_new.nextstate = new_prefinal_state; + arcs_remained.push_back(arc_to_new); + + lat_.DeleteArcs(prefinal_state); + for (auto &i : arcs_remained) lat_.AddArc(prefinal_state, i); + processed_prefinal_states_[prefinal_state] = new_prefinal_state; + } + } +} + // This function is specifically designed to obtain the initial arcs for a chunk // We have multiple states for one token label after determinization template @@ -1313,15 +1361,17 @@ void LatticeIncrementalDeterminizer::GetInitialRawLattice( typedef Arc::StateId StateId; typedef Arc::Weight Weight; typedef Arc::Label Label; - KALDI_ASSERT(final_arc_list_prev_.size()); + + GetRedeterminizedStates(); + olat->DeleteStates(); token_label2last_state_map->clear(); auto start_state = olat->AddState(); olat->SetStart(start_state); - // go over all prefinal state - for (auto &i : final_arc_list_prev_) { - auto prefinal_state = i.first; + // go over all prefinal states after preprocessing + for (auto &i : processed_prefinal_states_) { + auto prefinal_state = i.second; bool modified = AddRedeterminizedState(prefinal_state, olat); if (modified) GetRawLatticeForRedeterminizedStates(start_state, prefinal_state, @@ -1382,24 +1432,9 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla redeterminized_states_.clear(); } else { forward_costs_.push_back(0); // for the first state - forward_costs_for_redet_.push_back(0); } forward_costs_.resize(state_offset + clat.NumStates(), std::numeric_limits::infinity()); - forward_costs_for_redet_.resize(state_offset + clat.NumStates(), - std::numeric_limits::infinity()); - - std::unordered_set prefinal_states; - for (StateIterator siter(clat); !siter.Done(); siter.Next()) { - auto s = siter.Value(); - for (ArcIterator aiter(clat, s); !aiter.Done(); aiter.Next()) { - const auto &arc = aiter.Value(); - if (clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()) { - prefinal_states.insert(s); - break; - } - } - } for (StateIterator siter(clat); !siter.Done(); siter.Next()) { auto s = siter.Value(); @@ -1448,8 +1483,7 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla arc_appended.ilabel = 0; CompactLatticeWeight weight_offset; // remove alpha in weight - weight_offset.SetWeight( - LatticeWeight(0, -forward_costs_for_redet_[source_state])); + weight_offset.SetWeight(LatticeWeight(0, -forward_costs_[source_state])); arc_appended.weight = Times(arc_appended.weight, weight_offset); } KALDI_ASSERT(source_state != kNoStateId); @@ -1461,28 +1495,27 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla alpha_nextstate = std::min(alpha_nextstate, forward_costs_[source_state] + weight.Value1() + weight.Value2()); - // The forward_costs_for_redet_ is a version of the alpha that only includes - // contributions from non-redeterminized states, and will later be used to set - // the costs on arcs from the initial-state of the raw lattice - // Hence if it is a prefinal state, we do not update the arcs out going from it - bool is_prefinal_state = (prefinal_states.find(s) != prefinal_states.end()); - if (!is_prefinal_state) { - auto &alpha_nextstate_for_redet = - forward_costs_for_redet_[arc_appended.nextstate]; - // If the state is a initial state, the source state is from the last chunk - // We use the forward_costs_ but not forward_costs_for_redet_ to include the - // alpha with the contributions from redeterminized states - auto &alpha_state = is_initial_state - ? forward_costs_[source_state] - : forward_costs_for_redet_[source_state]; - alpha_nextstate_for_redet = - std::min(alpha_nextstate_for_redet, - alpha_state + weight.Value1() + weight.Value2()); - } } } - if (!not_first_chunk) olat->SetStart(0); // Initialize the first chunk for olat + if (!not_first_chunk) { + olat->SetStart(0); // Initialize the first chunk for olat + } else { + // The extra prefinal states generated by + // GetRedeterminizedStates are removed here, while splicing + // the compact lattices together + for (auto &i : processed_prefinal_states_) { + auto prefinal_state = i.first; + auto new_prefinal_state = i.second; + // It is without an extra prefinal state, hence do not need to process + if (prefinal_state == new_prefinal_state) continue; + for (ArcIterator aiter(lat_, new_prefinal_state); + !aiter.Done(); aiter.Next()) + lat_.AddArc(prefinal_state, aiter.Value()); + lat_.DeleteArcs(new_prefinal_state); + lat_.SetFinal(new_prefinal_state, CompactLatticeWeight::Zero()); + } + } final_arc_list_.swap(final_arc_list_prev_); final_arc_list_.clear(); diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 9f18f3b21b7..14aa853e048 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -44,6 +44,7 @@ struct LatticeIncrementalDecoderConfig { int32 determinize_delay; int32 determinize_max_active; bool redeterminize; + int32 redeterminize_max_frames; bool determinize_lattice; // not inspected by this class... used in // command-line program. BaseFloat beam_delta; // has nothing to do with beam_ratio @@ -67,6 +68,7 @@ struct LatticeIncrementalDecoderConfig { determinize_delay(25), determinize_max_active(std::numeric_limits::max()), redeterminize(false), + redeterminize_max_frames(std::numeric_limits::max()), determinize_lattice(true), beam_delta(0.5), hash_ratio(2.0), @@ -96,6 +98,11 @@ struct LatticeIncrementalDecoderConfig { opts->Register("redeterminize", &redeterminize, "whether to re-determinize the lattice after incremental " "determinization."); + opts->Register("redeterminize_max_frames", &redeterminize_max_frames, + "To impose a limit on how far back in time we will " + "redeterminize states. This is mainly intended to avoid " + "pathological cases. You could set it infinite to get a fully " + "determinized lattice."); opts->Register("determinize-lattice", &determinize_lattice, "If true, " "determinize the lattice (lattice-determinization, keeping only " @@ -539,8 +546,8 @@ class LatticeIncrementalDeterminizer { const CompactLattice &GetDeterminizedLattice() const { return lat_; } // Part of step 1 of incremental determinization, - // where the initial states are constructed corresponding to redeterminized - // states (see the description in redeterminized_states_) in the + // where the initial states are constructed corresponding to redeterminized + // states (see the description in redeterminized_states_) in the // determinized and appended lattice before this chunk. // We give each determinized and appended state an olabel id, called `state_label` // We maintain a map (`token_label2last_state_map`) from token label (obtained from @@ -591,6 +598,18 @@ class LatticeIncrementalDeterminizer { const unordered_map &token_label_final_cost, unordered_multimap *token_label2last_state_map, Lattice *olat); + // This function is to preprocess the appended compact lattice before + // generating raw lattices for the next chunk. + // After identifying prefinal-states, for any such state that is separated by + // more than config_.redeterminize_max_frames from the end of the current + // appended lattice, we create an extra state for it; we add an epsilon arc + // from that prefinal state to the extra state; we copy any final arcs from + // the prefinal state to its extra state and we remove those final arcs from + // the original prefinal-state. Now this extra state is the prefinal state to + // redeterminize and the original prefinal state does not need to redeterminize + // The epsilon would be removed later on in AppendLatticeChunks, while + // splicing the compact lattices together + void GetRedeterminizedStates(); const LatticeIncrementalDecoderConfig config_; const TransitionModel &trans_model_; // keep it for determinization @@ -603,7 +622,6 @@ class LatticeIncrementalDeterminizer { std::vector> final_arc_list_prev_; // alpha of each state in lat_ std::vector forward_costs_; - std::vector forward_costs_for_redet_; // we allocate a unique id for each source-state of the last arc of a series of // initial arcs in GetInitialRawLattice int32 state_last_initial_offset_; @@ -614,7 +632,12 @@ class LatticeIncrementalDeterminizer { // These redeterminized states will be stored in this map // which is a map from the state in the appended compact lattice to the // state_copy in the newly-created raw lattice. - std::unordered_map redeterminized_states_; + unordered_map redeterminized_states_; + // It is a map used in GetRedeterminizedStates (see the description there) + // A map from the original prefinal state to the prefinal states (i.e. the + // original prefinal state or an extra state generated by + // GetRedeterminizedStates) used for generating raw lattices of the next chunk. + unordered_map processed_prefinal_states_; // The compact lattice we obtain. It should be reseted before processing a // new utterance From b4ed30c61f1164bc3eb64645318528cac9dd5c46 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Thu, 2 May 2019 10:31:00 +0800 Subject: [PATCH 23/60] add eps removal; 1oco --- src/decoder/lattice-incremental-decoder.cc | 24 ++++++++++++++++++++-- src/decoder/lattice-incremental-decoder.h | 4 ++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index a34288a43f2..8cc22987cfb 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -1471,6 +1471,7 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla final_arc_list_.push_back( pair(state_appended, aiter.Position())); } + olat->AddArc(source_state, arc_appended); } else { // process initial arcs // a special olabel in the arc that corresponds to the identity of the // source-state of the last arc, we use its StateId and a offset here, called @@ -1485,9 +1486,28 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla // remove alpha in weight weight_offset.SetWeight(LatticeWeight(0, -forward_costs_[source_state])); arc_appended.weight = Times(arc_appended.weight, weight_offset); + + if (!config_.epsilon_removal || + clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()) { + // it should be the last chunk + olat->AddArc(source_state, arc_appended); + } else { + // append lattice chunk and remove Epsilon together + for (ArcIterator aiter_postinitial(clat, arc.nextstate); + !aiter_postinitial.Done(); aiter_postinitial.Next()) { + auto arc_postinitial(aiter_postinitial.Value()); + arc_postinitial.weight = + Times(arc_postinitial.weight, arc_appended.weight); + arc_postinitial.nextstate += state_offset; + olat->AddArc(source_state, arc_postinitial); + if (arc_postinitial.olabel > config_.max_word_id) { + KALDI_ASSERT(arc_postinitial.olabel < state_last_initial_offset_); + final_arc_list_.push_back( + pair(source_state, aiter_postinitial.Position())); + } + } + } } - KALDI_ASSERT(source_state != kNoStateId); - olat->AddArc(source_state, arc_appended); // update forward_costs_ (alpha) KALDI_ASSERT(arc_appended.nextstate < forward_costs_.size()); auto &alpha_nextstate = forward_costs_[arc_appended.nextstate]; diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 14aa853e048..25fe7f2a7e3 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -45,6 +45,7 @@ struct LatticeIncrementalDecoderConfig { int32 determinize_max_active; bool redeterminize; int32 redeterminize_max_frames; + bool epsilon_removal; bool determinize_lattice; // not inspected by this class... used in // command-line program. BaseFloat beam_delta; // has nothing to do with beam_ratio @@ -69,6 +70,7 @@ struct LatticeIncrementalDecoderConfig { determinize_max_active(std::numeric_limits::max()), redeterminize(false), redeterminize_max_frames(std::numeric_limits::max()), + epsilon_removal(false), determinize_lattice(true), beam_delta(0.5), hash_ratio(2.0), @@ -103,6 +105,8 @@ struct LatticeIncrementalDecoderConfig { "redeterminize states. This is mainly intended to avoid " "pathological cases. You could set it infinite to get a fully " "determinized lattice."); + opts->Register("epsilon-removal", &epsilon_removal, + "whether to remove epsilon when appending two adjacent chunks."); opts->Register("determinize-lattice", &determinize_lattice, "If true, " "determinize the lattice (lattice-determinization, keeping only " From 5651200f4d04a8792759789c3c0e36baa6c79adb Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Sun, 5 May 2019 09:35:14 +0800 Subject: [PATCH 24/60] bug fix when --epsilon-removal=1 --redeterminize-max-frames=10 --- src/decoder/lattice-incremental-decoder.cc | 34 +++++++++++++++++----- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 8cc22987cfb..f758a304e05 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -1172,6 +1172,7 @@ void LatticeIncrementalDeterminizer::Init() { forward_costs_.clear(); state_last_initial_offset_ = 2 * config_.max_word_id; redeterminized_states_.clear(); + processed_prefinal_states_.clear(); } template bool LatticeIncrementalDeterminizer::AddRedeterminizedState( @@ -1312,6 +1313,7 @@ void LatticeIncrementalDeterminizer::GetRedeterminizedStates() { processed_prefinal_states_.end()) continue; ArcIterator aiter(lat_, prefinal_state); + KALDI_ASSERT(lat_.NumArcs(prefinal_state) > i.second); aiter.Seek(i.second); auto final_arc = aiter.Value(); auto final_weight = lat_.Final(final_arc.nextstate); @@ -1338,8 +1340,7 @@ void LatticeIncrementalDeterminizer::GetRedeterminizedStates() { } else arcs_remained.push_back(arc); } - CompactLatticeArc arc_to_new; - arc_to_new.nextstate = new_prefinal_state; + CompactLatticeArc arc_to_new(0, 0, CompactLatticeWeight::One(), new_prefinal_state); arcs_remained.push_back(arc_to_new); lat_.DeleteArcs(prefinal_state); @@ -1347,6 +1348,7 @@ void LatticeIncrementalDeterminizer::GetRedeterminizedStates() { processed_prefinal_states_[prefinal_state] = new_prefinal_state; } } + KALDI_VLOG(8) << "states of the lattice after GetRedeterminizedStates: " << lat_.NumStates(); } // This function is specifically designed to obtain the initial arcs for a chunk @@ -1436,6 +1438,11 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla forward_costs_.resize(state_offset + clat.NumStates(), std::numeric_limits::infinity()); + // Here we construct a map from the original prefinal state to the prefinal states for later use + unordered_map invert_processed_prefinal_states; + invert_processed_prefinal_states.reserve(processed_prefinal_states_.size()); + for (auto i:processed_prefinal_states_) + invert_processed_prefinal_states[i.second]=i.first; for (StateIterator siter(clat); !siter.Done(); siter.Next()) { auto s = siter.Value(); StateId state_appended = kNoStateId; @@ -1468,6 +1475,8 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla // record final_arc in this chunk for the step 3.2 in the next call KALDI_ASSERT(arc.olabel < state_last_initial_offset_); KALDI_ASSERT(clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()); + // state_appended shouldn't be in invert_processed_prefinal_states + // So we do not need to map it final_arc_list_.push_back( pair(state_appended, aiter.Position())); } @@ -1487,11 +1496,20 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla weight_offset.SetWeight(LatticeWeight(0, -forward_costs_[source_state])); arc_appended.weight = Times(arc_appended.weight, weight_offset); + // if it is an extra prefinal state, we should use its original prefinal state + int arc_offset = 0; + auto r = invert_processed_prefinal_states.find(source_state); + if (r != invert_processed_prefinal_states.end() && r->second != r->first) { + source_state = r->second; + arc_offset = olat->NumArcs(source_state); + } + if (!config_.epsilon_removal || clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()) { // it should be the last chunk olat->AddArc(source_state, arc_appended); } else { + // append lattice chunk and remove Epsilon together for (ArcIterator aiter_postinitial(clat, arc.nextstate); !aiter_postinitial.Done(); aiter_postinitial.Next()) { @@ -1503,7 +1521,7 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla if (arc_postinitial.olabel > config_.max_word_id) { KALDI_ASSERT(arc_postinitial.olabel < state_last_initial_offset_); final_arc_list_.push_back( - pair(source_state, aiter_postinitial.Position())); + pair(source_state, aiter_postinitial.Position() + arc_offset)); } } } @@ -1517,6 +1535,8 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla forward_costs_[source_state] + weight.Value1() + weight.Value2()); } } + KALDI_ASSERT(olat->NumStates() == clat.NumStates() + state_offset); + KALDI_VLOG(8) << "states of the lattice: " << olat->NumStates(); if (!not_first_chunk) { olat->SetStart(0); // Initialize the first chunk for olat @@ -1529,11 +1549,11 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla auto new_prefinal_state = i.second; // It is without an extra prefinal state, hence do not need to process if (prefinal_state == new_prefinal_state) continue; - for (ArcIterator aiter(lat_, new_prefinal_state); + for (ArcIterator aiter(*olat, new_prefinal_state); !aiter.Done(); aiter.Next()) - lat_.AddArc(prefinal_state, aiter.Value()); - lat_.DeleteArcs(new_prefinal_state); - lat_.SetFinal(new_prefinal_state, CompactLatticeWeight::Zero()); + olat->AddArc(prefinal_state, aiter.Value()); + olat->DeleteArcs(new_prefinal_state); + olat->SetFinal(new_prefinal_state, CompactLatticeWeight::Zero()); } } From 7401fe44d7297956f9074f41479c7e5870e18d38 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Sun, 5 May 2019 13:10:57 +0800 Subject: [PATCH 25/60] code refine --- src/bin/Makefile | 3 +- src/bin/latgen-incremental-mapped.cc | 4 +- src/decoder/Makefile | 4 +- src/decoder/decoder-wrappers.cc | 3 - src/decoder/lattice-faster-decoder.cc | 2 - src/decoder/lattice-incremental-decoder.cc | 76 ++++++++++----------- src/decoder/lattice-incremental-decoder.h | 78 +++++++++++----------- 7 files changed, 80 insertions(+), 90 deletions(-) diff --git a/src/bin/Makefile b/src/bin/Makefile index 8046f6c9ab2..02c95ff4804 100644 --- a/src/bin/Makefile +++ b/src/bin/Makefile @@ -17,13 +17,12 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \ post-to-weights sum-tree-stats weight-post post-to-tacc copy-matrix \ copy-vector copy-int-vector sum-post sum-matrices draw-tree \ align-mapped align-compiled-mapped latgen-faster-mapped latgen-faster-mapped-parallel \ - latgen-incremental-mapped \ hmm-info analyze-counts post-to-phone-post \ post-to-pdf-post logprob-to-post prob-to-post copy-post \ matrix-sum build-pfile-from-ali get-post-on-ali tree-info am-info \ vector-sum matrix-sum-rows est-pca sum-lda-accs sum-mllt-accs \ transform-vec align-text matrix-dim post-to-smat compile-graph \ - compare-int-vector + compare-int-vector latgen-incremental-mapped OBJFILES = diff --git a/src/bin/latgen-incremental-mapped.cc b/src/bin/latgen-incremental-mapped.cc index 164a513f2d6..276e632391c 100644 --- a/src/bin/latgen-incremental-mapped.cc +++ b/src/bin/latgen-incremental-mapped.cc @@ -1,8 +1,6 @@ // bin/latgen-incremental-mapped.cc -// Copyright 2009-2012 Microsoft Corporation, Karel Vesely -// 2013 Johns Hopkins University (author: Daniel Povey) -// 2014 Guoguo Chen +// Copyright 2019 Zhehuai Chen // See ../../COPYING for clarification regarding multiple authors // diff --git a/src/decoder/Makefile b/src/decoder/Makefile index 849947e493f..ebac90e65ac 100644 --- a/src/decoder/Makefile +++ b/src/decoder/Makefile @@ -7,8 +7,8 @@ TESTFILES = OBJFILES = training-graph-compiler.o lattice-simple-decoder.o lattice-faster-decoder.o \ lattice-faster-online-decoder.o simple-decoder.o faster-decoder.o \ - lattice-incremental-decoder.o \ - decoder-wrappers.o grammar-fst.o decodable-matrix.o + decoder-wrappers.o grammar-fst.o decodable-matrix.o \ + lattice-incremental-decoder.o LIBNAME = kaldi-decoder diff --git a/src/decoder/decoder-wrappers.cc b/src/decoder/decoder-wrappers.cc index abb326eb012..294a2f69117 100644 --- a/src/decoder/decoder-wrappers.cc +++ b/src/decoder/decoder-wrappers.cc @@ -22,7 +22,6 @@ #include "decoder/lattice-faster-decoder.h" #include "decoder/grammar-fst.h" #include "lat/lattice-functions.h" -#include "base/timer.h" namespace kaldi { @@ -354,7 +353,6 @@ bool DecodeUtteranceLatticeFaster( // Get lattice, and do determinization if requested. Lattice lat; - Timer timer; decoder.GetRawLattice(&lat); if (lat.NumStates() == 0) KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt; @@ -379,7 +377,6 @@ bool DecodeUtteranceLatticeFaster( fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &lat); lattice_writer->Write(utt, lat); } - KALDI_VLOG(2) << "Delay time after decoding finalized (secs): " << timer.Elapsed(); KALDI_LOG << "Log-like per frame for utterance " << utt << " is " << (likelihood / num_frames) << " over " << num_frames << " frames."; diff --git a/src/decoder/lattice-faster-decoder.cc b/src/decoder/lattice-faster-decoder.cc index ed78ba5fddb..2bc8c7cdef4 100644 --- a/src/decoder/lattice-faster-decoder.cc +++ b/src/decoder/lattice-faster-decoder.cc @@ -89,9 +89,7 @@ bool LatticeFasterDecoderTpl::Decode(DecodableInterface *decodable) BaseFloat cost_cutoff = ProcessEmitting(decodable); ProcessNonemitting(cost_cutoff); } - Timer timer; FinalizeDecoding(); - KALDI_VLOG(2) << "Delay0 time after decoding finalized (secs): " << timer.Elapsed(); // Returns true if we have any kind of traceback available (not necessarily // to the end state; query ReachedFinal() for that). diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index f758a304e05..12e400ac100 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -1,9 +1,6 @@ // decoder/lattice-incremental-decoder.cc -// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann -// 2013-2018 Johns Hopkins University (Author: Daniel Povey) -// 2014 Guoguo Chen -// 2018 Zhehuai Chen +// Copyright 2019 Zhehuai Chen // See ../../COPYING for clarification regarding multiple authors // @@ -103,29 +100,30 @@ bool LatticeIncrementalDecoderTpl::Decode( while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { if (NumFramesDecoded() % config_.prune_interval == 0) { PruneActiveTokens(config_.lattice_beam * config_.prune_scale); - // We always incrementally determinize the lattice after lattice pruning in - // PruneActiveTokens() - // We have a delay on GetLattice to do determinization on more skinny lattices - int32 frame_det_most = NumFramesDecoded() - config_.determinize_delay; - // The minimum length of chunk is config_.prune_interval. We make it - // identical to PruneActiveTokens since we need extra_cost as the weights - // of final arcs to denote the "future" information of final states (Tokens) + } + + // We always incrementally determinize the lattice after lattice pruning in + // PruneActiveTokens() + // We have a delay on GetLattice to do determinization on more skinny lattices + int32 frame_det_most = NumFramesDecoded() - config_.determinize_delay; + // The minimum length of chunk is config_.prune_interval. We make it + // identical to PruneActiveTokens since we need extra_cost as the weights + // of final arcs to denote the "future" information of final states (Tokens) + if (frame_det_most % config_.prune_interval == 0) { int32 frame_det_least = last_get_lattice_frame_ + config_.prune_interval; - if (config_.determinize_lattice && frame_det_most > 0) { - // To adaptively decide the length of chunk, we further compare the number of - // tokens in each frame and a pre-defined threshold. - // If the number of tokens in a certain frame is less than - // config_.determinize_max_active, the lattice can be determinized up to this - // frame. And we try to determinize as most frames as possible so we check - // numbers from frame_det_most to frame_det_least - for (int32 f = frame_det_most; f >= frame_det_least; f--) { - if (config_.determinize_max_active == std::numeric_limits::max() || - GetNumToksForFrame(f) < config_.determinize_max_active) { - KALDI_VLOG(2) << "Frame: " << NumFramesDecoded() - << " incremental determinization up to " << f; - GetLattice(false, false, f); - break; - } + // To adaptively decide the length of chunk, we further compare the number of + // tokens in each frame and a pre-defined threshold. + // If the number of tokens in a certain frame is less than + // config_.determinize_max_active, the lattice can be determinized up to this + // frame. And we try to determinize as most frames as possible so we check + // numbers from frame_det_most to frame_det_least + for (int32 f = frame_det_most; f >= frame_det_least; f--) { + if (config_.determinize_max_active == std::numeric_limits::max() || + GetNumToksForFrame(f) < config_.determinize_max_active) { + KALDI_VLOG(2) << "Frame: " << NumFramesDecoded() + << " incremental determinization up to " << f; + GetLattice(false, false, f); + break; } } } @@ -134,8 +132,7 @@ bool LatticeIncrementalDecoderTpl::Decode( } Timer timer; FinalizeDecoding(); - if (config_.determinize_lattice) - GetLattice(true, config_.redeterminize, NumFramesDecoded()); + GetLattice(true, config_.redeterminize, NumFramesDecoded()); KALDI_VLOG(2) << "Delay time during and after decoding finalization (secs): " << timer.Elapsed(); @@ -1340,7 +1337,8 @@ void LatticeIncrementalDeterminizer::GetRedeterminizedStates() { } else arcs_remained.push_back(arc); } - CompactLatticeArc arc_to_new(0, 0, CompactLatticeWeight::One(), new_prefinal_state); + CompactLatticeArc arc_to_new(0, 0, CompactLatticeWeight::One(), + new_prefinal_state); arcs_remained.push_back(arc_to_new); lat_.DeleteArcs(prefinal_state); @@ -1348,7 +1346,8 @@ void LatticeIncrementalDeterminizer::GetRedeterminizedStates() { processed_prefinal_states_[prefinal_state] = new_prefinal_state; } } - KALDI_VLOG(8) << "states of the lattice after GetRedeterminizedStates: " << lat_.NumStates(); + KALDI_VLOG(8) << "states of the lattice after GetRedeterminizedStates: " + << lat_.NumStates(); } // This function is specifically designed to obtain the initial arcs for a chunk @@ -1438,11 +1437,12 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla forward_costs_.resize(state_offset + clat.NumStates(), std::numeric_limits::infinity()); - // Here we construct a map from the original prefinal state to the prefinal states for later use + // Here we construct a map from the original prefinal state to the prefinal states + // for later use unordered_map invert_processed_prefinal_states; invert_processed_prefinal_states.reserve(processed_prefinal_states_.size()); - for (auto i:processed_prefinal_states_) - invert_processed_prefinal_states[i.second]=i.first; + for (auto i : processed_prefinal_states_) + invert_processed_prefinal_states[i.second] = i.first; for (StateIterator siter(clat); !siter.Done(); siter.Next()) { auto s = siter.Value(); StateId state_appended = kNoStateId; @@ -1496,7 +1496,8 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla weight_offset.SetWeight(LatticeWeight(0, -forward_costs_[source_state])); arc_appended.weight = Times(arc_appended.weight, weight_offset); - // if it is an extra prefinal state, we should use its original prefinal state + // if it is an extra prefinal state, we should use its original prefinal + // state int arc_offset = 0; auto r = invert_processed_prefinal_states.find(source_state); if (r != invert_processed_prefinal_states.end() && r->second != r->first) { @@ -1509,7 +1510,6 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla // it should be the last chunk olat->AddArc(source_state, arc_appended); } else { - // append lattice chunk and remove Epsilon together for (ArcIterator aiter_postinitial(clat, arc.nextstate); !aiter_postinitial.Done(); aiter_postinitial.Next()) { @@ -1520,8 +1520,8 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla olat->AddArc(source_state, arc_postinitial); if (arc_postinitial.olabel > config_.max_word_id) { KALDI_ASSERT(arc_postinitial.olabel < state_last_initial_offset_); - final_arc_list_.push_back( - pair(source_state, aiter_postinitial.Position() + arc_offset)); + final_arc_list_.push_back(pair( + source_state, aiter_postinitial.Position() + arc_offset)); } } } @@ -1551,7 +1551,7 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla if (prefinal_state == new_prefinal_state) continue; for (ArcIterator aiter(*olat, new_prefinal_state); !aiter.Done(); aiter.Next()) - olat->AddArc(prefinal_state, aiter.Value()); + olat->AddArc(prefinal_state, aiter.Value()); olat->DeleteArcs(new_prefinal_state); olat->SetFinal(new_prefinal_state, CompactLatticeWeight::Zero()); } diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 25fe7f2a7e3..f810fd0e3fa 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -1,9 +1,6 @@ // decoder/lattice-incremental-decoder.h -// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann; -// 2013-2014 Johns Hopkins University (Author: Daniel Povey) -// 2014 Guoguo Chen -// 2018 Zhehuai Chen +// Copyright 2019 Zhehuai Chen // See ../../COPYING for clarification regarding multiple authors // @@ -46,9 +43,7 @@ struct LatticeIncrementalDecoderConfig { bool redeterminize; int32 redeterminize_max_frames; bool epsilon_removal; - bool determinize_lattice; // not inspected by this class... used in - // command-line program. - BaseFloat beam_delta; // has nothing to do with beam_ratio + BaseFloat beam_delta; // has nothing to do with beam_ratio BaseFloat hash_ratio; BaseFloat prune_scale; // Note: we don't make this configurable on the command line, @@ -71,7 +66,6 @@ struct LatticeIncrementalDecoderConfig { redeterminize(false), redeterminize_max_frames(std::numeric_limits::max()), epsilon_removal(false), - determinize_lattice(true), beam_delta(0.5), hash_ratio(2.0), prune_scale(0.1), @@ -90,13 +84,17 @@ struct LatticeIncrementalDecoderConfig { "Interval (in frames) at " "which to prune tokens"); opts->Register("determinize-delay", &determinize_delay, - "delay (in frames) at " - "which to incrementally determinize lattices"); + "delay (in frames, typically larger than --prune-interval) " + "at which to incrementally determinize lattices."); opts->Register("determinize-max-active", &determinize_max_active, - "This option is to adaptively decide --determinize-delay. " + "This option is to adaptively decide the size of the chunk " + "to be determinized. " "If the number of active tokens(in a certain frame) is less " - "than this number, we will start to incrementally " - "determinize lattices up to this frame."); + "than this number (typically 50), we will start to " + "incrementally determinize lattices from the last frame we " + "determinized up to this frame. It can work with " + "--determinize-delay to further reduce the computation " + "introduced by incremental determinization. "); opts->Register("redeterminize", &redeterminize, "whether to re-determinize the lattice after incremental " "determinization."); @@ -107,10 +105,6 @@ struct LatticeIncrementalDecoderConfig { "determinized lattice."); opts->Register("epsilon-removal", &epsilon_removal, "whether to remove epsilon when appending two adjacent chunks."); - opts->Register("determinize-lattice", &determinize_lattice, - "If true, " - "determinize the lattice (lattice-determinization, keeping only " - "best pdf-sequence for each word-sequence)."); opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " "parameter is obscure and relates to a speedup in the way the " @@ -122,17 +116,21 @@ struct LatticeIncrementalDecoderConfig { void Check() const { KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 && min_active <= max_active && prune_interval > 0 && - beam_delta > 0.0 && hash_ratio >= 1.0 && prune_scale > 0.0 && - prune_scale < 1.0); + determinize_delay >= 0 && determinize_max_active >= 0 && + redeterminize_max_frames >= 0 && beam_delta > 0.0 && + hash_ratio >= 1.0 && prune_scale > 0.0 && prune_scale < 1.0); } }; template class LatticeIncrementalDeterminizer; -/** This is the "normal" lattice-generating decoder. - See \ref lattices_generation \ref decoders_faster and \ref decoders_simple - for more information. +/* This is an extention to the "normal" lattice-generating decoder. + See \ref lattices_generation \ref decoders_faster and \ref decoders_simple + for more information. + + The main difference is the incremental determinization which will be + discussed in the function GetLattice(). The decoder is templated on the FST type and the token type. The token type will normally be StdToken, but also may be BackpointerToken which is to support @@ -209,10 +207,10 @@ class LatticeIncrementalDecoderTpl { // of it as a `pre-final state` // Similarly, we define a `initial arc` as an arc from a initial-state, and the // destination state of it as a `post-initial state` - // The initial states are constructed corresponding to pre-final states in the - // determinized and appended lattice before this chunk - // The final states are constructed correponding to tokens in the last frames of - // this chunk + // The post-initial states are constructed corresponding to pre-final states + // in the determinized and appended lattice before this chunk + // The pre-final states are constructed correponding to tokens in the last frames + // of this chunk. // Since the StateId can change during determinization, we need to give permanent // unique labels (as olabel) to these // raw-lattice states for latter appending. @@ -550,8 +548,8 @@ class LatticeIncrementalDeterminizer { const CompactLattice &GetDeterminizedLattice() const { return lat_; } // Part of step 1 of incremental determinization, - // where the initial states are constructed corresponding to redeterminized - // states (see the description in redeterminized_states_) in the + // where the post-initial states are constructed corresponding to + // redeterminized states (see the description in redeterminized_states_) in the // determinized and appended lattice before this chunk. // We give each determinized and appended state an olabel id, called `state_label` // We maintain a map (`token_label2last_state_map`) from token label (obtained from @@ -577,7 +575,7 @@ class LatticeIncrementalDeterminizer { // clat into olat // Otherwise, we need to connect states of the last frame of // the last chunk to states of the first frame of this chunk. - // These begin and final states are corresponding to the same Token, + // These post-initial and pre-final states are corresponding to the same Token, // guaranteed by unique state labels. bool AppendLatticeChunks(CompactLattice clat, bool not_first_chunk); @@ -604,13 +602,13 @@ class LatticeIncrementalDeterminizer { Lattice *olat); // This function is to preprocess the appended compact lattice before // generating raw lattices for the next chunk. - // After identifying prefinal-states, for any such state that is separated by + // After identifying pre-final states, for any such state that is separated by // more than config_.redeterminize_max_frames from the end of the current // appended lattice, we create an extra state for it; we add an epsilon arc - // from that prefinal state to the extra state; we copy any final arcs from - // the prefinal state to its extra state and we remove those final arcs from - // the original prefinal-state. Now this extra state is the prefinal state to - // redeterminize and the original prefinal state does not need to redeterminize + // from that pre-final state to the extra state; we copy any final arcs from + // the pre-final state to its extra state and we remove those final arcs from + // the original pre-final state. Now this extra state is the pre-final state to + // redeterminize and the original pre-final state does not need to redeterminize // The epsilon would be removed later on in AppendLatticeChunks, while // splicing the compact lattices together void GetRedeterminizedStates(); @@ -630,16 +628,16 @@ class LatticeIncrementalDeterminizer { // initial arcs in GetInitialRawLattice int32 state_last_initial_offset_; // We define a state in the appended lattice as a 'redeterminized-state' (meaning: - // one that will be redeterminized), if it is: prefinal, or there exists an arc - // from a redeterminized state to this state. We keep reapplying this rule until - // there are no more redeterminized states. The final state is not included. - // These redeterminized states will be stored in this map + // one that will be redeterminized), if it is: a pre-final state, or there + // exists an arc from a redeterminized state to this state. We keep reapplying + // this rule until there are no more redeterminized states. The final state + // is not included. These redeterminized states will be stored in this map // which is a map from the state in the appended compact lattice to the // state_copy in the newly-created raw lattice. unordered_map redeterminized_states_; // It is a map used in GetRedeterminizedStates (see the description there) - // A map from the original prefinal state to the prefinal states (i.e. the - // original prefinal state or an extra state generated by + // A map from the original pre-final state to the pre-final states (i.e. the + // original pre-final state or an extra state generated by // GetRedeterminizedStates) used for generating raw lattices of the next chunk. unordered_map processed_prefinal_states_; From e5cef12659f4f5bc3cd01000c4bffbdd88d73d1f Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Thu, 9 May 2019 03:02:36 -0400 Subject: [PATCH 26/60] bug fix add warning add determinize_chunk_size --- src/bin/latgen-incremental-mapped.cc | 2 +- src/decoder/lattice-incremental-decoder.cc | 15 +++++++++++---- src/decoder/lattice-incremental-decoder.h | 7 +++++++ 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/bin/latgen-incremental-mapped.cc b/src/bin/latgen-incremental-mapped.cc index 276e632391c..a5b7619c0ab 100644 --- a/src/bin/latgen-incremental-mapped.cc +++ b/src/bin/latgen-incremental-mapped.cc @@ -71,7 +71,7 @@ int main(int argc, char *argv[]) { TransitionModel trans_model; ReadKaldiObject(model_in_filename, &trans_model); - bool determinize = config.determinize_lattice; + bool determinize = true; CompactLatticeWriter compact_lattice_writer; LatticeWriter lattice_writer; if (! (determinize ? compact_lattice_writer.Open(lattice_wspecifier) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 12e400ac100..8101db1a850 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -106,11 +106,12 @@ bool LatticeIncrementalDecoderTpl::Decode( // PruneActiveTokens() // We have a delay on GetLattice to do determinization on more skinny lattices int32 frame_det_most = NumFramesDecoded() - config_.determinize_delay; - // The minimum length of chunk is config_.prune_interval. We make it + // The minimum length of chunk is config_.determinize_chunk_size. We make it // identical to PruneActiveTokens since we need extra_cost as the weights // of final arcs to denote the "future" information of final states (Tokens) - if (frame_det_most % config_.prune_interval == 0) { - int32 frame_det_least = last_get_lattice_frame_ + config_.prune_interval; + if (frame_det_most % config_.determinize_chunk_size == 0) { + int32 frame_det_least = last_get_lattice_frame_ + + config_.determinize_chunk_size; // To adaptively decide the length of chunk, we further compare the number of // tokens in each frame and a pre-defined threshold. // If the number of tokens in a certain frame is less than @@ -1054,7 +1055,13 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( KALDI_ASSERT(r != token_label_map_.end()); // it should exist int32 token_label = r->second; auto range = token_label2last_state_map.equal_range(token_label); - KALDI_ASSERT(range.first != range.second); + if (range.first == range.second) { + KALDI_WARN << "The token in the first frame of this chunk does not " + "exist in the last frame of previous chunk. It should be seldom" + " happen and probably caused by over-pruning in determinization," + "e.g. the lattice reaches --max-mem constrain."; + continue; + } std::vector tmp_vec; for (auto it = range.first; it != range.second; ++it) { // the destination state of the last of the sequence of arcs w.r.t the token diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index f810fd0e3fa..cb0dd04c8e7 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -39,6 +39,7 @@ struct LatticeIncrementalDecoderConfig { BaseFloat lattice_beam; int32 prune_interval; int32 determinize_delay; + int32 determinize_chunk_size; int32 determinize_max_active; bool redeterminize; int32 redeterminize_max_frames; @@ -62,6 +63,7 @@ struct LatticeIncrementalDecoderConfig { lattice_beam(10.0), prune_interval(25), determinize_delay(25), + determinize_chunk_size(20), determinize_max_active(std::numeric_limits::max()), redeterminize(false), redeterminize_max_frames(std::numeric_limits::max()), @@ -86,6 +88,10 @@ struct LatticeIncrementalDecoderConfig { opts->Register("determinize-delay", &determinize_delay, "delay (in frames, typically larger than --prune-interval) " "at which to incrementally determinize lattices."); + opts->Register("determinize-chunk-size", &determinize_chunk_size, + "the size (in frames) of chunk to do incrementally " + "determinization. If working with --determinize-max-active," + "it will become a lower bound of the size of chunk."); opts->Register("determinize-max-active", &determinize_max_active, "This option is to adaptively decide the size of the chunk " "to be determinized. " @@ -117,6 +123,7 @@ struct LatticeIncrementalDecoderConfig { KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 && min_active <= max_active && prune_interval > 0 && determinize_delay >= 0 && determinize_max_active >= 0 && + determinize_chunk_size >= 0 && redeterminize_max_frames >= 0 && beam_delta > 0.0 && hash_ratio >= 1.0 && prune_scale > 0.0 && prune_scale < 1.0); } From ecae786bb3a798164f6b0088276a5c4cdc87da48 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Mon, 13 May 2019 09:05:53 +0800 Subject: [PATCH 27/60] code refine --- src/decoder/decoder-wrappers.cc | 16 +- src/decoder/lattice-incremental-decoder.cc | 90 +++------ src/decoder/lattice-incremental-decoder.h | 213 ++++++++++----------- 3 files changed, 130 insertions(+), 189 deletions(-) diff --git a/src/decoder/decoder-wrappers.cc b/src/decoder/decoder-wrappers.cc index 294a2f69117..a912af41077 100644 --- a/src/decoder/decoder-wrappers.cc +++ b/src/decoder/decoder-wrappers.cc @@ -211,7 +211,6 @@ bool DecodeUtteranceLatticeIncremental( LatticeWriter *lattice_writer, double *like_ptr) { // puts utterance's like in like_ptr on success. using fst::VectorFst; - if (!decoder.Decode(&decodable)) { KALDI_WARN << "Failed to decode file " << utt; return false; @@ -264,19 +263,10 @@ bool DecodeUtteranceLatticeIncremental( decoder.GetLattice(&clat); if (clat.NumStates() == 0) KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt; - if (determinize) { // We'll write the lattice without acoustic scaling. - if (acoustic_scale != 0.0) - fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat); - compact_lattice_writer->Write(utt, clat); - } else { - Lattice lat; - decoder.GetRawLattice(&lat); - // We'll write the lattice without acoustic scaling. - if (acoustic_scale != 0.0) - fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &lat); - lattice_writer->Write(utt, lat); - } + if (acoustic_scale != 0.0) + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat); + compact_lattice_writer->Write(utt, clat); KALDI_LOG << "Log-like per frame for utterance " << utt << " is " << (likelihood / num_frames) << " over " << num_frames << " frames."; diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 8101db1a850..7c860412a95 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -89,8 +89,7 @@ void LatticeIncrementalDecoderTpl::InitDecoding() { // a final state). It should only very rarely return false; this indicates // an unusual search error. template -bool LatticeIncrementalDecoderTpl::Decode( - DecodableInterface *decodable) { +bool LatticeIncrementalDecoderTpl::Decode(DecodableInterface *decodable) { InitDecoding(); // We use 1-based indexing for frames in this decoder (if you view it in @@ -103,15 +102,15 @@ bool LatticeIncrementalDecoderTpl::Decode( } // We always incrementally determinize the lattice after lattice pruning in - // PruneActiveTokens() - // We have a delay on GetLattice to do determinization on more skinny lattices - int32 frame_det_most = NumFramesDecoded() - config_.determinize_delay; - // The minimum length of chunk is config_.determinize_chunk_size. We make it - // identical to PruneActiveTokens since we need extra_cost as the weights + // PruneActiveTokens() since we need extra_cost as the weights // of final arcs to denote the "future" information of final states (Tokens) + // Moreover, the delay on GetLattice to do determinization + // make it process more skinny lattices which reduces the computation overheads. + int32 frame_det_most = NumFramesDecoded() - config_.determinize_delay; + // The minimum length of chunk is config_.determinize_chunk_size. if (frame_det_most % config_.determinize_chunk_size == 0) { - int32 frame_det_least = last_get_lattice_frame_ + - config_.determinize_chunk_size; + int32 frame_det_least = + last_get_lattice_frame_ + config_.determinize_chunk_size; // To adaptively decide the length of chunk, we further compare the number of // tokens in each frame and a pre-defined threshold. // If the number of tokens in a certain frame is less than @@ -123,7 +122,7 @@ bool LatticeIncrementalDecoderTpl::Decode( GetNumToksForFrame(f) < config_.determinize_max_active) { KALDI_VLOG(2) << "Frame: " << NumFramesDecoded() << " incremental determinization up to " << f; - GetLattice(false, false, f); + GetLattice(false, f); break; } } @@ -133,9 +132,9 @@ bool LatticeIncrementalDecoderTpl::Decode( } Timer timer; FinalizeDecoding(); - GetLattice(true, config_.redeterminize, NumFramesDecoded()); - KALDI_VLOG(2) << "Delay time during and after decoding finalization (secs): " - << timer.Elapsed(); + GetLattice(true, NumFramesDecoded()); + KALDI_VLOG(2) << "Delay time during and after FinalizeDecoding()" + << "(secs): " << timer.Elapsed(); // Returns true if we have any kind of traceback available (not necessarily // to the end state; query ReachedFinal() for that). @@ -147,23 +146,12 @@ template bool LatticeIncrementalDecoderTpl::GetBestPath(Lattice *olat, bool use_final_probs) { CompactLattice lat, slat; - GetLattice(use_final_probs, config_.redeterminize, NumFramesDecoded(), &lat); + GetLattice(use_final_probs, NumFramesDecoded(), &lat); ShortestPath(lat, &slat); ConvertLattice(slat, olat); return (olat->NumStates() != 0); } -// Outputs an FST corresponding to the raw, state-level lattice -template -bool LatticeIncrementalDecoderTpl::GetRawLattice(Lattice *ofst, - bool use_final_probs) { - CompactLattice lat; - GetLattice(use_final_probs, config_.redeterminize, NumFramesDecoded(), &lat); - ConvertLattice(lat, ofst); - Connect(ofst); - return (ofst->NumStates() != 0); -} - template void LatticeIncrementalDecoderTpl::PossiblyResizeHash(size_t num_toks) { size_t new_sz = @@ -941,12 +929,11 @@ void LatticeIncrementalDecoderTpl::TopSortTokens( template bool LatticeIncrementalDecoderTpl::GetLattice(CompactLattice *olat) { - return GetLattice(true, config_.redeterminize, NumFramesDecoded(), olat); + return GetLattice(true, NumFramesDecoded(), olat); } template bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, - bool redeterminize, int32 last_frame_of_chunk, CompactLattice *olat) { using namespace fst; @@ -976,7 +963,7 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, << last_get_lattice_frame_; // step 4 - if (decoding_finalized_) ret &= determinizer_.Finalize(redeterminize); + if (decoding_finalized_) ret &= determinizer_.Finalize(); if (olat) { *olat = determinizer_.GetDeterminizedLattice(); ret &= (olat->NumStates() > 0); @@ -1056,10 +1043,11 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( int32 token_label = r->second; auto range = token_label2last_state_map.equal_range(token_label); if (range.first == range.second) { - KALDI_WARN << "The token in the first frame of this chunk does not " - "exist in the last frame of previous chunk. It should be seldom" - " happen and probably caused by over-pruning in determinization," - "e.g. the lattice reaches --max-mem constrain."; + KALDI_WARN + << "The token in the first frame of this chunk does not " + "exist in the last frame of previous chunk. It should be seldom" + " happen and probably caused by over-pruning in determinization," + "e.g. the lattice reaches --max-mem constrain."; continue; } std::vector tmp_vec; @@ -1163,8 +1151,7 @@ int32 LatticeIncrementalDecoderTpl::GetNumToksForFrame(int32 frame) template LatticeIncrementalDeterminizer::LatticeIncrementalDeterminizer( - const LatticeIncrementalDecoderConfig &config, - const TransitionModel &trans_model) + const LatticeIncrementalDecoderConfig &config, const TransitionModel &trans_model) : config_(config), trans_model_(trans_model) {} template @@ -1256,8 +1243,7 @@ void LatticeIncrementalDeterminizer::GetRawLatticeForRedeterminizedStates( weight_offset.SetWeight(LatticeWeight(0, -cost_offset)); // The arc weight is a combination of original arc weight, above cost_offset // and the weights on the final state - arc_weight = - Times(Times(arc_weight, lat_.Final(arc.nextstate)), weight_offset); + arc_weight = Times(Times(arc_weight, lat_.Final(arc.nextstate)), weight_offset); // We create a respective destination state for each final arc // later we will connect it to the state correponding to the token w.r.t @@ -1456,7 +1442,8 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla // We do not copy initial state, which exists except the first chunk if (!not_first_chunk || s != 0) { state_appended = s + state_offset; - KALDI_ASSERT(state_appended == olat->AddState()); + auto r = olat->AddState(); + KALDI_ASSERT(state_appended == r); olat->SetFinal(state_appended, clat.Final(s)); } @@ -1512,8 +1499,7 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla arc_offset = olat->NumArcs(source_state); } - if (!config_.epsilon_removal || - clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()) { + if (clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()) { // it should be the last chunk olat->AddArc(source_state, arc_appended); } else { @@ -1571,31 +1557,12 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla } template -bool LatticeIncrementalDeterminizer::Finalize(bool redeterminize) { +bool LatticeIncrementalDeterminizer::Finalize() { using namespace fst; auto *olat = &lat_; // The lattice determinization only needs to be finalized once if (determinization_finalized_) return true; - // step 4: re-determinize the final lattice - if (redeterminize) { - Connect(olat); // Remove unreachable states... there might be - DeterminizeLatticePrunedOptions det_opts; - det_opts.delta = config_.det_opts.delta; - det_opts.max_mem = config_.det_opts.max_mem; - Lattice lat; - ConvertLattice(*olat, &lat); - Invert(&lat); - if (lat.Properties(fst::kTopSorted, true) == 0) { - if (!TopSort(&lat)) { - // Cannot topologically sort the lattice -- determinization will fail. - KALDI_ERR << "Topological sorting of state-level lattice failed (probably" - << " your lexicon has empty words or your LM has epsilon cycles" - << ")."; - } - } - if (!DeterminizeLatticePruned(lat, config_.lattice_beam, olat, det_opts)) - KALDI_WARN << "Determinization finished earlier than the beam"; - } + // step 4: remove dead states Connect(olat); // Remove unreachable states... there might be KALDI_VLOG(2) << "states of the lattice: " << olat->NumStates(); determinization_finalized_ = true; @@ -1605,8 +1572,7 @@ bool LatticeIncrementalDeterminizer::Finalize(bool redeterminize) { // Instantiate the template for the combination of token types and FST types // that we'll need. -template class LatticeIncrementalDecoderTpl, - decoder::StdToken>; +template class LatticeIncrementalDecoderTpl, decoder::StdToken>; template class LatticeIncrementalDecoderTpl, decoder::StdToken>; template class LatticeIncrementalDecoderTpl, diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index cb0dd04c8e7..3c0b3024c8d 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -41,15 +41,12 @@ struct LatticeIncrementalDecoderConfig { int32 determinize_delay; int32 determinize_chunk_size; int32 determinize_max_active; - bool redeterminize; int32 redeterminize_max_frames; - bool epsilon_removal; BaseFloat beam_delta; // has nothing to do with beam_ratio BaseFloat hash_ratio; - BaseFloat - prune_scale; // Note: we don't make this configurable on the command line, - // it's not a very important parameter. It affects the - // algorithm that prunes the tokens as we go. + BaseFloat prune_scale; // Note: we don't make this configurable on the command line, + // it's not a very important parameter. It affects the + // algorithm that prunes the tokens as we go. // Most of the options inside det_opts are not actually queried by the // LatticeIncrementalDecoder class itself, but by the code that calls it, for // example in the function DecodeUtteranceLatticeIncremental. @@ -65,9 +62,7 @@ struct LatticeIncrementalDecoderConfig { determinize_delay(25), determinize_chunk_size(20), determinize_max_active(std::numeric_limits::max()), - redeterminize(false), redeterminize_max_frames(std::numeric_limits::max()), - epsilon_removal(false), beam_delta(0.5), hash_ratio(2.0), prune_scale(0.1), @@ -86,10 +81,12 @@ struct LatticeIncrementalDecoderConfig { "Interval (in frames) at " "which to prune tokens"); opts->Register("determinize-delay", &determinize_delay, - "delay (in frames, typically larger than --prune-interval) " - "at which to incrementally determinize lattices."); + "Delay (in frames) at which to incrementally determinize " + "lattices. A larger delay reduces the computational " + "overheads of incremental deteriminization while increasing" + "the length of the last chunk which may increase latencies."); opts->Register("determinize-chunk-size", &determinize_chunk_size, - "the size (in frames) of chunk to do incrementally " + "The size (in frames) of chunk to do incrementally " "determinization. If working with --determinize-max-active," "it will become a lower bound of the size of chunk."); opts->Register("determinize-max-active", &determinize_max_active, @@ -101,16 +98,13 @@ struct LatticeIncrementalDecoderConfig { "determinized up to this frame. It can work with " "--determinize-delay to further reduce the computation " "introduced by incremental determinization. "); - opts->Register("redeterminize", &redeterminize, - "whether to re-determinize the lattice after incremental " - "determinization."); opts->Register("redeterminize_max_frames", &redeterminize_max_frames, "To impose a limit on how far back in time we will " "redeterminize states. This is mainly intended to avoid " - "pathological cases. You could set it infinite to get a fully " + "pathological cases. Smaller value leads to less " + "deterministic but less likely to blow up the processing" + "time in bad cases. You could set it infinite to get a fully " "determinized lattice."); - opts->Register("epsilon-removal", &epsilon_removal, - "whether to remove epsilon when appending two adjacent chunks."); opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " "parameter is obscure and relates to a speedup in the way the " @@ -123,9 +117,9 @@ struct LatticeIncrementalDecoderConfig { KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 && min_active <= max_active && prune_interval > 0 && determinize_delay >= 0 && determinize_max_active >= 0 && - determinize_chunk_size >= 0 && - redeterminize_max_frames >= 0 && beam_delta > 0.0 && - hash_ratio >= 1.0 && prune_scale > 0.0 && prune_scale < 1.0); + determinize_chunk_size >= 0 && redeterminize_max_frames >= 0 && + beam_delta > 0.0 && hash_ratio >= 1.0 && prune_scale > 0.0 && + prune_scale < 1.0); } }; @@ -170,18 +164,24 @@ class LatticeIncrementalDecoderTpl { LatticeIncrementalDecoderTpl(const LatticeIncrementalDecoderConfig &config, FST *fst, const TransitionModel &trans_model); - void SetOptions(const LatticeIncrementalDecoderConfig &config) { - config_ = config; - } + void SetOptions(const LatticeIncrementalDecoderConfig &config) { config_ = config; } const LatticeIncrementalDecoderConfig &GetOptions() const { return config_; } ~LatticeIncrementalDecoderTpl(); - /// Decodes until there are no more frames left in the "decodable" object.. - /// note, this may block waiting for input if the "decodable" object blocks. - /// Returns true if any kind of traceback is available (not necessarily from a - /// final state). + /// An example of how to do decoding together with incremental + /// determinization. It decodes until there are no more frames left in the + /// "decodable" object. Note, this may block waiting for input + /// if the "decodable" object blocks. + /// In this example, config_.determinize_delay, config_.determinize_chunk_size + /// and config_.determinize_max_active are used to determine the time to + /// call GetLattice(). + /// Users may do it in their own ways by calling + /// AdvanceDecoding() and GetLattice(). So the logic for deciding + /// when we get the lattice would be driven by the user. + /// The function returns true if any kind + /// of traceback is available (not necessarily from a final state). bool Decode(DecodableInterface *decodable); /// says whether a final-state was active on the last frame. If it was not, the @@ -194,92 +194,78 @@ class LatticeIncrementalDecoderTpl { /// Returns true if result is nonempty (using the return status is deprecated, /// it will become void). If "use_final_probs" is true AND we reached the /// final-state of the graph then it will include those as final-probs, else - /// it will treat all final-probs as one. Note: this just calls GetRawLattice() - /// and figures out the shortest path. + /// it will treat all final-probs as one. bool GetBestPath(Lattice *ofst, bool use_final_probs = true); - // The following function is specifically designed for incremental - // determinization. The function obtains a CompactLattice for - // the part of this utterance up to the frame last_frame_of_chunk. - // If you call this multiple times - // (calling it on every frame would not make sense, - // but every, say, 10 to 40 frames might make sense) it will spread out the - // work of determinization over time, which might be useful for online - // applications. - // - // The procedure of incremental determinization is as follow: - // step 1: Get lattice chunk with initial and final states and arcs, called `raw - // lattice`. - // Here, we define a `final arc` as an arc to a final-state, and the source state - // of it as a `pre-final state` - // Similarly, we define a `initial arc` as an arc from a initial-state, and the - // destination state of it as a `post-initial state` - // The post-initial states are constructed corresponding to pre-final states - // in the determinized and appended lattice before this chunk - // The pre-final states are constructed correponding to tokens in the last frames - // of this chunk. - // Since the StateId can change during determinization, we need to give permanent - // unique labels (as olabel) to these - // raw-lattice states for latter appending. - // We give each token an olabel id, called `token_label`, and each determinized and - // appended state an olabel id, called `state_label` - // step 2: Determinize the chunk of above raw lattice using determinization - // algorithm the same as LatticeFasterDecoder. Benefit from above `state_label` and - // `token_label` in initial and final arcs, each pre-final state in the last chunk - // w.r.t the initial arc of this chunk can be treated uniquely and each token in - // the last frame of this chunk can also be treated uniquely. We call the - // determinized new - // chunk `compact lattice (clat)` - // step 3: Appending the new chunk `clat` to the determinized lattice - // before this chunk. First, for each StateId in clat except its - // initial state, allocate a new StateId in the appended - // compact lattice. Copy the arcs except whose incoming state is initial - // state. Secondly, for each initial arcs, change its source state to the state - // corresponding to its `state_label`, which is a determinized and appended state - // Finally, we make the previous final arcs point to a "dead state" - // step 4 (optional): We re-determinize the appended lattice if needed. - // - // In our implementation, step 1 is done in GetIncrementalRawLattice(), - // step 2-4 is taken care by the class - // LatticeIncrementalDeterminizer - // - // @param [in] use_final_probs If true *and* at least one final-state in HCLG - // was active on the final frame, include final-probs from - // HCLG - // in the lattice. Otherwise treat all final-costs of - // states active - // on the most recent frame as zero (i.e. Weight::One()). - // @param [in] redeterminize If true, re-determinize the CompactLattice - // after appending the most recently decoded chunk to it, - // to - // ensure that the output is fully deterministic. - // This does extra work, but not nearly as much as - // determinizing - // a RawLattice from scratch. - // @param [in] last_frame_of_chunk Pass the last frame of this chunk to - // the function. We make it not always equal to - // NumFramesDecoded() to have a delay on the - // deteriminization - // @param [out] olat The CompactLattice representing what has been decoded - // so far. - // If lat == NULL, the CompactLattice won't be outputed. - // @return ret This function will returns true if the chunk is processed - // successfully - bool GetLattice(bool use_final_probs, bool redeterminize, - int32 last_frame_of_chunk, CompactLattice *olat = NULL); + /** + The following function is specifically designed for incremental + determinization. The function obtains a CompactLattice for + the part of this utterance up to the frame last_frame_of_chunk. + If you call this multiple times + (calling it on every frame would not make sense, but every, say, + 10 to 40 frames might make sense) it will spread out the work of + determinization over time, which might be useful for online applications. + config_.determinize_delay, config_.determinize_chunk_size + and config_.determinize_max_active can be used to determine the time to + call this function. We show an example in Decode(). + + The procedure of incremental determinization is as follow: + step 1: Get lattice chunk with initial and final states and arcs, called `raw + lattice`. + Here, we define a `final arc` as an arc to a final-state, and the source state + of it as a `pre-final state` + Similarly, we define a `initial arc` as an arc from a initial-state, and the + destination state of it as a `post-initial state` + The post-initial states are constructed corresponding to pre-final states + in the determinized and appended lattice before this chunk + The pre-final states are constructed correponding to tokens in the last frames + of this chunk. + Since the StateId can change during determinization, we need to give permanent + unique labels (as olabel) to these + raw-lattice states for latter appending. + We give each token an olabel id, called `token_label`, and each determinized and + appended state an olabel id, called `state_label` + step 2: Determinize the chunk of above raw lattice using determinization + algorithm the same as LatticeFasterDecoder. Benefit from above `state_label` and + `token_label` in initial and final arcs, each pre-final state in the last chunk + w.r.t the initial arc of this chunk can be treated uniquely and each token in + the last frame of this chunk can also be treated uniquely. We call the + determinized new + chunk `compact lattice (clat)` + step 3: Appending the new chunk `clat` to the determinized lattice + before this chunk. First, for each StateId in clat except its + initial state, allocate a new StateId in the appended + compact lattice. Copy the arcs except whose incoming state is initial + state. Secondly, for each initial arcs, change its source state to the state + corresponding to its `state_label`, which is a determinized and appended state + Finally, we make the previous final arcs point to a "dead state" + step 4: We remove dead states in the very end. + + In our implementation, step 1 is done in GetIncrementalRawLattice(), + step 2-4 is taken care by the class + LatticeIncrementalDeterminizer + + @param [in] use_final_probs If true *and* at least one final-state in HCLG + was active on the final frame, include final-probs from + HCLG + in the lattice. Otherwise treat all final-costs of + states active + on the most recent frame as zero (i.e. Weight::One()). + @param [in] last_frame_of_chunk Pass the last frame of this chunk to + the function. We make it not always equal to + NumFramesDecoded() to have a delay on the + deteriminization + @param [out] olat The CompactLattice representing what has been decoded + so far. + If lat == NULL, the CompactLattice won't be outputed. + @return ret This function will returns true if the chunk is processed + successfully + */ + bool GetLattice(bool use_final_probs, int32 last_frame_of_chunk, + CompactLattice *olat = NULL); /// Specifically design when decoding_finalized_==true bool GetLattice(CompactLattice *olat); - /// This function is to keep forwards compatibility. - /// It outputs an FST corresponding to the raw, state-level - /// tracebacks. Returns true if result is nonempty. - /// If "use_final_probs" is true AND we reached the final-state - /// of the graph then it will include those as final-probs, else - /// it will treat all final-probs as one. - /// Notably, the raw lattice from this incremental determinization decoder - /// has already been partially determinized - bool GetRawLattice(Lattice *ofst, bool use_final_probs = true); - /// InitDecoding initializes the decoding, and should only be used if you /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to /// call this. You can also call InitDecoding if you have already decoded an @@ -294,9 +280,9 @@ class LatticeIncrementalDecoderTpl { /// This function may be optionally called after AdvanceDecoding(), when you /// do not plan to decode any further. It does an extra pruning step that - /// will help to prune the lattices output by GetLattice and (particularly) - /// GetRawLattice more accurately, particularly toward the end of the - /// utterance. It does this by using the final-probs in pruning (if any + /// will help to prune the lattices output by GetLattice more accurately, + /// particularly toward the end of the utterance. + /// It does this by using the final-probs in pruning (if any /// final-state survived); it also does a final pruning step that visits all /// states (the pruning that is done during decoding may fail to prune states /// that are within kPruningScale = 0.1 outside of the beam). If you call @@ -359,8 +345,7 @@ class LatticeIncrementalDecoderTpl { // If Token == StdToken, the 'backpointer' argument has no purpose (and will // hopefully be optimized out). inline Token *FindOrAddToken(StateId state, int32 frame_plus_one, - BaseFloat tot_cost, Token *backpointer, - bool *changed); + BaseFloat tot_cost, Token *backpointer, bool *changed); // prunes outgoing links for all tokens in active_toks_[frame] // it's called by PruneActiveTokens @@ -589,7 +574,7 @@ class LatticeIncrementalDeterminizer { // Step 4 of incremental determinization, // which either re-determinize above lat_, or simply remove the dead // states of lat_ - bool Finalize(bool redeterminize); + bool Finalize(); std::vector &GetForwardCosts() { return forward_costs_; } private: From 35a7abc45d28a090c6067a8431269c912f758c6f Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Thu, 23 May 2019 21:23:47 -0400 Subject: [PATCH 28/60] We need to be careful about the case where the start state of the `appended lattice' becomes a redeterminized-state. Certain things break down then, we might get an empty lattice. This could happen if there is silence at the utterance start. It would likely be fixable by inserting an epsilon transition from a new start state to the old start state, doing the normal procedure, and then removing that epsilon (the same way we would likely do when applying the `max-redeterminize-frames` option). --- src/decoder/lattice-incremental-decoder.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 7c860412a95..0cebb264383 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -1309,7 +1309,11 @@ void LatticeIncrementalDeterminizer::GetRedeterminizedStates() { auto final_weight = lat_.Final(final_arc.nextstate); KALDI_ASSERT(final_weight != CompactLatticeWeight::Zero()); auto num_frames = Times(final_arc.weight, final_weight).String().size(); - if (num_frames <= config_.redeterminize_max_frames) + // If the state is too far from the end of the current appended lattice, + // we leave the non-final arcs unchanged and only redeterminize the final + // arcs by the following procedure. + // We also do above things once we prepare to redeterminize the start state. + if (num_frames <= config_.redeterminize_max_frames && prefinal_state != 0) processed_prefinal_states_[prefinal_state] = prefinal_state; else { KALDI_VLOG(7) << "Impose a limit of " << config_.redeterminize_max_frames From 39d4181ad248bf466bb87fb7696fa8d7f1b5514b Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Thu, 30 May 2019 08:52:46 +0800 Subject: [PATCH 29/60] code refine --- src/bin/latgen-incremental-mapped.cc | 74 ++++++++++++----------- src/decoder/decoder-wrappers.cc | 8 +-- src/decoder/lattice-incremental-decoder.h | 4 +- 3 files changed, 45 insertions(+), 41 deletions(-) diff --git a/src/bin/latgen-incremental-mapped.cc b/src/bin/latgen-incremental-mapped.cc index a5b7619c0ab..6753cf49077 100644 --- a/src/bin/latgen-incremental-mapped.cc +++ b/src/bin/latgen-incremental-mapped.cc @@ -17,7 +17,6 @@ // See the Apache 2 License for the specific language governing permissions and // limitations under the License. - #include "base/kaldi-common.h" #include "util/common-utils.h" #include "tree/context-dep.h" @@ -27,7 +26,6 @@ #include "decoder/decodable-matrix.h" #include "base/timer.h" - int main(int argc, char *argv[]) { try { using namespace kaldi; @@ -39,7 +37,8 @@ int main(int argc, char *argv[]) { const char *usage = "Generate lattices, reading log-likelihoods as matrices\n" " (model is needed only for the integer mappings in its transition-model)\n" - "Usage: latgen-incremental-mapped [options] trans-model-in (fst-in|fsts-rspecifier) loglikes-rspecifier" + "Usage: latgen-incremental-mapped [options] trans-model-in " + "(fst-in|fsts-rspecifier) loglikes-rspecifier" " lattice-wspecifier [ words-wspecifier [alignments-wspecifier] ]\n"; ParseOptions po(usage); Timer timer; @@ -49,10 +48,13 @@ int main(int argc, char *argv[]) { std::string word_syms_filename; config.Register(&po); - po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic likelihoods"); + po.Register("acoustic-scale", &acoustic_scale, + "Scaling factor for acoustic likelihoods"); - po.Register("word-symbol-table", &word_syms_filename, "Symbol table for words [for debug output]"); - po.Register("allow-partial", &allow_partial, "If true, produce output even if end state was not reached."); + po.Register("word-symbol-table", &word_syms_filename, + "Symbol table for words [for debug output]"); + po.Register("allow-partial", &allow_partial, + "If true, produce output even if end state was not reached."); po.Read(argc, argv); @@ -61,12 +63,10 @@ int main(int argc, char *argv[]) { exit(1); } - std::string model_in_filename = po.GetArg(1), - fst_in_str = po.GetArg(2), - feature_rspecifier = po.GetArg(3), - lattice_wspecifier = po.GetArg(4), - words_wspecifier = po.GetOptArg(5), - alignment_wspecifier = po.GetOptArg(6); + std::string model_in_filename = po.GetArg(1), fst_in_str = po.GetArg(2), + feature_rspecifier = po.GetArg(3), lattice_wspecifier = po.GetArg(4), + words_wspecifier = po.GetOptArg(5), + alignment_wspecifier = po.GetOptArg(6); TransitionModel trans_model; ReadKaldiObject(model_in_filename, &trans_model); @@ -74,10 +74,10 @@ int main(int argc, char *argv[]) { bool determinize = true; CompactLatticeWriter compact_lattice_writer; LatticeWriter lattice_writer; - if (! (determinize ? compact_lattice_writer.Open(lattice_wspecifier) - : lattice_writer.Open(lattice_wspecifier))) + if (!(determinize ? compact_lattice_writer.Open(lattice_wspecifier) + : lattice_writer.Open(lattice_wspecifier))) KALDI_ERR << "Could not open table for writing lattices: " - << lattice_wspecifier; + << lattice_wspecifier; Int32VectorWriter words_writer(words_wspecifier); @@ -86,8 +86,7 @@ int main(int argc, char *argv[]) { fst::SymbolTable *word_syms = NULL; if (word_syms_filename != "") if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename))) - KALDI_ERR << "Could not read symbol table from file " - << word_syms_filename; + KALDI_ERR << "Could not read symbol table from file " << word_syms_filename; double tot_like = 0.0; kaldi::int64 frame_count = 0; @@ -104,7 +103,7 @@ int main(int argc, char *argv[]) { for (; !loglike_reader.Done(); loglike_reader.Next()) { std::string utt = loglike_reader.Key(); - Matrix loglikes (loglike_reader.Value()); + Matrix loglikes(loglike_reader.Value()); loglike_reader.FreeCurrent(); if (loglikes.NumRows() == 0) { KALDI_WARN << "Zero-length utterance: " << utt; @@ -112,22 +111,24 @@ int main(int argc, char *argv[]) { continue; } - DecodableMatrixScaledMapped decodable(trans_model, loglikes, acoustic_scale); + DecodableMatrixScaledMapped decodable(trans_model, loglikes, + acoustic_scale); double like; if (DecodeUtteranceLatticeIncremental( - decoder, decodable, trans_model, word_syms, utt, - acoustic_scale, determinize, allow_partial, &alignment_writer, - &words_writer, &compact_lattice_writer, &lattice_writer, - &like)) { + decoder, decodable, trans_model, word_syms, utt, acoustic_scale, + determinize, allow_partial, &alignment_writer, &words_writer, + &compact_lattice_writer, &lattice_writer, &like)) { tot_like += like; frame_count += loglikes.NumRows(); num_success++; - } else num_fail++; + } else { + num_fail++; + } } } delete decode_fst; // delete this only after decoder goes out of scope. - } else { // We have different FSTs for different utterances. + } else { // We have different FSTs for different utterances. SequentialTableReader fst_reader(fst_in_str); RandomAccessBaseFloatMatrixReader loglike_reader(feature_rspecifier); for (; !fst_reader.Done(); fst_reader.Next()) { @@ -154,23 +155,26 @@ int main(int argc, char *argv[]) { tot_like += like; frame_count += loglikes.NumRows(); num_success++; - } else num_fail++; + } else { + num_fail++; + } } } double elapsed = timer.Elapsed(); - KALDI_LOG << "Time taken "<< elapsed + KALDI_LOG << "Time taken " << elapsed << "s: real-time factor assuming 100 frames/sec is " - << (elapsed*100.0/frame_count); - KALDI_LOG << "Done " << num_success << " utterances, failed for " - << num_fail; - KALDI_LOG << "Overall log-likelihood per frame is " << (tot_like/frame_count) << " over " - << frame_count<<" frames."; + << (elapsed * 100.0 / frame_count); + KALDI_LOG << "Done " << num_success << " utterances, failed for " << num_fail; + KALDI_LOG << "Overall log-likelihood per frame is " << (tot_like / frame_count) + << " over " << frame_count << " frames."; delete word_syms; - if (num_success != 0) return 0; - else return 1; - } catch(const std::exception &e) { + if (num_success != 0) + return 0; + else + return 1; + } catch (const std::exception &e) { std::cerr << e.what(); return -1; } diff --git a/src/decoder/decoder-wrappers.cc b/src/decoder/decoder-wrappers.cc index a912af41077..af476f2322f 100644 --- a/src/decoder/decoder-wrappers.cc +++ b/src/decoder/decoder-wrappers.cc @@ -68,7 +68,7 @@ void DecodeUtteranceLatticeFasterClass::operator () () { success_ = true; using fst::VectorFst; if (!decoder_->Decode(decodable_)) { - KALDI_WARN << "Failed to decode file " << utt_; + KALDI_WARN << "Failed to decode utterance with id " << utt_; success_ = false; } if (!decoder_->ReachedFinal()) { @@ -212,7 +212,7 @@ bool DecodeUtteranceLatticeIncremental( double *like_ptr) { // puts utterance's like in like_ptr on success. using fst::VectorFst; if (!decoder.Decode(&decodable)) { - KALDI_WARN << "Failed to decode file " << utt; + KALDI_WARN << "Failed to decode utterance with id " << utt; return false; } if (!decoder.ReachedFinal()) { @@ -296,7 +296,7 @@ bool DecodeUtteranceLatticeFaster( using fst::VectorFst; if (!decoder.Decode(&decodable)) { - KALDI_WARN << "Failed to decode file " << utt; + KALDI_WARN << "Failed to decode utterance with id " << utt; return false; } if (!decoder.ReachedFinal()) { @@ -457,7 +457,7 @@ bool DecodeUtteranceLatticeSimple( using fst::VectorFst; if (!decoder.Decode(&decodable)) { - KALDI_WARN << "Failed to decode file " << utt; + KALDI_WARN << "Failed to decode utterance with id " << utt; return false; } if (!decoder.ReachedFinal()) { diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 3c0b3024c8d..1b79dbce08b 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -66,7 +66,7 @@ struct LatticeIncrementalDecoderConfig { beam_delta(0.5), hash_ratio(2.0), prune_scale(0.1), - max_word_id(1e7) {} + max_word_id(1e8) {} void Register(OptionsItf *opts) { det_opts.Register(opts); opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); @@ -137,7 +137,7 @@ class LatticeIncrementalDeterminizer; will normally be StdToken, but also may be BackpointerToken which is to support quick lookup of the current best path (see lattice-faster-online-decoder.h) - The FST you invoke this decoder with is expected to equal + The FST you invoke this decoder with is expected to be of type Fst::Fst, a.k.a. StdFst, or GrammarFst. If you invoke it with FST == StdFst and it notices that the actual FST type is fst::VectorFst or fst::ConstFst, the decoder object From 8e1648dddcbdb4349016e52350eaf71e486ae0f8 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Thu, 30 May 2019 14:17:23 +0800 Subject: [PATCH 30/60] Do the following modification. Results can be referred to sheet "ver 8" of https://docs.google.com/spreadsheets/d/1M_AlG1SYnDN663Fh8U0WMu7JrwKO9I_H5rCFXj6soI0/edit?usp=sharing Let a {\em final-arc} be an arc with an arc with a {\em state-label} on it. Suppose some state $s$ in the already-determinized part of the lattice has at least on {\em final-arc} leaving it, and it also has at least one successor-state $r$ whose time is more than {\em redeterminize-frames} earlier than the last frame in the already-determinized part of the lattice and which has no {\em final-arc} leaving it. Before processing the already-determinized part of the lattice to determine which states are {\em redeterminized states} as described in Section~\ref{sec:determinizing_second_half}*, we modify it as follows: for each such state $s$: we add a new state $t$, put an epsilon transition from $s$ to $t$, and move all the arcs which are {\em not} to successor-states such as state $r$ above, to leave from $t$ instead of $s$. This modification ensures that states such as $r$ do not become {\em redeterminized-states}, i.e. it limits the number of states we have to redeterminize. We will remove the newly-introduced epsilons after determinizing the latest piece of the lattice and combining it with the already-determinized portion. --- src/decoder/lattice-incremental-decoder.cc | 36 +++++++++++++++++----- src/decoder/lattice-incremental-decoder.h | 14 ++++++--- 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 0cebb264383..14027de32bd 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -1297,11 +1297,10 @@ void LatticeIncrementalDeterminizer::GetRedeterminizedStates() { processed_prefinal_states_.clear(); // go over all prefinal state KALDI_ASSERT(final_arc_list_prev_.size()); + unordered_set prefinal_states; + for (auto &i : final_arc_list_prev_) { auto prefinal_state = i.first; - if (processed_prefinal_states_.find(prefinal_state) != - processed_prefinal_states_.end()) - continue; ArcIterator aiter(lat_, prefinal_state); KALDI_ASSERT(lat_.NumArcs(prefinal_state) > i.second); aiter.Seek(i.second); @@ -1327,12 +1326,35 @@ void LatticeIncrementalDeterminizer::GetRedeterminizedStates() { std::vector arcs_remained; for (aiter.Reset(); !aiter.Done(); aiter.Next()) { auto arc = aiter.Value(); + bool remain_the_arc = true; // If we remain the arc, the state will not be + // re-determinized, vice versa. if (arc.olabel > config_.max_word_id) { // final arc KALDI_ASSERT(arc.olabel < state_last_initial_offset_); KALDI_ASSERT(lat_.Final(arc.nextstate) != CompactLatticeWeight::Zero()); - lat_.AddArc(new_prefinal_state, arc); - } else + remain_the_arc = false; + } else { + int num_frames_exclude_arc = num_frames - arc.weight.String().size(); + // destination-state of the arc is further than redeterminize_max_frames + // from the most recent frame we are determinizing + if (num_frames_exclude_arc > config_.redeterminize_max_frames) + remain_the_arc = true; + else { + // destination-state of the arc is no further than + // redeterminize_max_frames from the most recent frame we are + // determinizing + auto r = final_arc_list_prev_.find(arc.nextstate); + // destination-state of the arc is not prefinal state + if (r == final_arc_list_prev_.end()) remain_the_arc = true; + // destination-state of the arc is prefinal state + else + remain_the_arc = false; + } + } + + if (remain_the_arc) arcs_remained.push_back(arc); + else + lat_.AddArc(new_prefinal_state, arc); } CompactLatticeArc arc_to_new(0, 0, CompactLatticeWeight::One(), new_prefinal_state); @@ -1475,7 +1497,7 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla KALDI_ASSERT(clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()); // state_appended shouldn't be in invert_processed_prefinal_states // So we do not need to map it - final_arc_list_.push_back( + final_arc_list_.insert( pair(state_appended, aiter.Position())); } olat->AddArc(source_state, arc_appended); @@ -1517,7 +1539,7 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla olat->AddArc(source_state, arc_postinitial); if (arc_postinitial.olabel > config_.max_word_id) { KALDI_ASSERT(arc_postinitial.olabel < state_last_initial_offset_); - final_arc_list_.push_back(pair( + final_arc_list_.insert(pair( source_state, aiter_postinitial.Position() + arc_offset)); } } diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 1b79dbce08b..db9667b83de 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -599,7 +599,11 @@ class LatticeIncrementalDeterminizer { // appended lattice, we create an extra state for it; we add an epsilon arc // from that pre-final state to the extra state; we copy any final arcs from // the pre-final state to its extra state and we remove those final arcs from - // the original pre-final state. Now this extra state is the pre-final state to + // the original pre-final state. + // We also copy arcs meet the following requirements: i) destination-state of the + // arc is prefinal state. ii) destination-state of the arc is no further than than + // redeterminize_max_frames from the most recent frame we are determinizing. + // Now this extra state is the pre-final state to // redeterminize and the original pre-final state does not need to redeterminize // The epsilon would be removed later on in AppendLatticeChunks, while // splicing the compact lattices together @@ -611,9 +615,11 @@ class LatticeIncrementalDeterminizer { // Record whether we have finished determinized the whole utterance // (including re-determinize) bool determinization_finalized_; - // keep final_arc for appending later - std::vector> final_arc_list_; - std::vector> final_arc_list_prev_; + // A map from the prefinal state to its correponding first final arc (there could be + // multiple final arcs). We keep final arc information for GetRedeterminizedStates() + // later. It can also be used to identify whether a state is a prefinal state. + unordered_map final_arc_list_; + unordered_map final_arc_list_prev_; // alpha of each state in lat_ std::vector forward_costs_; // we allocate a unique id for each source-state of the last arc of a series of From 9a0873e317d13fce3c87db47f2abd79f1b4682f6 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Tue, 11 Jun 2019 05:04:33 -0400 Subject: [PATCH 31/60] make terms consistent with the paper --- src/decoder/lattice-incremental-decoder.cc | 6 +++--- src/decoder/lattice-incremental-decoder.h | 15 ++++++++------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 14027de32bd..87ab4093ac7 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -107,10 +107,10 @@ bool LatticeIncrementalDecoderTpl::Decode(DecodableInterface *decoda // Moreover, the delay on GetLattice to do determinization // make it process more skinny lattices which reduces the computation overheads. int32 frame_det_most = NumFramesDecoded() - config_.determinize_delay; - // The minimum length of chunk is config_.determinize_chunk_size. - if (frame_det_most % config_.determinize_chunk_size == 0) { + // The minimum length of chunk is config_.determinize_period. + if (frame_det_most % config_.determinize_period == 0) { int32 frame_det_least = - last_get_lattice_frame_ + config_.determinize_chunk_size; + last_get_lattice_frame_ + config_.determinize_period; // To adaptively decide the length of chunk, we further compare the number of // tokens in each frame and a pre-defined threshold. // If the number of tokens in a certain frame is less than diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index db9667b83de..9ec1b08a496 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -39,7 +39,7 @@ struct LatticeIncrementalDecoderConfig { BaseFloat lattice_beam; int32 prune_interval; int32 determinize_delay; - int32 determinize_chunk_size; + int32 determinize_period; int32 determinize_max_active; int32 redeterminize_max_frames; BaseFloat beam_delta; // has nothing to do with beam_ratio @@ -60,7 +60,7 @@ struct LatticeIncrementalDecoderConfig { lattice_beam(10.0), prune_interval(25), determinize_delay(25), - determinize_chunk_size(20), + determinize_period(20), determinize_max_active(std::numeric_limits::max()), redeterminize_max_frames(std::numeric_limits::max()), beam_delta(0.5), @@ -85,7 +85,7 @@ struct LatticeIncrementalDecoderConfig { "lattices. A larger delay reduces the computational " "overheads of incremental deteriminization while increasing" "the length of the last chunk which may increase latencies."); - opts->Register("determinize-chunk-size", &determinize_chunk_size, + opts->Register("determinize-period", &determinize_period, "The size (in frames) of chunk to do incrementally " "determinization. If working with --determinize-max-active," "it will become a lower bound of the size of chunk."); @@ -117,7 +117,7 @@ struct LatticeIncrementalDecoderConfig { KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 && min_active <= max_active && prune_interval > 0 && determinize_delay >= 0 && determinize_max_active >= 0 && - determinize_chunk_size >= 0 && redeterminize_max_frames >= 0 && + determinize_period >= 0 && redeterminize_max_frames >= 0 && beam_delta > 0.0 && hash_ratio >= 1.0 && prune_scale > 0.0 && prune_scale < 1.0); } @@ -174,7 +174,7 @@ class LatticeIncrementalDecoderTpl { /// determinization. It decodes until there are no more frames left in the /// "decodable" object. Note, this may block waiting for input /// if the "decodable" object blocks. - /// In this example, config_.determinize_delay, config_.determinize_chunk_size + /// In this example, config_.determinize_delay, config_.determinize_period /// and config_.determinize_max_active are used to determine the time to /// call GetLattice(). /// Users may do it in their own ways by calling @@ -205,7 +205,7 @@ class LatticeIncrementalDecoderTpl { (calling it on every frame would not make sense, but every, say, 10 to 40 frames might make sense) it will spread out the work of determinization over time, which might be useful for online applications. - config_.determinize_delay, config_.determinize_chunk_size + config_.determinize_delay, config_.determinize_period and config_.determinize_max_active can be used to determine the time to call this function. We show an example in Decode(). @@ -224,7 +224,8 @@ class LatticeIncrementalDecoderTpl { unique labels (as olabel) to these raw-lattice states for latter appending. We give each token an olabel id, called `token_label`, and each determinized and - appended state an olabel id, called `state_label` + appended state an olabel id, called `state_label`. Notably, in our + paper, we call both of them ``state labels'' for simplicity. step 2: Determinize the chunk of above raw lattice using determinization algorithm the same as LatticeFasterDecoder. Benefit from above `state_label` and `token_label` in initial and final arcs, each pre-final state in the last chunk From 4448c1fc2c46358fc5f6d81ec39182e0b07f37b0 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Thu, 20 Jun 2019 16:40:36 +0800 Subject: [PATCH 32/60] code refine according to Hainan's comments --- src/decoder/lattice-incremental-decoder.cc | 26 +++++++++++----------- src/decoder/lattice-incremental-decoder.h | 6 ++--- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 87ab4093ac7..223771f9214 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -957,10 +957,11 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, ret = determinizer_.ProcessChunk(raw_fst, last_get_lattice_frame_, last_frame_of_chunk); last_get_lattice_frame_ = last_frame_of_chunk; - } else if (last_get_lattice_frame_ > last_frame_of_chunk) + } else if (last_get_lattice_frame_ > last_frame_of_chunk) { KALDI_WARN << "Call GetLattice up to frame: " << last_frame_of_chunk << " while the determinizer_ has already done up to frame: " << last_get_lattice_frame_; + } // step 4 if (decoding_finalized_) ret &= determinizer_.Finalize(); @@ -1002,10 +1003,10 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( ofst->DeleteStates(); unordered_multimap - token_label2last_state_map; // for GetInitialRawLattice + token_label2last_state; // for GetInitialRawLattice // initial arcs for the chunk if (create_initial_state) - determinizer_.GetInitialRawLattice(ofst, &token_label2last_state_map, + determinizer_.GetInitialRawLattice(ofst, &token_label2last_state, token_label_final_cost_); // num-frames plus one (since frames are one-based, and we have // an extra frame for the start-state). @@ -1041,16 +1042,15 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( auto r = token_label_map_.find(tok); KALDI_ASSERT(r != token_label_map_.end()); // it should exist int32 token_label = r->second; - auto range = token_label2last_state_map.equal_range(token_label); + auto range = token_label2last_state.equal_range(token_label); if (range.first == range.second) { KALDI_WARN << "The token in the first frame of this chunk does not " - "exist in the last frame of previous chunk. It should be seldom" - " happen and probably caused by over-pruning in determinization," + "exist in the last frame of previous chunk. It should seldom" + " happen and would be caused by over-pruning in determinization," "e.g. the lattice reaches --max-mem constrain."; continue; } - std::vector tmp_vec; for (auto it = range.first; it != range.second; ++it) { // the destination state of the last of the sequence of arcs w.r.t the token // label @@ -1192,7 +1192,7 @@ template void LatticeIncrementalDeterminizer::GetRawLatticeForRedeterminizedStates( StateId start_state, StateId state, const unordered_map &token_label_final_cost, - unordered_multimap *token_label2last_state_map, + unordered_multimap *token_label2last_state, Lattice *olat) { using namespace fst; typedef LatticeArc Arc; @@ -1253,7 +1253,7 @@ void LatticeIncrementalDeterminizer::GetRawLatticeForRedeterminizedStates( // and connected to the state corresponding to token w.r.t arc_olabel // Notably, we have multiple states for one token label after determinization, // hence we use multiset here - token_label2last_state_map->insert( + token_label2last_state->insert( std::pair(arc_olabel, laststate_copy)); arc_olabel = 0; // remove token label } else { @@ -1288,7 +1288,7 @@ void LatticeIncrementalDeterminizer::GetRawLatticeForRedeterminizedStates( if (proc_nextstate) GetRawLatticeForRedeterminizedStates(start_state, arc.nextstate, token_label_final_cost, - token_label2last_state_map, olat); + token_label2last_state, olat); } } template @@ -1374,7 +1374,7 @@ void LatticeIncrementalDeterminizer::GetRedeterminizedStates() { template void LatticeIncrementalDeterminizer::GetInitialRawLattice( Lattice *olat, - unordered_multimap *token_label2last_state_map, + unordered_multimap *token_label2last_state, const unordered_map &token_label_final_cost) { using namespace fst; typedef LatticeArc Arc; @@ -1385,7 +1385,7 @@ void LatticeIncrementalDeterminizer::GetInitialRawLattice( GetRedeterminizedStates(); olat->DeleteStates(); - token_label2last_state_map->clear(); + token_label2last_state->clear(); auto start_state = olat->AddState(); olat->SetStart(start_state); @@ -1396,7 +1396,7 @@ void LatticeIncrementalDeterminizer::GetInitialRawLattice( if (modified) GetRawLatticeForRedeterminizedStates(start_state, prefinal_state, token_label_final_cost, - token_label2last_state_map, olat); + token_label2last_state, olat); } } diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 9ec1b08a496..12c8c8933e2 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -545,7 +545,7 @@ class LatticeIncrementalDeterminizer { // redeterminized states (see the description in redeterminized_states_) in the // determinized and appended lattice before this chunk. // We give each determinized and appended state an olabel id, called `state_label` - // We maintain a map (`token_label2last_state_map`) from token label (obtained from + // We maintain a map (`token_label2last_state`) from token label (obtained from // final arcs) to the destination state of the last of the sequence of initial arcs // w.r.t the token label here // Notably, we have multiple states for one token label after determinization, @@ -554,7 +554,7 @@ class LatticeIncrementalDeterminizer { // DeterminizeLatticePhonePrunedWrapper void GetInitialRawLattice( Lattice *olat, - unordered_multimap *token_label2last_state_map, + unordered_multimap *token_label2last_state, const unordered_map &token_label_final_cost); // This function consumes raw_fst generated by step 1 of incremental // determinization with specific initial and final arcs. @@ -591,7 +591,7 @@ class LatticeIncrementalDeterminizer { void GetRawLatticeForRedeterminizedStates( StateId start_state, StateId state, const unordered_map &token_label_final_cost, - unordered_multimap *token_label2last_state_map, + unordered_multimap *token_label2last_state, Lattice *olat); // This function is to preprocess the appended compact lattice before // generating raw lattices for the next chunk. From 0a4c9bb7ecd44bb964a8fbf9421756c68cfb80ab Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Tue, 2 Jul 2019 16:15:15 +0800 Subject: [PATCH 33/60] add final-prune-after-determinize --- src/decoder/lattice-incremental-decoder.cc | 6 +++++- src/decoder/lattice-incremental-decoder.h | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 223771f9214..3cbb7ed0d22 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -1589,7 +1589,11 @@ bool LatticeIncrementalDeterminizer::Finalize() { // The lattice determinization only needs to be finalized once if (determinization_finalized_) return true; // step 4: remove dead states - Connect(olat); // Remove unreachable states... there might be + if (config_.final_prune_after_determinize) + PruneLattice(config_.lattice_beam, olat); + else + Connect(olat); // Remove unreachable states... there might be + KALDI_VLOG(2) << "states of the lattice: " << olat->NumStates(); determinization_finalized_ = true; diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 12c8c8933e2..2ae58e697d3 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -42,6 +42,7 @@ struct LatticeIncrementalDecoderConfig { int32 determinize_period; int32 determinize_max_active; int32 redeterminize_max_frames; + bool final_prune_after_determinize; BaseFloat beam_delta; // has nothing to do with beam_ratio BaseFloat hash_ratio; BaseFloat prune_scale; // Note: we don't make this configurable on the command line, @@ -63,6 +64,7 @@ struct LatticeIncrementalDecoderConfig { determinize_period(20), determinize_max_active(std::numeric_limits::max()), redeterminize_max_frames(std::numeric_limits::max()), + final_prune_after_determinize(true), beam_delta(0.5), hash_ratio(2.0), prune_scale(0.1), @@ -98,13 +100,15 @@ struct LatticeIncrementalDecoderConfig { "determinized up to this frame. It can work with " "--determinize-delay to further reduce the computation " "introduced by incremental determinization. "); - opts->Register("redeterminize_max_frames", &redeterminize_max_frames, + opts->Register("redeterminize-max-frames", &redeterminize_max_frames, "To impose a limit on how far back in time we will " "redeterminize states. This is mainly intended to avoid " "pathological cases. Smaller value leads to less " "deterministic but less likely to blow up the processing" "time in bad cases. You could set it infinite to get a fully " "determinized lattice."); + opts->Register("final-prune-after-determinize", &final_prune_after_determinize, + "prune lattice after determinization "); opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " "parameter is obscure and relates to a speedup in the way the " From b4416f55efc58eaf4c7ab1d1fb6024cbbb4ea55b Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Wed, 3 Jul 2019 05:00:28 -0400 Subject: [PATCH 34/60] more comments --- src/decoder/lattice-incremental-decoder.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 3cbb7ed0d22..86ec6f7842a 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -116,7 +116,9 @@ bool LatticeIncrementalDecoderTpl::Decode(DecodableInterface *decoda // If the number of tokens in a certain frame is less than // config_.determinize_max_active, the lattice can be determinized up to this // frame. And we try to determinize as most frames as possible so we check - // numbers from frame_det_most to frame_det_least + // numbers from frame_det_most to frame_det_least. + // In the end of the utterance, all of the remaining chunks will be + // processed during the FinalizeDecoding() function later. for (int32 f = frame_det_most; f >= frame_det_least; f--) { if (config_.determinize_max_active == std::numeric_limits::max() || GetNumToksForFrame(f) < config_.determinize_max_active) { From a624b3e2e5a3d72f6db1ed5cb438c118cbe3016d Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Tue, 30 Jul 2019 19:40:33 -0400 Subject: [PATCH 35/60] bug fix --- src/decoder/lattice-incremental-decoder.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 86ec6f7842a..8e430d0bb3a 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -1536,7 +1536,7 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla !aiter_postinitial.Done(); aiter_postinitial.Next()) { auto arc_postinitial(aiter_postinitial.Value()); arc_postinitial.weight = - Times(arc_postinitial.weight, arc_appended.weight); + Times(arc_appended.weight, arc_postinitial.weight); arc_postinitial.nextstate += state_offset; olat->AddArc(source_state, arc_postinitial); if (arc_postinitial.olabel > config_.max_word_id) { From 6438a3b1958ccae7761bd46d20a6c716eaa1f83d Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Sat, 3 Aug 2019 15:56:20 -0400 Subject: [PATCH 36/60] refine --- src/bin/latgen-incremental-mapped.cc | 2 + src/decoder/lattice-incremental-decoder.cc | 63 ++++++++++++---------- src/decoder/lattice-incremental-decoder.h | 15 +++--- 3 files changed, 45 insertions(+), 35 deletions(-) diff --git a/src/bin/latgen-incremental-mapped.cc b/src/bin/latgen-incremental-mapped.cc index 6753cf49077..80c65bfb535 100644 --- a/src/bin/latgen-incremental-mapped.cc +++ b/src/bin/latgen-incremental-mapped.cc @@ -37,6 +37,8 @@ int main(int argc, char *argv[]) { const char *usage = "Generate lattices, reading log-likelihoods as matrices\n" " (model is needed only for the integer mappings in its transition-model)\n" + "The lattice determinization algorithm here can operate\n" + "incrementally.\n" "Usage: latgen-incremental-mapped [options] trans-model-in " "(fst-in|fsts-rspecifier) loglikes-rspecifier" " lattice-wspecifier [ words-wspecifier [alignments-wspecifier] ]\n"; diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 8e430d0bb3a..5abbea21993 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -85,6 +85,36 @@ void LatticeIncrementalDecoderTpl::InitDecoding() { ProcessNonemitting(config_.beam); } +template +void LatticeIncrementalDecoderTpl::DeterminizeLattice() { + // We always incrementally determinize the lattice after lattice pruning in + // PruneActiveTokens() since we need extra_cost as the weights + // of final arcs to denote the "future" information of final states (Tokens) + // Moreover, the delay on GetLattice to do determinization + // make it process more skinny lattices which reduces the computation overheads. + int32 frame_det_most = NumFramesDecoded() - config_.determinize_delay; + // The minimum length of chunk is config_.determinize_period. + if (frame_det_most % config_.determinize_period == 0) { + int32 frame_det_least = last_get_lattice_frame_ + config_.determinize_period; + // Incremental determinization: + // To adaptively decide the length of chunk, we further compare the number of + // tokens in each frame and a pre-defined threshold. + // If the number of tokens in a certain frame is less than + // config_.determinize_max_active, the lattice can be determinized up to this + // frame. And we try to determinize as most frames as possible so we check + // numbers from frame_det_most to frame_det_least + for (int32 f = frame_det_most; f >= frame_det_least; f--) { + if (config_.determinize_max_active == std::numeric_limits::max() || + GetNumToksForFrame(f) < config_.determinize_max_active) { + KALDI_VLOG(2) << "Frame: " << NumFramesDecoded() + << " incremental determinization up to " << f; + GetLattice(false, f); + break; + } + } + } + return; +} // Returns true if any kind of traceback is available (not necessarily from // a final state). It should only very rarely return false; this indicates // an unusual search error. @@ -101,34 +131,8 @@ bool LatticeIncrementalDecoderTpl::Decode(DecodableInterface *decoda PruneActiveTokens(config_.lattice_beam * config_.prune_scale); } - // We always incrementally determinize the lattice after lattice pruning in - // PruneActiveTokens() since we need extra_cost as the weights - // of final arcs to denote the "future" information of final states (Tokens) - // Moreover, the delay on GetLattice to do determinization - // make it process more skinny lattices which reduces the computation overheads. - int32 frame_det_most = NumFramesDecoded() - config_.determinize_delay; - // The minimum length of chunk is config_.determinize_period. - if (frame_det_most % config_.determinize_period == 0) { - int32 frame_det_least = - last_get_lattice_frame_ + config_.determinize_period; - // To adaptively decide the length of chunk, we further compare the number of - // tokens in each frame and a pre-defined threshold. - // If the number of tokens in a certain frame is less than - // config_.determinize_max_active, the lattice can be determinized up to this - // frame. And we try to determinize as most frames as possible so we check - // numbers from frame_det_most to frame_det_least. - // In the end of the utterance, all of the remaining chunks will be - // processed during the FinalizeDecoding() function later. - for (int32 f = frame_det_most; f >= frame_det_least; f--) { - if (config_.determinize_max_active == std::numeric_limits::max() || - GetNumToksForFrame(f) < config_.determinize_max_active) { - KALDI_VLOG(2) << "Frame: " << NumFramesDecoded() - << " incremental determinization up to " << f; - GetLattice(false, f); - break; - } - } - } + DeterminizeLattice(); + BaseFloat cost_cutoff = ProcessEmitting(decodable); ProcessNonemitting(cost_cutoff); } @@ -562,6 +566,9 @@ void LatticeIncrementalDecoderTpl::AdvanceDecoding( if (NumFramesDecoded() % config_.prune_interval == 0) { PruneActiveTokens(config_.lattice_beam * config_.prune_scale); } + + DeterminizeLattice(); + BaseFloat cost_cutoff = ProcessEmitting(decodable); ProcessNonemitting(cost_cutoff); } diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 2ae58e697d3..9f930b6610d 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -85,8 +85,8 @@ struct LatticeIncrementalDecoderConfig { opts->Register("determinize-delay", &determinize_delay, "Delay (in frames) at which to incrementally determinize " "lattices. A larger delay reduces the computational " - "overheads of incremental deteriminization while increasing" - "the length of the last chunk which may increase latencies."); + "overhead of incremental deteriminization while increasing" + "the length of the last chunk which may increase latency."); opts->Register("determinize-period", &determinize_period, "The size (in frames) of chunk to do incrementally " "determinization. If working with --determinize-max-active," @@ -508,8 +508,10 @@ class LatticeIncrementalDecoderTpl { // Get the number of tokens in each frame // It is useful, e.g. in using config_.determinize_max_active int32 GetNumToksForFrame(int32 frame); + void DeterminizeLattice(); - // The incremental lattice determinizer to take care of step 2-4 + // The incremental lattice determinizer to take care of determinization + // and appending the lattice. LatticeIncrementalDeterminizer determinizer_; int32 last_get_lattice_frame_; // the last time we call GetLattice // a map from Token to its token_label @@ -526,9 +528,8 @@ class LatticeIncrementalDecoderTpl { typedef LatticeIncrementalDecoderTpl LatticeIncrementalDecoder; -// This class is designed for step 2-4 and part of step 1 of incremental -// determinization -// introduced before above GetLattice() +// This class is designed for part of generating raw lattices and determnization +// and appending the lattice. template class LatticeIncrementalDeterminizer { public: @@ -562,7 +563,7 @@ class LatticeIncrementalDeterminizer { const unordered_map &token_label_final_cost); // This function consumes raw_fst generated by step 1 of incremental // determinization with specific initial and final arcs. - // It does step 2-4 and outputs the resultant CompactLattice if + // It processes lattices and outputs the resultant CompactLattice if // needed. Otherwise, it keeps the resultant lattice in lat_ bool ProcessChunk(Lattice &raw_fst, int32 first_frame, int32 last_frame); From 15cdab744ff03dd70592e51de829851000857e0b Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Sun, 4 Aug 2019 23:08:43 -0400 Subject: [PATCH 37/60] add online decoder --- src/decoder/Makefile | 2 +- .../lattice-incremental-online-decoder.cc | 150 +++++++++ .../lattice-incremental-online-decoder.h | 132 ++++++++ src/online2/Makefile | 2 +- src/online2/online-endpoint.cc | 44 ++- src/online2/online-endpoint.h | 14 +- src/online2/online-ivector-feature.cc | 58 ++++ src/online2/online-ivector-feature.h | 3 + .../online-nnet3-incremental-decoding.cc | 93 ++++++ .../online-nnet3-incremental-decoding.h | 128 ++++++++ src/online2bin/Makefile | 2 +- .../online2-wav-nnet3-latgen-incremental.cc | 304 ++++++++++++++++++ 12 files changed, 923 insertions(+), 9 deletions(-) create mode 100644 src/decoder/lattice-incremental-online-decoder.cc create mode 100644 src/decoder/lattice-incremental-online-decoder.h create mode 100644 src/online2/online-nnet3-incremental-decoding.cc create mode 100644 src/online2/online-nnet3-incremental-decoding.h create mode 100644 src/online2bin/online2-wav-nnet3-latgen-incremental.cc diff --git a/src/decoder/Makefile b/src/decoder/Makefile index ebac90e65ac..61e9670adba 100644 --- a/src/decoder/Makefile +++ b/src/decoder/Makefile @@ -8,7 +8,7 @@ TESTFILES = OBJFILES = training-graph-compiler.o lattice-simple-decoder.o lattice-faster-decoder.o \ lattice-faster-online-decoder.o simple-decoder.o faster-decoder.o \ decoder-wrappers.o grammar-fst.o decodable-matrix.o \ - lattice-incremental-decoder.o + lattice-incremental-decoder.o lattice-incremental-online-decoder.o LIBNAME = kaldi-decoder diff --git a/src/decoder/lattice-incremental-online-decoder.cc b/src/decoder/lattice-incremental-online-decoder.cc new file mode 100644 index 00000000000..f64b194eff2 --- /dev/null +++ b/src/decoder/lattice-incremental-online-decoder.cc @@ -0,0 +1,150 @@ +// decoder/lattice-incremental-online-decoder.cc + +// Copyright 2019 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// see note at the top of lattice-faster-decoder.cc, about how to maintain this +// file in sync with lattice-faster-decoder.cc + +#include "decoder/lattice-incremental-decoder.h" +#include "decoder/lattice-incremental-online-decoder.h" +#include "lat/lattice-functions.h" +#include "base/timer.h" + +namespace kaldi { + +// Outputs an FST corresponding to the single best path through the lattice. +template +bool LatticeIncrementalOnlineDecoderTpl::GetBestPath(Lattice *olat, + bool use_final_probs) const { + olat->DeleteStates(); + BaseFloat final_graph_cost; + BestPathIterator iter = BestPathEnd(use_final_probs, &final_graph_cost); + if (iter.Done()) + return false; // would have printed warning. + StateId state = olat->AddState(); + olat->SetFinal(state, LatticeWeight(final_graph_cost, 0.0)); + while (!iter.Done()) { + LatticeArc arc; + iter = TraceBackBestPath(iter, &arc); + arc.nextstate = state; + StateId new_state = olat->AddState(); + olat->AddArc(new_state, arc); + state = new_state; + } + olat->SetStart(state); + return true; +} + +template +typename LatticeIncrementalOnlineDecoderTpl::BestPathIterator LatticeIncrementalOnlineDecoderTpl::BestPathEnd( + bool use_final_probs, + BaseFloat *final_cost_out) const { + if (this->decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "BestPathEnd() with use_final_probs == false"; + KALDI_ASSERT(this->NumFramesDecoded() > 0 && + "You cannot call BestPathEnd if no frames were decoded."); + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (this->decoding_finalized_ ? this->final_costs_ :final_costs_local); + if (!this->decoding_finalized_ && use_final_probs) + this->ComputeFinalCosts(&final_costs_local, NULL, NULL); + + // Singly linked list of tokens on last frame (access list through "next" + // pointer). + BaseFloat best_cost = std::numeric_limits::infinity(); + BaseFloat best_final_cost = 0; + Token *best_tok = NULL; + for (Token *tok = this->active_toks_.back().toks; + tok != NULL; tok = tok->next) { + BaseFloat cost = tok->tot_cost, final_cost = 0.0; + if (use_final_probs && !final_costs.empty()) { + // if we are instructed to use final-probs, and any final tokens were + // active on final frame, include the final-prob in the cost of the token. + typename unordered_map::const_iterator + iter = final_costs.find(tok); + if (iter != final_costs.end()) { + final_cost = iter->second; + cost += final_cost; + } else { + cost = std::numeric_limits::infinity(); + } + } + if (cost < best_cost) { + best_cost = cost; + best_tok = tok; + best_final_cost = final_cost; + } + } + if (best_tok == NULL) { // this should not happen, and is likely a code error or + // caused by infinities in likelihoods, but I'm not making + // it a fatal error for now. + KALDI_WARN << "No final token found."; + } + if (final_cost_out) + *final_cost_out = best_final_cost; + return BestPathIterator(best_tok, this->NumFramesDecoded() - 1); +} + + +template +typename LatticeIncrementalOnlineDecoderTpl::BestPathIterator LatticeIncrementalOnlineDecoderTpl::TraceBackBestPath( + BestPathIterator iter, LatticeArc *oarc) const { + KALDI_ASSERT(!iter.Done() && oarc != NULL); + Token *tok = static_cast(iter.tok); + int32 cur_t = iter.frame, ret_t = cur_t; + if (tok->backpointer != NULL) { + ForwardLinkT *link; + for (link = tok->backpointer->links; + link != NULL; link = link->next) { + if (link->next_tok == tok) { // this is the link to "tok" + oarc->ilabel = link->ilabel; + oarc->olabel = link->olabel; + BaseFloat graph_cost = link->graph_cost, + acoustic_cost = link->acoustic_cost; + if (link->ilabel != 0) { + KALDI_ASSERT(static_cast(cur_t) < this->cost_offsets_.size()); + acoustic_cost -= this->cost_offsets_[cur_t]; + ret_t--; + } + oarc->weight = LatticeWeight(graph_cost, acoustic_cost); + break; + } + } + if (link == NULL) { // Did not find correct link. + KALDI_ERR << "Error tracing best-path back (likely " + << "bug in token-pruning algorithm)"; + } + } else { + oarc->ilabel = 0; + oarc->olabel = 0; + oarc->weight = LatticeWeight::One(); // zero costs. + } + return BestPathIterator(tok->backpointer, ret_t); +} + +// Instantiate the template for the FST types that we'll need. +template class LatticeIncrementalOnlineDecoderTpl >; +template class LatticeIncrementalOnlineDecoderTpl >; +template class LatticeIncrementalOnlineDecoderTpl >; +template class LatticeIncrementalOnlineDecoderTpl; + + +} // end namespace kaldi. diff --git a/src/decoder/lattice-incremental-online-decoder.h b/src/decoder/lattice-incremental-online-decoder.h new file mode 100644 index 00000000000..8bd41c851ab --- /dev/null +++ b/src/decoder/lattice-incremental-online-decoder.h @@ -0,0 +1,132 @@ +// decoder/lattice-incremental-online-decoder.h + +// Copyright 2019 Zhehuai Chen +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// see note at the top of lattice-faster-decoder.h, about how to maintain this +// file in sync with lattice-faster-decoder.h + + +#ifndef KALDI_DECODER_LATTICE_INCREMENTAL_ONLINE_DECODER_H_ +#define KALDI_DECODER_LATTICE_INCREMENTAL_ONLINE_DECODER_H_ + +#include "util/stl-utils.h" +#include "util/hash-list.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/lattice-incremental-decoder.h" + +namespace kaldi { + + + +/** LatticeIncrementalOnlineDecoderTpl is as LatticeIncrementalDecoderTpl but also + supports an efficient way to get the best path (see the function + BestPathEnd()), which is useful in endpointing and in situations where you + might want to frequently access the best path. + + This is only templated on the FST type, since the Token type is required to + be BackpointerToken. Actually it only makes sense to instantiate + LatticeIncrementalDecoderTpl with Token == BackpointerToken if you do so indirectly via + this child class. + */ +template +class LatticeIncrementalOnlineDecoderTpl: + public LatticeIncrementalDecoderTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using Token = decoder::BackpointerToken; + using ForwardLinkT = decoder::ForwardLink; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeIncrementalOnlineDecoderTpl(const FST &fst, + const TransitionModel &trans_model, + const LatticeIncrementalDecoderConfig &config): + LatticeIncrementalDecoderTpl(fst, trans_model, config) { } + + // This version of the initializer takes ownership of 'fst', and will delete + // it when this object is destroyed. + LatticeIncrementalOnlineDecoderTpl(const LatticeIncrementalDecoderConfig &config, + FST *fst, + const TransitionModel &trans_model): + LatticeIncrementalDecoderTpl(config, fst, trans_model) { } + + + struct BestPathIterator { + void *tok; + int32 frame; + // note, "frame" is the frame-index of the frame you'll get the + // transition-id for next time, if you call TraceBackBestPath on this + // iterator (assuming it's not an epsilon transition). Note that this + // is one less than you might reasonably expect, e.g. it's -1 for + // the nonemitting transitions before the first frame. + BestPathIterator(void *t, int32 f): tok(t), frame(f) { } + bool Done() { return tok == NULL; } + }; + + + /// Outputs an FST corresponding to the single best path through the lattice. + /// This is quite efficient because it doesn't get the entire raw lattice and find + /// the best path through it; instead, it uses the BestPathEnd and BestPathIterator + /// so it basically traces it back through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + bool GetBestPath(Lattice *ofst, + bool use_final_probs = true) const; + + + + /// This function returns an iterator that can be used to trace back + /// the best path. If use_final_probs == true and at least one final state + /// survived till the end, it will use the final-probs in working out the best + /// final Token, and will output the final cost to *final_cost (if non-NULL), + /// else it will use only the forward likelihood, and will put zero in + /// *final_cost (if non-NULL). + /// Requires that NumFramesDecoded() > 0. + BestPathIterator BestPathEnd(bool use_final_probs, + BaseFloat *final_cost = NULL) const; + + + /// This function can be used in conjunction with BestPathEnd() to trace back + /// the best path one link at a time (e.g. this can be useful in endpoint + /// detection). By "link" we mean a link in the graph; not all links cross + /// frame boundaries, but each time you see a nonzero ilabel you can interpret + /// that as a frame. The return value is the updated iterator. It outputs + /// the ilabel and olabel, and the (graph and acoustic) weight to the "arc" pointer, + /// while leaving its "nextstate" variable unchanged. + BestPathIterator TraceBackBestPath( + BestPathIterator iter, LatticeArc *arc) const; + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalOnlineDecoderTpl); +}; + +typedef LatticeIncrementalOnlineDecoderTpl LatticeIncrementalOnlineDecoder; + + +} // end namespace kaldi. + +#endif diff --git a/src/online2/Makefile b/src/online2/Makefile index 242c7be6da6..bbc7ac07bb1 100644 --- a/src/online2/Makefile +++ b/src/online2/Makefile @@ -9,7 +9,7 @@ OBJFILES = online-gmm-decodable.o online-feature-pipeline.o online-ivector-featu online-nnet2-feature-pipeline.o online-gmm-decoding.o online-timing.o \ online-endpoint.o onlinebin-util.o online-speex-wrapper.o \ online-nnet2-decoding.o online-nnet2-decoding-threaded.o \ - online-nnet3-decoding.o + online-nnet3-decoding.o online-nnet3-incremental-decoding.o LIBNAME = kaldi-online2 diff --git a/src/online2/online-endpoint.cc b/src/online2/online-endpoint.cc index aa7752c4484..a3be0791f03 100644 --- a/src/online2/online-endpoint.cc +++ b/src/online2/online-endpoint.cc @@ -71,10 +71,10 @@ bool EndpointDetected(const OnlineEndpointConfig &config, return false; } -template +template int32 TrailingSilenceLength(const TransitionModel &tmodel, const std::string &silence_phones_str, - const LatticeFasterOnlineDecoderTpl &decoder) { + const DEC &decoder) { std::vector silence_phones; if (!SplitStringToIntegers(silence_phones_str, ":", false, &silence_phones)) KALDI_ERR << "Bad --silence-phones option in endpointing config: " @@ -87,7 +87,7 @@ int32 TrailingSilenceLength(const TransitionModel &tmodel, ConstIntegerSet silence_set(silence_phones); bool use_final_probs = false; - typename LatticeFasterOnlineDecoderTpl::BestPathIterator iter = + typename DEC::BestPathIterator iter = decoder.BestPathEnd(use_final_probs, NULL); int32 num_silence_frames = 0; while (!iter.Done()) { // we're going backwards in time from the most @@ -117,7 +117,7 @@ bool EndpointDetected( BaseFloat final_relative_cost = decoder.FinalRelativeCost(); int32 num_frames_decoded = decoder.NumFramesDecoded(), - trailing_silence_frames = TrailingSilenceLength(tmodel, + trailing_silence_frames = TrailingSilenceLength>(tmodel, config.silence_phones, decoder); @@ -125,6 +125,26 @@ bool EndpointDetected( frame_shift_in_seconds, final_relative_cost); } +template +bool EndpointDetected( + const OnlineEndpointConfig &config, + const TransitionModel &tmodel, + BaseFloat frame_shift_in_seconds, + const LatticeIncrementalOnlineDecoderTpl &decoder) { + if (decoder.NumFramesDecoded() == 0) return false; + + BaseFloat final_relative_cost = decoder.FinalRelativeCost(); + + int32 num_frames_decoded = decoder.NumFramesDecoded(), + trailing_silence_frames = TrailingSilenceLength>(tmodel, + config.silence_phones, + decoder); + + return EndpointDetected(config, num_frames_decoded, trailing_silence_frames, + frame_shift_in_seconds, final_relative_cost); +} + + // Instantiate EndpointDetected for the types we need. // It will require TrailingSilenceLength so we don't have to instantiate that. @@ -143,5 +163,21 @@ bool EndpointDetected( BaseFloat frame_shift_in_seconds, const LatticeFasterOnlineDecoderTpl &decoder); +template +bool EndpointDetected >( + const OnlineEndpointConfig &config, + const TransitionModel &tmodel, + BaseFloat frame_shift_in_seconds, + const LatticeIncrementalOnlineDecoderTpl > &decoder); + + +template +bool EndpointDetected( + const OnlineEndpointConfig &config, + const TransitionModel &tmodel, + BaseFloat frame_shift_in_seconds, + const LatticeIncrementalOnlineDecoderTpl &decoder); + + } // namespace kaldi diff --git a/src/online2/online-endpoint.h b/src/online2/online-endpoint.h index aaf9232db13..3171f0c532c 100644 --- a/src/online2/online-endpoint.h +++ b/src/online2/online-endpoint.h @@ -35,6 +35,7 @@ #include "lat/kaldi-lattice.h" #include "hmm/transition-model.h" #include "decoder/lattice-faster-online-decoder.h" +#include "decoder/lattice-incremental-online-decoder.h" namespace kaldi { /// @addtogroup onlinedecoding OnlineDecoding @@ -187,10 +188,10 @@ bool EndpointDetected(const OnlineEndpointConfig &config, /// integer id's of phones that we consider silence. We use the the /// BestPathEnd() and TraceBackOneLink() functions of LatticeFasterOnlineDecoder /// to do this efficiently. -template +template int32 TrailingSilenceLength(const TransitionModel &tmodel, const std::string &silence_phones, - const LatticeFasterOnlineDecoderTpl &decoder); + const DEC &decoder); /// This is a higher-level convenience function that works out the @@ -202,6 +203,15 @@ bool EndpointDetected( BaseFloat frame_shift_in_seconds, const LatticeFasterOnlineDecoderTpl &decoder); +/// This is a higher-level convenience function that works out the +/// arguments to the EndpointDetected function above, from the decoder. +template +bool EndpointDetected( + const OnlineEndpointConfig &config, + const TransitionModel &tmodel, + BaseFloat frame_shift_in_seconds, + const LatticeIncrementalOnlineDecoderTpl &decoder); + diff --git a/src/online2/online-ivector-feature.cc b/src/online2/online-ivector-feature.cc index 2042fbb8b80..13a41ae4f68 100644 --- a/src/online2/online-ivector-feature.cc +++ b/src/online2/online-ivector-feature.cc @@ -510,6 +510,57 @@ void OnlineSilenceWeighting::ComputeCurrentTraceback( } } +template +void OnlineSilenceWeighting::ComputeCurrentTraceback( + const LatticeIncrementalOnlineDecoderTpl &decoder) { + int32 num_frames_decoded = decoder.NumFramesDecoded(), + num_frames_prev = frame_info_.size(); + // note, num_frames_prev is not the number of frames previously decoded, + // it's the generally-larger number of frames that we were requested to + // provide weights for. + if (num_frames_prev < num_frames_decoded) + frame_info_.resize(num_frames_decoded); + if (num_frames_prev > num_frames_decoded && + frame_info_[num_frames_decoded].transition_id != -1) + KALDI_ERR << "Number of frames decoded decreased"; // Likely bug + + if (num_frames_decoded == 0) + return; + int32 frame = num_frames_decoded - 1; + bool use_final_probs = false; + typename LatticeIncrementalOnlineDecoderTpl::BestPathIterator iter = + decoder.BestPathEnd(use_final_probs, NULL); + while (frame >= 0) { + LatticeArc arc; + arc.ilabel = 0; + while (arc.ilabel == 0) // the while loop skips over input-epsilons + iter = decoder.TraceBackBestPath(iter, &arc); + // note, the iter.frame values are slightly unintuitively defined, + // they are one less than you might expect. + KALDI_ASSERT(iter.frame == frame - 1); + + if (frame_info_[frame].token == iter.tok) { + // we know that the traceback from this point back will be identical, so + // no point tracing back further. Note: we are comparing memory addresses + // of tokens of the decoder; this guarantees it's the same exact token + // because tokens, once allocated on a frame, are only deleted, never + // reallocated for that frame. + break; + } + + if (num_frames_output_and_correct_ > frame) + num_frames_output_and_correct_ = frame; + + frame_info_[frame].token = iter.tok; + frame_info_[frame].transition_id = arc.ilabel; + frame--; + // leave frame_info_.current_weight at zero for now (as set in the + // constructor), reflecting that we haven't already output a weight for that + // frame. + } +} + + // Instantiate the template OnlineSilenceWeighting::ComputeCurrentTraceback(). template void OnlineSilenceWeighting::ComputeCurrentTraceback >( @@ -517,6 +568,13 @@ void OnlineSilenceWeighting::ComputeCurrentTraceback >( template void OnlineSilenceWeighting::ComputeCurrentTraceback( const LatticeFasterOnlineDecoderTpl &decoder); +template +void OnlineSilenceWeighting::ComputeCurrentTraceback >( + const LatticeIncrementalOnlineDecoderTpl > &decoder); +template +void OnlineSilenceWeighting::ComputeCurrentTraceback( + const LatticeIncrementalOnlineDecoderTpl &decoder); + int32 OnlineSilenceWeighting::GetBeginFrame() { int32 max_duration = config_.max_state_duration; diff --git a/src/online2/online-ivector-feature.h b/src/online2/online-ivector-feature.h index 25e078f1a98..5e674e2b7f1 100644 --- a/src/online2/online-ivector-feature.h +++ b/src/online2/online-ivector-feature.h @@ -33,6 +33,7 @@ #include "feat/online-feature.h" #include "ivector/ivector-extractor.h" #include "decoder/lattice-faster-online-decoder.h" +#include "decoder/lattice-incremental-online-decoder.h" namespace kaldi { /// @addtogroup onlinefeat OnlineFeatureExtraction @@ -471,6 +472,8 @@ class OnlineSilenceWeighting { // It will be instantiated for FST == fst::Fst and fst::GrammarFst. template void ComputeCurrentTraceback(const LatticeFasterOnlineDecoderTpl &decoder); + template + void ComputeCurrentTraceback(const LatticeIncrementalOnlineDecoderTpl &decoder); // Calling this function gets the changes in weight that require us to modify // the stats... the output format is (frame-index, delta-weight). The diff --git a/src/online2/online-nnet3-incremental-decoding.cc b/src/online2/online-nnet3-incremental-decoding.cc new file mode 100644 index 00000000000..cf6ddc6e6e2 --- /dev/null +++ b/src/online2/online-nnet3-incremental-decoding.cc @@ -0,0 +1,93 @@ +// online2/online-nnet3-decoding.cc + +// Copyright 2019 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "online2/online-nnet3-incremental-decoding.h" +#include "lat/lattice-functions.h" +#include "lat/determinize-lattice-pruned.h" +#include "decoder/grammar-fst.h" + +namespace kaldi { + +template +SingleUtteranceNnet3IncrementalDecoderTpl::SingleUtteranceNnet3IncrementalDecoderTpl( + const LatticeIncrementalDecoderConfig &decoder_opts, + const TransitionModel &trans_model, + const nnet3::DecodableNnetSimpleLoopedInfo &info, + const FST &fst, + OnlineNnet2FeaturePipeline *features): + decoder_opts_(decoder_opts), + input_feature_frame_shift_in_seconds_(features->FrameShiftInSeconds()), + trans_model_(trans_model), + decodable_(trans_model_, info, + features->InputFeature(), features->IvectorFeature()), + decoder_(fst, trans_model, decoder_opts_) { + decoder_.InitDecoding(); +} + +template +void SingleUtteranceNnet3IncrementalDecoderTpl::InitDecoding(int32 frame_offset) { + decoder_.InitDecoding(); + decodable_.SetFrameOffset(frame_offset); +} + +template +void SingleUtteranceNnet3IncrementalDecoderTpl::AdvanceDecoding() { + decoder_.AdvanceDecoding(&decodable_); +} + +template +void SingleUtteranceNnet3IncrementalDecoderTpl::FinalizeDecoding() { + decoder_.FinalizeDecoding(); +} + +template +int32 SingleUtteranceNnet3IncrementalDecoderTpl::NumFramesDecoded() const { + return decoder_.NumFramesDecoded(); +} + +template +void SingleUtteranceNnet3IncrementalDecoderTpl::GetLattice(bool end_of_utterance, + CompactLattice *clat) { + if (NumFramesDecoded() == 0) + KALDI_ERR << "You cannot get a lattice if you decoded no frames."; + decoder_.GetLattice(end_of_utterance, decoder_.NumFramesDecoded(), clat); +} + +template +void SingleUtteranceNnet3IncrementalDecoderTpl::GetBestPath(bool end_of_utterance, + Lattice *best_path) const { + decoder_.GetBestPath(best_path, end_of_utterance); +} + +template +bool SingleUtteranceNnet3IncrementalDecoderTpl::EndpointDetected( + const OnlineEndpointConfig &config) { + BaseFloat output_frame_shift = + input_feature_frame_shift_in_seconds_ * + decodable_.FrameSubsamplingFactor(); + return kaldi::EndpointDetected(config, trans_model_, + output_frame_shift, decoder_); +} + + +// Instantiate the template for the types needed. +template class SingleUtteranceNnet3IncrementalDecoderTpl >; +template class SingleUtteranceNnet3IncrementalDecoderTpl; + +} // namespace kaldi diff --git a/src/online2/online-nnet3-incremental-decoding.h b/src/online2/online-nnet3-incremental-decoding.h new file mode 100644 index 00000000000..e880b4238e8 --- /dev/null +++ b/src/online2/online-nnet3-incremental-decoding.h @@ -0,0 +1,128 @@ +// online2/online-nnet3-decoding.h + +// Copyright 2019 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_ONLINE2_ONLINE_NNET3_INCREMENTAL_DECODING_H_ +#define KALDI_ONLINE2_ONLINE_NNET3_INCREMENTAL_DECODING_H_ + +#include +#include +#include + +#include "nnet3/decodable-online-looped.h" +#include "matrix/matrix-lib.h" +#include "util/common-utils.h" +#include "base/kaldi-error.h" +#include "itf/online-feature-itf.h" +#include "online2/online-endpoint.h" +#include "online2/online-nnet2-feature-pipeline.h" +#include "decoder/lattice-incremental-online-decoder.h" +#include "hmm/transition-model.h" +#include "hmm/posterior.h" + +namespace kaldi { +/// @addtogroup onlinedecoding OnlineDecoding +/// @{ + + +/** + You will instantiate this class when you want to decode a single utterance + using the online-decoding setup for neural nets. The template will be + instantiated only for FST = fst::Fst and FST = fst::GrammarFst. +*/ + +template +class SingleUtteranceNnet3IncrementalDecoderTpl { + public: + + // Constructor. The pointer 'features' is not being given to this class to own + // and deallocate, it is owned externally. + SingleUtteranceNnet3IncrementalDecoderTpl(const LatticeIncrementalDecoderConfig &decoder_opts, + const TransitionModel &trans_model, + const nnet3::DecodableNnetSimpleLoopedInfo &info, + const FST &fst, + OnlineNnet2FeaturePipeline *features); + + /// Initializes the decoding and sets the frame offset of the underlying + /// decodable object. This method is called by the constructor. You can also + /// call this method when you want to reset the decoder state, but want to + /// keep using the same decodable object, e.g. in case of an endpoint. + void InitDecoding(int32 frame_offset = 0); + + /// Advances the decoding as far as we can. + void AdvanceDecoding(); + + /// Finalizes the decoding. Cleans up and prunes remaining tokens, so the + /// GetLattice() call will return faster. You must not call this before + /// calling (TerminateDecoding() or InputIsFinished()) and then Wait(). + void FinalizeDecoding(); + + int32 NumFramesDecoded() const; + + /// Gets the lattice. The output lattice has any acoustic scaling in it + /// (which will typically be desirable in an online-decoding context); if you + /// want an un-scaled lattice, scale it using ScaleLattice() with the inverse + /// of the acoustic weight. "end_of_utterance" will be true if you want the + /// final-probs to be included. + void GetLattice(bool end_of_utterance, + CompactLattice *clat); + + /// Outputs an FST corresponding to the single best path through the current + /// lattice. If "use_final_probs" is true AND we reached the final-state of + /// the graph then it will include those as final-probs, else it will treat + /// all final-probs as one. + void GetBestPath(bool end_of_utterance, + Lattice *best_path) const; + + + /// This function calls EndpointDetected from online-endpoint.h, + /// with the required arguments. + bool EndpointDetected(const OnlineEndpointConfig &config); + + const LatticeIncrementalOnlineDecoderTpl &Decoder() const { return decoder_; } + + ~SingleUtteranceNnet3IncrementalDecoderTpl() { } + private: + + const LatticeIncrementalDecoderConfig &decoder_opts_; + + // this is remembered from the constructor; it's ultimately + // derived from calling FrameShiftInSeconds() on the feature pipeline. + BaseFloat input_feature_frame_shift_in_seconds_; + + // we need to keep a reference to the transition model around only because + // it's needed by the endpointing code. + const TransitionModel &trans_model_; + + nnet3::DecodableAmNnetLoopedOnline decodable_; + + LatticeIncrementalOnlineDecoderTpl decoder_; + +}; + + +typedef SingleUtteranceNnet3IncrementalDecoderTpl > SingleUtteranceNnet3IncrementalDecoder; + +/// @} End of "addtogroup onlinedecoding" + +} // namespace kaldi + + + +#endif // KALDI_ONLINE2_ONLINE_NNET3_DECODING_H_ diff --git a/src/online2bin/Makefile b/src/online2bin/Makefile index 28c135eb950..2552e7148dc 100644 --- a/src/online2bin/Makefile +++ b/src/online2bin/Makefile @@ -12,7 +12,7 @@ BINFILES = online2-wav-gmm-latgen-faster apply-cmvn-online \ online2-wav-dump-features ivector-randomize \ online2-wav-nnet2-am-compute online2-wav-nnet2-latgen-threaded \ online2-wav-nnet3-latgen-faster online2-wav-nnet3-latgen-grammar \ - online2-tcp-nnet3-decode-faster + online2-tcp-nnet3-decode-faster online2-wav-nnet3-latgen-incremental OBJFILES = diff --git a/src/online2bin/online2-wav-nnet3-latgen-incremental.cc b/src/online2bin/online2-wav-nnet3-latgen-incremental.cc new file mode 100644 index 00000000000..b48337af5fb --- /dev/null +++ b/src/online2bin/online2-wav-nnet3-latgen-incremental.cc @@ -0,0 +1,304 @@ +// online2bin/online2-wav-nnet3-latgen-incremental.cc + +// Copyright 2019 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "feat/wave-reader.h" +#include "online2/online-nnet3-incremental-decoding.h" +#include "online2/online-nnet2-feature-pipeline.h" +#include "online2/onlinebin-util.h" +#include "online2/online-timing.h" +#include "online2/online-endpoint.h" +#include "fstext/fstext-lib.h" +#include "lat/lattice-functions.h" +#include "util/kaldi-thread.h" +#include "nnet3/nnet-utils.h" + +namespace kaldi { + +void GetDiagnosticsAndPrintOutput(const std::string &utt, + const fst::SymbolTable *word_syms, + const CompactLattice &clat, + int64 *tot_num_frames, + double *tot_like) { + if (clat.NumStates() == 0) { + KALDI_WARN << "Empty lattice."; + return; + } + CompactLattice best_path_clat; + CompactLatticeShortestPath(clat, &best_path_clat); + + Lattice best_path_lat; + ConvertLattice(best_path_clat, &best_path_lat); + + double likelihood; + LatticeWeight weight; + int32 num_frames; + std::vector alignment; + std::vector words; + GetLinearSymbolSequence(best_path_lat, &alignment, &words, &weight); + num_frames = alignment.size(); + likelihood = -(weight.Value1() + weight.Value2()); + *tot_num_frames += num_frames; + *tot_like += likelihood; + KALDI_VLOG(2) << "Likelihood per frame for utterance " << utt << " is " + << (likelihood / num_frames) << " over " << num_frames + << " frames."; + + if (word_syms != NULL) { + std::cerr << utt << ' '; + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms->Find(words[i]); + if (s == "") + KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; + std::cerr << s << ' '; + } + std::cerr << std::endl; + } +} + +} + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace fst; + + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "Reads in wav file(s) and simulates online decoding with neural nets\n" + "(nnet3 setup), with optional iVector-based speaker adaptation and\n" + "optional endpointing. Note: some configuration values and inputs are\n" + "set via config files whose filenames are passed as options\n" + "The lattice determinization algorithm here can operate\n" + "incrementally.\n" + "\n" + "Usage: online2-wav-nnet3-latgen-incremental [options] " + " \n" + "The spk2utt-rspecifier can just be if\n" + "you want to decode utterance by utterance.\n"; + + ParseOptions po(usage); + + std::string word_syms_rxfilename; + + // feature_opts includes configuration for the iVector adaptation, + // as well as the basic features. + OnlineNnet2FeaturePipelineConfig feature_opts; + nnet3::NnetSimpleLoopedComputationOptions decodable_opts; + LatticeIncrementalDecoderConfig decoder_opts; + OnlineEndpointConfig endpoint_opts; + + BaseFloat chunk_length_secs = 0.18; + bool do_endpointing = false; + bool online = true; + + po.Register("chunk-length", &chunk_length_secs, + "Length of chunk size in seconds, that we process. Set to <= 0 " + "to use all input in one chunk."); + po.Register("word-symbol-table", &word_syms_rxfilename, + "Symbol table for words [for debug output]"); + po.Register("do-endpointing", &do_endpointing, + "If true, apply endpoint detection"); + po.Register("online", &online, + "You can set this to false to disable online iVector estimation " + "and have all the data for each utterance used, even at " + "utterance start. This is useful where you just want the best " + "results and don't care about online operation. Setting this to " + "false has the same effect as setting " + "--use-most-recent-ivector=true and --greedy-ivector-extractor=true " + "in the file given to --ivector-extraction-config, and " + "--chunk-length=-1."); + po.Register("num-threads-startup", &g_num_threads, + "Number of threads used when initializing iVector extractor."); + + feature_opts.Register(&po); + decodable_opts.Register(&po); + decoder_opts.Register(&po); + endpoint_opts.Register(&po); + + + po.Read(argc, argv); + + if (po.NumArgs() != 5) { + po.PrintUsage(); + return 1; + } + + std::string nnet3_rxfilename = po.GetArg(1), + fst_rxfilename = po.GetArg(2), + spk2utt_rspecifier = po.GetArg(3), + wav_rspecifier = po.GetArg(4), + clat_wspecifier = po.GetArg(5); + + OnlineNnet2FeaturePipelineInfo feature_info(feature_opts); + + if (!online) { + feature_info.ivector_extractor_info.use_most_recent_ivector = true; + feature_info.ivector_extractor_info.greedy_ivector_extractor = true; + chunk_length_secs = -1.0; + } + + TransitionModel trans_model; + nnet3::AmNnetSimple am_nnet; + { + bool binary; + Input ki(nnet3_rxfilename, &binary); + trans_model.Read(ki.Stream(), binary); + am_nnet.Read(ki.Stream(), binary); + SetBatchnormTestMode(true, &(am_nnet.GetNnet())); + SetDropoutTestMode(true, &(am_nnet.GetNnet())); + nnet3::CollapseModel(nnet3::CollapseModelConfig(), &(am_nnet.GetNnet())); + } + + // this object contains precomputed stuff that is used by all decodable + // objects. It takes a pointer to am_nnet because if it has iVectors it has + // to modify the nnet to accept iVectors at intervals. + nnet3::DecodableNnetSimpleLoopedInfo decodable_info(decodable_opts, + &am_nnet); + + + fst::Fst *decode_fst = ReadFstKaldiGeneric(fst_rxfilename); + + fst::SymbolTable *word_syms = NULL; + if (word_syms_rxfilename != "") + if (!(word_syms = fst::SymbolTable::ReadText(word_syms_rxfilename))) + KALDI_ERR << "Could not read symbol table from file " + << word_syms_rxfilename; + + int32 num_done = 0, num_err = 0; + double tot_like = 0.0; + int64 num_frames = 0; + + SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier); + RandomAccessTableReader wav_reader(wav_rspecifier); + CompactLatticeWriter clat_writer(clat_wspecifier); + + OnlineTimingStats timing_stats; + + for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) { + std::string spk = spk2utt_reader.Key(); + const std::vector &uttlist = spk2utt_reader.Value(); + OnlineIvectorExtractorAdaptationState adaptation_state( + feature_info.ivector_extractor_info); + for (size_t i = 0; i < uttlist.size(); i++) { + std::string utt = uttlist[i]; + if (!wav_reader.HasKey(utt)) { + KALDI_WARN << "Did not find audio for utterance " << utt; + num_err++; + continue; + } + const WaveData &wave_data = wav_reader.Value(utt); + // get the data for channel zero (if the signal is not mono, we only + // take the first channel). + SubVector data(wave_data.Data(), 0); + + OnlineNnet2FeaturePipeline feature_pipeline(feature_info); + feature_pipeline.SetAdaptationState(adaptation_state); + + OnlineSilenceWeighting silence_weighting( + trans_model, + feature_info.silence_weighting_config, + decodable_opts.frame_subsampling_factor); + + SingleUtteranceNnet3IncrementalDecoder decoder(decoder_opts, trans_model, + decodable_info, + *decode_fst, &feature_pipeline); + OnlineTimer decoding_timer(utt); + + BaseFloat samp_freq = wave_data.SampFreq(); + int32 chunk_length; + if (chunk_length_secs > 0) { + chunk_length = int32(samp_freq * chunk_length_secs); + if (chunk_length == 0) chunk_length = 1; + } else { + chunk_length = std::numeric_limits::max(); + } + + int32 samp_offset = 0; + std::vector > delta_weights; + + while (samp_offset < data.Dim()) { + int32 samp_remaining = data.Dim() - samp_offset; + int32 num_samp = chunk_length < samp_remaining ? chunk_length + : samp_remaining; + + SubVector wave_part(data, samp_offset, num_samp); + feature_pipeline.AcceptWaveform(samp_freq, wave_part); + + samp_offset += num_samp; + decoding_timer.WaitUntil(samp_offset / samp_freq); + if (samp_offset == data.Dim()) { + // no more input. flush out last frames + feature_pipeline.InputFinished(); + } + + if (silence_weighting.Active() && + feature_pipeline.IvectorFeature() != NULL) { + silence_weighting.ComputeCurrentTraceback(decoder.Decoder()); + silence_weighting.GetDeltaWeights(feature_pipeline.NumFramesReady(), + &delta_weights); + feature_pipeline.IvectorFeature()->UpdateFrameWeights(delta_weights); + } + + decoder.AdvanceDecoding(); + + if (do_endpointing && decoder.EndpointDetected(endpoint_opts)) { + break; + } + } + decoder.FinalizeDecoding(); + + CompactLattice clat; + bool end_of_utterance = true; + decoder.GetLattice(end_of_utterance, &clat); + + GetDiagnosticsAndPrintOutput(utt, word_syms, clat, + &num_frames, &tot_like); + + decoding_timer.OutputStats(&timing_stats); + + // In an application you might avoid updating the adaptation state if + // you felt the utterance had low confidence. See lat/confidence.h + feature_pipeline.GetAdaptationState(&adaptation_state); + + // we want to output the lattice with un-scaled acoustics. + BaseFloat inv_acoustic_scale = + 1.0 / decodable_opts.acoustic_scale; + ScaleLattice(AcousticLatticeScale(inv_acoustic_scale), &clat); + + clat_writer.Write(utt, clat); + KALDI_LOG << "Decoded utterance " << utt; + num_done++; + } + } + timing_stats.Print(online); + + KALDI_LOG << "Decoded " << num_done << " utterances, " + << num_err << " with errors."; + KALDI_LOG << "Overall likelihood per frame was " << (tot_like / num_frames) + << " per frame over " << num_frames << " frames."; + delete decode_fst; + delete word_syms; // will delete if non-NULL. + return (num_done != 0 ? 0 : 1); + } catch(const std::exception& e) { + std::cerr << e.what(); + return -1; + } +} // main() From 95663706ffc3879244d9bf54cf14a0f9edeb147c Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Fri, 6 Sep 2019 07:47:39 -0400 Subject: [PATCH 38/60] refine --- src/online2/online-nnet3-incremental-decoding.cc | 2 +- src/online2/online-nnet3-incremental-decoding.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/online2/online-nnet3-incremental-decoding.cc b/src/online2/online-nnet3-incremental-decoding.cc index cf6ddc6e6e2..540a3a4f850 100644 --- a/src/online2/online-nnet3-incremental-decoding.cc +++ b/src/online2/online-nnet3-incremental-decoding.cc @@ -1,4 +1,4 @@ -// online2/online-nnet3-decoding.cc +// online2/online-nnet3-incremental-decoding.cc // Copyright 2019 Zhehuai Chen diff --git a/src/online2/online-nnet3-incremental-decoding.h b/src/online2/online-nnet3-incremental-decoding.h index e880b4238e8..ddd9707bf54 100644 --- a/src/online2/online-nnet3-incremental-decoding.h +++ b/src/online2/online-nnet3-incremental-decoding.h @@ -1,4 +1,4 @@ -// online2/online-nnet3-decoding.h +// online2/online-nnet3-incremental-decoding.h // Copyright 2019 Zhehuai Chen From 10a597abeb35a026f0abd7c78f5b738e15049ab7 Mon Sep 17 00:00:00 2001 From: Zhehuai Chen Date: Tue, 1 Oct 2019 19:44:58 -0400 Subject: [PATCH 39/60] refine --- src/decoder/lattice-incremental-online-decoder.cc | 2 +- src/online2/online-ivector-feature.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/decoder/lattice-incremental-online-decoder.cc b/src/decoder/lattice-incremental-online-decoder.cc index f64b194eff2..85f902bde3d 100644 --- a/src/decoder/lattice-incremental-online-decoder.cc +++ b/src/decoder/lattice-incremental-online-decoder.cc @@ -98,7 +98,7 @@ typename LatticeIncrementalOnlineDecoderTpl::BestPathIterator LatticeIncrem // it a fatal error for now. KALDI_WARN << "No final token found."; } - if (final_cost_out) + if (final_cost_out == NULL) *final_cost_out = best_final_cost; return BestPathIterator(best_tok, this->NumFramesDecoded() - 1); } diff --git a/src/online2/online-ivector-feature.cc b/src/online2/online-ivector-feature.cc index 13a41ae4f68..fb1b7d9225d 100644 --- a/src/online2/online-ivector-feature.cc +++ b/src/online2/online-ivector-feature.cc @@ -542,7 +542,7 @@ void OnlineSilenceWeighting::ComputeCurrentTraceback( if (frame_info_[frame].token == iter.tok) { // we know that the traceback from this point back will be identical, so // no point tracing back further. Note: we are comparing memory addresses - // of tokens of the decoder; this guarantees it's the same exact token + // of tokens of the decoder; this guarantees it's the same exact token, // because tokens, once allocated on a frame, are only deleted, never // reallocated for that frame. break; From b0799805a2074ecad79111ba99656272a525703f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 7 Nov 2019 12:30:14 -0800 Subject: [PATCH 40/60] Some initial work on rewriting incremental determinization --- src/decoder/decoder-wrappers.cc | 7 +- src/decoder/lattice-faster-decoder.h | 10 +- src/decoder/lattice-incremental-decoder.cc | 233 ++++---- src/decoder/lattice-incremental-decoder.h | 650 +++++++++++---------- src/lat/determinize-lattice-pruned.h | 10 +- 5 files changed, 463 insertions(+), 447 deletions(-) diff --git a/src/decoder/decoder-wrappers.cc b/src/decoder/decoder-wrappers.cc index e34c83f20d8..15465b88635 100644 --- a/src/decoder/decoder-wrappers.cc +++ b/src/decoder/decoder-wrappers.cc @@ -232,7 +232,8 @@ bool DecodeUtteranceLatticeIncremental( int32 num_frames; { // First do some stuff with word-level traceback... VectorFst decoded; - if (!decoder.GetBestPath(&decoded)) + decoder.GetBestPath(&decoded); + if (decoded.Start() == fst::kNoStateId) // Shouldn't really reach this point as already checked success. KALDI_ERR << "Failed to get traceback for utterance " << utt; @@ -258,9 +259,9 @@ bool DecodeUtteranceLatticeIncremental( likelihood = -(weight.Value1() + weight.Value2()); } - // Get lattice, and do determinization if requested. + // Get lattice CompactLattice clat; - decoder.GetLattice(&clat); + decoder.GetLattice(true, decoder.NumFramesDecoded(), &clat); if (clat.NumStates() == 0) KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt; // We'll write the lattice without acoustic scaling. diff --git a/src/decoder/lattice-faster-decoder.h b/src/decoder/lattice-faster-decoder.h index e0cf7dea8d6..d6bac1bca5d 100644 --- a/src/decoder/lattice-faster-decoder.h +++ b/src/decoder/lattice-faster-decoder.h @@ -43,11 +43,13 @@ struct LatticeFasterDecoderConfig { int32 prune_interval; bool determinize_lattice; // not inspected by this class... used in // command-line program. - BaseFloat beam_delta; // has nothing to do with beam_ratio + BaseFloat beam_delta; BaseFloat hash_ratio; - BaseFloat prune_scale; // Note: we don't make this configurable on the command line, - // it's not a very important parameter. It affects the - // algorithm that prunes the tokens as we go. + // Note: we don't make prune_scale configurable on the command line, it's not + // a very important parameter. It affects the algorithm that prunes the + // tokens as we go. + BaseFloat prune_scale; + // Most of the options inside det_opts are not actually queried by the // LatticeFasterDecoder class itself, but by the code that calls it, for // example in the function DecodeUtteranceLatticeFaster. diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 5abbea21993..8c2404d9a76 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -30,8 +30,8 @@ LatticeIncrementalDecoderTpl::LatticeIncrementalDecoderTpl( const LatticeIncrementalDecoderConfig &config) : fst_(&fst), delete_fst_(false), - config_(config), num_toks_(0), + config_(config), determinizer_(config, trans_model) { config.Check(); toks_.SetSize(1000); // just so on the first frame we do something reasonable. @@ -43,8 +43,8 @@ LatticeIncrementalDecoderTpl::LatticeIncrementalDecoderTpl( const TransitionModel &trans_model) : fst_(fst), delete_fst_(true), - config_(config), num_toks_(0), + config_(config), determinizer_(config, trans_model) { config.Check(); toks_.SetSize(1000); // just so on the first frame we do something reasonable. @@ -76,10 +76,10 @@ void LatticeIncrementalDecoderTpl::InitDecoding() { num_toks_++; last_get_lattice_frame_ = 0; - token_label_map_.clear(); - token_label_map_.reserve(std::min((int32)1e5, config_.max_active)); + token2label_map_.clear(); + token2label_map_.reserve(std::min((int32)1e5, config_.max_active)); token_label_available_idx_ = config_.max_word_id + 1; - token_label_final_cost_.clear(); + token_label2final_cost_.clear(); determinizer_.Init(); ProcessNonemitting(config_.beam); @@ -149,13 +149,12 @@ bool LatticeIncrementalDecoderTpl::Decode(DecodableInterface *decoda // Outputs an FST corresponding to the single best path through the lattice. template -bool LatticeIncrementalDecoderTpl::GetBestPath(Lattice *olat, +void LatticeIncrementalDecoderTpl::GetBestPath(Lattice *olat, bool use_final_probs) { CompactLattice lat, slat; GetLattice(use_final_probs, NumFramesDecoded(), &lat); ShortestPath(lat, &slat); ConvertLattice(slat, olat); - return (olat->NumStates() != 0); } template @@ -937,34 +936,27 @@ void LatticeIncrementalDecoderTpl::TopSortTokens( } template -bool LatticeIncrementalDecoderTpl::GetLattice(CompactLattice *olat) { - return GetLattice(true, NumFramesDecoded(), olat); -} - -template -bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, +void LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, int32 last_frame_of_chunk, CompactLattice *olat) { + olat->DeleteStates(); /* Clear the FST */ + KALDI_ASSERT(olat->Start() == fst::kNoStateId); // TODO: remove using namespace fst; - bool not_first_chunk = last_get_lattice_frame_ != 0; - bool ret = true; + bool first_chunk = last_get_lattice_frame_ == 0; - // last_get_lattice_frame_ is used to record the first frame of the chunk - // last time we obtain from calling this function. If it reaches - // last_frame_of_chunk - // we cannot generate any more chunk + KALDI_ASSERT(last_get_lattice_frame_ <= last_frame_of_chunk); if (last_get_lattice_frame_ < last_frame_of_chunk) { Lattice raw_fst; // step 1: Get lattice chunk with initial and final states // In this function, we do not create the initial state in // the first chunk, and we do not create the final state in the last chunk if (!GetIncrementalRawLattice(&raw_fst, use_final_probs, last_get_lattice_frame_, - last_frame_of_chunk, not_first_chunk, + last_frame_of_chunk, !first_chunk, !decoding_finalized_)) KALDI_ERR << "Unexpected problem when getting lattice"; // step 2-3 - ret = determinizer_.ProcessChunk(raw_fst, last_get_lattice_frame_, - last_frame_of_chunk); + determinizer_.AcceptRawLatticeChunk(last_get_lattice_frame_, + last_frame_of_chunk, &raw_fst); last_get_lattice_frame_ = last_frame_of_chunk; } else if (last_get_lattice_frame_ > last_frame_of_chunk) { KALDI_WARN << "Call GetLattice up to frame: " << last_frame_of_chunk @@ -972,22 +964,11 @@ bool LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, << last_get_lattice_frame_; } - // step 4 - if (decoding_finalized_) ret &= determinizer_.Finalize(); - if (olat) { - *olat = determinizer_.GetDeterminizedLattice(); - ret &= (olat->NumStates() > 0); - } - if (!ret) { - KALDI_WARN << "Last chunk processing failed." - << " We will retry from frame 0."; - // Reset determinizer_ and re-determinize from - // frame 0 to last_frame_of_chunk - last_get_lattice_frame_ = 0; - determinizer_.Init(); - } + if (decoding_finalized_) + determinizer_.Finalize(); - return ret; + if (olat) + *olat = determinizer_.GetDeterminizedLattice(); } template @@ -1011,12 +992,12 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( ComputeFinalCosts(&final_costs_local, NULL, NULL); ofst->DeleteStates(); - unordered_multimap - token_label2last_state; // for GetInitialRawLattice + unordered_map + token_label2state; // for InitializeRawLatticeChunk // initial arcs for the chunk if (create_initial_state) - determinizer_.GetInitialRawLattice(ofst, &token_label2last_state, - token_label_final_cost_); + determinizer_.InitializeRawLatticeChunk(ofst, token_label2final_cost_, + &token_label2state); // num-frames plus one (since frames are one-based, and we have // an extra frame for the start-state). KALDI_ASSERT(frame_end > 0); @@ -1047,11 +1028,11 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( if (create_initial_state) { for (Token *tok = active_toks_[frame_begin].toks; tok != NULL; tok = tok->next) { StateId cur_state = tok_map[tok]; - // token_label_map_ is construct during create_final_state - auto r = token_label_map_.find(tok); - KALDI_ASSERT(r != token_label_map_.end()); // it should exist + // token2label_map_ is construct during create_final_state + auto r = token2label_map_.find(tok); + KALDI_ASSERT(r != token2label_map_.end()); // it should exist int32 token_label = r->second; - auto range = token_label2last_state.equal_range(token_label); + auto range = token_label2state.equal_range(token_label); if (range.first == range.second) { KALDI_WARN << "The token in the first frame of this chunk does not " @@ -1063,7 +1044,7 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( for (auto it = range.first; it != range.second; ++it) { // the destination state of the last of the sequence of arcs w.r.t the token // label - // here created by GetInitialRawLattice + // here created by InitializeRawLatticeChunk auto state_last_initial = it->second; // connect it to the state correponding to the token w.r.t the token label // here @@ -1120,14 +1101,14 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( StateId end_state = ofst->AddState(); // final-state for the chunk ofst->SetFinal(end_state, Weight::One()); - token_label_map_.clear(); - token_label_map_.reserve(std::min((int32)1e5, config_.max_active)); + token2label_map_.clear(); + token2label_map_.reserve(std::min((int32)1e5, config_.max_active)); for (Token *tok = active_toks_[frame_end].toks; tok != NULL; tok = tok->next) { StateId cur_state = tok_map[tok]; // We assign an unique state label for each of the token in the last frame // of this chunk int32 id = token_label_available_idx_++; - token_label_map_[tok] = id; + token2label_map_[tok] = id; // The final weight has been worked out in the previous for loop and // store in the states // Here, we create a specific final state, and move the final costs to @@ -1141,12 +1122,13 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( BaseFloat cost_offset = tok->extra_cost - tok->tot_cost; // We record these cost_offset, and after we appending two chunks // we will cancel them out - token_label_final_cost_[id] = cost_offset; + token_label2final_cost_[id] = cost_offset; Arc arc(0, id, Times(final_weight, Weight(0, cost_offset)), end_state); ofst->AddArc(cur_state, arc); ofst->SetFinal(cur_state, Weight::Zero()); } } + // TODO: clean up maps used internally. TopSortLatticeIfNeeded(ofst); return (ofst->NumStates() > 0); } @@ -1167,20 +1149,20 @@ template void LatticeIncrementalDeterminizer::Init() { final_arc_list_.clear(); final_arc_list_prev_.clear(); - lat_.DeleteStates(); + clat_.DeleteStates(); determinization_finalized_ = false; forward_costs_.clear(); state_last_initial_offset_ = 2 * config_.max_word_id; - redeterminized_states_.clear(); + redeterminized_state_map_.clear(); processed_prefinal_states_.clear(); } template -bool LatticeIncrementalDeterminizer::AddRedeterminizedState( +bool LatticeIncrementalDeterminizer::FindOrAddRedeterminizedState( Lattice::StateId nextstate, Lattice *olat, Lattice::StateId *nextstate_copy) { using namespace fst; bool modified = false; StateId nextstate_insert = kNoStateId; - auto r = redeterminized_states_.insert({nextstate, nextstate_insert}); + auto r = redeterminized_state_map_.insert({nextstate, nextstate_insert}); if (r.second) { // didn't exist, successfully insert here // create a new state w.r.t state nextstate_insert = olat->AddState(); @@ -1198,10 +1180,10 @@ bool LatticeIncrementalDeterminizer::AddRedeterminizedState( } template -void LatticeIncrementalDeterminizer::GetRawLatticeForRedeterminizedStates( - StateId start_state, StateId state, - const unordered_map &token_label_final_cost, - unordered_multimap *token_label2last_state, +void LatticeIncrementalDeterminizer::ProcessRedeterminizedState( + Lattice::StateId state, + const unordered_map &token_label2final_cost, + unordered_map *token_label2state, Lattice *olat) { using namespace fst; typedef LatticeArc Arc; @@ -1209,11 +1191,11 @@ void LatticeIncrementalDeterminizer::GetRawLatticeForRedeterminizedStates( typedef Arc::Weight Weight; typedef Arc::Label Label; - auto r = redeterminized_states_.find(state); - KALDI_ASSERT(r != redeterminized_states_.end()); + auto r = redeterminized_state_map_.find(state); + KALDI_ASSERT(r != redeterminized_state_map_.end()); auto state_copy = r->second; KALDI_ASSERT(state_copy != kNoStateId); - ArcIterator aiter(lat_, state); + ArcIterator aiter(clat_, state); // use state_label in initial arcs int state_label = state + state_last_initial_offset_; @@ -1222,6 +1204,7 @@ void LatticeIncrementalDeterminizer::GetRawLatticeForRedeterminizedStates( KALDI_ASSERT(state < forward_costs_.size()); auto alpha_cost = forward_costs_[state]; Arc arc_initial(0, state_label, LatticeWeight(0, alpha_cost), state_copy); + Lattice::StateId start_state = olat->Start(); if (alpha_cost != std::numeric_limits::infinity()) olat->AddArc(start_state, arc_initial); @@ -1234,8 +1217,8 @@ void LatticeIncrementalDeterminizer::GetRawLatticeForRedeterminizedStates( KALDI_ASSERT(arc.olabel == arc.ilabel); auto arc_olabel = arc.olabel; - // the destination of the arc is the final state - if (lat_.Final(arc.nextstate) != CompactLatticeWeight::Zero()) { + // the destination of the arc is a final -> a "splice state". + if (clat_.Final(arc.nextstate) != CompactLatticeWeight::Zero()) { KALDI_ASSERT(arc_olabel > config_.max_word_id && arc_olabel < state_last_initial_offset_); // token label // create a initial arc @@ -1245,14 +1228,14 @@ void LatticeIncrementalDeterminizer::GetRawLatticeForRedeterminizedStates( CompactLatticeWeight weight_offset; // To cancel out the weight on the final arcs, which is (extra cost - forward // cost). - // see token_label_final_cost for more details - const auto r = token_label_final_cost.find(arc_olabel); - KALDI_ASSERT(r != token_label_final_cost.end()); + // see token_label2final_cost for more details + const auto r = token_label2final_cost.find(arc_olabel); + KALDI_ASSERT(r != token_label2final_cost.end()); auto cost_offset = r->second; weight_offset.SetWeight(LatticeWeight(0, -cost_offset)); // The arc weight is a combination of original arc weight, above cost_offset // and the weights on the final state - arc_weight = Times(Times(arc_weight, lat_.Final(arc.nextstate)), weight_offset); + arc_weight = Times(Times(arc_weight, clat_.Final(arc.nextstate)), weight_offset); // We create a respective destination state for each final arc // later we will connect it to the state correponding to the token w.r.t @@ -1262,7 +1245,7 @@ void LatticeIncrementalDeterminizer::GetRawLatticeForRedeterminizedStates( // and connected to the state corresponding to token w.r.t arc_olabel // Notably, we have multiple states for one token label after determinization, // hence we use multiset here - token_label2last_state->insert( + token_label2state->insert( std::pair(arc_olabel, laststate_copy)); arc_olabel = 0; // remove token label } else { @@ -1271,7 +1254,7 @@ void LatticeIncrementalDeterminizer::GetRawLatticeForRedeterminizedStates( KALDI_ASSERT(arc_olabel); // get the nextstate_copy w.r.t arc.nextstate StateId nextstate_copy = kNoStateId; - proc_nextstate = AddRedeterminizedState(arc.nextstate, olat, &nextstate_copy); + proc_nextstate = FindOrAddRedeterminizedState(arc.nextstate, olat, &nextstate_copy); KALDI_ASSERT(nextstate_copy != kNoStateId); laststate_copy = nextstate_copy; } @@ -1294,10 +1277,12 @@ void LatticeIncrementalDeterminizer::GetRawLatticeForRedeterminizedStates( olat->AddArc(prev_state, arc_last); // not final state && previously didn't process this state + + // TODO: verify that the following call is not necessary. if (proc_nextstate) - GetRawLatticeForRedeterminizedStates(start_state, arc.nextstate, - token_label_final_cost, - token_label2last_state, olat); + ProcessRedeterminizedState(arc.nextstate, + token_label2final_cost, + token_label2state, olat); } } template @@ -1310,11 +1295,11 @@ void LatticeIncrementalDeterminizer::GetRedeterminizedStates() { for (auto &i : final_arc_list_prev_) { auto prefinal_state = i.first; - ArcIterator aiter(lat_, prefinal_state); - KALDI_ASSERT(lat_.NumArcs(prefinal_state) > i.second); + ArcIterator aiter(clat_, prefinal_state); + KALDI_ASSERT(clat_.NumArcs(prefinal_state) > i.second); aiter.Seek(i.second); auto final_arc = aiter.Value(); - auto final_weight = lat_.Final(final_arc.nextstate); + auto final_weight = clat_.Final(final_arc.nextstate); KALDI_ASSERT(final_weight != CompactLatticeWeight::Zero()); auto num_frames = Times(final_arc.weight, final_weight).String().size(); // If the state is too far from the end of the current appended lattice, @@ -1328,18 +1313,18 @@ void LatticeIncrementalDeterminizer::GetRedeterminizedStates() { << " on how far back in time we will redeterminize states. " << num_frames << " frames in this arc. "; - auto new_prefinal_state = lat_.AddState(); + auto new_prefinal_state = clat_.AddState(); forward_costs_.resize(new_prefinal_state + 1); forward_costs_[new_prefinal_state] = forward_costs_[prefinal_state]; - std::vector arcs_remained; + std::vector arcs_remaining; for (aiter.Reset(); !aiter.Done(); aiter.Next()) { auto arc = aiter.Value(); bool remain_the_arc = true; // If we remain the arc, the state will not be // re-determinized, vice versa. if (arc.olabel > config_.max_word_id) { // final arc KALDI_ASSERT(arc.olabel < state_last_initial_offset_); - KALDI_ASSERT(lat_.Final(arc.nextstate) != CompactLatticeWeight::Zero()); + KALDI_ASSERT(clat_.Final(arc.nextstate) != CompactLatticeWeight::Zero()); remain_the_arc = false; } else { int num_frames_exclude_arc = num_frames - arc.weight.String().size(); @@ -1361,30 +1346,31 @@ void LatticeIncrementalDeterminizer::GetRedeterminizedStates() { } if (remain_the_arc) - arcs_remained.push_back(arc); + arcs_remaining.push_back(arc); else - lat_.AddArc(new_prefinal_state, arc); + clat_.AddArc(new_prefinal_state, arc); } CompactLatticeArc arc_to_new(0, 0, CompactLatticeWeight::One(), new_prefinal_state); - arcs_remained.push_back(arc_to_new); + arcs_remaining.push_back(arc_to_new); - lat_.DeleteArcs(prefinal_state); - for (auto &i : arcs_remained) lat_.AddArc(prefinal_state, i); + clat_.DeleteArcs(prefinal_state); + for (auto &i : arcs_remaining) + clat_.AddArc(prefinal_state, i); processed_prefinal_states_[prefinal_state] = new_prefinal_state; } } KALDI_VLOG(8) << "states of the lattice after GetRedeterminizedStates: " - << lat_.NumStates(); + << clat_.NumStates(); } // This function is specifically designed to obtain the initial arcs for a chunk // We have multiple states for one token label after determinization template -void LatticeIncrementalDeterminizer::GetInitialRawLattice( +void LatticeIncrementalDeterminizer::InitializeRawLatticeChunk( Lattice *olat, - unordered_multimap *token_label2last_state, - const unordered_map &token_label_final_cost) { + const unordered_map &token_label2final_cost, + unordered_map *token_label2state) { using namespace fst; typedef LatticeArc Arc; typedef Arc::StateId StateId; @@ -1394,27 +1380,26 @@ void LatticeIncrementalDeterminizer::GetInitialRawLattice( GetRedeterminizedStates(); olat->DeleteStates(); - token_label2last_state->clear(); + token_label2state->clear(); auto start_state = olat->AddState(); olat->SetStart(start_state); // go over all prefinal states after preprocessing for (auto &i : processed_prefinal_states_) { auto prefinal_state = i.second; - bool modified = AddRedeterminizedState(prefinal_state, olat); + bool modified = FindOrAddRedeterminizedState(prefinal_state, olat); if (modified) - GetRawLatticeForRedeterminizedStates(start_state, prefinal_state, - token_label_final_cost, - token_label2last_state, olat); + ProcessRedeterminizedState(prefinal_state, + token_label2final_cost, + token_label2state, olat); } } template -bool LatticeIncrementalDeterminizer::ProcessChunk(Lattice &raw_fst, - int32 first_frame, - int32 last_frame) { - bool not_first_chunk = first_frame != 0; - bool ret = true; +bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk(int32 first_frame, + int32 last_frame, + Lattice *raw_fst) { + bool first_chunk = first_frame == 0; // step 2: Determinize the chunk CompactLattice clat; // We do determinization with beam pruning here @@ -1427,38 +1412,36 @@ bool LatticeIncrementalDeterminizer::ProcessChunk(Lattice &raw_fst, // LatticeFasterDecoder, we need to use a slightly larger beam here // than the lattice_beam used PruneActiveTokens. Hence the beam we use is // (0.1 + config_.lattice_beam) - ret &= DeterminizeLatticePhonePrunedWrapper( - trans_model_, &raw_fst, (config_.lattice_beam + 0.1), &clat, config_.det_opts); + bool determinized_till_beam = DeterminizeLatticePhonePrunedWrapper( + trans_model_, raw_fst, (config_.lattice_beam + 0.1), &clat, config_.det_opts); // step 3: Appending the new chunk in clat to the old one in lat_ - ret &= AppendLatticeChunks(clat, not_first_chunk); + // later we need to calculate forward_costs_ for clat + + TopSortCompactLatticeIfNeeded(&clat); + AppendLatticeChunks(clat, first_chunk); - ret &= (lat_.NumStates() > 0); KALDI_VLOG(2) << "Frame: ( " << first_frame << " , " << last_frame << " )" << " states of the chunk: " << clat.NumStates() - << " states of the lattice: " << lat_.NumStates(); - return ret; + << " states of the lattice: " << clat_.NumStates(); + return determinized_till_beam; } template -bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice clat, - bool not_first_chunk) { +void LatticeIncrementalDeterminizer::AppendLatticeChunks( + const CompactLattice &clat, bool first_chunk) { using namespace fst; - CompactLattice *olat = &lat_; - - // later we need to calculate forward_costs_ for clat - TopSortCompactLatticeIfNeeded(&clat); - + CompactLattice *olat = &clat_; // step 3.1: Appending new chunk to the old one int32 state_offset = olat->NumStates(); - if (not_first_chunk) { + if (!first_chunk) { state_offset--; // since we do not append initial state in the first chunk - // remove arcs from redeterminized_states_ - for (auto i : redeterminized_states_) { + // remove arcs from redeterminized_state_map_ + for (auto i : redeterminized_state_map_) { olat->DeleteArcs(i.first); olat->SetFinal(i.first, CompactLatticeWeight::Zero()); } - redeterminized_states_.clear(); + redeterminized_state_map_.clear(); } else { forward_costs_.push_back(0); // for the first state } @@ -1475,7 +1458,7 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla auto s = siter.Value(); StateId state_appended = kNoStateId; // We do not copy initial state, which exists except the first chunk - if (!not_first_chunk || s != 0) { + if (first_chunk || s != 0) { state_appended = s + state_offset; auto r = olat->AddState(); KALDI_ASSERT(state_appended == r); @@ -1494,7 +1477,7 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla // process it here // In the last chunk, there could be a initial arc ending in final state, and // we process it in "process initial arcs" in the following - bool is_initial_state = (not_first_chunk && s == 0); + bool is_initial_state = (!first_chunk && s == 0); if (!is_initial_state) { KALDI_ASSERT(state_appended != kNoStateId); KALDI_ASSERT(arc.olabel < state_last_initial_offset_); @@ -1566,7 +1549,7 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla KALDI_ASSERT(olat->NumStates() == clat.NumStates() + state_offset); KALDI_VLOG(8) << "states of the lattice: " << olat->NumStates(); - if (!not_first_chunk) { + if (first_chunk) { olat->SetStart(0); // Initialize the first chunk for olat } else { // The extra prefinal states generated by @@ -1587,26 +1570,22 @@ bool LatticeIncrementalDeterminizer::AppendLatticeChunks(CompactLattice cla final_arc_list_.swap(final_arc_list_prev_); final_arc_list_.clear(); - - return true; } template -bool LatticeIncrementalDeterminizer::Finalize() { +void LatticeIncrementalDeterminizer::Finalize() { using namespace fst; - auto *olat = &lat_; // The lattice determinization only needs to be finalized once - if (determinization_finalized_) return true; + if (determinization_finalized_) + return; // step 4: remove dead states if (config_.final_prune_after_determinize) - PruneLattice(config_.lattice_beam, olat); + PruneLattice(config_.lattice_beam, &clat_); else - Connect(olat); // Remove unreachable states... there might be + Connect(&clat_); // Remove unreachable states... there might be - KALDI_VLOG(2) << "states of the lattice: " << olat->NumStates(); + KALDI_VLOG(2) << "states of the lattice: " << clat_.NumStates(); determinization_finalized_ = true; - - return (olat->NumStates() > 0); } // Instantiate the template for the combination of token types and FST types diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 9f930b6610d..a81bcff0984 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -1,6 +1,6 @@ // decoder/lattice-incremental-decoder.h -// Copyright 2019 Zhehuai Chen +// Copyright 2019 Zhehuai Chen, Hainan Xu, Daniel Povey // See ../../COPYING for clarification regarding multiple authors // @@ -31,18 +31,62 @@ #include "lattice-faster-decoder.h" namespace kaldi { +/** + The normal decoder, lattice-faster-decoder.h, sometimes has an issue when + doing real-time applications with long utterances, that each time you get the + lattice the lattice determinization can take a considerable amount of time; + this introduces latency. This version of the decoder spreads the work of + lattice determinization out throughout the decoding process. + + NOTE: + + Please see https://www.danielpovey.com/files/ *TBD* .pdf for a technical + explanation of what is going on here. + + GLOSSARY OF TERMS: + chunk: We do the determinization on chunks of frames; these + may coincide with the chunks on which the user calls + AdvanceDecoding(). The basic idea is to extract chunks + of the raw lattice and determinize them individually, but + it gets much more complicated than that. The chunks + should normally be at least as long as a word (let's say, + at least 20 frames), or the overhead of this algorithm + might become excessive and affect RTF. + + raw lattice chunk: A chunk of raw (i.e. undeterminized) lattice + that we will determinize. In the paper this corresponds + to the FST B that is described in Section 5.2. + + token_label, state_label: In the paper these are both + referred to as `state labels` (these are special, large integer + id's that refer to states in the undeterminized lattice + and in the the determinized lattice); + but we use two separate terms here, for more clarity, + when referring to the undeterminized vs. determinized lattice. + + token_label conceptually refers to states in the + raw lattice, but we don't materialize the entire + raw lattice as a physical FST and and these tokens + are actually tokens (template type Token) held by + the decoder + + state_label when used in this code refers specifically + to labels that identify states in the determinized + lattice (i.e. state indexes in lat_). + + redeterminized-non-splice-state, aka redetnss: + A redeterminized state which is not also a splice state; + refer to the paper for explanation. + */ struct LatticeIncrementalDecoderConfig { + // All the configuration values until det_opts are the same as in + // LatticeFasterDecoder. For clarity we repeat them rather than inheriting. BaseFloat beam; int32 max_active; int32 min_active; BaseFloat lattice_beam; int32 prune_interval; - int32 determinize_delay; - int32 determinize_period; - int32 determinize_max_active; - int32 redeterminize_max_frames; - bool final_prune_after_determinize; BaseFloat beam_delta; // has nothing to do with beam_ratio BaseFloat hash_ratio; BaseFloat prune_scale; // Note: we don't make this configurable on the command line, @@ -51,23 +95,33 @@ struct LatticeIncrementalDecoderConfig { // Most of the options inside det_opts are not actually queried by the // LatticeIncrementalDecoder class itself, but by the code that calls it, for // example in the function DecodeUtteranceLatticeIncremental. - int32 max_word_id; // for GetLattice fst::DeterminizeLatticePhonePrunedOptions det_opts; + // The configuration values from this point on are specific to the + // incremental determinization. + // TODO: explain the following. + int32 determinize_delay; + int32 determinize_period; + int32 determinize_max_active; + int32 redeterminize_max_frames; + bool final_prune_after_determinize; + int32 max_word_id; // for GetLattice + + LatticeIncrementalDecoderConfig() : beam(16.0), max_active(std::numeric_limits::max()), min_active(200), lattice_beam(10.0), prune_interval(25), + beam_delta(0.5), + hash_ratio(2.0), + prune_scale(0.1), determinize_delay(25), determinize_period(20), determinize_max_active(std::numeric_limits::max()), redeterminize_max_frames(std::numeric_limits::max()), final_prune_after_determinize(true), - beam_delta(0.5), - hash_ratio(2.0), - prune_scale(0.1), max_word_id(1e8) {} void Register(OptionsItf *opts) { det_opts.Register(opts); @@ -82,6 +136,7 @@ struct LatticeIncrementalDecoderConfig { opts->Register("prune-interval", &prune_interval, "Interval (in frames) at " "which to prune tokens"); + // TODO: check the following. opts->Register("determinize-delay", &determinize_delay, "Delay (in frames) at which to incrementally determinize " "lattices. A larger delay reduces the computational " @@ -130,12 +185,14 @@ struct LatticeIncrementalDecoderConfig { template class LatticeIncrementalDeterminizer; -/* This is an extention to the "normal" lattice-generating decoder. +/** This is an extention to the "normal" lattice-generating decoder. See \ref lattices_generation \ref decoders_faster and \ref decoders_simple for more information. The main difference is the incremental determinization which will be - discussed in the function GetLattice(). + discussed in the function GetLattice(). This means that the work of determinizatin + isn't done all at once at the end of the file, but incrementally while decoding. + See the comment at the top of this file for more explanation. The decoder is templated on the FST type and the token type. The token type will normally be StdToken, but also may be BackpointerToken which is to support @@ -174,153 +231,145 @@ class LatticeIncrementalDecoderTpl { ~LatticeIncrementalDecoderTpl(); - /// An example of how to do decoding together with incremental - /// determinization. It decodes until there are no more frames left in the - /// "decodable" object. Note, this may block waiting for input - /// if the "decodable" object blocks. - /// In this example, config_.determinize_delay, config_.determinize_period - /// and config_.determinize_max_active are used to determine the time to - /// call GetLattice(). - /// Users may do it in their own ways by calling - /// AdvanceDecoding() and GetLattice(). So the logic for deciding - /// when we get the lattice would be driven by the user. - /// The function returns true if any kind - /// of traceback is available (not necessarily from a final state). + /** + CAUTION: this function is provided only for testing and instructional + purposes. In a scenario where you have the entire file and just want + to decode it, there is no point using this decoder. + + An example of how to do decoding together with incremental + determinization. It decodes until there are no more frames left in the + "decodable" object. + + In this example, config_.determinize_delay, config_.determinize_period + and config_.determinize_max_active are used to determine the time to + call GetLattice(). + + Users will probably want to use appropriate combinations of + AdvanceDecoding() and GetLattice() to build their application; this just + gives you some idea how. + + The function returns true if any kind of traceback is available (not + necessarily from a final state). + */ bool Decode(DecodableInterface *decodable); - /// says whether a final-state was active on the last frame. If it was not, the - /// lattice (or traceback) will end with states that are not final-states. + /// says whether a final-state was active on the last frame. If it was not, + /// the lattice (or traceback) will end with states that are not final-states. bool ReachedFinal() const { return FinalRelativeCost() != std::numeric_limits::infinity(); } - /// Outputs an FST corresponding to the single best path through the lattice. - /// Returns true if result is nonempty (using the return status is deprecated, - /// it will become void). If "use_final_probs" is true AND we reached the - /// final-state of the graph then it will include those as final-probs, else - /// it will treat all final-probs as one. - bool GetBestPath(Lattice *ofst, bool use_final_probs = true); + /** + Outputs an FST corresponding to the single best path through the lattice. + If "use_final_probs" is true AND we reached the + final-state of the graph then it will include those as final-probs, else + it will treat all final-probs as one. + + Note: this gets the traceback from the compact lattice, which will not + include the most recently decoded frames if determinize_delay > 0 and + FinalizeDecoding() has not been called. If you'll be wanting to call + GetBestPath() a lot and need it to be up to date, you may prefer to + use LatticeIncrementalOnlineDecoder. + */ + void GetBestPath(Lattice *ofst, bool use_final_probs = true); /** - The following function is specifically designed for incremental - determinization. The function obtains a CompactLattice for - the part of this utterance up to the frame last_frame_of_chunk. - If you call this multiple times - (calling it on every frame would not make sense, but every, say, - 10 to 40 frames might make sense) it will spread out the work of - determinization over time, which might be useful for online applications. - config_.determinize_delay, config_.determinize_period - and config_.determinize_max_active can be used to determine the time to - call this function. We show an example in Decode(). - - The procedure of incremental determinization is as follow: - step 1: Get lattice chunk with initial and final states and arcs, called `raw - lattice`. - Here, we define a `final arc` as an arc to a final-state, and the source state - of it as a `pre-final state` - Similarly, we define a `initial arc` as an arc from a initial-state, and the - destination state of it as a `post-initial state` - The post-initial states are constructed corresponding to pre-final states - in the determinized and appended lattice before this chunk - The pre-final states are constructed correponding to tokens in the last frames - of this chunk. - Since the StateId can change during determinization, we need to give permanent - unique labels (as olabel) to these - raw-lattice states for latter appending. - We give each token an olabel id, called `token_label`, and each determinized and - appended state an olabel id, called `state_label`. Notably, in our - paper, we call both of them ``state labels'' for simplicity. - step 2: Determinize the chunk of above raw lattice using determinization - algorithm the same as LatticeFasterDecoder. Benefit from above `state_label` and - `token_label` in initial and final arcs, each pre-final state in the last chunk - w.r.t the initial arc of this chunk can be treated uniquely and each token in - the last frame of this chunk can also be treated uniquely. We call the - determinized new - chunk `compact lattice (clat)` - step 3: Appending the new chunk `clat` to the determinized lattice - before this chunk. First, for each StateId in clat except its - initial state, allocate a new StateId in the appended - compact lattice. Copy the arcs except whose incoming state is initial - state. Secondly, for each initial arcs, change its source state to the state - corresponding to its `state_label`, which is a determinized and appended state - Finally, we make the previous final arcs point to a "dead state" - step 4: We remove dead states in the very end. - - In our implementation, step 1 is done in GetIncrementalRawLattice(), - step 2-4 is taken care by the class - LatticeIncrementalDeterminizer - - @param [in] use_final_probs If true *and* at least one final-state in HCLG - was active on the final frame, include final-probs from - HCLG - in the lattice. Otherwise treat all final-costs of - states active - on the most recent frame as zero (i.e. Weight::One()). - @param [in] last_frame_of_chunk Pass the last frame of this chunk to - the function. We make it not always equal to - NumFramesDecoded() to have a delay on the - deteriminization - @param [out] olat The CompactLattice representing what has been decoded - so far. - If lat == NULL, the CompactLattice won't be outputed. - @return ret This function will returns true if the chunk is processed - successfully + This GetLattice() function is the main way you will interact with the + incremental determinization that this class provides. Note that the + interface is slightly different from that of other decoders. For example, + if olat is NULL it will do the work of incremental determinization without + actually giving you the lattice (which can save it some time). + + Note: calling it on every frame doesn't make sense as it would + still have to do a fair amount of work; calling it every, say, + 10 to 40 frames would make sense though. + + @param [in] use_final_probs If true *and* at least one final-state in HCLG + was active on the most recently decoded frame, include the + final-probs from the decoding FST (HCLG) in the lattice. + Otherwise treat all final-costs of states active on the + most recent frame as zero (i.e. use Weight::One()). You + can tell whether a final-prob was active on the most + recent frame by calling ReachedFinal(). + Setting use_final_probs will not affect the lattices + output by subsequent calls to this function. (TODO: + verify this). + + @param [in] num_frames_to_include The number of frames that you want + to be included in the lattice. Must be >0 and + <= NumFramesDecoded(). If you are calling this + just to keep the incremental lattice determinization up to + date and don't really need the lattice (olat == NULL), you + will probably want to give it some delay (at least 5 or 10 + frames); search for determinize-delay in the paper + and for determinize_delay in the configuration class and the + code. You may not call this with a num_frames_to_include + that is smaller than the largest value previously + provided. + + @param [out] olat The CompactLattice representing what has been decoded + up until `num_frames_to_include` (e.g., LatticeStateTimes() + on this lattice would return `num_frames_to_include`). + If NULL, the lattice won't be output, and this will save + the work of copying it, but the incremental determinization + will still be done. */ - bool GetLattice(bool use_final_probs, int32 last_frame_of_chunk, + void GetLattice(bool use_final_probs, int32 num_frames_to_include, CompactLattice *olat = NULL); - /// Specifically design when decoding_finalized_==true - bool GetLattice(CompactLattice *olat); - /// InitDecoding initializes the decoding, and should only be used if you - /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to - /// call this. You can also call InitDecoding if you have already decoded an - /// utterance and want to start with a new utterance. + + /** + InitDecoding initializes the decoding, and should only be used if you + intend to call AdvanceDecoding(). If you call Decode(), you don't need to + call this. You can also call InitDecoding if you have already decoded an + utterance and want to start with a new utterance. + */ void InitDecoding(); - /// This will decode until there are no more frames ready in the decodable - /// object. You can keep calling it each time more frames become available. - /// If max_num_frames is specified, it specifies the maximum number of frames - /// the function will decode before returning. + /** + This will decode until there are no more frames ready in the decodable + object. You can keep calling it each time more frames become available + (this is the normal pattern in a real-time/online decoding scenario). + If max_num_frames is specified, it specifies the maximum number of frames + the function will decode before returning. + */ void AdvanceDecoding(DecodableInterface *decodable, int32 max_num_frames = -1); - /// This function may be optionally called after AdvanceDecoding(), when you - /// do not plan to decode any further. It does an extra pruning step that - /// will help to prune the lattices output by GetLattice more accurately, - /// particularly toward the end of the utterance. - /// It does this by using the final-probs in pruning (if any - /// final-state survived); it also does a final pruning step that visits all - /// states (the pruning that is done during decoding may fail to prune states - /// that are within kPruningScale = 0.1 outside of the beam). If you call - /// this, you cannot call AdvanceDecoding again (it will fail), and you - /// cannot call GetLattice() and related functions with use_final_probs = - /// false. - /// Used to be called PruneActiveTokensFinal(). + + /** + This function may be optionally called after AdvanceDecoding(), when you + do not plan to decode any further. It does an extra pruning step that + will help to prune the lattices output by GetLattice more accurately, + particularly toward the end of the utterance. + It does this by using the final-probs in pruning (if any + final-state survived); it also does a final pruning step that visits all + states (the pruning that is done during decoding may fail to prune states + that are within kPruningScale = 0.1 outside of the beam). If you call + this, you cannot call AdvanceDecoding again (it will fail), and you + cannot call GetLattice() and related functions with use_final_probs = + false. + */ void FinalizeDecoding(); - /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives - /// more information. It returns the difference between the best (final-cost - /// plus cost) of any token on the final frame, and the best cost of any token - /// on the final frame. If it is infinity it means no final-states were - /// present on the final frame. It will usually be nonnegative. If it not - /// too positive (e.g. < 5 is my first guess, but this is not tested) you can - /// take it as a good indication that we reached the final-state with - /// reasonable likelihood. + /** FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives + more information. It returns the difference between the best (final-cost + plus cost) of any token on the final frame, and the best cost of any token + on the final frame. If it is infinity it means no final-states were + present on the final frame. It will usually be nonnegative. If it not + too positive (e.g. < 5 is my first guess, but this is not tested) you can + take it as a good indication that we reached the final-state with + reasonable likelihood. */ BaseFloat FinalRelativeCost() const; - // Returns the number of frames decoded so far. The value returned changes - // whenever we call ProcessEmitting(). + /** Returns the number of frames decoded so far. */ inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } protected: - // we make things protected instead of private, as future code in - // LatticeIncrementalOnlineDecoderTpl, which inherits from this, also will - // use the internals. + /* Some protected things are needed in LatticeIncrementalOnlineDecoderTpl. */ - // Deletes the elements of the singly linked list tok->links. + /** NOTE: for parts the internal implementation that are shared with LatticeFasterDecoer, + we have removed the comments.*/ inline static void DeleteForwardLinks(Token *tok); - - // head of per-frame list of Tokens (list is in topological order), - // and something saying whether we ever pruned it using PruneForwardLinks. struct TokenList { Token *toks; bool must_prune_forward_links; @@ -328,139 +377,56 @@ class LatticeIncrementalDecoderTpl { TokenList() : toks(NULL), must_prune_forward_links(true), must_prune_tokens(true) {} }; - using Elem = typename HashList::Elem; - // Equivalent to: - // struct Elem { - // StateId key; - // Token *val; - // Elem *tail; - // }; - void PossiblyResizeHash(size_t num_toks); - - // FindOrAddToken either locates a token in hash of toks_, or if necessary - // inserts a new, empty token (i.e. with no forward links) for the current - // frame. [note: it's inserted if necessary into hash toks_ and also into the - // singly linked list of tokens active on this frame (whose head is at - // active_toks_[frame]). The frame_plus_one argument is the acoustic frame - // index plus one, which is used to index into the active_toks_ array. - // Returns the Token pointer. Sets "changed" (if non-NULL) to true if the - // token was newly created or the cost changed. - // If Token == StdToken, the 'backpointer' argument has no purpose (and will - // hopefully be optimized out). inline Token *FindOrAddToken(StateId state, int32 frame_plus_one, BaseFloat tot_cost, Token *backpointer, bool *changed); - - // prunes outgoing links for all tokens in active_toks_[frame] - // it's called by PruneActiveTokens - // all links, that have link_extra_cost > lattice_beam are pruned - // delta is the amount by which the extra_costs must change - // before we set *extra_costs_changed = true. - // If delta is larger, we'll tend to go back less far - // toward the beginning of the file. - // extra_costs_changed is set to true if extra_cost was changed for any token - // links_pruned is set to true if any link in any token was pruned void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed, bool *links_pruned, BaseFloat delta); - - // This function computes the final-costs for tokens active on the final - // frame. It outputs to final-costs, if non-NULL, a map from the Token* - // pointer to the final-prob of the corresponding state, for all Tokens - // that correspond to states that have final-probs. This map will be - // empty if there were no final-probs. It outputs to - // final_relative_cost, if non-NULL, the difference between the best - // forward-cost including the final-prob cost, and the best forward-cost - // without including the final-prob cost (this will usually be positive), or - // infinity if there were no final-probs. [c.f. FinalRelativeCost(), which - // outputs this quanitity]. It outputs to final_best_cost, if - // non-NULL, the lowest for any token t active on the final frame, of - // forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in - // the graph of the state corresponding to token t, or the best of - // forward-cost[t] if there were no final-probs active on the final frame. - // You cannot call this after FinalizeDecoding() has been called; in that - // case you should get the answer from class-member variables. void ComputeFinalCosts(unordered_map *final_costs, BaseFloat *final_relative_cost, BaseFloat *final_best_cost) const; - - // PruneForwardLinksFinal is a version of PruneForwardLinks that we call - // on the final frame. If there are final tokens active, it uses - // the final-probs for pruning, otherwise it treats all tokens as final. void PruneForwardLinksFinal(); - - // Prune away any tokens on this frame that have no forward links. - // [we don't do this in PruneForwardLinks because it would give us - // a problem with dangling pointers]. - // It's called by PruneActiveTokens if any forward links have been pruned void PruneTokensForFrame(int32 frame_plus_one); - - // Go backwards through still-alive tokens, pruning them if the - // forward+backward cost is more than lat_beam away from the best path. It's - // possible to prove that this is "correct" in the sense that we won't lose - // anything outside of lat_beam, regardless of what happens in the future. - // delta controls when it considers a cost to have changed enough to continue - // going backward and propagating the change. larger delta -> will recurse - // less far. void PruneActiveTokens(BaseFloat delta); - - /// Gets the weight cutoff. Also counts the active tokens. BaseFloat GetCutoff(Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, Elem **best_elem); - - /// Processes emitting arcs for one frame. Propagates from prev_toks_ to - /// cur_toks_. Returns the cost cutoff for subsequent ProcessNonemitting() to - /// use. BaseFloat ProcessEmitting(DecodableInterface *decodable); - - /// Processes nonemitting (epsilon) arcs for one frame. Called after - /// ProcessEmitting() on each frame. The cost cutoff is computed by the - /// preceding ProcessEmitting(). void ProcessNonemitting(BaseFloat cost_cutoff); - // HashList defined in ../util/hash-list.h. It actually allows us to maintain - // more than one list (e.g. for current and previous frames), but only one of - // them at a time can be indexed by StateId. It is indexed by frame-index - // plus one, where the frame-index is zero-based, as used in decodable object. - // That is, the emitting probs of frame t are accounted for in tokens at - // toks_[t+1]. The zeroth frame is for nonemitting transition at the start of - // the graph. HashList toks_; - - std::vector active_toks_; // Lists of tokens, indexed by - // frame (members of TokenList are toks, must_prune_forward_links, - // must_prune_tokens). + std::vector active_toks_; // indexed by frame. std::vector queue_; // temp variable used in ProcessNonemitting, std::vector tmp_array_; // used in GetCutoff. - - // fst_ is a pointer to the FST we are decoding from. const FST *fst_; - // delete_fst_ is true if the pointer fst_ needs to be deleted when this - // object is destroyed. bool delete_fst_; - - std::vector cost_offsets_; // This contains, for each - // frame, an offset that was added to the acoustic log-likelihoods on that - // frame in order to keep everything in a nice dynamic range i.e. close to - // zero, to reduce roundoff errors. - LatticeIncrementalDecoderConfig config_; - int32 num_toks_; // current total #toks allocated... + std::vector cost_offsets_; + int32 num_toks_; bool warned_; - - /// decoding_finalized_ is true if someone called FinalizeDecoding(). [note, - /// calling this is optional]. If true, it's forbidden to decode more. Also, - /// if this is set, then the output of ComputeFinalCosts() is in the next - /// three variables. The reason we need to do this is that after - /// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some - /// of the tokens on the last frame are freed, so we free the list from toks_ - /// to avoid having dangling pointers hanging around. bool decoding_finalized_; - /// For the meaning of the next 3 variables, see the comment for - /// decoding_finalized_ above., and ComputeFinalCosts(). unordered_map final_costs_; BaseFloat final_relative_cost_; BaseFloat final_best_cost_; + /*** Variables below this point relate to the incremental + determinization. ***/ + LatticeIncrementalDecoderConfig config_; + /** Much of the the incremental determinization algorithm is encapsulated in + the determinize_ object. */ + LatticeIncrementalDeterminizer determinizer_; + /** last_get_lattice_frame_ is the highest `num_frames_to_include_` argument + for any prior call to GetLattice(). */ + int32 last_get_lattice_frame_; + // a map from Token to its token_label + unordered_map token2label_map_; + // we allocate a unique id for each Token + int32 token_label_available_idx_; + // We keep cost_offset for each token_label (Token) in final arcs. We need them to + // guide determinization + // We cancel them after determinization + unordered_map token_label2final_cost_; + + // There are various cleanup tasks... the the toks_ structure contains // singly linked lists of Token pointers, where Elem is the list type. // It also indexes them in a hash, indexed by state (this hash is only @@ -505,31 +471,29 @@ class LatticeIncrementalDecoderTpl { bool GetIncrementalRawLattice(Lattice *ofst, bool use_final_probs, int32 frame_begin, int32 frame_end, bool create_initial_state, bool create_final_state); - // Get the number of tokens in each frame - // It is useful, e.g. in using config_.determinize_max_active + // Returns the number of active tokens on frame `frame`. int32 GetNumToksForFrame(int32 frame); + + // DeterminizeLattice() is just a wrapper for GetLattice() that uses the various + // heuristics specified in the config class to decide when, and with what arguments, + // to call GetLattice() in order to make sure that the incremental determinization + // is kept up to date. It is mainly of use for documentation (it is called inside + // Decode() which is not recommended for users to call in most scenarios). + // We may at some point decide to make this public. void DeterminizeLattice(); - // The incremental lattice determinizer to take care of determinization - // and appending the lattice. - LatticeIncrementalDeterminizer determinizer_; - int32 last_get_lattice_frame_; // the last time we call GetLattice - // a map from Token to its token_label - unordered_map token_label_map_; - // we allocate a unique id for each Token - int32 token_label_available_idx_; - // We keep cost_offset for each token_label (Token) in final arcs. We need them to - // guide determinization - // We cancel them after determinization - unordered_map token_label_final_cost_; KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDecoderTpl); }; typedef LatticeIncrementalDecoderTpl LatticeIncrementalDecoder; -// This class is designed for part of generating raw lattices and determnization -// and appending the lattice. +/** + This class is used inside LatticeIncrementalDecoderTpl; it handles + some of the details of incremental determinization. + https://www.danielpovey.com/files/ *TBD*.pdf for the paper. + +*/ template class LatticeIncrementalDeterminizer { public: @@ -540,33 +504,84 @@ class LatticeIncrementalDeterminizer { LatticeIncrementalDeterminizer(const LatticeIncrementalDecoderConfig &config, const TransitionModel &trans_model); - // Reset the lattice determinization data for an utterance + + // Resets the lattice determinization data for new utterance void Init(); - // Output the resultant determinized lattice in the form of CompactLattice - const CompactLattice &GetDeterminizedLattice() const { return lat_; } - - // Part of step 1 of incremental determinization, - // where the post-initial states are constructed corresponding to - // redeterminized states (see the description in redeterminized_states_) in the - // determinized and appended lattice before this chunk. - // We give each determinized and appended state an olabel id, called `state_label` - // We maintain a map (`token_label2last_state`) from token label (obtained from - // final arcs) to the destination state of the last of the sequence of initial arcs - // w.r.t the token label here - // Notably, we have multiple states for one token label after determinization, - // hence we use multiset here - // We need `token_label_final_cost` to cancel out the cost offset used in guiding - // DeterminizeLatticePhonePrunedWrapper - void GetInitialRawLattice( + + // Returns the current determinized lattice. + const CompactLattice &GetDeterminizedLattice() const { return clat_; } + + + /** + Starts the process of creating a raw lattice chunk. (Search the glossary + for "raw lattice chunk"). This just sets up the initial states and + redeterminized-states in the chunk. Relates to sec. 5.2 in the paper, + specifically the initial-state i and the redeterminized-states. + + After calling this, the caller would add the remaining arcs and states + to `olat` and then call AcceptChunk() with the result. + + @param [out] olat The lattice to be (partially) created + @param [in] token_label2final_cost For each token-label, + contains a cost that we will need to negate and then + introduce into newly created arcs in 'olat' that correspond + to arcs with that token-label on in the previous + determinized chunk. (They won't actually have the token + label as we remove them at this point). This relates + to an issue not discussed in the paper, which is + that to get pruned determinization to work right we + have to introduce special final-probs when determinizing + the previous chunk (think of the previous chunk as + the FST A in the paper). This map allows us to cancel + out those final-probs. + @param [out] token_label2state For each token-label (say, t) + that appears in lat_ (i.e. the result of determinizing previous + chunks), this identifies the state in `olat` that we allocate + for that token-label. This is a so-called `splice-state`; they + will never have arcs leaving them within `lat_`. When the + calling code processes the arcs in the raw lattice, it will add + arcs leaving these splice states. + See the last bullet point before Sec. 5.3 in the paper. + */ + void InitializeRawLatticeChunk( Lattice *olat, - unordered_multimap *token_label2last_state, - const unordered_map &token_label_final_cost); - // This function consumes raw_fst generated by step 1 of incremental - // determinization with specific initial and final arcs. - // It processes lattices and outputs the resultant CompactLattice if - // needed. Otherwise, it keeps the resultant lattice in lat_ - bool ProcessChunk(Lattice &raw_fst, int32 first_frame, int32 last_frame); + const unordered_map &token_label2final_cost, + unordered_map *token_label2state); + /** + This function accepts the raw FST (state-level lattice) corresponding + to a single chunk of the lattice, determinizes it and appends it to + this->clat_. + + @param [in] first_frame The start frame-index, which equals the + total number of frames in all chunks previous to this one. + Only needed to ask "is this the first chunk", plus + debug info. + @param [in] last_frame The end frame-index, which equals the + total number of frames in all previous chunks plus + this one. Only needed for debug. + @param [in] raw_fst (Consumed destructively). The input + raw (state-level) lattice. Would correspond to the + FST A in the paper if first_frame == 0, and B + otherwise. + + @return returns false if determinization finished earlier than the beam, + true otherwise. + */ + bool AcceptRawLatticeChunk(int32 first_frame, int32 last_frame, Lattice *raw_fst); + + + /** + Finalize incremental decoding by pruning the lattice (if + config_.final_prune_after_determinize), otherwise just removing unreachable + states. + */ + void Finalize(); + + + private: + + /** // Step 3 of incremental determinization, // which is to append the new chunk in clat to the old one in lat_ // If not_first_chunk == false, we do not need to append and just copy @@ -575,35 +590,48 @@ class LatticeIncrementalDeterminizer { // the last chunk to states of the first frame of this chunk. // These post-initial and pre-final states are corresponding to the same Token, // guaranteed by unique state labels. - bool AppendLatticeChunks(CompactLattice clat, bool not_first_chunk); + NOTE clat must be top sorted. + */ + void AppendLatticeChunks(const CompactLattice &clat, bool first_chunk); - // Step 4 of incremental determinization, - // which either re-determinize above lat_, or simply remove the dead - // states of lat_ - bool Finalize(); - std::vector &GetForwardCosts() { return forward_costs_; } - private: - // This function either locates a redeterminized state w.r.t nextstate previously - // added, or if necessary inserts a new one. - // The new one is inserted in olat and kept by the map (redeterminized_states_) - // which is from the state in the appended compact lattice to the state_copy in the - // raw lattice. The function returns whether a new one is inserted - // The StateId of the redeterminized state will be outputed by nextstate_copy - bool AddRedeterminizedState(Lattice::StateId nextstate, Lattice *olat, - Lattice::StateId *nextstate_copy = NULL); - // Sub function of GetInitialRawLattice(). Refer to description there - void GetRawLatticeForRedeterminizedStates( - StateId start_state, StateId state, - const unordered_map &token_label_final_cost, - unordered_multimap *token_label2last_state, - Lattice *olat); + /** + In the paper, recall from Sec. 5.2 that some states in det(A) (specifically: + redeterminized state) are also included in B. In the paper we just assumed + that the same state-ids were used, but in practice they are different numbers; + in redeterminized_state_map_ we store the mapping from state-id in det(A)==clat_ to + the state-id in B==raw_lat_chunk. The map is re-initialized each time we + process a new chunk. This function maps from the the state-id in clat_ + to the state_id in `raw_lat_chunk`, adding to the map and creating a new state + in `raw_lat_chunk` if it was not already present. + + @param [in] redet_state State-id of a redeterminized-state in clat_ + @param [in] raw_lat_chunk The raw lattice that we are creating; + this function may add a new state to it. + @param [out] state_id If non-NULL, the state-id in `raw_lat_chunk` + will be output to here + @return Returns true if a new state was created and added to the map + */ + bool FindOrAddRedeterminizedState( + CompactLattice::StateId redet_state, + Lattice *raw_lat_chunk, + Lattice::StateId *state_id = NULL); + + /** + TODO + */ + void ProcessRedeterminizedState( + Lattice::StateId state, + const unordered_map &token_label2final_cost, + unordered_map *token_label2state, + Lattice *raw_lat_chunk); + // This function is to preprocess the appended compact lattice before // generating raw lattices for the next chunk. - // After identifying pre-final states, for any such state that is separated by + // After identifying redeterminized states, for any such state that is separated by // more than config_.redeterminize_max_frames from the end of the current // appended lattice, we create an extra state for it; we add an epsilon arc - // from that pre-final state to the extra state; we copy any final arcs from + // from that redeterminized state to the extra state; we copy any final arcs from // the pre-final state to its extra state and we remove those final arcs from // the original pre-final state. // We also copy arcs meet the following requirements: i) destination-state of the @@ -621,16 +649,19 @@ class LatticeIncrementalDeterminizer { // Record whether we have finished determinized the whole utterance // (including re-determinize) bool determinization_finalized_; + /** // A map from the prefinal state to its correponding first final arc (there could be // multiple final arcs). We keep final arc information for GetRedeterminizedStates() // later. It can also be used to identify whether a state is a prefinal state. + */ unordered_map final_arc_list_; unordered_map final_arc_list_prev_; // alpha of each state in lat_ std::vector forward_costs_; // we allocate a unique id for each source-state of the last arc of a series of - // initial arcs in GetInitialRawLattice + // initial arcs in InitializeRawLatticeChunk int32 state_last_initial_offset_; + // We define a state in the appended lattice as a 'redeterminized-state' (meaning: // one that will be redeterminized), if it is: a pre-final state, or there // exists an arc from a redeterminized state to this state. We keep reapplying @@ -638,16 +669,19 @@ class LatticeIncrementalDeterminizer { // is not included. These redeterminized states will be stored in this map // which is a map from the state in the appended compact lattice to the // state_copy in the newly-created raw lattice. - unordered_map redeterminized_states_; + unordered_map redeterminized_state_map_; + + /** // It is a map used in GetRedeterminizedStates (see the description there) // A map from the original pre-final state to the pre-final states (i.e. the // original pre-final state or an extra state generated by // GetRedeterminizedStates) used for generating raw lattices of the next chunk. + */ unordered_map processed_prefinal_states_; - // The compact lattice we obtain. It should be reseted before processing a - // new utterance - CompactLattice lat_; + // The compact lattice we obtain. It should be cleared before processing a new + // utterance + CompactLattice clat_; KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDeterminizer); }; diff --git a/src/lat/determinize-lattice-pruned.h b/src/lat/determinize-lattice-pruned.h index 8e1858aa2b1..35323709815 100644 --- a/src/lat/determinize-lattice-pruned.h +++ b/src/lat/determinize-lattice-pruned.h @@ -105,8 +105,8 @@ namespace fst { representation" and hence the "minimal representation" will be the same. We can use this to reduce compute. Note that if two initial representations are different, this does not preclude the other representations from being the same. - -*/ + +*/ struct DeterminizeLatticePrunedOptions { @@ -190,7 +190,7 @@ template bool DeterminizeLatticePruned( const ExpandedFst > &ifst, double prune, - MutableFst > *ofst, + MutableFst > *ofst, DeterminizeLatticePrunedOptions opts = DeterminizeLatticePrunedOptions()); @@ -199,7 +199,7 @@ bool DeterminizeLatticePruned( (i.e. the sequences of output symbols are represented directly as strings The input FST must be topologically sorted in order for the algorithm to work. For efficiency it is recommended to sort the ilabel for the input FST as well. - Returns true on success, and false if it had to terminate the determinization + Returns true on normal success, and false if it had to terminate the determinization earlier than specified by the "prune" beam-- that is, if it terminated because of the max_mem, max_loop or max_arcs constraints in the options. CAUTION: if Lattice is the input, you need to Invert() before calling this, @@ -261,7 +261,7 @@ bool DeterminizeLatticePhonePruned( = DeterminizeLatticePhonePrunedOptions()); /** "Destructive" version of DeterminizeLatticePhonePruned() where the input - lattice might be changed. + lattice might be changed. */ template bool DeterminizeLatticePhonePruned( From a3fb8ce01dbc6a49c38d25ec39dee8d2622207f4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 7 Nov 2019 14:39:57 -0800 Subject: [PATCH 41/60] Further cleanup --- src/decoder/lattice-incremental-decoder.cc | 10 ++++------ src/decoder/lattice-incremental-decoder.h | 12 ++++++------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 8c2404d9a76..63e4c93bef7 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -1148,7 +1148,6 @@ LatticeIncrementalDeterminizer::LatticeIncrementalDeterminizer( template void LatticeIncrementalDeterminizer::Init() { final_arc_list_.clear(); - final_arc_list_prev_.clear(); clat_.DeleteStates(); determinization_finalized_ = false; forward_costs_.clear(); @@ -1290,10 +1289,10 @@ void LatticeIncrementalDeterminizer::GetRedeterminizedStates() { using namespace fst; processed_prefinal_states_.clear(); // go over all prefinal state - KALDI_ASSERT(final_arc_list_prev_.size()); + KALDI_ASSERT(final_arc_list_.size()); unordered_set prefinal_states; - for (auto &i : final_arc_list_prev_) { + for (auto &i : final_arc_list_) { auto prefinal_state = i.first; ArcIterator aiter(clat_, prefinal_state); KALDI_ASSERT(clat_.NumArcs(prefinal_state) > i.second); @@ -1336,9 +1335,9 @@ void LatticeIncrementalDeterminizer::GetRedeterminizedStates() { // destination-state of the arc is no further than // redeterminize_max_frames from the most recent frame we are // determinizing - auto r = final_arc_list_prev_.find(arc.nextstate); + auto r = final_arc_list_.find(arc.nextstate); // destination-state of the arc is not prefinal state - if (r == final_arc_list_prev_.end()) remain_the_arc = true; + if (r == final_arc_list_.end()) remain_the_arc = true; // destination-state of the arc is prefinal state else remain_the_arc = false; @@ -1568,7 +1567,6 @@ void LatticeIncrementalDeterminizer::AppendLatticeChunks( } } - final_arc_list_.swap(final_arc_list_prev_); final_arc_list_.clear(); } diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index a81bcff0984..a139635cfef 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -74,9 +74,10 @@ namespace kaldi { to labels that identify states in the determinized lattice (i.e. state indexes in lat_). - redeterminized-non-splice-state, aka redetnss: + redeterminized-non-splice-state, aka ns_redet: A redeterminized state which is not also a splice state; - refer to the paper for explanation. + refer to the paper for explanation. In the already-determinized + part this means a redeterminized state which is not final. */ struct LatticeIncrementalDecoderConfig { @@ -649,13 +650,12 @@ class LatticeIncrementalDeterminizer { // Record whether we have finished determinized the whole utterance // (including re-determinize) bool determinization_finalized_; + /** - // A map from the prefinal state to its correponding first final arc (there could be - // multiple final arcs). We keep final arc information for GetRedeterminizedStates() - // later. It can also be used to identify whether a state is a prefinal state. + final_arc_list_ is a map from each non-final redeterminized-state + in clat_ to the arc-index of its first arc to a final state. */ unordered_map final_arc_list_; - unordered_map final_arc_list_prev_; // alpha of each state in lat_ std::vector forward_costs_; // we allocate a unique id for each source-state of the last arc of a series of From 629c449e4a0c43d63b01f6446c22df2629e1cb56 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 8 Nov 2019 18:56:39 -0800 Subject: [PATCH 42/60] Some intermediate work on incremental-decoder rewrite --- src/decoder/lattice-faster-decoder.h | 13 +- src/decoder/lattice-incremental-decoder.cc | 481 +++++++++++++++++++-- src/decoder/lattice-incremental-decoder.h | 397 ++++++++--------- 3 files changed, 615 insertions(+), 276 deletions(-) diff --git a/src/decoder/lattice-faster-decoder.h b/src/decoder/lattice-faster-decoder.h index d6bac1bca5d..57cbe5fe178 100644 --- a/src/decoder/lattice-faster-decoder.h +++ b/src/decoder/lattice-faster-decoder.h @@ -318,15 +318,10 @@ class LatticeFasterDecoderTpl { /// This function may be optionally called after AdvanceDecoding(), when you /// do not plan to decode any further. It does an extra pruning step that /// will help to prune the lattices output by GetLattice and (particularly) - /// GetRawLattice more accurately, particularly toward the end of the - /// utterance. It does this by using the final-probs in pruning (if any - /// final-state survived); it also does a final pruning step that visits all - /// states (the pruning that is done during decoding may fail to prune states - /// that are within kPruningScale = 0.1 outside of the beam). If you call - /// this, you cannot call AdvanceDecoding again (it will fail), and you - /// cannot call GetLattice() and related functions with use_final_probs = - /// false. - /// Used to be called PruneActiveTokensFinal(). + /// GetRawLattice more completely, particularly toward the end of the + /// utterance. If you call this, you cannot call AdvanceDecoding again (it + /// will fail), and you cannot call GetLattice() and related functions with + /// use_final_probs = false. Used to be called PruneActiveTokensFinal(). void FinalizeDecoding(); /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 63e4c93bef7..9949f094ad7 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -75,7 +75,7 @@ void LatticeIncrementalDecoderTpl::InitDecoding() { toks_.Insert(start_state, start_tok); num_toks_++; - last_get_lattice_frame_ = 0; + num_frames_in_lattice_ = 0; token2label_map_.clear(); token2label_map_.reserve(std::min((int32)1e5, config_.max_active)); token_label_available_idx_ = config_.max_word_id + 1; @@ -95,7 +95,7 @@ void LatticeIncrementalDecoderTpl::DeterminizeLattice() { int32 frame_det_most = NumFramesDecoded() - config_.determinize_delay; // The minimum length of chunk is config_.determinize_period. if (frame_det_most % config_.determinize_period == 0) { - int32 frame_det_least = last_get_lattice_frame_ + config_.determinize_period; + int32 frame_det_least = num_frames_in_lattice_ + config_.determinize_period; // Incremental determinization: // To adaptively decide the length of chunk, we further compare the number of // tokens in each frame and a pre-defined threshold. @@ -942,26 +942,26 @@ void LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, olat->DeleteStates(); /* Clear the FST */ KALDI_ASSERT(olat->Start() == fst::kNoStateId); // TODO: remove using namespace fst; - bool first_chunk = last_get_lattice_frame_ == 0; + bool first_chunk = num_frames_in_lattice_ == 0; - KALDI_ASSERT(last_get_lattice_frame_ <= last_frame_of_chunk); - if (last_get_lattice_frame_ < last_frame_of_chunk) { + KALDI_ASSERT(num_frames_in_lattice_ <= last_frame_of_chunk); + if (num_frames_in_lattice_ < last_frame_of_chunk) { Lattice raw_fst; // step 1: Get lattice chunk with initial and final states // In this function, we do not create the initial state in // the first chunk, and we do not create the final state in the last chunk - if (!GetIncrementalRawLattice(&raw_fst, use_final_probs, last_get_lattice_frame_, + if (!GetIncrementalRawLattice(&raw_fst, use_final_probs, num_frames_in_lattice_, last_frame_of_chunk, !first_chunk, !decoding_finalized_)) KALDI_ERR << "Unexpected problem when getting lattice"; // step 2-3 - determinizer_.AcceptRawLatticeChunk(last_get_lattice_frame_, + determinizer_.AcceptRawLatticeChunk(num_frames_in_lattice_, last_frame_of_chunk, &raw_fst); - last_get_lattice_frame_ = last_frame_of_chunk; - } else if (last_get_lattice_frame_ > last_frame_of_chunk) { + num_frames_in_lattice_ = last_frame_of_chunk; + } else if (num_frames_in_lattice_ > last_frame_of_chunk) { KALDI_WARN << "Call GetLattice up to frame: " << last_frame_of_chunk << " while the determinizer_ has already done up to frame: " - << last_get_lattice_frame_; + << num_frames_in_lattice_; } if (decoding_finalized_) @@ -1140,27 +1140,25 @@ int32 LatticeIncrementalDecoderTpl::GetNumToksForFrame(int32 frame) return r; } -template -LatticeIncrementalDeterminizer::LatticeIncrementalDeterminizer( +LatticeIncrementalDeterminizer::LatticeIncrementalDeterminizer( const LatticeIncrementalDecoderConfig &config, const TransitionModel &trans_model) : config_(config), trans_model_(trans_model) {} -template -void LatticeIncrementalDeterminizer::Init() { +void LatticeIncrementalDeterminizer::Init() { final_arc_list_.clear(); clat_.DeleteStates(); determinization_finalized_ = false; forward_costs_.clear(); - state_last_initial_offset_ = 2 * config_.max_word_id; + state_label_offset_ = 2 * config_.max_word_id; redeterminized_state_map_.clear(); processed_prefinal_states_.clear(); } -template -bool LatticeIncrementalDeterminizer::FindOrAddRedeterminizedState( + +bool LatticeIncrementalDeterminizer::FindOrAddRedeterminizedState( Lattice::StateId nextstate, Lattice *olat, Lattice::StateId *nextstate_copy) { using namespace fst; bool modified = false; - StateId nextstate_insert = kNoStateId; + LatticeArc::StateId nextstate_insert = kNoStateId; auto r = redeterminized_state_map_.insert({nextstate, nextstate_insert}); if (r.second) { // didn't exist, successfully insert here // create a new state w.r.t state @@ -1178,8 +1176,7 @@ bool LatticeIncrementalDeterminizer::FindOrAddRedeterminizedState( return modified; } -template -void LatticeIncrementalDeterminizer::ProcessRedeterminizedState( +void LatticeIncrementalDeterminizer::ProcessRedeterminizedState( Lattice::StateId state, const unordered_map &token_label2final_cost, unordered_map *token_label2state, @@ -1197,7 +1194,7 @@ void LatticeIncrementalDeterminizer::ProcessRedeterminizedState( ArcIterator aiter(clat_, state); // use state_label in initial arcs - int state_label = state + state_last_initial_offset_; + int state_label = state + state_label_offset_; // Moreover, we need to use the forward coast (alpha) of this determinized and // appended state to guide the determinization later KALDI_ASSERT(state < forward_costs_.size()); @@ -1219,7 +1216,7 @@ void LatticeIncrementalDeterminizer::ProcessRedeterminizedState( // the destination of the arc is a final -> a "splice state". if (clat_.Final(arc.nextstate) != CompactLatticeWeight::Zero()) { KALDI_ASSERT(arc_olabel > config_.max_word_id && - arc_olabel < state_last_initial_offset_); // token label + arc_olabel < state_label_offset_); // token label // create a initial arc // Get arc weight here @@ -1284,8 +1281,7 @@ void LatticeIncrementalDeterminizer::ProcessRedeterminizedState( token_label2state, olat); } } -template -void LatticeIncrementalDeterminizer::GetRedeterminizedStates() { +void LatticeIncrementalDeterminizer::GetRedeterminizedStates() { using namespace fst; processed_prefinal_states_.clear(); // go over all prefinal state @@ -1322,7 +1318,7 @@ void LatticeIncrementalDeterminizer::GetRedeterminizedStates() { bool remain_the_arc = true; // If we remain the arc, the state will not be // re-determinized, vice versa. if (arc.olabel > config_.max_word_id) { // final arc - KALDI_ASSERT(arc.olabel < state_last_initial_offset_); + KALDI_ASSERT(arc.olabel < state_label_offset_); KALDI_ASSERT(clat_.Final(arc.nextstate) != CompactLatticeWeight::Zero()); remain_the_arc = false; } else { @@ -1365,8 +1361,7 @@ void LatticeIncrementalDeterminizer::GetRedeterminizedStates() { // This function is specifically designed to obtain the initial arcs for a chunk // We have multiple states for one token label after determinization -template -void LatticeIncrementalDeterminizer::InitializeRawLatticeChunk( +void LatticeIncrementalDeterminizer::InitializeRawLatticeChunk( Lattice *olat, const unordered_map &token_label2final_cost, unordered_map *token_label2state) { @@ -1394,11 +1389,417 @@ void LatticeIncrementalDeterminizer::InitializeRawLatticeChunk( } } -template -bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk(int32 first_frame, - int32 last_frame, - Lattice *raw_fst) { - bool first_chunk = first_frame == 0; + + +/* This utility function adds an arc to a Lattice, but where the source is a + CompactLatticeArc. If the CompactLatticeArc has a string with length greater + than 1, this will require adding extra states to `lat`. + */ +static void AddCompactLatticeArcToLattice( + const CompactLatticeArc &clat_arc, + LatticeArc::StateId src_state, + Lattice *lat) { + const std::vector &string = clat_arc.weight.String(); + size_t N = string.size(); + if (N == 0) { + LatticeArc arc; + arc.ilabel = 0; + arc.olabel = clat_arc.label; + arc.nextstate = clat_arc.nextstate; + arc.weight = clat_arc.weight.Weight(); + lat->AddArc(src_state, arc); + } else { + LatticeArc::StateId cur_state = arc_state; + for (size_t i = 0; i < N; i++) { + LatticeArc arc; + arc.ilabel = string[i]; + arc.olabel = (i == 0 ? clat_arc.ilabel : 0); + arc.nextstate = (i + 1 == N ? clat_arc.nextstate : lat->AddState()); + arc.weight = (i == 0 ? clat_arc.weight.Weight() : 0); + lat->AddArc(cur_state, arc); + cur_state = arc.nextstate; + } + } +} + + +/** + Reweights a compact lattice chunk in a way that makes the combination with + the current compact lattice easier. Also removes some temporary + forward-probs that we previously added. +*/ +void LatticeIncrementalDeterminizer2::ReweightChunk( + CompactLattice *chunk_clat) { + using StateId = CompactLatticeArc::StateId; + using Label = CompactLatticeArc::Label; + StateId start = chunk_clat->Start(); + + std::vector potentials(chunk_clat->NumStates(), + CompactLatticeWeight::One()); + + for (fst::MutableArcIterator aiter(chunk_clat, start_state); + !aiter.Done(); aiter.Next()) { + CompactLatticeArc arc = aiter.Value(); + Label label = arc.ilabel; // ilabel == olabel. + StateId clat_state = label - kStateLabelOffset; + KALDI_ASSERT(clat_state >= 0 && clat_state < clat_num_states); + // `extra_weight` serves to cancel out the weight + // `forward_costs_[clat_state]` that we introduced in + // InitializeRawLatticeChunk(); the purpose of that was to + // make the pruned determinization work right, but they are + // no longer needed. + LatticeWeight extra_weight(-forward_costs_[clat_state], 0.0); + arc.weight.SetWeight( + CompactLatticeWeight::Times(arc.weight.Weight(), + extra_weight)); + aiter.SetValue(arc); + potentials[arc.nextstate] = arc.weight; + } + // TODO: consider doing the following manually for this special case, + // since most states are not reweighted. + fst::Reweight(potentials, fst::ReweightToFinal, chunk_clat); + + // Below is just a check that weights on arcs leaving initial state + // are all One(). + // TODO: remove the following. + for (fst::ArcIterator aiter(*chunk_clat, start_state); + !aiter.Done(); aiter.Next()) { + KALDI_ASSERT(fst::ApproxEqual(aiter.Value().weight, + CompactLatticeWeight::One())); + } + Label label = arc.ilabel; // ilabel == olabel. + StateId clat_state = label - kStateLabelOffset; + KALDI_ASSERT(clat_state >= 0 && clat_state < clat_num_states); + +} + + +/** + Identifies states in `chunk_clat` that have arcs entering them with a + `token-label` on them (see glossary in header for definition). + It produces a map from such states in chunk_clat, to the `token-label` + on arcs entering them. (It is not possible that the same state would + have multiple arcs entering it with different token-labels, or + some arcs entering with one token-label and some another, or be + both initial and have such arcs; this is true due to how we construct + the raw lattice.) + */ +void LatticeIncrementalDeterminizer2::IdentifyTokenFinalStates( + const CompactLattice &chunk_clat, + std::unordered_map *token_map) { + token_map->clear(); + using StateId = CompactLatticeArc::StateId; + using Label = CompactLatticeArc::Label; + + StateId num_states = chunk_clat.NumStates(); + for (StateId state = 0; state < num_states; state++) { + for (fst::ArcIterator aiter(chunk_clat, start_state); + !aiter.Done(); aiter.Next()) { + CompactLatticeArc &arc = aiter.Value(); + if (arc.olabel >= kTokenLabelOffset && arc.olabel < kMaxTokenLabel) { + StateId nextstate = arc.nextstate; + auto r = token_map->insert({nextstate, arc.olabel}); + // Check consistency of labels on incoming arcs + KALDI_ASSERT(r->second.second == arc.olabel); + } + } + } +} + + + +void LatticeIncrementalDeterminizer2::InitializeRawLatticeChunk( + Lattice *olat, + unordered_map *token_label2state) { + using namespace fst; + + + olat->DeleteStates(); + LatticeArc::State start_state = olat->AddState(); + token_label2state->clear(); + + // redet_state_map maps from state-ids in clat_ to state-ids in olat_. + unordered_map redet_state_map; + + for (CompactLatticeArc::StateId redet_state: non_final_redet_states_) + redet_state_map[redet_state] = olat->AddState(); + + // First, process any arcs leaving the non-final redeterminized states that + // are not to final-states. (What we mean by "not to final states" is, not to + // stats that are final in the `canonical appended lattice`.. they may + // actually be physically final in clat_, because we make clat_ what we want + // to return to the user. + for (CompactLatticeArc::StateId redet_state: non_final_redet_states_) { + LatticeArc::StateId lat_state = redet_state_map[redet_state]; + + for (ArcIterator aiter(clat_, redet_state); + !aiter.Done(); aiter.Next()) { + const CompactLatticeArc &arc = aiter.Value(); + CompactLatticeArc::StateId nextstate = arc.nextstate; + auto iter = redet_state_map.find(nextstate); + KALDI_(iter != redet_state_map.end()); + CompactLatticeArc clat_arc(arc); + clat_arc.nextstate = iter->second; + AddCompactLatticeArcToLattice(clat_arc, lat_state, olat); + } + } + + for (const CompactLatticeArc &arc: final_arcs_) { + // We abuse the `nextstate` field to store the source state. + CompactLatticeArc::StateId src_state = arc.nextstate; + Label token_label = arc.ilabel; // will be == arc.olabel. + KALDI_ASSERT(token_label >= kTokenLabelOffset && + token_label < kMaxTokenLabel); + CompactLatticeArc + + auto r = token_label2state->insert({token_labelstate_label, + olat->NumStates()}); + if (r.second) { // was inserted + StateId new_state = olat->AddState(); + KALDI_ASSERT(r.first->second == new_state); + } + LatticeArc::StateId next_lat_state = r.second; + auto iter = redet_state_map.find(src_state); + KALDI_ASSERT(iter != redet_state_map.end()); + LatticeArc::StateId src_lat_state = iter->second; + CompactLatticeArc new_arc; + new_arc.nextstate = next_lat_state; + new_arc.ilabel = new_arc.olabel = token_label; + new_arc.weight = arc.weight; + AddCompactLatticeArcToLattice(new_arc, src_lat_state, olat); + } + + + // Now deal with the initial-probs. Arcs from initial-states to + // redeterminized-states in the raw lattice have an olabel that identifies the + // id of that redeterminized-state in clat_, and a cost that is derived from + // its entry in forward_costs_. These forward-probs are used to get the + // pruned lattice determinization to behave correctly, and will be canceled + // out later on. + // + // In the paper this is the second-from-last bullet in Sec. 5.2. NOTE: in the + // paper we state that we only include such arcs for "each redeterminized + // state that is either initial in det(A) or that has an arc entering it from + // a state that is not a redeterminized state." In fact, we include these + // arcs for all redeterminized states. I realized that it won't make a + // difference to the outcome, and it's easier to do it this way. + for (auto iter: non_final_redet_states_) { + CompactLatticeArc::StateId state_id = iter->first; + BaseFloat forward_cost = forward_costs_[state_id]; + LatticeArc arc; + arc.ilabel = 0; + // The olabel (which appears where the word-id would) is what + // we call a 'state-label'. It identifies a state in clat_. + arc.olabel = state_id + kStateLabelOffset; + // It doesn't matter what field we put forward_cost in (or whether we + // divide it among them both; the effect on pruning is the same, and + // we will cancel it out later anyway. + arc.weight = LatticeWeight(forward_cost, 0); + auto iter = redet_state_map.find(state_id); + KALDI_ASSERT(iter != redet_state_map.end()); + arc.nextstate = iter->second; + olat->AddArc(start_state, arc); + } +} + + +bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( + Lattice *raw_fst, + std::unordered_map *new_final_costs) { + using Label = CompactLatticeArc::Label; + using StateId = CompactLatticeArc::StateId; + + bool first_chunk = (first_frame == 0); + + + // final_costs is a map from a `token-label` (see glossary) to the + // associated final-prob in a final-state of `raw_fst`, that is associated with + // that Token. These are Tokens that were active at the end of + // the chunk. The final-probs may arise from beta (backward) costs, + // introduced for pruning purposes, and/or from final-probs in HCLG. + // Those costs will not be included in anything we store in this class; + // we will use `old_final_costs` later to cancel them out. + std::unordered_map old_final_costs; + if (!is_last_chunk) { + StateId raw_fst_num_states = raw_fst->NumStates(); + for (LatticeArc::StateId s = 0; s < raw_fst_num_states; s++) { + for (ArcIterator aiter(*raw_fst, s); !aiter.Done(); + aiter.Next()) { + const LatticeArc &value = aiter.Value(); + if (value.olabel >= (Label)kTokenLabelOffset && + value.olabel < (Label)kMaxTokenLabel) { + LatticeWeight final_weight = raw_fst->Final(value.nextstate); + if (final_weight == LatticeState::Zero() || + final_weight.Value2() != 0) { + KALDI_ERR << "Label " << value.olabel + << " looks like a token-label but its next-state " + "has unexpected final-weight " << final_weight.Value1() << ',' + << final_weight.Value2(); + } + auto r = final_costs.insert({value.olabel, final_weight.Value1()}); + if (!r->second && r->first.second != final_weight.Value1()) { + // For any given token-label, all arcs in raw_fst with that + // olabel should go to the same state, so this should be + // impossible. + KALDI_ERR << "Unexpected mismatch in final-costs for tokens, " + << r->first.second << " vs " << final_weight.Value1(); + } + } + } + } + } + + + CompactLattice chunk_clat; + bool determinized_till_beam = DeterminizeLatticePhonePrunedWrapper( + trans_model_, raw_fst, (config_.lattice_beam + 0.1), &chunk_clat, + config_.det_opts); + + TopSortCompactLatticeIfNeeded(&chunk_clat); + + StateId num_chunk_states = chunk_clat.NumStates(); + if (num_chunk_states == 0) { + // This will be an error but user-level calling code can detect it from the + // lattice being empty. + chunk_clat_.DeleteStates(); + return; + } + + ReweightChunk(&chunk_clat); + + StateId start_state = chunk_clat.Start(); // would be 0. + KALDI_ASSERT(start_state == 0); + + // Process arcs leaving the start state. All arcs leaving the start state will + // have `state labels` on them (identifying redeterminized-states in clat_), + // and will transition to a state in `chunk_clat` that we can identify with + // that redeterminized- state. + + // state_map maps from (non-initial state s in chunk_clat) to: + // if s is not final, then a state in clat_, + // if s is final, then a state-label allocated by AllocateNewStateLabel(); + // this will become a .nextstate in final_arcs_). + std::unordered_map state_map; + + StateId clat_num_states = clat_.NumStates(); + + // Process arcs leaving the start state of chunk_clat. These will + // have state-labels on them. The weights will all be One(); + // this is ensured in ReweightChunk(). + for (fst::ArcIterator aiter(chunk_clat, start_state); + !aiter.Done(); aiter.Next()) { + const CompactLatticeArc &arc = aiter.Value(); + Label label = arc.ilabel; // ilabel == olabel. + StateId clat_state = label - kStateLabelOffset; + KALDI_ASSERT(clat_state >= 0 && clat_state < clat_num_states); + StateId chunk_state = arc.nextstate; + + CompactLatticeWeight weight(arc.weight); + + bool inserted = state_map.insert({chunk_state, clat_state}); + // Should not have been in the map before. + KALDI_ASSERT(inserted); + } + + + // Remove any existing arcs in clat_ that leave redeterminized-states, + // and make those states non-final. + for (auto iter: non_final_redet_states_) { + StateId clat_state = *iter; + clat_.DeleteArcs(clat_state); + clat.SetFinal(clat_state, CompactLatticeWeight::Zero()); + } + + // The final-arc info is no longer relevant, we'll recreate it below. + final_arcs_.clear(); + + + // assume start-state == 0; we asserted it above. Allocate state-ids for all + // remaining states in chunk_clat (Except final-states, if this is not the + // last chunk). + for (StateId state = 1; state < num_chunk_states; state++) { + if (is_last_chunk || chunk_clat.Final(state) == CompactLatticeWeight::Zero()) { + // Allocate an actual state. + StateId new_clat_state = clat_.NumStates(); + if (state_map.insert({state, new_clat_state}).second) { + // If it was inserted then we need to actually allocate that state + StateId s = clat_.NewState(); + KALDI_ASSERT(s == new_clat_state); + } // else do nothing; it would have been a redeterminized-state and no + // allocation is needed since they already exist in clat_. and + // in state_map. + } + } + + // Now transfer arcs from chunk_clat to clat_. + for (StateId chunk_state = 1; chunk_state < num_chunk_states; chunk_state++) { + bool is_final = chunk_clat.Final(chunk_state) != CompactLattice::Zero(); + if (is_last_chunk || !is_final) { + auto iter = state_map.find(chunk_state); + KALDI_ASSERT(iter != state_map.end()); + StateId clat_state = iter->second; + if (is_last_chunk && is_final) + clat_.SetFinal(clat_state, chunk_clat.Final(chunk_state)); + for (ArcIterator aiter(chunk_clat, chunk_state); + !aiter.Done(); aiter.Next()) { + CompactLatticeArc arc(aiter.Value()); + + auto next_iter = state_map.find(arc.nextstate); + if (next_iter != state_map.end()) { + arc.nextstate = next_iter->second; + clat_->AddArc(clat_state, arc); + } else { + KALDI_ASSERT(chunk_clat.Final(arc.nextstate) != CompactLatticeWeight::Zero() && + + arc.olabel >= (Label)kTokenLabelOffset && + arc.olabel < (Label)kMaxTokenLabel); + // Below we'll correct arc.weight for the final-cost. + arc.weight = fst::Times(arc.weight, chunk_clat.Final(arc.nextstate)); + // We just use the .nextstate field to encode the source state. + arc.nextstate = clat_state; + + // Note: the only reason we introduce these final-probs to clat_ + // is so that the user can obtain the compact lattice at an intermediate + // stage of the calculation. + if (keep_final_probs) + clat_->SetFinal(fst::Sum(lat_->Final(), + arc.weight)); + + // Cancel out `final_cost` (which will really be some kind of + // `backward`/beta cost from the raw lattice, introduced to guide + // pruned determinization) from arc.weight. + auto final_cost_iter = final_costs.find(arc.olabel); + KALDI_ASSERT(final_cost_iter != final_costs.end()); + BaseFloat final_cost = final_cost_iter; + arc.weight.SetWeight(Times(arc.weight.Weight(), + LatticeWeight(-final_cost, 0))); + + if (!keep_final_probs) // Set the final-prob of the state after + // sutracting the backward cost. + clat_->SetFinal(fst::Sum(lat_->Final(), + arc.weight)); + final_arcs_.push_back(arc); + } + } + } + } + return determinized_till_beam; +} + +/* + TODO: move outside. + KALDI_VLOG(2) << "Frame: ( " << first_frame << " , " << last_frame << " )" + << " states of the chunk: " << clat.NumStates() + << " states of the lattice: " << clat_.NumStates(); +*/ + + + +bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( + int32 first_frame, int32 last_frame, + Lattice *raw_fst) { + + bool first_chunk = (first_frame == 0); // step 2: Determinize the chunk CompactLattice clat; // We do determinization with beam pruning here @@ -1426,8 +1827,7 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk(int32 first_fram return determinized_till_beam; } -template -void LatticeIncrementalDeterminizer::AppendLatticeChunks( +void LatticeIncrementalDeterminizer::AppendLatticeChunks( const CompactLattice &clat, bool first_chunk) { using namespace fst; CompactLattice *olat = &clat_; @@ -1479,12 +1879,12 @@ void LatticeIncrementalDeterminizer::AppendLatticeChunks( bool is_initial_state = (!first_chunk && s == 0); if (!is_initial_state) { KALDI_ASSERT(state_appended != kNoStateId); - KALDI_ASSERT(arc.olabel < state_last_initial_offset_); + KALDI_ASSERT(arc.olabel < state_label_offset_); source_state = state_appended; // process final arcs if (arc.olabel > config_.max_word_id) { // record final_arc in this chunk for the step 3.2 in the next call - KALDI_ASSERT(arc.olabel < state_last_initial_offset_); + KALDI_ASSERT(arc.olabel < state_label_offset_); KALDI_ASSERT(clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()); // state_appended shouldn't be in invert_processed_prefinal_states // So we do not need to map it @@ -1498,8 +1898,8 @@ void LatticeIncrementalDeterminizer::AppendLatticeChunks( // state_label auto state_label = arc.olabel; KALDI_ASSERT(state_label > config_.max_word_id); - KALDI_ASSERT(state_label >= state_last_initial_offset_); - source_state = state_label - state_last_initial_offset_; + KALDI_ASSERT(state_label >= state_label_offset_); + source_state = state_label - state_label_offset_; arc_appended.olabel = 0; arc_appended.ilabel = 0; CompactLatticeWeight weight_offset; @@ -1529,7 +1929,7 @@ void LatticeIncrementalDeterminizer::AppendLatticeChunks( arc_postinitial.nextstate += state_offset; olat->AddArc(source_state, arc_postinitial); if (arc_postinitial.olabel > config_.max_word_id) { - KALDI_ASSERT(arc_postinitial.olabel < state_last_initial_offset_); + KALDI_ASSERT(arc_postinitial.olabel < state_label_offset_); final_arc_list_.insert(pair( source_state, aiter_postinitial.Position() + arc_offset)); } @@ -1570,8 +1970,7 @@ void LatticeIncrementalDeterminizer::AppendLatticeChunks( final_arc_list_.clear(); } -template -void LatticeIncrementalDeterminizer::Finalize() { +void LatticeIncrementalDeterminizer::Finalize() { using namespace fst; // The lattice determinization only needs to be finalized once if (determinization_finalized_) diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index a139635cfef..395afc24bec 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -79,6 +79,17 @@ namespace kaldi { refer to the paper for explanation. In the already-determinized part this means a redeterminized state which is not final. + canonical appended lattice: This is the appended compact lattice + that we conceptually have (i.e. what we described in the paper). + The difference from the "actual appended lattice" is that the + actual appended lattice has all its final-arcs replaced with + final-probs (we keep the real final-arcs "on the side" in a + separate data structure). + + final-arc: An arc in the canonical appended CompactLattice which + goes to a final-state. These arcs will have `state-labels` as + their labels. + */ struct LatticeIncrementalDecoderConfig { // All the configuration values until det_opts are the same as in @@ -183,8 +194,133 @@ struct LatticeIncrementalDecoderConfig { } }; -template -class LatticeIncrementalDeterminizer; + + +/** + This class is used inside LatticeIncrementalDecoderTpl; it handles + some of the details of incremental determinization. + https://www.danielpovey.com/files/ *TBD*.pdf for the paper. + +*/ +class LatticeIncrementalDeterminizer2 { + public: + using Label = typename LatticeArc::Label; /* Actualy the same labels appear + in both lattice and compact + lattice, so we don't use the + specific type all the time but + just say 'Label' */ + + LatticeIncrementalDeterminizer2(const LatticeIncrementalDecoderConfig &config, + const TransitionModel &trans_model); + + // Resets the lattice determinization data for new utterance + void Init(); + + // Returns the current determinized lattice. + const CompactLattice &GetDeterminizedLattice() const { return clat_; } + + /** + Starts the process of creating a raw lattice chunk. (Search the glossary + for "raw lattice chunk"). This just sets up the initial states and + redeterminized-states in the chunk. Relates to sec. 5.2 in the paper, + specifically the initial-state i and the redeterminized-states. + + After calling this, the caller would add the remaining arcs and states + to `olat` and then call AcceptChunk() with the result. + + @param [out] olat The lattice to be (partially) created + + @param [out] token_label2state This function outputs to here + a map from `token-label` to the state we created for + it in *olat. See glossary for `token-label`. + The keys actually correspond to the .nextstate fields + in the arcs in final_arcs_; values are states in `olat`. + See the last bullet point before Sec. 5.3 in the paper. + */ + void InitializeRawLatticeChunk( + Lattice *olat, + unordered_map *token_label2state); + + /** + This function accepts the raw FST (state-level lattice) corresponding to a + single chunk of the lattice, determinizes it and appends it to this->clat_. + Unless this was the + + Note: final-probs in `raw_fst` are treated specially: they are used to + guide the pruned determinization, but when you call GetLattice() it will be + -- except for pruning effects-- as if all nonzero final-probs in `raw_fst` + were: One() if final_costs == NULL; else the value present in `final_costs`. + + @param [in] raw_fst (Consumed destructively). The input + raw (state-level) lattice. Would correspond to the + FST A in the paper if first_frame == 0, and B + otherwise. + @param [in] final_costs Final-costs that the user wants to + be included in clat_. These replace the values present + in the Final() probs in raw_fst whenever there was + a nonzero final-prob in raw_fst. (States in raw_fst + that had a final-prob will still be non-final). + + @return returns false if determinization finished earlier than the beam, + true otherwise. + */ + bool AcceptRawLatticeChunk(Lattice *raw_fst, + const std::unordered_map *final_costs = NULL); + + + const CompactLattice &GetLattice() { return clat_; } + + private: + + // kTokenLabelOffset is where we start allocating labels corresponding to Tokens + // (these correspond with raw lattice states); + // kStateLabelOffset is what we add to state-ids in clat_ to produce labels + // to identify them in the raw lattice chunk + enum { kStateLabelOffset = (int)1e8, kTokenLabelOffset = (int)2e8, kMaxTokenLabel = (int)3e8 }; + + Label next_state_label_; + + // clat_ is the appended lattice (containing all chunks processed so + // far), except its `final-arcs` (i.e. arcs which in the canonical + // lattice would go to final-states) are not present (they are stored + // separately in final_arcs_) and states which in the canonical lattice + // should have final-arcs leaving them will instead have a final-prob. + CompactLattice clat_; + + // The elements of this set are the redeterminized-states which are not final in + // the canonical appended lattice. This means the set of .first elements in + // final_arcs, plus whatever states in clat_ are reachable from such states. + // (The final redeterminized states/splice-states are never actually + // materialized.) + std::unordered_set non_final_redet_states_; + + + // final_arcs_ contains arcs which would appear in the canonical appended + // lattice but for implementation reasons are not physically present in clat_. + // These are arcs to final states in the canonical appended lattice. The + // .first elements are the source states in clat_ (these will all be elements + // of non_final_redet_states_); the .nextstate elements of the arcs does not + // contain a physical state, but contain state-labels allocated by + // AllocateNewStateLabel(). + std::vector final_arcs_; + + + // final_weights_ contain the final-probs of states that are final in the + // canonical compact lattice. Physically it maps from the state-labels which + // are allocated by AllocateNewStateLabel() and are stored in the .nextstate + // in final_arcs_, to the weight that would be on that final-state in the + // canonical compact lattice. + std::unordered_map final_weights_; + + // forward_costs_, indexed by the state-id in clat_, stores the alpha + // (forward) costs, i.e. the minimum cost from the start state to each state + // in clat_. This is relevant for pruned determinization. The BaseFloat can + // be thought of as the sum of a Value1() + Value2() in a LatticeWeight. + std::vector forward_costs_; + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDeterminizer2); +}; + /** This is an extention to the "normal" lattice-generating decoder. See \ref lattices_generation \ref decoders_faster and \ref decoders_simple @@ -285,38 +421,41 @@ class LatticeIncrementalDecoderTpl { still have to do a fair amount of work; calling it every, say, 10 to 40 frames would make sense though. - @param [in] use_final_probs If true *and* at least one final-state in HCLG - was active on the most recently decoded frame, include the - final-probs from the decoding FST (HCLG) in the lattice. - Otherwise treat all final-costs of states active on the - most recent frame as zero (i.e. use Weight::One()). You - can tell whether a final-prob was active on the most - recent frame by calling ReachedFinal(). - Setting use_final_probs will not affect the lattices - output by subsequent calls to this function. (TODO: - verify this). + @param [in] use_final_probs True if you want the final-probs + of HCLG to be included in the output lattice. + (However, if no state was final on frame + `num_frames_to_include` they won't be included regardless + of use_final_probs; if this equals NumFramesDecoded() you + can test this with ReachedFinal()). Caution: + it is an error to call this function with + the same num_frames_to_include and different values + of `use_final_probs`. (This is not a fundamental + limitation but just the way we coded it.) @param [in] num_frames_to_include The number of frames that you want to be included in the lattice. Must be >0 and - <= NumFramesDecoded(). If you are calling this - just to keep the incremental lattice determinization up to - date and don't really need the lattice (olat == NULL), you - will probably want to give it some delay (at least 5 or 10 - frames); search for determinize-delay in the paper - and for determinize_delay in the configuration class and the - code. You may not call this with a num_frames_to_include - that is smaller than the largest value previously - provided. - - @param [out] olat The CompactLattice representing what has been decoded + <= NumFramesDecoded(). If you are calling this just to + keep the incremental lattice determinization up to date and + don't really need the lattice now or don't need it to be up + to date, you will probably want to make + num_frames_to_include at least 5 or 10 frames less than + NumFramessDecoded(); search for determinize-delay in the + paper and for determinize_delay in the configuration class + and the code. You may not call this with a + num_frames_to_include that is smaller than the largest + value previously provided. Calling it with an + only-slightly-larger version than the last time (e.g. just + a few frames larger) is probably not a good use of + computational resources. + + @return clat The CompactLattice representing what has been decoded up until `num_frames_to_include` (e.g., LatticeStateTimes() on this lattice would return `num_frames_to_include`). - If NULL, the lattice won't be output, and this will save - the work of copying it, but the incremental determinization - will still be done. + */ - void GetLattice(bool use_final_probs, int32 num_frames_to_include, - CompactLattice *olat = NULL); + const CompactLattice &GetLattice(bool use_final_probs, + int32 num_frames_to_include); + /** @@ -414,10 +553,11 @@ class LatticeIncrementalDecoderTpl { LatticeIncrementalDecoderConfig config_; /** Much of the the incremental determinization algorithm is encapsulated in the determinize_ object. */ - LatticeIncrementalDeterminizer determinizer_; - /** last_get_lattice_frame_ is the highest `num_frames_to_include_` argument + LatticeIncrementalDeterminizer2 determinizer_; + + /** num_frames_in_lattice_ is the highest `num_frames_to_include_` argument for any prior call to GetLattice(). */ - int32 last_get_lattice_frame_; + int32 num_frames_in_lattice_; // a map from Token to its token_label unordered_map token2label_map_; // we allocate a unique id for each Token @@ -489,201 +629,6 @@ class LatticeIncrementalDecoderTpl { typedef LatticeIncrementalDecoderTpl LatticeIncrementalDecoder; -/** - This class is used inside LatticeIncrementalDecoderTpl; it handles - some of the details of incremental determinization. - https://www.danielpovey.com/files/ *TBD*.pdf for the paper. - -*/ -template -class LatticeIncrementalDeterminizer { - public: - using Arc = typename FST::Arc; - using Label = typename Arc::Label; - using StateId = typename Arc::StateId; - using Weight = typename Arc::Weight; - - LatticeIncrementalDeterminizer(const LatticeIncrementalDecoderConfig &config, - const TransitionModel &trans_model); - - // Resets the lattice determinization data for new utterance - void Init(); - - // Returns the current determinized lattice. - const CompactLattice &GetDeterminizedLattice() const { return clat_; } - - - /** - Starts the process of creating a raw lattice chunk. (Search the glossary - for "raw lattice chunk"). This just sets up the initial states and - redeterminized-states in the chunk. Relates to sec. 5.2 in the paper, - specifically the initial-state i and the redeterminized-states. - - After calling this, the caller would add the remaining arcs and states - to `olat` and then call AcceptChunk() with the result. - - @param [out] olat The lattice to be (partially) created - @param [in] token_label2final_cost For each token-label, - contains a cost that we will need to negate and then - introduce into newly created arcs in 'olat' that correspond - to arcs with that token-label on in the previous - determinized chunk. (They won't actually have the token - label as we remove them at this point). This relates - to an issue not discussed in the paper, which is - that to get pruned determinization to work right we - have to introduce special final-probs when determinizing - the previous chunk (think of the previous chunk as - the FST A in the paper). This map allows us to cancel - out those final-probs. - @param [out] token_label2state For each token-label (say, t) - that appears in lat_ (i.e. the result of determinizing previous - chunks), this identifies the state in `olat` that we allocate - for that token-label. This is a so-called `splice-state`; they - will never have arcs leaving them within `lat_`. When the - calling code processes the arcs in the raw lattice, it will add - arcs leaving these splice states. - See the last bullet point before Sec. 5.3 in the paper. - */ - void InitializeRawLatticeChunk( - Lattice *olat, - const unordered_map &token_label2final_cost, - unordered_map *token_label2state); - - /** - This function accepts the raw FST (state-level lattice) corresponding - to a single chunk of the lattice, determinizes it and appends it to - this->clat_. - - @param [in] first_frame The start frame-index, which equals the - total number of frames in all chunks previous to this one. - Only needed to ask "is this the first chunk", plus - debug info. - @param [in] last_frame The end frame-index, which equals the - total number of frames in all previous chunks plus - this one. Only needed for debug. - @param [in] raw_fst (Consumed destructively). The input - raw (state-level) lattice. Would correspond to the - FST A in the paper if first_frame == 0, and B - otherwise. - - @return returns false if determinization finished earlier than the beam, - true otherwise. - */ - bool AcceptRawLatticeChunk(int32 first_frame, int32 last_frame, Lattice *raw_fst); - - - /** - Finalize incremental decoding by pruning the lattice (if - config_.final_prune_after_determinize), otherwise just removing unreachable - states. - */ - void Finalize(); - - - private: - - /** - // Step 3 of incremental determinization, - // which is to append the new chunk in clat to the old one in lat_ - // If not_first_chunk == false, we do not need to append and just copy - // clat into olat - // Otherwise, we need to connect states of the last frame of - // the last chunk to states of the first frame of this chunk. - // These post-initial and pre-final states are corresponding to the same Token, - // guaranteed by unique state labels. - NOTE clat must be top sorted. - */ - void AppendLatticeChunks(const CompactLattice &clat, bool first_chunk); - - - /** - In the paper, recall from Sec. 5.2 that some states in det(A) (specifically: - redeterminized state) are also included in B. In the paper we just assumed - that the same state-ids were used, but in practice they are different numbers; - in redeterminized_state_map_ we store the mapping from state-id in det(A)==clat_ to - the state-id in B==raw_lat_chunk. The map is re-initialized each time we - process a new chunk. This function maps from the the state-id in clat_ - to the state_id in `raw_lat_chunk`, adding to the map and creating a new state - in `raw_lat_chunk` if it was not already present. - - @param [in] redet_state State-id of a redeterminized-state in clat_ - @param [in] raw_lat_chunk The raw lattice that we are creating; - this function may add a new state to it. - @param [out] state_id If non-NULL, the state-id in `raw_lat_chunk` - will be output to here - @return Returns true if a new state was created and added to the map - */ - bool FindOrAddRedeterminizedState( - CompactLattice::StateId redet_state, - Lattice *raw_lat_chunk, - Lattice::StateId *state_id = NULL); - - /** - TODO - */ - void ProcessRedeterminizedState( - Lattice::StateId state, - const unordered_map &token_label2final_cost, - unordered_map *token_label2state, - Lattice *raw_lat_chunk); - - // This function is to preprocess the appended compact lattice before - // generating raw lattices for the next chunk. - // After identifying redeterminized states, for any such state that is separated by - // more than config_.redeterminize_max_frames from the end of the current - // appended lattice, we create an extra state for it; we add an epsilon arc - // from that redeterminized state to the extra state; we copy any final arcs from - // the pre-final state to its extra state and we remove those final arcs from - // the original pre-final state. - // We also copy arcs meet the following requirements: i) destination-state of the - // arc is prefinal state. ii) destination-state of the arc is no further than than - // redeterminize_max_frames from the most recent frame we are determinizing. - // Now this extra state is the pre-final state to - // redeterminize and the original pre-final state does not need to redeterminize - // The epsilon would be removed later on in AppendLatticeChunks, while - // splicing the compact lattices together - void GetRedeterminizedStates(); - - const LatticeIncrementalDecoderConfig config_; - const TransitionModel &trans_model_; // keep it for determinization - - // Record whether we have finished determinized the whole utterance - // (including re-determinize) - bool determinization_finalized_; - - /** - final_arc_list_ is a map from each non-final redeterminized-state - in clat_ to the arc-index of its first arc to a final state. - */ - unordered_map final_arc_list_; - // alpha of each state in lat_ - std::vector forward_costs_; - // we allocate a unique id for each source-state of the last arc of a series of - // initial arcs in InitializeRawLatticeChunk - int32 state_last_initial_offset_; - - // We define a state in the appended lattice as a 'redeterminized-state' (meaning: - // one that will be redeterminized), if it is: a pre-final state, or there - // exists an arc from a redeterminized state to this state. We keep reapplying - // this rule until there are no more redeterminized states. The final state - // is not included. These redeterminized states will be stored in this map - // which is a map from the state in the appended compact lattice to the - // state_copy in the newly-created raw lattice. - unordered_map redeterminized_state_map_; - - /** - // It is a map used in GetRedeterminizedStates (see the description there) - // A map from the original pre-final state to the pre-final states (i.e. the - // original pre-final state or an extra state generated by - // GetRedeterminizedStates) used for generating raw lattices of the next chunk. - */ - unordered_map processed_prefinal_states_; - - // The compact lattice we obtain. It should be cleared before processing a new - // utterance - CompactLattice clat_; - KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDeterminizer); -}; } // end namespace kaldi. From 5f27eb9fc5b81c7949ea7da463076571d8cabcfd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 9 Nov 2019 13:47:19 -0800 Subject: [PATCH 43/60] Storing some intermediate work --- src/decoder/lattice-incremental-decoder.cc | 924 ++++++++------------- src/decoder/lattice-incremental-decoder.h | 158 ++-- 2 files changed, 410 insertions(+), 672 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 9949f094ad7..3e16c19b5b2 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -1,6 +1,6 @@ // decoder/lattice-incremental-decoder.cc -// Copyright 2019 Zhehuai Chen +// Copyright 2019 Zhehuai Chen, Daniel Povey // See ../../COPYING for clarification regarding multiple authors // @@ -75,13 +75,10 @@ void LatticeIncrementalDecoderTpl::InitDecoding() { toks_.Insert(start_state, start_tok); num_toks_++; - num_frames_in_lattice_ = 0; - token2label_map_.clear(); - token2label_map_.reserve(std::min((int32)1e5, config_.max_active)); - token_label_available_idx_ = config_.max_word_id + 1; - token_label2final_cost_.clear(); determinizer_.Init(); - + num_frames_in_lattice_ = 0; + tokentlabel_map_.clear(); + next_token_label_ = LatticeIncrementalDeterminizer2::kTokenLabelOffset; ProcessNonemitting(config_.beam); } @@ -411,13 +408,11 @@ void LatticeIncrementalDecoderTpl::PruneForwardLinksFinal() { template BaseFloat LatticeIncrementalDecoderTpl::FinalRelativeCost() const { - if (!decoding_finalized_) { + if (NumFramesDecoded() != final_cost_frame_) { BaseFloat relative_cost; ComputeFinalCosts(NULL, &relative_cost, NULL); return relative_cost; } else { - // we're not allowed to call that function if FinalizeDecoding() has - // been called; return a cached value. return final_relative_cost_; } } @@ -936,41 +931,160 @@ void LatticeIncrementalDecoderTpl::TopSortTokens( } template -void LatticeIncrementalDecoderTpl::GetLattice(bool use_final_probs, - int32 last_frame_of_chunk, - CompactLattice *olat) { - olat->DeleteStates(); /* Clear the FST */ - KALDI_ASSERT(olat->Start() == fst::kNoStateId); // TODO: remove - using namespace fst; - bool first_chunk = num_frames_in_lattice_ == 0; - - KALDI_ASSERT(num_frames_in_lattice_ <= last_frame_of_chunk); - if (num_frames_in_lattice_ < last_frame_of_chunk) { - Lattice raw_fst; - // step 1: Get lattice chunk with initial and final states - // In this function, we do not create the initial state in - // the first chunk, and we do not create the final state in the last chunk - if (!GetIncrementalRawLattice(&raw_fst, use_final_probs, num_frames_in_lattice_, - last_frame_of_chunk, !first_chunk, - !decoding_finalized_)) - KALDI_ERR << "Unexpected problem when getting lattice"; - // step 2-3 - determinizer_.AcceptRawLatticeChunk(num_frames_in_lattice_, - last_frame_of_chunk, &raw_fst); - num_frames_in_lattice_ = last_frame_of_chunk; - } else if (num_frames_in_lattice_ > last_frame_of_chunk) { - KALDI_WARN << "Call GetLattice up to frame: " << last_frame_of_chunk - << " while the determinizer_ has already done up to frame: " - << num_frames_in_lattice_; +const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( + int32 num_frames_to_include, + bool use_final_probs, bool finalize) { + + if (num_frames_to_include == num_frames_in_lattice_) { + // We've already obtained the lattice up to here. + KALDI_ASSERT(finalize == decoding_finalized_); + return determinizer_.GetLattice(); + } + + if (num_frames_to_include < num_frames_in_lattice_ || + num_frames_to_include > NumFramesDecoded()) { + KALDI_ERR << "GetLattice() called with num-frames-to-include = " + << num_frames_to_include << " but already determinized " + << num_frames_in_lattice_ << " frames and " + << NumFramesDecoded() << " frames decoded so far."; + } + KALDI_ASSERT(!decoding_finalized_); + if (finalize) + FinalizeDecoding(); // does pruning of the raw lattice. + + if (num_frames_to_include < NumFramesDecoded() && + (use_final_probs || finalize)) { + KALDI_ERR << "You cannot set use_final_probs or finalize if not requesting " + "all the frames decoded so far."; + } + + Lattice chunk_lat; + + unordered_map token_label2state; + if (num_frames_in_lattice_ != 0) { + determinizer_.InitializeRawLatticeChunk(&chunk_lat, + &token_label2state); } - if (decoding_finalized_) - determinizer_.Finalize(); + // tok_map will map from Token* to state-id in chunk_lat. + // The cur and prev versions alternate on different frames. + unordered_map &tok2state_map(temp_token_map_); + tok2state_map.clear(); + + unordered_map &next_token2label_map(token2label_map_temp_); + next_token2label_map_.clear(); + + + { // Deal with the last frame. We allocate token labels, and set tokens as + // final, but don't add any transitions. This may leave some states + // disconnected (e.g. due to chains of nonemitting arcs), but it's OK; we'll + // fix it when we generate the next chunk of lattice. + int32 frame = num_frames_to_include; + // Allocate state-ids for all tokens on this frame. + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + StateId state = chunk_lat.AddState(); + tok2state_map[tok] = state; + next_token2label_map[tok] = AllocateNewTokenLabel(); + // First imagine extra_cost == 0, which it would be if + // num_frames_to_include == NumFramesDecoded(). We use a final-prob + // (i.e. beta) that is the negative of the token's total cost. This + // ensures that all tokens on the final frame are the 'best token' / have + // the same best-path cost. This is done for pruning purposes, so we + // never prune anything out that's active on the last frame. For earlier + // frames than the final one, the extra cost is equal to the beta (==cost + // to the end), assuming we had set the betas on last frame to the + // negatives of the alphas. + chunk_lat.SetFinal(state, + LatticeWeight(-(tok->tot_cost + tok->extra_cost), 0.0)); + } + } - if (olat) - *olat = determinizer_.GetDeterminizedLattice(); + // Go in reverse order over the remaining frames so we can create arcs as we + // go, and their destination-states will already be in the map. + for (int32 frame = num_frames_to_include - 1; + frame >= num_frames_in_lattice_; frame--) { + BaseFloat cost_offset = cost_offsets_[f]; + + // For the first frame of the chunk, we need to make sure the states are + // the ones created by InitializeRawLatticeChunk() (where not pruned away). + if (frame == num_frames_in_lattice_ && num_frames_in_lattice_ != 0) { + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + auto iter = token2label_map_.find(tok); + KALDI_ASSERT(iter != token2label_map_.end()); + Label token_label = iter->second; + auto iter2 = token_label2state.find(token_label); + if (iter2 != token_label2state.end()) { + StateId state = iter2->second; + tok2state_map[tok] = state; + } else { + // Some states may have been pruned out, but we should still allocate + // them. They might have been part of chains of nonemitting arcs + // where the state became disconnected because the last chunk didn't + // include arcs starting at this frame. + StateId state = chunk_lat.AddState(); + tok2state_map[tok] = state; + } + } + } else { + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + StateId state = chunk_lat.AddState(); + tok2state_map[tok] = state; + } + } + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + auto iter = tok2state_map.find(tok); + KALDI_ASSERT(iter != tok2state_map.end()); + StateId cur_state = iter->second; + for (ForwardLinkT *l = tok->links; l != NULL; l = l->next) { + auto next_iter = tok2state_map.find(l->next_tok); + KALDI_ASSERT(next_iter != tok2state_map.end()); + StateId next_state = next_iter->second; + BaseFloat this_offset = (l->ilabel != 0 ? cost_offset : 0); + LatticeArc arc(l->ilabel, l->olabel, + Weight(l->graph_cost, l->acoustic_cost - cost_offset), + next_state); + chunk_lat.AddArc(state, arc); + } + } + } + if (num_frames_in_lattice_ == 0) { + std::vector tok_list; + TopSortTokens(active_toks_[0], &tok_list); + Tok *start_token = tok_list[0]; + auto iter = tok2state_map.find(start_token); + KALDI_ASSERT(iter != tok2state_map.end()); + StateId start_state = iter->second; + chunk_lat.SetStart(start_state); + } + token2label_map_.swap(next_token2label_map); + + std::unordered_map final_costs; + if (use_final_probs) { + final_cost_frame_ = num_frames_to_include; + ComputeFinalCosts(&final_costs_, &final_relative_cost_, + &final_best_cost_); + for (auto iter = final_costs_.begin(); iter != final_costs_.end(); + ++iter) { + Token *tok = iter->first; + BaseFloat final_cost = iter->second; + auto iter2 = tok2state_map.find(tok); + KALDI_ASSERT(iter2 != tok2state_map.end()); + StateId lat_state = iter2->second; + bool inserted = final_costs.insert({lat_state, final_cost}).second; + KALDI_ASSERT(inserted); + } + } + + + bool finished_before_beam = + determinizer_.AcceptRawLatticeChunk(chunk_lat, + (use_final_probs ? &final_costs : NULL)); + + num_frames_in_lattice_ = num_frames_to_include; + return determinizer_.GetLattice(); } + template bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( Lattice *ofst, bool use_final_probs, int32 frame_begin, int32 frame_end, @@ -985,12 +1099,12 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( << "GetIncrementalRawLattice() with use_final_probs == false"; unordered_map final_costs_local; - const unordered_map &final_costs = (decoding_finalized_ ? final_costs_ : final_costs_local); if (!decoding_finalized_ && use_final_probs) ComputeFinalCosts(&final_costs_local, NULL, NULL); + ofst->DeleteStates(); unordered_map token_label2state; // for InitializeRawLatticeChunk @@ -1001,8 +1115,9 @@ bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( // num-frames plus one (since frames are one-based, and we have // an extra frame for the start-state). KALDI_ASSERT(frame_end > 0); - const int32 bucket_count = num_toks_ / 2 + 3; - unordered_map tok_map(bucket_count); + unordered_map &tok_map(temp_token_map_); + tok_map.clear(); + // First create all states. std::vector token_list; for (int32 f = frame_begin; f <= frame_end; f++) { @@ -1140,254 +1255,6 @@ int32 LatticeIncrementalDecoderTpl::GetNumToksForFrame(int32 frame) return r; } -LatticeIncrementalDeterminizer::LatticeIncrementalDeterminizer( - const LatticeIncrementalDecoderConfig &config, const TransitionModel &trans_model) - : config_(config), trans_model_(trans_model) {} - -void LatticeIncrementalDeterminizer::Init() { - final_arc_list_.clear(); - clat_.DeleteStates(); - determinization_finalized_ = false; - forward_costs_.clear(); - state_label_offset_ = 2 * config_.max_word_id; - redeterminized_state_map_.clear(); - processed_prefinal_states_.clear(); -} - -bool LatticeIncrementalDeterminizer::FindOrAddRedeterminizedState( - Lattice::StateId nextstate, Lattice *olat, Lattice::StateId *nextstate_copy) { - using namespace fst; - bool modified = false; - LatticeArc::StateId nextstate_insert = kNoStateId; - auto r = redeterminized_state_map_.insert({nextstate, nextstate_insert}); - if (r.second) { // didn't exist, successfully insert here - // create a new state w.r.t state - nextstate_insert = olat->AddState(); - // map from arc.nextstate to nextstate_insert - r.first->second = nextstate_insert; - modified = true; - } else { // else already exist - // get nextstate_insert - nextstate_insert = r.first->second; - KALDI_ASSERT(nextstate_insert != kNoStateId); - modified = false; - } - if (nextstate_copy) *nextstate_copy = nextstate_insert; - return modified; -} - -void LatticeIncrementalDeterminizer::ProcessRedeterminizedState( - Lattice::StateId state, - const unordered_map &token_label2final_cost, - unordered_map *token_label2state, - Lattice *olat) { - using namespace fst; - typedef LatticeArc Arc; - typedef Arc::StateId StateId; - typedef Arc::Weight Weight; - typedef Arc::Label Label; - - auto r = redeterminized_state_map_.find(state); - KALDI_ASSERT(r != redeterminized_state_map_.end()); - auto state_copy = r->second; - KALDI_ASSERT(state_copy != kNoStateId); - ArcIterator aiter(clat_, state); - - // use state_label in initial arcs - int state_label = state + state_label_offset_; - // Moreover, we need to use the forward coast (alpha) of this determinized and - // appended state to guide the determinization later - KALDI_ASSERT(state < forward_costs_.size()); - auto alpha_cost = forward_costs_[state]; - Arc arc_initial(0, state_label, LatticeWeight(0, alpha_cost), state_copy); - Lattice::StateId start_state = olat->Start(); - if (alpha_cost != std::numeric_limits::infinity()) - olat->AddArc(start_state, arc_initial); - - for (; !aiter.Done(); aiter.Next()) { - const auto &arc = aiter.Value(); - auto laststate_copy = kNoStateId; - bool proc_nextstate = false; - auto arc_weight = arc.weight; - - KALDI_ASSERT(arc.olabel == arc.ilabel); - auto arc_olabel = arc.olabel; - - // the destination of the arc is a final -> a "splice state". - if (clat_.Final(arc.nextstate) != CompactLatticeWeight::Zero()) { - KALDI_ASSERT(arc_olabel > config_.max_word_id && - arc_olabel < state_label_offset_); // token label - // create a initial arc - - // Get arc weight here - // We will include it in arc_last in the following - CompactLatticeWeight weight_offset; - // To cancel out the weight on the final arcs, which is (extra cost - forward - // cost). - // see token_label2final_cost for more details - const auto r = token_label2final_cost.find(arc_olabel); - KALDI_ASSERT(r != token_label2final_cost.end()); - auto cost_offset = r->second; - weight_offset.SetWeight(LatticeWeight(0, -cost_offset)); - // The arc weight is a combination of original arc weight, above cost_offset - // and the weights on the final state - arc_weight = Times(Times(arc_weight, clat_.Final(arc.nextstate)), weight_offset); - - // We create a respective destination state for each final arc - // later we will connect it to the state correponding to the token w.r.t - // arc_olabel - laststate_copy = olat->AddState(); - // the destination state of the last of the sequence of arcs will be recorded - // and connected to the state corresponding to token w.r.t arc_olabel - // Notably, we have multiple states for one token label after determinization, - // hence we use multiset here - token_label2state->insert( - std::pair(arc_olabel, laststate_copy)); - arc_olabel = 0; // remove token label - } else { - // the arc connects to a non-final state (redeterminized state) - KALDI_ASSERT(arc_olabel < config_.max_word_id); // no token label - KALDI_ASSERT(arc_olabel); - // get the nextstate_copy w.r.t arc.nextstate - StateId nextstate_copy = kNoStateId; - proc_nextstate = FindOrAddRedeterminizedState(arc.nextstate, olat, &nextstate_copy); - KALDI_ASSERT(nextstate_copy != kNoStateId); - laststate_copy = nextstate_copy; - } - auto &state_seqs = arc_weight.String(); - // create new arcs w.r.t arc - // the following is for a normal arc - // We generate a linear sequence of arcs sufficient to contain all the - // transition-ids on the string - auto prev_state = state_copy; // from state_copy - for (auto &j : state_seqs) { - auto cur_state = olat->AddState(); - Arc arc(j, 0, LatticeWeight::One(), cur_state); - olat->AddArc(prev_state, arc); - prev_state = cur_state; - } - - // connect previous sequence of arcs to the laststate_copy - // the weight on the previous arc is stored in the arc to laststate_copy here - Arc arc_last(0, arc_olabel, arc_weight.Weight(), laststate_copy); - olat->AddArc(prev_state, arc_last); - - // not final state && previously didn't process this state - - // TODO: verify that the following call is not necessary. - if (proc_nextstate) - ProcessRedeterminizedState(arc.nextstate, - token_label2final_cost, - token_label2state, olat); - } -} -void LatticeIncrementalDeterminizer::GetRedeterminizedStates() { - using namespace fst; - processed_prefinal_states_.clear(); - // go over all prefinal state - KALDI_ASSERT(final_arc_list_.size()); - unordered_set prefinal_states; - - for (auto &i : final_arc_list_) { - auto prefinal_state = i.first; - ArcIterator aiter(clat_, prefinal_state); - KALDI_ASSERT(clat_.NumArcs(prefinal_state) > i.second); - aiter.Seek(i.second); - auto final_arc = aiter.Value(); - auto final_weight = clat_.Final(final_arc.nextstate); - KALDI_ASSERT(final_weight != CompactLatticeWeight::Zero()); - auto num_frames = Times(final_arc.weight, final_weight).String().size(); - // If the state is too far from the end of the current appended lattice, - // we leave the non-final arcs unchanged and only redeterminize the final - // arcs by the following procedure. - // We also do above things once we prepare to redeterminize the start state. - if (num_frames <= config_.redeterminize_max_frames && prefinal_state != 0) - processed_prefinal_states_[prefinal_state] = prefinal_state; - else { - KALDI_VLOG(7) << "Impose a limit of " << config_.redeterminize_max_frames - << " on how far back in time we will redeterminize states. " - << num_frames << " frames in this arc. "; - - auto new_prefinal_state = clat_.AddState(); - forward_costs_.resize(new_prefinal_state + 1); - forward_costs_[new_prefinal_state] = forward_costs_[prefinal_state]; - - std::vector arcs_remaining; - for (aiter.Reset(); !aiter.Done(); aiter.Next()) { - auto arc = aiter.Value(); - bool remain_the_arc = true; // If we remain the arc, the state will not be - // re-determinized, vice versa. - if (arc.olabel > config_.max_word_id) { // final arc - KALDI_ASSERT(arc.olabel < state_label_offset_); - KALDI_ASSERT(clat_.Final(arc.nextstate) != CompactLatticeWeight::Zero()); - remain_the_arc = false; - } else { - int num_frames_exclude_arc = num_frames - arc.weight.String().size(); - // destination-state of the arc is further than redeterminize_max_frames - // from the most recent frame we are determinizing - if (num_frames_exclude_arc > config_.redeterminize_max_frames) - remain_the_arc = true; - else { - // destination-state of the arc is no further than - // redeterminize_max_frames from the most recent frame we are - // determinizing - auto r = final_arc_list_.find(arc.nextstate); - // destination-state of the arc is not prefinal state - if (r == final_arc_list_.end()) remain_the_arc = true; - // destination-state of the arc is prefinal state - else - remain_the_arc = false; - } - } - - if (remain_the_arc) - arcs_remaining.push_back(arc); - else - clat_.AddArc(new_prefinal_state, arc); - } - CompactLatticeArc arc_to_new(0, 0, CompactLatticeWeight::One(), - new_prefinal_state); - arcs_remaining.push_back(arc_to_new); - - clat_.DeleteArcs(prefinal_state); - for (auto &i : arcs_remaining) - clat_.AddArc(prefinal_state, i); - processed_prefinal_states_[prefinal_state] = new_prefinal_state; - } - } - KALDI_VLOG(8) << "states of the lattice after GetRedeterminizedStates: " - << clat_.NumStates(); -} - -// This function is specifically designed to obtain the initial arcs for a chunk -// We have multiple states for one token label after determinization -void LatticeIncrementalDeterminizer::InitializeRawLatticeChunk( - Lattice *olat, - const unordered_map &token_label2final_cost, - unordered_map *token_label2state) { - using namespace fst; - typedef LatticeArc Arc; - typedef Arc::StateId StateId; - typedef Arc::Weight Weight; - typedef Arc::Label Label; - - GetRedeterminizedStates(); - - olat->DeleteStates(); - token_label2state->clear(); - - auto start_state = olat->AddState(); - olat->SetStart(start_state); - // go over all prefinal states after preprocessing - for (auto &i : processed_prefinal_states_) { - auto prefinal_state = i.second; - bool modified = FindOrAddRedeterminizedState(prefinal_state, olat); - if (modified) - ProcessRedeterminizedState(prefinal_state, - token_label2final_cost, - token_label2state, olat); - } -} @@ -1423,6 +1290,13 @@ static void AddCompactLatticeArcToLattice( } +void LatticeIncrementalDeterminizer2::Init() { + non_final_redet_states_.clear(); + clat_.DeleteStates(); + final_arcs_.clear(); + forward_costs_.clear(); +} + /** Reweights a compact lattice chunk in a way that makes the combination with the current compact lattice easier. Also removes some temporary @@ -1476,13 +1350,13 @@ void LatticeIncrementalDeterminizer2::ReweightChunk( /** Identifies states in `chunk_clat` that have arcs entering them with a - `token-label` on them (see glossary in header for definition). - It produces a map from such states in chunk_clat, to the `token-label` - on arcs entering them. (It is not possible that the same state would - have multiple arcs entering it with different token-labels, or - some arcs entering with one token-label and some another, or be - both initial and have such arcs; this is true due to how we construct - the raw lattice.) + `token-label` on them (see glossary in header for definition). We're calling + these `token-final` states. This function outputs a map from such states in + chunk_clat, to the `token-label` on arcs entering them. (It is not possible + that the same state would have multiple arcs entering it with different + token-labels, or some arcs entering with one token-label and some another, or + be both initial and have such arcs; this is true due to how we construct the + raw lattice.) */ void LatticeIncrementalDeterminizer2::IdentifyTokenFinalStates( const CompactLattice &chunk_clat, @@ -1508,6 +1382,37 @@ void LatticeIncrementalDeterminizer2::IdentifyTokenFinalStates( + +void LatticeIncrementalDeterminizer2::GetNonFinalRedetStates() { + using StateId = CompactLatticeArc::StateId; + non_final_redet_states_.clear(); + non_final_redet_states_.reserve(final_arcs_.size()); + + std::vector state_queue; + for (const CompactLatticeArc &arc: final_arcs_) { + // Note: we abuse the .nextstate field to store the state which is really + // the source of that arc. + StateId redet_state = arc.nextstate; + if (non_final_redet_states_.insert(redet_state).second) { + // it was not already there + state_queue.push_back(state); + } + } + // Add any states that are reachable from the states above. + while (!state_queue.empty()) { + StateId s = state_queue.back(); + state_queue.pop_back(); + for (fst::ArcIterator aiter(clat_, s); !aiter.Done(); + aiter.Next()) { + const CompactLatticeArc &arc = aiter.Value(); + StateId nextstate = arc.nextstate; + if (non_final_redet_states_.insert(nextstate).second) + state_queue.push_back(nextstate); // it was not already there + } + } +} + + void LatticeIncrementalDeterminizer2::InitializeRawLatticeChunk( Lattice *olat, unordered_map *token_label2state) { @@ -1518,7 +1423,7 @@ void LatticeIncrementalDeterminizer2::InitializeRawLatticeChunk( LatticeArc::State start_state = olat->AddState(); token_label2state->clear(); - // redet_state_map maps from state-ids in clat_ to state-ids in olat_. + // redet_state_map maps from state-ids in clat_ to state-ids in olat. unordered_map redet_state_map; for (CompactLatticeArc::StateId redet_state: non_final_redet_states_) @@ -1605,12 +1510,10 @@ void LatticeIncrementalDeterminizer2::InitializeRawLatticeChunk( bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( Lattice *raw_fst, - std::unordered_map *new_final_costs) { + const std::unordered_map *new_final_costs) { using Label = CompactLatticeArc::Label; using StateId = CompactLatticeArc::StateId; - bool first_chunk = (first_frame == 0); - // final_costs is a map from a `token-label` (see glossary) to the // associated final-prob in a final-state of `raw_fst`, that is associated with @@ -1620,66 +1523,68 @@ bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( // Those costs will not be included in anything we store in this class; // we will use `old_final_costs` later to cancel them out. std::unordered_map old_final_costs; - if (!is_last_chunk) { - StateId raw_fst_num_states = raw_fst->NumStates(); - for (LatticeArc::StateId s = 0; s < raw_fst_num_states; s++) { - for (ArcIterator aiter(*raw_fst, s); !aiter.Done(); - aiter.Next()) { - const LatticeArc &value = aiter.Value(); - if (value.olabel >= (Label)kTokenLabelOffset && - value.olabel < (Label)kMaxTokenLabel) { - LatticeWeight final_weight = raw_fst->Final(value.nextstate); - if (final_weight == LatticeState::Zero() || - final_weight.Value2() != 0) { - KALDI_ERR << "Label " << value.olabel - << " looks like a token-label but its next-state " - "has unexpected final-weight " << final_weight.Value1() << ',' - << final_weight.Value2(); - } - auto r = final_costs.insert({value.olabel, final_weight.Value1()}); - if (!r->second && r->first.second != final_weight.Value1()) { - // For any given token-label, all arcs in raw_fst with that - // olabel should go to the same state, so this should be - // impossible. - KALDI_ERR << "Unexpected mismatch in final-costs for tokens, " - << r->first.second << " vs " << final_weight.Value1(); - } + StateId raw_fst_num_states = raw_fst->NumStates(); + for (LatticeArc::StateId s = 0; s < raw_fst_num_states; s++) { + for (ArcIterator aiter(*raw_fst, s); !aiter.Done(); + aiter.Next()) { + const LatticeArc &value = aiter.Value(); + if (value.olabel >= (Label)kTokenLabelOffset && + value.olabel < (Label)kMaxTokenLabel) { + LatticeWeight final_weight = raw_fst->Final(value.nextstate); + if (final_weight == LatticeState::Zero() || + final_weight.Value2() != 0) { + KALDI_ERR << "Label " << value.olabel + << " looks like a token-label but its next-state " + "has unexpected final-weight " << final_weight.Value1() << ',' + << final_weight.Value2(); + } + auto r = final_costs.insert({value.olabel, final_weight.Value1()}); + if (!r->second && r->first.second != final_weight.Value1()) { + // For any given token-label, all arcs in raw_fst with that + // olabel should go to the same state, so this should be + // impossible. + KALDI_ERR << "Unexpected mismatch in final-costs for tokens, " + << r->first.second << " vs " << final_weight.Value1(); } } } } - CompactLattice chunk_clat; bool determinized_till_beam = DeterminizeLatticePhonePrunedWrapper( - trans_model_, raw_fst, (config_.lattice_beam + 0.1), &chunk_clat, + trans_model_, raw_fst, config_.lattice_beam, &chunk_clat, config_.det_opts); TopSortCompactLatticeIfNeeded(&chunk_clat); - StateId num_chunk_states = chunk_clat.NumStates(); - if (num_chunk_states == 0) { + std::unordered_map chunk_state_to_token; + IdentifyTokenFinalStates(chunk_clat, + &chunk_state_to_token); + + StateId chunk_num_states = chunk_clat.NumStates(); + if (chunk_num_states == 0) { // This will be an error but user-level calling code can detect it from the // lattice being empty. + KALDI_WARN << "Empty lattice, something went wrong."; chunk_clat_.DeleteStates(); return; } - ReweightChunk(&chunk_clat); - StateId start_state = chunk_clat.Start(); // would be 0. KALDI_ASSERT(start_state == 0); - // Process arcs leaving the start state. All arcs leaving the start state will - // have `state labels` on them (identifying redeterminized-states in clat_), - // and will transition to a state in `chunk_clat` that we can identify with - // that redeterminized- state. + // Process arcs leaving the start state of chunk_clat. Unless this is the + // first chunk in the lattice, all arcs leaving the start state of chunk_clat + // will have `state labels` on them (identifying redeterminized-states in + // clat_), and will transition to a state in `chunk_clat` that we can identify + // with that redeterminized-state. // state_map maps from (non-initial state s in chunk_clat) to: // if s is not final, then a state in clat_, // if s is final, then a state-label allocated by AllocateNewStateLabel(); // this will become a .nextstate in final_arcs_). std::unordered_map state_map; + bool is_first_chunk = false; StateId clat_num_states = clat_.NumStates(); @@ -1689,18 +1594,24 @@ bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( for (fst::ArcIterator aiter(chunk_clat, start_state); !aiter.Done(); aiter.Next()) { const CompactLatticeArc &arc = aiter.Value(); - Label label = arc.ilabel; // ilabel == olabel. + Label label = arc.ilabel; // ilabel == olabel; would be the olabel (word + // label) in a Lattice. + if (!(label >= kStateLabelOffset && label < clat_num_states)) { + // The label was not a state-label. This should only be possible on the + // first chunk. + KALDI_ASSERT(state_map.empty()); + is_first_chunk = true; + break; + } StateId clat_state = label - kStateLabelOffset; - KALDI_ASSERT(clat_state >= 0 && clat_state < clat_num_states); StateId chunk_state = arc.nextstate; - - CompactLatticeWeight weight(arc.weight); - bool inserted = state_map.insert({chunk_state, clat_state}); // Should not have been in the map before. KALDI_ASSERT(inserted); } + if (!is_first_chunk) + ReweightChunk(&chunk_clat); // Note: we haven't inspected any weights yet. // Remove any existing arcs in clat_ that leave redeterminized-states, // and make those states non-final. @@ -1710,266 +1621,113 @@ bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( clat.SetFinal(clat_state, CompactLatticeWeight::Zero()); } - // The final-arc info is no longer relevant, we'll recreate it below. + // The final-arc info is no longer relevant; we'll recreate it below. final_arcs_.clear(); - - // assume start-state == 0; we asserted it above. Allocate state-ids for all - // remaining states in chunk_clat (Except final-states, if this is not the - // last chunk). - for (StateId state = 1; state < num_chunk_states; state++) { - if (is_last_chunk || chunk_clat.Final(state) == CompactLatticeWeight::Zero()) { - // Allocate an actual state. - StateId new_clat_state = clat_.NumStates(); - if (state_map.insert({state, new_clat_state}).second) { - // If it was inserted then we need to actually allocate that state - StateId s = clat_.NewState(); - KALDI_ASSERT(s == new_clat_state); - } // else do nothing; it would have been a redeterminized-state and no - // allocation is needed since they already exist in clat_. and + // assume chunk_lat.Start() == 0; we asserted it above. Allocate state-ids + // for all remaining states in chunk_clat, except for token-final states. + for (StateId state = (is_first_chunk ? 0 : 1); + state < chunk_num_states; state++) { + if (chunk_state_to_token.count(state) != 0) + continue; // these `token-final` states don't get a state allocated. + + StateId new_clat_state = clat_.NumStates(); + if (state_map.insert({state, new_clat_state}).second) { + // If it was inserted then we need to actually allocate that state + StateId s = clat_.NewState(); + KALDI_ASSERT(s == new_clat_state); + } // else do nothing; it would have been a redeterminized-state and no + } // allocation is needed since they already exist in clat_. and // in state_map. - } - } + + if (is_first_chunk) + clat_.SetStart(state_map[start_state]); // Now transfer arcs from chunk_clat to clat_. - for (StateId chunk_state = 1; chunk_state < num_chunk_states; chunk_state++) { - bool is_final = chunk_clat.Final(chunk_state) != CompactLattice::Zero(); - if (is_last_chunk || !is_final) { - auto iter = state_map.find(chunk_state); - KALDI_ASSERT(iter != state_map.end()); - StateId clat_state = iter->second; - if (is_last_chunk && is_final) - clat_.SetFinal(clat_state, chunk_clat.Final(chunk_state)); - for (ArcIterator aiter(chunk_clat, chunk_state); - !aiter.Done(); aiter.Next()) { - CompactLatticeArc arc(aiter.Value()); - - auto next_iter = state_map.find(arc.nextstate); - if (next_iter != state_map.end()) { - arc.nextstate = next_iter->second; - clat_->AddArc(clat_state, arc); - } else { - KALDI_ASSERT(chunk_clat.Final(arc.nextstate) != CompactLatticeWeight::Zero() && - - arc.olabel >= (Label)kTokenLabelOffset && - arc.olabel < (Label)kMaxTokenLabel); - // Below we'll correct arc.weight for the final-cost. - arc.weight = fst::Times(arc.weight, chunk_clat.Final(arc.nextstate)); - // We just use the .nextstate field to encode the source state. - arc.nextstate = clat_state; - - // Note: the only reason we introduce these final-probs to clat_ - // is so that the user can obtain the compact lattice at an intermediate - // stage of the calculation. - if (keep_final_probs) - clat_->SetFinal(fst::Sum(lat_->Final(), - arc.weight)); - - // Cancel out `final_cost` (which will really be some kind of - // `backward`/beta cost from the raw lattice, introduced to guide - // pruned determinization) from arc.weight. - auto final_cost_iter = final_costs.find(arc.olabel); - KALDI_ASSERT(final_cost_iter != final_costs.end()); - BaseFloat final_cost = final_cost_iter; - arc.weight.SetWeight(Times(arc.weight.Weight(), - LatticeWeight(-final_cost, 0))); - - if (!keep_final_probs) // Set the final-prob of the state after - // sutracting the backward cost. - clat_->SetFinal(fst::Sum(lat_->Final(), - arc.weight)); - final_arcs_.push_back(arc); - } - } + for (StateId chunk_state = (is_first_chunk ? 0 : 1); + chunk_state < chunk_num_states; chunk_state++) { + auto iter = state_map.find(chunk_state); + if (iter == state_map.end()) { + KALDI_ASSERT(chunk_state_to_token.count(chunk_state) != 0); + // Don't process token-final states. Anyway they have no arcs leaving + // them. + continue; } - } - return determinized_till_beam; -} - -/* - TODO: move outside. - KALDI_VLOG(2) << "Frame: ( " << first_frame << " , " << last_frame << " )" - << " states of the chunk: " << clat.NumStates() - << " states of the lattice: " << clat_.NumStates(); -*/ + StateId clat_state = iter->second; + // Only in the last chunk of the lattice would be there be a final-prob on + // states that are not `token-final states`; these final-probs would + // normally all be zero at this point. + // So in almost all cases the following call will do nothing. + clat_->SetFinal(clat_state, chunk_clat.Final(chunk_state)); + for (ArcIterator aiter(chunk_clat, chunk_state); + !aiter.Done(); aiter.Next()) { + CompactLatticeArc arc(aiter.Value()); -bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( - int32 first_frame, int32 last_frame, - Lattice *raw_fst) { - - bool first_chunk = (first_frame == 0); - // step 2: Determinize the chunk - CompactLattice clat; - // We do determinization with beam pruning here - // Only if we use a beam larger than (config_.beam+config_.lattice_beam) here, we - // can guarantee no final or initial arcs in clat are pruned by this function. - // These pruned final arcs can hurt oracle WER performance in the final lattice - // (also result in less lattice density) but they seldom hurt 1-best WER. - // Since pruning behaviors in DeterminizeLatticePhonePrunedWrapper and - // PruneActiveTokens are not the same, to get similar lattice density as - // LatticeFasterDecoder, we need to use a slightly larger beam here - // than the lattice_beam used PruneActiveTokens. Hence the beam we use is - // (0.1 + config_.lattice_beam) - bool determinized_till_beam = DeterminizeLatticePhonePrunedWrapper( - trans_model_, raw_fst, (config_.lattice_beam + 0.1), &clat, config_.det_opts); - - // step 3: Appending the new chunk in clat to the old one in lat_ - // later we need to calculate forward_costs_ for clat - - TopSortCompactLatticeIfNeeded(&clat); - AppendLatticeChunks(clat, first_chunk); - - KALDI_VLOG(2) << "Frame: ( " << first_frame << " , " << last_frame << " )" - << " states of the chunk: " << clat.NumStates() - << " states of the lattice: " << clat_.NumStates(); - return determinized_till_beam; -} - -void LatticeIncrementalDeterminizer::AppendLatticeChunks( - const CompactLattice &clat, bool first_chunk) { - using namespace fst; - CompactLattice *olat = &clat_; - // step 3.1: Appending new chunk to the old one - int32 state_offset = olat->NumStates(); - if (!first_chunk) { - state_offset--; // since we do not append initial state in the first chunk - // remove arcs from redeterminized_state_map_ - for (auto i : redeterminized_state_map_) { - olat->DeleteArcs(i.first); - olat->SetFinal(i.first, CompactLatticeWeight::Zero()); - } - redeterminized_state_map_.clear(); - } else { - forward_costs_.push_back(0); // for the first state - } - forward_costs_.resize(state_offset + clat.NumStates(), - std::numeric_limits::infinity()); - - // Here we construct a map from the original prefinal state to the prefinal states - // for later use - unordered_map invert_processed_prefinal_states; - invert_processed_prefinal_states.reserve(processed_prefinal_states_.size()); - for (auto i : processed_prefinal_states_) - invert_processed_prefinal_states[i.second] = i.first; - for (StateIterator siter(clat); !siter.Done(); siter.Next()) { - auto s = siter.Value(); - StateId state_appended = kNoStateId; - // We do not copy initial state, which exists except the first chunk - if (first_chunk || s != 0) { - state_appended = s + state_offset; - auto r = olat->AddState(); - KALDI_ASSERT(state_appended == r); - olat->SetFinal(state_appended, clat.Final(s)); - } - - for (ArcIterator aiter(clat, s); !aiter.Done(); aiter.Next()) { - const auto &arc = aiter.Value(); - - StateId source_state = kNoStateId; - // We do not copy initial arcs, which exists except the first chunk. - // These arcs will be taken care later in step 3.2 - CompactLatticeArc arc_appended(arc); - arc_appended.nextstate += state_offset; - // In the first chunk, there could be a final arc starting from state 0, and we - // process it here - // In the last chunk, there could be a initial arc ending in final state, and - // we process it in "process initial arcs" in the following - bool is_initial_state = (!first_chunk && s == 0); - if (!is_initial_state) { - KALDI_ASSERT(state_appended != kNoStateId); - KALDI_ASSERT(arc.olabel < state_label_offset_); - source_state = state_appended; - // process final arcs - if (arc.olabel > config_.max_word_id) { - // record final_arc in this chunk for the step 3.2 in the next call - KALDI_ASSERT(arc.olabel < state_label_offset_); - KALDI_ASSERT(clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()); - // state_appended shouldn't be in invert_processed_prefinal_states - // So we do not need to map it - final_arc_list_.insert( - pair(state_appended, aiter.Position())); - } - olat->AddArc(source_state, arc_appended); - } else { // process initial arcs - // a special olabel in the arc that corresponds to the identity of the - // source-state of the last arc, we use its StateId and a offset here, called - // state_label - auto state_label = arc.olabel; - KALDI_ASSERT(state_label > config_.max_word_id); - KALDI_ASSERT(state_label >= state_label_offset_); - source_state = state_label - state_label_offset_; - arc_appended.olabel = 0; - arc_appended.ilabel = 0; - CompactLatticeWeight weight_offset; - // remove alpha in weight - weight_offset.SetWeight(LatticeWeight(0, -forward_costs_[source_state])); - arc_appended.weight = Times(arc_appended.weight, weight_offset); - - // if it is an extra prefinal state, we should use its original prefinal - // state - int arc_offset = 0; - auto r = invert_processed_prefinal_states.find(source_state); - if (r != invert_processed_prefinal_states.end() && r->second != r->first) { - source_state = r->second; - arc_offset = olat->NumArcs(source_state); + auto next_iter = state_map.find(arc.nextstate); + if (next_iter != state_map.end()) { + arc.nextstate = next_iter->second; + clat_->AddArc(clat_state, arc); + } else { + // TODO: remove the following slightly excessive assertion. + KALDI_ASSERT(chunk_clat.Final(arc.nextstate) != CompactLatticeWeight::Zero() && + arc.olabel >= (Label)kTokenLabelOffset && + arc.olabel < (Label)kMaxTokenLabel && + chunk_state_to_token.count(arc.nextstate) != 0 && + old_final_costs.count(arc.olabel) != 0); + + // Include the final-cost of the next state (which should be final) + // in arc.weight. + arc.weight = fst::Times(arc.weight, + chunk_clat.Final(arc.nextstate)); + + BaseFloat old_final_cost = old_final_costs[arc.olabel]; + auto iter = new_final_costs->find(arc.olabel); + + BaseFloat new_cost; + if (new_final_costs == NULL) { + new_cost = 0.0; // treat all new final-costs as One() + } else if (iter != new_final_costs->end()) { + new_cost = iter->second; + } else { + new_cost == std::numeric_limits::infinity; } - if (clat.Final(arc.nextstate) != CompactLatticeWeight::Zero()) { - // it should be the last chunk - olat->AddArc(source_state, arc_appended); - } else { - // append lattice chunk and remove Epsilon together - for (ArcIterator aiter_postinitial(clat, arc.nextstate); - !aiter_postinitial.Done(); aiter_postinitial.Next()) { - auto arc_postinitial(aiter_postinitial.Value()); - arc_postinitial.weight = - Times(arc_appended.weight, arc_postinitial.weight); - arc_postinitial.nextstate += state_offset; - olat->AddArc(source_state, arc_postinitial); - if (arc_postinitial.olabel > config_.max_word_id) { - KALDI_ASSERT(arc_postinitial.olabel < state_label_offset_); - final_arc_list_.insert(pair( - source_state, aiter_postinitial.Position() + arc_offset)); - } - } + if (new_cost != std::numeric_limits::infinity) { + // Add a final-prob in clat_. + // These final-probs will be consumed by the user if they get the + // lattices as we incrementally determinize, but they will not affect + // what happens after we process the next chunk. These final-probs + // would not exist in the `canonical compact lattice` (see glossary). + LatticeWeight cost_correction(new_cost - old_final_cost, 0.0); + CompactLatticeWeight final_prob(arc.weight); + final_prob.SetWeight(fst::Times(cost_correction, final_prob.Weight())); + clat_->SetFinal(clat_state, fst::Sum(clat_->Final(clat_state), + final_prob)); } - } - // update forward_costs_ (alpha) - KALDI_ASSERT(arc_appended.nextstate < forward_costs_.size()); - auto &alpha_nextstate = forward_costs_[arc_appended.nextstate]; - auto &weight = arc_appended.weight.Weight(); - alpha_nextstate = - std::min(alpha_nextstate, - forward_costs_[source_state] + weight.Value1() + weight.Value2()); - } - } - KALDI_ASSERT(olat->NumStates() == clat.NumStates() + state_offset); - KALDI_VLOG(8) << "states of the lattice: " << olat->NumStates(); - if (first_chunk) { - olat->SetStart(0); // Initialize the first chunk for olat - } else { - // The extra prefinal states generated by - // GetRedeterminizedStates are removed here, while splicing - // the compact lattices together - for (auto &i : processed_prefinal_states_) { - auto prefinal_state = i.first; - auto new_prefinal_state = i.second; - // It is without an extra prefinal state, hence do not need to process - if (prefinal_state == new_prefinal_state) continue; - for (ArcIterator aiter(*olat, new_prefinal_state); - !aiter.Done(); aiter.Next()) - olat->AddArc(prefinal_state, aiter.Value()); - olat->DeleteArcs(new_prefinal_state); - olat->SetFinal(new_prefinal_state, CompactLatticeWeight::Zero()); + // OK, `arc` is going to become an element of final_arcs_. These + // contain information about transitions from states in clat_ to + // `token-final` states (i.e. states that have a token-label on the arc + // to them and that are final in the canonical compact lattice). + arc.weight.SetWeight(fst::Times(arc.weight, + LatticeWeight{-old_final_cost, 0.0))); + // In a slight abuse of the Arc data structure, the nextstate is set to + // the source state. The label (ilabel == olabel) indicates the + // token it is associated with. + arc.nextstate = clat_state; + final_arcs_.push_back(arc); + } } } + GetNonFinalRedetStates(); - final_arc_list_.clear(); + return determinized_till_beam; } + + void LatticeIncrementalDeterminizer::Finalize() { using namespace fst; // The lattice determinization only needs to be finalized once diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 395afc24bec..09366360cdb 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -1,6 +1,6 @@ // decoder/lattice-incremental-decoder.h -// Copyright 2019 Zhehuai Chen, Hainan Xu, Daniel Povey +// Copyright 2019 Zhehuai Chen, Daniel Povey // See ../../COPYING for clarification regarding multiple authors // @@ -116,8 +116,6 @@ struct LatticeIncrementalDecoderConfig { int32 determinize_period; int32 determinize_max_active; int32 redeterminize_max_frames; - bool final_prune_after_determinize; - int32 max_word_id; // for GetLattice LatticeIncrementalDecoderConfig() @@ -132,9 +130,7 @@ struct LatticeIncrementalDecoderConfig { determinize_delay(25), determinize_period(20), determinize_max_active(std::numeric_limits::max()), - redeterminize_max_frames(std::numeric_limits::max()), - final_prune_after_determinize(true), - max_word_id(1e8) {} + redeterminize_max_frames(std::numeric_limits::max()) { } void Register(OptionsItf *opts) { det_opts.Register(opts); opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); @@ -210,8 +206,7 @@ class LatticeIncrementalDeterminizer2 { specific type all the time but just say 'Label' */ - LatticeIncrementalDeterminizer2(const LatticeIncrementalDecoderConfig &config, - const TransitionModel &trans_model); + LatticeIncrementalDeterminizer2() { } // Resets the lattice determinization data for new utterance void Init(); @@ -226,7 +221,7 @@ class LatticeIncrementalDeterminizer2 { specifically the initial-state i and the redeterminized-states. After calling this, the caller would add the remaining arcs and states - to `olat` and then call AcceptChunk() with the result. + to `olat` and then call AcceptRawLatticeChunk() with the result. @param [out] olat The lattice to be (partially) created @@ -270,15 +265,24 @@ class LatticeIncrementalDeterminizer2 { const CompactLattice &GetLattice() { return clat_; } - private: - // kTokenLabelOffset is where we start allocating labels corresponding to Tokens - // (these correspond with raw lattice states); // kStateLabelOffset is what we add to state-ids in clat_ to produce labels // to identify them in the raw lattice chunk + // kTokenLabelOffset is where we start allocating labels corresponding to Tokens + // (these correspond with raw lattice states); enum { kStateLabelOffset = (int)1e8, kTokenLabelOffset = (int)2e8, kMaxTokenLabel = (int)3e8 }; - Label next_state_label_; + private: + // Sets up non_final_redet_states_. See documentation for that variable. + void GetNonFinalRedetStates(); + + // Contains the set of redeterminized-states which are not final in the + // canonical appended lattice. Since the final ones don't physically appear + // in clat_, this means the set of redeterminized-states which are physically + // in clat_. In code terms, this means set of .first elements in final_arcs, + // plus whatever other states in clat_ are reachable from such states. + std::unordered_set non_final_redet_states_; + // clat_ is the appended lattice (containing all chunks processed so // far), except its `final-arcs` (i.e. arcs which in the canonical @@ -287,14 +291,6 @@ class LatticeIncrementalDeterminizer2 { // should have final-arcs leaving them will instead have a final-prob. CompactLattice clat_; - // The elements of this set are the redeterminized-states which are not final in - // the canonical appended lattice. This means the set of .first elements in - // final_arcs, plus whatever states in clat_ are reachable from such states. - // (The final redeterminized states/splice-states are never actually - // materialized.) - std::unordered_set non_final_redet_states_; - - // final_arcs_ contains arcs which would appear in the canonical appended // lattice but for implementation reasons are not physically present in clat_. // These are arcs to final states in the canonical appended lattice. The @@ -304,14 +300,6 @@ class LatticeIncrementalDeterminizer2 { // AllocateNewStateLabel(). std::vector final_arcs_; - - // final_weights_ contain the final-probs of states that are final in the - // canonical compact lattice. Physically it maps from the state-labels which - // are allocated by AllocateNewStateLabel() and are stored in the .nextstate - // in final_arcs_, to the weight that would be on that final-state in the - // canonical compact lattice. - std::unordered_map final_weights_; - // forward_costs_, indexed by the state-id in clat_, stores the alpha // (forward) costs, i.e. the minimum cost from the start state to each state // in clat_. This is relevant for pruned determinization. The BaseFloat can @@ -421,18 +409,7 @@ class LatticeIncrementalDecoderTpl { still have to do a fair amount of work; calling it every, say, 10 to 40 frames would make sense though. - @param [in] use_final_probs True if you want the final-probs - of HCLG to be included in the output lattice. - (However, if no state was final on frame - `num_frames_to_include` they won't be included regardless - of use_final_probs; if this equals NumFramesDecoded() you - can test this with ReachedFinal()). Caution: - it is an error to call this function with - the same num_frames_to_include and different values - of `use_final_probs`. (This is not a fundamental - limitation but just the way we coded it.) - - @param [in] num_frames_to_include The number of frames that you want + @param [in] num_frames_to_include The number of frames that you want to be included in the lattice. Must be >0 and <= NumFramesDecoded(). If you are calling this just to keep the incremental lattice determinization up to date and @@ -448,13 +425,29 @@ class LatticeIncrementalDecoderTpl { a few frames larger) is probably not a good use of computational resources. + @param [in] use_final_probs True if you want the final-probs + of HCLG to be included in the output lattice. Must not be + set if num_frames_to_include < NumFramesDecoded(). If no + state was final on frame `num_frames_to_include` they won't + be included regardless of use_final_probs; you can test this + with ReachedFinal(). Caution: it is an error to call this + function in succession with the same num_frames_to_include + and different values of `use_final_probs`. (This is not a + fundamental limitation but just the way we coded it.) + + @param [in] finalize If true, finalize the lattice (does an extra + pruning step on the raw lattice). After this call, no + further calls to GetLattice() will be allowed. + @return clat The CompactLattice representing what has been decoded up until `num_frames_to_include` (e.g., LatticeStateTimes() on this lattice would return `num_frames_to_include`). */ - const CompactLattice &GetLattice(bool use_final_probs, - int32 num_frames_to_include); + const CompactLattice &GetLattice(int32 num_frames_to_include, + bool use_final_probs = false, + bool finalize = false); + @@ -476,21 +469,6 @@ class LatticeIncrementalDecoderTpl { void AdvanceDecoding(DecodableInterface *decodable, int32 max_num_frames = -1); - /** - This function may be optionally called after AdvanceDecoding(), when you - do not plan to decode any further. It does an extra pruning step that - will help to prune the lattices output by GetLattice more accurately, - particularly toward the end of the utterance. - It does this by using the final-probs in pruning (if any - final-state survived); it also does a final pruning step that visits all - states (the pruning that is done during decoding may fail to prune states - that are within kPruningScale = 0.1 outside of the beam). If you call - this, you cannot call AdvanceDecoding again (it will fail), and you - cannot call GetLattice() and related functions with use_final_probs = - false. - */ - void FinalizeDecoding(); - /** FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives more information. It returns the difference between the best (final-cost plus cost) of any token on the final frame, and the best cost of any token @@ -544,28 +522,38 @@ class LatticeIncrementalDecoderTpl { int32 num_toks_; bool warned_; bool decoding_finalized_; + + int32 final_cost_frame_; // TODO: initialize. unordered_map final_costs_; BaseFloat final_relative_cost_; BaseFloat final_best_cost_; - /*** Variables below this point relate to the incremental - determinization. ***/ + /*********************** + Variables below this point relate to the incremental + determinization. + *********************/ LatticeIncrementalDecoderConfig config_; /** Much of the the incremental determinization algorithm is encapsulated in the determinize_ object. */ LatticeIncrementalDeterminizer2 determinizer_; + + /* Just a temporary used in a function; stored here to avoid reallocation. */ + unordered_map temp_token_map_; + /** num_frames_in_lattice_ is the highest `num_frames_to_include_` argument for any prior call to GetLattice(). */ int32 num_frames_in_lattice_; - // a map from Token to its token_label - unordered_map token2label_map_; + // A map from Token to its token_label. Will contain an entry for + // each Token in active_toks_[num_frames_in_lattice_]. + unordered_map token2label_map_; + + // A temporary used in a function, kept here to avoid reallocation. + unordered_map token2label_map_temp_; + // we allocate a unique id for each Token - int32 token_label_available_idx_; - // We keep cost_offset for each token_label (Token) in final arcs. We need them to - // guide determinization - // We cancel them after determinization - unordered_map token_label2final_cost_; + Label next_token_label_; + inline Label AllocateNewTokenLabel() { return next_token_label_++; } // There are various cleanup tasks... the the toks_ structure contains @@ -591,28 +579,10 @@ class LatticeIncrementalDecoderTpl { void ClearActiveTokens(); - // The following part is specifically designed for incremental determinization - // This function is modified from LatticeFasterDecoderTpl::GetRawLattice() - // and specific design for step 1 of incremental determinization - // introduced before above GetLattice() - // It does the same thing as GetRawLattice in lattice-faster-decoder.cc except: - // - // i) it creates a initial state, and connect - // each token in the first frame of this chunk to the initial state - // by one or more arcs with a state_label correponding to the pre-final state w.r.t - // this token(the pre-final state is appended in the last chunk) as its olabel - // ii) it creates a final state, and connect - // all the tokens in the last frame of this chunk to the final state - // by an arc with a per-token token_label as its olabel - // `frame_begin` and `frame_end` are the first and last frame of this chunk - // if `create_initial_state` == false, we will not create initial state and - // the corresponding initial arcs. Similar for `create_final_state` - // In incremental GetLattice, we do not create the initial state in - // the first chunk, and we do not create the final state in the last chunk - bool GetIncrementalRawLattice(Lattice *ofst, bool use_final_probs, - int32 frame_begin, int32 frame_end, - bool create_initial_state, bool create_final_state); - // Returns the number of active tokens on frame `frame`. + + // Returns the number of active tokens on frame `frame`. Can be used as part + // of a heuristic to decide which frame to determinize until, if you are not + // at the end of an utterance. int32 GetNumToksForFrame(int32 frame); // DeterminizeLattice() is just a wrapper for GetLattice() that uses the various @@ -623,6 +593,16 @@ class LatticeIncrementalDecoderTpl { // We may at some point decide to make this public. void DeterminizeLattice(); + /** + This function used to be public in LatticeFasterDecoder but is now accessed + only by including the 'finalize' argument to GetLattice(). It may be + called only once per utterance, at the end. (GetLattice() will ensure this + anyway. + It prunes the raw lattice. + */ + void FinalizeDecoding(); + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDecoderTpl); }; From 86a6bc1ae60049f79aef6b1e1f4f8583398ced48 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 10 Nov 2019 18:28:54 -0800 Subject: [PATCH 44/60] Some more cleanup, working on making it compile --- src/decoder/lattice-incremental-decoder.cc | 285 +++++---------------- src/decoder/lattice-incremental-decoder.h | 28 +- 2 files changed, 96 insertions(+), 217 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 3e16c19b5b2..ce02d00ab29 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -31,8 +31,7 @@ LatticeIncrementalDecoderTpl::LatticeIncrementalDecoderTpl( : fst_(&fst), delete_fst_(false), num_toks_(0), - config_(config), - determinizer_(config, trans_model) { + config_(config) { config.Check(); toks_.SetSize(1000); // just so on the first frame we do something reasonable. } @@ -44,8 +43,7 @@ LatticeIncrementalDecoderTpl::LatticeIncrementalDecoderTpl( : fst_(fst), delete_fst_(true), num_toks_(0), - config_(config), - determinizer_(config, trans_model) { + config_(config){ config.Check(); toks_.SetSize(1000); // just so on the first frame we do something reasonable. } @@ -77,7 +75,7 @@ void LatticeIncrementalDecoderTpl::InitDecoding() { determinizer_.Init(); num_frames_in_lattice_ = 0; - tokentlabel_map_.clear(); + token2label_map_.clear(); next_token_label_ = LatticeIncrementalDeterminizer2::kTokenLabelOffset; ProcessNonemitting(config_.beam); } @@ -972,7 +970,7 @@ const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( tok2state_map.clear(); unordered_map &next_token2label_map(token2label_map_temp_); - next_token2label_map_.clear(); + next_token2label_map.clear(); { // Deal with the last frame. We allocate token labels, and set tokens as @@ -1003,7 +1001,7 @@ const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( // go, and their destination-states will already be in the map. for (int32 frame = num_frames_to_include - 1; frame >= num_frames_in_lattice_; frame--) { - BaseFloat cost_offset = cost_offsets_[f]; + BaseFloat cost_offset = cost_offsets_[frame]; // For the first frame of the chunk, we need to make sure the states are // the ones created by InitializeRawLatticeChunk() (where not pruned away). @@ -1041,16 +1039,16 @@ const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( StateId next_state = next_iter->second; BaseFloat this_offset = (l->ilabel != 0 ? cost_offset : 0); LatticeArc arc(l->ilabel, l->olabel, - Weight(l->graph_cost, l->acoustic_cost - cost_offset), + Weight(l->graph_cost, l->acoustic_cost - this_offset), next_state); - chunk_lat.AddArc(state, arc); + chunk_lat.AddArc(cur_state, arc); } } } if (num_frames_in_lattice_ == 0) { std::vector tok_list; TopSortTokens(active_toks_[0], &tok_list); - Tok *start_token = tok_list[0]; + Token *start_token = tok_list[0]; auto iter = tok2state_map.find(start_token); KALDI_ASSERT(iter != tok2state_map.end()); StateId start_state = iter->second; @@ -1075,179 +1073,16 @@ const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( } } - - bool finished_before_beam = - determinizer_.AcceptRawLatticeChunk(chunk_lat, - (use_final_probs ? &final_costs : NULL)); + // bool finished_before_beam = + determinizer_.AcceptRawLatticeChunk(chunk_lat, + (use_final_probs ? &final_costs : NULL)); + // We are ignoring the return status, which say whether it finished before the beam. num_frames_in_lattice_ = num_frames_to_include; return determinizer_.GetLattice(); } -template -bool LatticeIncrementalDecoderTpl::GetIncrementalRawLattice( - Lattice *ofst, bool use_final_probs, int32 frame_begin, int32 frame_end, - bool create_initial_state, bool create_final_state) { - typedef LatticeArc Arc; - typedef Arc::StateId StateId; - typedef Arc::Weight Weight; - typedef Arc::Label Label; - - if (decoding_finalized_ && !use_final_probs) - KALDI_ERR << "You cannot call FinalizeDecoding() and then call " - << "GetIncrementalRawLattice() with use_final_probs == false"; - - unordered_map final_costs_local; - const unordered_map &final_costs = - (decoding_finalized_ ? final_costs_ : final_costs_local); - if (!decoding_finalized_ && use_final_probs) - ComputeFinalCosts(&final_costs_local, NULL, NULL); - - - ofst->DeleteStates(); - unordered_map - token_label2state; // for InitializeRawLatticeChunk - // initial arcs for the chunk - if (create_initial_state) - determinizer_.InitializeRawLatticeChunk(ofst, token_label2final_cost_, - &token_label2state); - // num-frames plus one (since frames are one-based, and we have - // an extra frame for the start-state). - KALDI_ASSERT(frame_end > 0); - unordered_map &tok_map(temp_token_map_); - tok_map.clear(); - - // First create all states. - std::vector token_list; - for (int32 f = frame_begin; f <= frame_end; f++) { - if (active_toks_[f].toks == NULL) { - KALDI_WARN << "GetIncrementalRawLattice: no tokens active on frame " << f - << ": not producing lattice.\n"; - return false; - } - TopSortTokens(active_toks_[f].toks, &token_list); - for (size_t i = 0; i < token_list.size(); i++) - if (token_list[i] != NULL) tok_map[token_list[i]] = ofst->AddState(); - } - // The next statement sets the start state of the output FST. - // No matter create_initial_state or not , state zero must be the start-state. - StateId start_state = 0; - ofst->SetStart(start_state); - - KALDI_VLOG(4) << "init:" << num_toks_ / 2 + 3 - << " buckets:" << tok_map.bucket_count() - << " load:" << tok_map.load_factor() - << " max:" << tok_map.max_load_factor(); - // step 1.1: create initial_arc for later appending with the previous chunk - if (create_initial_state) { - for (Token *tok = active_toks_[frame_begin].toks; tok != NULL; tok = tok->next) { - StateId cur_state = tok_map[tok]; - // token2label_map_ is construct during create_final_state - auto r = token2label_map_.find(tok); - KALDI_ASSERT(r != token2label_map_.end()); // it should exist - int32 token_label = r->second; - auto range = token_label2state.equal_range(token_label); - if (range.first == range.second) { - KALDI_WARN - << "The token in the first frame of this chunk does not " - "exist in the last frame of previous chunk. It should seldom" - " happen and would be caused by over-pruning in determinization," - "e.g. the lattice reaches --max-mem constrain."; - continue; - } - for (auto it = range.first; it != range.second; ++it) { - // the destination state of the last of the sequence of arcs w.r.t the token - // label - // here created by InitializeRawLatticeChunk - auto state_last_initial = it->second; - // connect it to the state correponding to the token w.r.t the token label - // here - Arc arc(0, 0, Weight::One(), cur_state); - ofst->AddArc(state_last_initial, arc); - } - } - } - // step 1.2: create all arcs as GetRawLattice() of LatticeFasterDecoder - for (int32 f = frame_begin; f <= frame_end; f++) { - for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { - StateId cur_state = tok_map[tok]; - for (ForwardLinkT *l = tok->links; l != NULL; l = l->next) { - // for the arcs outgoing from the last frame Token in this chunk, we will - // create these arcs in the next chunk - if (f == frame_end && l->ilabel > 0) continue; - typename unordered_map::const_iterator iter = - tok_map.find(l->next_tok); - KALDI_ASSERT(iter != tok_map.end()); - StateId nextstate = iter->second; - BaseFloat cost_offset = 0.0; - if (l->ilabel != 0) { // emitting.. - KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); - cost_offset = cost_offsets_[f]; - } - Arc arc(l->ilabel, l->olabel, - Weight(l->graph_cost, l->acoustic_cost - cost_offset), nextstate); - ofst->AddArc(cur_state, arc); - } - // For the last frame in this chunk, we need to work out a - // proper final weight for the corresponding state. - // If use_final_probs == true, we will try to use the final cost we just - // calculated - // Otherwise, we use LatticeWeight::One(). We record these cost in the state - // Later in the code, if create_final_state == true, we will create - // a specific final state, and move the final costs to the cost of an arc - // connecting to the final state - if (f == frame_end) { - LatticeWeight weight = LatticeWeight::One(); - if (use_final_probs && !final_costs.empty()) { - typename unordered_map::const_iterator iter = - final_costs.find(tok); - if (iter != final_costs.end()) - weight = LatticeWeight(iter->second, 0); - else - weight = LatticeWeight::Zero(); - } - ofst->SetFinal(cur_state, weight); - } - } - } - // step 1.3 create final_arc for later appending with the next chunk - if (create_final_state) { - StateId end_state = ofst->AddState(); // final-state for the chunk - ofst->SetFinal(end_state, Weight::One()); - - token2label_map_.clear(); - token2label_map_.reserve(std::min((int32)1e5, config_.max_active)); - for (Token *tok = active_toks_[frame_end].toks; tok != NULL; tok = tok->next) { - StateId cur_state = tok_map[tok]; - // We assign an unique state label for each of the token in the last frame - // of this chunk - int32 id = token_label_available_idx_++; - token2label_map_[tok] = id; - // The final weight has been worked out in the previous for loop and - // store in the states - // Here, we create a specific final state, and move the final costs to - // the cost of an arc connecting to the final state - KALDI_ASSERT(ofst->Final(cur_state) != Weight::Zero()); - Weight final_weight = ofst->Final(cur_state); - // Use cost_offsets to guide DeterminizeLatticePruned() - // For now, we use extra_cost from the decoding stage , which has some - // "future information", as - // the final weights of this chunk - BaseFloat cost_offset = tok->extra_cost - tok->tot_cost; - // We record these cost_offset, and after we appending two chunks - // we will cancel them out - token_label2final_cost_[id] = cost_offset; - Arc arc(0, id, Times(final_weight, Weight(0, cost_offset)), end_state); - ofst->AddArc(cur_state, arc); - ofst->SetFinal(cur_state, Weight::Zero()); - } - } - // TODO: clean up maps used internally. - TopSortLatticeIfNeeded(ofst); - return (ofst->NumStates() > 0); -} - template int32 LatticeIncrementalDecoderTpl::GetNumToksForFrame(int32 frame) { int32 r = 0; @@ -1257,7 +1092,6 @@ int32 LatticeIncrementalDecoderTpl::GetNumToksForFrame(int32 frame) - /* This utility function adds an arc to a Lattice, but where the source is a CompactLatticeArc. If the CompactLatticeArc has a string with length greater than 1, this will require adding extra states to `lat`. @@ -1271,18 +1105,18 @@ static void AddCompactLatticeArcToLattice( if (N == 0) { LatticeArc arc; arc.ilabel = 0; - arc.olabel = clat_arc.label; + arc.olabel = clat_arc.ilabel; arc.nextstate = clat_arc.nextstate; arc.weight = clat_arc.weight.Weight(); lat->AddArc(src_state, arc); } else { - LatticeArc::StateId cur_state = arc_state; + LatticeArc::StateId cur_state = src_state; for (size_t i = 0; i < N; i++) { LatticeArc arc; arc.ilabel = string[i]; arc.olabel = (i == 0 ? clat_arc.ilabel : 0); arc.nextstate = (i + 1 == N ? clat_arc.nextstate : lat->AddState()); - arc.weight = (i == 0 ? clat_arc.weight.Weight() : 0); + arc.weight = (i == 0 ? clat_arc.weight.Weight() : LatticeWeight::One()); lat->AddArc(cur_state, arc); cur_state = arc.nextstate; } @@ -1297,16 +1131,12 @@ void LatticeIncrementalDeterminizer2::Init() { forward_costs_.clear(); } -/** - Reweights a compact lattice chunk in a way that makes the combination with - the current compact lattice easier. Also removes some temporary - forward-probs that we previously added. -*/ +// See documentation in header void LatticeIncrementalDeterminizer2::ReweightChunk( - CompactLattice *chunk_clat) { + CompactLattice *chunk_clat) const { using StateId = CompactLatticeArc::StateId; using Label = CompactLatticeArc::Label; - StateId start = chunk_clat->Start(); + StateId start_state = chunk_clat->Start(), num_states = chunk_clat->NumStates(); std::vector potentials(chunk_clat->NumStates(), CompactLatticeWeight::One()); @@ -1316,22 +1146,20 @@ void LatticeIncrementalDeterminizer2::ReweightChunk( CompactLatticeArc arc = aiter.Value(); Label label = arc.ilabel; // ilabel == olabel. StateId clat_state = label - kStateLabelOffset; - KALDI_ASSERT(clat_state >= 0 && clat_state < clat_num_states); + KALDI_ASSERT(clat_state >= 0 && clat_state < num_states); // `extra_weight` serves to cancel out the weight // `forward_costs_[clat_state]` that we introduced in // InitializeRawLatticeChunk(); the purpose of that was to // make the pruned determinization work right, but they are // no longer needed. LatticeWeight extra_weight(-forward_costs_[clat_state], 0.0); - arc.weight.SetWeight( - CompactLatticeWeight::Times(arc.weight.Weight(), - extra_weight)); + arc.weight.SetWeight(fst::Times(arc.weight.Weight(), extra_weight)); aiter.SetValue(arc); potentials[arc.nextstate] = arc.weight; } // TODO: consider doing the following manually for this special case, // since most states are not reweighted. - fst::Reweight(potentials, fst::ReweightToFinal, chunk_clat); + fst::Reweight(chunk_clat, potentials, fst::REWEIGHT_TO_FINAL); // Below is just a check that weights on arcs leaving initial state // are all One(). @@ -1341,10 +1169,6 @@ void LatticeIncrementalDeterminizer2::ReweightChunk( KALDI_ASSERT(fst::ApproxEqual(aiter.Value().weight, CompactLatticeWeight::One())); } - Label label = arc.ilabel; // ilabel == olabel. - StateId clat_state = label - kStateLabelOffset; - KALDI_ASSERT(clat_state >= 0 && clat_state < clat_num_states); - } @@ -1508,6 +1332,50 @@ void LatticeIncrementalDeterminizer2::InitializeRawLatticeChunk( } +static bool incr_det_warned = false; +void LatticeIncrementalDeterminizer2::UpdateForwardCosts( + const std::unordered_map &state_map) { + using StateId = CompactLattice::StateId; + BaseFloat infinity = std::numeric_limits::infinity; + StateId cur_size = forward_costs_.size(); + for (auto &p: state_map) { + StateId state = p.second; // the state-id in clat_ + // The reason we can make the following assertion is that the states should + // be in topological order and each state should be reachable from an + // earlier state (and we should have processed that earlier state by now). + KALDI_ASSERT(state < cur_size); + BaseFloat cur_cost = forward_costs_[state]; + if (cur_cost == infinity) { + // I don't think I can exclude that there might be unreachable + // states + if (!incr_det_warned) { + KALDI_WARN << "Found unreachable state in compact lattice while determinizing"; + incr_det_warned = true; + } + continue; + } + KALDI_ASSERT(cur_cost < infinity); + for (fst::ArcIterator aiter(clat_, state); + !aiter.Done(); aiter.Next()) { + const CompactLatticeArc &arc = aiter.Value(); + BaseFloat arc_cost = arc.weight.GetWeight().Value1() + + arc.weight.GetWeight().Value2(), + next_cost = cur_cost + arc_cost; + if (arc.nextstate >= cur_size) { + forward_costs_.resize(arc.nextstate + 1, infinity); + cur_size = arc.nextstate + 1; + forward_costs_[arc.nextstate] = next_cost; + } else if (forward_costs_[arc.nextstate] > next_cost) { + forward_costs_[arc.nextstate] = next_cost; + } + } + + + } + +} + + bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( Lattice *raw_fst, const std::unordered_map *new_final_costs) { @@ -1640,8 +1508,11 @@ bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( } // allocation is needed since they already exist in clat_. and // in state_map. - if (is_first_chunk) + if (is_first_chunk) { + KALDI_ASSERT(forward_probs_.empty() && start_state == 0); + forward_probs_.push_back(0.0); // forward-cost of start state is 0. clat_.SetStart(state_map[start_state]); + } // Now transfer arcs from chunk_clat to clat_. for (StateId chunk_state = (is_first_chunk ? 0 : 1); @@ -1723,26 +1594,12 @@ bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( } GetNonFinalRedetStates(); + UpdateForwardProbs(state_map); + return determinized_till_beam; } - -void LatticeIncrementalDeterminizer::Finalize() { - using namespace fst; - // The lattice determinization only needs to be finalized once - if (determinization_finalized_) - return; - // step 4: remove dead states - if (config_.final_prune_after_determinize) - PruneLattice(config_.lattice_beam, &clat_); - else - Connect(&clat_); // Remove unreachable states... there might be - - KALDI_VLOG(2) << "states of the lattice: " << clat_.NumStates(); - determinization_finalized_ = true; -} - // Instantiate the template for the combination of token types and FST types // that we'll need. template class LatticeIncrementalDecoderTpl, decoder::StdToken>; diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 09366360cdb..dce964bccb3 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -170,8 +170,6 @@ struct LatticeIncrementalDecoderConfig { "deterministic but less likely to blow up the processing" "time in bad cases. You could set it infinite to get a fully " "determinized lattice."); - opts->Register("final-prune-after-determinize", &final_prune_after_determinize, - "prune lattice after determinization "); opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " "parameter is obscure and relates to a speedup in the way the " @@ -205,7 +203,6 @@ class LatticeIncrementalDeterminizer2 { lattice, so we don't use the specific type all the time but just say 'Label' */ - LatticeIncrementalDeterminizer2() { } // Resets the lattice determinization data for new utterance @@ -276,6 +273,31 @@ class LatticeIncrementalDeterminizer2 { // Sets up non_final_redet_states_. See documentation for that variable. void GetNonFinalRedetStates(); + // Updates forward_costs_ for all the states which are successors of states + // appearing as values in `state_map`. (By "a is a successor of b" I mean + // there is an arc from a to b.) + void UpdateForwardCosts( + const std::unordered_map &state_map); + + + // Reweights `chunk_clat`. Must not be called if this is the first chunk. + // This does: + // (1) For arcs leaving chunk_clat->Start(), identify the redeterminized-state + // clat_state in clat_ that its .nextstate corresponds to, and multiply the weight + // by LatticeWeight(-forward_costs_[clat_state], 0). This is the opposite + // of a cost that we introduced when constructing the raw lattice chunk, + // in order to make sure that determinized pruning works right. We need to + // cancel it out because it's not really part of this chunk. + // (2) After doing (1), modifies chunk_clat so that the weights on arcs + // leaving its start state are all CompactLatticeWeight::One()... + // does this while maintaining equivalence, using OpenFst's + // Reweight() function. This is done for convenience, because + // the start state doesn't correspond to any state in clat_, + // and if there were weights on arcs leaving it we'd need to take + // them into account somehow. + void ReweightChunk(CompactLattice *chunk_clat) const; + + // Contains the set of redeterminized-states which are not final in the // canonical appended lattice. Since the final ones don't physically appear // in clat_, this means the set of redeterminized-states which are physically From a182fe3abb4eba5975b9b4e58d9a5f1a5fe8defb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 11 Nov 2019 14:13:09 -0800 Subject: [PATCH 45/60] Got decoder directory to compile --- src/decoder/lattice-faster-decoder.cc | 17 +- src/decoder/lattice-incremental-decoder.cc | 261 ++++++++------------- src/decoder/lattice-incremental-decoder.h | 77 +++--- 3 files changed, 150 insertions(+), 205 deletions(-) diff --git a/src/decoder/lattice-faster-decoder.cc b/src/decoder/lattice-faster-decoder.cc index 9106309eb84..83c582d3b5e 100644 --- a/src/decoder/lattice-faster-decoder.cc +++ b/src/decoder/lattice-faster-decoder.cc @@ -229,24 +229,17 @@ void LatticeFasterDecoderTpl::PossiblyResizeHash(size_t num_toks) { extra_cost is used in pruning tokens, to save memory. - Define the 'forward cost' of a token as zero for any token on the frame - we're currently decoding; and for other frames, as the shortest-path cost - between that token and a token on the frame we're currently decoding. - (by "currently decoding" I mean the most recently processed frame). - - Then define the extra_cost of a token (always >= 0) as the forward-cost of - the token minus the smallest forward-cost of any token on the same frame. + extra_cost can be thought of as a beta (backward) cost assuming + we had set the betas on currently-active tokens to all be the negative + of the alphas for those tokens. (So all currently active tokens would + be on (tied) best paths). We can use the extra_cost to accurately prune away tokens that we know will never appear in the lattice. If the extra_cost is greater than the desired lattice beam, the token would provably never appear in the lattice, so we can prune away the token. - The advantage of storing the extra_cost rather than the forward-cost, is that - it is less costly to keep the extra_cost up-to-date when we process new frames. - When we process a new frame, *all* the previous frames' forward-costs would change; - but in general the extra_cost will change only for a finite number of frames. - (Actually we don't update all the extra_costs every time we update a frame; we + (Note: we don't update all the extra_costs every time we update a frame; we only do it every 'config_.prune_interval' frames). */ diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index ce02d00ab29..e5f7590b497 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -31,7 +31,8 @@ LatticeIncrementalDecoderTpl::LatticeIncrementalDecoderTpl( : fst_(&fst), delete_fst_(false), num_toks_(0), - config_(config) { + config_(config), + determinizer_(trans_model, config) { config.Check(); toks_.SetSize(1000); // just so on the first frame we do something reasonable. } @@ -43,7 +44,8 @@ LatticeIncrementalDecoderTpl::LatticeIncrementalDecoderTpl( : fst_(fst), delete_fst_(true), num_toks_(0), - config_(config){ + config_(config), + determinizer_(trans_model, config) { config.Check(); toks_.SetSize(1000); // just so on the first frame we do something reasonable. } @@ -142,15 +144,6 @@ bool LatticeIncrementalDecoderTpl::Decode(DecodableInterface *decoda return !active_toks_.empty() && active_toks_.back().toks != NULL; } -// Outputs an FST corresponding to the single best path through the lattice. -template -void LatticeIncrementalDecoderTpl::GetBestPath(Lattice *olat, - bool use_final_probs) { - CompactLattice lat, slat; - GetLattice(use_final_probs, NumFramesDecoded(), &lat); - ShortestPath(lat, &slat); - ConvertLattice(slat, olat); -} template void LatticeIncrementalDecoderTpl::PossiblyResizeHash(size_t num_toks) { @@ -166,6 +159,12 @@ void LatticeIncrementalDecoderTpl::PossiblyResizeHash(size_t num_tok extra_cost is used in pruning tokens, to save memory. + extra_cost can be thought of as a beta (backward) cost assuming + we had set the betas on currently-active tokens to all be the negative + of the alphas for those tokens. (So all currently active tokens would + be on (tied) best paths). + + Define the 'forward cost' of a token as zero for any token on the frame we're currently decoding; and for other frames, as the shortest-path cost between that token and a token on the frame we're currently decoding. @@ -850,83 +849,6 @@ void LatticeIncrementalDecoderTpl< KALDI_ASSERT(num_toks_ == 0); } -// static -template -void LatticeIncrementalDecoderTpl::TopSortTokens( - Token *tok_list, std::vector *topsorted_list) { - unordered_map token2pos; - typedef typename unordered_map::iterator IterType; - int32 num_toks = 0; - for (Token *tok = tok_list; tok != NULL; tok = tok->next) num_toks++; - int32 cur_pos = 0; - // We assign the tokens numbers num_toks - 1, ... , 2, 1, 0. - // This is likely to be in closer to topological order than - // if we had given them ascending order, because of the way - // new tokens are put at the front of the list. - for (Token *tok = tok_list; tok != NULL; tok = tok->next) - token2pos[tok] = num_toks - ++cur_pos; - - unordered_set reprocess; - - for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) { - Token *tok = iter->first; - int32 pos = iter->second; - for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { - if (link->ilabel == 0) { - // We only need to consider epsilon links, since non-epsilon links - // transition between frames and this function only needs to sort a list - // of tokens from a single frame. - IterType following_iter = token2pos.find(link->next_tok); - if (following_iter != token2pos.end()) { // another token on this frame, - // so must consider it. - int32 next_pos = following_iter->second; - if (next_pos < pos) { // reassign the position of the next Token. - following_iter->second = cur_pos++; - reprocess.insert(link->next_tok); - } - } - } - } - // In case we had previously assigned this token to be reprocessed, we can - // erase it from that set because it's "happy now" (we just processed it). - reprocess.erase(tok); - } - - size_t max_loop = 1000000, loop_count; // max_loop is to detect epsilon cycles. - for (loop_count = 0; !reprocess.empty() && loop_count < max_loop; ++loop_count) { - std::vector reprocess_vec; - for (typename unordered_set::iterator iter = reprocess.begin(); - iter != reprocess.end(); ++iter) - reprocess_vec.push_back(*iter); - reprocess.clear(); - for (typename std::vector::iterator iter = reprocess_vec.begin(); - iter != reprocess_vec.end(); ++iter) { - Token *tok = *iter; - int32 pos = token2pos[tok]; - // Repeat the processing we did above (for comments, see above). - for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { - if (link->ilabel == 0) { - IterType following_iter = token2pos.find(link->next_tok); - if (following_iter != token2pos.end()) { - int32 next_pos = following_iter->second; - if (next_pos < pos) { - following_iter->second = cur_pos++; - reprocess.insert(link->next_tok); - } - } - } - } - } - } - KALDI_ASSERT(loop_count < max_loop && - "Epsilon loops exist in your decoding " - "graph (this is not allowed!)"); - - topsorted_list->clear(); - topsorted_list->resize(cur_pos, NULL); // create a list with NULLs in between. - for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) - (*topsorted_list)[iter->second] = iter->first; -} template const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( @@ -1039,16 +961,25 @@ const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( StateId next_state = next_iter->second; BaseFloat this_offset = (l->ilabel != 0 ? cost_offset : 0); LatticeArc arc(l->ilabel, l->olabel, - Weight(l->graph_cost, l->acoustic_cost - this_offset), + LatticeWeight(l->graph_cost, l->acoustic_cost - this_offset), next_state); chunk_lat.AddArc(cur_state, arc); } } } if (num_frames_in_lattice_ == 0) { - std::vector tok_list; - TopSortTokens(active_toks_[0], &tok_list); - Token *start_token = tok_list[0]; + // This block locates the start token. NOTE: we use the fact that in the + // linked list of tokens, things are added at the head, so the start state + // must be at the tail. If this data structure is changed in future, we + // might need to explicitly store the start token as a class member. + Token *tok = active_toks_[0].toks; + if (tok == NULL) { + KALDI_WARN << "No tokens exist on start frame"; + return determinizer_.GetLattice(); // will be empty. + } + while (tok->next != NULL) + tok = tok->next; + Token *start_token = tok; auto iter = tok2state_map.find(start_token); KALDI_ASSERT(iter != tok2state_map.end()); StateId start_state = iter->second; @@ -1056,7 +987,7 @@ const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( } token2label_map_.swap(next_token2label_map); - std::unordered_map final_costs; + std::unordered_map final_costs; if (use_final_probs) { final_cost_frame_ = num_frames_to_include; ComputeFinalCosts(&final_costs_, &final_relative_cost_, @@ -1074,7 +1005,7 @@ const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( } // bool finished_before_beam = - determinizer_.AcceptRawLatticeChunk(chunk_lat, + determinizer_.AcceptRawLatticeChunk(&chunk_lat, (use_final_probs ? &final_costs : NULL)); // We are ignoring the return status, which say whether it finished before the beam. @@ -1172,33 +1103,24 @@ void LatticeIncrementalDeterminizer2::ReweightChunk( } -/** - Identifies states in `chunk_clat` that have arcs entering them with a - `token-label` on them (see glossary in header for definition). We're calling - these `token-final` states. This function outputs a map from such states in - chunk_clat, to the `token-label` on arcs entering them. (It is not possible - that the same state would have multiple arcs entering it with different - token-labels, or some arcs entering with one token-label and some another, or - be both initial and have such arcs; this is true due to how we construct the - raw lattice.) - */ +// See documentation in header void LatticeIncrementalDeterminizer2::IdentifyTokenFinalStates( const CompactLattice &chunk_clat, - std::unordered_map *token_map) { + std::unordered_map *token_map) const { token_map->clear(); using StateId = CompactLatticeArc::StateId; using Label = CompactLatticeArc::Label; StateId num_states = chunk_clat.NumStates(); for (StateId state = 0; state < num_states; state++) { - for (fst::ArcIterator aiter(chunk_clat, start_state); + for (fst::ArcIterator aiter(chunk_clat, state); !aiter.Done(); aiter.Next()) { - CompactLatticeArc &arc = aiter.Value(); + const CompactLatticeArc &arc = aiter.Value(); if (arc.olabel >= kTokenLabelOffset && arc.olabel < kMaxTokenLabel) { StateId nextstate = arc.nextstate; auto r = token_map->insert({nextstate, arc.olabel}); // Check consistency of labels on incoming arcs - KALDI_ASSERT(r->second.second == arc.olabel); + KALDI_ASSERT(r.first->second == arc.olabel); } } } @@ -1219,7 +1141,7 @@ void LatticeIncrementalDeterminizer2::GetNonFinalRedetStates() { StateId redet_state = arc.nextstate; if (non_final_redet_states_.insert(redet_state).second) { // it was not already there - state_queue.push_back(state); + state_queue.push_back(redet_state); } } // Add any states that are reachable from the states above. @@ -1242,13 +1164,16 @@ void LatticeIncrementalDeterminizer2::InitializeRawLatticeChunk( unordered_map *token_label2state) { using namespace fst; - olat->DeleteStates(); - LatticeArc::State start_state = olat->AddState(); + LatticeArc::StateId start_state = olat->AddState(); token_label2state->clear(); - // redet_state_map maps from state-ids in clat_ to state-ids in olat. - unordered_map redet_state_map; + // redet_state_map maps from state-ids in clat_ to state-ids in olat. This + // will be the set of states from which the arcs to final-states in the + // canonical appended lattice leave (physically, these are in the .nextstate + // elements of arcs_, since we use that field for the source state), plus any + // states reachable from those states. + unordered_map redet_state_map; for (CompactLatticeArc::StateId redet_state: non_final_redet_states_) redet_state_map[redet_state] = olat->AddState(); @@ -1265,10 +1190,17 @@ void LatticeIncrementalDeterminizer2::InitializeRawLatticeChunk( !aiter.Done(); aiter.Next()) { const CompactLatticeArc &arc = aiter.Value(); CompactLatticeArc::StateId nextstate = arc.nextstate; - auto iter = redet_state_map.find(nextstate); - KALDI_(iter != redet_state_map.end()); + LatticeArc::StateId lat_nextstate = olat->NumStates(); + auto r = redet_state_map.insert({nextstate, lat_nextstate}); + if (r.second) { // Was inserted. + LatticeArc::StateId s = olat->AddState(); + KALDI_ASSERT(s == lat_nextstate); + } else { + // was not inserted -> was already there. + lat_nextstate = r.first->second; + } CompactLatticeArc clat_arc(arc); - clat_arc.nextstate = iter->second; + clat_arc.nextstate = lat_nextstate; AddCompactLatticeArcToLattice(clat_arc, lat_state, olat); } } @@ -1276,29 +1208,26 @@ void LatticeIncrementalDeterminizer2::InitializeRawLatticeChunk( for (const CompactLatticeArc &arc: final_arcs_) { // We abuse the `nextstate` field to store the source state. CompactLatticeArc::StateId src_state = arc.nextstate; + auto iter = redet_state_map.find(src_state); + KALDI_ASSERT(iter != redet_state_map.end()); + LatticeArc::StateId src_lat_state = iter->second; Label token_label = arc.ilabel; // will be == arc.olabel. KALDI_ASSERT(token_label >= kTokenLabelOffset && token_label < kMaxTokenLabel); - CompactLatticeArc - - auto r = token_label2state->insert({token_labelstate_label, - olat->NumStates()}); + auto r = token_label2state->insert({token_label, + olat->NumStates()}); + LatticeArc::StateId dest_lat_state = r.first->second; if (r.second) { // was inserted - StateId new_state = olat->AddState(); - KALDI_ASSERT(r.first->second == new_state); + LatticeArc::StateId new_state = olat->AddState(); + KALDI_ASSERT(new_state == dest_lat_state); } - LatticeArc::StateId next_lat_state = r.second; - auto iter = redet_state_map.find(src_state); - KALDI_ASSERT(iter != redet_state_map.end()); - LatticeArc::StateId src_lat_state = iter->second; CompactLatticeArc new_arc; - new_arc.nextstate = next_lat_state; + new_arc.nextstate = dest_lat_state; new_arc.ilabel = new_arc.olabel = token_label; new_arc.weight = arc.weight; AddCompactLatticeArcToLattice(new_arc, src_lat_state, olat); } - // Now deal with the initial-probs. Arcs from initial-states to // redeterminized-states in the raw lattice have an olabel that identifies the // id of that redeterminized-state in clat_, and a cost that is derived from @@ -1312,8 +1241,7 @@ void LatticeIncrementalDeterminizer2::InitializeRawLatticeChunk( // a state that is not a redeterminized state." In fact, we include these // arcs for all redeterminized states. I realized that it won't make a // difference to the outcome, and it's easier to do it this way. - for (auto iter: non_final_redet_states_) { - CompactLatticeArc::StateId state_id = iter->first; + for (CompactLatticeArc::StateId state_id: non_final_redet_states_) { BaseFloat forward_cost = forward_costs_[state_id]; LatticeArc arc; arc.ilabel = 0; @@ -1336,7 +1264,7 @@ static bool incr_det_warned = false; void LatticeIncrementalDeterminizer2::UpdateForwardCosts( const std::unordered_map &state_map) { using StateId = CompactLattice::StateId; - BaseFloat infinity = std::numeric_limits::infinity; + BaseFloat infinity = std::numeric_limits::infinity(); StateId cur_size = forward_costs_.size(); for (auto &p: state_map) { StateId state = p.second; // the state-id in clat_ @@ -1358,8 +1286,8 @@ void LatticeIncrementalDeterminizer2::UpdateForwardCosts( for (fst::ArcIterator aiter(clat_, state); !aiter.Done(); aiter.Next()) { const CompactLatticeArc &arc = aiter.Value(); - BaseFloat arc_cost = arc.weight.GetWeight().Value1() + - arc.weight.GetWeight().Value2(), + BaseFloat arc_cost = arc.weight.Weight().Value1() + + arc.weight.Weight().Value2(), next_cost = cur_cost + arc_cost; if (arc.nextstate >= cur_size) { forward_costs_.resize(arc.nextstate + 1, infinity); @@ -1382,37 +1310,38 @@ bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( using Label = CompactLatticeArc::Label; using StateId = CompactLatticeArc::StateId; - - // final_costs is a map from a `token-label` (see glossary) to the - // associated final-prob in a final-state of `raw_fst`, that is associated with - // that Token. These are Tokens that were active at the end of - // the chunk. The final-probs may arise from beta (backward) costs, - // introduced for pruning purposes, and/or from final-probs in HCLG. - // Those costs will not be included in anything we store in this class; - // we will use `old_final_costs` later to cancel them out. + // old_final_costs is a map from a `token-label` (see glossary) to the + // associated final-prob in a final-state of `raw_fst`, that is associated + // with that Token. These are Tokens that were active at the end of the + // chunk. The final-probs may arise from beta (backward) costs, introduced + // for pruning purposes, and/or from final-probs in HCLG. Those costs will + // not be included in anything we store permamently in this class; they used + // only to guide pruned determinization, and we will use `old_final_costs` + // later to cancel them out. std::unordered_map old_final_costs; StateId raw_fst_num_states = raw_fst->NumStates(); for (LatticeArc::StateId s = 0; s < raw_fst_num_states; s++) { - for (ArcIterator aiter(*raw_fst, s); !aiter.Done(); + for (fst::ArcIterator aiter(*raw_fst, s); !aiter.Done(); aiter.Next()) { const LatticeArc &value = aiter.Value(); if (value.olabel >= (Label)kTokenLabelOffset && value.olabel < (Label)kMaxTokenLabel) { LatticeWeight final_weight = raw_fst->Final(value.nextstate); - if (final_weight == LatticeState::Zero() || + if (final_weight == LatticeWeight::Zero() || final_weight.Value2() != 0) { KALDI_ERR << "Label " << value.olabel << " looks like a token-label but its next-state " "has unexpected final-weight " << final_weight.Value1() << ',' << final_weight.Value2(); } - auto r = final_costs.insert({value.olabel, final_weight.Value1()}); - if (!r->second && r->first.second != final_weight.Value1()) { + auto r = old_final_costs.insert({value.olabel, + final_weight.Value1()}); + if (!r.second && r.first->second != final_weight.Value1()) { // For any given token-label, all arcs in raw_fst with that // olabel should go to the same state, so this should be // impossible. KALDI_ERR << "Unexpected mismatch in final-costs for tokens, " - << r->first.second << " vs " << final_weight.Value1(); + << r.first->second << " vs " << final_weight.Value1(); } } } @@ -1434,8 +1363,8 @@ bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( // This will be an error but user-level calling code can detect it from the // lattice being empty. KALDI_WARN << "Empty lattice, something went wrong."; - chunk_clat_.DeleteStates(); - return; + clat_.DeleteStates(); + return false; } StateId start_state = chunk_clat.Start(); // would be 0. @@ -1473,7 +1402,7 @@ bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( } StateId clat_state = label - kStateLabelOffset; StateId chunk_state = arc.nextstate; - bool inserted = state_map.insert({chunk_state, clat_state}); + bool inserted = state_map.insert({chunk_state, clat_state}).second; // Should not have been in the map before. KALDI_ASSERT(inserted); } @@ -1481,12 +1410,12 @@ bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( if (!is_first_chunk) ReweightChunk(&chunk_clat); // Note: we haven't inspected any weights yet. - // Remove any existing arcs in clat_ that leave redeterminized-states, - // and make those states non-final. - for (auto iter: non_final_redet_states_) { - StateId clat_state = *iter; + // Remove any existing arcs in clat_ that leave redeterminized-states, and + // make those states non-final. Below, we'll add arcs leaving those states + // (and possibly new final-probs.) + for (StateId clat_state: non_final_redet_states_) { clat_.DeleteArcs(clat_state); - clat.SetFinal(clat_state, CompactLatticeWeight::Zero()); + clat_.SetFinal(clat_state, CompactLatticeWeight::Zero()); } // The final-arc info is no longer relevant; we'll recreate it below. @@ -1502,15 +1431,15 @@ bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( StateId new_clat_state = clat_.NumStates(); if (state_map.insert({state, new_clat_state}).second) { // If it was inserted then we need to actually allocate that state - StateId s = clat_.NewState(); + StateId s = clat_.AddState(); KALDI_ASSERT(s == new_clat_state); } // else do nothing; it would have been a redeterminized-state and no } // allocation is needed since they already exist in clat_. and // in state_map. if (is_first_chunk) { - KALDI_ASSERT(forward_probs_.empty() && start_state == 0); - forward_probs_.push_back(0.0); // forward-cost of start state is 0. + KALDI_ASSERT(forward_costs_.empty() && start_state == 0); + forward_costs_.push_back(0.0); // forward-cost of start state is 0. clat_.SetStart(state_map[start_state]); } @@ -1530,16 +1459,16 @@ bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( // states that are not `token-final states`; these final-probs would // normally all be zero at this point. // So in almost all cases the following call will do nothing. - clat_->SetFinal(clat_state, chunk_clat.Final(chunk_state)); + clat_.SetFinal(clat_state, chunk_clat.Final(chunk_state)); - for (ArcIterator aiter(chunk_clat, chunk_state); + for (fst::ArcIterator aiter(chunk_clat, chunk_state); !aiter.Done(); aiter.Next()) { CompactLatticeArc arc(aiter.Value()); auto next_iter = state_map.find(arc.nextstate); if (next_iter != state_map.end()) { arc.nextstate = next_iter->second; - clat_->AddArc(clat_state, arc); + clat_.AddArc(clat_state, arc); } else { // TODO: remove the following slightly excessive assertion. KALDI_ASSERT(chunk_clat.Final(arc.nextstate) != CompactLatticeWeight::Zero() && @@ -1562,10 +1491,10 @@ bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( } else if (iter != new_final_costs->end()) { new_cost = iter->second; } else { - new_cost == std::numeric_limits::infinity; + new_cost = std::numeric_limits::infinity(); } - if (new_cost != std::numeric_limits::infinity) { + if (new_cost != std::numeric_limits::infinity()) { // Add a final-prob in clat_. // These final-probs will be consumed by the user if they get the // lattices as we incrementally determinize, but they will not affect @@ -1574,7 +1503,7 @@ bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( LatticeWeight cost_correction(new_cost - old_final_cost, 0.0); CompactLatticeWeight final_prob(arc.weight); final_prob.SetWeight(fst::Times(cost_correction, final_prob.Weight())); - clat_->SetFinal(clat_state, fst::Sum(clat_->Final(clat_state), + clat_.SetFinal(clat_state, fst::Plus(clat_.Final(clat_state), final_prob)); } @@ -1582,8 +1511,8 @@ bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( // contain information about transitions from states in clat_ to // `token-final` states (i.e. states that have a token-label on the arc // to them and that are final in the canonical compact lattice). - arc.weight.SetWeight(fst::Times(arc.weight, - LatticeWeight{-old_final_cost, 0.0))); + arc.weight.SetWeight(fst::Times(arc.weight.Weight(), + LatticeWeight{-old_final_cost, 0.0})); // In a slight abuse of the Arc data structure, the nextstate is set to // the source state. The label (ilabel == olabel) indicates the // token it is associated with. @@ -1594,7 +1523,7 @@ bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( } GetNonFinalRedetStates(); - UpdateForwardProbs(state_map); + UpdateForwardCosts(state_map); return determinized_till_beam; } diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index dce964bccb3..3a1e5f6099e 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -81,15 +81,20 @@ namespace kaldi { canonical appended lattice: This is the appended compact lattice that we conceptually have (i.e. what we described in the paper). - The difference from the "actual appended lattice" is that the + The difference from the "actual appended lattice" stored + in LatticeIncrementalDeterminizer::clat_ is that the actual appended lattice has all its final-arcs replaced with - final-probs (we keep the real final-arcs "on the side" in a - separate data structure). + final-probs, and we keep the real final-arcs "on the side" in a + separate data structure. The final-probs in clat_ aren't + necessarily related to the costs on the final-arcs; instead + they can have arbitrary values passed in by the user (e.g. + if we want to include final-probs). This means that the + clat_ can be returned without modification to the user who wants + a partially determinized result. final-arc: An arc in the canonical appended CompactLattice which goes to a final-state. These arcs will have `state-labels` as their labels. - */ struct LatticeIncrementalDecoderConfig { // All the configuration values until det_opts are the same as in @@ -203,7 +208,10 @@ class LatticeIncrementalDeterminizer2 { lattice, so we don't use the specific type all the time but just say 'Label' */ - LatticeIncrementalDeterminizer2() { } + LatticeIncrementalDeterminizer2( + const TransitionModel &trans_model, + const LatticeIncrementalDecoderConfig &config): + trans_model_(trans_model), config_(config) { } // Resets the lattice determinization data for new utterance void Init(); @@ -253,8 +261,8 @@ class LatticeIncrementalDeterminizer2 { a nonzero final-prob in raw_fst. (States in raw_fst that had a final-prob will still be non-final). - @return returns false if determinization finished earlier than the beam, - true otherwise. + @return returns false if determinization finished earlier than the beam + or the determinized lattice was empty; true otherwise. */ bool AcceptRawLatticeChunk(Lattice *raw_fst, const std::unordered_map *final_costs = NULL); @@ -276,6 +284,18 @@ class LatticeIncrementalDeterminizer2 { // Updates forward_costs_ for all the states which are successors of states // appearing as values in `state_map`. (By "a is a successor of b" I mean // there is an arc from a to b.) + // For states that already had entries in the forward_costs_ array, this + // will never decrease their forward costs. This may in theory make + // the forward-costs inaccurate (too large) in cases where arcs + // between redeterminized-states were removed by pruned determinization. + // But the forward_costs_ are anyway only used for the pruned determinization, + // and this would never cause things to be pruned away + // and such paths can never become the best-path (this is true because of + // how we set the betas/final-probs/extra-costs on the tokens + + // But this is OK because + // adding a piece of lattice should never worsen the cost of existing + // states void UpdateForwardCosts( const std::unordered_map &state_map); @@ -298,6 +318,26 @@ class LatticeIncrementalDeterminizer2 { void ReweightChunk(CompactLattice *chunk_clat) const; + // Identifies states in `chunk_clat` that have arcs entering them with a + // `token-label` on them (see glossary in header for definition). We're + // calling these `token-final` states. This function outputs a map from such + // states in chunk_clat, to the `token-label` on arcs entering them. (It is + // not possible that the same state would have multiple arcs entering it with + // different token-labels, or some arcs entering with one token-label and some + // another, or be both initial and have such arcs; this is true due to how we + // construct the raw lattice.) + void IdentifyTokenFinalStates( + const CompactLattice &chunk_clat, + std::unordered_map *token_map) const; + + // trans_model_ is needed by DeterminizeLatticePhonePrunedWrapper() which this + // class calls. + const TransitionModel &trans_model_; + // config_ is needed by DeterminizeLatticePhonePrunedWrapper() which this + // class calls. + const LatticeIncrementalDecoderConfig &config_; + + // Contains the set of redeterminized-states which are not final in the // canonical appended lattice. Since the final ones don't physically appear // in clat_, this means the set of redeterminized-states which are physically @@ -407,18 +447,9 @@ class LatticeIncrementalDecoderTpl { } /** - Outputs an FST corresponding to the single best path through the lattice. - If "use_final_probs" is true AND we reached the - final-state of the graph then it will include those as final-probs, else - it will treat all final-probs as one. - - Note: this gets the traceback from the compact lattice, which will not - include the most recently decoded frames if determinize_delay > 0 and - FinalizeDecoding() has not been called. If you'll be wanting to call - GetBestPath() a lot and need it to be up to date, you may prefer to - use LatticeIncrementalOnlineDecoder. - */ - void GetBestPath(Lattice *ofst, bool use_final_probs = true); + This decoder has no GetBestPath() function. + If you need that functionality you should probably use lattice-incremental-online-decoder.h, + which makes it very efficient to obtain the best path. */ /** This GetLattice() function is the main way you will interact with the @@ -591,14 +622,6 @@ class LatticeIncrementalDecoderTpl { // using the "next" pointer. We delete them manually. void DeleteElems(Elem *list); - // This function takes a singly linked list of tokens for a single frame, and - // outputs a list of them in topological order (it will crash if no such order - // can be found, which will typically be due to decoding graphs with epsilon - // cycles, which are not allowed). Note: the output list may contain NULLs, - // which the caller should pass over; it just happens to be more efficient for - // the algorithm to output a list that contains NULLs. - static void TopSortTokens(Token *tok_list, std::vector *topsorted_list); - void ClearActiveTokens(); From 780343c564d07995ecff8225b0d2477c31369b57 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 14 Nov 2019 10:31:37 +0800 Subject: [PATCH 46/60] Fix some compilation issues --- src/decoder/decoder-wrappers.cc | 20 +++--- src/decoder/lattice-incremental-decoder.cc | 79 ++++++++++++++++------ 2 files changed, 70 insertions(+), 29 deletions(-) diff --git a/src/decoder/decoder-wrappers.cc b/src/decoder/decoder-wrappers.cc index 15465b88635..68a1431f470 100644 --- a/src/decoder/decoder-wrappers.cc +++ b/src/decoder/decoder-wrappers.cc @@ -227,12 +227,21 @@ bool DecodeUtteranceLatticeIncremental( } } + // Get lattice + CompactLattice clat; + decoder.GetLattice(decoder.NumFramesDecoded(), true, true); + if (clat.NumStates() == 0) + KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt; + double likelihood; LatticeWeight weight; int32 num_frames; { // First do some stuff with word-level traceback... - VectorFst decoded; - decoder.GetBestPath(&decoded); + CompactLattice decoded_clat; + CompactLatticeShortestPath(clat, &decoded_clat); + Lattice decoded; + fst::ConvertLattice(decoded_clat, &decoded); + if (decoded.Start() == fst::kNoStateId) // Shouldn't really reach this point as already checked success. KALDI_ERR << "Failed to get traceback for utterance " << utt; @@ -259,12 +268,7 @@ bool DecodeUtteranceLatticeIncremental( likelihood = -(weight.Value1() + weight.Value2()); } - // Get lattice - CompactLattice clat; - decoder.GetLattice(true, decoder.NumFramesDecoded(), &clat); - if (clat.NumStates() == 0) - KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt; - // We'll write the lattice without acoustic scaling. + // We'll write the lattice without acoustic scaling. if (acoustic_scale != 0.0) fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat); compact_lattice_writer->Write(utt, clat); diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index e5f7590b497..a32ed936a7f 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -861,6 +861,11 @@ const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( return determinizer_.GetLattice(); } + if (finalize) { + /* Will prune away tokens on the last frame.. */ + FinalizeDecoding(); + } + if (num_frames_to_include < num_frames_in_lattice_ || num_frames_to_include > NumFramesDecoded()) { KALDI_ERR << "GetLattice() called with num-frames-to-include = " @@ -895,16 +900,57 @@ const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( next_token2label_map.clear(); - { // Deal with the last frame. We allocate token labels, and set tokens as + std::unordered_map state2final_cost; + + + if (use_final_probs) { + final_cost_frame_ = num_frames_to_include; + ComputeFinalCosts(&final_costs_, &final_relative_cost_, + &final_best_cost_); + } + if (final_costs_.empty()) { + /* If no states were final on the last frame, we don't use the final-probs. + The user can detect this by calling ReachedFinal(). */ + use_final_probs = false; + } + /* + for (auto iter = final_costs_.begin(); iter != final_costs_.end(); + ++iter) { + Token *tok = iter->first; + BaseFloat final_cost = iter->second; + auto iter2 = tok2state_map.find(tok); + KALDI_ASSERT(iter2 != tok2state_map.end()); + StateId lat_state = iter2->second; + bool inserted = final_costs.insert({lat_state, final_cost}).second; + KALDI_ASSERT(inserted); + } + } */ + + + { // Deal with the last frame in the chunk, the one numbered `num_frames_to_include`. + // (Yes, this is backwards). We allocate token labels, and set tokens as // final, but don't add any transitions. This may leave some states // disconnected (e.g. due to chains of nonemitting arcs), but it's OK; we'll // fix it when we generate the next chunk of lattice. int32 frame = num_frames_to_include; // Allocate state-ids for all tokens on this frame. + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + BaseFloat final_cost = 0.0; + if (use_final_probs) { + auto iter = final_costs_.find(tok); + if (iter == final_costs_.end()) + continue; + final_cost = iter->second; + } + StateId state = chunk_lat.AddState(); tok2state_map[tok] = state; next_token2label_map[tok] = AllocateNewTokenLabel(); + if (use_final_probs) + state2final_cost[state] = final_cost; + + // (A) the case where this is not the last chunk: // First imagine extra_cost == 0, which it would be if // num_frames_to_include == NumFramesDecoded(). We use a final-prob // (i.e. beta) that is the negative of the token's total cost. This @@ -914,8 +960,16 @@ const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( // frames than the final one, the extra cost is equal to the beta (==cost // to the end), assuming we had set the betas on last frame to the // negatives of the alphas. - chunk_lat.SetFinal(state, - LatticeWeight(-(tok->tot_cost + tok->extra_cost), 0.0)); + // + // (B) the case where this is the last chunk: + // use `final_cost`, which reflect the final-cost in HCLG if + // requested by the user and if at least one final state was present + // on the last frame; and otherwise zero. + + BaseFloat chunk_final_cost = (finalize ? final_cost : + -(tok->tot_cost + tok->extra_cost)); + + chunk_lat.SetFinal(state, LatticeWeight(chunk_final_cost, 0.0)); } } @@ -987,26 +1041,9 @@ const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( } token2label_map_.swap(next_token2label_map); - std::unordered_map final_costs; - if (use_final_probs) { - final_cost_frame_ = num_frames_to_include; - ComputeFinalCosts(&final_costs_, &final_relative_cost_, - &final_best_cost_); - for (auto iter = final_costs_.begin(); iter != final_costs_.end(); - ++iter) { - Token *tok = iter->first; - BaseFloat final_cost = iter->second; - auto iter2 = tok2state_map.find(tok); - KALDI_ASSERT(iter2 != tok2state_map.end()); - StateId lat_state = iter2->second; - bool inserted = final_costs.insert({lat_state, final_cost}).second; - KALDI_ASSERT(inserted); - } - } - // bool finished_before_beam = determinizer_.AcceptRawLatticeChunk(&chunk_lat, - (use_final_probs ? &final_costs : NULL)); + (use_final_probs ? &state2final_cost : NULL)); // We are ignoring the return status, which say whether it finished before the beam. num_frames_in_lattice_ = num_frames_to_include; From 6836282f5285cddbc6a8c806910fb93826c2857f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 16 Nov 2019 08:13:13 +0800 Subject: [PATCH 47/60] Some code simplification in incremental determinization --- src/decoder/lattice-incremental-decoder.cc | 551 +++++++++++---------- src/decoder/lattice-incremental-decoder.h | 267 +++++----- 2 files changed, 441 insertions(+), 377 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index a32ed936a7f..0aa36a14211 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -78,38 +78,39 @@ void LatticeIncrementalDecoderTpl::InitDecoding() { determinizer_.Init(); num_frames_in_lattice_ = 0; token2label_map_.clear(); - next_token_label_ = LatticeIncrementalDeterminizer2::kTokenLabelOffset; + next_token_label_ = LatticeIncrementalDeterminizer::kTokenLabelOffset; ProcessNonemitting(config_.beam); } template -void LatticeIncrementalDecoderTpl::DeterminizeLattice() { - // We always incrementally determinize the lattice after lattice pruning in - // PruneActiveTokens() since we need extra_cost as the weights - // of final arcs to denote the "future" information of final states (Tokens) - // Moreover, the delay on GetLattice to do determinization - // make it process more skinny lattices which reduces the computation overheads. - int32 frame_det_most = NumFramesDecoded() - config_.determinize_delay; - // The minimum length of chunk is config_.determinize_period. - if (frame_det_most % config_.determinize_period == 0) { - int32 frame_det_least = num_frames_in_lattice_ + config_.determinize_period; - // Incremental determinization: - // To adaptively decide the length of chunk, we further compare the number of - // tokens in each frame and a pre-defined threshold. - // If the number of tokens in a certain frame is less than - // config_.determinize_max_active, the lattice can be determinized up to this - // frame. And we try to determinize as most frames as possible so we check - // numbers from frame_det_most to frame_det_least - for (int32 f = frame_det_most; f >= frame_det_least; f--) { - if (config_.determinize_max_active == std::numeric_limits::max() || - GetNumToksForFrame(f) < config_.determinize_max_active) { - KALDI_VLOG(2) << "Frame: " << NumFramesDecoded() - << " incremental determinization up to " << f; - GetLattice(false, f); - break; - } +void LatticeIncrementalDecoderTpl::UpdateLatticeDeterminization() { + if (NumFramesDecoded() - num_frames_in_lattice_ < + config_.determinize_max_delay) + return; + + + /* Make sure the token-pruning is active. Note: PruneActiveTokens() has + internal logic that prevents it from doing unnecessary work if you + call it and then immediately call it again. */ + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + + int32 first = num_frames_in_lattice_ + config_.determinize_min_chunk_size, + last = NumFramesDecoded(), + fewest_tokens = std::numeric_limits::max(), + best_frame = -1; + for (int32 t = first; t <= last; t++) { + /* Make sure PruneActiveTokens() has computed num_toks for all these + frames... */ + KALDI_ASSERT(!active_toks_[t].num_toks != -1); + if (active_toks_[t].num_toks < fewest_tokens) { + fewest_tokens = active_toks_[t].num_toks; + best_frame = t; } } + /* OK, determinize the chunk that spans from num_frames_in_lattice_ to + best_frame. */ + bool use_final_probs = false; + GetLattice(best_frame, use_final_probs); return; } // Returns true if any kind of traceback is available (not necessarily from @@ -128,14 +129,15 @@ bool LatticeIncrementalDecoderTpl::Decode(DecodableInterface *decoda PruneActiveTokens(config_.lattice_beam * config_.prune_scale); } - DeterminizeLattice(); + UpdateLatticeDeterminization(); BaseFloat cost_cutoff = ProcessEmitting(decodable); ProcessNonemitting(cost_cutoff); } Timer timer; FinalizeDecoding(); - GetLattice(true, NumFramesDecoded()); + bool use_final_probs = true; + GetLattice(NumFramesDecoded(), use_final_probs); KALDI_VLOG(2) << "Delay time during and after FinalizeDecoding()" << "(secs): " << timer.Elapsed(); @@ -405,13 +407,9 @@ void LatticeIncrementalDecoderTpl::PruneForwardLinksFinal() { template BaseFloat LatticeIncrementalDecoderTpl::FinalRelativeCost() const { - if (NumFramesDecoded() != final_cost_frame_) { - BaseFloat relative_cost; - ComputeFinalCosts(NULL, &relative_cost, NULL); - return relative_cost; - } else { - return final_relative_cost_; - } + BaseFloat relative_cost; + ComputeFinalCosts(NULL, &relative_cost, NULL); + return relative_cost; } // Prune away any tokens on this frame that have no forward links. @@ -425,7 +423,8 @@ void LatticeIncrementalDecoderTpl::PruneTokensForFrame( Token *&toks = active_toks_[frame_plus_one].toks; if (toks == NULL) KALDI_WARN << "No tokens alive [doing pruning]"; Token *tok, *next_tok, *prev_tok = NULL; - for (tok = toks; tok != NULL; tok = next_tok) { + int32 num_toks = 0; + for (tok = toks; tok != NULL; tok = next_tok, num_toks++) { next_tok = tok->next; if (tok->extra_cost == std::numeric_limits::infinity()) { // token is unreachable from end of graph; (no forward links survived) @@ -440,6 +439,7 @@ void LatticeIncrementalDecoderTpl::PruneTokensForFrame( prev_tok = tok; } } + active_toks_[frame_plus_one].num_toks = num_toks; } // Go backwards through still-alive tokens, pruning them, starting not from @@ -557,9 +557,6 @@ void LatticeIncrementalDecoderTpl::AdvanceDecoding( if (NumFramesDecoded() % config_.prune_interval == 0) { PruneActiveTokens(config_.lattice_beam * config_.prune_scale); } - - DeterminizeLattice(); - BaseFloat cost_cutoff = ProcessEmitting(decodable); ProcessNonemitting(cost_cutoff); } @@ -853,200 +850,178 @@ void LatticeIncrementalDecoderTpl< template const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( int32 num_frames_to_include, - bool use_final_probs, bool finalize) { - - if (num_frames_to_include == num_frames_in_lattice_) { - // We've already obtained the lattice up to here. - KALDI_ASSERT(finalize == decoding_finalized_); - return determinizer_.GetLattice(); + bool use_final_probs) { + KALDI_ASSERT(num_frames_to_include >= num_frames_in_lattice_ && + num_frames_to_include <= NumFramesDecoded()); + + if (decoding_finalized_ && !use_final_probs) { + // This is not supported + KALDI_ERR << "You cannot get the lattice without final-probs after " + "calling FinalizeDecoding()."; } - - if (finalize) { - /* Will prune away tokens on the last frame.. */ - FinalizeDecoding(); + if (use_final_probs && num_frames_to_include != NumFramesDecoded()) { + /* This is because we only remember the relation between HCLG states and + Tokens for the current frame; the Token does not have a `state` field. */ + KALDI_ERR << "use-final-probs may no be true if you are not " + "getting a lattice for all frames decoded so far."; } - if (num_frames_to_include < num_frames_in_lattice_ || - num_frames_to_include > NumFramesDecoded()) { - KALDI_ERR << "GetLattice() called with num-frames-to-include = " - << num_frames_to_include << " but already determinized " - << num_frames_in_lattice_ << " frames and " - << NumFramesDecoded() << " frames decoded so far."; - } - KALDI_ASSERT(!decoding_finalized_); - if (finalize) - FinalizeDecoding(); // does pruning of the raw lattice. - - if (num_frames_to_include < NumFramesDecoded() && - (use_final_probs || finalize)) { - KALDI_ERR << "You cannot set use_final_probs or finalize if not requesting " - "all the frames decoded so far."; - } - Lattice chunk_lat; + if (num_frames_to_include > num_frames_in_lattice_) { + /* Make sure the token-pruning is up to date. If we just pruned the tokens, + this will do very little work. */ + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); - unordered_map token_label2state; - if (num_frames_in_lattice_ != 0) { - determinizer_.InitializeRawLatticeChunk(&chunk_lat, - &token_label2state); - } + Lattice chunk_lat; - // tok_map will map from Token* to state-id in chunk_lat. - // The cur and prev versions alternate on different frames. - unordered_map &tok2state_map(temp_token_map_); - tok2state_map.clear(); + unordered_map token_label2state; + if (num_frames_in_lattice_ != 0) { + determinizer_.InitializeRawLatticeChunk(&chunk_lat, + &token_label2state); + } - unordered_map &next_token2label_map(token2label_map_temp_); - next_token2label_map.clear(); + // tok_map will map from Token* to state-id in chunk_lat. + // The cur and prev versions alternate on different frames. + unordered_map &tok2state_map(temp_token_map_); + tok2state_map.clear(); + unordered_map &next_token2label_map(token2label_map_temp_); + next_token2label_map.clear(); - std::unordered_map state2final_cost; + { // Deal with the last frame in the chunk, the one numbered `num_frames_to_include`. + // (Yes, this is backwards). We allocate token labels, and set tokens as + // final, but don't add any transitions. This may leave some states + // disconnected (e.g. due to chains of nonemitting arcs), but it's OK; we'll + // fix it when we generate the next chunk of lattice. + int32 frame = num_frames_to_include; + // Allocate state-ids for all tokens on this frame. + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + /* If we included the final-costs at this stage, they will cause + non-final states to be pruned out from the end of the lattice. */ + BaseFloat final_cost; + if (decoding_finalized_) { + if (final_costs_.empty()) { + final_cost = 0.0; /* No final-state survived, so treat all as final + * with probability One(). */ + } else { + auto iter = final_costs_.find(tok); + if (iter == final_costs_.end()) + continue; + final_cost = iter->second; + } + } else { + /* this is a `fake` final-cost used to guide pruning. This equals + the alpha+beta of the state, if we were to set the betas on + the final frame to the negatives of the alphas (this is a trick + to make all such tokens on the best path, to avoid pruning out + anything that might be within `lattice-beam` of the eventual + best path). + */ + final_cost = -(tok->tot_cost + tok->extra_cost); + } - if (use_final_probs) { - final_cost_frame_ = num_frames_to_include; - ComputeFinalCosts(&final_costs_, &final_relative_cost_, - &final_best_cost_); - } - if (final_costs_.empty()) { - /* If no states were final on the last frame, we don't use the final-probs. - The user can detect this by calling ReachedFinal(). */ - use_final_probs = false; - } - /* - for (auto iter = final_costs_.begin(); iter != final_costs_.end(); - ++iter) { - Token *tok = iter->first; - BaseFloat final_cost = iter->second; - auto iter2 = tok2state_map.find(tok); - KALDI_ASSERT(iter2 != tok2state_map.end()); - StateId lat_state = iter2->second; - bool inserted = final_costs.insert({lat_state, final_cost}).second; - KALDI_ASSERT(inserted); - } - } */ - - - { // Deal with the last frame in the chunk, the one numbered `num_frames_to_include`. - // (Yes, this is backwards). We allocate token labels, and set tokens as - // final, but don't add any transitions. This may leave some states - // disconnected (e.g. due to chains of nonemitting arcs), but it's OK; we'll - // fix it when we generate the next chunk of lattice. - int32 frame = num_frames_to_include; - // Allocate state-ids for all tokens on this frame. - - for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { - BaseFloat final_cost = 0.0; - if (use_final_probs) { - auto iter = final_costs_.find(tok); - if (iter == final_costs_.end()) - continue; - final_cost = iter->second; + StateId state = chunk_lat.AddState(); + tok2state_map[tok] = state; + next_token2label_map[tok] = AllocateNewTokenLabel(); + chunk_lat.SetFinal(state, LatticeWeight(final_cost, 0.0)); } - - StateId state = chunk_lat.AddState(); - tok2state_map[tok] = state; - next_token2label_map[tok] = AllocateNewTokenLabel(); - if (use_final_probs) - state2final_cost[state] = final_cost; - - // (A) the case where this is not the last chunk: - // First imagine extra_cost == 0, which it would be if - // num_frames_to_include == NumFramesDecoded(). We use a final-prob - // (i.e. beta) that is the negative of the token's total cost. This - // ensures that all tokens on the final frame are the 'best token' / have - // the same best-path cost. This is done for pruning purposes, so we - // never prune anything out that's active on the last frame. For earlier - // frames than the final one, the extra cost is equal to the beta (==cost - // to the end), assuming we had set the betas on last frame to the - // negatives of the alphas. - // - // (B) the case where this is the last chunk: - // use `final_cost`, which reflect the final-cost in HCLG if - // requested by the user and if at least one final state was present - // on the last frame; and otherwise zero. - - BaseFloat chunk_final_cost = (finalize ? final_cost : - -(tok->tot_cost + tok->extra_cost)); - - chunk_lat.SetFinal(state, LatticeWeight(chunk_final_cost, 0.0)); } - } - // Go in reverse order over the remaining frames so we can create arcs as we - // go, and their destination-states will already be in the map. - for (int32 frame = num_frames_to_include - 1; - frame >= num_frames_in_lattice_; frame--) { - BaseFloat cost_offset = cost_offsets_[frame]; - - // For the first frame of the chunk, we need to make sure the states are - // the ones created by InitializeRawLatticeChunk() (where not pruned away). - if (frame == num_frames_in_lattice_ && num_frames_in_lattice_ != 0) { - for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { - auto iter = token2label_map_.find(tok); - KALDI_ASSERT(iter != token2label_map_.end()); - Label token_label = iter->second; - auto iter2 = token_label2state.find(token_label); - if (iter2 != token_label2state.end()) { - StateId state = iter2->second; - tok2state_map[tok] = state; - } else { - // Some states may have been pruned out, but we should still allocate - // them. They might have been part of chains of nonemitting arcs - // where the state became disconnected because the last chunk didn't - // include arcs starting at this frame. + // Go in reverse order over the remaining frames so we can create arcs as we + // go, and their destination-states will already be in the map. + for (int32 frame = num_frames_to_include - 1; + frame >= num_frames_in_lattice_; frame--) { + BaseFloat cost_offset = cost_offsets_[frame]; + + // For the first frame of the chunk, we need to make sure the states are + // the ones created by InitializeRawLatticeChunk() (where not pruned away). + if (frame == num_frames_in_lattice_ && num_frames_in_lattice_ != 0) { + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + auto iter = token2label_map_.find(tok); + KALDI_ASSERT(iter != token2label_map_.end()); + Label token_label = iter->second; + auto iter2 = token_label2state.find(token_label); + if (iter2 != token_label2state.end()) { + StateId state = iter2->second; + tok2state_map[tok] = state; + } else { + // Some states may have been pruned out, but we should still allocate + // them. They might have been part of chains of nonemitting arcs + // where the state became disconnected because the last chunk didn't + // include arcs starting at this frame. + StateId state = chunk_lat.AddState(); + tok2state_map[tok] = state; + } + } + } else { + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { StateId state = chunk_lat.AddState(); tok2state_map[tok] = state; } } - } else { for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { - StateId state = chunk_lat.AddState(); - tok2state_map[tok] = state; + auto iter = tok2state_map.find(tok); + KALDI_ASSERT(iter != tok2state_map.end()); + StateId cur_state = iter->second; + for (ForwardLinkT *l = tok->links; l != NULL; l = l->next) { + auto next_iter = tok2state_map.find(l->next_tok); + KALDI_ASSERT(next_iter != tok2state_map.end()); + StateId next_state = next_iter->second; + BaseFloat this_offset = (l->ilabel != 0 ? cost_offset : 0); + LatticeArc arc(l->ilabel, l->olabel, + LatticeWeight(l->graph_cost, l->acoustic_cost - this_offset), + next_state); + chunk_lat.AddArc(cur_state, arc); + } } } - for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { - auto iter = tok2state_map.find(tok); - KALDI_ASSERT(iter != tok2state_map.end()); - StateId cur_state = iter->second; - for (ForwardLinkT *l = tok->links; l != NULL; l = l->next) { - auto next_iter = tok2state_map.find(l->next_tok); - KALDI_ASSERT(next_iter != tok2state_map.end()); - StateId next_state = next_iter->second; - BaseFloat this_offset = (l->ilabel != 0 ? cost_offset : 0); - LatticeArc arc(l->ilabel, l->olabel, - LatticeWeight(l->graph_cost, l->acoustic_cost - this_offset), - next_state); - chunk_lat.AddArc(cur_state, arc); + if (num_frames_in_lattice_ == 0) { + // This block locates the start token. NOTE: we use the fact that in the + // linked list of tokens, things are added at the head, so the start state + // must be at the tail. If this data structure is changed in future, we + // might need to explicitly store the start token as a class member. + Token *tok = active_toks_[0].toks; + if (tok == NULL) { + KALDI_WARN << "No tokens exist on start frame"; + return determinizer_.GetLattice(); // will be empty. } + while (tok->next != NULL) + tok = tok->next; + Token *start_token = tok; + auto iter = tok2state_map.find(start_token); + KALDI_ASSERT(iter != tok2state_map.end()); + StateId start_state = iter->second; + chunk_lat.SetStart(start_state); } + token2label_map_.swap(next_token2label_map); + + // bool finished_before_beam = + determinizer_.AcceptRawLatticeChunk(&chunk_lat); + // We are ignoring the return status, which say whether it finished before the beam. + + num_frames_in_lattice_ = num_frames_to_include; } - if (num_frames_in_lattice_ == 0) { - // This block locates the start token. NOTE: we use the fact that in the - // linked list of tokens, things are added at the head, so the start state - // must be at the tail. If this data structure is changed in future, we - // might need to explicitly store the start token as a class member. - Token *tok = active_toks_[0].toks; - if (tok == NULL) { - KALDI_WARN << "No tokens exist on start frame"; - return determinizer_.GetLattice(); // will be empty. + + unordered_map token2final_cost; + unordered_map token_label2final_cost; + if (use_final_probs) { + ComputeFinalCosts(&token2final_cost, NULL, NULL); + for (const auto &p: token2final_cost) { + Token *tok = p.first; + BaseFloat cost = p.second; + auto iter = token2label_map_.find(tok); + KALDI_ASSERT(iter != token2label_map_.end()); + Label token_label = iter->second; + bool ret = token_label2final_cost.insert({token_label, cost}).second; + KALDI_ASSERT(ret); /* Make sure it was inserted. */ } - while (tok->next != NULL) - tok = tok->next; - Token *start_token = tok; - auto iter = tok2state_map.find(start_token); - KALDI_ASSERT(iter != tok2state_map.end()); - StateId start_state = iter->second; - chunk_lat.SetStart(start_state); } - token2label_map_.swap(next_token2label_map); - - // bool finished_before_beam = - determinizer_.AcceptRawLatticeChunk(&chunk_lat, - (use_final_probs ? &state2final_cost : NULL)); - // We are ignoring the return status, which say whether it finished before the beam. + /* Note: these final-probs won't affect the next chunk, only the lattice + returned from GetLattice(). They are kind of temporaries. */ + determinizer_.SetFinalCosts(token_label2final_cost.empty() ? NULL : + &token_label2final_cost); - num_frames_in_lattice_ = num_frames_to_include; return determinizer_.GetLattice(); } @@ -1092,7 +1067,7 @@ static void AddCompactLatticeArcToLattice( } -void LatticeIncrementalDeterminizer2::Init() { +void LatticeIncrementalDeterminizer::Init() { non_final_redet_states_.clear(); clat_.DeleteStates(); final_arcs_.clear(); @@ -1100,7 +1075,7 @@ void LatticeIncrementalDeterminizer2::Init() { } // See documentation in header -void LatticeIncrementalDeterminizer2::ReweightChunk( +void LatticeIncrementalDeterminizer::ReweightChunk( CompactLattice *chunk_clat) const { using StateId = CompactLatticeArc::StateId; using Label = CompactLatticeArc::Label; @@ -1141,7 +1116,7 @@ void LatticeIncrementalDeterminizer2::ReweightChunk( // See documentation in header -void LatticeIncrementalDeterminizer2::IdentifyTokenFinalStates( +void LatticeIncrementalDeterminizer::IdentifyTokenFinalStates( const CompactLattice &chunk_clat, std::unordered_map *token_map) const { token_map->clear(); @@ -1166,7 +1141,7 @@ void LatticeIncrementalDeterminizer2::IdentifyTokenFinalStates( -void LatticeIncrementalDeterminizer2::GetNonFinalRedetStates() { +void LatticeIncrementalDeterminizer::GetNonFinalRedetStates() { using StateId = CompactLatticeArc::StateId; non_final_redet_states_.clear(); non_final_redet_states_.reserve(final_arcs_.size()); @@ -1196,7 +1171,7 @@ void LatticeIncrementalDeterminizer2::GetNonFinalRedetStates() { } -void LatticeIncrementalDeterminizer2::InitializeRawLatticeChunk( +void LatticeIncrementalDeterminizer::InitializeRawLatticeChunk( Lattice *olat, unordered_map *token_label2state) { using namespace fst; @@ -1298,7 +1273,7 @@ void LatticeIncrementalDeterminizer2::InitializeRawLatticeChunk( static bool incr_det_warned = false; -void LatticeIncrementalDeterminizer2::UpdateForwardCosts( +void LatticeIncrementalDeterminizer::UpdateForwardCosts( const std::unordered_map &state_map) { using StateId = CompactLattice::StateId; BaseFloat infinity = std::numeric_limits::infinity(); @@ -1334,36 +1309,21 @@ void LatticeIncrementalDeterminizer2::UpdateForwardCosts( forward_costs_[arc.nextstate] = next_cost; } } - - } - } -bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( - Lattice *raw_fst, - const std::unordered_map *new_final_costs) { - using Label = CompactLatticeArc::Label; - using StateId = CompactLatticeArc::StateId; - - // old_final_costs is a map from a `token-label` (see glossary) to the - // associated final-prob in a final-state of `raw_fst`, that is associated - // with that Token. These are Tokens that were active at the end of the - // chunk. The final-probs may arise from beta (backward) costs, introduced - // for pruning purposes, and/or from final-probs in HCLG. Those costs will - // not be included in anything we store permamently in this class; they used - // only to guide pruned determinization, and we will use `old_final_costs` - // later to cancel them out. - std::unordered_map old_final_costs; - StateId raw_fst_num_states = raw_fst->NumStates(); +void LatticeIncrementalDeterminizer::GetRawLatticeFinalCosts( + const Lattice &raw_fst, + std::unordered_map *old_final_costs) { + LatticeArc::StateId raw_fst_num_states = raw_fst.NumStates(); for (LatticeArc::StateId s = 0; s < raw_fst_num_states; s++) { - for (fst::ArcIterator aiter(*raw_fst, s); !aiter.Done(); + for (fst::ArcIterator aiter(raw_fst, s); !aiter.Done(); aiter.Next()) { const LatticeArc &value = aiter.Value(); if (value.olabel >= (Label)kTokenLabelOffset && value.olabel < (Label)kMaxTokenLabel) { - LatticeWeight final_weight = raw_fst->Final(value.nextstate); + LatticeWeight final_weight = raw_fst.Final(value.nextstate); if (final_weight == LatticeWeight::Zero() || final_weight.Value2() != 0) { KALDI_ERR << "Label " << value.olabel @@ -1371,7 +1331,7 @@ bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( "has unexpected final-weight " << final_weight.Value1() << ',' << final_weight.Value2(); } - auto r = old_final_costs.insert({value.olabel, + auto r = old_final_costs->insert({value.olabel, final_weight.Value1()}); if (!r.second && r.first->second != final_weight.Value1()) { // For any given token-label, all arcs in raw_fst with that @@ -1383,6 +1343,24 @@ bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( } } } +} + + +bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( + Lattice *raw_fst) { + using Label = CompactLatticeArc::Label; + using StateId = CompactLatticeArc::StateId; + + // old_final_costs is a map from a `token-label` (see glossary) to the + // associated final-prob in a final-state of `raw_fst`, that is associated + // with that Token. These are Tokens that were active at the end of the + // chunk. The final-probs may arise from beta (backward) costs, introduced + // for pruning purposes, and/or from final-probs in HCLG. Those costs will + // not be included in anything we store permamently in this class; they used + // only to guide pruned determinization, and we will use `old_final_costs` + // later to cancel them out. + std::unordered_map old_final_costs; + GetRawLatticeFinalCosts(*raw_fst, &old_final_costs); CompactLattice chunk_clat; bool determinized_till_beam = DeterminizeLatticePhonePrunedWrapper( @@ -1418,8 +1396,8 @@ bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( // if s is final, then a state-label allocated by AllocateNewStateLabel(); // this will become a .nextstate in final_arcs_). std::unordered_map state_map; - bool is_first_chunk = false; + bool is_first_chunk = false; StateId clat_num_states = clat_.NumStates(); // Process arcs leaving the start state of chunk_clat. These will @@ -1492,24 +1470,34 @@ bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( } StateId clat_state = iter->second; + // We know that this point that `clat_state` is not a token-final state + // (see glossary for definition) as if it were we would have done + // `continue` above. + // // Only in the last chunk of the lattice would be there be a final-prob on // states that are not `token-final states`; these final-probs would - // normally all be zero at this point. - // So in almost all cases the following call will do nothing. + // normally all be Zero() at this point. So in almost all cases the following + // call will do nothing. clat_.SetFinal(clat_state, chunk_clat.Final(chunk_state)); + // Process arcs leaving this state. for (fst::ArcIterator aiter(chunk_clat, chunk_state); !aiter.Done(); aiter.Next()) { CompactLatticeArc arc(aiter.Value()); auto next_iter = state_map.find(arc.nextstate); if (next_iter != state_map.end()) { + // The normal case (when the .nextstate has a corresponding + // state in clat_) is very simple. Just copy the arc over. arc.nextstate = next_iter->second; clat_.AddArc(clat_state, arc); } else { - // TODO: remove the following slightly excessive assertion. + // This is the case when the arc is to a `token-final` state (see + // glossary.) + + // TODO: remove the following slightly excessive assertion? KALDI_ASSERT(chunk_clat.Final(arc.nextstate) != CompactLatticeWeight::Zero() && - arc.olabel >= (Label)kTokenLabelOffset && + arc.olabel >= (Label)kTokenLabelOffset && arc.olabel < (Label)kMaxTokenLabel && chunk_state_to_token.count(arc.nextstate) != 0 && old_final_costs.count(arc.olabel) != 0); @@ -1520,34 +1508,13 @@ bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( chunk_clat.Final(arc.nextstate)); BaseFloat old_final_cost = old_final_costs[arc.olabel]; - auto iter = new_final_costs->find(arc.olabel); - - BaseFloat new_cost; - if (new_final_costs == NULL) { - new_cost = 0.0; // treat all new final-costs as One() - } else if (iter != new_final_costs->end()) { - new_cost = iter->second; - } else { - new_cost = std::numeric_limits::infinity(); - } - if (new_cost != std::numeric_limits::infinity()) { - // Add a final-prob in clat_. - // These final-probs will be consumed by the user if they get the - // lattices as we incrementally determinize, but they will not affect - // what happens after we process the next chunk. These final-probs - // would not exist in the `canonical compact lattice` (see glossary). - LatticeWeight cost_correction(new_cost - old_final_cost, 0.0); - CompactLatticeWeight final_prob(arc.weight); - final_prob.SetWeight(fst::Times(cost_correction, final_prob.Weight())); - clat_.SetFinal(clat_state, fst::Plus(clat_.Final(clat_state), - final_prob)); - } - - // OK, `arc` is going to become an element of final_arcs_. These + // `arc` is going to become an element of final_arcs_. These // contain information about transitions from states in clat_ to // `token-final` states (i.e. states that have a token-label on the arc // to them and that are final in the canonical compact lattice). + // We subtract the old_final_cost as it was just a temporary cost + // introduced for pruning purposes. arc.weight.SetWeight(fst::Times(arc.weight.Weight(), LatticeWeight{-old_final_cost, 0.0})); // In a slight abuse of the Arc data structure, the nextstate is set to @@ -1566,6 +1533,64 @@ bool LatticeIncrementalDeterminizer2::AcceptRawLatticeChunk( } + +void LatticeIncrementalDeterminizer::SetFinalCosts( + const unordered_map *token_label2final_cost) { + if (final_arcs_.empty()) { + KALDI_WARN << "SetFinalCosts() called when final_arcs_.empty()... possibly " + "means you are calling this after Finalize()? Not allowed: could " + "indicate a code error. Or possibly decoding failed somehow."; + } + + /* + prefinal states a terminology that does not appear in the paper. What it + means is: the set of states that have an arc with a Token-label as the label + leaving them in the canonical appended lattice. + */ + std::unordered_set &prefinal_states(temp_); + prefinal_states.clear(); + for (const auto &arc: final_arcs_) { + /* Caution: `state` is actually the state the arc would + leave from in the canonical appended lattice; we just store + that in the .nextstate field. */ + CompactLatticeArc::StateId state = arc.nextstate; + prefinal_states.insert(state); + } + + for (int32 state: prefinal_states) + clat_.SetFinal(state, CompactLatticeWeight::Zero()); + + + for (const CompactLatticeArc &arc: final_arcs_) { + Label token_label = arc.ilabel; + /* Note: we store the source state in the .nextstate field. */ + CompactLatticeArc::StateId src_state = arc.nextstate; + BaseFloat graph_final_cost; + if (token_label2final_cost == NULL) { + graph_final_cost = 0.0; + } else { + auto iter = token_label2final_cost->find(token_label); + if (iter == token_label2final_cost->end()) + continue; + else + graph_final_cost = iter->second; + } + /* It might seem odd to set a final-prob on the src-state of the arc.. + the point is that the symbol on the arc is a token-label, which should not + appear in the lattice the user sees, so after that token-label is removed + the arc would just become a final-prob. + */ + clat_.SetFinal(src_state, + fst::Plus(clat_.Final(src_state), + fst::Times(arc.weight, + CompactLatticeWeight( + LatticeWeight(graph_final_cost, 0), {})))); + } +} + + + + // Instantiate the template for the combination of token types and FST types // that we'll need. template class LatticeIncrementalDecoderTpl, decoder::StdToken>; diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 3a1e5f6099e..e91e9a3805b 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -57,12 +57,13 @@ namespace kaldi { that we will determinize. In the paper this corresponds to the FST B that is described in Section 5.2. - token_label, state_label: In the paper these are both - referred to as `state labels` (these are special, large integer - id's that refer to states in the undeterminized lattice - and in the the determinized lattice); - but we use two separate terms here, for more clarity, - when referring to the undeterminized vs. determinized lattice. + token_label, state_label / token-label, state-label: + + In the paper these are both referred to as `state labels` (these are + special, large integer id's that refer to states in the undeterminized + lattice and in the the determinized lattice); but we use two separate + terms here, for more clarity, when referring to the undeterminized + vs. determinized lattice. token_label conceptually refers to states in the raw lattice, but we don't materialize the entire @@ -74,6 +75,11 @@ namespace kaldi { to labels that identify states in the determinized lattice (i.e. state indexes in lat_). + token-final state + A state in a raw lattice or in a determinized chunk that has an arc + entering it that has a `token-label` on it (as defined above). + These states will have nonzero final-probs. + redeterminized-non-splice-state, aka ns_redet: A redeterminized state which is not also a splice state; refer to the paper for explanation. In the already-determinized @@ -95,6 +101,7 @@ namespace kaldi { final-arc: An arc in the canonical appended CompactLattice which goes to a final-state. These arcs will have `state-labels` as their labels. + */ struct LatticeIncrementalDecoderConfig { // All the configuration values until det_opts are the same as in @@ -115,12 +122,12 @@ struct LatticeIncrementalDecoderConfig { fst::DeterminizeLatticePhonePrunedOptions det_opts; // The configuration values from this point on are specific to the - // incremental determinization. - // TODO: explain the following. - int32 determinize_delay; - int32 determinize_period; - int32 determinize_max_active; - int32 redeterminize_max_frames; + // incremental determinization. See where they are registered for + // explanation. + // Caution: these are only inspected in UpdateLatticeDeterminization(). + // If you call + int32 determinize_max_delay; + int32 determinize_min_chunk_size; LatticeIncrementalDecoderConfig() @@ -132,10 +139,8 @@ struct LatticeIncrementalDecoderConfig { beam_delta(0.5), hash_ratio(2.0), prune_scale(0.1), - determinize_delay(25), - determinize_period(20), - determinize_max_active(std::numeric_limits::max()), - redeterminize_max_frames(std::numeric_limits::max()) { } + determinize_max_delay(60), + determinize_min_chunk_size(20) { } void Register(OptionsItf *opts) { det_opts.Register(opts); opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); @@ -149,32 +154,6 @@ struct LatticeIncrementalDecoderConfig { opts->Register("prune-interval", &prune_interval, "Interval (in frames) at " "which to prune tokens"); - // TODO: check the following. - opts->Register("determinize-delay", &determinize_delay, - "Delay (in frames) at which to incrementally determinize " - "lattices. A larger delay reduces the computational " - "overhead of incremental deteriminization while increasing" - "the length of the last chunk which may increase latency."); - opts->Register("determinize-period", &determinize_period, - "The size (in frames) of chunk to do incrementally " - "determinization. If working with --determinize-max-active," - "it will become a lower bound of the size of chunk."); - opts->Register("determinize-max-active", &determinize_max_active, - "This option is to adaptively decide the size of the chunk " - "to be determinized. " - "If the number of active tokens(in a certain frame) is less " - "than this number (typically 50), we will start to " - "incrementally determinize lattices from the last frame we " - "determinized up to this frame. It can work with " - "--determinize-delay to further reduce the computation " - "introduced by incremental determinization. "); - opts->Register("redeterminize-max-frames", &redeterminize_max_frames, - "To impose a limit on how far back in time we will " - "redeterminize states. This is mainly intended to avoid " - "pathological cases. Smaller value leads to less " - "deterministic but less likely to blow up the processing" - "time in bad cases. You could set it infinite to get a fully " - "determinized lattice."); opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " "parameter is obscure and relates to a speedup in the way the " @@ -182,14 +161,20 @@ struct LatticeIncrementalDecoderConfig { opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to " "control hash behavior"); + opts->Register("determinize-max-delay", &determinize_max_delay, + "Maximum frames of delay between decoding a frame and " + "determinizing it"); + opts->Register("determinize-min-chunk-size", &determinize_min_chunk_size, + "Minimum chunk size used in determinization"); + } void Check() const { KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 && min_active <= max_active && prune_interval > 0 && - determinize_delay >= 0 && determinize_max_active >= 0 && - determinize_period >= 0 && redeterminize_max_frames >= 0 && - beam_delta > 0.0 && hash_ratio >= 1.0 && prune_scale > 0.0 && - prune_scale < 1.0); + beam_delta > 0.0 && hash_ratio >= 1.0 && + prune_scale > 0.0 && prune_scale < 1.0 && + determinize_max_delay > determinize_min_chunk_size && + determinize_min_chunk_size > 0); } }; @@ -201,14 +186,14 @@ struct LatticeIncrementalDecoderConfig { https://www.danielpovey.com/files/ *TBD*.pdf for the paper. */ -class LatticeIncrementalDeterminizer2 { +class LatticeIncrementalDeterminizer { public: using Label = typename LatticeArc::Label; /* Actualy the same labels appear in both lattice and compact lattice, so we don't use the specific type all the time but just say 'Label' */ - LatticeIncrementalDeterminizer2( + LatticeIncrementalDeterminizer( const TransitionModel &trans_model, const LatticeIncrementalDecoderConfig &config): trans_model_(trans_model), config_(config) { } @@ -255,22 +240,42 @@ class LatticeIncrementalDeterminizer2 { raw (state-level) lattice. Would correspond to the FST A in the paper if first_frame == 0, and B otherwise. - @param [in] final_costs Final-costs that the user wants to - be included in clat_. These replace the values present - in the Final() probs in raw_fst whenever there was - a nonzero final-prob in raw_fst. (States in raw_fst - that had a final-prob will still be non-final). @return returns false if determinization finished earlier than the beam or the determinized lattice was empty; true otherwise. - */ - bool AcceptRawLatticeChunk(Lattice *raw_fst, - const std::unordered_map *final_costs = NULL); + NOTE: if this is not the final chunk, you will probably want to call + SetFinalCosts() directly after calling this. + */ + bool AcceptRawLatticeChunk(Lattice *raw_fst); + + /* + Sets final-probs in `clat_`. Must only be called if the final chunk + has not been processed. (The final chunk is whenever GetLattice() is + called with finalize == true). + + The reason this is a separate function from AcceptRawLatticeChunk() is that + there may be situations where a user wants to get the latice with + final-probs in it, after previously getting it without final-probs; or + vice versa. By final-probs, we mean the Final() probabilities in the + HCLG (decoding graph; this->fst_). + + @param [in] token_label2final_cost A map from the token-label + corresponding to Tokens active on the final frame of the + lattice in the object, to the final-cost we want to use for + those tokens. If NULL, it means all Tokens should be treated + as final with probability One(). If non-NULL, and a particular + token-label is not a key of this map, it means that Token + corresponded to a state that was not final in HCLG; and + such tokens will be treated as non-final. However, + if this would result in no states in the lattice being final, + we will treat all Tokens as final with probability One(), + a warning will be printed (this should not happen.) + */ + void SetFinalCosts(const unordered_map *token_label2final_cost = NULL); const CompactLattice &GetLattice() { return clat_; } - // kStateLabelOffset is what we add to state-ids in clat_ to produce labels // to identify them in the raw lattice chunk // kTokenLabelOffset is where we start allocating labels corresponding to Tokens @@ -278,6 +283,18 @@ class LatticeIncrementalDeterminizer2 { enum { kStateLabelOffset = (int)1e8, kTokenLabelOffset = (int)2e8, kMaxTokenLabel = (int)3e8 }; private: + + // Gets the final costs from token-final states in the raw lattice (see + // glossary for definition). These final costs will be subtracted after + // determinization; in the normal case they are `temporaries` used to guide + // pruning. NOTE: the index of the array is not the FST state that is final, + // but the label on arcs entering it (these will be `token-labels`). Each + // token-final state will have the same label on all arcs entering it. + // + // `old_final_costs` is assumed to be empty at entry. + void GetRawLatticeFinalCosts(const Lattice &raw_fst, + std::unordered_map *old_final_costs); + // Sets up non_final_redet_states_. See documentation for that variable. void GetNonFinalRedetStates(); @@ -289,13 +306,12 @@ class LatticeIncrementalDeterminizer2 { // the forward-costs inaccurate (too large) in cases where arcs // between redeterminized-states were removed by pruned determinization. // But the forward_costs_ are anyway only used for the pruned determinization, - // and this would never cause things to be pruned away - // and such paths can never become the best-path (this is true because of - // how we set the betas/final-probs/extra-costs on the tokens - - // But this is OK because - // adding a piece of lattice should never worsen the cost of existing - // states + // and this type of error would never cause things to be pruned away that + // should not have been pruned away. (Bear in mind that + // such paths can never become the best-path; this is true because of + // how we set the betas/final-probs/extra-costs on the tokens, which + // makes the distance of a state or arc from the best path a lower bound on what that + // distance will eventually become after we finish decoding.) void UpdateForwardCosts( const std::unordered_map &state_map); @@ -318,11 +334,10 @@ class LatticeIncrementalDeterminizer2 { void ReweightChunk(CompactLattice *chunk_clat) const; - // Identifies states in `chunk_clat` that have arcs entering them with a - // `token-label` on them (see glossary in header for definition). We're - // calling these `token-final` states. This function outputs a map from such - // states in chunk_clat, to the `token-label` on arcs entering them. (It is - // not possible that the same state would have multiple arcs entering it with + // Identifies token-final states in `chunk_clat`; see glossary above for + // definition of `token-final`. This function outputs a map from such states + // in chunk_clat, to the `token-label` on arcs entering them. (It is not + // possible that the same state would have multiple arcs entering it with // different token-labels, or some arcs entering with one token-label and some // another, or be both initial and have such arcs; this is true due to how we // construct the raw lattice.) @@ -368,7 +383,10 @@ class LatticeIncrementalDeterminizer2 { // be thought of as the sum of a Value1() + Value2() in a LatticeWeight. std::vector forward_costs_; - KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDeterminizer2); + // temporary used in a function, kept here to avoid excessive reallocation. + std::unordered_set temp_; + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDeterminizer); }; @@ -419,9 +437,9 @@ class LatticeIncrementalDecoderTpl { ~LatticeIncrementalDecoderTpl(); /** - CAUTION: this function is provided only for testing and instructional - purposes. In a scenario where you have the entire file and just want - to decode it, there is no point using this decoder. + CAUTION: it's unlikely that you will ever want to call this function. In a + scenario where you have the entire file and just want to decode it, there + is no point using this decoder. An example of how to do decoding together with incremental determinization. It decodes until there are no more frames left in the @@ -466,43 +484,69 @@ class LatticeIncrementalDecoderTpl { to be included in the lattice. Must be >0 and <= NumFramesDecoded(). If you are calling this just to keep the incremental lattice determinization up to date and - don't really need the lattice now or don't need it to be up + don't really need the lattice now or don't need it to be fully up to date, you will probably want to make num_frames_to_include at least 5 or 10 frames less than - NumFramessDecoded(); search for determinize-delay in the - paper and for determinize_delay in the configuration class - and the code. You may not call this with a - num_frames_to_include that is smaller than the largest - value previously provided. Calling it with an + NumFramessDecoded(), to avoid the lattice having too many + arcs. + + CAUTION: You may not call this with a + num_frames_to_include that is smaller than NumFramesInLattice() + value previously provided to Get. Calling it with an only-slightly-larger version than the last time (e.g. just a few frames larger) is probably not a good use of computational resources. @param [in] use_final_probs True if you want the final-probs - of HCLG to be included in the output lattice. Must not be - set if num_frames_to_include < NumFramesDecoded(). If no - state was final on frame `num_frames_to_include` they won't - be included regardless of use_final_probs; you can test this - with ReachedFinal(). Caution: it is an error to call this - function in succession with the same num_frames_to_include - and different values of `use_final_probs`. (This is not a - fundamental limitation but just the way we coded it.) - - @param [in] finalize If true, finalize the lattice (does an extra - pruning step on the raw lattice). After this call, no - further calls to GetLattice() will be allowed. + of HCLG to be included in the output lattice. Must not + be set to true if num_frames_to_include != + NumFramesDecoded(). Must be set to true if you have + previously called FinalizeDecoding(). + + (If no state was final on frame `num_frames_to_include` the + final-probs won't be included regardless of use_final_probs; + you can test this with ReachedFinal(). @return clat The CompactLattice representing what has been decoded up until `num_frames_to_include` (e.g., LatticeStateTimes() on this lattice would return `num_frames_to_include`). + See also UpdateLatticeDeterminizaton(). */ const CompactLattice &GetLattice(int32 num_frames_to_include, - bool use_final_probs = false, - bool finalize = false); - - + bool use_final_probs = false); + /** + UpdateLatticeDeterminization() is what you call when you don't yet need + the lattice, but want to make sure the work of determinization is kept + up to date so that when you do need the lattice you can get it fast. + It uses the configuration values `determinize_delay`, + `determinize_period`, `determinize_period` and `determinize_max_active` + to decide whether and when to call GetLattice(). You can safely + call this as often as you want, it won't do subtantially more work if it + is called frequently. + */ + void UpdateLatticeDeterminization(); + + /* + Returns the number of frames in the currently-determinized part of the + lattice which will be a number in [0, NumFramesDecoded()]. It will + be the largest number that GetLattice() was called with, but note + that GetLattice() may be called from UpdateLatticeDeterminization(). + + Made + available in case the user wants to give that same number to + GetLattice(). + + CAUTION: if the caller wants to call GetLattice() with use_final_probs == + true, or use_final_probs == true and finalize == true, and you previously + updated the lattice with UpdateLatticeDeterminization() up to a certain + frame, it's necessary to call GetLattice() with a *higher* frame than the + previously used one, because the interface doesn't support adding/removing + the final-probs unless you are processing a new chunk. This is + not a fundamental limitation, it's just the way we have coded it. + */ + int NumFramesInLattice() const { return num_frames_in_lattice_; } /** InitDecoding initializes the decoding, and should only be used if you @@ -535,6 +579,12 @@ class LatticeIncrementalDecoderTpl { /** Returns the number of frames decoded so far. */ inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } + /** + Finalizes the decoding, doing an extra pruning step on the last frame + that uses the final-probs. May be called only once. + */ + void FinalizeDecoding(); + protected: /* Some protected things are needed in LatticeIncrementalOnlineDecoderTpl. */ @@ -545,8 +595,12 @@ class LatticeIncrementalDecoderTpl { Token *toks; bool must_prune_forward_links; bool must_prune_tokens; + int32 num_toks; /* Note: you can only trust `num_toks` if must_prune_tokens + * == false, because it is only set in + * PruneTokensForFrame(). */ TokenList() - : toks(NULL), must_prune_forward_links(true), must_prune_tokens(true) {} + : toks(NULL), must_prune_forward_links(true), must_prune_tokens(true), + num_toks(-1) {} }; using Elem = typename HashList::Elem; void PossiblyResizeHash(size_t num_toks); @@ -576,7 +630,6 @@ class LatticeIncrementalDecoderTpl { bool warned_; bool decoding_finalized_; - int32 final_cost_frame_; // TODO: initialize. unordered_map final_costs_; BaseFloat final_relative_cost_; BaseFloat final_best_cost_; @@ -588,7 +641,7 @@ class LatticeIncrementalDecoderTpl { LatticeIncrementalDecoderConfig config_; /** Much of the the incremental determinization algorithm is encapsulated in the determinize_ object. */ - LatticeIncrementalDeterminizer2 determinizer_; + LatticeIncrementalDeterminizer determinizer_; /* Just a temporary used in a function; stored here to avoid reallocation. */ @@ -597,6 +650,7 @@ class LatticeIncrementalDecoderTpl { /** num_frames_in_lattice_ is the highest `num_frames_to_include_` argument for any prior call to GetLattice(). */ int32 num_frames_in_lattice_; + // A map from Token to its token_label. Will contain an entry for // each Token in active_toks_[num_frames_in_lattice_]. unordered_map token2label_map_; @@ -606,6 +660,7 @@ class LatticeIncrementalDecoderTpl { // we allocate a unique id for each Token Label next_token_label_; + inline Label AllocateNewTokenLabel() { return next_token_label_++; } @@ -630,22 +685,6 @@ class LatticeIncrementalDecoderTpl { // at the end of an utterance. int32 GetNumToksForFrame(int32 frame); - // DeterminizeLattice() is just a wrapper for GetLattice() that uses the various - // heuristics specified in the config class to decide when, and with what arguments, - // to call GetLattice() in order to make sure that the incremental determinization - // is kept up to date. It is mainly of use for documentation (it is called inside - // Decode() which is not recommended for users to call in most scenarios). - // We may at some point decide to make this public. - void DeterminizeLattice(); - - /** - This function used to be public in LatticeFasterDecoder but is now accessed - only by including the 'finalize' argument to GetLattice(). It may be - called only once per utterance, at the end. (GetLattice() will ensure this - anyway. - It prunes the raw lattice. - */ - void FinalizeDecoding(); KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDecoderTpl); From 4faf9bc7b61c7c1e7f24d477b6eed1700f013b2d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 16 Nov 2019 08:54:37 +0800 Subject: [PATCH 48/60] Simplify interface in lattice determinization --- src/decoder/lattice-incremental-decoder.cc | 4 +- src/decoder/lattice-incremental-decoder.h | 61 ++++++++-------------- 2 files changed, 24 insertions(+), 41 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 0aa36a14211..87e097681ea 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -101,7 +101,7 @@ void LatticeIncrementalDecoderTpl::UpdateLatticeDeterminization() { for (int32 t = first; t <= last; t++) { /* Make sure PruneActiveTokens() has computed num_toks for all these frames... */ - KALDI_ASSERT(!active_toks_[t].num_toks != -1); + KALDI_ASSERT(active_toks_[t].num_toks != -1); if (active_toks_[t].num_toks < fewest_tokens) { fewest_tokens = active_toks_[t].num_toks; best_frame = t; @@ -128,7 +128,6 @@ bool LatticeIncrementalDecoderTpl::Decode(DecodableInterface *decoda if (NumFramesDecoded() % config_.prune_interval == 0) { PruneActiveTokens(config_.lattice_beam * config_.prune_scale); } - UpdateLatticeDeterminization(); BaseFloat cost_cutoff = ProcessEmitting(decodable); @@ -560,6 +559,7 @@ void LatticeIncrementalDecoderTpl::AdvanceDecoding( BaseFloat cost_cutoff = ProcessEmitting(decodable); ProcessNonemitting(cost_cutoff); } + UpdateLatticeDeterminization(); } // FinalizeDecoding() is a version of PruneActiveTokens that we call diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index e91e9a3805b..2e4c6c62291 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -470,32 +470,16 @@ class LatticeIncrementalDecoderTpl { which makes it very efficient to obtain the best path. */ /** - This GetLattice() function is the main way you will interact with the - incremental determinization that this class provides. Note that the - interface is slightly different from that of other decoders. For example, - if olat is NULL it will do the work of incremental determinization without - actually giving you the lattice (which can save it some time). - - Note: calling it on every frame doesn't make sense as it would - still have to do a fair amount of work; calling it every, say, - 10 to 40 frames would make sense though. + This GetLattice() function returns the lattice containing + `num_frames_to_decode` frames; this will be all frames decoded so + far, if you let num_frames_to_decode == NumFramesDecoded(), + but it will generally be better to make it a few frames less than + that to avoid the lattice having too many active states at + the end. @param [in] num_frames_to_include The number of frames that you want - to be included in the lattice. Must be >0 and - <= NumFramesDecoded(). If you are calling this just to - keep the incremental lattice determinization up to date and - don't really need the lattice now or don't need it to be fully up - to date, you will probably want to make - num_frames_to_include at least 5 or 10 frames less than - NumFramessDecoded(), to avoid the lattice having too many - arcs. - - CAUTION: You may not call this with a - num_frames_to_include that is smaller than NumFramesInLattice() - value previously provided to Get. Calling it with an - only-slightly-larger version than the last time (e.g. just - a few frames larger) is probably not a good use of - computational resources. + to be included in the lattice. Must be >= + NumFramesInLattice() and <= NumFramesDecoded(). @param [in] use_final_probs True if you want the final-probs of HCLG to be included in the output lattice. Must not @@ -503,9 +487,10 @@ class LatticeIncrementalDecoderTpl { NumFramesDecoded(). Must be set to true if you have previously called FinalizeDecoding(). - (If no state was final on frame `num_frames_to_include` the - final-probs won't be included regardless of use_final_probs; - you can test this with ReachedFinal(). + (If no state was final on frame `num_frames_to_include`, the + final-probs won't be included regardless of + `use_final_probs`; you can test whether this + was the case by calling ReachedFinal(). @return clat The CompactLattice representing what has been decoded up until `num_frames_to_include` (e.g., LatticeStateTimes() @@ -516,18 +501,6 @@ class LatticeIncrementalDecoderTpl { const CompactLattice &GetLattice(int32 num_frames_to_include, bool use_final_probs = false); - /** - UpdateLatticeDeterminization() is what you call when you don't yet need - the lattice, but want to make sure the work of determinization is kept - up to date so that when you do need the lattice you can get it fast. - It uses the configuration values `determinize_delay`, - `determinize_period`, `determinize_period` and `determinize_max_active` - to decide whether and when to call GetLattice(). You can safely - call this as often as you want, it won't do subtantially more work if it - is called frequently. - */ - void UpdateLatticeDeterminization(); - /* Returns the number of frames in the currently-determinized part of the lattice which will be a number in [0, NumFramesDecoded()]. It will @@ -685,6 +658,16 @@ class LatticeIncrementalDecoderTpl { // at the end of an utterance. int32 GetNumToksForFrame(int32 frame); + /** + UpdateLatticeDeterminization() ensures the work of determinization is kept + up to date so that when you do need the lattice you can get it fast. It + uses the configuration values `determinize_delay`, `determinize_max_delay` + and `determinize_min_chunk_size` to decide whether and when to call + GetLattice(). You can safely call this as often as you want (e.g. after + each time you call AdvanceDecoding(); it won't do subtantially more work if + it is called frequently. + */ + void UpdateLatticeDeterminization(); KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDecoderTpl); From c8ef5fffd3a048d6d0a5d7b72a883b76eae71e56 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 15 Nov 2019 23:50:15 -0500 Subject: [PATCH 49/60] Fix compilation errors in incremental decoding --- src/decoder/decoder-wrappers.cc | 2 +- src/decoder/lattice-incremental-decoder.h | 4 +- .../online-nnet3-incremental-decoding.cc | 18 -------- .../online-nnet3-incremental-decoding.h | 46 +++++++++++++------ 4 files changed, 37 insertions(+), 33 deletions(-) diff --git a/src/decoder/decoder-wrappers.cc b/src/decoder/decoder-wrappers.cc index 68a1431f470..6eff770f8ea 100644 --- a/src/decoder/decoder-wrappers.cc +++ b/src/decoder/decoder-wrappers.cc @@ -229,7 +229,7 @@ bool DecodeUtteranceLatticeIncremental( // Get lattice CompactLattice clat; - decoder.GetLattice(decoder.NumFramesDecoded(), true, true); + decoder.GetLattice(decoder.NumFramesDecoded(), true); if (clat.NumStates() == 0) KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt; diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 2e4c6c62291..138e8241573 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -496,7 +496,9 @@ class LatticeIncrementalDecoderTpl { up until `num_frames_to_include` (e.g., LatticeStateTimes() on this lattice would return `num_frames_to_include`). - See also UpdateLatticeDeterminizaton(). + See also UpdateLatticeDeterminizaton(). Caution: this const ref + is only valid until the next time you call AdvanceDecoding() or + GetLattice(). */ const CompactLattice &GetLattice(int32 num_frames_to_include, bool use_final_probs = false); diff --git a/src/online2/online-nnet3-incremental-decoding.cc b/src/online2/online-nnet3-incremental-decoding.cc index 540a3a4f850..5e7acf147ee 100644 --- a/src/online2/online-nnet3-incremental-decoding.cc +++ b/src/online2/online-nnet3-incremental-decoding.cc @@ -51,24 +51,6 @@ void SingleUtteranceNnet3IncrementalDecoderTpl::AdvanceDecoding() { decoder_.AdvanceDecoding(&decodable_); } -template -void SingleUtteranceNnet3IncrementalDecoderTpl::FinalizeDecoding() { - decoder_.FinalizeDecoding(); -} - -template -int32 SingleUtteranceNnet3IncrementalDecoderTpl::NumFramesDecoded() const { - return decoder_.NumFramesDecoded(); -} - -template -void SingleUtteranceNnet3IncrementalDecoderTpl::GetLattice(bool end_of_utterance, - CompactLattice *clat) { - if (NumFramesDecoded() == 0) - KALDI_ERR << "You cannot get a lattice if you decoded no frames."; - decoder_.GetLattice(end_of_utterance, decoder_.NumFramesDecoded(), clat); -} - template void SingleUtteranceNnet3IncrementalDecoderTpl::GetBestPath(bool end_of_utterance, Lattice *best_path) const { diff --git a/src/online2/online-nnet3-incremental-decoding.h b/src/online2/online-nnet3-incremental-decoding.h index ddd9707bf54..e407cc2be2b 100644 --- a/src/online2/online-nnet3-incremental-decoding.h +++ b/src/online2/online-nnet3-incremental-decoding.h @@ -54,10 +54,10 @@ class SingleUtteranceNnet3IncrementalDecoderTpl { // Constructor. The pointer 'features' is not being given to this class to own // and deallocate, it is owned externally. SingleUtteranceNnet3IncrementalDecoderTpl(const LatticeIncrementalDecoderConfig &decoder_opts, - const TransitionModel &trans_model, - const nnet3::DecodableNnetSimpleLoopedInfo &info, - const FST &fst, - OnlineNnet2FeaturePipeline *features); + const TransitionModel &trans_model, + const nnet3::DecodableNnetSimpleLoopedInfo &info, + const FST &fst, + OnlineNnet2FeaturePipeline *features); /// Initializes the decoding and sets the frame offset of the underlying /// decodable object. This method is called by the constructor. You can also @@ -71,17 +71,37 @@ class SingleUtteranceNnet3IncrementalDecoderTpl { /// Finalizes the decoding. Cleans up and prunes remaining tokens, so the /// GetLattice() call will return faster. You must not call this before /// calling (TerminateDecoding() or InputIsFinished()) and then Wait(). - void FinalizeDecoding(); + void FinalizeDecoding() { decoder_.FinalizeDecoding(); } + + int32 NumFramesDecoded() const { return decoder_.NumFramesDecoded(); } + + int32 NumFramesInLattice() const { return decoder_.NumFramesInLattice(); } + + /* Gets the lattice. The output lattice has any acoustic scaling in it + (which will typically be desirable in an online-decoding context); if you + want an un-scaled lattice, scale it using ScaleLattice() with the inverse + of the acoustic weight. + + @param [in] num_frames_to_include The number of frames you want + to be included in the lattice. Must be in the range + [NumFramesInLattice().. NumFramesDecoded()]. If you + make it a few frames less than NumFramesDecoded(), it + will save significant computation. + @param [in] use_final_probs True if you want the lattice to + contain final-probs (if at least one state was final + on the most recently decoded frame). Must be false + if num_frames_to_include < NumFramesDecoded(). + Must be true if you have previously called + FinalizeDecoding(). + */ + const CompactLattice &GetLattice(int32 num_frames_to_include, + bool use_final_probs = false) { + return decoder_.GetLattice(num_frames_to_include, use_final_probs); + } + + - int32 NumFramesDecoded() const; - /// Gets the lattice. The output lattice has any acoustic scaling in it - /// (which will typically be desirable in an online-decoding context); if you - /// want an un-scaled lattice, scale it using ScaleLattice() with the inverse - /// of the acoustic weight. "end_of_utterance" will be true if you want the - /// final-probs to be included. - void GetLattice(bool end_of_utterance, - CompactLattice *clat); /// Outputs an FST corresponding to the single best path through the current /// lattice. If "use_final_probs" is true AND we reached the final-state of From 3590656530fe08f6eed6de10c6238804c7ccd45e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 16 Nov 2019 08:34:09 -0500 Subject: [PATCH 50/60] Some progress on fixing runtime errors --- src/decoder/lattice-incremental-decoder.cc | 112 ++++++++++++++---- src/decoder/lattice-incremental-decoder.h | 30 +++-- .../online2-wav-nnet3-latgen-incremental.cc | 6 +- 3 files changed, 113 insertions(+), 35 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 87e097681ea..20bf3b815c9 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -98,11 +98,12 @@ void LatticeIncrementalDecoderTpl::UpdateLatticeDeterminization() { last = NumFramesDecoded(), fewest_tokens = std::numeric_limits::max(), best_frame = -1; - for (int32 t = first; t <= last; t++) { + for (int32 t = last; t >= first; t--) { /* Make sure PruneActiveTokens() has computed num_toks for all these frames... */ KALDI_ASSERT(active_toks_[t].num_toks != -1); if (active_toks_[t].num_toks < fewest_tokens) { + // <= because we want the latest one in case of ties. fewest_tokens = active_toks_[t].num_toks; best_frame = t; } @@ -450,6 +451,17 @@ template void LatticeIncrementalDecoderTpl::PruneActiveTokens(BaseFloat delta) { int32 cur_frame_plus_one = NumFramesDecoded(); int32 num_toks_begin = num_toks_; + + if (active_toks_[cur_frame_plus_one].num_toks == -1){ + // The current frame's tokens don't get pruned so they don't get counted + // (the count is needed by the incremental determinization code). + // Fix this. + int this_frame_num_toks = 0; + for (Token *t = active_toks_[cur_frame_plus_one].toks; t != NULL; t = t->next) + this_frame_num_toks++; + active_toks_[cur_frame_plus_one].num_toks = this_frame_num_toks; + } + // The index "f" below represents a "frame plus one", i.e. you'd have to subtract // one to get the corresponding index for the decodable object. for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) { @@ -472,7 +484,8 @@ void LatticeIncrementalDecoderTpl::PruneActiveTokens(BaseFloat delta active_toks_[f + 1].must_prune_tokens = false; } } - KALDI_VLOG(4) << "PruneActiveTokens: pruned tokens from " << num_toks_begin + PruneTokensForFrame(0); + KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin << " to " << num_toks_; } @@ -924,7 +937,16 @@ const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( StateId state = chunk_lat.AddState(); tok2state_map[tok] = state; next_token2label_map[tok] = AllocateNewTokenLabel(); - chunk_lat.SetFinal(state, LatticeWeight(final_cost, 0.0)); + StateId token_final_state = chunk_lat.AddState(); + LatticeArc::Label ilabel = 0, + olabel = (next_token2label_map[tok] = AllocateNewTokenLabel()); + chunk_lat.AddArc(state, + LatticeArc(ilabel, olabel, + LatticeWeight::One(), + token_final_state)); + KALDI_ASSERT(final_cost - final_cost == 0.0); // no inf. + KALDI_LOG << "Setting state " << token_final_state << " final w.p. " << final_cost; + chunk_lat.SetFinal(token_final_state, LatticeWeight(final_cost, 0.0)); } } @@ -1079,17 +1101,18 @@ void LatticeIncrementalDeterminizer::ReweightChunk( CompactLattice *chunk_clat) const { using StateId = CompactLatticeArc::StateId; using Label = CompactLatticeArc::Label; - StateId start_state = chunk_clat->Start(), num_states = chunk_clat->NumStates(); - + StateId chunk_start_state = chunk_clat->Start(), + clat_num_states = clat_.NumStates(); std::vector potentials(chunk_clat->NumStates(), CompactLatticeWeight::One()); - for (fst::MutableArcIterator aiter(chunk_clat, start_state); + for (fst::MutableArcIterator aiter(chunk_clat, + chunk_start_state); !aiter.Done(); aiter.Next()) { CompactLatticeArc arc = aiter.Value(); Label label = arc.ilabel; // ilabel == olabel. StateId clat_state = label - kStateLabelOffset; - KALDI_ASSERT(clat_state >= 0 && clat_state < num_states); + KALDI_ASSERT(clat_state >= 0 && clat_state < clat_num_states); // `extra_weight` serves to cancel out the weight // `forward_costs_[clat_state]` that we introduced in // InitializeRawLatticeChunk(); the purpose of that was to @@ -1107,7 +1130,8 @@ void LatticeIncrementalDeterminizer::ReweightChunk( // Below is just a check that weights on arcs leaving initial state // are all One(). // TODO: remove the following. - for (fst::ArcIterator aiter(*chunk_clat, start_state); + for (fst::ArcIterator aiter(*chunk_clat, + chunk_start_state); !aiter.Done(); aiter.Next()) { KALDI_ASSERT(fst::ApproxEqual(aiter.Value().weight, CompactLatticeWeight::One())); @@ -1178,6 +1202,7 @@ void LatticeIncrementalDeterminizer::InitializeRawLatticeChunk( olat->DeleteStates(); LatticeArc::StateId start_state = olat->AddState(); + olat->SetStart(start_state); token_label2state->clear(); // redet_state_map maps from state-ids in clat_ to state-ids in olat. This @@ -1235,7 +1260,8 @@ void LatticeIncrementalDeterminizer::InitializeRawLatticeChunk( } CompactLatticeArc new_arc; new_arc.nextstate = dest_lat_state; - new_arc.ilabel = new_arc.olabel = token_label; + /* We convert the token-label to epsilon; it's not needed anymore. */ + new_arc.ilabel = new_arc.olabel = 0; new_arc.weight = arc.weight; AddCompactLatticeArcToLattice(new_arc, src_lat_state, olat); } @@ -1274,12 +1300,11 @@ void LatticeIncrementalDeterminizer::InitializeRawLatticeChunk( static bool incr_det_warned = false; void LatticeIncrementalDeterminizer::UpdateForwardCosts( - const std::unordered_map &state_map) { + const std::vector &new_states) { using StateId = CompactLattice::StateId; BaseFloat infinity = std::numeric_limits::infinity(); StateId cur_size = forward_costs_.size(); - for (auto &p: state_map) { - StateId state = p.second; // the state-id in clat_ + for (StateId state: new_states) { // The reason we can make the following assertion is that the states should // be in topological order and each state should be reachable from an // earlier state (and we should have processed that earlier state by now). @@ -1326,9 +1351,10 @@ void LatticeIncrementalDeterminizer::GetRawLatticeFinalCosts( LatticeWeight final_weight = raw_fst.Final(value.nextstate); if (final_weight == LatticeWeight::Zero() || final_weight.Value2() != 0) { - KALDI_ERR << "Label " << value.olabel + KALDI_ERR << "Label " << value.olabel << " from state " << s << " looks like a token-label but its next-state " - "has unexpected final-weight " << final_weight.Value1() << ',' + << value.nextstate << + " has unexpected final-weight " << final_weight.Value1() << ',' << final_weight.Value2(); } auto r = old_final_costs->insert({value.olabel, @@ -1351,6 +1377,16 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( using Label = CompactLatticeArc::Label; using StateId = CompactLatticeArc::StateId; + + { + std::ostringstream os; + bool acceptor = false, write_one = false; + fst::FstPrinter printer(*raw_fst, NULL, NULL, + NULL, acceptor, write_one, "\t"); + printer.Print(&os, ""); + KALDI_LOG << "Raw FST is " << os.str(); + } + // old_final_costs is a map from a `token-label` (see glossary) to the // associated final-prob in a final-state of `raw_fst`, that is associated // with that Token. These are Tokens that were active at the end of the @@ -1367,12 +1403,23 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( trans_model_, raw_fst, config_.lattice_beam, &chunk_clat, config_.det_opts); + { + std::ostringstream os; + bool acceptor = false, write_one = false; + fst::FstPrinter printer(chunk_clat, NULL, NULL, + NULL, acceptor, write_one, "\t"); + printer.Print(&os, ""); + KALDI_LOG << "Determinized chunk FST is " << os.str(); + } + + + TopSortCompactLatticeIfNeeded(&chunk_clat); std::unordered_map chunk_state_to_token; IdentifyTokenFinalStates(chunk_clat, &chunk_state_to_token); - + KALDI_LOG << "num token-final states is " << chunk_state_to_token.size(); StateId chunk_num_states = chunk_clat.NumStates(); if (chunk_num_states == 0) { // This will be an error but user-level calling code can detect it from the @@ -1392,11 +1439,15 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( // with that redeterminized-state. // state_map maps from (non-initial state s in chunk_clat) to: - // if s is not final, then a state in clat_, - // if s is final, then a state-label allocated by AllocateNewStateLabel(); - // this will become a .nextstate in final_arcs_). + // if s is not token-final, then a state in clat_, + + // if s is token-final, then a state-label allocated by + // AllocateNewStateLabel(); this will become a .nextstate in final_arcs_). std::unordered_map state_map; + std::unordered_map rev_state_map; + + bool is_first_chunk = false; StateId clat_num_states = clat_.NumStates(); @@ -1406,9 +1457,10 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( for (fst::ArcIterator aiter(chunk_clat, start_state); !aiter.Done(); aiter.Next()) { const CompactLatticeArc &arc = aiter.Value(); - Label label = arc.ilabel; // ilabel == olabel; would be the olabel (word - // label) in a Lattice. - if (!(label >= kStateLabelOffset && label < clat_num_states)) { + Label label = arc.ilabel; // ilabel == olabel; would be the olabel + // in a Lattice. + if (!(label >= kStateLabelOffset && + label - kStateLabelOffset < clat_num_states)) { // The label was not a state-label. This should only be possible on the // first chunk. KALDI_ASSERT(state_map.empty()); @@ -1455,7 +1507,9 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( if (is_first_chunk) { KALDI_ASSERT(forward_costs_.empty() && start_state == 0); forward_costs_.push_back(0.0); // forward-cost of start state is 0. - clat_.SetStart(state_map[start_state]); + auto iter = state_map.find(start_state); + KALDI_ASSERT(iter != state_map.end()); + clat_.SetStart(iter->second); } // Now transfer arcs from chunk_clat to clat_. @@ -1490,6 +1544,8 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( // The normal case (when the .nextstate has a corresponding // state in clat_) is very simple. Just copy the arc over. arc.nextstate = next_iter->second; + KALDI_ASSERT(arc.ilabel < kTokenLabelOffset || + arc.ilabel > kMaxTokenLabel); clat_.AddArc(clat_state, arc); } else { // This is the case when the arc is to a `token-final` state (see @@ -1527,7 +1583,17 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( } GetNonFinalRedetStates(); - UpdateForwardCosts(state_map); + { + // `clat_new_states` is states that are newly added + // or which were modified (i.e. redeterminized-states). + std::vector clat_new_states; + for (auto p: state_map) + if (p.second < kStateLabelOffset) + clat_new_states.push_back(p.second); + std::sort(clat_new_states.begin(), + clat_new_states.end()); + UpdateForwardCosts(clat_new_states); + } return determinized_till_beam; } diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 138e8241573..5e6c3494ed4 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -140,7 +140,9 @@ struct LatticeIncrementalDecoderConfig { hash_ratio(2.0), prune_scale(0.1), determinize_max_delay(60), - determinize_min_chunk_size(20) { } + determinize_min_chunk_size(20) { + det_opts.minimize = false; + } void Register(OptionsItf *opts) { det_opts.Register(opts); opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); @@ -169,12 +171,17 @@ struct LatticeIncrementalDecoderConfig { } void Check() const { - KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 && - min_active <= max_active && prune_interval > 0 && - beam_delta > 0.0 && hash_ratio >= 1.0 && - prune_scale > 0.0 && prune_scale < 1.0 && - determinize_max_delay > determinize_min_chunk_size && - determinize_min_chunk_size > 0); + if (!(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 && + min_active <= max_active && prune_interval > 0 && + beam_delta > 0.0 && hash_ratio >= 1.0 && + prune_scale > 0.0 && prune_scale < 1.0 && + determinize_max_delay > determinize_min_chunk_size && + determinize_min_chunk_size > 0)) + KALDI_ERR << "Invalid options given to decoder"; + /* Minimization of the chunks is not compatible withour algorithm (or at + least, would require additional complexity to implement.) */ + if (det_opts.minimize || !det_opts.word_determinize) + KALDI_ERR << "Invalid determinization options given to decoder."; } }; @@ -299,7 +306,7 @@ class LatticeIncrementalDeterminizer { void GetNonFinalRedetStates(); // Updates forward_costs_ for all the states which are successors of states - // appearing as values in `state_map`. (By "a is a successor of b" I mean + // appearing as values in `clat_new_states`. (By "a is a successor of b" I mean // there is an arc from a to b.) // For states that already had entries in the forward_costs_ array, this // will never decrease their forward costs. This may in theory make @@ -312,8 +319,13 @@ class LatticeIncrementalDeterminizer { // how we set the betas/final-probs/extra-costs on the tokens, which // makes the distance of a state or arc from the best path a lower bound on what that // distance will eventually become after we finish decoding.) + // + // @param [in] clat_new_states Sorted list of state-ids in + // clat_ which were either given new arcs, or + // created, in the most recent iteration of + // pruned lattice determinization void UpdateForwardCosts( - const std::unordered_map &state_map); + const std::vector &clat_new_states); // Reweights `chunk_clat`. Must not be called if this is the first chunk. diff --git a/src/online2bin/online2-wav-nnet3-latgen-incremental.cc b/src/online2bin/online2-wav-nnet3-latgen-incremental.cc index b48337af5fb..5e4554e4a8c 100644 --- a/src/online2bin/online2-wav-nnet3-latgen-incremental.cc +++ b/src/online2bin/online2-wav-nnet3-latgen-incremental.cc @@ -265,9 +265,9 @@ int main(int argc, char *argv[]) { } decoder.FinalizeDecoding(); - CompactLattice clat; - bool end_of_utterance = true; - decoder.GetLattice(end_of_utterance, &clat); + bool use_final_probs = true; + CompactLattice clat = decoder.GetLattice(decoder.NumFramesDecoded(), + use_final_probs); GetDiagnosticsAndPrintOutput(utt, word_syms, clat, &num_frames, &tot_like); From 075c915bcb8a99a8f723c97796cecf9c60a13153 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 17 Nov 2019 03:20:26 -0500 Subject: [PATCH 51/60] Further progress towards working version --- src/decoder/lattice-incremental-decoder.cc | 294 ++++++++++----------- src/decoder/lattice-incremental-decoder.h | 52 +--- 2 files changed, 148 insertions(+), 198 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 20bf3b815c9..23c7c7adc3c 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -243,6 +243,7 @@ template void LatticeIncrementalDecoderTpl::PruneForwardLinks( int32 frame_plus_one, bool *extra_costs_changed, bool *links_pruned, BaseFloat delta) { + KALDI_LOG << "In PruneForwardLinks " << frame_plus_one; // delta is the amount by which the extra_costs must change // If delta is larger, we'll tend to go back less far // toward the beginning of the file. @@ -419,6 +420,7 @@ BaseFloat LatticeIncrementalDecoderTpl::FinalRelativeCost() const { template void LatticeIncrementalDecoderTpl::PruneTokensForFrame( int32 frame_plus_one) { + KALDI_LOG << "In PruneTokensForFrame: " << frame_plus_one; KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); Token *&toks = active_toks_[frame_plus_one].toks; if (toks == NULL) KALDI_WARN << "No tokens alive [doing pruning]"; @@ -484,7 +486,6 @@ void LatticeIncrementalDecoderTpl::PruneActiveTokens(BaseFloat delta active_toks_[f + 1].must_prune_tokens = false; } } - PruneTokensForFrame(0); KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin << " to " << num_toks_; } @@ -493,7 +494,6 @@ template void LatticeIncrementalDecoderTpl::ComputeFinalCosts( unordered_map *final_costs, BaseFloat *final_relative_cost, BaseFloat *final_best_cost) const { - KALDI_ASSERT(!decoding_finalized_); if (final_costs != NULL) final_costs->clear(); const Elem *final_toks = toks_.GetList(); BaseFloat infinity = std::numeric_limits::infinity(); @@ -582,7 +582,7 @@ template void LatticeIncrementalDecoderTpl::FinalizeDecoding() { int32 final_frame_plus_one = NumFramesDecoded(); int32 num_toks_begin = num_toks_; - // PruneForwardLinksFinal() prunes final frame (with final-probs), and + // PruneForwardLinksFinal() prunes the final frame (with final-probs), and // sets decoding_finalized_. PruneForwardLinksFinal(); for (int32 f = final_frame_plus_one - 1; f >= 0; f--) { @@ -913,25 +913,28 @@ const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( /* If we included the final-costs at this stage, they will cause non-final states to be pruned out from the end of the lattice. */ BaseFloat final_cost; - if (decoding_finalized_) { - if (final_costs_.empty()) { - final_cost = 0.0; /* No final-state survived, so treat all as final - * with probability One(). */ + { // This block computes final_cost + if (decoding_finalized_) { + if (final_costs_.empty()) { + final_cost = 0.0; /* No final-state survived, so treat all as final + * with probability One(). */ + } else { + auto iter = final_costs_.find(tok); + if (iter == final_costs_.end()) + final_cost = std::numeric_limits::infinity(); + else + final_cost = iter->second; + } } else { - auto iter = final_costs_.find(tok); - if (iter == final_costs_.end()) - continue; - final_cost = iter->second; + /* this is a `fake` final-cost used to guide pruning. This equals + the alpha+beta of the state, if we were to set the betas on + the final frame to the negatives of the alphas (this is a trick + to make all such tokens on the best path, to avoid pruning out + anything that might be within `lattice-beam` of the eventual + best path). + */ + final_cost = -(tok->tot_cost + tok->extra_cost); } - } else { - /* this is a `fake` final-cost used to guide pruning. This equals - the alpha+beta of the state, if we were to set the betas on - the final frame to the negatives of the alphas (this is a trick - to make all such tokens on the best path, to avoid pruning out - anything that might be within `lattice-beam` of the eventual - best path). - */ - final_cost = -(tok->tot_cost + tok->extra_cost); } StateId state = chunk_lat.AddState(); @@ -945,8 +948,12 @@ const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( LatticeWeight::One(), token_final_state)); KALDI_ASSERT(final_cost - final_cost == 0.0); // no inf. - KALDI_LOG << "Setting state " << token_final_state << " final w.p. " << final_cost; chunk_lat.SetFinal(token_final_state, LatticeWeight(final_cost, 0.0)); + + if (decoding_finalized_ && frame == num_frames_to_include) { + // If this is the last chunk, we need to include epsilon transitions + // on the last frame. + } } } @@ -1094,50 +1101,29 @@ void LatticeIncrementalDeterminizer::Init() { clat_.DeleteStates(); final_arcs_.clear(); forward_costs_.clear(); + arcs_in_.clear(); } -// See documentation in header -void LatticeIncrementalDeterminizer::ReweightChunk( - CompactLattice *chunk_clat) const { - using StateId = CompactLatticeArc::StateId; - using Label = CompactLatticeArc::Label; - StateId chunk_start_state = chunk_clat->Start(), - clat_num_states = clat_.NumStates(); - std::vector potentials(chunk_clat->NumStates(), - CompactLatticeWeight::One()); - - for (fst::MutableArcIterator aiter(chunk_clat, - chunk_start_state); - !aiter.Done(); aiter.Next()) { - CompactLatticeArc arc = aiter.Value(); - Label label = arc.ilabel; // ilabel == olabel. - StateId clat_state = label - kStateLabelOffset; - KALDI_ASSERT(clat_state >= 0 && clat_state < clat_num_states); - // `extra_weight` serves to cancel out the weight - // `forward_costs_[clat_state]` that we introduced in - // InitializeRawLatticeChunk(); the purpose of that was to - // make the pruned determinization work right, but they are - // no longer needed. - LatticeWeight extra_weight(-forward_costs_[clat_state], 0.0); - arc.weight.SetWeight(fst::Times(arc.weight.Weight(), extra_weight)); - aiter.SetValue(arc); - potentials[arc.nextstate] = arc.weight; - } - // TODO: consider doing the following manually for this special case, - // since most states are not reweighted. - fst::Reweight(chunk_clat, potentials, fst::REWEIGHT_TO_FINAL); - - // Below is just a check that weights on arcs leaving initial state - // are all One(). - // TODO: remove the following. - for (fst::ArcIterator aiter(*chunk_clat, - chunk_start_state); - !aiter.Done(); aiter.Next()) { - KALDI_ASSERT(fst::ApproxEqual(aiter.Value().weight, - CompactLatticeWeight::One())); - } +CompactLatticeArc::StateId LatticeIncrementalDeterminizer::AddStateToClat() { + CompactLatticeArc::StateId ans = clat_.AddState(); + forward_costs_.push_back(std::numeric_limits::infinity()); + KALDI_ASSERT(forward_costs_.size() == ans + 1); + arcs_in_.resize(ans + 1); + return ans; } +void LatticeIncrementalDeterminizer::AddArcToClat( + CompactLatticeArc::StateId state, + const CompactLatticeArc &arc) { + int32 arc_idx = clat_.NumArcs(state); + clat_.AddArc(state, arc); + arcs_in_[arc.nextstate].push_back({state, arc_idx}); + BaseFloat forward_cost = forward_costs_[state] + + ConvertToCost(arc.weight); + if (forward_cost < forward_costs_[arc.nextstate]) + forward_costs_[arc.nextstate] = forward_cost; + KALDI_ASSERT(forward_cost - forward_cost == 0); // TODO: remove this. TEMP. +} // See documentation in header void LatticeIncrementalDeterminizer::IdentifyTokenFinalStates( @@ -1240,6 +1226,8 @@ void LatticeIncrementalDeterminizer::InitializeRawLatticeChunk( clat_arc.nextstate = lat_nextstate; AddCompactLatticeArcToLattice(clat_arc, lat_state, olat); } + clat_.DeleteArcs(redet_state); + clat_.SetFinal(redet_state, CompactLatticeWeight::Zero()); } for (const CompactLatticeArc &arc: final_arcs_) { @@ -1297,47 +1285,6 @@ void LatticeIncrementalDeterminizer::InitializeRawLatticeChunk( } } - -static bool incr_det_warned = false; -void LatticeIncrementalDeterminizer::UpdateForwardCosts( - const std::vector &new_states) { - using StateId = CompactLattice::StateId; - BaseFloat infinity = std::numeric_limits::infinity(); - StateId cur_size = forward_costs_.size(); - for (StateId state: new_states) { - // The reason we can make the following assertion is that the states should - // be in topological order and each state should be reachable from an - // earlier state (and we should have processed that earlier state by now). - KALDI_ASSERT(state < cur_size); - BaseFloat cur_cost = forward_costs_[state]; - if (cur_cost == infinity) { - // I don't think I can exclude that there might be unreachable - // states - if (!incr_det_warned) { - KALDI_WARN << "Found unreachable state in compact lattice while determinizing"; - incr_det_warned = true; - } - continue; - } - KALDI_ASSERT(cur_cost < infinity); - for (fst::ArcIterator aiter(clat_, state); - !aiter.Done(); aiter.Next()) { - const CompactLatticeArc &arc = aiter.Value(); - BaseFloat arc_cost = arc.weight.Weight().Value1() + - arc.weight.Weight().Value2(), - next_cost = cur_cost + arc_cost; - if (arc.nextstate >= cur_size) { - forward_costs_.resize(arc.nextstate + 1, infinity); - cur_size = arc.nextstate + 1; - forward_costs_[arc.nextstate] = next_cost; - } else if (forward_costs_[arc.nextstate] > next_cost) { - forward_costs_[arc.nextstate] = next_cost; - } - } - } -} - - void LatticeIncrementalDeterminizer::GetRawLatticeFinalCosts( const Lattice &raw_fst, std::unordered_map *old_final_costs) { @@ -1377,16 +1324,6 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( using Label = CompactLatticeArc::Label; using StateId = CompactLatticeArc::StateId; - - { - std::ostringstream os; - bool acceptor = false, write_one = false; - fst::FstPrinter printer(*raw_fst, NULL, NULL, - NULL, acceptor, write_one, "\t"); - printer.Print(&os, ""); - KALDI_LOG << "Raw FST is " << os.str(); - } - // old_final_costs is a map from a `token-label` (see glossary) to the // associated final-prob in a final-state of `raw_fst`, that is associated // with that Token. These are Tokens that were active at the end of the @@ -1403,23 +1340,11 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( trans_model_, raw_fst, config_.lattice_beam, &chunk_clat, config_.det_opts); - { - std::ostringstream os; - bool acceptor = false, write_one = false; - fst::FstPrinter printer(chunk_clat, NULL, NULL, - NULL, acceptor, write_one, "\t"); - printer.Print(&os, ""); - KALDI_LOG << "Determinized chunk FST is " << os.str(); - } - - - TopSortCompactLatticeIfNeeded(&chunk_clat); std::unordered_map chunk_state_to_token; IdentifyTokenFinalStates(chunk_clat, &chunk_state_to_token); - KALDI_LOG << "num token-final states is " << chunk_state_to_token.size(); StateId chunk_num_states = chunk_clat.NumStates(); if (chunk_num_states == 0) { // This will be an error but user-level calling code can detect it from the @@ -1438,22 +1363,28 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( // clat_), and will transition to a state in `chunk_clat` that we can identify // with that redeterminized-state. - // state_map maps from (non-initial state s in chunk_clat) to: - // if s is not token-final, then a state in clat_, - - // if s is token-final, then a state-label allocated by - // AllocateNewStateLabel(); this will become a .nextstate in final_arcs_). + // state_map maps from (non-initial, non-token-final state s in chunk_clat) to + // a state in clat_. std::unordered_map state_map; - std::unordered_map rev_state_map; - bool is_first_chunk = false; StateId clat_num_states = clat_.NumStates(); - // Process arcs leaving the start state of chunk_clat. These will - // have state-labels on them. The weights will all be One(); - // this is ensured in ReweightChunk(). + + + if (0) { + std::ostringstream os; + bool acceptor = false, write_one = false; + fst::FstPrinter printer(chunk_clat, NULL, NULL, + NULL, acceptor, write_one, "\t"); + printer.Print(&os, ""); + KALDI_LOG << "Determinized chunk FST is " << os.str(); + } + + + // Process arcs leaving the start state of chunk_clat. These arcs will have + // state-labels on them (unless this is the first chunk). for (fst::ArcIterator aiter(chunk_clat, start_state); !aiter.Done(); aiter.Next()) { const CompactLatticeArc &arc = aiter.Value(); @@ -1469,14 +1400,70 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( } StateId clat_state = label - kStateLabelOffset; StateId chunk_state = arc.nextstate; - bool inserted = state_map.insert({chunk_state, clat_state}).second; - // Should not have been in the map before. - KALDI_ASSERT(inserted); + auto p = state_map.insert({chunk_state, clat_state}); + StateId dest_clat_state = p.first->second; + /* + In almost all cases, dest_clat_state and clat_state will be the sme state; + but there may be situations where two arcs with different state-labels + left the start state and entered the same next-state in chunk_clat; and in + these cases, they will be different. + + We didn't address this issue in the paper (or actually realize it could be + a problem). What we do is pick one of the clat_states as the "canonical" + one, and redirect all incoming transitions of the others to enter the + "canonical" one. (Search below for new_in_arc.nextstate = + dest_clat_state). + */ + + // in_weight is an extra weight that we'll include on arcs entering this + // state from the previous chunk. We need to cancel out + // `forward_costs[clat_state]`, which was included in the corresponding arc + // in the raw lattice for pruning purposes; and we need to include + // the weight from the start-state of `chunk_clat` to this state. + CompactLatticeWeight extra_weight_in = arc.weight; + extra_weight_in.SetWeight( + fst::Times(extra_weight_in.Weight(), + LatticeWeight(-forward_costs_[clat_state], 0.0))); + + // Note: 0 is the start state of clat_. This was checked. + forward_costs_[clat_state] = (clat_state == 0 ? 0 : + std::numeric_limits::infinity()); + std::vector > &arcs_in(arcs_in_[clat_state]); + std::vector > new_arcs_in; + for (auto p: arcs_in) { + // Note: we'll be doing `continue` below if this input arc came from + // another redeterminized-state, because we did DeleteStates() for them in + // InitializeRawLatticeChunk(). + CompactLattice::StateId src_state = p.first; + int32 arc_pos = p.second; + if (arc_pos >= (int32)clat_.NumArcs(src_state)) + continue; + fst::MutableArcIterator aiter(&clat_, src_state); + aiter.Seek(arc_pos); + if (aiter.Value().nextstate != clat_state) + continue; // This arc record has become invalidated. + CompactLatticeArc new_in_arc(aiter.Value()); + // In most cases we will have dest_clat_state == clat_state, so the next + // line won't change the value of .nextstate + new_in_arc.nextstate = dest_clat_state; + new_in_arc.weight = fst::Times(new_in_arc.weight, extra_weight_in); + aiter.SetValue(new_in_arc); + + BaseFloat new_forward_cost = forward_costs_[src_state] + + ConvertToCost(new_in_arc.weight); + if (new_forward_cost < forward_costs_[dest_clat_state]) + forward_costs_[dest_clat_state] = new_forward_cost; + new_arcs_in.push_back(p); + } + // Commit a cleaned-up version of `arcs_in` that doesn't contain any + // no-longer-valid arcs. This may seem redundant, since no-longer-valid + // arcs are permitted in arcs_in_, and we almost certainly won't need this + // information in future; but I am concerned that if some states remain + // active for many chunks, keeping all the old input arcs might cause the + // algorithm to become quadratic-time. + arcs_in.swap(new_arcs_in); } - if (!is_first_chunk) - ReweightChunk(&chunk_clat); // Note: we haven't inspected any weights yet. - // Remove any existing arcs in clat_ that leave redeterminized-states, and // make those states non-final. Below, we'll add arcs leaving those states // (and possibly new final-probs.) @@ -1485,7 +1472,7 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( clat_.SetFinal(clat_state, CompactLatticeWeight::Zero()); } - // The final-arc info is no longer relevant; we'll recreate it below. + // The previous final-arc info is no longer relevant; we'll recreate it below. final_arcs_.clear(); // assume chunk_lat.Start() == 0; we asserted it above. Allocate state-ids @@ -1498,18 +1485,19 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( StateId new_clat_state = clat_.NumStates(); if (state_map.insert({state, new_clat_state}).second) { // If it was inserted then we need to actually allocate that state - StateId s = clat_.AddState(); + StateId s = AddStateToClat(); KALDI_ASSERT(s == new_clat_state); } // else do nothing; it would have been a redeterminized-state and no } // allocation is needed since they already exist in clat_. and // in state_map. if (is_first_chunk) { - KALDI_ASSERT(forward_costs_.empty() && start_state == 0); - forward_costs_.push_back(0.0); // forward-cost of start state is 0. auto iter = state_map.find(start_state); KALDI_ASSERT(iter != state_map.end()); - clat_.SetStart(iter->second); + CompactLatticeArc::StateId clat_start_state = iter->second; + KALDI_ASSERT(clat_start_state == 0); // topological order. + clat_.SetStart(clat_start_state); + forward_costs_[clat_start_state] = 0.0; } // Now transfer arcs from chunk_clat to clat_. @@ -1525,7 +1513,7 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( StateId clat_state = iter->second; // We know that this point that `clat_state` is not a token-final state - // (see glossary for definition) as if it were we would have done + // (see glossary for definition) as if it were, we would have done // `continue` above. // // Only in the last chunk of the lattice would be there be a final-prob on @@ -1546,7 +1534,7 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( arc.nextstate = next_iter->second; KALDI_ASSERT(arc.ilabel < kTokenLabelOffset || arc.ilabel > kMaxTokenLabel); - clat_.AddArc(clat_state, arc); + AddArcToClat(clat_state, arc); } else { // This is the case when the arc is to a `token-final` state (see // glossary.) @@ -1583,18 +1571,6 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( } GetNonFinalRedetStates(); - { - // `clat_new_states` is states that are newly added - // or which were modified (i.e. redeterminized-states). - std::vector clat_new_states; - for (auto p: state_map) - if (p.second < kStateLabelOffset) - clat_new_states.push_back(p.second); - std::sort(clat_new_states.begin(), - clat_new_states.end()); - UpdateForwardCosts(clat_new_states); - } - return determinized_till_beam; } diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 5e6c3494ed4..09f100e7956 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -305,45 +305,10 @@ class LatticeIncrementalDeterminizer { // Sets up non_final_redet_states_. See documentation for that variable. void GetNonFinalRedetStates(); - // Updates forward_costs_ for all the states which are successors of states - // appearing as values in `clat_new_states`. (By "a is a successor of b" I mean - // there is an arc from a to b.) - // For states that already had entries in the forward_costs_ array, this - // will never decrease their forward costs. This may in theory make - // the forward-costs inaccurate (too large) in cases where arcs - // between redeterminized-states were removed by pruned determinization. - // But the forward_costs_ are anyway only used for the pruned determinization, - // and this type of error would never cause things to be pruned away that - // should not have been pruned away. (Bear in mind that - // such paths can never become the best-path; this is true because of - // how we set the betas/final-probs/extra-costs on the tokens, which - // makes the distance of a state or arc from the best path a lower bound on what that - // distance will eventually become after we finish decoding.) - // - // @param [in] clat_new_states Sorted list of state-ids in - // clat_ which were either given new arcs, or - // created, in the most recent iteration of - // pruned lattice determinization - void UpdateForwardCosts( - const std::vector &clat_new_states); - - - // Reweights `chunk_clat`. Must not be called if this is the first chunk. - // This does: - // (1) For arcs leaving chunk_clat->Start(), identify the redeterminized-state - // clat_state in clat_ that its .nextstate corresponds to, and multiply the weight - // by LatticeWeight(-forward_costs_[clat_state], 0). This is the opposite - // of a cost that we introduced when constructing the raw lattice chunk, - // in order to make sure that determinized pruning works right. We need to - // cancel it out because it's not really part of this chunk. - // (2) After doing (1), modifies chunk_clat so that the weights on arcs - // leaving its start state are all CompactLatticeWeight::One()... - // does this while maintaining equivalence, using OpenFst's - // Reweight() function. This is done for convenience, because - // the start state doesn't correspond to any state in clat_, - // and if there were weights on arcs leaving it we'd need to take - // them into account somehow. - void ReweightChunk(CompactLattice *chunk_clat) const; + + void AddArcToClat(CompactLatticeArc::StateId state, + const CompactLatticeArc &arc); + CompactLattice::StateId AddStateToClat(); // Identifies token-final states in `chunk_clat`; see glossary above for @@ -380,6 +345,15 @@ class LatticeIncrementalDeterminizer { // should have final-arcs leaving them will instead have a final-prob. CompactLattice clat_; + + // arcs_in_ is indexed by (state-id in clat_), and is a list of + // arcs that come into this state, in the form (prev-state, + // arc-index). CAUTION: not all these input-arc records will always + // be valid (some may be out-of-date, and may refer to an out-of-range + // arc or an arc that does not point to this state). But all + // input arcs will always be listed. + std::vector > > arcs_in_; + // final_arcs_ contains arcs which would appear in the canonical appended // lattice but for implementation reasons are not physically present in clat_. // These are arcs to final states in the canonical appended lattice. The From 017e395d1c805a1d244afe015b05ee6bc5fe7e8c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 17 Nov 2019 21:17:20 -0500 Subject: [PATCH 52/60] [src] Various fixes in lattice determinization --- src/decoder/decoder-wrappers.cc | 1 + src/decoder/lattice-incremental-decoder.cc | 109 +++++++++--------- src/decoder/lattice-incremental-decoder.h | 14 +-- .../online2-wav-nnet3-latgen-faster.cc | 3 +- .../online2-wav-nnet3-latgen-incremental.cc | 3 +- 5 files changed, 65 insertions(+), 65 deletions(-) diff --git a/src/decoder/decoder-wrappers.cc b/src/decoder/decoder-wrappers.cc index 6eff770f8ea..e13e9c892bb 100644 --- a/src/decoder/decoder-wrappers.cc +++ b/src/decoder/decoder-wrappers.cc @@ -271,6 +271,7 @@ bool DecodeUtteranceLatticeIncremental( // We'll write the lattice without acoustic scaling. if (acoustic_scale != 0.0) fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat); + Connect(&clat); compact_lattice_writer->Write(utt, clat); KALDI_LOG << "Log-like per frame for utterance " << utt << " is " << (likelihood / num_frames) << " over " diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 23c7c7adc3c..201de969a0d 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -243,7 +243,6 @@ template void LatticeIncrementalDecoderTpl::PruneForwardLinks( int32 frame_plus_one, bool *extra_costs_changed, bool *links_pruned, BaseFloat delta) { - KALDI_LOG << "In PruneForwardLinks " << frame_plus_one; // delta is the amount by which the extra_costs must change // If delta is larger, we'll tend to go back less far // toward the beginning of the file. @@ -420,7 +419,6 @@ BaseFloat LatticeIncrementalDecoderTpl::FinalRelativeCost() const { template void LatticeIncrementalDecoderTpl::PruneTokensForFrame( int32 frame_plus_one) { - KALDI_LOG << "In PruneTokensForFrame: " << frame_plus_one; KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); Token *&toks = active_toks_[frame_plus_one].toks; if (toks == NULL) KALDI_WARN << "No tokens alive [doing pruning]"; @@ -494,6 +492,14 @@ template void LatticeIncrementalDecoderTpl::ComputeFinalCosts( unordered_map *final_costs, BaseFloat *final_relative_cost, BaseFloat *final_best_cost) const { + if (decoding_finalized_) { + // If we finalized decoding, the list toks_ will no longer exist, so return + // something we already computed. + if (final_costs) *final_costs = final_costs_; + if (final_relative_cost) *final_relative_cost = final_relative_cost_; + if (final_best_cost) *final_best_cost = final_best_cost_; + return; + } if (final_costs != NULL) final_costs->clear(); const Elem *final_toks = toks_.GetList(); BaseFloat infinity = std::numeric_limits::infinity(); @@ -939,29 +945,27 @@ const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( StateId state = chunk_lat.AddState(); tok2state_map[tok] = state; - next_token2label_map[tok] = AllocateNewTokenLabel(); - StateId token_final_state = chunk_lat.AddState(); - LatticeArc::Label ilabel = 0, - olabel = (next_token2label_map[tok] = AllocateNewTokenLabel()); - chunk_lat.AddArc(state, - LatticeArc(ilabel, olabel, - LatticeWeight::One(), - token_final_state)); - KALDI_ASSERT(final_cost - final_cost == 0.0); // no inf. - chunk_lat.SetFinal(token_final_state, LatticeWeight(final_cost, 0.0)); - - if (decoding_finalized_ && frame == num_frames_to_include) { - // If this is the last chunk, we need to include epsilon transitions - // on the last frame. + if (final_cost < std::numeric_limits::infinity()) { + next_token2label_map[tok] = AllocateNewTokenLabel(); + StateId token_final_state = chunk_lat.AddState(); + LatticeArc::Label ilabel = 0, + olabel = (next_token2label_map[tok] = AllocateNewTokenLabel()); + chunk_lat.AddArc(state, + LatticeArc(ilabel, olabel, + LatticeWeight::One(), + token_final_state)); + chunk_lat.SetFinal(token_final_state, LatticeWeight(final_cost, 0.0)); } } } // Go in reverse order over the remaining frames so we can create arcs as we // go, and their destination-states will already be in the map. - for (int32 frame = num_frames_to_include - 1; + for (int32 frame = num_frames_to_include; frame >= num_frames_in_lattice_; frame--) { - BaseFloat cost_offset = cost_offsets_[frame]; + // The conditional below is needed for the last frame of the utterance. + BaseFloat cost_offset = (frame < cost_offsets_.size() ? + cost_offsets_[frame] : 0.0); // For the first frame of the chunk, we need to make sure the states are // the ones created by InitializeRawLatticeChunk() (where not pruned away). @@ -983,7 +987,8 @@ const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( tok2state_map[tok] = state; } } - } else { + } else if (frame != num_frames_to_include) { // We already created states + // for the last frame. for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { StateId state = chunk_lat.AddState(); tok2state_map[tok] = state; @@ -995,12 +1000,19 @@ const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( StateId cur_state = iter->second; for (ForwardLinkT *l = tok->links; l != NULL; l = l->next) { auto next_iter = tok2state_map.find(l->next_tok); - KALDI_ASSERT(next_iter != tok2state_map.end()); + if (next_iter == tok2state_map.end()) { + // Emitting arcs from the last frame we're including -- ignore + // these. + KALDI_ASSERT(frame == num_frames_to_include); + continue; + } StateId next_state = next_iter->second; BaseFloat this_offset = (l->ilabel != 0 ? cost_offset : 0); LatticeArc arc(l->ilabel, l->olabel, LatticeWeight(l->graph_cost, l->acoustic_cost - this_offset), next_state); + // Note: the epsilons get redundantly included at the end and beginning + // of successive chunks. These will get removed in the determinization. chunk_lat.AddArc(cur_state, arc); } } @@ -1040,10 +1052,12 @@ const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( Token *tok = p.first; BaseFloat cost = p.second; auto iter = token2label_map_.find(tok); - KALDI_ASSERT(iter != token2label_map_.end()); - Label token_label = iter->second; - bool ret = token_label2final_cost.insert({token_label, cost}).second; - KALDI_ASSERT(ret); /* Make sure it was inserted. */ + if (iter != token2label_map_.end()) { + /* Some tokens may not have survived the pruned determinization. */ + Label token_label = iter->second; + bool ret = token_label2final_cost.insert({token_label, cost}).second; + KALDI_ASSERT(ret); /* Make sure it was inserted. */ + } } } /* Note: these final-probs won't affect the next chunk, only the lattice @@ -1115,14 +1129,15 @@ CompactLatticeArc::StateId LatticeIncrementalDeterminizer::AddStateToClat() { void LatticeIncrementalDeterminizer::AddArcToClat( CompactLatticeArc::StateId state, const CompactLatticeArc &arc) { + BaseFloat forward_cost = forward_costs_[state] + + ConvertToCost(arc.weight); + if (forward_cost == std::numeric_limits::infinity()) + return; int32 arc_idx = clat_.NumArcs(state); clat_.AddArc(state, arc); arcs_in_[arc.nextstate].push_back({state, arc_idx}); - BaseFloat forward_cost = forward_costs_[state] + - ConvertToCost(arc.weight); if (forward_cost < forward_costs_[arc.nextstate]) forward_costs_[arc.nextstate] = forward_cost; - KALDI_ASSERT(forward_cost - forward_cost == 0); // TODO: remove this. TEMP. } // See documentation in header @@ -1161,9 +1176,12 @@ void LatticeIncrementalDeterminizer::GetNonFinalRedetStates() { // Note: we abuse the .nextstate field to store the state which is really // the source of that arc. StateId redet_state = arc.nextstate; - if (non_final_redet_states_.insert(redet_state).second) { - // it was not already there - state_queue.push_back(redet_state); + if (forward_costs_[redet_state] != std::numeric_limits::infinity()) { + // if it is accessible.. + if (non_final_redet_states_.insert(redet_state).second) { + // it was not already there + state_queue.push_back(redet_state); + } } } // Add any states that are reachable from the states above. @@ -1371,20 +1389,10 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( bool is_first_chunk = false; StateId clat_num_states = clat_.NumStates(); - - - if (0) { - std::ostringstream os; - bool acceptor = false, write_one = false; - fst::FstPrinter printer(chunk_clat, NULL, NULL, - NULL, acceptor, write_one, "\t"); - printer.Print(&os, ""); - KALDI_LOG << "Determinized chunk FST is " << os.str(); - } - - // Process arcs leaving the start state of chunk_clat. These arcs will have // state-labels on them (unless this is the first chunk). + // For destination-states of those arcs, work out which states in + // clat_ they correspond to and update their forward_costs. for (fst::ArcIterator aiter(chunk_clat, start_state); !aiter.Done(); aiter.Next()) { const CompactLatticeArc &arc = aiter.Value(); @@ -1402,8 +1410,10 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( StateId chunk_state = arc.nextstate; auto p = state_map.insert({chunk_state, clat_state}); StateId dest_clat_state = p.first->second; + // We deleted all its arcs in InitializeRawLatticeChunk + KALDI_ASSERT(clat_.NumArcs(clat_state) == 0); /* - In almost all cases, dest_clat_state and clat_state will be the sme state; + In almost all cases, dest_clat_state and clat_state will be the same state; but there may be situations where two arcs with different state-labels left the start state and entered the same next-state in chunk_clat; and in these cases, they will be different. @@ -1428,8 +1438,8 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( // Note: 0 is the start state of clat_. This was checked. forward_costs_[clat_state] = (clat_state == 0 ? 0 : std::numeric_limits::infinity()); - std::vector > &arcs_in(arcs_in_[clat_state]); - std::vector > new_arcs_in; + std::vector > arcs_in; + arcs_in.swap(arcs_in_[clat_state]); for (auto p: arcs_in) { // Note: we'll be doing `continue` below if this input arc came from // another redeterminized-state, because we did DeleteStates() for them in @@ -1453,15 +1463,8 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( ConvertToCost(new_in_arc.weight); if (new_forward_cost < forward_costs_[dest_clat_state]) forward_costs_[dest_clat_state] = new_forward_cost; - new_arcs_in.push_back(p); + arcs_in_[dest_clat_state].push_back(p); } - // Commit a cleaned-up version of `arcs_in` that doesn't contain any - // no-longer-valid arcs. This may seem redundant, since no-longer-valid - // arcs are permitted in arcs_in_, and we almost certainly won't need this - // information in future; but I am concerned that if some states remain - // active for many chunks, keeping all the old input arcs might cause the - // algorithm to become quadratic-time. - arcs_in.swap(new_arcs_in); } // Remove any existing arcs in clat_ that leave redeterminized-states, and diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 09f100e7956..fc1d322c02b 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -485,6 +485,9 @@ class LatticeIncrementalDecoderTpl { See also UpdateLatticeDeterminizaton(). Caution: this const ref is only valid until the next time you call AdvanceDecoding() or GetLattice(). + + CAUTION: the lattice may contain disconnnected states; you should + call Connect() on the output before writing it out. */ const CompactLattice &GetLattice(int32 num_frames_to_include, bool use_final_probs = false); @@ -495,17 +498,8 @@ class LatticeIncrementalDecoderTpl { be the largest number that GetLattice() was called with, but note that GetLattice() may be called from UpdateLatticeDeterminization(). - Made - available in case the user wants to give that same number to + Made available in case the user wants to give that same number to GetLattice(). - - CAUTION: if the caller wants to call GetLattice() with use_final_probs == - true, or use_final_probs == true and finalize == true, and you previously - updated the lattice with UpdateLatticeDeterminization() up to a certain - frame, it's necessary to call GetLattice() with a *higher* frame than the - previously used one, because the interface doesn't support adding/removing - the final-probs unless you are processing a new chunk. This is - not a fundamental limitation, it's just the way we have coded it. */ int NumFramesInLattice() const { return num_frames_in_lattice_; } diff --git a/src/online2bin/online2-wav-nnet3-latgen-faster.cc b/src/online2bin/online2-wav-nnet3-latgen-faster.cc index 1549dd6ae52..c7fb3806e6b 100644 --- a/src/online2bin/online2-wav-nnet3-latgen-faster.cc +++ b/src/online2bin/online2-wav-nnet3-latgen-faster.cc @@ -58,7 +58,8 @@ void GetDiagnosticsAndPrintOutput(const std::string &utt, *tot_like += likelihood; KALDI_VLOG(2) << "Likelihood per frame for utterance " << utt << " is " << (likelihood / num_frames) << " over " << num_frames - << " frames."; + << " frames, = " << (-weight.Value1() / num_frames) + << ',' << (weight.Value2() / num_frames); if (word_syms != NULL) { std::cerr << utt << ' '; diff --git a/src/online2bin/online2-wav-nnet3-latgen-incremental.cc b/src/online2bin/online2-wav-nnet3-latgen-incremental.cc index 5e4554e4a8c..cf36c4cbae5 100644 --- a/src/online2bin/online2-wav-nnet3-latgen-incremental.cc +++ b/src/online2bin/online2-wav-nnet3-latgen-incremental.cc @@ -57,7 +57,8 @@ void GetDiagnosticsAndPrintOutput(const std::string &utt, *tot_like += likelihood; KALDI_VLOG(2) << "Likelihood per frame for utterance " << utt << " is " << (likelihood / num_frames) << " over " << num_frames - << " frames."; + << " frames, = " << (-weight.Value1() / num_frames) + << ',' << (weight.Value2() / num_frames); if (word_syms != NULL) { std::cerr << utt << ' '; From 95c7771ad22ce45a216387cfc7533fb08764b2b0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 18 Nov 2019 11:16:07 +0800 Subject: [PATCH 53/60] [src] Fix to incr-det code --- src/decoder/lattice-incremental-decoder.cc | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 201de969a0d..4737c8c08ac 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -932,14 +932,16 @@ const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( final_cost = iter->second; } } else { - /* this is a `fake` final-cost used to guide pruning. This equals - the alpha+beta of the state, if we were to set the betas on - the final frame to the negatives of the alphas (this is a trick - to make all such tokens on the best path, to avoid pruning out - anything that might be within `lattice-beam` of the eventual - best path). + /* this is a `fake` final-cost used to guide pruning. It's as if we + set the betas (backward-probs) on the final frame to the + negatives of the corresponding alphas, so all tokens on the last + frae will be on a best path.. the extra_cost for each token + always corresponds to its alpha+beta on this assumption. We want + the final_cost here to correspond to the beta (backward-prob), so + we get that by final_cost = extra_cost - tot_cost. + [The tot_cost is the forward/alpha cost.] */ - final_cost = -(tok->tot_cost + tok->extra_cost); + final_cost = tok->extra_cost - tok->tot_cost; } } @@ -1443,7 +1445,8 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( for (auto p: arcs_in) { // Note: we'll be doing `continue` below if this input arc came from // another redeterminized-state, because we did DeleteStates() for them in - // InitializeRawLatticeChunk(). + // InitializeRawLatticeChunk(). Those arcs will be transferred + // from chunk_clat later on. CompactLattice::StateId src_state = p.first; int32 arc_pos = p.second; if (arc_pos >= (int32)clat_.NumArcs(src_state)) From 19f062c9205468eec93a5539be4880c6f1050877 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 18 Nov 2019 14:13:09 +0800 Subject: [PATCH 54/60] [src] Code refactor in incr-det; fix assert in CompactLatticeShortestPath(). --- src/lat/lattice-functions.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lat/lattice-functions.cc b/src/lat/lattice-functions.cc index 7f484f95233..a82c4b4a297 100644 --- a/src/lat/lattice-functions.cc +++ b/src/lat/lattice-functions.cc @@ -1107,7 +1107,6 @@ void CompactLatticeShortestPath(const CompactLattice &clat, // Now we can assume it's topologically sorted. shortest_path->DeleteStates(); if (clat.Start() == kNoStateId) return; - KALDI_ASSERT(clat.Start() == 0); // since top-sorted. typedef CompactLatticeArc Arc; typedef Arc::StateId StateId; typedef CompactLatticeWeight Weight; @@ -1117,7 +1116,7 @@ void CompactLatticeShortestPath(const CompactLattice &clat, best_cost_and_pred[s].first = std::numeric_limits::infinity(); best_cost_and_pred[s].second = fst::kNoStateId; } - best_cost_and_pred[0].first = 0; + best_cost_and_pred[clat.Start()].first = 0; for (StateId s = 0; s < clat.NumStates(); s++) { double my_cost = best_cost_and_pred[s].first; for (ArcIterator aiter(clat, s); From 3ef55441e87f5d9ec629b590a3ee8980344010e1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 18 Nov 2019 14:16:27 +0800 Subject: [PATCH 55/60] [src] incr-decoder refactor --- src/decoder/lattice-incremental-decoder.cc | 117 +++++++++++---------- src/decoder/lattice-incremental-decoder.h | 29 +++++ 2 files changed, 91 insertions(+), 55 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 4737c8c08ac..2de245554c0 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -1339,63 +1339,17 @@ void LatticeIncrementalDeterminizer::GetRawLatticeFinalCosts( } -bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( - Lattice *raw_fst) { - using Label = CompactLatticeArc::Label; - using StateId = CompactLatticeArc::StateId; - - // old_final_costs is a map from a `token-label` (see glossary) to the - // associated final-prob in a final-state of `raw_fst`, that is associated - // with that Token. These are Tokens that were active at the end of the - // chunk. The final-probs may arise from beta (backward) costs, introduced - // for pruning purposes, and/or from final-probs in HCLG. Those costs will - // not be included in anything we store permamently in this class; they used - // only to guide pruned determinization, and we will use `old_final_costs` - // later to cancel them out. - std::unordered_map old_final_costs; - GetRawLatticeFinalCosts(*raw_fst, &old_final_costs); - - CompactLattice chunk_clat; - bool determinized_till_beam = DeterminizeLatticePhonePrunedWrapper( - trans_model_, raw_fst, config_.lattice_beam, &chunk_clat, - config_.det_opts); - - TopSortCompactLatticeIfNeeded(&chunk_clat); - - std::unordered_map chunk_state_to_token; - IdentifyTokenFinalStates(chunk_clat, - &chunk_state_to_token); - StateId chunk_num_states = chunk_clat.NumStates(); - if (chunk_num_states == 0) { - // This will be an error but user-level calling code can detect it from the - // lattice being empty. - KALDI_WARN << "Empty lattice, something went wrong."; - clat_.DeleteStates(); - return false; - } - - StateId start_state = chunk_clat.Start(); // would be 0. - KALDI_ASSERT(start_state == 0); - - // Process arcs leaving the start state of chunk_clat. Unless this is the - // first chunk in the lattice, all arcs leaving the start state of chunk_clat - // will have `state labels` on them (identifying redeterminized-states in - // clat_), and will transition to a state in `chunk_clat` that we can identify - // with that redeterminized-state. - - // state_map maps from (non-initial, non-token-final state s in chunk_clat) to - // a state in clat_. - std::unordered_map state_map; - - - bool is_first_chunk = false; +bool LatticeIncrementalDeterminizer::ProcessArcsFromStartState( + const CompactLattice &chunk_clat, + std::unordered_map *state_map) { + using StateId = CompactLattice::StateId; StateId clat_num_states = clat_.NumStates(); // Process arcs leaving the start state of chunk_clat. These arcs will have // state-labels on them (unless this is the first chunk). // For destination-states of those arcs, work out which states in // clat_ they correspond to and update their forward_costs. - for (fst::ArcIterator aiter(chunk_clat, start_state); + for (fst::ArcIterator aiter(chunk_clat, chunk_clat.Start()); !aiter.Done(); aiter.Next()) { const CompactLatticeArc &arc = aiter.Value(); Label label = arc.ilabel; // ilabel == olabel; would be the olabel @@ -1404,13 +1358,12 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( label - kStateLabelOffset < clat_num_states)) { // The label was not a state-label. This should only be possible on the // first chunk. - KALDI_ASSERT(state_map.empty()); - is_first_chunk = true; - break; + KALDI_ASSERT(state_map->empty()); + return true; // this is the first chunk. } StateId clat_state = label - kStateLabelOffset; StateId chunk_state = arc.nextstate; - auto p = state_map.insert({chunk_state, clat_state}); + auto p = state_map->insert({chunk_state, clat_state}); StateId dest_clat_state = p.first->second; // We deleted all its arcs in InitializeRawLatticeChunk KALDI_ASSERT(clat_.NumArcs(clat_state) == 0); @@ -1469,6 +1422,60 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( arcs_in_[dest_clat_state].push_back(p); } } + return false; // this is not the first chunk. +} + + +bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( + Lattice *raw_fst) { + using Label = CompactLatticeArc::Label; + using StateId = CompactLatticeArc::StateId; + + // old_final_costs is a map from a `token-label` (see glossary) to the + // associated final-prob in a final-state of `raw_fst`, that is associated + // with that Token. These are Tokens that were active at the end of the + // chunk. The final-probs may arise from beta (backward) costs, introduced + // for pruning purposes, and/or from final-probs in HCLG. Those costs will + // not be included in anything we store permamently in this class; they used + // only to guide pruned determinization, and we will use `old_final_costs` + // later to cancel them out. + std::unordered_map old_final_costs; + GetRawLatticeFinalCosts(*raw_fst, &old_final_costs); + + CompactLattice chunk_clat; + bool determinized_till_beam = DeterminizeLatticePhonePrunedWrapper( + trans_model_, raw_fst, config_.lattice_beam, &chunk_clat, + config_.det_opts); + + TopSortCompactLatticeIfNeeded(&chunk_clat); + + std::unordered_map chunk_state_to_token; + IdentifyTokenFinalStates(chunk_clat, + &chunk_state_to_token); + StateId chunk_num_states = chunk_clat.NumStates(); + if (chunk_num_states == 0) { + // This will be an error but user-level calling code can detect it from the + // lattice being empty. + KALDI_WARN << "Empty lattice, something went wrong."; + clat_.DeleteStates(); + return false; + } + + StateId start_state = chunk_clat.Start(); // would be 0. + KALDI_ASSERT(start_state == 0); + + // Process arcs leaving the start state of chunk_clat. Unless this is the + // first chunk in the lattice, all arcs leaving the start state of chunk_clat + // will have `state labels` on them (identifying redeterminized-states in + // clat_), and will transition to a state in `chunk_clat` that we can identify + // with that redeterminized-state. + + // state_map maps from (non-initial, non-token-final state s in chunk_clat) to + // a state in clat_. + std::unordered_map state_map; + + + bool is_first_chunk = ProcessArcsFromStartState(chunk_clat, &state_map); // Remove any existing arcs in clat_ that leave redeterminized-states, and // make those states non-final. Below, we'll add arcs leaving those states diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index fc1d322c02b..ba16c282663 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -291,6 +291,7 @@ class LatticeIncrementalDeterminizer { private: + // [called from AcceptRawLatticeChunk()] // Gets the final costs from token-final states in the raw lattice (see // glossary for definition). These final costs will be subtracted after // determinization; in the normal case they are `temporaries` used to guide @@ -305,6 +306,34 @@ class LatticeIncrementalDeterminizer { // Sets up non_final_redet_states_. See documentation for that variable. void GetNonFinalRedetStates(); + /** [called from AcceptRawLatticeChunk()] Processes arcs that leave the + start-state of `chunk_clat` (if this is not the first chunk); does nothing + if this is the first chunk. This includes using the `state-labels` to + work out which states in clat_ these states correspond to, and writing + that mapping to `state_map`. + + Also modifies forward_costs_, because it has to do a kind of reweighting + of the clat states that are the values it puts in `state_map`, to take + account of the probabilities on the arcs from the start state of + chunk_clat to the states corresponding to those redeterminized-states + (i.e. the states in clat corresponding to the values it puts in + `*state_map`). It also modifies arcs_in_, mostly because there + are rare cases when we end up `merging` sets of those redeterminized-states, + because the determinization process mapped them to a single state, + and that means we need to reroute the arcs into members of that + set into one single member (which will appear as a value in + `*state_map`). + + @param [in] chunk_clat The determinized chunk of lattice we are + processing + @param [out] state_map Mapping from states in chunk_clat to + the state in clat_ they correspond to. + @return Returns true if this is the first chunk. + */ + bool ProcessArcsFromStartState( + const CompactLattice &chunk_clat, + std::unordered_map *state_map); + void AddArcToClat(CompactLatticeArc::StateId state, const CompactLatticeArc &arc); From aaa2484b2150946c45d914c3e618cf476ed7dd22 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 18 Nov 2019 14:28:44 +0800 Subject: [PATCH 56/60] [src] Further fix to CompactLatticeShortestPath --- src/lat/lattice-functions.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lat/lattice-functions.cc b/src/lat/lattice-functions.cc index a82c4b4a297..f4a184f3cd4 100644 --- a/src/lat/lattice-functions.cc +++ b/src/lat/lattice-functions.cc @@ -1138,8 +1138,8 @@ void CompactLatticeShortestPath(const CompactLattice &clat, } } std::vector states; // states on best path. - StateId cur_state = superfinal; - while (cur_state != 0) { + StateId cur_state = superfinal, start_state = clat.Start(); + while (cur_state != start_state) { StateId prev_state = best_cost_and_pred[cur_state].second; if (prev_state == kNoStateId) { KALDI_WARN << "Failure in best-path algorithm for lattice (infinite costs?)"; From 9104f5816b837ccacd5e6a81bb7c721361a0d008 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 18 Nov 2019 15:20:51 +0800 Subject: [PATCH 57/60] [src] Hopefully fix issue with start state --- src/decoder/lattice-incremental-decoder.cc | 40 ++++++++++++++++++++-- src/decoder/lattice-incremental-decoder.h | 26 ++++++++++++-- 2 files changed, 61 insertions(+), 5 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 2de245554c0..44688b08ef3 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -1339,9 +1339,10 @@ void LatticeIncrementalDeterminizer::GetRawLatticeFinalCosts( } -bool LatticeIncrementalDeterminizer::ProcessArcsFromStartState( +bool LatticeIncrementalDeterminizer::ProcessArcsFromChunkStartState( const CompactLattice &chunk_clat, - std::unordered_map *state_map) { + std::unordered_map *state_map, + CompactLatticeWeight *extra_start_weight) { using StateId = CompactLattice::StateId; StateId clat_num_states = clat_.NumStates(); @@ -1379,6 +1380,15 @@ bool LatticeIncrementalDeterminizer::ProcessArcsFromStartState( "canonical" one. (Search below for new_in_arc.nextstate = dest_clat_state). */ + if (clat_state != dest_clat_state) { + // Check that the start state isn't getting merged with any other state. + // If this were possible, we'd need to deal with it specially, but it + // can't be, because to be merged, 2 states must have identical arcs + // leaving them with identical weights, so we'd need to have another state + // on frame 0 identical to the start state, which is not possible if the + // lattice is deterministic and epsilon-free. + KALDI_ASSERT(clat_state != 0 && dest_clat_state != 0); + } // in_weight is an extra weight that we'll include on arcs entering this // state from the previous chunk. We need to cancel out @@ -1390,6 +1400,15 @@ bool LatticeIncrementalDeterminizer::ProcessArcsFromStartState( fst::Times(extra_weight_in.Weight(), LatticeWeight(-forward_costs_[clat_state], 0.0))); + if (clat_state == 0) { + // if clat_state is the star-state of clat_ (state 0), we can't modify + // incoming arcs; we need to modify outgoing arcs, but we'll do that + // later, after we add them. + *extra_start_weight = extra_weight_in; + forward_costs_[0] = forward_costs_[0] + ConvertToCost(extra_weight_in); + continue; + } + // Note: 0 is the start state of clat_. This was checked. forward_costs_[clat_state] = (clat_state == 0 ? 0 : std::numeric_limits::infinity()); @@ -1425,6 +1444,15 @@ bool LatticeIncrementalDeterminizer::ProcessArcsFromStartState( return false; // this is not the first chunk. } +void LatticeIncrementalDeterminizer::ReweightStartState( + CompactLatticeWeight &extra_start_weight) { + for (fst::MutableArcIterator aiter(&clat_, 0); + !aiter.Done(); aiter.Next()) { + CompactLatticeArc arc(aiter.Value()); + arc.weight = fst::Times(extra_start_weight, arc.weight); + aiter.SetValue(arc); + } +} bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( Lattice *raw_fst) { @@ -1475,7 +1503,9 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( std::unordered_map state_map; - bool is_first_chunk = ProcessArcsFromStartState(chunk_clat, &state_map); + CompactLatticeWeight extra_start_weight = CompactLatticeWeight::One(); + bool is_first_chunk = ProcessArcsFromChunkStartState(chunk_clat, &state_map, + &extra_start_weight); // Remove any existing arcs in clat_ that leave redeterminized-states, and // make those states non-final. Below, we'll add arcs leaving those states @@ -1582,6 +1612,10 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( } } } + + if (extra_start_weight != CompactLatticeWeight::One()) + ReweightStartState(extra_start_weight); + GetNonFinalRedetStates(); return determinized_till_beam; diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index ba16c282663..5b697846e31 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -328,11 +328,33 @@ class LatticeIncrementalDeterminizer { processing @param [out] state_map Mapping from states in chunk_clat to the state in clat_ they correspond to. + @param [out] extra_start_weight If the start-state of + clat_ (its state 0) needs to be modified as + if its incoming arcs were multiplied by + `extra_start_weight`, this isn't possible + using the `in_arcs_` data-structure, + so we remember the extra weight and multiply + it in later, after processing arcs leaving + the start state of clat_. This is set + only if the start-state of clat_ is a + redeterminized state. @return Returns true if this is the first chunk. */ - bool ProcessArcsFromStartState( + bool ProcessArcsFromChunkStartState( const CompactLattice &chunk_clat, - std::unordered_map *state_map); + std::unordered_map *state_map, + CompactLatticeWeight *extra_start_weight); + + /** + This function, called from AcceptRawLatticeChunk(), takes care of an + unusual situation where we need to reweight the start state of clat_. This + `extra_start_weight` is to be thought of as an extra `incoming` weight, and + we need to left-multiply all the arcs leaving the start state, by it. + + This function does not need to modify forward_costs_; that will + already have been done by ProcessArcsFromChunkStartState(). + */ + void ReweightStartState(CompactLatticeWeight &extra_start_weight); void AddArcToClat(CompactLatticeArc::StateId state, From 611187425bbee97b2d42e58a9e855b372e17f3a6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 18 Nov 2019 15:44:40 +0800 Subject: [PATCH 58/60] [src] Refactor/cleanup incr-det code --- src/decoder/lattice-incremental-decoder.cc | 187 ++++++++++-------- src/decoder/lattice-incremental-decoder.h | 45 ++++- .../online2-wav-nnet3-latgen-incremental.cc | 1 + 3 files changed, 144 insertions(+), 89 deletions(-) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index 44688b08ef3..de952b7e1cd 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -1120,8 +1120,8 @@ void LatticeIncrementalDeterminizer::Init() { arcs_in_.clear(); } -CompactLatticeArc::StateId LatticeIncrementalDeterminizer::AddStateToClat() { - CompactLatticeArc::StateId ans = clat_.AddState(); +CompactLattice::StateId LatticeIncrementalDeterminizer::AddStateToClat() { + CompactLattice::StateId ans = clat_.AddState(); forward_costs_.push_back(std::numeric_limits::infinity()); KALDI_ASSERT(forward_costs_.size() == ans + 1); arcs_in_.resize(ans + 1); @@ -1129,7 +1129,7 @@ CompactLatticeArc::StateId LatticeIncrementalDeterminizer::AddStateToClat() { } void LatticeIncrementalDeterminizer::AddArcToClat( - CompactLatticeArc::StateId state, + CompactLattice::StateId state, const CompactLatticeArc &arc) { BaseFloat forward_cost = forward_costs_[state] + ConvertToCost(arc.weight); @@ -1145,9 +1145,9 @@ void LatticeIncrementalDeterminizer::AddArcToClat( // See documentation in header void LatticeIncrementalDeterminizer::IdentifyTokenFinalStates( const CompactLattice &chunk_clat, - std::unordered_map *token_map) const { + std::unordered_map *token_map) const { token_map->clear(); - using StateId = CompactLatticeArc::StateId; + using StateId = CompactLattice::StateId; using Label = CompactLatticeArc::Label; StateId num_states = chunk_clat.NumStates(); @@ -1169,7 +1169,7 @@ void LatticeIncrementalDeterminizer::IdentifyTokenFinalStates( void LatticeIncrementalDeterminizer::GetNonFinalRedetStates() { - using StateId = CompactLatticeArc::StateId; + using StateId = CompactLattice::StateId; non_final_redet_states_.clear(); non_final_redet_states_.reserve(final_arcs_.size()); @@ -1216,9 +1216,9 @@ void LatticeIncrementalDeterminizer::InitializeRawLatticeChunk( // canonical appended lattice leave (physically, these are in the .nextstate // elements of arcs_, since we use that field for the source state), plus any // states reachable from those states. - unordered_map redet_state_map; + unordered_map redet_state_map; - for (CompactLatticeArc::StateId redet_state: non_final_redet_states_) + for (CompactLattice::StateId redet_state: non_final_redet_states_) redet_state_map[redet_state] = olat->AddState(); // First, process any arcs leaving the non-final redeterminized states that @@ -1226,13 +1226,13 @@ void LatticeIncrementalDeterminizer::InitializeRawLatticeChunk( // stats that are final in the `canonical appended lattice`.. they may // actually be physically final in clat_, because we make clat_ what we want // to return to the user. - for (CompactLatticeArc::StateId redet_state: non_final_redet_states_) { + for (CompactLattice::StateId redet_state: non_final_redet_states_) { LatticeArc::StateId lat_state = redet_state_map[redet_state]; for (ArcIterator aiter(clat_, redet_state); !aiter.Done(); aiter.Next()) { const CompactLatticeArc &arc = aiter.Value(); - CompactLatticeArc::StateId nextstate = arc.nextstate; + CompactLattice::StateId nextstate = arc.nextstate; LatticeArc::StateId lat_nextstate = olat->NumStates(); auto r = redet_state_map.insert({nextstate, lat_nextstate}); if (r.second) { // Was inserted. @@ -1252,7 +1252,7 @@ void LatticeIncrementalDeterminizer::InitializeRawLatticeChunk( for (const CompactLatticeArc &arc: final_arcs_) { // We abuse the `nextstate` field to store the source state. - CompactLatticeArc::StateId src_state = arc.nextstate; + CompactLattice::StateId src_state = arc.nextstate; auto iter = redet_state_map.find(src_state); KALDI_ASSERT(iter != redet_state_map.end()); LatticeArc::StateId src_lat_state = iter->second; @@ -1287,7 +1287,7 @@ void LatticeIncrementalDeterminizer::InitializeRawLatticeChunk( // a state that is not a redeterminized state." In fact, we include these // arcs for all redeterminized states. I realized that it won't make a // difference to the outcome, and it's easier to do it this way. - for (CompactLatticeArc::StateId state_id: non_final_redet_states_) { + for (CompactLattice::StateId state_id: non_final_redet_states_) { BaseFloat forward_cost = forward_costs_[state_id]; LatticeArc arc; arc.ilabel = 0; @@ -1454,10 +1454,93 @@ void LatticeIncrementalDeterminizer::ReweightStartState( } } +void LatticeIncrementalDeterminizer::TransferArcsToClat( + const CompactLattice &chunk_clat, + bool is_first_chunk, + const std::unordered_map &state_map, + const std::unordered_map &chunk_state_to_token, + const std::unordered_map &old_final_costs) { + using StateId = CompactLattice::StateId; + StateId chunk_num_states = chunk_clat.NumStates(); + + // Now transfer arcs from chunk_clat to clat_. + for (StateId chunk_state = (is_first_chunk ? 0 : 1); + chunk_state < chunk_num_states; chunk_state++) { + auto iter = state_map.find(chunk_state); + if (iter == state_map.end()) { + KALDI_ASSERT(chunk_state_to_token.count(chunk_state) != 0); + // Don't process token-final states. Anyway they have no arcs leaving + // them. + continue; + } + StateId clat_state = iter->second; + + // We know that this point that `clat_state` is not a token-final state + // (see glossary for definition) as if it were, we would have done + // `continue` above. + // + // Only in the last chunk of the lattice would be there be a final-prob on + // states that are not `token-final states`; these final-probs would + // normally all be Zero() at this point. So in almost all cases the following + // call will do nothing. + clat_.SetFinal(clat_state, chunk_clat.Final(chunk_state)); + + // Process arcs leaving this state. + for (fst::ArcIterator aiter(chunk_clat, chunk_state); + !aiter.Done(); aiter.Next()) { + CompactLatticeArc arc(aiter.Value()); + + auto next_iter = state_map.find(arc.nextstate); + if (next_iter != state_map.end()) { + // The normal case (when the .nextstate has a corresponding + // state in clat_) is very simple. Just copy the arc over. + arc.nextstate = next_iter->second; + KALDI_ASSERT(arc.ilabel < kTokenLabelOffset || + arc.ilabel > kMaxTokenLabel); + AddArcToClat(clat_state, arc); + } else { + // This is the case when the arc is to a `token-final` state (see + // glossary.) + + // TODO: remove the following slightly excessive assertion? + KALDI_ASSERT(chunk_clat.Final(arc.nextstate) != CompactLatticeWeight::Zero() && + arc.olabel >= (Label)kTokenLabelOffset && + arc.olabel < (Label)kMaxTokenLabel && + chunk_state_to_token.count(arc.nextstate) != 0 && + old_final_costs.count(arc.olabel) != 0); + + // Include the final-cost of the next state (which should be final) + // in arc.weight. + arc.weight = fst::Times(arc.weight, + chunk_clat.Final(arc.nextstate)); + + auto cost_iter = old_final_costs.find(arc.olabel); + KALDI_ASSERT(cost_iter != old_final_costs.end()); + BaseFloat old_final_cost = cost_iter->second; + + // `arc` is going to become an element of final_arcs_. These + // contain information about transitions from states in clat_ to + // `token-final` states (i.e. states that have a token-label on the arc + // to them and that are final in the canonical compact lattice). + // We subtract the old_final_cost as it was just a temporary cost + // introduced for pruning purposes. + arc.weight.SetWeight(fst::Times(arc.weight.Weight(), + LatticeWeight{-old_final_cost, 0.0})); + // In a slight abuse of the Arc data structure, the nextstate is set to + // the source state. The label (ilabel == olabel) indicates the + // token it is associated with. + arc.nextstate = clat_state; + final_arcs_.push_back(arc); + } + } + } + +} + bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( Lattice *raw_fst) { using Label = CompactLatticeArc::Label; - using StateId = CompactLatticeArc::StateId; + using StateId = CompactLattice::StateId; // old_final_costs is a map from a `token-label` (see glossary) to the // associated final-prob in a final-state of `raw_fst`, that is associated @@ -1480,6 +1563,7 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( std::unordered_map chunk_state_to_token; IdentifyTokenFinalStates(chunk_clat, &chunk_state_to_token); + StateId chunk_num_states = chunk_clat.NumStates(); if (chunk_num_states == 0) { // This will be an error but user-level calling code can detect it from the @@ -1537,81 +1621,14 @@ bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( if (is_first_chunk) { auto iter = state_map.find(start_state); KALDI_ASSERT(iter != state_map.end()); - CompactLatticeArc::StateId clat_start_state = iter->second; + CompactLattice::StateId clat_start_state = iter->second; KALDI_ASSERT(clat_start_state == 0); // topological order. clat_.SetStart(clat_start_state); forward_costs_[clat_start_state] = 0.0; } - // Now transfer arcs from chunk_clat to clat_. - for (StateId chunk_state = (is_first_chunk ? 0 : 1); - chunk_state < chunk_num_states; chunk_state++) { - auto iter = state_map.find(chunk_state); - if (iter == state_map.end()) { - KALDI_ASSERT(chunk_state_to_token.count(chunk_state) != 0); - // Don't process token-final states. Anyway they have no arcs leaving - // them. - continue; - } - StateId clat_state = iter->second; - - // We know that this point that `clat_state` is not a token-final state - // (see glossary for definition) as if it were, we would have done - // `continue` above. - // - // Only in the last chunk of the lattice would be there be a final-prob on - // states that are not `token-final states`; these final-probs would - // normally all be Zero() at this point. So in almost all cases the following - // call will do nothing. - clat_.SetFinal(clat_state, chunk_clat.Final(chunk_state)); - - // Process arcs leaving this state. - for (fst::ArcIterator aiter(chunk_clat, chunk_state); - !aiter.Done(); aiter.Next()) { - CompactLatticeArc arc(aiter.Value()); - - auto next_iter = state_map.find(arc.nextstate); - if (next_iter != state_map.end()) { - // The normal case (when the .nextstate has a corresponding - // state in clat_) is very simple. Just copy the arc over. - arc.nextstate = next_iter->second; - KALDI_ASSERT(arc.ilabel < kTokenLabelOffset || - arc.ilabel > kMaxTokenLabel); - AddArcToClat(clat_state, arc); - } else { - // This is the case when the arc is to a `token-final` state (see - // glossary.) - - // TODO: remove the following slightly excessive assertion? - KALDI_ASSERT(chunk_clat.Final(arc.nextstate) != CompactLatticeWeight::Zero() && - arc.olabel >= (Label)kTokenLabelOffset && - arc.olabel < (Label)kMaxTokenLabel && - chunk_state_to_token.count(arc.nextstate) != 0 && - old_final_costs.count(arc.olabel) != 0); - - // Include the final-cost of the next state (which should be final) - // in arc.weight. - arc.weight = fst::Times(arc.weight, - chunk_clat.Final(arc.nextstate)); - - BaseFloat old_final_cost = old_final_costs[arc.olabel]; - - // `arc` is going to become an element of final_arcs_. These - // contain information about transitions from states in clat_ to - // `token-final` states (i.e. states that have a token-label on the arc - // to them and that are final in the canonical compact lattice). - // We subtract the old_final_cost as it was just a temporary cost - // introduced for pruning purposes. - arc.weight.SetWeight(fst::Times(arc.weight.Weight(), - LatticeWeight{-old_final_cost, 0.0})); - // In a slight abuse of the Arc data structure, the nextstate is set to - // the source state. The label (ilabel == olabel) indicates the - // token it is associated with. - arc.nextstate = clat_state; - final_arcs_.push_back(arc); - } - } - } + TransferArcsToClat(chunk_clat, is_first_chunk, + state_map, chunk_state_to_token, old_final_costs); if (extra_start_weight != CompactLatticeWeight::One()) ReweightStartState(extra_start_weight); @@ -1642,7 +1659,7 @@ void LatticeIncrementalDeterminizer::SetFinalCosts( /* Caution: `state` is actually the state the arc would leave from in the canonical appended lattice; we just store that in the .nextstate field. */ - CompactLatticeArc::StateId state = arc.nextstate; + CompactLattice::StateId state = arc.nextstate; prefinal_states.insert(state); } @@ -1653,7 +1670,7 @@ void LatticeIncrementalDeterminizer::SetFinalCosts( for (const CompactLatticeArc &arc: final_arcs_) { Label token_label = arc.ilabel; /* Note: we store the source state in the .nextstate field. */ - CompactLatticeArc::StateId src_state = arc.nextstate; + CompactLattice::StateId src_state = arc.nextstate; BaseFloat graph_final_cost; if (token_label2final_cost == NULL) { graph_final_cost = 0.0; diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h index 5b697846e31..7abc370178a 100644 --- a/src/decoder/lattice-incremental-decoder.h +++ b/src/decoder/lattice-incremental-decoder.h @@ -357,7 +357,44 @@ class LatticeIncrementalDeterminizer { void ReweightStartState(CompactLatticeWeight &extra_start_weight); - void AddArcToClat(CompactLatticeArc::StateId state, + /** + This function, called from AcceptRawLatticeChunk(), transfers arcs from + `chunk_clat` to clat_. For those arcs that have `token-labels` on them, + they don't get written to clat_ but instead are stored in the arcs_ array. + + @param [in] chunk_clat The determinized lattice for the chunk + we are processing; this is the source of the arcs + we are moving. + @param [in] is_first_chunk True if this is the first chunk in the + utterance; it's needed because if it is, we + will also transfer arcs from the start state of + chunk_clat. + @param [in] state_map Map from state-ids in chunk_clat to state-ids + in clat_. + @param [in] chunk_state_to_token Map from `token-final states` + (see glossary) in chunk_clat, to the token-label + on arcs entering those states. + @param [in] old_final_costs Map from token-label to the + final-costs that were on the corresponding + token-final states in the undeterminized lattice; + these final-costs need to be removed when + we record the weights in final_arcs_, because + they were just temporary. + */ + void TransferArcsToClat( + const CompactLattice &chunk_clat, + bool is_first_chunk, + const std::unordered_map &state_map, + const std::unordered_map &chunk_state_to_token, + const std::unordered_map &old_final_costs); + + + + /** + Adds one arc to `clat_`. It's like clat_.AddArc(state, arc), except + it also modifies arcs_in_ and forward_costs_. + */ + void AddArcToClat(CompactLattice::StateId state, const CompactLatticeArc &arc); CompactLattice::StateId AddStateToClat(); @@ -371,7 +408,7 @@ class LatticeIncrementalDeterminizer { // construct the raw lattice.) void IdentifyTokenFinalStates( const CompactLattice &chunk_clat, - std::unordered_map *token_map) const; + std::unordered_map *token_map) const; // trans_model_ is needed by DeterminizeLatticePhonePrunedWrapper() which this // class calls. @@ -386,7 +423,7 @@ class LatticeIncrementalDeterminizer { // in clat_, this means the set of redeterminized-states which are physically // in clat_. In code terms, this means set of .first elements in final_arcs, // plus whatever other states in clat_ are reachable from such states. - std::unordered_set non_final_redet_states_; + std::unordered_set non_final_redet_states_; // clat_ is the appended lattice (containing all chunks processed so @@ -403,7 +440,7 @@ class LatticeIncrementalDeterminizer { // be valid (some may be out-of-date, and may refer to an out-of-range // arc or an arc that does not point to this state). But all // input arcs will always be listed. - std::vector > > arcs_in_; + std::vector > > arcs_in_; // final_arcs_ contains arcs which would appear in the canonical appended // lattice but for implementation reasons are not physically present in clat_. diff --git a/src/online2bin/online2-wav-nnet3-latgen-incremental.cc b/src/online2bin/online2-wav-nnet3-latgen-incremental.cc index cf36c4cbae5..aaa87f24de1 100644 --- a/src/online2bin/online2-wav-nnet3-latgen-incremental.cc +++ b/src/online2bin/online2-wav-nnet3-latgen-incremental.cc @@ -270,6 +270,7 @@ int main(int argc, char *argv[]) { CompactLattice clat = decoder.GetLattice(decoder.NumFramesDecoded(), use_final_probs); + Connect(&clat); GetDiagnosticsAndPrintOutput(utt, word_syms, clat, &num_frames, &tot_like); From 66c462f2df07073d0a2864e1b4c49e6be0f7fbb9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 22 Nov 2019 11:00:02 +0800 Subject: [PATCH 59/60] [src] Bug-fix in incremental decoder --- src/decoder/lattice-incremental-decoder.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc index de952b7e1cd..81e70083301 100644 --- a/src/decoder/lattice-incremental-decoder.cc +++ b/src/decoder/lattice-incremental-decoder.cc @@ -1254,6 +1254,8 @@ void LatticeIncrementalDeterminizer::InitializeRawLatticeChunk( // We abuse the `nextstate` field to store the source state. CompactLattice::StateId src_state = arc.nextstate; auto iter = redet_state_map.find(src_state); + if (forward_costs_[src_state] == std::numeric_limits::infinity()) + continue; /* Unreachable state */ KALDI_ASSERT(iter != redet_state_map.end()); LatticeArc::StateId src_lat_state = iter->second; Label token_label = arc.ilabel; // will be == arc.olabel. From c198620ae1136bce2260bd04583420a8c8689e8c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 23 Nov 2019 10:06:10 +0800 Subject: [PATCH 60/60] Fix bug in decoder wrapper --- src/decoder/decoder-wrappers.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/decoder/decoder-wrappers.cc b/src/decoder/decoder-wrappers.cc index e13e9c892bb..f63b3caa7c0 100644 --- a/src/decoder/decoder-wrappers.cc +++ b/src/decoder/decoder-wrappers.cc @@ -228,8 +228,7 @@ bool DecodeUtteranceLatticeIncremental( } // Get lattice - CompactLattice clat; - decoder.GetLattice(decoder.NumFramesDecoded(), true); + CompactLattice clat = decoder.GetLattice(decoder.NumFramesDecoded(), true); if (clat.NumStates() == 0) KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt;