Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions src/decoder/biglm-faster-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -397,13 +397,11 @@ class BiglmFasterDecoder {
if (new_weight < next_weight_cutoff) { // not pruned..
PairId next_pair = ConstructPair(arc.nextstate, next_lm_state);
Token *new_tok = new Token(arc, ac_weight, tok);
Elem *e_found = toks_.Find(next_pair);
Elem *e_found = toks_.Insert(next_pair, new_tok);
if (new_weight + adaptive_beam < next_weight_cutoff)
next_weight_cutoff = new_weight + adaptive_beam;
if (e_found == NULL) {
toks_.Insert(next_pair, new_tok);
} else {
if ( *(e_found->val) < *new_tok ) {
if (e_found->val != new_tok) {
if (*(e_found->val) < *new_tok) {
Token::TokenDelete(e_found->val);
e_found->val = new_tok;
} else {
Expand All @@ -426,11 +424,12 @@ class BiglmFasterDecoder {
// Processes nonemitting arcs for one frame.
KALDI_ASSERT(queue_.empty());
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail)
queue_.push_back(e->key);
queue_.push_back(e);
while (!queue_.empty()) {
PairId state_pair = queue_.back();
const Elem *e = queue_.back();
queue_.pop_back();
Token *tok = toks_.Find(state_pair)->val; // would segfault if state not
PairId state_pair = e->key;
Token *tok = e->val; // would segfault if state not
// in toks_ but this can't happen.
if (tok->weight_.Value() > cutoff) { // Don't bother processing successors.
continue;
Expand All @@ -450,15 +449,14 @@ class BiglmFasterDecoder {
if (new_tok->weight_.Value() > cutoff) { // prune
Token::TokenDelete(new_tok);
} else {
Elem *e_found = toks_.Find(next_pair);
if (e_found == NULL) {
toks_.Insert(next_pair, new_tok);
queue_.push_back(next_pair);
Elem *e_found = toks_.Insert(next_pair, new_tok);
if (e_found->val == new_tok) {
queue_.push_back(e_found);
} else {
if ( *(e_found->val) < *new_tok ) {
Token::TokenDelete(e_found->val);
e_found->val = new_tok;
queue_.push_back(next_pair);
queue_.push_back(e_found);
} else {
Token::TokenDelete(new_tok);
}
Expand All @@ -477,7 +475,7 @@ class BiglmFasterDecoder {
fst::DeterministicOnDemandFst<fst::StdArc> *lm_diff_fst_;
BiglmFasterDecoderOptions opts_;
bool warned_noarc_;
std::vector<PairId> queue_; // temp variable used in ProcessNonemitting,
std::vector<const Elem* > queue_; // temp variable used in ProcessNonemitting,
std::vector<BaseFloat> tmp_array_; // used in GetCutoff.
// make it class member to avoid internal new/delete.

Expand Down
26 changes: 12 additions & 14 deletions src/decoder/faster-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,11 @@ double FasterDecoder::ProcessEmitting(DecodableInterface *decodable) {
double new_weight = arc.weight.Value() + tok->cost_ + ac_cost;
if (new_weight < next_weight_cutoff) { // not pruned..
Token *new_tok = new Token(arc, ac_cost, tok);
Elem *e_found = toks_.Find(arc.nextstate);
Elem *e_found = toks_.Insert(arc.nextstate, new_tok);
if (new_weight + adaptive_beam < next_weight_cutoff)
next_weight_cutoff = new_weight + adaptive_beam;
if (e_found == NULL) {
toks_.Insert(arc.nextstate, new_tok);
} else {
if ( *(e_found->val) < *new_tok ) {
if (e_found->val != new_tok) {
if (*(e_found->val) < *new_tok) {
Token::TokenDelete(e_found->val);
e_found->val = new_tok;
} else {
Expand All @@ -307,11 +305,12 @@ void FasterDecoder::ProcessNonemitting(double cutoff) {
// Processes nonemitting arcs for one frame.
KALDI_ASSERT(queue_.empty());
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail)
queue_.push_back(e->key);
queue_.push_back(e);
while (!queue_.empty()) {
StateId state = queue_.back();
const Elem* e = queue_.back();
queue_.pop_back();
Token *tok = toks_.Find(state)->val; // would segfault if state not
StateId state = e->key;
Token *tok = e->val; // would segfault if state not
// in toks_ but this can't happen.
if (tok->cost_ > cutoff) { // Don't bother processing successors.
continue;
Expand All @@ -326,15 +325,14 @@ void FasterDecoder::ProcessNonemitting(double cutoff) {
if (new_tok->cost_ > cutoff) { // prune
Token::TokenDelete(new_tok);
} else {
Elem *e_found = toks_.Find(arc.nextstate);
if (e_found == NULL) {
toks_.Insert(arc.nextstate, new_tok);
queue_.push_back(arc.nextstate);
Elem *e_found = toks_.Insert(arc.nextstate, new_tok);
if (e_found->val == new_tok) {
queue_.push_back(e_found);
} else {
if ( *(e_found->val) < *new_tok ) {
if (*(e_found->val) < *new_tok) {
Token::TokenDelete(e_found->val);
e_found->val = new_tok;
queue_.push_back(arc.nextstate);
queue_.push_back(e_found);
} else {
Token::TokenDelete(new_tok);
}
Expand Down
2 changes: 1 addition & 1 deletion src/decoder/faster-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class FasterDecoder {
HashList<StateId, Token*> toks_;
const fst::Fst<fst::StdArc> &fst_;
FasterDecoderOptions config_;
std::vector<StateId> queue_; // temp variable used in ProcessNonemitting,
std::vector<const Elem* > queue_; // temp variable used in ProcessNonemitting,
std::vector<BaseFloat> tmp_array_; // used in GetCutoff.
// make it class member to avoid internal new/delete.

Expand Down
37 changes: 19 additions & 18 deletions src/decoder/lattice-biglm-faster-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,14 @@ class LatticeBiglmFasterDecoder {
// for the current frame. [note: it's inserted if necessary into hash toks_
// and also into the singly linked list of tokens active on this frame
// (whose head is at active_toks_[frame]).
inline Token *FindOrAddToken(PairId state_pair, int32 frame, BaseFloat tot_cost,
bool emitting, bool *changed) {
inline Elem *FindOrAddToken(PairId state_pair, int32 frame,
BaseFloat tot_cost, bool emitting, bool *changed) {
// Returns the Token pointer. Sets "changed" (if non-NULL) to true
// if the token was newly created or the cost changed.
KALDI_ASSERT(frame < active_toks_.size());
Token *&toks = active_toks_[frame].toks;
Elem *e_found = toks_.Find(state_pair);
if (e_found == NULL) { // no such token presently.
Elem *e_found = toks_.Insert(state_pair, NULL);
if (e_found->val == NULL) { // no such token presently.
const BaseFloat extra_cost = 0.0;
// tokens on the currently final frame have zero extra_cost
// as any of them could end up
Expand All @@ -328,9 +328,9 @@ class LatticeBiglmFasterDecoder {
// NULL: no forward links yet
toks = new_tok;
num_toks_++;
toks_.Insert(state_pair, new_tok);
e_found->val = new_tok;
if (changed) *changed = true;
return new_tok;
return e_found;
} else {
Token *tok = e_found->val; // There is an existing Token for this state.
if (tok->tot_cost > tot_cost) { // replace old token
Expand All @@ -346,7 +346,7 @@ class LatticeBiglmFasterDecoder {
} else {
if (changed) *changed = false;
}
return tok;
return e_found;
}
}

Expand Down Expand Up @@ -744,11 +744,11 @@ class LatticeBiglmFasterDecoder {
else if (tot_cost + config_.beam < next_cutoff)
next_cutoff = tot_cost + config_.beam; // prune by best current token
PairId next_pair = ConstructPair(arc.nextstate, next_lm_state);
Token *next_tok = FindOrAddToken(next_pair, frame, tot_cost, true, NULL);
Elem *e_next = FindOrAddToken(next_pair, frame, tot_cost, true, NULL);
// true: emitting, NULL: no change indicator needed

// Add ForwardLink from tok to next_tok (put on head of list tok->links)
tok->links = new ForwardLink(next_tok, arc.ilabel, arc.olabel,
tok->links = new ForwardLink(e_next->val, arc.ilabel, arc.olabel,
graph_cost, ac_cost, tok->links);
}
} // for all arcs
Expand All @@ -770,7 +770,7 @@ class LatticeBiglmFasterDecoder {
KALDI_ASSERT(queue_.empty());
BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
queue_.push_back(e->key);
queue_.push_back(e);
// for pruning with current best token
best_cost = std::min(best_cost, static_cast<BaseFloat>(e->val->tot_cost));
}
Expand All @@ -784,11 +784,12 @@ class LatticeBiglmFasterDecoder {
BaseFloat cutoff = best_cost + config_.beam;

while (!queue_.empty()) {
PairId state_pair = queue_.back();
const Elem *e = queue_.back();
queue_.pop_back();

Token *tok = toks_.Find(state_pair)->val; // would segfault if state not in
// toks_ but this can't happen.
PairId state_pair = e->key;
Token *tok = e->val; // would segfault if state not in
// toks_ but this can't happen.
BaseFloat cur_cost = tok->tot_cost;
if (cur_cost > cutoff) // Don't bother processing successors.
continue;
Expand All @@ -812,15 +813,15 @@ class LatticeBiglmFasterDecoder {
if (tot_cost < cutoff) {
bool changed;
PairId next_pair = ConstructPair(arc.nextstate, next_lm_state);
Token *new_tok = FindOrAddToken(next_pair, frame, tot_cost,
false, &changed); // false: non-emit
Elem *e_new = FindOrAddToken(next_pair, frame, tot_cost,
false, &changed); // false: non-emit

tok->links = new ForwardLink(new_tok, 0, arc.olabel,
tok->links = new ForwardLink(e_new->val, 0, arc.olabel,
graph_cost, 0, tok->links);

// "changed" tells us whether the new token has a different
// cost from before, or is new [if so, add into queue].
if (changed) queue_.push_back(next_pair);
if (changed) queue_.push_back(e_new);
}
}
} // for all arcs
Expand All @@ -835,7 +836,7 @@ class LatticeBiglmFasterDecoder {
std::vector<TokenList> active_toks_; // Lists of tokens, indexed by
// frame (members of TokenList are toks, must_prune_forward_links,
// must_prune_tokens).
std::vector<PairId> queue_; // temp variable used in ProcessNonemitting,
std::vector<const Elem* > queue_; // temp variable used in ProcessNonemitting,
std::vector<BaseFloat> tmp_array_; // used in GetCutoff.
// make it class member to avoid internal new/delete.
const fst::Fst<fst::StdArc> &fst_;
Expand Down
32 changes: 17 additions & 15 deletions src/decoder/lattice-faster-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,16 @@ void LatticeFasterDecoderTpl<FST, Token>::PossiblyResizeHash(size_t num_toks) {
// and also into the singly linked list of tokens active on this frame
// (whose head is at active_toks_[frame]).
template <typename FST, typename Token>
inline Token* LatticeFasterDecoderTpl<FST, Token>::FindOrAddToken(
inline typename LatticeFasterDecoderTpl<FST, Token>::Elem*
LatticeFasterDecoderTpl<FST, Token>::FindOrAddToken(
StateId state, int32 frame_plus_one, BaseFloat tot_cost,
Token *backpointer, bool *changed) {
// Returns the Token pointer. Sets "changed" (if non-NULL) to true
// if the token was newly created or the cost changed.
KALDI_ASSERT(frame_plus_one < active_toks_.size());
Token *&toks = active_toks_[frame_plus_one].toks;
Elem *e_found = toks_.Find(state);
if (e_found == NULL) { // no such token presently.
Elem *e_found = toks_.Insert(state, NULL);
if (e_found->val == NULL) { // no such token presently.
const BaseFloat extra_cost = 0.0;
// tokens on the currently final frame have zero extra_cost
// as any of them could end up
Expand All @@ -280,9 +281,9 @@ inline Token* LatticeFasterDecoderTpl<FST, Token>::FindOrAddToken(
// NULL: no forward links yet
toks = new_tok;
num_toks_++;
toks_.Insert(state, new_tok);
e_found->val = new_tok;
if (changed) *changed = true;
return new_tok;
return e_found;
} else {
Token *tok = e_found->val; // There is an existing Token for this state.
if (tok->tot_cost > tot_cost) { // replace old token
Expand All @@ -301,7 +302,7 @@ inline Token* LatticeFasterDecoderTpl<FST, Token>::FindOrAddToken(
} else {
if (changed) *changed = false;
}
return tok;
return e_found;
}
}

Expand Down Expand Up @@ -800,12 +801,12 @@ BaseFloat LatticeFasterDecoderTpl<FST, Token>::ProcessEmitting(
next_cutoff = tot_cost + adaptive_beam; // prune by best current token
// Note: the frame indexes into active_toks_ are one-based,
// hence the + 1.
Token *next_tok = FindOrAddToken(arc.nextstate,
frame + 1, tot_cost, tok, NULL);
Elem *e_next = FindOrAddToken(arc.nextstate,
frame + 1, tot_cost, tok, NULL);
// NULL: no change indicator needed

// Add ForwardLink from tok to next_tok (put on head of list tok->links)
tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel,
tok->links = new ForwardLinkT(e_next->val, arc.ilabel, arc.olabel,
graph_cost, ac_cost, tok->links);
}
} // for all arcs
Expand Down Expand Up @@ -855,14 +856,15 @@ void LatticeFasterDecoderTpl<FST, Token>::ProcessNonemitting(BaseFloat cutoff) {
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
StateId state = e->key;
if (fst_->NumInputEpsilons(state) != 0)
queue_.push_back(state);
queue_.push_back(e);
}

while (!queue_.empty()) {
StateId state = queue_.back();
const Elem *e = queue_.back();
queue_.pop_back();

Token *tok = toks_.Find(state)->val; // would segfault if state not in toks_ but this can't happen.
StateId state = e->key;
Token *tok = e->val; // would segfault if e is a NULL pointer but this can't happen.
BaseFloat cur_cost = tok->tot_cost;
if (cur_cost > cutoff) // Don't bother processing successors.
continue;
Expand All @@ -882,16 +884,16 @@ void LatticeFasterDecoderTpl<FST, Token>::ProcessNonemitting(BaseFloat cutoff) {
if (tot_cost < cutoff) {
bool changed;

Token *new_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost,
Elem *e_new = FindOrAddToken(arc.nextstate, frame + 1, tot_cost,
tok, &changed);

tok->links = new ForwardLinkT(new_tok, 0, arc.olabel,
tok->links = new ForwardLinkT(e_new->val, 0, arc.olabel,
graph_cost, 0, tok->links);

// "changed" tells us whether the new token has a different
// cost from before, or is new [if so, add into queue].
if (changed && fst_->NumInputEpsilons(arc.nextstate) != 0)
queue_.push_back(arc.nextstate);
queue_.push_back(e_new);
}
}
} // for all arcs
Expand Down
8 changes: 4 additions & 4 deletions src/decoder/lattice-faster-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,9 @@ class LatticeFasterDecoderTpl {
// token was newly created or the cost changed.
// If Token == StdToken, the 'backpointer' argument has no purpose (and will
// hopefully be optimized out).
inline Token *FindOrAddToken(StateId state, int32 frame_plus_one,
BaseFloat tot_cost, Token *backpointer,
bool *changed);
inline Elem *FindOrAddToken(StateId state, int32 frame_plus_one,
BaseFloat tot_cost, Token *backpointer,
bool *changed);

// prunes outgoing links for all tokens in active_toks_[frame]
// it's called by PruneActiveTokens
Expand Down Expand Up @@ -464,7 +464,7 @@ class LatticeFasterDecoderTpl {
std::vector<TokenList> active_toks_; // Lists of tokens, indexed by
// frame (members of TokenList are toks, must_prune_forward_links,
// must_prune_tokens).
std::vector<StateId> queue_; // temp variable used in ProcessNonemitting,
std::vector<const Elem* > queue_; // temp variable used in ProcessNonemitting,
std::vector<BaseFloat> tmp_array_; // used in GetCutoff.

// fst_ is a pointer to the FST we are decoding from.
Expand Down
16 changes: 13 additions & 3 deletions src/util/hash-list-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,24 @@ HashList<I, T>::~HashList() {
}
}


template<class I, class T>
void HashList<I, T>::Insert(I key, T val) {
inline typename HashList<I, T>::Elem* HashList<I, T>::Insert(I key, T val) {
size_t index = (static_cast<size_t>(key) % hash_size_);
HashBucket &bucket = buckets_[index];
// Check the element is existing or not.
if (bucket.last_elem != NULL) {
Elem *head = (bucket.prev_bucket == static_cast<size_t>(-1) ?
list_head_ :
buckets_[bucket.prev_bucket].last_elem->tail),
*tail = bucket.last_elem->tail;
for (Elem *e = head; e != tail; e = e->tail)
if (e->key == key) return e;
}

// This is a new element. Insert it.
Elem *elem = New();
elem->key = key;
elem->val = val;

if (bucket.last_elem == NULL) { // Unoccupied bucket. Insert at
// head of bucket list (which is tail of regular list, they go in
// opposite directions).
Expand All @@ -152,6 +161,7 @@ void HashList<I, T>::Insert(I key, T val) {
bucket.last_elem->tail = elem;
bucket.last_elem = elem;
}
return elem;
}

template<class I, class T>
Expand Down
Loading