diff --git a/src/fstext/Makefile b/src/fstext/Makefile index d00b97a88bc..099e269e4cd 100644 --- a/src/fstext/Makefile +++ b/src/fstext/Makefile @@ -13,7 +13,7 @@ include ../kaldi.mk TESTFILES = determinize-star-test \ pre-determinize-test trivial-factor-weight-test \ context-fst-test factor-test table-matcher-test fstext-utils-test \ - remove-eps-local-test rescale-test lattice-weight-test \ + remove-eps-local-test lattice-weight-test \ determinize-lattice-test lattice-utils-test deterministic-fst-test \ push-special-test epsilon-property-test prune-special-test diff --git a/src/fstext/rescale-inl.h b/src/fstext/rescale-inl.h deleted file mode 100644 index e7e237fef61..00000000000 --- a/src/fstext/rescale-inl.h +++ /dev/null @@ -1,213 +0,0 @@ -// fstext/rescale-inl.h - -// Copyright 2009-2011 Microsoft Corporation; Jan Silovsky - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_FSTEXT_RESCALE_INL_H_ -#define KALDI_FSTEXT_RESCALE_INL_H_ -#include -#include "base/kaldi-common.h" -#include "base/kaldi-math.h" -#include "util/stl-utils.h" -#include "fstext/fstext-utils.h" - -namespace fst { - - -template -inline typename Arc::Weight -ComputeTotalWeight(ExpandedFst &fst, typename Arc::Weight max_weight, float delta) { - typedef typename Arc::StateId StateId; - typedef typename Arc::Weight Weight; - if (fst.Start() == kNoStateId) return Weight::Zero(); - StateId num_states = fst.NumStates(); - - float zero = Weight::Zero().Value(); - - // Should probably use Weight instead of float here, but would - // involve some painful comparators. - vector cur_tot(num_states, zero); - vector cur_delta(num_states, zero); - vector queued(num_states, false); - - std::queue q; // FIFO queue. - - Weight total_final = Weight::Zero(); - { - float f = static_cast(Weight::One().Value()); - q.push(fst.Start()); - cur_delta[fst.Start()] = cur_tot[fst.Start()] = f; - queued[fst.Start()] = true; - } - - while (!q.empty()) { - StateId s = q.front(); - q.pop(); - Weight w = cur_delta[s]; - cur_delta[s] = zero; - queued[s] = false; - cur_tot[s] = Plus(w, Weight(cur_tot[s])).Value(); - - for (ArcIterator > aiter(fst, s); !aiter.Done(); aiter.Next()) { - const Arc &arc = aiter.Value(); - Weight next_weight = Times(w, arc.weight); - cur_delta[arc.nextstate] = Plus(Weight(cur_delta[arc.nextstate]), - next_weight).Value(); - if (!queued[arc.nextstate] && - !ApproxEqual(Weight(cur_tot[arc.nextstate]), - Plus(Weight(cur_delta[arc.nextstate]), - Weight(cur_tot[arc.nextstate])), delta)) { - queued[arc.nextstate] = true; - q.push(arc.nextstate); - } - } - Weight final = fst.Final(s); - if (final != Weight::Zero()) { - total_final = Plus(total_final, Times(w, final)); - if (total_final.Value() < max_weight.Value()) { // Note: this means that total_final is MORE THAN max_weight. - // assuming the weight is of the normal type. - return max_weight; - } - } - } - return total_final; -} - - - -template -inline void Rescale(MutableFst *fst, typename Arc::Weight rescale) { - typedef typename Arc::StateId StateId; - // Multiplies all arcs and final-probs in the FST by this rescaling amount. - // Typically useful with non-stochastic FST, in conjunction with pushing: i.e. - // we rescale, push, and rescale again. - for (StateIterator > siter(*fst); !siter.Done(); siter.Next()) { - StateId s = siter.Value(); - fst->SetFinal(s, Times(rescale, fst->Final(s))); - for (MutableArcIterator > aiter(fst, s); - !aiter.Done(); - aiter.Next()) { - Arc arc = aiter.Value(); - arc.weight = Times(arc.weight, rescale); - aiter.SetValue(arc); - } - } -} - -inline LogWeight RescaleToStochastic(MutableFst *fst, - float approx_delta, - float delta) { - // Rescales the FST until, to within "approx_delta", it sums to one (within - // its own semiring). Returns the amount it rescaled by. Must be of the - // LogArc or StdArc type. - typedef LogArc::Weight Weight; - - if (fst->Start() == kNoStateId) - return Weight::One(); // can't rescale empty FST. - - // total weight). - Weight max = Weight(-kaldi::Log(2.0)); - - // upper_bound and lower_bound are in terms of weight.Value(), - // in terms of weight they would have the reversed names. - float upper_bound = 1.0e+20, lower_bound = -1.0e+20; - bool have_lower = false, have_upper = false; - - Weight last_rescale = Weight::One(); - Weight last_tot = Weight::Zero(); - - Weight cur_rescale = Weight::One(); - Weight cur_tot; - while (1) { - { - FstPrinter fstprinter(*fst, NULL, NULL, NULL, false, true); - fstprinter.Print(&std::cout, "standard output"); - } - - cur_tot = ComputeTotalWeight(*fst, max, delta); - std::cerr << "Current rescaling factor is " << cur_rescale <<", total is: " << cur_tot << '\n'; - if (cur_tot.Value() < Weight::One().Value()) { // read as: cur_tot > 1. - if (cur_rescale.Value() > lower_bound) { - lower_bound = cur_rescale.Value(); - have_lower = true; - } - } else { - if (cur_rescale.Value() < upper_bound) { - upper_bound = cur_rescale.Value(); - have_upper = true; - } - } - - Weight factor; // next factor by which we multiply. - if (ApproxEqual(cur_tot, Weight::One())) return cur_rescale; // Succeeded. - else if (cur_tot == max) { // Maxed out. - if (last_tot == Weight::Zero()) { // First iter. - factor = Weight(0.2); - } else if (last_tot != max) { // Didn't max out last time... - assert(cur_tot.Value() < last_tot.Value()); // or doesn't make sense. - factor = Weight( (last_rescale.Value() - cur_rescale.Value()) / 2.0); - } else { // Maxed out last time. Use larger interval. - factor = Weight(0.5); - } - } else { - // cur_tot != max. - if (last_tot == Weight::Zero() || last_tot == max) { // No derivative information. - if (cur_tot.Value() > 0) factor = Weight(-0.01); // enough to approximate derivative. - else factor = Weight(0.01); - } else { - float dtot_drescale = (cur_tot.Value()-last_tot.Value()) / - (cur_rescale.Value() - last_rescale.Value()); - std::cerr << "dtot_drescale = " < 0) factor = Weight(-0.01); // enough to approximate derivative. - else factor = Weight(0.01); - } else { - factor = Weight(-cur_tot.Value() / dtot_drescale); // want to get the value to zero. - std::cerr << "Factor is " << factor.Value()<<'\n'; - } - } - } - - last_rescale = cur_rescale; - last_tot = cur_tot; - - if (have_lower && have_upper) { - float oldval = cur_rescale.Value(), - newval = oldval+factor.Value(); - assert(upper_bound > lower_bound); - assert(oldval >= lower_bound && oldval <= upper_bound); - float margin = 0.1; // limit must be 0.0 < margin < 0.5. - float new_lower = lower_bound + (oldval-lower_bound)*margin; - float new_upper = upper_bound + (oldval-upper_bound)*margin; - if (newval > new_upper) newval = new_upper; - if (newval < new_lower) newval = new_lower; - factor = Weight(newval - cur_rescale.Value()); - cur_rescale = Weight(newval); - } else { - cur_rescale = Times(cur_rescale, factor); - } - Rescale(fst, factor); - - } -} - - - -} // namespace fst. - -#endif diff --git a/src/fstext/rescale-test.cc b/src/fstext/rescale-test.cc deleted file mode 100644 index 68e21119791..00000000000 --- a/src/fstext/rescale-test.cc +++ /dev/null @@ -1,109 +0,0 @@ -// fstext/rescale-test.cc - -// Copyright 2009-2011 Microsoft Corporation - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include "fstext/rescale.h" -#include "fstext/fstext-utils.h" -#include "fstext/fst-test-utils.h" -#include "base/kaldi-math.h" -// Just check that it compiles, for now. - -namespace fst -{ - - using kaldi::Exp; - using kaldi::Log; - -template void TestComputeTotalWeight() { - typedef typename Arc::Weight Weight; - VectorFst *fst = RandFst(); - - std::cout <<" printing FST at start\n"; - { -#ifdef HAVE_OPENFST_GE_10400 - FstPrinter fstprinter(*fst, NULL, NULL, NULL, false, true, "\t"); -#else - FstPrinter fstprinter(*fst, NULL, NULL, NULL, false, true); -#endif - fstprinter.Print(&std::cout, "standard output"); - } - - Weight max(-Log(2.0)); - Weight tot = ComputeTotalWeight(*fst, max); - std::cout << "Total weight is: " << tot.Value() << '\n'; - - - if (tot.Value() > max.Value()) { // didn't max out... - Weight tot2 = ShortestDistance(*fst); - if (!ApproxEqual(tot, tot2, 0.05)) { - KALDI_ERR << tot << " differs from " << tot2; - assert(0); - } - std::cout << "our tot: " < *fst = RandFst(opts); - - std::cout <<" printing FST at start\n"; - { -#ifdef HAVE_OPENFST_GE_10400 - FstPrinter fstprinter(*fst, NULL, NULL, NULL, false, true, "\t"); -#else - FstPrinter fstprinter(*fst, NULL, NULL, NULL, false, true); -#endif - fstprinter.Print(&std::cout, "standard output"); - - } - float diff = 0.01; - - RescaleToStochastic(fst, diff); - Weight tot = ShortestDistance(*fst), - tot2 = ComputeTotalWeight(*fst, Weight(-Log(2.0))); - std::cerr << " tot is " << tot<<", tot2 = "<(); - std::cout << "Testing with log:\n"; - fst::TestComputeTotalWeight(); - } - for (int i = 0;i < 10;i++) { - std::cout << "i = "< -#include -#include -#include -#include -#include - - -namespace fst { - - -/// ComputeTotalWeight computes (approximately) the total weight of the FST, -/// i.e. the sum of all paths. It will only work for arcs of StdArc/LogArc type -/// whose weights we can compare using Value(). If the total weight is greater -/// than max_weight, we just return max_weight (this enables us to avoid -/// pathological cases that would not terminate). - -template -inline typename Arc::Weight -ComputeTotalWeight(ExpandedFst &fst, - typename Arc::Weight max_weight, - float delta = kDelta); - - -/// Rescale multiplies (in the semiring) all weights and final probabilities in -/// the FST by this weight. Does not preserve equivalence. -template -inline void Rescale(MutableFst *fst, - typename Arc::Weight rescale); - - -/// RescaleToStochastic uses a form of line search to compute the weight we must -/// apply to the FST using Rescale to make it so that the "total weight" of the -/// FST is unity, and applies this weight. The projected use-case is that -/// you want to do push-weights but you're scared this might blow up, so you -/// do RescaleToStochastic, push-weights, and then Rescale with the inverse -/// (in the semiring) of that weight, so that you are equivalent to the -/// original FST and the "non-stochasticity" is distributed equally among -/// all states. -inline LogWeight RescaleToStochastic(MutableFst *fst, - float approx_delta = 0.001, - float delta = kDelta); - - -} // end namespace fst - - -#include "fstext/rescale-inl.h" - -#endif