From f77a735425f523026e14db3bb58efc79a5254779 Mon Sep 17 00:00:00 2001 From: LvHang Date: Tue, 25 Jun 2019 23:55:43 -0400 Subject: [PATCH] [src] Update Insert function of hashlist and decoders (#3402) makes interface of HashList more standard; slight speed improvement. fix alignment bug which is caused by hashlist insert --- src/decoder/biglm-faster-decoder.h | 26 +++++++-------- src/decoder/faster-decoder.cc | 26 +++++++-------- src/decoder/faster-decoder.h | 2 +- src/decoder/lattice-biglm-faster-decoder.h | 37 +++++++++++----------- src/decoder/lattice-faster-decoder.cc | 32 ++++++++++--------- src/decoder/lattice-faster-decoder.h | 8 ++--- src/util/hash-list-inl.h | 16 ++++++++-- src/util/hash-list.h | 14 ++++---- 8 files changed, 84 insertions(+), 77 deletions(-) diff --git a/src/decoder/biglm-faster-decoder.h b/src/decoder/biglm-faster-decoder.h index a6b99fba95e..8e36deb8bb6 100644 --- a/src/decoder/biglm-faster-decoder.h +++ b/src/decoder/biglm-faster-decoder.h @@ -397,13 +397,11 @@ class BiglmFasterDecoder { if (new_weight < next_weight_cutoff) { // not pruned.. PairId next_pair = ConstructPair(arc.nextstate, next_lm_state); Token *new_tok = new Token(arc, ac_weight, tok); - Elem *e_found = toks_.Find(next_pair); + Elem *e_found = toks_.Insert(next_pair, new_tok); if (new_weight + adaptive_beam < next_weight_cutoff) next_weight_cutoff = new_weight + adaptive_beam; - if (e_found == NULL) { - toks_.Insert(next_pair, new_tok); - } else { - if ( *(e_found->val) < *new_tok ) { + if (e_found->val != new_tok) { + if (*(e_found->val) < *new_tok) { Token::TokenDelete(e_found->val); e_found->val = new_tok; } else { @@ -426,11 +424,12 @@ class BiglmFasterDecoder { // Processes nonemitting arcs for one frame. KALDI_ASSERT(queue_.empty()); for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) - queue_.push_back(e->key); + queue_.push_back(e); while (!queue_.empty()) { - PairId state_pair = queue_.back(); + const Elem *e = queue_.back(); queue_.pop_back(); - Token *tok = toks_.Find(state_pair)->val; // would segfault if state not + PairId state_pair = e->key; + Token *tok = e->val; // would segfault if state not // in toks_ but this can't happen. if (tok->weight_.Value() > cutoff) { // Don't bother processing successors. continue; @@ -450,15 +449,14 @@ class BiglmFasterDecoder { if (new_tok->weight_.Value() > cutoff) { // prune Token::TokenDelete(new_tok); } else { - Elem *e_found = toks_.Find(next_pair); - if (e_found == NULL) { - toks_.Insert(next_pair, new_tok); - queue_.push_back(next_pair); + Elem *e_found = toks_.Insert(next_pair, new_tok); + if (e_found->val == new_tok) { + queue_.push_back(e_found); } else { if ( *(e_found->val) < *new_tok ) { Token::TokenDelete(e_found->val); e_found->val = new_tok; - queue_.push_back(next_pair); + queue_.push_back(e_found); } else { Token::TokenDelete(new_tok); } @@ -477,7 +475,7 @@ class BiglmFasterDecoder { fst::DeterministicOnDemandFst *lm_diff_fst_; BiglmFasterDecoderOptions opts_; bool warned_noarc_; - std::vector queue_; // temp variable used in ProcessNonemitting, + std::vector queue_; // temp variable used in ProcessNonemitting, std::vector tmp_array_; // used in GetCutoff. // make it class member to avoid internal new/delete. diff --git a/src/decoder/faster-decoder.cc b/src/decoder/faster-decoder.cc index 105289eb6d7..84b3424f119 100644 --- a/src/decoder/faster-decoder.cc +++ b/src/decoder/faster-decoder.cc @@ -277,13 +277,11 @@ double FasterDecoder::ProcessEmitting(DecodableInterface *decodable) { double new_weight = arc.weight.Value() + tok->cost_ + ac_cost; if (new_weight < next_weight_cutoff) { // not pruned.. Token *new_tok = new Token(arc, ac_cost, tok); - Elem *e_found = toks_.Find(arc.nextstate); + Elem *e_found = toks_.Insert(arc.nextstate, new_tok); if (new_weight + adaptive_beam < next_weight_cutoff) next_weight_cutoff = new_weight + adaptive_beam; - if (e_found == NULL) { - toks_.Insert(arc.nextstate, new_tok); - } else { - if ( *(e_found->val) < *new_tok ) { + if (e_found->val != new_tok) { + if (*(e_found->val) < *new_tok) { Token::TokenDelete(e_found->val); e_found->val = new_tok; } else { @@ -307,11 +305,12 @@ void FasterDecoder::ProcessNonemitting(double cutoff) { // Processes nonemitting arcs for one frame. KALDI_ASSERT(queue_.empty()); for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) - queue_.push_back(e->key); + queue_.push_back(e); while (!queue_.empty()) { - StateId state = queue_.back(); + const Elem* e = queue_.back(); queue_.pop_back(); - Token *tok = toks_.Find(state)->val; // would segfault if state not + StateId state = e->key; + Token *tok = e->val; // would segfault if state not // in toks_ but this can't happen. if (tok->cost_ > cutoff) { // Don't bother processing successors. continue; @@ -326,15 +325,14 @@ void FasterDecoder::ProcessNonemitting(double cutoff) { if (new_tok->cost_ > cutoff) { // prune Token::TokenDelete(new_tok); } else { - Elem *e_found = toks_.Find(arc.nextstate); - if (e_found == NULL) { - toks_.Insert(arc.nextstate, new_tok); - queue_.push_back(arc.nextstate); + Elem *e_found = toks_.Insert(arc.nextstate, new_tok); + if (e_found->val == new_tok) { + queue_.push_back(e_found); } else { - if ( *(e_found->val) < *new_tok ) { + if (*(e_found->val) < *new_tok) { Token::TokenDelete(e_found->val); e_found->val = new_tok; - queue_.push_back(arc.nextstate); + queue_.push_back(e_found); } else { Token::TokenDelete(new_tok); } diff --git a/src/decoder/faster-decoder.h b/src/decoder/faster-decoder.h index baedcc022b6..db03569614f 100644 --- a/src/decoder/faster-decoder.h +++ b/src/decoder/faster-decoder.h @@ -170,7 +170,7 @@ class FasterDecoder { HashList toks_; const fst::Fst &fst_; FasterDecoderOptions config_; - std::vector queue_; // temp variable used in ProcessNonemitting, + std::vector queue_; // temp variable used in ProcessNonemitting, std::vector tmp_array_; // used in GetCutoff. // make it class member to avoid internal new/delete. diff --git a/src/decoder/lattice-biglm-faster-decoder.h b/src/decoder/lattice-biglm-faster-decoder.h index 6276c25a83d..9ea53a95836 100644 --- a/src/decoder/lattice-biglm-faster-decoder.h +++ b/src/decoder/lattice-biglm-faster-decoder.h @@ -312,14 +312,14 @@ class LatticeBiglmFasterDecoder { // for the current frame. [note: it's inserted if necessary into hash toks_ // and also into the singly linked list of tokens active on this frame // (whose head is at active_toks_[frame]). - inline Token *FindOrAddToken(PairId state_pair, int32 frame, BaseFloat tot_cost, - bool emitting, bool *changed) { + inline Elem *FindOrAddToken(PairId state_pair, int32 frame, + BaseFloat tot_cost, bool emitting, bool *changed) { // Returns the Token pointer. Sets "changed" (if non-NULL) to true // if the token was newly created or the cost changed. KALDI_ASSERT(frame < active_toks_.size()); Token *&toks = active_toks_[frame].toks; - Elem *e_found = toks_.Find(state_pair); - if (e_found == NULL) { // no such token presently. + Elem *e_found = toks_.Insert(state_pair, NULL); + if (e_found->val == NULL) { // no such token presently. const BaseFloat extra_cost = 0.0; // tokens on the currently final frame have zero extra_cost // as any of them could end up @@ -328,9 +328,9 @@ class LatticeBiglmFasterDecoder { // NULL: no forward links yet toks = new_tok; num_toks_++; - toks_.Insert(state_pair, new_tok); + e_found->val = new_tok; if (changed) *changed = true; - return new_tok; + return e_found; } else { Token *tok = e_found->val; // There is an existing Token for this state. if (tok->tot_cost > tot_cost) { // replace old token @@ -346,7 +346,7 @@ class LatticeBiglmFasterDecoder { } else { if (changed) *changed = false; } - return tok; + return e_found; } } @@ -744,11 +744,11 @@ class LatticeBiglmFasterDecoder { else if (tot_cost + config_.beam < next_cutoff) next_cutoff = tot_cost + config_.beam; // prune by best current token PairId next_pair = ConstructPair(arc.nextstate, next_lm_state); - Token *next_tok = FindOrAddToken(next_pair, frame, tot_cost, true, NULL); + Elem *e_next = FindOrAddToken(next_pair, frame, tot_cost, true, NULL); // true: emitting, NULL: no change indicator needed // Add ForwardLink from tok to next_tok (put on head of list tok->links) - tok->links = new ForwardLink(next_tok, arc.ilabel, arc.olabel, + tok->links = new ForwardLink(e_next->val, arc.ilabel, arc.olabel, graph_cost, ac_cost, tok->links); } } // for all arcs @@ -770,7 +770,7 @@ class LatticeBiglmFasterDecoder { KALDI_ASSERT(queue_.empty()); BaseFloat best_cost = std::numeric_limits::infinity(); for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) { - queue_.push_back(e->key); + queue_.push_back(e); // for pruning with current best token best_cost = std::min(best_cost, static_cast(e->val->tot_cost)); } @@ -784,11 +784,12 @@ class LatticeBiglmFasterDecoder { BaseFloat cutoff = best_cost + config_.beam; while (!queue_.empty()) { - PairId state_pair = queue_.back(); + const Elem *e = queue_.back(); queue_.pop_back(); - Token *tok = toks_.Find(state_pair)->val; // would segfault if state not in - // toks_ but this can't happen. + PairId state_pair = e->key; + Token *tok = e->val; // would segfault if state not in + // toks_ but this can't happen. BaseFloat cur_cost = tok->tot_cost; if (cur_cost > cutoff) // Don't bother processing successors. continue; @@ -812,15 +813,15 @@ class LatticeBiglmFasterDecoder { if (tot_cost < cutoff) { bool changed; PairId next_pair = ConstructPair(arc.nextstate, next_lm_state); - Token *new_tok = FindOrAddToken(next_pair, frame, tot_cost, - false, &changed); // false: non-emit + Elem *e_new = FindOrAddToken(next_pair, frame, tot_cost, + false, &changed); // false: non-emit - tok->links = new ForwardLink(new_tok, 0, arc.olabel, + tok->links = new ForwardLink(e_new->val, 0, arc.olabel, graph_cost, 0, tok->links); // "changed" tells us whether the new token has a different // cost from before, or is new [if so, add into queue]. - if (changed) queue_.push_back(next_pair); + if (changed) queue_.push_back(e_new); } } } // for all arcs @@ -835,7 +836,7 @@ class LatticeBiglmFasterDecoder { std::vector active_toks_; // Lists of tokens, indexed by // frame (members of TokenList are toks, must_prune_forward_links, // must_prune_tokens). - std::vector queue_; // temp variable used in ProcessNonemitting, + std::vector queue_; // temp variable used in ProcessNonemitting, std::vector tmp_array_; // used in GetCutoff. // make it class member to avoid internal new/delete. const fst::Fst &fst_; diff --git a/src/decoder/lattice-faster-decoder.cc b/src/decoder/lattice-faster-decoder.cc index 2bc8c7cdef4..7b5725bc1c2 100644 --- a/src/decoder/lattice-faster-decoder.cc +++ b/src/decoder/lattice-faster-decoder.cc @@ -263,15 +263,16 @@ void LatticeFasterDecoderTpl::PossiblyResizeHash(size_t num_toks) { // and also into the singly linked list of tokens active on this frame // (whose head is at active_toks_[frame]). template -inline Token* LatticeFasterDecoderTpl::FindOrAddToken( +inline typename LatticeFasterDecoderTpl::Elem* +LatticeFasterDecoderTpl::FindOrAddToken( StateId state, int32 frame_plus_one, BaseFloat tot_cost, Token *backpointer, bool *changed) { // Returns the Token pointer. Sets "changed" (if non-NULL) to true // if the token was newly created or the cost changed. KALDI_ASSERT(frame_plus_one < active_toks_.size()); Token *&toks = active_toks_[frame_plus_one].toks; - Elem *e_found = toks_.Find(state); - if (e_found == NULL) { // no such token presently. + Elem *e_found = toks_.Insert(state, NULL); + if (e_found->val == NULL) { // no such token presently. const BaseFloat extra_cost = 0.0; // tokens on the currently final frame have zero extra_cost // as any of them could end up @@ -280,9 +281,9 @@ inline Token* LatticeFasterDecoderTpl::FindOrAddToken( // NULL: no forward links yet toks = new_tok; num_toks_++; - toks_.Insert(state, new_tok); + e_found->val = new_tok; if (changed) *changed = true; - return new_tok; + return e_found; } else { Token *tok = e_found->val; // There is an existing Token for this state. if (tok->tot_cost > tot_cost) { // replace old token @@ -301,7 +302,7 @@ inline Token* LatticeFasterDecoderTpl::FindOrAddToken( } else { if (changed) *changed = false; } - return tok; + return e_found; } } @@ -800,12 +801,12 @@ BaseFloat LatticeFasterDecoderTpl::ProcessEmitting( next_cutoff = tot_cost + adaptive_beam; // prune by best current token // Note: the frame indexes into active_toks_ are one-based, // hence the + 1. - Token *next_tok = FindOrAddToken(arc.nextstate, - frame + 1, tot_cost, tok, NULL); + Elem *e_next = FindOrAddToken(arc.nextstate, + frame + 1, tot_cost, tok, NULL); // NULL: no change indicator needed // Add ForwardLink from tok to next_tok (put on head of list tok->links) - tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, + tok->links = new ForwardLinkT(e_next->val, arc.ilabel, arc.olabel, graph_cost, ac_cost, tok->links); } } // for all arcs @@ -855,14 +856,15 @@ void LatticeFasterDecoderTpl::ProcessNonemitting(BaseFloat cutoff) { for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) { StateId state = e->key; if (fst_->NumInputEpsilons(state) != 0) - queue_.push_back(state); + queue_.push_back(e); } while (!queue_.empty()) { - StateId state = queue_.back(); + const Elem *e = queue_.back(); queue_.pop_back(); - Token *tok = toks_.Find(state)->val; // would segfault if state not in toks_ but this can't happen. + StateId state = e->key; + Token *tok = e->val; // would segfault if e is a NULL pointer but this can't happen. BaseFloat cur_cost = tok->tot_cost; if (cur_cost > cutoff) // Don't bother processing successors. continue; @@ -882,16 +884,16 @@ void LatticeFasterDecoderTpl::ProcessNonemitting(BaseFloat cutoff) { if (tot_cost < cutoff) { bool changed; - Token *new_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, + Elem *e_new = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, tok, &changed); - tok->links = new ForwardLinkT(new_tok, 0, arc.olabel, + tok->links = new ForwardLinkT(e_new->val, 0, arc.olabel, graph_cost, 0, tok->links); // "changed" tells us whether the new token has a different // cost from before, or is new [if so, add into queue]. if (changed && fst_->NumInputEpsilons(arc.nextstate) != 0) - queue_.push_back(arc.nextstate); + queue_.push_back(e_new); } } } // for all arcs diff --git a/src/decoder/lattice-faster-decoder.h b/src/decoder/lattice-faster-decoder.h index 5f8c0778723..eb725c45559 100644 --- a/src/decoder/lattice-faster-decoder.h +++ b/src/decoder/lattice-faster-decoder.h @@ -380,9 +380,9 @@ class LatticeFasterDecoderTpl { // token was newly created or the cost changed. // If Token == StdToken, the 'backpointer' argument has no purpose (and will // hopefully be optimized out). - inline Token *FindOrAddToken(StateId state, int32 frame_plus_one, - BaseFloat tot_cost, Token *backpointer, - bool *changed); + inline Elem *FindOrAddToken(StateId state, int32 frame_plus_one, + BaseFloat tot_cost, Token *backpointer, + bool *changed); // prunes outgoing links for all tokens in active_toks_[frame] // it's called by PruneActiveTokens @@ -464,7 +464,7 @@ class LatticeFasterDecoderTpl { std::vector active_toks_; // Lists of tokens, indexed by // frame (members of TokenList are toks, must_prune_forward_links, // must_prune_tokens). - std::vector queue_; // temp variable used in ProcessNonemitting, + std::vector queue_; // temp variable used in ProcessNonemitting, std::vector tmp_array_; // used in GetCutoff. // fst_ is a pointer to the FST we are decoding from. diff --git a/src/util/hash-list-inl.h b/src/util/hash-list-inl.h index 3fe16182b82..da6165af784 100644 --- a/src/util/hash-list-inl.h +++ b/src/util/hash-list-inl.h @@ -121,15 +121,24 @@ HashList::~HashList() { } } - template -void HashList::Insert(I key, T val) { +inline typename HashList::Elem* HashList::Insert(I key, T val) { size_t index = (static_cast(key) % hash_size_); HashBucket &bucket = buckets_[index]; + // Check the element is existing or not. + if (bucket.last_elem != NULL) { + Elem *head = (bucket.prev_bucket == static_cast(-1) ? + list_head_ : + buckets_[bucket.prev_bucket].last_elem->tail), + *tail = bucket.last_elem->tail; + for (Elem *e = head; e != tail; e = e->tail) + if (e->key == key) return e; + } + + // This is a new element. Insert it. Elem *elem = New(); elem->key = key; elem->val = val; - if (bucket.last_elem == NULL) { // Unoccupied bucket. Insert at // head of bucket list (which is tail of regular list, they go in // opposite directions). @@ -152,6 +161,7 @@ void HashList::Insert(I key, T val) { bucket.last_elem->tail = elem; bucket.last_elem = elem; } + return elem; } template diff --git a/src/util/hash-list.h b/src/util/hash-list.h index 67257d053cd..9ae0043f050 100644 --- a/src/util/hash-list.h +++ b/src/util/hash-list.h @@ -86,14 +86,12 @@ template class HashList { /// is free to modify the "val" element. inline Elem *Find(I key); - /// Insert inserts a new element into the hashtable/stored list. By calling - /// this, - /// the user asserts that it is not already present (e.g. Find was called and - /// returned NULL). With current code, calling this if an element already - /// exists will result in duplicate elements in the structure, and Find() - /// will find the first one that was added. - /// [but we don't guarantee this behavior]. - inline void Insert(I key, T val); + /// Insert inserts a new element into the hashtable/stored list. + /// Because element keys in a hashtable are unique, this operation checks + /// whether each inserted element has a key equivalent to the one of an + /// element already in the hashtable. If so, the element is not inserted, + /// returning an pointer to this existing element. + inline Elem *Insert(I key, T val); /// Insert inserts another element with same key into the hashtable/ /// stored list.