diff --git a/src/bin/Makefile b/src/bin/Makefile index bfb037fc792..a04a84e21af 100644 --- a/src/bin/Makefile +++ b/src/bin/Makefile @@ -22,7 +22,7 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \ matrix-sum build-pfile-from-ali get-post-on-ali tree-info am-info \ vector-sum matrix-sum-rows est-pca sum-lda-accs sum-mllt-accs \ transform-vec align-text matrix-dim post-to-smat compile-graph \ - compare-int-vector compute-gop + compare-int-vector latgen-incremental-mapped compute-gop OBJFILES = diff --git a/src/bin/latgen-incremental-mapped.cc b/src/bin/latgen-incremental-mapped.cc new file mode 100644 index 00000000000..80c65bfb535 --- /dev/null +++ b/src/bin/latgen-incremental-mapped.cc @@ -0,0 +1,183 @@ +// bin/latgen-incremental-mapped.cc + +// Copyright 2019 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "tree/context-dep.h" +#include "hmm/transition-model.h" +#include "fstext/fstext-lib.h" +#include "decoder/decoder-wrappers.h" +#include "decoder/decodable-matrix.h" +#include "base/timer.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + using fst::SymbolTable; + using fst::Fst; + using fst::StdArc; + + const char *usage = + "Generate lattices, reading log-likelihoods as matrices\n" + " (model is needed only for the integer mappings in its transition-model)\n" + "The lattice determinization algorithm here can operate\n" + "incrementally.\n" + "Usage: latgen-incremental-mapped [options] trans-model-in " + "(fst-in|fsts-rspecifier) loglikes-rspecifier" + " lattice-wspecifier [ words-wspecifier [alignments-wspecifier] ]\n"; + ParseOptions po(usage); + Timer timer; + bool allow_partial = false; + BaseFloat acoustic_scale = 0.1; + LatticeIncrementalDecoderConfig config; + + std::string word_syms_filename; + config.Register(&po); + po.Register("acoustic-scale", &acoustic_scale, + "Scaling factor for acoustic likelihoods"); + + po.Register("word-symbol-table", &word_syms_filename, + "Symbol table for words [for debug output]"); + po.Register("allow-partial", &allow_partial, + "If true, produce output even if end state was not reached."); + + po.Read(argc, argv); + + if (po.NumArgs() < 4 || po.NumArgs() > 6) { + po.PrintUsage(); + exit(1); + } + + std::string model_in_filename = po.GetArg(1), fst_in_str = po.GetArg(2), + feature_rspecifier = po.GetArg(3), lattice_wspecifier = po.GetArg(4), + words_wspecifier = po.GetOptArg(5), + alignment_wspecifier = po.GetOptArg(6); + + TransitionModel trans_model; + ReadKaldiObject(model_in_filename, &trans_model); + + bool determinize = true; + CompactLatticeWriter compact_lattice_writer; + LatticeWriter lattice_writer; + if (!(determinize ? compact_lattice_writer.Open(lattice_wspecifier) + : lattice_writer.Open(lattice_wspecifier))) + KALDI_ERR << "Could not open table for writing lattices: " + << lattice_wspecifier; + + Int32VectorWriter words_writer(words_wspecifier); + + Int32VectorWriter alignment_writer(alignment_wspecifier); + + fst::SymbolTable *word_syms = NULL; + if (word_syms_filename != "") + if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename))) + KALDI_ERR << "Could not read symbol table from file " << word_syms_filename; + + double tot_like = 0.0; + kaldi::int64 frame_count = 0; + int num_success = 0, num_fail = 0; + + if (ClassifyRspecifier(fst_in_str, NULL, NULL) == kNoRspecifier) { + SequentialBaseFloatMatrixReader loglike_reader(feature_rspecifier); + // Input FST is just one FST, not a table of FSTs. + Fst *decode_fst = fst::ReadFstKaldiGeneric(fst_in_str); + timer.Reset(); + + { + LatticeIncrementalDecoder decoder(*decode_fst, trans_model, config); + + for (; !loglike_reader.Done(); loglike_reader.Next()) { + std::string utt = loglike_reader.Key(); + Matrix loglikes(loglike_reader.Value()); + loglike_reader.FreeCurrent(); + if (loglikes.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_fail++; + continue; + } + + DecodableMatrixScaledMapped decodable(trans_model, loglikes, + acoustic_scale); + + double like; + if (DecodeUtteranceLatticeIncremental( + decoder, decodable, trans_model, word_syms, utt, acoustic_scale, + determinize, allow_partial, &alignment_writer, &words_writer, + &compact_lattice_writer, &lattice_writer, &like)) { + tot_like += like; + frame_count += loglikes.NumRows(); + num_success++; + } else { + num_fail++; + } + } + } + delete decode_fst; // delete this only after decoder goes out of scope. + } else { // We have different FSTs for different utterances. + SequentialTableReader fst_reader(fst_in_str); + RandomAccessBaseFloatMatrixReader loglike_reader(feature_rspecifier); + for (; !fst_reader.Done(); fst_reader.Next()) { + std::string utt = fst_reader.Key(); + if (!loglike_reader.HasKey(utt)) { + KALDI_WARN << "Not decoding utterance " << utt + << " because no loglikes available."; + num_fail++; + continue; + } + const Matrix &loglikes = loglike_reader.Value(utt); + if (loglikes.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_fail++; + continue; + } + LatticeIncrementalDecoder decoder(fst_reader.Value(), trans_model, config); + DecodableMatrixScaledMapped decodable(trans_model, loglikes, acoustic_scale); + double like; + if (DecodeUtteranceLatticeIncremental( + decoder, decodable, trans_model, word_syms, utt, acoustic_scale, + determinize, allow_partial, &alignment_writer, &words_writer, + &compact_lattice_writer, &lattice_writer, &like)) { + tot_like += like; + frame_count += loglikes.NumRows(); + num_success++; + } else { + num_fail++; + } + } + } + + double elapsed = timer.Elapsed(); + KALDI_LOG << "Time taken " << elapsed + << "s: real-time factor assuming 100 frames/sec is " + << (elapsed * 100.0 / frame_count); + KALDI_LOG << "Done " << num_success << " utterances, failed for " << num_fail; + KALDI_LOG << "Overall log-likelihood per frame is " << (tot_like / frame_count) + << " over " << frame_count << " frames."; + + delete word_syms; + if (num_success != 0) + return 0; + else + return 1; + } catch (const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/decoder/Makefile b/src/decoder/Makefile index fbd8386f005..a814931f694 100644 --- a/src/decoder/Makefile +++ b/src/decoder/Makefile @@ -7,7 +7,8 @@ TESTFILES = OBJFILES = training-graph-compiler.o lattice-simple-decoder.o lattice-faster-decoder.o \ lattice-faster-online-decoder.o simple-decoder.o faster-decoder.o \ - decoder-wrappers.o grammar-fst.o decodable-matrix.o + decoder-wrappers.o grammar-fst.o decodable-matrix.o \ + lattice-incremental-decoder.o lattice-incremental-online-decoder.o LIBNAME = kaldi-decoder diff --git a/src/decoder/decoder-wrappers.cc b/src/decoder/decoder-wrappers.cc index 588274e113b..f63b3caa7c0 100644 --- a/src/decoder/decoder-wrappers.cc +++ b/src/decoder/decoder-wrappers.cc @@ -68,7 +68,7 @@ void DecodeUtteranceLatticeFasterClass::operator () () { success_ = true; using fst::VectorFst; if (!decoder_->Decode(decodable_)) { - KALDI_WARN << "Failed to decode file " << utt_; + KALDI_WARN << "Failed to decode utterance with id " << utt_; success_ = false; } if (!decoder_->ReachedFinal()) { @@ -195,6 +195,92 @@ DecodeUtteranceLatticeFasterClass::~DecodeUtteranceLatticeFasterClass() { delete decodable_; } +template +bool DecodeUtteranceLatticeIncremental( + LatticeIncrementalDecoderTpl &decoder, // not const but is really an input. + DecodableInterface &decodable, // not const but is really an input. + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr) { // puts utterance's like in like_ptr on success. + using fst::VectorFst; + if (!decoder.Decode(&decodable)) { + KALDI_WARN << "Failed to decode utterance with id " << utt; + return false; + } + if (!decoder.ReachedFinal()) { + if (allow_partial) { + KALDI_WARN << "Outputting partial output for utterance " << utt + << " since no final-state reached\n"; + } else { + KALDI_WARN << "Not producing output for utterance " << utt + << " since no final-state reached and " + << "--allow-partial=false.\n"; + return false; + } + } + + // Get lattice + CompactLattice clat = decoder.GetLattice(decoder.NumFramesDecoded(), true); + if (clat.NumStates() == 0) + KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt; + + double likelihood; + LatticeWeight weight; + int32 num_frames; + { // First do some stuff with word-level traceback... + CompactLattice decoded_clat; + CompactLatticeShortestPath(clat, &decoded_clat); + Lattice decoded; + fst::ConvertLattice(decoded_clat, &decoded); + + if (decoded.Start() == fst::kNoStateId) + // Shouldn't really reach this point as already checked success. + KALDI_ERR << "Failed to get traceback for utterance " << utt; + + std::vector alignment; + std::vector words; + GetLinearSymbolSequence(decoded, &alignment, &words, &weight); + num_frames = alignment.size(); + KALDI_ASSERT(num_frames == decoder.NumFramesDecoded()); + if (words_writer->IsOpen()) + words_writer->Write(utt, words); + if (alignment_writer->IsOpen()) + alignment_writer->Write(utt, alignment); + if (word_syms != NULL) { + std::cerr << utt << ' '; + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms->Find(words[i]); + if (s == "") + KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; + std::cerr << s << ' '; + } + std::cerr << '\n'; + } + likelihood = -(weight.Value1() + weight.Value2()); + } + + // We'll write the lattice without acoustic scaling. + if (acoustic_scale != 0.0) + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat); + Connect(&clat); + compact_lattice_writer->Write(utt, clat); + KALDI_LOG << "Log-like per frame for utterance " << utt << " is " + << (likelihood / num_frames) << " over " + << num_frames << " frames."; + KALDI_VLOG(2) << "Cost for utterance " << utt << " is " + << weight.Value1() << " + " << weight.Value2(); + *like_ptr = likelihood; + return true; +} + // Takes care of output. Returns true on success. template @@ -215,7 +301,7 @@ bool DecodeUtteranceLatticeFaster( using fst::VectorFst; if (!decoder.Decode(&decodable)) { - KALDI_WARN << "Failed to decode file " << utt; + KALDI_WARN << "Failed to decode utterance with id " << utt; return false; } if (!decoder.ReachedFinal()) { @@ -296,6 +382,37 @@ bool DecodeUtteranceLatticeFaster( } // Instantiate the template above for the two required FST types. +template bool DecodeUtteranceLatticeIncremental( + LatticeIncrementalDecoderTpl > &decoder, + DecodableInterface &decodable, + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr); + +template bool DecodeUtteranceLatticeIncremental( + LatticeIncrementalDecoderTpl &decoder, + DecodableInterface &decodable, + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr); + + template bool DecodeUtteranceLatticeFaster( LatticeFasterDecoderTpl > &decoder, DecodableInterface &decodable, @@ -345,7 +462,7 @@ bool DecodeUtteranceLatticeSimple( using fst::VectorFst; if (!decoder.Decode(&decodable)) { - KALDI_WARN << "Failed to decode file " << utt; + KALDI_WARN << "Failed to decode utterance with id " << utt; return false; } if (!decoder.ReachedFinal()) { diff --git a/src/decoder/decoder-wrappers.h b/src/decoder/decoder-wrappers.h index 17592d0282b..085c8e94e73 100644 --- a/src/decoder/decoder-wrappers.h +++ b/src/decoder/decoder-wrappers.h @@ -22,6 +22,7 @@ #include "itf/options-itf.h" #include "decoder/lattice-faster-decoder.h" +#include "decoder/lattice-incremental-decoder.h" #include "decoder/lattice-simple-decoder.h" // This header contains declarations from various convenience functions that are called @@ -88,6 +89,23 @@ void AlignUtteranceWrapper( void ModifyGraphForCarefulAlignment( fst::VectorFst *fst); +/// TODO +template +bool DecodeUtteranceLatticeIncremental( + LatticeIncrementalDecoderTpl &decoder, // not const but is really an input. + DecodableInterface &decodable, // not const but is really an input. + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignments_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr); // puts utterance's likelihood in like_ptr on success. + /// This function DecodeUtteranceLatticeFaster is used in several decoders, and /// we have moved it here. Note: this is really "binary-level" code as it diff --git a/src/decoder/lattice-faster-decoder.cc b/src/decoder/lattice-faster-decoder.cc index 9106309eb84..83c582d3b5e 100644 --- a/src/decoder/lattice-faster-decoder.cc +++ b/src/decoder/lattice-faster-decoder.cc @@ -229,24 +229,17 @@ void LatticeFasterDecoderTpl::PossiblyResizeHash(size_t num_toks) { extra_cost is used in pruning tokens, to save memory. - Define the 'forward cost' of a token as zero for any token on the frame - we're currently decoding; and for other frames, as the shortest-path cost - between that token and a token on the frame we're currently decoding. - (by "currently decoding" I mean the most recently processed frame). - - Then define the extra_cost of a token (always >= 0) as the forward-cost of - the token minus the smallest forward-cost of any token on the same frame. + extra_cost can be thought of as a beta (backward) cost assuming + we had set the betas on currently-active tokens to all be the negative + of the alphas for those tokens. (So all currently active tokens would + be on (tied) best paths). We can use the extra_cost to accurately prune away tokens that we know will never appear in the lattice. If the extra_cost is greater than the desired lattice beam, the token would provably never appear in the lattice, so we can prune away the token. - The advantage of storing the extra_cost rather than the forward-cost, is that - it is less costly to keep the extra_cost up-to-date when we process new frames. - When we process a new frame, *all* the previous frames' forward-costs would change; - but in general the extra_cost will change only for a finite number of frames. - (Actually we don't update all the extra_costs every time we update a frame; we + (Note: we don't update all the extra_costs every time we update a frame; we only do it every 'config_.prune_interval' frames). */ diff --git a/src/decoder/lattice-faster-decoder.h b/src/decoder/lattice-faster-decoder.h index e0cf7dea8d6..57cbe5fe178 100644 --- a/src/decoder/lattice-faster-decoder.h +++ b/src/decoder/lattice-faster-decoder.h @@ -43,11 +43,13 @@ struct LatticeFasterDecoderConfig { int32 prune_interval; bool determinize_lattice; // not inspected by this class... used in // command-line program. - BaseFloat beam_delta; // has nothing to do with beam_ratio + BaseFloat beam_delta; BaseFloat hash_ratio; - BaseFloat prune_scale; // Note: we don't make this configurable on the command line, - // it's not a very important parameter. It affects the - // algorithm that prunes the tokens as we go. + // Note: we don't make prune_scale configurable on the command line, it's not + // a very important parameter. It affects the algorithm that prunes the + // tokens as we go. + BaseFloat prune_scale; + // Most of the options inside det_opts are not actually queried by the // LatticeFasterDecoder class itself, but by the code that calls it, for // example in the function DecodeUtteranceLatticeFaster. @@ -316,15 +318,10 @@ class LatticeFasterDecoderTpl { /// This function may be optionally called after AdvanceDecoding(), when you /// do not plan to decode any further. It does an extra pruning step that /// will help to prune the lattices output by GetLattice and (particularly) - /// GetRawLattice more accurately, particularly toward the end of the - /// utterance. It does this by using the final-probs in pruning (if any - /// final-state survived); it also does a final pruning step that visits all - /// states (the pruning that is done during decoding may fail to prune states - /// that are within kPruningScale = 0.1 outside of the beam). If you call - /// this, you cannot call AdvanceDecoding again (it will fail), and you - /// cannot call GetLattice() and related functions with use_final_probs = - /// false. - /// Used to be called PruneActiveTokensFinal(). + /// GetRawLattice more completely, particularly toward the end of the + /// utterance. If you call this, you cannot call AdvanceDecoding again (it + /// will fail), and you cannot call GetLattice() and related functions with + /// use_final_probs = false. Used to be called PruneActiveTokensFinal(). void FinalizeDecoding(); /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc new file mode 100644 index 00000000000..81e70083301 --- /dev/null +++ b/src/decoder/lattice-incremental-decoder.cc @@ -0,0 +1,1720 @@ +// decoder/lattice-incremental-decoder.cc + +// Copyright 2019 Zhehuai Chen, Daniel Povey + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/lattice-incremental-decoder.h" +#include "lat/lattice-functions.h" +#include "base/timer.h" + +namespace kaldi { + +// instantiate this class once for each thing you have to decode. +template +LatticeIncrementalDecoderTpl::LatticeIncrementalDecoderTpl( + const FST &fst, const TransitionModel &trans_model, + const LatticeIncrementalDecoderConfig &config) + : fst_(&fst), + delete_fst_(false), + num_toks_(0), + config_(config), + determinizer_(trans_model, config) { + config.Check(); + toks_.SetSize(1000); // just so on the first frame we do something reasonable. +} + +template +LatticeIncrementalDecoderTpl::LatticeIncrementalDecoderTpl( + const LatticeIncrementalDecoderConfig &config, FST *fst, + const TransitionModel &trans_model) + : fst_(fst), + delete_fst_(true), + num_toks_(0), + config_(config), + determinizer_(trans_model, config) { + config.Check(); + toks_.SetSize(1000); // just so on the first frame we do something reasonable. +} + +template +LatticeIncrementalDecoderTpl::~LatticeIncrementalDecoderTpl() { + DeleteElems(toks_.Clear()); + ClearActiveTokens(); + if (delete_fst_) delete fst_; +} + +template +void LatticeIncrementalDecoderTpl::InitDecoding() { + // clean up from last time: + DeleteElems(toks_.Clear()); + cost_offsets_.clear(); + ClearActiveTokens(); + warned_ = false; + num_toks_ = 0; + decoding_finalized_ = false; + final_costs_.clear(); + StateId start_state = fst_->Start(); + KALDI_ASSERT(start_state != fst::kNoStateId); + active_toks_.resize(1); + Token *start_tok = new Token(0.0, 0.0, NULL, NULL, NULL); + active_toks_[0].toks = start_tok; + toks_.Insert(start_state, start_tok); + num_toks_++; + + determinizer_.Init(); + num_frames_in_lattice_ = 0; + token2label_map_.clear(); + next_token_label_ = LatticeIncrementalDeterminizer::kTokenLabelOffset; + ProcessNonemitting(config_.beam); +} + +template +void LatticeIncrementalDecoderTpl::UpdateLatticeDeterminization() { + if (NumFramesDecoded() - num_frames_in_lattice_ < + config_.determinize_max_delay) + return; + + + /* Make sure the token-pruning is active. Note: PruneActiveTokens() has + internal logic that prevents it from doing unnecessary work if you + call it and then immediately call it again. */ + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + + int32 first = num_frames_in_lattice_ + config_.determinize_min_chunk_size, + last = NumFramesDecoded(), + fewest_tokens = std::numeric_limits::max(), + best_frame = -1; + for (int32 t = last; t >= first; t--) { + /* Make sure PruneActiveTokens() has computed num_toks for all these + frames... */ + KALDI_ASSERT(active_toks_[t].num_toks != -1); + if (active_toks_[t].num_toks < fewest_tokens) { + // <= because we want the latest one in case of ties. + fewest_tokens = active_toks_[t].num_toks; + best_frame = t; + } + } + /* OK, determinize the chunk that spans from num_frames_in_lattice_ to + best_frame. */ + bool use_final_probs = false; + GetLattice(best_frame, use_final_probs); + return; +} +// Returns true if any kind of traceback is available (not necessarily from +// a final state). It should only very rarely return false; this indicates +// an unusual search error. +template +bool LatticeIncrementalDecoderTpl::Decode(DecodableInterface *decodable) { + InitDecoding(); + + // We use 1-based indexing for frames in this decoder (if you view it in + // terms of features), but note that the decodable object uses zero-based + // numbering, which we have to correct for when we call it. + + while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { + if (NumFramesDecoded() % config_.prune_interval == 0) { + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + } + UpdateLatticeDeterminization(); + + BaseFloat cost_cutoff = ProcessEmitting(decodable); + ProcessNonemitting(cost_cutoff); + } + Timer timer; + FinalizeDecoding(); + bool use_final_probs = true; + GetLattice(NumFramesDecoded(), use_final_probs); + KALDI_VLOG(2) << "Delay time during and after FinalizeDecoding()" + << "(secs): " << timer.Elapsed(); + + // Returns true if we have any kind of traceback available (not necessarily + // to the end state; query ReachedFinal() for that). + return !active_toks_.empty() && active_toks_.back().toks != NULL; +} + + +template +void LatticeIncrementalDecoderTpl::PossiblyResizeHash(size_t num_toks) { + size_t new_sz = + static_cast(static_cast(num_toks) * config_.hash_ratio); + if (new_sz > toks_.Size()) { + toks_.SetSize(new_sz); + } +} + +/* + A note on the definition of extra_cost. + + extra_cost is used in pruning tokens, to save memory. + + extra_cost can be thought of as a beta (backward) cost assuming + we had set the betas on currently-active tokens to all be the negative + of the alphas for those tokens. (So all currently active tokens would + be on (tied) best paths). + + + Define the 'forward cost' of a token as zero for any token on the frame + we're currently decoding; and for other frames, as the shortest-path cost + between that token and a token on the frame we're currently decoding. + (by "currently decoding" I mean the most recently processed frame). + + Then define the extra_cost of a token (always >= 0) as the forward-cost of + the token minus the smallest forward-cost of any token on the same frame. + + We can use the extra_cost to accurately prune away tokens that we know will + never appear in the lattice. If the extra_cost is greater than the desired + lattice beam, the token would provably never appear in the lattice, so we can + prune away the token. + + The advantage of storing the extra_cost rather than the forward-cost, is that + it is less costly to keep the extra_cost up-to-date when we process new frames. + When we process a new frame, *all* the previous frames' forward-costs would change; + but in general the extra_cost will change only for a finite number of frames. + (Actually we don't update all the extra_costs every time we update a frame; we + only do it every 'config_.prune_interval' frames). + */ + +// FindOrAddToken either locates a token in hash of toks_, +// or if necessary inserts a new, empty token (i.e. with no forward links) +// for the current frame. [note: it's inserted if necessary into hash toks_ +// and also into the singly linked list of tokens active on this frame +// (whose head is at active_toks_[frame]). +template +inline Token *LatticeIncrementalDecoderTpl::FindOrAddToken( + StateId state, int32 frame_plus_one, BaseFloat tot_cost, Token *backpointer, + bool *changed) { + // Returns the Token pointer. Sets "changed" (if non-NULL) to true + // if the token was newly created or the cost changed. + KALDI_ASSERT(frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + Elem *e_found = toks_.Find(state); + if (e_found == NULL) { // no such token presently. + const BaseFloat extra_cost = 0.0; + // tokens on the currently final frame have zero extra_cost + // as any of them could end up + // on the winning path. + Token *new_tok = new Token(tot_cost, extra_cost, NULL, toks, backpointer); + // NULL: no forward links yet + toks = new_tok; + num_toks_++; + toks_.Insert(state, new_tok); + if (changed) *changed = true; + return new_tok; + } else { + Token *tok = e_found->val; // There is an existing Token for this state. + if (tok->tot_cost > tot_cost) { // replace old token + tok->tot_cost = tot_cost; + // SetBackpointer() just does tok->backpointer = backpointer in + // the case where Token == BackpointerToken, else nothing. + tok->SetBackpointer(backpointer); + // we don't allocate a new token, the old stays linked in active_toks_ + // we only replace the tot_cost + // in the current frame, there are no forward links (and no extra_cost) + // only in ProcessNonemitting we have to delete forward links + // in case we visit a state for the second time + // those forward links, that lead to this replaced token before: + // they remain and will hopefully be pruned later (PruneForwardLinks...) + if (changed) *changed = true; + } else { + if (changed) *changed = false; + } + return tok; + } +} + +// prunes outgoing links for all tokens in active_toks_[frame] +// it's called by PruneActiveTokens +// all links, that have link_extra_cost > lattice_beam are pruned +template +void LatticeIncrementalDecoderTpl::PruneForwardLinks( + int32 frame_plus_one, bool *extra_costs_changed, bool *links_pruned, + BaseFloat delta) { + // delta is the amount by which the extra_costs must change + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + + *extra_costs_changed = false; + *links_pruned = false; + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen. + if (!warned_) { + KALDI_WARN << "No tokens alive [doing pruning].. warning first " + "time only for each utterance\n"; + warned_ = true; + } + } + + // We have to iterate until there is no more change, because the links + // are not guaranteed to be in topological order. + bool changed = true; // difference new minus old extra cost >= delta ? + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; tok != NULL; + tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost for tok. + BaseFloat tok_extra_cost = std::numeric_limits::infinity(); + // tok_extra_cost is the best (min) of link_extra_cost of outgoing links + for (link = tok->links; link != NULL;) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = + next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) - + next_tok->tot_cost); // difference in brackets is >= 0 + // link_exta_cost is the difference in score between the best paths + // through link source state and through link destination state + KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) + prev_link->next = next_link; + else + tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + *links_pruned = true; + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) tok_extra_cost = link_extra_cost; + prev_link = link; // move to next link + link = link->next; + } + } // for all outgoing links + if (fabs(tok_extra_cost - tok->extra_cost) > delta) + changed = true; // difference new minus old is bigger than delta + tok->extra_cost = tok_extra_cost; + // will be +infinity or <= lattice_beam_. + // infinity indicates, that no forward link survived pruning + } // for all Token on active_toks_[frame] + if (changed) *extra_costs_changed = true; + + // Note: it's theoretically possible that aggressive compiler + // optimizations could cause an infinite loop here for small delta and + // high-dynamic-range scores. + } // while changed +} + +// PruneForwardLinksFinal is a version of PruneForwardLinks that we call +// on the final frame. If there are final tokens active, it uses +// the final-probs for pruning, otherwise it treats all tokens as final. +template +void LatticeIncrementalDecoderTpl::PruneForwardLinksFinal() { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame_plus_one = active_toks_.size() - 1; + + if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen. + KALDI_WARN << "No tokens alive at end of file"; + + typedef typename unordered_map::const_iterator IterType; + ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_); + decoding_finalized_ = true; + // We call DeleteElems() as a nicety, not because it's really necessary; + // otherwise there would be a time, after calling PruneTokensForFrame() on the + // final frame, when toks_.GetList() or toks_.Clear() would contain pointers + // to nonexistent tokens. + DeleteElems(toks_.Clear()); + + // Now go through tokens on this frame, pruning forward links... may have to + // iterate a few times until there is no more change, because the list is not + // in topological order. This is a modified version of the code in + // PruneForwardLinks, but here we also take account of the final-probs. + bool changed = true; + BaseFloat delta = 1.0e-05; + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; tok != NULL; + tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost. It has a term in it that corresponds + // to the "final-prob", so instead of initializing tok_extra_cost to infinity + // below we set it to the difference between the (score+final_prob) of this + // token, + // and the best such (score+final_prob). + BaseFloat final_cost; + if (final_costs_.empty()) { + final_cost = 0.0; + } else { + IterType iter = final_costs_.find(tok); + if (iter != final_costs_.end()) + final_cost = iter->second; + else + final_cost = std::numeric_limits::infinity(); + } + BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_; + // tok_extra_cost will be a "min" over either directly being final, or + // being indirectly final through other links, and the loop below may + // decrease its value: + for (link = tok->links; link != NULL;) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = + next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) - + next_tok->tot_cost); + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) + prev_link->next = next_link; + else + tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) tok_extra_cost = link_extra_cost; + prev_link = link; + link = link->next; + } + } + // prune away tokens worse than lattice_beam above best path. This step + // was not necessary in the non-final case because then, this case + // showed up as having no forward links. Here, the tok_extra_cost has + // an extra component relating to the final-prob. + if (tok_extra_cost > config_.lattice_beam) + tok_extra_cost = std::numeric_limits::infinity(); + // to be pruned in PruneTokensForFrame + + if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) changed = true; + tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. + } + } // while changed +} + +template +BaseFloat LatticeIncrementalDecoderTpl::FinalRelativeCost() const { + BaseFloat relative_cost; + ComputeFinalCosts(NULL, &relative_cost, NULL); + return relative_cost; +} + +// Prune away any tokens on this frame that have no forward links. +// [we don't do this in PruneForwardLinks because it would give us +// a problem with dangling pointers]. +// It's called by PruneActiveTokens if any forward links have been pruned +template +void LatticeIncrementalDecoderTpl::PruneTokensForFrame( + int32 frame_plus_one) { + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + if (toks == NULL) KALDI_WARN << "No tokens alive [doing pruning]"; + Token *tok, *next_tok, *prev_tok = NULL; + int32 num_toks = 0; + for (tok = toks; tok != NULL; tok = next_tok, num_toks++) { + next_tok = tok->next; + if (tok->extra_cost == std::numeric_limits::infinity()) { + // token is unreachable from end of graph; (no forward links survived) + // excise tok from list and delete tok. + if (prev_tok != NULL) + prev_tok->next = tok->next; + else + toks = tok->next; + delete tok; + num_toks_--; + } else { // fetch next Token + prev_tok = tok; + } + } + active_toks_[frame_plus_one].num_toks = num_toks; +} + +// Go backwards through still-alive tokens, pruning them, starting not from +// the current frame (where we want to keep all tokens) but from the frame before +// that. We go backwards through the frames and stop when we reach a point +// where the delta-costs are not changing (and the delta controls when we consider +// a cost to have "not changed"). +template +void LatticeIncrementalDecoderTpl::PruneActiveTokens(BaseFloat delta) { + int32 cur_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + + if (active_toks_[cur_frame_plus_one].num_toks == -1){ + // The current frame's tokens don't get pruned so they don't get counted + // (the count is needed by the incremental determinization code). + // Fix this. + int this_frame_num_toks = 0; + for (Token *t = active_toks_[cur_frame_plus_one].toks; t != NULL; t = t->next) + this_frame_num_toks++; + active_toks_[cur_frame_plus_one].num_toks = this_frame_num_toks; + } + + // The index "f" below represents a "frame plus one", i.e. you'd have to subtract + // one to get the corresponding index for the decodable object. + for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) { + // Reason why we need to prune forward links in this situation: + // (1) we have never pruned them (new TokenList) + // (2) we have not yet pruned the forward links to the next f, + // after any of those tokens have changed their extra_cost. + if (active_toks_[f].must_prune_forward_links) { + bool extra_costs_changed = false, links_pruned = false; + PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta); + if (extra_costs_changed && f > 0) // any token has changed extra_cost + active_toks_[f - 1].must_prune_forward_links = true; + if (links_pruned) // any link was pruned + active_toks_[f].must_prune_tokens = true; + active_toks_[f].must_prune_forward_links = false; // job done + } + if (f + 1 < cur_frame_plus_one && // except for last f (no forward links) + active_toks_[f + 1].must_prune_tokens) { + PruneTokensForFrame(f + 1); + active_toks_[f + 1].must_prune_tokens = false; + } + } + KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +template +void LatticeIncrementalDecoderTpl::ComputeFinalCosts( + unordered_map *final_costs, BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const { + if (decoding_finalized_) { + // If we finalized decoding, the list toks_ will no longer exist, so return + // something we already computed. + if (final_costs) *final_costs = final_costs_; + if (final_relative_cost) *final_relative_cost = final_relative_cost_; + if (final_best_cost) *final_best_cost = final_best_cost_; + return; + } + if (final_costs != NULL) final_costs->clear(); + const Elem *final_toks = toks_.GetList(); + BaseFloat infinity = std::numeric_limits::infinity(); + BaseFloat best_cost = infinity, best_cost_with_final = infinity; + + while (final_toks != NULL) { + StateId state = final_toks->key; + Token *tok = final_toks->val; + const Elem *next = final_toks->tail; + BaseFloat final_cost = fst_->Final(state).Value(); + BaseFloat cost = tok->tot_cost, cost_with_final = cost + final_cost; + best_cost = std::min(cost, best_cost); + best_cost_with_final = std::min(cost_with_final, best_cost_with_final); + if (final_costs != NULL && final_cost != infinity) + (*final_costs)[tok] = final_cost; + final_toks = next; + } + if (final_relative_cost != NULL) { + if (best_cost == infinity && best_cost_with_final == infinity) { + // Likely this will only happen if there are no tokens surviving. + // This seems the least bad way to handle it. + *final_relative_cost = infinity; + } else { + *final_relative_cost = best_cost_with_final - best_cost; + } + } + if (final_best_cost != NULL) { + if (best_cost_with_final != infinity) { // final-state exists. + *final_best_cost = best_cost_with_final; + } else { // no final-state exists. + *final_best_cost = best_cost; + } + } +} + +template +void LatticeIncrementalDecoderTpl::AdvanceDecoding( + DecodableInterface *decodable, int32 max_num_frames) { + if (std::is_same >::value) { + // if the type 'FST' is the FST base-class, then see if the FST type of fst_ + // is actually VectorFst or ConstFst. If so, call the AdvanceDecoding() + // function after casting *this to the more specific type. + if (fst_->Type() == "const") { + LatticeIncrementalDecoderTpl, Token> *this_cast = + reinterpret_cast< + LatticeIncrementalDecoderTpl, Token> *>( + this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } else if (fst_->Type() == "vector") { + LatticeIncrementalDecoderTpl, Token> *this_cast = + reinterpret_cast< + LatticeIncrementalDecoderTpl, Token> *>( + this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } + } + + KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ && + "You must call InitDecoding() before AdvanceDecoding"); + int32 num_frames_ready = decodable->NumFramesReady(); + // num_frames_ready must be >= num_frames_decoded, or else + // the number of frames ready must have decreased (which doesn't + // make sense) or the decodable object changed between calls + // (which isn't allowed). + KALDI_ASSERT(num_frames_ready >= NumFramesDecoded()); + int32 target_frames_decoded = num_frames_ready; + if (max_num_frames >= 0) + target_frames_decoded = + std::min(target_frames_decoded, NumFramesDecoded() + max_num_frames); + while (NumFramesDecoded() < target_frames_decoded) { + if (NumFramesDecoded() % config_.prune_interval == 0) { + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + } + BaseFloat cost_cutoff = ProcessEmitting(decodable); + ProcessNonemitting(cost_cutoff); + } + UpdateLatticeDeterminization(); +} + +// FinalizeDecoding() is a version of PruneActiveTokens that we call +// (optionally) on the final frame. Takes into account the final-prob of +// tokens. This function used to be called PruneActiveTokensFinal(). +template +void LatticeIncrementalDecoderTpl::FinalizeDecoding() { + int32 final_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // PruneForwardLinksFinal() prunes the final frame (with final-probs), and + // sets decoding_finalized_. + PruneForwardLinksFinal(); + for (int32 f = final_frame_plus_one - 1; f >= 0; f--) { + bool b1, b2; // values not used. + BaseFloat dontcare = 0.0; // delta of zero means we must always update + PruneForwardLinks(f, &b1, &b2, dontcare); + PruneTokensForFrame(f + 1); + } + PruneTokensForFrame(0); + KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin << " to " << num_toks_; +} + +/// Gets the weight cutoff. Also counts the active tokens. +template +BaseFloat LatticeIncrementalDecoderTpl::GetCutoff( + Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, Elem **best_elem) { + BaseFloat best_weight = std::numeric_limits::infinity(); + // positive == high cost == bad. + size_t count = 0; + if (config_.max_active == std::numeric_limits::max() && + config_.min_active == 0) { + for (Elem *e = list_head; e != NULL; e = e->tail, count++) { + BaseFloat w = static_cast(e->val->tot_cost); + if (w < best_weight) { + best_weight = w; + if (best_elem) *best_elem = e; + } + } + if (tok_count != NULL) *tok_count = count; + if (adaptive_beam != NULL) *adaptive_beam = config_.beam; + return best_weight + config_.beam; + } else { + tmp_array_.clear(); + for (Elem *e = list_head; e != NULL; e = e->tail, count++) { + BaseFloat w = e->val->tot_cost; + tmp_array_.push_back(w); + if (w < best_weight) { + best_weight = w; + if (best_elem) *best_elem = e; + } + } + if (tok_count != NULL) *tok_count = count; + + BaseFloat beam_cutoff = best_weight + config_.beam, + min_active_cutoff = std::numeric_limits::infinity(), + max_active_cutoff = std::numeric_limits::infinity(); + + KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() + << " is " << tmp_array_.size(); + + if (tmp_array_.size() > static_cast(config_.max_active)) { + std::nth_element(tmp_array_.begin(), tmp_array_.begin() + config_.max_active, + tmp_array_.end()); + max_active_cutoff = tmp_array_[config_.max_active]; + } + if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam. + if (adaptive_beam) + *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta; + return max_active_cutoff; + } + if (tmp_array_.size() > static_cast(config_.min_active)) { + if (config_.min_active == 0) + min_active_cutoff = best_weight; + else { + std::nth_element(tmp_array_.begin(), tmp_array_.begin() + config_.min_active, + tmp_array_.size() > static_cast(config_.max_active) + ? tmp_array_.begin() + config_.max_active + : tmp_array_.end()); + min_active_cutoff = tmp_array_[config_.min_active]; + } + } + if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam. + if (adaptive_beam) + *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta; + return min_active_cutoff; + } else { + *adaptive_beam = config_.beam; + return beam_cutoff; + } + } +} + +template +BaseFloat LatticeIncrementalDecoderTpl::ProcessEmitting( + DecodableInterface *decodable) { + KALDI_ASSERT(active_toks_.size() > 0); + int32 frame = active_toks_.size() - 1; // frame is the frame-index + // (zero-based) used to get likelihoods + // from the decodable object. + active_toks_.resize(active_toks_.size() + 1); + + Elem *final_toks = toks_.Clear(); // analogous to swapping prev_toks_ / cur_toks_ + // in simple-decoder.h. Removes the Elems from + // being indexed in the hash in toks_. + Elem *best_elem = NULL; + BaseFloat adaptive_beam; + size_t tok_cnt; + BaseFloat cur_cutoff = GetCutoff(final_toks, &tok_cnt, &adaptive_beam, &best_elem); + KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is " + << adaptive_beam; + + PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough. + + BaseFloat next_cutoff = std::numeric_limits::infinity(); + // pruning "online" before having seen all tokens + + BaseFloat cost_offset = 0.0; // Used to keep probabilities in a good + // dynamic range. + + // First process the best token to get a hopefully + // reasonably tight bound on the next cutoff. The only + // products of the next block are "next_cutoff" and "cost_offset". + if (best_elem) { + StateId state = best_elem->key; + Token *tok = best_elem->val; + cost_offset = -tok->tot_cost; + for (fst::ArcIterator aiter(*fst_, state); !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + BaseFloat new_weight = arc.weight.Value() + cost_offset - + decodable->LogLikelihood(frame, arc.ilabel) + + tok->tot_cost; + if (new_weight + adaptive_beam < next_cutoff) + next_cutoff = new_weight + adaptive_beam; + } + } + } + + // Store the offset on the acoustic likelihoods that we're applying. + // Could just do cost_offsets_.push_back(cost_offset), but we + // do it this way as it's more robust to future code changes. + cost_offsets_.resize(frame + 1, 0.0); + cost_offsets_[frame] = cost_offset; + + // the tokens are now owned here, in final_toks, and the hash is empty. + // 'owned' is a complex thing here; the point is we need to call DeleteElem + // on each elem 'e' to let toks_ know we're done with them. + for (Elem *e = final_toks, *e_tail; e != NULL; e = e_tail) { + // loop this way because we delete "e" as we go. + StateId state = e->key; + Token *tok = e->val; + if (tok->tot_cost <= cur_cutoff) { + for (fst::ArcIterator aiter(*fst_, state); !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + BaseFloat ac_cost = + cost_offset - decodable->LogLikelihood(frame, arc.ilabel), + graph_cost = arc.weight.Value(), cur_cost = tok->tot_cost, + tot_cost = cur_cost + ac_cost + graph_cost; + if (tot_cost > next_cutoff) + continue; + else if (tot_cost + adaptive_beam < next_cutoff) + next_cutoff = tot_cost + adaptive_beam; // prune by best current token + // Note: the frame indexes into active_toks_ are one-based, + // hence the + 1. + Token *next_tok = + FindOrAddToken(arc.nextstate, frame + 1, tot_cost, tok, NULL); + // NULL: no change indicator needed + + // Add ForwardLink from tok to next_tok (put on head of list tok->links) + tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, graph_cost, + ac_cost, tok->links); + } + } // for all arcs + } + e_tail = e->tail; + toks_.Delete(e); // delete Elem + } + return next_cutoff; +} + +// static inline +template +void LatticeIncrementalDecoderTpl::DeleteForwardLinks(Token *tok) { + ForwardLinkT *l = tok->links, *m; + while (l != NULL) { + m = l->next; + delete l; + l = m; + } + tok->links = NULL; +} + +template +void LatticeIncrementalDecoderTpl::ProcessNonemitting(BaseFloat cutoff) { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame = static_cast(active_toks_.size()) - 2; + // Note: "frame" is the time-index we just processed, or -1 if + // we are processing the nonemitting transitions before the + // first frame (called from InitDecoding()). + + // Processes nonemitting arcs for one frame. Propagates within toks_. + // Note-- this queue structure is is not very optimal as + // it may cause us to process states unnecessarily (e.g. more than once), + // but in the baseline code, turning this vector into a set to fix this + // problem did not improve overall speed. + + KALDI_ASSERT(queue_.empty()); + + if (toks_.GetList() == NULL) { + if (!warned_) { + KALDI_WARN << "Error, no surviving tokens: frame is " << frame; + warned_ = true; + } + } + + for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) { + StateId state = e->key; + if (fst_->NumInputEpsilons(state) != 0) queue_.push_back(state); + } + + while (!queue_.empty()) { + StateId state = queue_.back(); + queue_.pop_back(); + + Token *tok = + toks_.Find(state) + ->val; // would segfault if state not in toks_ but this can't happen. + BaseFloat cur_cost = tok->tot_cost; + if (cur_cost > cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + // but since most states are emitting it's not a huge issue. + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel == 0) { // propagate nonemitting only... + BaseFloat graph_cost = arc.weight.Value(), tot_cost = cur_cost + graph_cost; + if (tot_cost < cutoff) { + bool changed; + + Token *new_tok = + FindOrAddToken(arc.nextstate, frame + 1, tot_cost, tok, &changed); + + tok->links = + new ForwardLinkT(new_tok, 0, arc.olabel, graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new [if so, add into queue]. + if (changed && fst_->NumInputEpsilons(arc.nextstate) != 0) + queue_.push_back(arc.nextstate); + } + } + } // for all arcs + } // while queue not empty +} + +template +void LatticeIncrementalDecoderTpl::DeleteElems(Elem *list) { + for (Elem *e = list, *e_tail; e != NULL; e = e_tail) { + e_tail = e->tail; + toks_.Delete(e); + } +} + +template +void LatticeIncrementalDecoderTpl< + FST, Token>::ClearActiveTokens() { // a cleanup routine, at utt end/begin + for (size_t i = 0; i < active_toks_.size(); i++) { + // Delete all tokens alive on this frame, and any forward + // links they may have. + for (Token *tok = active_toks_[i].toks; tok != NULL;) { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + active_toks_.clear(); + KALDI_ASSERT(num_toks_ == 0); +} + + +template +const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( + int32 num_frames_to_include, + bool use_final_probs) { + KALDI_ASSERT(num_frames_to_include >= num_frames_in_lattice_ && + num_frames_to_include <= NumFramesDecoded()); + + if (decoding_finalized_ && !use_final_probs) { + // This is not supported + KALDI_ERR << "You cannot get the lattice without final-probs after " + "calling FinalizeDecoding()."; + } + if (use_final_probs && num_frames_to_include != NumFramesDecoded()) { + /* This is because we only remember the relation between HCLG states and + Tokens for the current frame; the Token does not have a `state` field. */ + KALDI_ERR << "use-final-probs may no be true if you are not " + "getting a lattice for all frames decoded so far."; + } + + + if (num_frames_to_include > num_frames_in_lattice_) { + /* Make sure the token-pruning is up to date. If we just pruned the tokens, + this will do very little work. */ + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + + Lattice chunk_lat; + + unordered_map token_label2state; + if (num_frames_in_lattice_ != 0) { + determinizer_.InitializeRawLatticeChunk(&chunk_lat, + &token_label2state); + } + + // tok_map will map from Token* to state-id in chunk_lat. + // The cur and prev versions alternate on different frames. + unordered_map &tok2state_map(temp_token_map_); + tok2state_map.clear(); + + unordered_map &next_token2label_map(token2label_map_temp_); + next_token2label_map.clear(); + + { // Deal with the last frame in the chunk, the one numbered `num_frames_to_include`. + // (Yes, this is backwards). We allocate token labels, and set tokens as + // final, but don't add any transitions. This may leave some states + // disconnected (e.g. due to chains of nonemitting arcs), but it's OK; we'll + // fix it when we generate the next chunk of lattice. + int32 frame = num_frames_to_include; + // Allocate state-ids for all tokens on this frame. + + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + /* If we included the final-costs at this stage, they will cause + non-final states to be pruned out from the end of the lattice. */ + BaseFloat final_cost; + { // This block computes final_cost + if (decoding_finalized_) { + if (final_costs_.empty()) { + final_cost = 0.0; /* No final-state survived, so treat all as final + * with probability One(). */ + } else { + auto iter = final_costs_.find(tok); + if (iter == final_costs_.end()) + final_cost = std::numeric_limits::infinity(); + else + final_cost = iter->second; + } + } else { + /* this is a `fake` final-cost used to guide pruning. It's as if we + set the betas (backward-probs) on the final frame to the + negatives of the corresponding alphas, so all tokens on the last + frae will be on a best path.. the extra_cost for each token + always corresponds to its alpha+beta on this assumption. We want + the final_cost here to correspond to the beta (backward-prob), so + we get that by final_cost = extra_cost - tot_cost. + [The tot_cost is the forward/alpha cost.] + */ + final_cost = tok->extra_cost - tok->tot_cost; + } + } + + StateId state = chunk_lat.AddState(); + tok2state_map[tok] = state; + if (final_cost < std::numeric_limits::infinity()) { + next_token2label_map[tok] = AllocateNewTokenLabel(); + StateId token_final_state = chunk_lat.AddState(); + LatticeArc::Label ilabel = 0, + olabel = (next_token2label_map[tok] = AllocateNewTokenLabel()); + chunk_lat.AddArc(state, + LatticeArc(ilabel, olabel, + LatticeWeight::One(), + token_final_state)); + chunk_lat.SetFinal(token_final_state, LatticeWeight(final_cost, 0.0)); + } + } + } + + // Go in reverse order over the remaining frames so we can create arcs as we + // go, and their destination-states will already be in the map. + for (int32 frame = num_frames_to_include; + frame >= num_frames_in_lattice_; frame--) { + // The conditional below is needed for the last frame of the utterance. + BaseFloat cost_offset = (frame < cost_offsets_.size() ? + cost_offsets_[frame] : 0.0); + + // For the first frame of the chunk, we need to make sure the states are + // the ones created by InitializeRawLatticeChunk() (where not pruned away). + if (frame == num_frames_in_lattice_ && num_frames_in_lattice_ != 0) { + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + auto iter = token2label_map_.find(tok); + KALDI_ASSERT(iter != token2label_map_.end()); + Label token_label = iter->second; + auto iter2 = token_label2state.find(token_label); + if (iter2 != token_label2state.end()) { + StateId state = iter2->second; + tok2state_map[tok] = state; + } else { + // Some states may have been pruned out, but we should still allocate + // them. They might have been part of chains of nonemitting arcs + // where the state became disconnected because the last chunk didn't + // include arcs starting at this frame. + StateId state = chunk_lat.AddState(); + tok2state_map[tok] = state; + } + } + } else if (frame != num_frames_to_include) { // We already created states + // for the last frame. + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + StateId state = chunk_lat.AddState(); + tok2state_map[tok] = state; + } + } + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + auto iter = tok2state_map.find(tok); + KALDI_ASSERT(iter != tok2state_map.end()); + StateId cur_state = iter->second; + for (ForwardLinkT *l = tok->links; l != NULL; l = l->next) { + auto next_iter = tok2state_map.find(l->next_tok); + if (next_iter == tok2state_map.end()) { + // Emitting arcs from the last frame we're including -- ignore + // these. + KALDI_ASSERT(frame == num_frames_to_include); + continue; + } + StateId next_state = next_iter->second; + BaseFloat this_offset = (l->ilabel != 0 ? cost_offset : 0); + LatticeArc arc(l->ilabel, l->olabel, + LatticeWeight(l->graph_cost, l->acoustic_cost - this_offset), + next_state); + // Note: the epsilons get redundantly included at the end and beginning + // of successive chunks. These will get removed in the determinization. + chunk_lat.AddArc(cur_state, arc); + } + } + } + if (num_frames_in_lattice_ == 0) { + // This block locates the start token. NOTE: we use the fact that in the + // linked list of tokens, things are added at the head, so the start state + // must be at the tail. If this data structure is changed in future, we + // might need to explicitly store the start token as a class member. + Token *tok = active_toks_[0].toks; + if (tok == NULL) { + KALDI_WARN << "No tokens exist on start frame"; + return determinizer_.GetLattice(); // will be empty. + } + while (tok->next != NULL) + tok = tok->next; + Token *start_token = tok; + auto iter = tok2state_map.find(start_token); + KALDI_ASSERT(iter != tok2state_map.end()); + StateId start_state = iter->second; + chunk_lat.SetStart(start_state); + } + token2label_map_.swap(next_token2label_map); + + // bool finished_before_beam = + determinizer_.AcceptRawLatticeChunk(&chunk_lat); + // We are ignoring the return status, which say whether it finished before the beam. + + num_frames_in_lattice_ = num_frames_to_include; + } + + unordered_map token2final_cost; + unordered_map token_label2final_cost; + if (use_final_probs) { + ComputeFinalCosts(&token2final_cost, NULL, NULL); + for (const auto &p: token2final_cost) { + Token *tok = p.first; + BaseFloat cost = p.second; + auto iter = token2label_map_.find(tok); + if (iter != token2label_map_.end()) { + /* Some tokens may not have survived the pruned determinization. */ + Label token_label = iter->second; + bool ret = token_label2final_cost.insert({token_label, cost}).second; + KALDI_ASSERT(ret); /* Make sure it was inserted. */ + } + } + } + /* Note: these final-probs won't affect the next chunk, only the lattice + returned from GetLattice(). They are kind of temporaries. */ + determinizer_.SetFinalCosts(token_label2final_cost.empty() ? NULL : + &token_label2final_cost); + + return determinizer_.GetLattice(); +} + + +template +int32 LatticeIncrementalDecoderTpl::GetNumToksForFrame(int32 frame) { + int32 r = 0; + for (Token *tok = active_toks_[frame].toks; tok; tok = tok->next) r++; + return r; +} + + + +/* This utility function adds an arc to a Lattice, but where the source is a + CompactLatticeArc. If the CompactLatticeArc has a string with length greater + than 1, this will require adding extra states to `lat`. + */ +static void AddCompactLatticeArcToLattice( + const CompactLatticeArc &clat_arc, + LatticeArc::StateId src_state, + Lattice *lat) { + const std::vector &string = clat_arc.weight.String(); + size_t N = string.size(); + if (N == 0) { + LatticeArc arc; + arc.ilabel = 0; + arc.olabel = clat_arc.ilabel; + arc.nextstate = clat_arc.nextstate; + arc.weight = clat_arc.weight.Weight(); + lat->AddArc(src_state, arc); + } else { + LatticeArc::StateId cur_state = src_state; + for (size_t i = 0; i < N; i++) { + LatticeArc arc; + arc.ilabel = string[i]; + arc.olabel = (i == 0 ? clat_arc.ilabel : 0); + arc.nextstate = (i + 1 == N ? clat_arc.nextstate : lat->AddState()); + arc.weight = (i == 0 ? clat_arc.weight.Weight() : LatticeWeight::One()); + lat->AddArc(cur_state, arc); + cur_state = arc.nextstate; + } + } +} + + +void LatticeIncrementalDeterminizer::Init() { + non_final_redet_states_.clear(); + clat_.DeleteStates(); + final_arcs_.clear(); + forward_costs_.clear(); + arcs_in_.clear(); +} + +CompactLattice::StateId LatticeIncrementalDeterminizer::AddStateToClat() { + CompactLattice::StateId ans = clat_.AddState(); + forward_costs_.push_back(std::numeric_limits::infinity()); + KALDI_ASSERT(forward_costs_.size() == ans + 1); + arcs_in_.resize(ans + 1); + return ans; +} + +void LatticeIncrementalDeterminizer::AddArcToClat( + CompactLattice::StateId state, + const CompactLatticeArc &arc) { + BaseFloat forward_cost = forward_costs_[state] + + ConvertToCost(arc.weight); + if (forward_cost == std::numeric_limits::infinity()) + return; + int32 arc_idx = clat_.NumArcs(state); + clat_.AddArc(state, arc); + arcs_in_[arc.nextstate].push_back({state, arc_idx}); + if (forward_cost < forward_costs_[arc.nextstate]) + forward_costs_[arc.nextstate] = forward_cost; +} + +// See documentation in header +void LatticeIncrementalDeterminizer::IdentifyTokenFinalStates( + const CompactLattice &chunk_clat, + std::unordered_map *token_map) const { + token_map->clear(); + using StateId = CompactLattice::StateId; + using Label = CompactLatticeArc::Label; + + StateId num_states = chunk_clat.NumStates(); + for (StateId state = 0; state < num_states; state++) { + for (fst::ArcIterator aiter(chunk_clat, state); + !aiter.Done(); aiter.Next()) { + const CompactLatticeArc &arc = aiter.Value(); + if (arc.olabel >= kTokenLabelOffset && arc.olabel < kMaxTokenLabel) { + StateId nextstate = arc.nextstate; + auto r = token_map->insert({nextstate, arc.olabel}); + // Check consistency of labels on incoming arcs + KALDI_ASSERT(r.first->second == arc.olabel); + } + } + } +} + + + + +void LatticeIncrementalDeterminizer::GetNonFinalRedetStates() { + using StateId = CompactLattice::StateId; + non_final_redet_states_.clear(); + non_final_redet_states_.reserve(final_arcs_.size()); + + std::vector state_queue; + for (const CompactLatticeArc &arc: final_arcs_) { + // Note: we abuse the .nextstate field to store the state which is really + // the source of that arc. + StateId redet_state = arc.nextstate; + if (forward_costs_[redet_state] != std::numeric_limits::infinity()) { + // if it is accessible.. + if (non_final_redet_states_.insert(redet_state).second) { + // it was not already there + state_queue.push_back(redet_state); + } + } + } + // Add any states that are reachable from the states above. + while (!state_queue.empty()) { + StateId s = state_queue.back(); + state_queue.pop_back(); + for (fst::ArcIterator aiter(clat_, s); !aiter.Done(); + aiter.Next()) { + const CompactLatticeArc &arc = aiter.Value(); + StateId nextstate = arc.nextstate; + if (non_final_redet_states_.insert(nextstate).second) + state_queue.push_back(nextstate); // it was not already there + } + } +} + + +void LatticeIncrementalDeterminizer::InitializeRawLatticeChunk( + Lattice *olat, + unordered_map *token_label2state) { + using namespace fst; + + olat->DeleteStates(); + LatticeArc::StateId start_state = olat->AddState(); + olat->SetStart(start_state); + token_label2state->clear(); + + // redet_state_map maps from state-ids in clat_ to state-ids in olat. This + // will be the set of states from which the arcs to final-states in the + // canonical appended lattice leave (physically, these are in the .nextstate + // elements of arcs_, since we use that field for the source state), plus any + // states reachable from those states. + unordered_map redet_state_map; + + for (CompactLattice::StateId redet_state: non_final_redet_states_) + redet_state_map[redet_state] = olat->AddState(); + + // First, process any arcs leaving the non-final redeterminized states that + // are not to final-states. (What we mean by "not to final states" is, not to + // stats that are final in the `canonical appended lattice`.. they may + // actually be physically final in clat_, because we make clat_ what we want + // to return to the user. + for (CompactLattice::StateId redet_state: non_final_redet_states_) { + LatticeArc::StateId lat_state = redet_state_map[redet_state]; + + for (ArcIterator aiter(clat_, redet_state); + !aiter.Done(); aiter.Next()) { + const CompactLatticeArc &arc = aiter.Value(); + CompactLattice::StateId nextstate = arc.nextstate; + LatticeArc::StateId lat_nextstate = olat->NumStates(); + auto r = redet_state_map.insert({nextstate, lat_nextstate}); + if (r.second) { // Was inserted. + LatticeArc::StateId s = olat->AddState(); + KALDI_ASSERT(s == lat_nextstate); + } else { + // was not inserted -> was already there. + lat_nextstate = r.first->second; + } + CompactLatticeArc clat_arc(arc); + clat_arc.nextstate = lat_nextstate; + AddCompactLatticeArcToLattice(clat_arc, lat_state, olat); + } + clat_.DeleteArcs(redet_state); + clat_.SetFinal(redet_state, CompactLatticeWeight::Zero()); + } + + for (const CompactLatticeArc &arc: final_arcs_) { + // We abuse the `nextstate` field to store the source state. + CompactLattice::StateId src_state = arc.nextstate; + auto iter = redet_state_map.find(src_state); + if (forward_costs_[src_state] == std::numeric_limits::infinity()) + continue; /* Unreachable state */ + KALDI_ASSERT(iter != redet_state_map.end()); + LatticeArc::StateId src_lat_state = iter->second; + Label token_label = arc.ilabel; // will be == arc.olabel. + KALDI_ASSERT(token_label >= kTokenLabelOffset && + token_label < kMaxTokenLabel); + auto r = token_label2state->insert({token_label, + olat->NumStates()}); + LatticeArc::StateId dest_lat_state = r.first->second; + if (r.second) { // was inserted + LatticeArc::StateId new_state = olat->AddState(); + KALDI_ASSERT(new_state == dest_lat_state); + } + CompactLatticeArc new_arc; + new_arc.nextstate = dest_lat_state; + /* We convert the token-label to epsilon; it's not needed anymore. */ + new_arc.ilabel = new_arc.olabel = 0; + new_arc.weight = arc.weight; + AddCompactLatticeArcToLattice(new_arc, src_lat_state, olat); + } + + // Now deal with the initial-probs. Arcs from initial-states to + // redeterminized-states in the raw lattice have an olabel that identifies the + // id of that redeterminized-state in clat_, and a cost that is derived from + // its entry in forward_costs_. These forward-probs are used to get the + // pruned lattice determinization to behave correctly, and will be canceled + // out later on. + // + // In the paper this is the second-from-last bullet in Sec. 5.2. NOTE: in the + // paper we state that we only include such arcs for "each redeterminized + // state that is either initial in det(A) or that has an arc entering it from + // a state that is not a redeterminized state." In fact, we include these + // arcs for all redeterminized states. I realized that it won't make a + // difference to the outcome, and it's easier to do it this way. + for (CompactLattice::StateId state_id: non_final_redet_states_) { + BaseFloat forward_cost = forward_costs_[state_id]; + LatticeArc arc; + arc.ilabel = 0; + // The olabel (which appears where the word-id would) is what + // we call a 'state-label'. It identifies a state in clat_. + arc.olabel = state_id + kStateLabelOffset; + // It doesn't matter what field we put forward_cost in (or whether we + // divide it among them both; the effect on pruning is the same, and + // we will cancel it out later anyway. + arc.weight = LatticeWeight(forward_cost, 0); + auto iter = redet_state_map.find(state_id); + KALDI_ASSERT(iter != redet_state_map.end()); + arc.nextstate = iter->second; + olat->AddArc(start_state, arc); + } +} + +void LatticeIncrementalDeterminizer::GetRawLatticeFinalCosts( + const Lattice &raw_fst, + std::unordered_map *old_final_costs) { + LatticeArc::StateId raw_fst_num_states = raw_fst.NumStates(); + for (LatticeArc::StateId s = 0; s < raw_fst_num_states; s++) { + for (fst::ArcIterator aiter(raw_fst, s); !aiter.Done(); + aiter.Next()) { + const LatticeArc &value = aiter.Value(); + if (value.olabel >= (Label)kTokenLabelOffset && + value.olabel < (Label)kMaxTokenLabel) { + LatticeWeight final_weight = raw_fst.Final(value.nextstate); + if (final_weight == LatticeWeight::Zero() || + final_weight.Value2() != 0) { + KALDI_ERR << "Label " << value.olabel << " from state " << s + << " looks like a token-label but its next-state " + << value.nextstate << + " has unexpected final-weight " << final_weight.Value1() << ',' + << final_weight.Value2(); + } + auto r = old_final_costs->insert({value.olabel, + final_weight.Value1()}); + if (!r.second && r.first->second != final_weight.Value1()) { + // For any given token-label, all arcs in raw_fst with that + // olabel should go to the same state, so this should be + // impossible. + KALDI_ERR << "Unexpected mismatch in final-costs for tokens, " + << r.first->second << " vs " << final_weight.Value1(); + } + } + } + } +} + + +bool LatticeIncrementalDeterminizer::ProcessArcsFromChunkStartState( + const CompactLattice &chunk_clat, + std::unordered_map *state_map, + CompactLatticeWeight *extra_start_weight) { + using StateId = CompactLattice::StateId; + StateId clat_num_states = clat_.NumStates(); + + // Process arcs leaving the start state of chunk_clat. These arcs will have + // state-labels on them (unless this is the first chunk). + // For destination-states of those arcs, work out which states in + // clat_ they correspond to and update their forward_costs. + for (fst::ArcIterator aiter(chunk_clat, chunk_clat.Start()); + !aiter.Done(); aiter.Next()) { + const CompactLatticeArc &arc = aiter.Value(); + Label label = arc.ilabel; // ilabel == olabel; would be the olabel + // in a Lattice. + if (!(label >= kStateLabelOffset && + label - kStateLabelOffset < clat_num_states)) { + // The label was not a state-label. This should only be possible on the + // first chunk. + KALDI_ASSERT(state_map->empty()); + return true; // this is the first chunk. + } + StateId clat_state = label - kStateLabelOffset; + StateId chunk_state = arc.nextstate; + auto p = state_map->insert({chunk_state, clat_state}); + StateId dest_clat_state = p.first->second; + // We deleted all its arcs in InitializeRawLatticeChunk + KALDI_ASSERT(clat_.NumArcs(clat_state) == 0); + /* + In almost all cases, dest_clat_state and clat_state will be the same state; + but there may be situations where two arcs with different state-labels + left the start state and entered the same next-state in chunk_clat; and in + these cases, they will be different. + + We didn't address this issue in the paper (or actually realize it could be + a problem). What we do is pick one of the clat_states as the "canonical" + one, and redirect all incoming transitions of the others to enter the + "canonical" one. (Search below for new_in_arc.nextstate = + dest_clat_state). + */ + if (clat_state != dest_clat_state) { + // Check that the start state isn't getting merged with any other state. + // If this were possible, we'd need to deal with it specially, but it + // can't be, because to be merged, 2 states must have identical arcs + // leaving them with identical weights, so we'd need to have another state + // on frame 0 identical to the start state, which is not possible if the + // lattice is deterministic and epsilon-free. + KALDI_ASSERT(clat_state != 0 && dest_clat_state != 0); + } + + // in_weight is an extra weight that we'll include on arcs entering this + // state from the previous chunk. We need to cancel out + // `forward_costs[clat_state]`, which was included in the corresponding arc + // in the raw lattice for pruning purposes; and we need to include + // the weight from the start-state of `chunk_clat` to this state. + CompactLatticeWeight extra_weight_in = arc.weight; + extra_weight_in.SetWeight( + fst::Times(extra_weight_in.Weight(), + LatticeWeight(-forward_costs_[clat_state], 0.0))); + + if (clat_state == 0) { + // if clat_state is the star-state of clat_ (state 0), we can't modify + // incoming arcs; we need to modify outgoing arcs, but we'll do that + // later, after we add them. + *extra_start_weight = extra_weight_in; + forward_costs_[0] = forward_costs_[0] + ConvertToCost(extra_weight_in); + continue; + } + + // Note: 0 is the start state of clat_. This was checked. + forward_costs_[clat_state] = (clat_state == 0 ? 0 : + std::numeric_limits::infinity()); + std::vector > arcs_in; + arcs_in.swap(arcs_in_[clat_state]); + for (auto p: arcs_in) { + // Note: we'll be doing `continue` below if this input arc came from + // another redeterminized-state, because we did DeleteStates() for them in + // InitializeRawLatticeChunk(). Those arcs will be transferred + // from chunk_clat later on. + CompactLattice::StateId src_state = p.first; + int32 arc_pos = p.second; + if (arc_pos >= (int32)clat_.NumArcs(src_state)) + continue; + fst::MutableArcIterator aiter(&clat_, src_state); + aiter.Seek(arc_pos); + if (aiter.Value().nextstate != clat_state) + continue; // This arc record has become invalidated. + CompactLatticeArc new_in_arc(aiter.Value()); + // In most cases we will have dest_clat_state == clat_state, so the next + // line won't change the value of .nextstate + new_in_arc.nextstate = dest_clat_state; + new_in_arc.weight = fst::Times(new_in_arc.weight, extra_weight_in); + aiter.SetValue(new_in_arc); + + BaseFloat new_forward_cost = forward_costs_[src_state] + + ConvertToCost(new_in_arc.weight); + if (new_forward_cost < forward_costs_[dest_clat_state]) + forward_costs_[dest_clat_state] = new_forward_cost; + arcs_in_[dest_clat_state].push_back(p); + } + } + return false; // this is not the first chunk. +} + +void LatticeIncrementalDeterminizer::ReweightStartState( + CompactLatticeWeight &extra_start_weight) { + for (fst::MutableArcIterator aiter(&clat_, 0); + !aiter.Done(); aiter.Next()) { + CompactLatticeArc arc(aiter.Value()); + arc.weight = fst::Times(extra_start_weight, arc.weight); + aiter.SetValue(arc); + } +} + +void LatticeIncrementalDeterminizer::TransferArcsToClat( + const CompactLattice &chunk_clat, + bool is_first_chunk, + const std::unordered_map &state_map, + const std::unordered_map &chunk_state_to_token, + const std::unordered_map &old_final_costs) { + using StateId = CompactLattice::StateId; + StateId chunk_num_states = chunk_clat.NumStates(); + + // Now transfer arcs from chunk_clat to clat_. + for (StateId chunk_state = (is_first_chunk ? 0 : 1); + chunk_state < chunk_num_states; chunk_state++) { + auto iter = state_map.find(chunk_state); + if (iter == state_map.end()) { + KALDI_ASSERT(chunk_state_to_token.count(chunk_state) != 0); + // Don't process token-final states. Anyway they have no arcs leaving + // them. + continue; + } + StateId clat_state = iter->second; + + // We know that this point that `clat_state` is not a token-final state + // (see glossary for definition) as if it were, we would have done + // `continue` above. + // + // Only in the last chunk of the lattice would be there be a final-prob on + // states that are not `token-final states`; these final-probs would + // normally all be Zero() at this point. So in almost all cases the following + // call will do nothing. + clat_.SetFinal(clat_state, chunk_clat.Final(chunk_state)); + + // Process arcs leaving this state. + for (fst::ArcIterator aiter(chunk_clat, chunk_state); + !aiter.Done(); aiter.Next()) { + CompactLatticeArc arc(aiter.Value()); + + auto next_iter = state_map.find(arc.nextstate); + if (next_iter != state_map.end()) { + // The normal case (when the .nextstate has a corresponding + // state in clat_) is very simple. Just copy the arc over. + arc.nextstate = next_iter->second; + KALDI_ASSERT(arc.ilabel < kTokenLabelOffset || + arc.ilabel > kMaxTokenLabel); + AddArcToClat(clat_state, arc); + } else { + // This is the case when the arc is to a `token-final` state (see + // glossary.) + + // TODO: remove the following slightly excessive assertion? + KALDI_ASSERT(chunk_clat.Final(arc.nextstate) != CompactLatticeWeight::Zero() && + arc.olabel >= (Label)kTokenLabelOffset && + arc.olabel < (Label)kMaxTokenLabel && + chunk_state_to_token.count(arc.nextstate) != 0 && + old_final_costs.count(arc.olabel) != 0); + + // Include the final-cost of the next state (which should be final) + // in arc.weight. + arc.weight = fst::Times(arc.weight, + chunk_clat.Final(arc.nextstate)); + + auto cost_iter = old_final_costs.find(arc.olabel); + KALDI_ASSERT(cost_iter != old_final_costs.end()); + BaseFloat old_final_cost = cost_iter->second; + + // `arc` is going to become an element of final_arcs_. These + // contain information about transitions from states in clat_ to + // `token-final` states (i.e. states that have a token-label on the arc + // to them and that are final in the canonical compact lattice). + // We subtract the old_final_cost as it was just a temporary cost + // introduced for pruning purposes. + arc.weight.SetWeight(fst::Times(arc.weight.Weight(), + LatticeWeight{-old_final_cost, 0.0})); + // In a slight abuse of the Arc data structure, the nextstate is set to + // the source state. The label (ilabel == olabel) indicates the + // token it is associated with. + arc.nextstate = clat_state; + final_arcs_.push_back(arc); + } + } + } + +} + +bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( + Lattice *raw_fst) { + using Label = CompactLatticeArc::Label; + using StateId = CompactLattice::StateId; + + // old_final_costs is a map from a `token-label` (see glossary) to the + // associated final-prob in a final-state of `raw_fst`, that is associated + // with that Token. These are Tokens that were active at the end of the + // chunk. The final-probs may arise from beta (backward) costs, introduced + // for pruning purposes, and/or from final-probs in HCLG. Those costs will + // not be included in anything we store permamently in this class; they used + // only to guide pruned determinization, and we will use `old_final_costs` + // later to cancel them out. + std::unordered_map old_final_costs; + GetRawLatticeFinalCosts(*raw_fst, &old_final_costs); + + CompactLattice chunk_clat; + bool determinized_till_beam = DeterminizeLatticePhonePrunedWrapper( + trans_model_, raw_fst, config_.lattice_beam, &chunk_clat, + config_.det_opts); + + TopSortCompactLatticeIfNeeded(&chunk_clat); + + std::unordered_map chunk_state_to_token; + IdentifyTokenFinalStates(chunk_clat, + &chunk_state_to_token); + + StateId chunk_num_states = chunk_clat.NumStates(); + if (chunk_num_states == 0) { + // This will be an error but user-level calling code can detect it from the + // lattice being empty. + KALDI_WARN << "Empty lattice, something went wrong."; + clat_.DeleteStates(); + return false; + } + + StateId start_state = chunk_clat.Start(); // would be 0. + KALDI_ASSERT(start_state == 0); + + // Process arcs leaving the start state of chunk_clat. Unless this is the + // first chunk in the lattice, all arcs leaving the start state of chunk_clat + // will have `state labels` on them (identifying redeterminized-states in + // clat_), and will transition to a state in `chunk_clat` that we can identify + // with that redeterminized-state. + + // state_map maps from (non-initial, non-token-final state s in chunk_clat) to + // a state in clat_. + std::unordered_map state_map; + + + CompactLatticeWeight extra_start_weight = CompactLatticeWeight::One(); + bool is_first_chunk = ProcessArcsFromChunkStartState(chunk_clat, &state_map, + &extra_start_weight); + + // Remove any existing arcs in clat_ that leave redeterminized-states, and + // make those states non-final. Below, we'll add arcs leaving those states + // (and possibly new final-probs.) + for (StateId clat_state: non_final_redet_states_) { + clat_.DeleteArcs(clat_state); + clat_.SetFinal(clat_state, CompactLatticeWeight::Zero()); + } + + // The previous final-arc info is no longer relevant; we'll recreate it below. + final_arcs_.clear(); + + // assume chunk_lat.Start() == 0; we asserted it above. Allocate state-ids + // for all remaining states in chunk_clat, except for token-final states. + for (StateId state = (is_first_chunk ? 0 : 1); + state < chunk_num_states; state++) { + if (chunk_state_to_token.count(state) != 0) + continue; // these `token-final` states don't get a state allocated. + + StateId new_clat_state = clat_.NumStates(); + if (state_map.insert({state, new_clat_state}).second) { + // If it was inserted then we need to actually allocate that state + StateId s = AddStateToClat(); + KALDI_ASSERT(s == new_clat_state); + } // else do nothing; it would have been a redeterminized-state and no + } // allocation is needed since they already exist in clat_. and + // in state_map. + + if (is_first_chunk) { + auto iter = state_map.find(start_state); + KALDI_ASSERT(iter != state_map.end()); + CompactLattice::StateId clat_start_state = iter->second; + KALDI_ASSERT(clat_start_state == 0); // topological order. + clat_.SetStart(clat_start_state); + forward_costs_[clat_start_state] = 0.0; + } + + TransferArcsToClat(chunk_clat, is_first_chunk, + state_map, chunk_state_to_token, old_final_costs); + + if (extra_start_weight != CompactLatticeWeight::One()) + ReweightStartState(extra_start_weight); + + GetNonFinalRedetStates(); + + return determinized_till_beam; +} + + + +void LatticeIncrementalDeterminizer::SetFinalCosts( + const unordered_map *token_label2final_cost) { + if (final_arcs_.empty()) { + KALDI_WARN << "SetFinalCosts() called when final_arcs_.empty()... possibly " + "means you are calling this after Finalize()? Not allowed: could " + "indicate a code error. Or possibly decoding failed somehow."; + } + + /* + prefinal states a terminology that does not appear in the paper. What it + means is: the set of states that have an arc with a Token-label as the label + leaving them in the canonical appended lattice. + */ + std::unordered_set &prefinal_states(temp_); + prefinal_states.clear(); + for (const auto &arc: final_arcs_) { + /* Caution: `state` is actually the state the arc would + leave from in the canonical appended lattice; we just store + that in the .nextstate field. */ + CompactLattice::StateId state = arc.nextstate; + prefinal_states.insert(state); + } + + for (int32 state: prefinal_states) + clat_.SetFinal(state, CompactLatticeWeight::Zero()); + + + for (const CompactLatticeArc &arc: final_arcs_) { + Label token_label = arc.ilabel; + /* Note: we store the source state in the .nextstate field. */ + CompactLattice::StateId src_state = arc.nextstate; + BaseFloat graph_final_cost; + if (token_label2final_cost == NULL) { + graph_final_cost = 0.0; + } else { + auto iter = token_label2final_cost->find(token_label); + if (iter == token_label2final_cost->end()) + continue; + else + graph_final_cost = iter->second; + } + /* It might seem odd to set a final-prob on the src-state of the arc.. + the point is that the symbol on the arc is a token-label, which should not + appear in the lattice the user sees, so after that token-label is removed + the arc would just become a final-prob. + */ + clat_.SetFinal(src_state, + fst::Plus(clat_.Final(src_state), + fst::Times(arc.weight, + CompactLatticeWeight( + LatticeWeight(graph_final_cost, 0), {})))); + } +} + + + + +// Instantiate the template for the combination of token types and FST types +// that we'll need. +template class LatticeIncrementalDecoderTpl, decoder::StdToken>; +template class LatticeIncrementalDecoderTpl, + decoder::StdToken>; +template class LatticeIncrementalDecoderTpl, + decoder::StdToken>; +template class LatticeIncrementalDecoderTpl; + +template class LatticeIncrementalDecoderTpl, + decoder::BackpointerToken>; +template class LatticeIncrementalDecoderTpl, + decoder::BackpointerToken>; +template class LatticeIncrementalDecoderTpl, + decoder::BackpointerToken>; +template class LatticeIncrementalDecoderTpl; + +} // end namespace kaldi. diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h new file mode 100644 index 00000000000..7abc370178a --- /dev/null +++ b/src/decoder/lattice-incremental-decoder.h @@ -0,0 +1,752 @@ +// decoder/lattice-incremental-decoder.h + +// Copyright 2019 Zhehuai Chen, Daniel Povey + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_LATTICE_INCREMENTAL_DECODER_H_ +#define KALDI_DECODER_LATTICE_INCREMENTAL_DECODER_H_ + +#include "util/stl-utils.h" +#include "util/hash-list.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/grammar-fst.h" +#include "lattice-faster-decoder.h" + +namespace kaldi { +/** + The normal decoder, lattice-faster-decoder.h, sometimes has an issue when + doing real-time applications with long utterances, that each time you get the + lattice the lattice determinization can take a considerable amount of time; + this introduces latency. This version of the decoder spreads the work of + lattice determinization out throughout the decoding process. + + NOTE: + + Please see https://www.danielpovey.com/files/ *TBD* .pdf for a technical + explanation of what is going on here. + + GLOSSARY OF TERMS: + chunk: We do the determinization on chunks of frames; these + may coincide with the chunks on which the user calls + AdvanceDecoding(). The basic idea is to extract chunks + of the raw lattice and determinize them individually, but + it gets much more complicated than that. The chunks + should normally be at least as long as a word (let's say, + at least 20 frames), or the overhead of this algorithm + might become excessive and affect RTF. + + raw lattice chunk: A chunk of raw (i.e. undeterminized) lattice + that we will determinize. In the paper this corresponds + to the FST B that is described in Section 5.2. + + token_label, state_label / token-label, state-label: + + In the paper these are both referred to as `state labels` (these are + special, large integer id's that refer to states in the undeterminized + lattice and in the the determinized lattice); but we use two separate + terms here, for more clarity, when referring to the undeterminized + vs. determinized lattice. + + token_label conceptually refers to states in the + raw lattice, but we don't materialize the entire + raw lattice as a physical FST and and these tokens + are actually tokens (template type Token) held by + the decoder + + state_label when used in this code refers specifically + to labels that identify states in the determinized + lattice (i.e. state indexes in lat_). + + token-final state + A state in a raw lattice or in a determinized chunk that has an arc + entering it that has a `token-label` on it (as defined above). + These states will have nonzero final-probs. + + redeterminized-non-splice-state, aka ns_redet: + A redeterminized state which is not also a splice state; + refer to the paper for explanation. In the already-determinized + part this means a redeterminized state which is not final. + + canonical appended lattice: This is the appended compact lattice + that we conceptually have (i.e. what we described in the paper). + The difference from the "actual appended lattice" stored + in LatticeIncrementalDeterminizer::clat_ is that the + actual appended lattice has all its final-arcs replaced with + final-probs, and we keep the real final-arcs "on the side" in a + separate data structure. The final-probs in clat_ aren't + necessarily related to the costs on the final-arcs; instead + they can have arbitrary values passed in by the user (e.g. + if we want to include final-probs). This means that the + clat_ can be returned without modification to the user who wants + a partially determinized result. + + final-arc: An arc in the canonical appended CompactLattice which + goes to a final-state. These arcs will have `state-labels` as + their labels. + + */ +struct LatticeIncrementalDecoderConfig { + // All the configuration values until det_opts are the same as in + // LatticeFasterDecoder. For clarity we repeat them rather than inheriting. + BaseFloat beam; + int32 max_active; + int32 min_active; + BaseFloat lattice_beam; + int32 prune_interval; + BaseFloat beam_delta; // has nothing to do with beam_ratio + BaseFloat hash_ratio; + BaseFloat prune_scale; // Note: we don't make this configurable on the command line, + // it's not a very important parameter. It affects the + // algorithm that prunes the tokens as we go. + // Most of the options inside det_opts are not actually queried by the + // LatticeIncrementalDecoder class itself, but by the code that calls it, for + // example in the function DecodeUtteranceLatticeIncremental. + fst::DeterminizeLatticePhonePrunedOptions det_opts; + + // The configuration values from this point on are specific to the + // incremental determinization. See where they are registered for + // explanation. + // Caution: these are only inspected in UpdateLatticeDeterminization(). + // If you call + int32 determinize_max_delay; + int32 determinize_min_chunk_size; + + + LatticeIncrementalDecoderConfig() + : beam(16.0), + max_active(std::numeric_limits::max()), + min_active(200), + lattice_beam(10.0), + prune_interval(25), + beam_delta(0.5), + hash_ratio(2.0), + prune_scale(0.1), + determinize_max_delay(60), + determinize_min_chunk_size(20) { + det_opts.minimize = false; + } + void Register(OptionsItf *opts) { + det_opts.Register(opts); + opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); + opts->Register("max-active", &max_active, + "Decoder max active states. Larger->slower; " + "more accurate"); + opts->Register("min-active", &min_active, "Decoder minimum #active states."); + opts->Register("lattice-beam", &lattice_beam, + "Lattice generation beam. Larger->slower, " + "and deeper lattices"); + opts->Register("prune-interval", &prune_interval, + "Interval (in frames) at " + "which to prune tokens"); + opts->Register("beam-delta", &beam_delta, + "Increment used in decoding-- this " + "parameter is obscure and relates to a speedup in the way the " + "max-active constraint is applied. Larger is more accurate."); + opts->Register("hash-ratio", &hash_ratio, + "Setting used in decoder to " + "control hash behavior"); + opts->Register("determinize-max-delay", &determinize_max_delay, + "Maximum frames of delay between decoding a frame and " + "determinizing it"); + opts->Register("determinize-min-chunk-size", &determinize_min_chunk_size, + "Minimum chunk size used in determinization"); + + } + void Check() const { + if (!(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 && + min_active <= max_active && prune_interval > 0 && + beam_delta > 0.0 && hash_ratio >= 1.0 && + prune_scale > 0.0 && prune_scale < 1.0 && + determinize_max_delay > determinize_min_chunk_size && + determinize_min_chunk_size > 0)) + KALDI_ERR << "Invalid options given to decoder"; + /* Minimization of the chunks is not compatible withour algorithm (or at + least, would require additional complexity to implement.) */ + if (det_opts.minimize || !det_opts.word_determinize) + KALDI_ERR << "Invalid determinization options given to decoder."; + } +}; + + + +/** + This class is used inside LatticeIncrementalDecoderTpl; it handles + some of the details of incremental determinization. + https://www.danielpovey.com/files/ *TBD*.pdf for the paper. + +*/ +class LatticeIncrementalDeterminizer { + public: + using Label = typename LatticeArc::Label; /* Actualy the same labels appear + in both lattice and compact + lattice, so we don't use the + specific type all the time but + just say 'Label' */ + LatticeIncrementalDeterminizer( + const TransitionModel &trans_model, + const LatticeIncrementalDecoderConfig &config): + trans_model_(trans_model), config_(config) { } + + // Resets the lattice determinization data for new utterance + void Init(); + + // Returns the current determinized lattice. + const CompactLattice &GetDeterminizedLattice() const { return clat_; } + + /** + Starts the process of creating a raw lattice chunk. (Search the glossary + for "raw lattice chunk"). This just sets up the initial states and + redeterminized-states in the chunk. Relates to sec. 5.2 in the paper, + specifically the initial-state i and the redeterminized-states. + + After calling this, the caller would add the remaining arcs and states + to `olat` and then call AcceptRawLatticeChunk() with the result. + + @param [out] olat The lattice to be (partially) created + + @param [out] token_label2state This function outputs to here + a map from `token-label` to the state we created for + it in *olat. See glossary for `token-label`. + The keys actually correspond to the .nextstate fields + in the arcs in final_arcs_; values are states in `olat`. + See the last bullet point before Sec. 5.3 in the paper. + */ + void InitializeRawLatticeChunk( + Lattice *olat, + unordered_map *token_label2state); + + /** + This function accepts the raw FST (state-level lattice) corresponding to a + single chunk of the lattice, determinizes it and appends it to this->clat_. + Unless this was the + + Note: final-probs in `raw_fst` are treated specially: they are used to + guide the pruned determinization, but when you call GetLattice() it will be + -- except for pruning effects-- as if all nonzero final-probs in `raw_fst` + were: One() if final_costs == NULL; else the value present in `final_costs`. + + @param [in] raw_fst (Consumed destructively). The input + raw (state-level) lattice. Would correspond to the + FST A in the paper if first_frame == 0, and B + otherwise. + + @return returns false if determinization finished earlier than the beam + or the determinized lattice was empty; true otherwise. + + NOTE: if this is not the final chunk, you will probably want to call + SetFinalCosts() directly after calling this. + */ + bool AcceptRawLatticeChunk(Lattice *raw_fst); + + /* + Sets final-probs in `clat_`. Must only be called if the final chunk + has not been processed. (The final chunk is whenever GetLattice() is + called with finalize == true). + + The reason this is a separate function from AcceptRawLatticeChunk() is that + there may be situations where a user wants to get the latice with + final-probs in it, after previously getting it without final-probs; or + vice versa. By final-probs, we mean the Final() probabilities in the + HCLG (decoding graph; this->fst_). + + @param [in] token_label2final_cost A map from the token-label + corresponding to Tokens active on the final frame of the + lattice in the object, to the final-cost we want to use for + those tokens. If NULL, it means all Tokens should be treated + as final with probability One(). If non-NULL, and a particular + token-label is not a key of this map, it means that Token + corresponded to a state that was not final in HCLG; and + such tokens will be treated as non-final. However, + if this would result in no states in the lattice being final, + we will treat all Tokens as final with probability One(), + a warning will be printed (this should not happen.) + */ + void SetFinalCosts(const unordered_map *token_label2final_cost = NULL); + + const CompactLattice &GetLattice() { return clat_; } + + // kStateLabelOffset is what we add to state-ids in clat_ to produce labels + // to identify them in the raw lattice chunk + // kTokenLabelOffset is where we start allocating labels corresponding to Tokens + // (these correspond with raw lattice states); + enum { kStateLabelOffset = (int)1e8, kTokenLabelOffset = (int)2e8, kMaxTokenLabel = (int)3e8 }; + + private: + + // [called from AcceptRawLatticeChunk()] + // Gets the final costs from token-final states in the raw lattice (see + // glossary for definition). These final costs will be subtracted after + // determinization; in the normal case they are `temporaries` used to guide + // pruning. NOTE: the index of the array is not the FST state that is final, + // but the label on arcs entering it (these will be `token-labels`). Each + // token-final state will have the same label on all arcs entering it. + // + // `old_final_costs` is assumed to be empty at entry. + void GetRawLatticeFinalCosts(const Lattice &raw_fst, + std::unordered_map *old_final_costs); + + // Sets up non_final_redet_states_. See documentation for that variable. + void GetNonFinalRedetStates(); + + /** [called from AcceptRawLatticeChunk()] Processes arcs that leave the + start-state of `chunk_clat` (if this is not the first chunk); does nothing + if this is the first chunk. This includes using the `state-labels` to + work out which states in clat_ these states correspond to, and writing + that mapping to `state_map`. + + Also modifies forward_costs_, because it has to do a kind of reweighting + of the clat states that are the values it puts in `state_map`, to take + account of the probabilities on the arcs from the start state of + chunk_clat to the states corresponding to those redeterminized-states + (i.e. the states in clat corresponding to the values it puts in + `*state_map`). It also modifies arcs_in_, mostly because there + are rare cases when we end up `merging` sets of those redeterminized-states, + because the determinization process mapped them to a single state, + and that means we need to reroute the arcs into members of that + set into one single member (which will appear as a value in + `*state_map`). + + @param [in] chunk_clat The determinized chunk of lattice we are + processing + @param [out] state_map Mapping from states in chunk_clat to + the state in clat_ they correspond to. + @param [out] extra_start_weight If the start-state of + clat_ (its state 0) needs to be modified as + if its incoming arcs were multiplied by + `extra_start_weight`, this isn't possible + using the `in_arcs_` data-structure, + so we remember the extra weight and multiply + it in later, after processing arcs leaving + the start state of clat_. This is set + only if the start-state of clat_ is a + redeterminized state. + @return Returns true if this is the first chunk. + */ + bool ProcessArcsFromChunkStartState( + const CompactLattice &chunk_clat, + std::unordered_map *state_map, + CompactLatticeWeight *extra_start_weight); + + /** + This function, called from AcceptRawLatticeChunk(), takes care of an + unusual situation where we need to reweight the start state of clat_. This + `extra_start_weight` is to be thought of as an extra `incoming` weight, and + we need to left-multiply all the arcs leaving the start state, by it. + + This function does not need to modify forward_costs_; that will + already have been done by ProcessArcsFromChunkStartState(). + */ + void ReweightStartState(CompactLatticeWeight &extra_start_weight); + + + /** + This function, called from AcceptRawLatticeChunk(), transfers arcs from + `chunk_clat` to clat_. For those arcs that have `token-labels` on them, + they don't get written to clat_ but instead are stored in the arcs_ array. + + @param [in] chunk_clat The determinized lattice for the chunk + we are processing; this is the source of the arcs + we are moving. + @param [in] is_first_chunk True if this is the first chunk in the + utterance; it's needed because if it is, we + will also transfer arcs from the start state of + chunk_clat. + @param [in] state_map Map from state-ids in chunk_clat to state-ids + in clat_. + @param [in] chunk_state_to_token Map from `token-final states` + (see glossary) in chunk_clat, to the token-label + on arcs entering those states. + @param [in] old_final_costs Map from token-label to the + final-costs that were on the corresponding + token-final states in the undeterminized lattice; + these final-costs need to be removed when + we record the weights in final_arcs_, because + they were just temporary. + */ + void TransferArcsToClat( + const CompactLattice &chunk_clat, + bool is_first_chunk, + const std::unordered_map &state_map, + const std::unordered_map &chunk_state_to_token, + const std::unordered_map &old_final_costs); + + + + /** + Adds one arc to `clat_`. It's like clat_.AddArc(state, arc), except + it also modifies arcs_in_ and forward_costs_. + */ + void AddArcToClat(CompactLattice::StateId state, + const CompactLatticeArc &arc); + CompactLattice::StateId AddStateToClat(); + + + // Identifies token-final states in `chunk_clat`; see glossary above for + // definition of `token-final`. This function outputs a map from such states + // in chunk_clat, to the `token-label` on arcs entering them. (It is not + // possible that the same state would have multiple arcs entering it with + // different token-labels, or some arcs entering with one token-label and some + // another, or be both initial and have such arcs; this is true due to how we + // construct the raw lattice.) + void IdentifyTokenFinalStates( + const CompactLattice &chunk_clat, + std::unordered_map *token_map) const; + + // trans_model_ is needed by DeterminizeLatticePhonePrunedWrapper() which this + // class calls. + const TransitionModel &trans_model_; + // config_ is needed by DeterminizeLatticePhonePrunedWrapper() which this + // class calls. + const LatticeIncrementalDecoderConfig &config_; + + + // Contains the set of redeterminized-states which are not final in the + // canonical appended lattice. Since the final ones don't physically appear + // in clat_, this means the set of redeterminized-states which are physically + // in clat_. In code terms, this means set of .first elements in final_arcs, + // plus whatever other states in clat_ are reachable from such states. + std::unordered_set non_final_redet_states_; + + + // clat_ is the appended lattice (containing all chunks processed so + // far), except its `final-arcs` (i.e. arcs which in the canonical + // lattice would go to final-states) are not present (they are stored + // separately in final_arcs_) and states which in the canonical lattice + // should have final-arcs leaving them will instead have a final-prob. + CompactLattice clat_; + + + // arcs_in_ is indexed by (state-id in clat_), and is a list of + // arcs that come into this state, in the form (prev-state, + // arc-index). CAUTION: not all these input-arc records will always + // be valid (some may be out-of-date, and may refer to an out-of-range + // arc or an arc that does not point to this state). But all + // input arcs will always be listed. + std::vector > > arcs_in_; + + // final_arcs_ contains arcs which would appear in the canonical appended + // lattice but for implementation reasons are not physically present in clat_. + // These are arcs to final states in the canonical appended lattice. The + // .first elements are the source states in clat_ (these will all be elements + // of non_final_redet_states_); the .nextstate elements of the arcs does not + // contain a physical state, but contain state-labels allocated by + // AllocateNewStateLabel(). + std::vector final_arcs_; + + // forward_costs_, indexed by the state-id in clat_, stores the alpha + // (forward) costs, i.e. the minimum cost from the start state to each state + // in clat_. This is relevant for pruned determinization. The BaseFloat can + // be thought of as the sum of a Value1() + Value2() in a LatticeWeight. + std::vector forward_costs_; + + // temporary used in a function, kept here to avoid excessive reallocation. + std::unordered_set temp_; + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDeterminizer); +}; + + +/** This is an extention to the "normal" lattice-generating decoder. + See \ref lattices_generation \ref decoders_faster and \ref decoders_simple + for more information. + + The main difference is the incremental determinization which will be + discussed in the function GetLattice(). This means that the work of determinizatin + isn't done all at once at the end of the file, but incrementally while decoding. + See the comment at the top of this file for more explanation. + + The decoder is templated on the FST type and the token type. The token type + will normally be StdToken, but also may be BackpointerToken which is to support + quick lookup of the current best path (see lattice-faster-online-decoder.h) + + The FST you invoke this decoder with is expected to be of type + Fst::Fst, a.k.a. StdFst, or GrammarFst. If you invoke it with + FST == StdFst and it notices that the actual FST type is + fst::VectorFst or fst::ConstFst, the decoder object + will internally cast itself to one that is templated on those more specific + types; this is an optimization for speed. + */ +template +class LatticeIncrementalDecoderTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using ForwardLinkT = decoder::ForwardLink; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeIncrementalDecoderTpl(const FST &fst, const TransitionModel &trans_model, + const LatticeIncrementalDecoderConfig &config); + + // This version of the constructor takes ownership of the fst, and will delete + // it when this object is destroyed. + LatticeIncrementalDecoderTpl(const LatticeIncrementalDecoderConfig &config, + FST *fst, const TransitionModel &trans_model); + + void SetOptions(const LatticeIncrementalDecoderConfig &config) { config_ = config; } + + const LatticeIncrementalDecoderConfig &GetOptions() const { return config_; } + + ~LatticeIncrementalDecoderTpl(); + + /** + CAUTION: it's unlikely that you will ever want to call this function. In a + scenario where you have the entire file and just want to decode it, there + is no point using this decoder. + + An example of how to do decoding together with incremental + determinization. It decodes until there are no more frames left in the + "decodable" object. + + In this example, config_.determinize_delay, config_.determinize_period + and config_.determinize_max_active are used to determine the time to + call GetLattice(). + + Users will probably want to use appropriate combinations of + AdvanceDecoding() and GetLattice() to build their application; this just + gives you some idea how. + + The function returns true if any kind of traceback is available (not + necessarily from a final state). + */ + bool Decode(DecodableInterface *decodable); + + /// says whether a final-state was active on the last frame. If it was not, + /// the lattice (or traceback) will end with states that are not final-states. + bool ReachedFinal() const { + return FinalRelativeCost() != std::numeric_limits::infinity(); + } + + /** + This decoder has no GetBestPath() function. + If you need that functionality you should probably use lattice-incremental-online-decoder.h, + which makes it very efficient to obtain the best path. */ + + /** + This GetLattice() function returns the lattice containing + `num_frames_to_decode` frames; this will be all frames decoded so + far, if you let num_frames_to_decode == NumFramesDecoded(), + but it will generally be better to make it a few frames less than + that to avoid the lattice having too many active states at + the end. + + @param [in] num_frames_to_include The number of frames that you want + to be included in the lattice. Must be >= + NumFramesInLattice() and <= NumFramesDecoded(). + + @param [in] use_final_probs True if you want the final-probs + of HCLG to be included in the output lattice. Must not + be set to true if num_frames_to_include != + NumFramesDecoded(). Must be set to true if you have + previously called FinalizeDecoding(). + + (If no state was final on frame `num_frames_to_include`, the + final-probs won't be included regardless of + `use_final_probs`; you can test whether this + was the case by calling ReachedFinal(). + + @return clat The CompactLattice representing what has been decoded + up until `num_frames_to_include` (e.g., LatticeStateTimes() + on this lattice would return `num_frames_to_include`). + + See also UpdateLatticeDeterminizaton(). Caution: this const ref + is only valid until the next time you call AdvanceDecoding() or + GetLattice(). + + CAUTION: the lattice may contain disconnnected states; you should + call Connect() on the output before writing it out. + */ + const CompactLattice &GetLattice(int32 num_frames_to_include, + bool use_final_probs = false); + + /* + Returns the number of frames in the currently-determinized part of the + lattice which will be a number in [0, NumFramesDecoded()]. It will + be the largest number that GetLattice() was called with, but note + that GetLattice() may be called from UpdateLatticeDeterminization(). + + Made available in case the user wants to give that same number to + GetLattice(). + */ + int NumFramesInLattice() const { return num_frames_in_lattice_; } + + /** + InitDecoding initializes the decoding, and should only be used if you + intend to call AdvanceDecoding(). If you call Decode(), you don't need to + call this. You can also call InitDecoding if you have already decoded an + utterance and want to start with a new utterance. + */ + void InitDecoding(); + + /** + This will decode until there are no more frames ready in the decodable + object. You can keep calling it each time more frames become available + (this is the normal pattern in a real-time/online decoding scenario). + If max_num_frames is specified, it specifies the maximum number of frames + the function will decode before returning. + */ + void AdvanceDecoding(DecodableInterface *decodable, int32 max_num_frames = -1); + + + /** FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives + more information. It returns the difference between the best (final-cost + plus cost) of any token on the final frame, and the best cost of any token + on the final frame. If it is infinity it means no final-states were + present on the final frame. It will usually be nonnegative. If it not + too positive (e.g. < 5 is my first guess, but this is not tested) you can + take it as a good indication that we reached the final-state with + reasonable likelihood. */ + BaseFloat FinalRelativeCost() const; + + /** Returns the number of frames decoded so far. */ + inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } + + /** + Finalizes the decoding, doing an extra pruning step on the last frame + that uses the final-probs. May be called only once. + */ + void FinalizeDecoding(); + + protected: + /* Some protected things are needed in LatticeIncrementalOnlineDecoderTpl. */ + + /** NOTE: for parts the internal implementation that are shared with LatticeFasterDecoer, + we have removed the comments.*/ + inline static void DeleteForwardLinks(Token *tok); + struct TokenList { + Token *toks; + bool must_prune_forward_links; + bool must_prune_tokens; + int32 num_toks; /* Note: you can only trust `num_toks` if must_prune_tokens + * == false, because it is only set in + * PruneTokensForFrame(). */ + TokenList() + : toks(NULL), must_prune_forward_links(true), must_prune_tokens(true), + num_toks(-1) {} + }; + using Elem = typename HashList::Elem; + void PossiblyResizeHash(size_t num_toks); + inline Token *FindOrAddToken(StateId state, int32 frame_plus_one, + BaseFloat tot_cost, Token *backpointer, bool *changed); + void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, BaseFloat delta); + void ComputeFinalCosts(unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const; + void PruneForwardLinksFinal(); + void PruneTokensForFrame(int32 frame_plus_one); + void PruneActiveTokens(BaseFloat delta); + BaseFloat GetCutoff(Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, + Elem **best_elem); + BaseFloat ProcessEmitting(DecodableInterface *decodable); + void ProcessNonemitting(BaseFloat cost_cutoff); + + HashList toks_; + std::vector active_toks_; // indexed by frame. + std::vector queue_; // temp variable used in ProcessNonemitting, + std::vector tmp_array_; // used in GetCutoff. + const FST *fst_; + bool delete_fst_; + std::vector cost_offsets_; + int32 num_toks_; + bool warned_; + bool decoding_finalized_; + + unordered_map final_costs_; + BaseFloat final_relative_cost_; + BaseFloat final_best_cost_; + + /*********************** + Variables below this point relate to the incremental + determinization. + *********************/ + LatticeIncrementalDecoderConfig config_; + /** Much of the the incremental determinization algorithm is encapsulated in + the determinize_ object. */ + LatticeIncrementalDeterminizer determinizer_; + + + /* Just a temporary used in a function; stored here to avoid reallocation. */ + unordered_map temp_token_map_; + + /** num_frames_in_lattice_ is the highest `num_frames_to_include_` argument + for any prior call to GetLattice(). */ + int32 num_frames_in_lattice_; + + // A map from Token to its token_label. Will contain an entry for + // each Token in active_toks_[num_frames_in_lattice_]. + unordered_map token2label_map_; + + // A temporary used in a function, kept here to avoid reallocation. + unordered_map token2label_map_temp_; + + // we allocate a unique id for each Token + Label next_token_label_; + + inline Label AllocateNewTokenLabel() { return next_token_label_++; } + + + // There are various cleanup tasks... the the toks_ structure contains + // singly linked lists of Token pointers, where Elem is the list type. + // It also indexes them in a hash, indexed by state (this hash is only + // maintained for the most recent frame). toks_.Clear() + // deletes them from the hash and returns the list of Elems. The + // function DeleteElems calls toks_.Delete(elem) for each elem in + // the list, which returns ownership of the Elem to the toks_ structure + // for reuse, but does not delete the Token pointer. The Token pointers + // are reference-counted and are ultimately deleted in PruneTokensForFrame, + // but are also linked together on each frame by their own linked-list, + // using the "next" pointer. We delete them manually. + void DeleteElems(Elem *list); + + void ClearActiveTokens(); + + + // Returns the number of active tokens on frame `frame`. Can be used as part + // of a heuristic to decide which frame to determinize until, if you are not + // at the end of an utterance. + int32 GetNumToksForFrame(int32 frame); + + /** + UpdateLatticeDeterminization() ensures the work of determinization is kept + up to date so that when you do need the lattice you can get it fast. It + uses the configuration values `determinize_delay`, `determinize_max_delay` + and `determinize_min_chunk_size` to decide whether and when to call + GetLattice(). You can safely call this as often as you want (e.g. after + each time you call AdvanceDecoding(); it won't do subtantially more work if + it is called frequently. + */ + void UpdateLatticeDeterminization(); + + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDecoderTpl); +}; + +typedef LatticeIncrementalDecoderTpl + LatticeIncrementalDecoder; + + +} // end namespace kaldi. + +#endif diff --git a/src/decoder/lattice-incremental-online-decoder.cc b/src/decoder/lattice-incremental-online-decoder.cc new file mode 100644 index 00000000000..85f902bde3d --- /dev/null +++ b/src/decoder/lattice-incremental-online-decoder.cc @@ -0,0 +1,150 @@ +// decoder/lattice-incremental-online-decoder.cc + +// Copyright 2019 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// see note at the top of lattice-faster-decoder.cc, about how to maintain this +// file in sync with lattice-faster-decoder.cc + +#include "decoder/lattice-incremental-decoder.h" +#include "decoder/lattice-incremental-online-decoder.h" +#include "lat/lattice-functions.h" +#include "base/timer.h" + +namespace kaldi { + +// Outputs an FST corresponding to the single best path through the lattice. +template +bool LatticeIncrementalOnlineDecoderTpl::GetBestPath(Lattice *olat, + bool use_final_probs) const { + olat->DeleteStates(); + BaseFloat final_graph_cost; + BestPathIterator iter = BestPathEnd(use_final_probs, &final_graph_cost); + if (iter.Done()) + return false; // would have printed warning. + StateId state = olat->AddState(); + olat->SetFinal(state, LatticeWeight(final_graph_cost, 0.0)); + while (!iter.Done()) { + LatticeArc arc; + iter = TraceBackBestPath(iter, &arc); + arc.nextstate = state; + StateId new_state = olat->AddState(); + olat->AddArc(new_state, arc); + state = new_state; + } + olat->SetStart(state); + return true; +} + +template +typename LatticeIncrementalOnlineDecoderTpl::BestPathIterator LatticeIncrementalOnlineDecoderTpl::BestPathEnd( + bool use_final_probs, + BaseFloat *final_cost_out) const { + if (this->decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "BestPathEnd() with use_final_probs == false"; + KALDI_ASSERT(this->NumFramesDecoded() > 0 && + "You cannot call BestPathEnd if no frames were decoded."); + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (this->decoding_finalized_ ? this->final_costs_ :final_costs_local); + if (!this->decoding_finalized_ && use_final_probs) + this->ComputeFinalCosts(&final_costs_local, NULL, NULL); + + // Singly linked list of tokens on last frame (access list through "next" + // pointer). + BaseFloat best_cost = std::numeric_limits::infinity(); + BaseFloat best_final_cost = 0; + Token *best_tok = NULL; + for (Token *tok = this->active_toks_.back().toks; + tok != NULL; tok = tok->next) { + BaseFloat cost = tok->tot_cost, final_cost = 0.0; + if (use_final_probs && !final_costs.empty()) { + // if we are instructed to use final-probs, and any final tokens were + // active on final frame, include the final-prob in the cost of the token. + typename unordered_map::const_iterator + iter = final_costs.find(tok); + if (iter != final_costs.end()) { + final_cost = iter->second; + cost += final_cost; + } else { + cost = std::numeric_limits::infinity(); + } + } + if (cost < best_cost) { + best_cost = cost; + best_tok = tok; + best_final_cost = final_cost; + } + } + if (best_tok == NULL) { // this should not happen, and is likely a code error or + // caused by infinities in likelihoods, but I'm not making + // it a fatal error for now. + KALDI_WARN << "No final token found."; + } + if (final_cost_out == NULL) + *final_cost_out = best_final_cost; + return BestPathIterator(best_tok, this->NumFramesDecoded() - 1); +} + + +template +typename LatticeIncrementalOnlineDecoderTpl::BestPathIterator LatticeIncrementalOnlineDecoderTpl::TraceBackBestPath( + BestPathIterator iter, LatticeArc *oarc) const { + KALDI_ASSERT(!iter.Done() && oarc != NULL); + Token *tok = static_cast(iter.tok); + int32 cur_t = iter.frame, ret_t = cur_t; + if (tok->backpointer != NULL) { + ForwardLinkT *link; + for (link = tok->backpointer->links; + link != NULL; link = link->next) { + if (link->next_tok == tok) { // this is the link to "tok" + oarc->ilabel = link->ilabel; + oarc->olabel = link->olabel; + BaseFloat graph_cost = link->graph_cost, + acoustic_cost = link->acoustic_cost; + if (link->ilabel != 0) { + KALDI_ASSERT(static_cast(cur_t) < this->cost_offsets_.size()); + acoustic_cost -= this->cost_offsets_[cur_t]; + ret_t--; + } + oarc->weight = LatticeWeight(graph_cost, acoustic_cost); + break; + } + } + if (link == NULL) { // Did not find correct link. + KALDI_ERR << "Error tracing best-path back (likely " + << "bug in token-pruning algorithm)"; + } + } else { + oarc->ilabel = 0; + oarc->olabel = 0; + oarc->weight = LatticeWeight::One(); // zero costs. + } + return BestPathIterator(tok->backpointer, ret_t); +} + +// Instantiate the template for the FST types that we'll need. +template class LatticeIncrementalOnlineDecoderTpl >; +template class LatticeIncrementalOnlineDecoderTpl >; +template class LatticeIncrementalOnlineDecoderTpl >; +template class LatticeIncrementalOnlineDecoderTpl; + + +} // end namespace kaldi. diff --git a/src/decoder/lattice-incremental-online-decoder.h b/src/decoder/lattice-incremental-online-decoder.h new file mode 100644 index 00000000000..8bd41c851ab --- /dev/null +++ b/src/decoder/lattice-incremental-online-decoder.h @@ -0,0 +1,132 @@ +// decoder/lattice-incremental-online-decoder.h + +// Copyright 2019 Zhehuai Chen +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// see note at the top of lattice-faster-decoder.h, about how to maintain this +// file in sync with lattice-faster-decoder.h + + +#ifndef KALDI_DECODER_LATTICE_INCREMENTAL_ONLINE_DECODER_H_ +#define KALDI_DECODER_LATTICE_INCREMENTAL_ONLINE_DECODER_H_ + +#include "util/stl-utils.h" +#include "util/hash-list.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/lattice-incremental-decoder.h" + +namespace kaldi { + + + +/** LatticeIncrementalOnlineDecoderTpl is as LatticeIncrementalDecoderTpl but also + supports an efficient way to get the best path (see the function + BestPathEnd()), which is useful in endpointing and in situations where you + might want to frequently access the best path. + + This is only templated on the FST type, since the Token type is required to + be BackpointerToken. Actually it only makes sense to instantiate + LatticeIncrementalDecoderTpl with Token == BackpointerToken if you do so indirectly via + this child class. + */ +template +class LatticeIncrementalOnlineDecoderTpl: + public LatticeIncrementalDecoderTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using Token = decoder::BackpointerToken; + using ForwardLinkT = decoder::ForwardLink; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeIncrementalOnlineDecoderTpl(const FST &fst, + const TransitionModel &trans_model, + const LatticeIncrementalDecoderConfig &config): + LatticeIncrementalDecoderTpl(fst, trans_model, config) { } + + // This version of the initializer takes ownership of 'fst', and will delete + // it when this object is destroyed. + LatticeIncrementalOnlineDecoderTpl(const LatticeIncrementalDecoderConfig &config, + FST *fst, + const TransitionModel &trans_model): + LatticeIncrementalDecoderTpl(config, fst, trans_model) { } + + + struct BestPathIterator { + void *tok; + int32 frame; + // note, "frame" is the frame-index of the frame you'll get the + // transition-id for next time, if you call TraceBackBestPath on this + // iterator (assuming it's not an epsilon transition). Note that this + // is one less than you might reasonably expect, e.g. it's -1 for + // the nonemitting transitions before the first frame. + BestPathIterator(void *t, int32 f): tok(t), frame(f) { } + bool Done() { return tok == NULL; } + }; + + + /// Outputs an FST corresponding to the single best path through the lattice. + /// This is quite efficient because it doesn't get the entire raw lattice and find + /// the best path through it; instead, it uses the BestPathEnd and BestPathIterator + /// so it basically traces it back through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + bool GetBestPath(Lattice *ofst, + bool use_final_probs = true) const; + + + + /// This function returns an iterator that can be used to trace back + /// the best path. If use_final_probs == true and at least one final state + /// survived till the end, it will use the final-probs in working out the best + /// final Token, and will output the final cost to *final_cost (if non-NULL), + /// else it will use only the forward likelihood, and will put zero in + /// *final_cost (if non-NULL). + /// Requires that NumFramesDecoded() > 0. + BestPathIterator BestPathEnd(bool use_final_probs, + BaseFloat *final_cost = NULL) const; + + + /// This function can be used in conjunction with BestPathEnd() to trace back + /// the best path one link at a time (e.g. this can be useful in endpoint + /// detection). By "link" we mean a link in the graph; not all links cross + /// frame boundaries, but each time you see a nonzero ilabel you can interpret + /// that as a frame. The return value is the updated iterator. It outputs + /// the ilabel and olabel, and the (graph and acoustic) weight to the "arc" pointer, + /// while leaving its "nextstate" variable unchanged. + BestPathIterator TraceBackBestPath( + BestPathIterator iter, LatticeArc *arc) const; + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalOnlineDecoderTpl); +}; + +typedef LatticeIncrementalOnlineDecoderTpl LatticeIncrementalOnlineDecoder; + + +} // end namespace kaldi. + +#endif diff --git a/src/lat/determinize-lattice-pruned.h b/src/lat/determinize-lattice-pruned.h index 8e1858aa2b1..35323709815 100644 --- a/src/lat/determinize-lattice-pruned.h +++ b/src/lat/determinize-lattice-pruned.h @@ -105,8 +105,8 @@ namespace fst { representation" and hence the "minimal representation" will be the same. We can use this to reduce compute. Note that if two initial representations are different, this does not preclude the other representations from being the same. - -*/ + +*/ struct DeterminizeLatticePrunedOptions { @@ -190,7 +190,7 @@ template bool DeterminizeLatticePruned( const ExpandedFst > &ifst, double prune, - MutableFst > *ofst, + MutableFst > *ofst, DeterminizeLatticePrunedOptions opts = DeterminizeLatticePrunedOptions()); @@ -199,7 +199,7 @@ bool DeterminizeLatticePruned( (i.e. the sequences of output symbols are represented directly as strings The input FST must be topologically sorted in order for the algorithm to work. For efficiency it is recommended to sort the ilabel for the input FST as well. - Returns true on success, and false if it had to terminate the determinization + Returns true on normal success, and false if it had to terminate the determinization earlier than specified by the "prune" beam-- that is, if it terminated because of the max_mem, max_loop or max_arcs constraints in the options. CAUTION: if Lattice is the input, you need to Invert() before calling this, @@ -261,7 +261,7 @@ bool DeterminizeLatticePhonePruned( = DeterminizeLatticePhonePrunedOptions()); /** "Destructive" version of DeterminizeLatticePhonePruned() where the input - lattice might be changed. + lattice might be changed. */ template bool DeterminizeLatticePhonePruned( diff --git a/src/lat/lattice-functions.cc b/src/lat/lattice-functions.cc index 7f484f95233..f4a184f3cd4 100644 --- a/src/lat/lattice-functions.cc +++ b/src/lat/lattice-functions.cc @@ -1107,7 +1107,6 @@ void CompactLatticeShortestPath(const CompactLattice &clat, // Now we can assume it's topologically sorted. shortest_path->DeleteStates(); if (clat.Start() == kNoStateId) return; - KALDI_ASSERT(clat.Start() == 0); // since top-sorted. typedef CompactLatticeArc Arc; typedef Arc::StateId StateId; typedef CompactLatticeWeight Weight; @@ -1117,7 +1116,7 @@ void CompactLatticeShortestPath(const CompactLattice &clat, best_cost_and_pred[s].first = std::numeric_limits::infinity(); best_cost_and_pred[s].second = fst::kNoStateId; } - best_cost_and_pred[0].first = 0; + best_cost_and_pred[clat.Start()].first = 0; for (StateId s = 0; s < clat.NumStates(); s++) { double my_cost = best_cost_and_pred[s].first; for (ArcIterator aiter(clat, s); @@ -1139,8 +1138,8 @@ void CompactLatticeShortestPath(const CompactLattice &clat, } } std::vector states; // states on best path. - StateId cur_state = superfinal; - while (cur_state != 0) { + StateId cur_state = superfinal, start_state = clat.Start(); + while (cur_state != start_state) { StateId prev_state = best_cost_and_pred[cur_state].second; if (prev_state == kNoStateId) { KALDI_WARN << "Failure in best-path algorithm for lattice (infinite costs?)"; diff --git a/src/online2/Makefile b/src/online2/Makefile index 242c7be6da6..bbc7ac07bb1 100644 --- a/src/online2/Makefile +++ b/src/online2/Makefile @@ -9,7 +9,7 @@ OBJFILES = online-gmm-decodable.o online-feature-pipeline.o online-ivector-featu online-nnet2-feature-pipeline.o online-gmm-decoding.o online-timing.o \ online-endpoint.o onlinebin-util.o online-speex-wrapper.o \ online-nnet2-decoding.o online-nnet2-decoding-threaded.o \ - online-nnet3-decoding.o + online-nnet3-decoding.o online-nnet3-incremental-decoding.o LIBNAME = kaldi-online2 diff --git a/src/online2/online-endpoint.cc b/src/online2/online-endpoint.cc index aa7752c4484..a3be0791f03 100644 --- a/src/online2/online-endpoint.cc +++ b/src/online2/online-endpoint.cc @@ -71,10 +71,10 @@ bool EndpointDetected(const OnlineEndpointConfig &config, return false; } -template +template int32 TrailingSilenceLength(const TransitionModel &tmodel, const std::string &silence_phones_str, - const LatticeFasterOnlineDecoderTpl &decoder) { + const DEC &decoder) { std::vector silence_phones; if (!SplitStringToIntegers(silence_phones_str, ":", false, &silence_phones)) KALDI_ERR << "Bad --silence-phones option in endpointing config: " @@ -87,7 +87,7 @@ int32 TrailingSilenceLength(const TransitionModel &tmodel, ConstIntegerSet silence_set(silence_phones); bool use_final_probs = false; - typename LatticeFasterOnlineDecoderTpl::BestPathIterator iter = + typename DEC::BestPathIterator iter = decoder.BestPathEnd(use_final_probs, NULL); int32 num_silence_frames = 0; while (!iter.Done()) { // we're going backwards in time from the most @@ -117,7 +117,7 @@ bool EndpointDetected( BaseFloat final_relative_cost = decoder.FinalRelativeCost(); int32 num_frames_decoded = decoder.NumFramesDecoded(), - trailing_silence_frames = TrailingSilenceLength(tmodel, + trailing_silence_frames = TrailingSilenceLength>(tmodel, config.silence_phones, decoder); @@ -125,6 +125,26 @@ bool EndpointDetected( frame_shift_in_seconds, final_relative_cost); } +template +bool EndpointDetected( + const OnlineEndpointConfig &config, + const TransitionModel &tmodel, + BaseFloat frame_shift_in_seconds, + const LatticeIncrementalOnlineDecoderTpl &decoder) { + if (decoder.NumFramesDecoded() == 0) return false; + + BaseFloat final_relative_cost = decoder.FinalRelativeCost(); + + int32 num_frames_decoded = decoder.NumFramesDecoded(), + trailing_silence_frames = TrailingSilenceLength>(tmodel, + config.silence_phones, + decoder); + + return EndpointDetected(config, num_frames_decoded, trailing_silence_frames, + frame_shift_in_seconds, final_relative_cost); +} + + // Instantiate EndpointDetected for the types we need. // It will require TrailingSilenceLength so we don't have to instantiate that. @@ -143,5 +163,21 @@ bool EndpointDetected( BaseFloat frame_shift_in_seconds, const LatticeFasterOnlineDecoderTpl &decoder); +template +bool EndpointDetected >( + const OnlineEndpointConfig &config, + const TransitionModel &tmodel, + BaseFloat frame_shift_in_seconds, + const LatticeIncrementalOnlineDecoderTpl > &decoder); + + +template +bool EndpointDetected( + const OnlineEndpointConfig &config, + const TransitionModel &tmodel, + BaseFloat frame_shift_in_seconds, + const LatticeIncrementalOnlineDecoderTpl &decoder); + + } // namespace kaldi diff --git a/src/online2/online-endpoint.h b/src/online2/online-endpoint.h index aaf9232db13..3171f0c532c 100644 --- a/src/online2/online-endpoint.h +++ b/src/online2/online-endpoint.h @@ -35,6 +35,7 @@ #include "lat/kaldi-lattice.h" #include "hmm/transition-model.h" #include "decoder/lattice-faster-online-decoder.h" +#include "decoder/lattice-incremental-online-decoder.h" namespace kaldi { /// @addtogroup onlinedecoding OnlineDecoding @@ -187,10 +188,10 @@ bool EndpointDetected(const OnlineEndpointConfig &config, /// integer id's of phones that we consider silence. We use the the /// BestPathEnd() and TraceBackOneLink() functions of LatticeFasterOnlineDecoder /// to do this efficiently. -template +template int32 TrailingSilenceLength(const TransitionModel &tmodel, const std::string &silence_phones, - const LatticeFasterOnlineDecoderTpl &decoder); + const DEC &decoder); /// This is a higher-level convenience function that works out the @@ -202,6 +203,15 @@ bool EndpointDetected( BaseFloat frame_shift_in_seconds, const LatticeFasterOnlineDecoderTpl &decoder); +/// This is a higher-level convenience function that works out the +/// arguments to the EndpointDetected function above, from the decoder. +template +bool EndpointDetected( + const OnlineEndpointConfig &config, + const TransitionModel &tmodel, + BaseFloat frame_shift_in_seconds, + const LatticeIncrementalOnlineDecoderTpl &decoder); + diff --git a/src/online2/online-ivector-feature.cc b/src/online2/online-ivector-feature.cc index 32a4db70097..3a15ac9a325 100644 --- a/src/online2/online-ivector-feature.cc +++ b/src/online2/online-ivector-feature.cc @@ -519,6 +519,57 @@ void OnlineSilenceWeighting::ComputeCurrentTraceback( } } +template +void OnlineSilenceWeighting::ComputeCurrentTraceback( + const LatticeIncrementalOnlineDecoderTpl &decoder) { + int32 num_frames_decoded = decoder.NumFramesDecoded(), + num_frames_prev = frame_info_.size(); + // note, num_frames_prev is not the number of frames previously decoded, + // it's the generally-larger number of frames that we were requested to + // provide weights for. + if (num_frames_prev < num_frames_decoded) + frame_info_.resize(num_frames_decoded); + if (num_frames_prev > num_frames_decoded && + frame_info_[num_frames_decoded].transition_id != -1) + KALDI_ERR << "Number of frames decoded decreased"; // Likely bug + + if (num_frames_decoded == 0) + return; + int32 frame = num_frames_decoded - 1; + bool use_final_probs = false; + typename LatticeIncrementalOnlineDecoderTpl::BestPathIterator iter = + decoder.BestPathEnd(use_final_probs, NULL); + while (frame >= 0) { + LatticeArc arc; + arc.ilabel = 0; + while (arc.ilabel == 0) // the while loop skips over input-epsilons + iter = decoder.TraceBackBestPath(iter, &arc); + // note, the iter.frame values are slightly unintuitively defined, + // they are one less than you might expect. + KALDI_ASSERT(iter.frame == frame - 1); + + if (frame_info_[frame].token == iter.tok) { + // we know that the traceback from this point back will be identical, so + // no point tracing back further. Note: we are comparing memory addresses + // of tokens of the decoder; this guarantees it's the same exact token, + // because tokens, once allocated on a frame, are only deleted, never + // reallocated for that frame. + break; + } + + if (num_frames_output_and_correct_ > frame) + num_frames_output_and_correct_ = frame; + + frame_info_[frame].token = iter.tok; + frame_info_[frame].transition_id = arc.ilabel; + frame--; + // leave frame_info_.current_weight at zero for now (as set in the + // constructor), reflecting that we haven't already output a weight for that + // frame. + } +} + + // Instantiate the template OnlineSilenceWeighting::ComputeCurrentTraceback(). template void OnlineSilenceWeighting::ComputeCurrentTraceback >( @@ -526,6 +577,13 @@ void OnlineSilenceWeighting::ComputeCurrentTraceback >( template void OnlineSilenceWeighting::ComputeCurrentTraceback( const LatticeFasterOnlineDecoderTpl &decoder); +template +void OnlineSilenceWeighting::ComputeCurrentTraceback >( + const LatticeIncrementalOnlineDecoderTpl > &decoder); +template +void OnlineSilenceWeighting::ComputeCurrentTraceback( + const LatticeIncrementalOnlineDecoderTpl &decoder); + void OnlineSilenceWeighting::GetDeltaWeights( int32 num_frames_ready, int32 first_decoder_frame, diff --git a/src/online2/online-ivector-feature.h b/src/online2/online-ivector-feature.h index 12bc5c6bb2f..0d02ab06eff 100644 --- a/src/online2/online-ivector-feature.h +++ b/src/online2/online-ivector-feature.h @@ -33,6 +33,7 @@ #include "feat/online-feature.h" #include "ivector/ivector-extractor.h" #include "decoder/lattice-faster-online-decoder.h" +#include "decoder/lattice-incremental-online-decoder.h" namespace kaldi { /// @addtogroup onlinefeat OnlineFeatureExtraction @@ -480,6 +481,8 @@ class OnlineSilenceWeighting { // It will be instantiated for FST == fst::Fst and fst::GrammarFst. template void ComputeCurrentTraceback(const LatticeFasterOnlineDecoderTpl &decoder); + template + void ComputeCurrentTraceback(const LatticeIncrementalOnlineDecoderTpl &decoder); // Calling this function gets the changes in weight that require us to modify // the stats... the output format is (frame-index, delta-weight). diff --git a/src/online2/online-nnet3-incremental-decoding.cc b/src/online2/online-nnet3-incremental-decoding.cc new file mode 100644 index 00000000000..5e7acf147ee --- /dev/null +++ b/src/online2/online-nnet3-incremental-decoding.cc @@ -0,0 +1,75 @@ +// online2/online-nnet3-incremental-decoding.cc + +// Copyright 2019 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "online2/online-nnet3-incremental-decoding.h" +#include "lat/lattice-functions.h" +#include "lat/determinize-lattice-pruned.h" +#include "decoder/grammar-fst.h" + +namespace kaldi { + +template +SingleUtteranceNnet3IncrementalDecoderTpl::SingleUtteranceNnet3IncrementalDecoderTpl( + const LatticeIncrementalDecoderConfig &decoder_opts, + const TransitionModel &trans_model, + const nnet3::DecodableNnetSimpleLoopedInfo &info, + const FST &fst, + OnlineNnet2FeaturePipeline *features): + decoder_opts_(decoder_opts), + input_feature_frame_shift_in_seconds_(features->FrameShiftInSeconds()), + trans_model_(trans_model), + decodable_(trans_model_, info, + features->InputFeature(), features->IvectorFeature()), + decoder_(fst, trans_model, decoder_opts_) { + decoder_.InitDecoding(); +} + +template +void SingleUtteranceNnet3IncrementalDecoderTpl::InitDecoding(int32 frame_offset) { + decoder_.InitDecoding(); + decodable_.SetFrameOffset(frame_offset); +} + +template +void SingleUtteranceNnet3IncrementalDecoderTpl::AdvanceDecoding() { + decoder_.AdvanceDecoding(&decodable_); +} + +template +void SingleUtteranceNnet3IncrementalDecoderTpl::GetBestPath(bool end_of_utterance, + Lattice *best_path) const { + decoder_.GetBestPath(best_path, end_of_utterance); +} + +template +bool SingleUtteranceNnet3IncrementalDecoderTpl::EndpointDetected( + const OnlineEndpointConfig &config) { + BaseFloat output_frame_shift = + input_feature_frame_shift_in_seconds_ * + decodable_.FrameSubsamplingFactor(); + return kaldi::EndpointDetected(config, trans_model_, + output_frame_shift, decoder_); +} + + +// Instantiate the template for the types needed. +template class SingleUtteranceNnet3IncrementalDecoderTpl >; +template class SingleUtteranceNnet3IncrementalDecoderTpl; + +} // namespace kaldi diff --git a/src/online2/online-nnet3-incremental-decoding.h b/src/online2/online-nnet3-incremental-decoding.h new file mode 100644 index 00000000000..e407cc2be2b --- /dev/null +++ b/src/online2/online-nnet3-incremental-decoding.h @@ -0,0 +1,148 @@ +// online2/online-nnet3-incremental-decoding.h + +// Copyright 2019 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_ONLINE2_ONLINE_NNET3_INCREMENTAL_DECODING_H_ +#define KALDI_ONLINE2_ONLINE_NNET3_INCREMENTAL_DECODING_H_ + +#include +#include +#include + +#include "nnet3/decodable-online-looped.h" +#include "matrix/matrix-lib.h" +#include "util/common-utils.h" +#include "base/kaldi-error.h" +#include "itf/online-feature-itf.h" +#include "online2/online-endpoint.h" +#include "online2/online-nnet2-feature-pipeline.h" +#include "decoder/lattice-incremental-online-decoder.h" +#include "hmm/transition-model.h" +#include "hmm/posterior.h" + +namespace kaldi { +/// @addtogroup onlinedecoding OnlineDecoding +/// @{ + + +/** + You will instantiate this class when you want to decode a single utterance + using the online-decoding setup for neural nets. The template will be + instantiated only for FST = fst::Fst and FST = fst::GrammarFst. +*/ + +template +class SingleUtteranceNnet3IncrementalDecoderTpl { + public: + + // Constructor. The pointer 'features' is not being given to this class to own + // and deallocate, it is owned externally. + SingleUtteranceNnet3IncrementalDecoderTpl(const LatticeIncrementalDecoderConfig &decoder_opts, + const TransitionModel &trans_model, + const nnet3::DecodableNnetSimpleLoopedInfo &info, + const FST &fst, + OnlineNnet2FeaturePipeline *features); + + /// Initializes the decoding and sets the frame offset of the underlying + /// decodable object. This method is called by the constructor. You can also + /// call this method when you want to reset the decoder state, but want to + /// keep using the same decodable object, e.g. in case of an endpoint. + void InitDecoding(int32 frame_offset = 0); + + /// Advances the decoding as far as we can. + void AdvanceDecoding(); + + /// Finalizes the decoding. Cleans up and prunes remaining tokens, so the + /// GetLattice() call will return faster. You must not call this before + /// calling (TerminateDecoding() or InputIsFinished()) and then Wait(). + void FinalizeDecoding() { decoder_.FinalizeDecoding(); } + + int32 NumFramesDecoded() const { return decoder_.NumFramesDecoded(); } + + int32 NumFramesInLattice() const { return decoder_.NumFramesInLattice(); } + + /* Gets the lattice. The output lattice has any acoustic scaling in it + (which will typically be desirable in an online-decoding context); if you + want an un-scaled lattice, scale it using ScaleLattice() with the inverse + of the acoustic weight. + + @param [in] num_frames_to_include The number of frames you want + to be included in the lattice. Must be in the range + [NumFramesInLattice().. NumFramesDecoded()]. If you + make it a few frames less than NumFramesDecoded(), it + will save significant computation. + @param [in] use_final_probs True if you want the lattice to + contain final-probs (if at least one state was final + on the most recently decoded frame). Must be false + if num_frames_to_include < NumFramesDecoded(). + Must be true if you have previously called + FinalizeDecoding(). + */ + const CompactLattice &GetLattice(int32 num_frames_to_include, + bool use_final_probs = false) { + return decoder_.GetLattice(num_frames_to_include, use_final_probs); + } + + + + + + /// Outputs an FST corresponding to the single best path through the current + /// lattice. If "use_final_probs" is true AND we reached the final-state of + /// the graph then it will include those as final-probs, else it will treat + /// all final-probs as one. + void GetBestPath(bool end_of_utterance, + Lattice *best_path) const; + + + /// This function calls EndpointDetected from online-endpoint.h, + /// with the required arguments. + bool EndpointDetected(const OnlineEndpointConfig &config); + + const LatticeIncrementalOnlineDecoderTpl &Decoder() const { return decoder_; } + + ~SingleUtteranceNnet3IncrementalDecoderTpl() { } + private: + + const LatticeIncrementalDecoderConfig &decoder_opts_; + + // this is remembered from the constructor; it's ultimately + // derived from calling FrameShiftInSeconds() on the feature pipeline. + BaseFloat input_feature_frame_shift_in_seconds_; + + // we need to keep a reference to the transition model around only because + // it's needed by the endpointing code. + const TransitionModel &trans_model_; + + nnet3::DecodableAmNnetLoopedOnline decodable_; + + LatticeIncrementalOnlineDecoderTpl decoder_; + +}; + + +typedef SingleUtteranceNnet3IncrementalDecoderTpl > SingleUtteranceNnet3IncrementalDecoder; + +/// @} End of "addtogroup onlinedecoding" + +} // namespace kaldi + + + +#endif // KALDI_ONLINE2_ONLINE_NNET3_DECODING_H_ diff --git a/src/online2bin/Makefile b/src/online2bin/Makefile index 28c135eb950..2552e7148dc 100644 --- a/src/online2bin/Makefile +++ b/src/online2bin/Makefile @@ -12,7 +12,7 @@ BINFILES = online2-wav-gmm-latgen-faster apply-cmvn-online \ online2-wav-dump-features ivector-randomize \ online2-wav-nnet2-am-compute online2-wav-nnet2-latgen-threaded \ online2-wav-nnet3-latgen-faster online2-wav-nnet3-latgen-grammar \ - online2-tcp-nnet3-decode-faster + online2-tcp-nnet3-decode-faster online2-wav-nnet3-latgen-incremental OBJFILES = diff --git a/src/online2bin/online2-wav-nnet3-latgen-faster.cc b/src/online2bin/online2-wav-nnet3-latgen-faster.cc index 1549dd6ae52..c7fb3806e6b 100644 --- a/src/online2bin/online2-wav-nnet3-latgen-faster.cc +++ b/src/online2bin/online2-wav-nnet3-latgen-faster.cc @@ -58,7 +58,8 @@ void GetDiagnosticsAndPrintOutput(const std::string &utt, *tot_like += likelihood; KALDI_VLOG(2) << "Likelihood per frame for utterance " << utt << " is " << (likelihood / num_frames) << " over " << num_frames - << " frames."; + << " frames, = " << (-weight.Value1() / num_frames) + << ',' << (weight.Value2() / num_frames); if (word_syms != NULL) { std::cerr << utt << ' '; diff --git a/src/online2bin/online2-wav-nnet3-latgen-incremental.cc b/src/online2bin/online2-wav-nnet3-latgen-incremental.cc new file mode 100644 index 00000000000..aaa87f24de1 --- /dev/null +++ b/src/online2bin/online2-wav-nnet3-latgen-incremental.cc @@ -0,0 +1,306 @@ +// online2bin/online2-wav-nnet3-latgen-incremental.cc + +// Copyright 2019 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "feat/wave-reader.h" +#include "online2/online-nnet3-incremental-decoding.h" +#include "online2/online-nnet2-feature-pipeline.h" +#include "online2/onlinebin-util.h" +#include "online2/online-timing.h" +#include "online2/online-endpoint.h" +#include "fstext/fstext-lib.h" +#include "lat/lattice-functions.h" +#include "util/kaldi-thread.h" +#include "nnet3/nnet-utils.h" + +namespace kaldi { + +void GetDiagnosticsAndPrintOutput(const std::string &utt, + const fst::SymbolTable *word_syms, + const CompactLattice &clat, + int64 *tot_num_frames, + double *tot_like) { + if (clat.NumStates() == 0) { + KALDI_WARN << "Empty lattice."; + return; + } + CompactLattice best_path_clat; + CompactLatticeShortestPath(clat, &best_path_clat); + + Lattice best_path_lat; + ConvertLattice(best_path_clat, &best_path_lat); + + double likelihood; + LatticeWeight weight; + int32 num_frames; + std::vector alignment; + std::vector words; + GetLinearSymbolSequence(best_path_lat, &alignment, &words, &weight); + num_frames = alignment.size(); + likelihood = -(weight.Value1() + weight.Value2()); + *tot_num_frames += num_frames; + *tot_like += likelihood; + KALDI_VLOG(2) << "Likelihood per frame for utterance " << utt << " is " + << (likelihood / num_frames) << " over " << num_frames + << " frames, = " << (-weight.Value1() / num_frames) + << ',' << (weight.Value2() / num_frames); + + if (word_syms != NULL) { + std::cerr << utt << ' '; + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms->Find(words[i]); + if (s == "") + KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; + std::cerr << s << ' '; + } + std::cerr << std::endl; + } +} + +} + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace fst; + + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "Reads in wav file(s) and simulates online decoding with neural nets\n" + "(nnet3 setup), with optional iVector-based speaker adaptation and\n" + "optional endpointing. Note: some configuration values and inputs are\n" + "set via config files whose filenames are passed as options\n" + "The lattice determinization algorithm here can operate\n" + "incrementally.\n" + "\n" + "Usage: online2-wav-nnet3-latgen-incremental [options] " + " \n" + "The spk2utt-rspecifier can just be if\n" + "you want to decode utterance by utterance.\n"; + + ParseOptions po(usage); + + std::string word_syms_rxfilename; + + // feature_opts includes configuration for the iVector adaptation, + // as well as the basic features. + OnlineNnet2FeaturePipelineConfig feature_opts; + nnet3::NnetSimpleLoopedComputationOptions decodable_opts; + LatticeIncrementalDecoderConfig decoder_opts; + OnlineEndpointConfig endpoint_opts; + + BaseFloat chunk_length_secs = 0.18; + bool do_endpointing = false; + bool online = true; + + po.Register("chunk-length", &chunk_length_secs, + "Length of chunk size in seconds, that we process. Set to <= 0 " + "to use all input in one chunk."); + po.Register("word-symbol-table", &word_syms_rxfilename, + "Symbol table for words [for debug output]"); + po.Register("do-endpointing", &do_endpointing, + "If true, apply endpoint detection"); + po.Register("online", &online, + "You can set this to false to disable online iVector estimation " + "and have all the data for each utterance used, even at " + "utterance start. This is useful where you just want the best " + "results and don't care about online operation. Setting this to " + "false has the same effect as setting " + "--use-most-recent-ivector=true and --greedy-ivector-extractor=true " + "in the file given to --ivector-extraction-config, and " + "--chunk-length=-1."); + po.Register("num-threads-startup", &g_num_threads, + "Number of threads used when initializing iVector extractor."); + + feature_opts.Register(&po); + decodable_opts.Register(&po); + decoder_opts.Register(&po); + endpoint_opts.Register(&po); + + + po.Read(argc, argv); + + if (po.NumArgs() != 5) { + po.PrintUsage(); + return 1; + } + + std::string nnet3_rxfilename = po.GetArg(1), + fst_rxfilename = po.GetArg(2), + spk2utt_rspecifier = po.GetArg(3), + wav_rspecifier = po.GetArg(4), + clat_wspecifier = po.GetArg(5); + + OnlineNnet2FeaturePipelineInfo feature_info(feature_opts); + + if (!online) { + feature_info.ivector_extractor_info.use_most_recent_ivector = true; + feature_info.ivector_extractor_info.greedy_ivector_extractor = true; + chunk_length_secs = -1.0; + } + + TransitionModel trans_model; + nnet3::AmNnetSimple am_nnet; + { + bool binary; + Input ki(nnet3_rxfilename, &binary); + trans_model.Read(ki.Stream(), binary); + am_nnet.Read(ki.Stream(), binary); + SetBatchnormTestMode(true, &(am_nnet.GetNnet())); + SetDropoutTestMode(true, &(am_nnet.GetNnet())); + nnet3::CollapseModel(nnet3::CollapseModelConfig(), &(am_nnet.GetNnet())); + } + + // this object contains precomputed stuff that is used by all decodable + // objects. It takes a pointer to am_nnet because if it has iVectors it has + // to modify the nnet to accept iVectors at intervals. + nnet3::DecodableNnetSimpleLoopedInfo decodable_info(decodable_opts, + &am_nnet); + + + fst::Fst *decode_fst = ReadFstKaldiGeneric(fst_rxfilename); + + fst::SymbolTable *word_syms = NULL; + if (word_syms_rxfilename != "") + if (!(word_syms = fst::SymbolTable::ReadText(word_syms_rxfilename))) + KALDI_ERR << "Could not read symbol table from file " + << word_syms_rxfilename; + + int32 num_done = 0, num_err = 0; + double tot_like = 0.0; + int64 num_frames = 0; + + SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier); + RandomAccessTableReader wav_reader(wav_rspecifier); + CompactLatticeWriter clat_writer(clat_wspecifier); + + OnlineTimingStats timing_stats; + + for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) { + std::string spk = spk2utt_reader.Key(); + const std::vector &uttlist = spk2utt_reader.Value(); + OnlineIvectorExtractorAdaptationState adaptation_state( + feature_info.ivector_extractor_info); + for (size_t i = 0; i < uttlist.size(); i++) { + std::string utt = uttlist[i]; + if (!wav_reader.HasKey(utt)) { + KALDI_WARN << "Did not find audio for utterance " << utt; + num_err++; + continue; + } + const WaveData &wave_data = wav_reader.Value(utt); + // get the data for channel zero (if the signal is not mono, we only + // take the first channel). + SubVector data(wave_data.Data(), 0); + + OnlineNnet2FeaturePipeline feature_pipeline(feature_info); + feature_pipeline.SetAdaptationState(adaptation_state); + + OnlineSilenceWeighting silence_weighting( + trans_model, + feature_info.silence_weighting_config, + decodable_opts.frame_subsampling_factor); + + SingleUtteranceNnet3IncrementalDecoder decoder(decoder_opts, trans_model, + decodable_info, + *decode_fst, &feature_pipeline); + OnlineTimer decoding_timer(utt); + + BaseFloat samp_freq = wave_data.SampFreq(); + int32 chunk_length; + if (chunk_length_secs > 0) { + chunk_length = int32(samp_freq * chunk_length_secs); + if (chunk_length == 0) chunk_length = 1; + } else { + chunk_length = std::numeric_limits::max(); + } + + int32 samp_offset = 0; + std::vector > delta_weights; + + while (samp_offset < data.Dim()) { + int32 samp_remaining = data.Dim() - samp_offset; + int32 num_samp = chunk_length < samp_remaining ? chunk_length + : samp_remaining; + + SubVector wave_part(data, samp_offset, num_samp); + feature_pipeline.AcceptWaveform(samp_freq, wave_part); + + samp_offset += num_samp; + decoding_timer.WaitUntil(samp_offset / samp_freq); + if (samp_offset == data.Dim()) { + // no more input. flush out last frames + feature_pipeline.InputFinished(); + } + + if (silence_weighting.Active() && + feature_pipeline.IvectorFeature() != NULL) { + silence_weighting.ComputeCurrentTraceback(decoder.Decoder()); + silence_weighting.GetDeltaWeights(feature_pipeline.NumFramesReady(), + &delta_weights); + feature_pipeline.IvectorFeature()->UpdateFrameWeights(delta_weights); + } + + decoder.AdvanceDecoding(); + + if (do_endpointing && decoder.EndpointDetected(endpoint_opts)) { + break; + } + } + decoder.FinalizeDecoding(); + + bool use_final_probs = true; + CompactLattice clat = decoder.GetLattice(decoder.NumFramesDecoded(), + use_final_probs); + + Connect(&clat); + GetDiagnosticsAndPrintOutput(utt, word_syms, clat, + &num_frames, &tot_like); + + decoding_timer.OutputStats(&timing_stats); + + // In an application you might avoid updating the adaptation state if + // you felt the utterance had low confidence. See lat/confidence.h + feature_pipeline.GetAdaptationState(&adaptation_state); + + // we want to output the lattice with un-scaled acoustics. + BaseFloat inv_acoustic_scale = + 1.0 / decodable_opts.acoustic_scale; + ScaleLattice(AcousticLatticeScale(inv_acoustic_scale), &clat); + + clat_writer.Write(utt, clat); + KALDI_LOG << "Decoded utterance " << utt; + num_done++; + } + } + timing_stats.Print(online); + + KALDI_LOG << "Decoded " << num_done << " utterances, " + << num_err << " with errors."; + KALDI_LOG << "Overall likelihood per frame was " << (tot_like / num_frames) + << " per frame over " << num_frames << " frames."; + delete decode_fst; + delete word_syms; // will delete if non-NULL. + return (num_done != 0 ? 0 : 1); + } catch(const std::exception& e) { + std::cerr << e.what(); + return -1; + } +} // main()