diff --git a/src/bin/vector-sum.cc b/src/bin/vector-sum.cc index 20f58d52b7d..42404e38384 100644 --- a/src/bin/vector-sum.cc +++ b/src/bin/vector-sum.cc @@ -101,7 +101,8 @@ int32 TypeOneUsage(const ParseOptions &po) { } int32 TypeTwoUsage(const ParseOptions &po, - bool binary) { + bool binary, + bool average = false) { KALDI_ASSERT(po.NumArgs() == 2); KALDI_ASSERT(ClassifyRspecifier(po.GetArg(1), NULL, NULL) != kNoRspecifier && "vector-sum: first argument must be an rspecifier"); @@ -133,6 +134,8 @@ int32 TypeTwoUsage(const ParseOptions &po, } } } + + if (num_done > 0 && average) sum.Scale(1.0 / num_done); Vector sum_float(sum); WriteKaldiObject(sum_float, po.GetArg(2), binary); @@ -199,12 +202,13 @@ int main(int argc, char *argv[]) { " e.g.: vector-sum --binary=false 1.vec 2.vec 3.vec sum.vec\n" "See also: copy-vector, dot-weights\n"; - bool binary; + bool binary, average = false; ParseOptions po(usage); po.Register("binary", &binary, "If true, write output as binary (only " "relevant for usage types two or three"); + po.Register("average", &average, "Do average instead of sum"); po.Read(argc, argv); @@ -219,7 +223,7 @@ int main(int argc, char *argv[]) { ClassifyWspecifier(po.GetArg(N), NULL, NULL, NULL) == kNoWspecifier) { // input from a single table, output not to table. - exit_status = TypeTwoUsage(po, binary); + exit_status = TypeTwoUsage(po, binary, average); } else if (po.NumArgs() >= 2 && ClassifyRspecifier(po.GetArg(1), NULL, NULL) == kNoRspecifier && ClassifyWspecifier(po.GetArg(N), NULL, NULL, NULL) == diff --git a/src/chainbin/nnet3-chain-get-egs.cc b/src/chainbin/nnet3-chain-get-egs.cc index d6094bd3cc8..ed162d1d18b 100644 --- a/src/chainbin/nnet3-chain-get-egs.cc +++ b/src/chainbin/nnet3-chain-get-egs.cc @@ -25,6 +25,7 @@ #include "hmm/posterior.h" #include "nnet3/nnet-example.h" #include "nnet3/nnet-chain-example.h" +#include "nnet3/nnet-example-utils.h" namespace kaldi { namespace nnet3 { @@ -207,35 +208,6 @@ static bool ProcessFile(const fst::StdVectorFst &normalization_fst, return true; } -void RoundUpNumFrames(int32 frame_subsampling_factor, - int32 *num_frames, - int32 *num_frames_overlap) { - if (*num_frames % frame_subsampling_factor != 0) { - int32 new_num_frames = frame_subsampling_factor * - (*num_frames / frame_subsampling_factor + 1); - KALDI_LOG << "Rounding up --num-frames=" << (*num_frames) - << " to a multiple of --frame-subsampling-factor=" - << frame_subsampling_factor - << ", now --num-frames=" << new_num_frames; - *num_frames = new_num_frames; - } - if (*num_frames_overlap % frame_subsampling_factor != 0) { - int32 new_num_frames_overlap = frame_subsampling_factor * - (*num_frames_overlap / frame_subsampling_factor + 1); - KALDI_LOG << "Rounding up --num-frames-overlap=" << (*num_frames_overlap) - << " to a multiple of --frame-subsampling-factor=" - << frame_subsampling_factor - << ", now --num-frames-overlap=" << new_num_frames_overlap; - *num_frames_overlap = new_num_frames_overlap; - } - if (*num_frames_overlap < 0 || *num_frames_overlap >= *num_frames) { - KALDI_ERR << "--num-frames-overlap=" << (*num_frames_overlap) << " < " - << "--num-frames=" << (*num_frames); - } - -} - - } // namespace nnet2 } // namespace kaldi diff --git a/src/hmm/posterior.cc b/src/hmm/posterior.cc index 25acf48a7d1..4e5cbd45282 100644 --- a/src/hmm/posterior.cc +++ b/src/hmm/posterior.cc @@ -429,18 +429,6 @@ void WeightSilencePostDistributed(const TransitionModel &trans_model, } } -// comparator object that can be used to sort from greatest to -// least posterior. -struct CompareReverseSecond { - // view this as an "<" operator used for sorting, except it behaves like - // a ">" operator on the .second field of the pair because we want the - // sort to be in reverse order (greatest to least) on posterior. - bool operator() (const std::pair &a, - const std::pair &b) { - return (a.second > b.second); - } -}; - BaseFloat VectorToPosteriorEntry( const VectorBase &log_likes, int32 num_gselect, diff --git a/src/hmm/posterior.h b/src/hmm/posterior.h index 18bbd65a86a..4f5896da7c6 100644 --- a/src/hmm/posterior.h +++ b/src/hmm/posterior.h @@ -155,6 +155,18 @@ int32 MergePosteriors(const Posterior &post1, bool drop_frames, Posterior *post); +// comparator object that can be used to sort from greatest to +// least posterior. +struct CompareReverseSecond { + // view this as an "<" operator used for sorting, except it behaves like + // a ">" operator on the .second field of the pair because we want the + // sort to be in reverse order (greatest to least) on posterior. + bool operator() (const std::pair &a, + const std::pair &b) { + return (a.second > b.second); + } +}; + /// Given a vector of log-likelihoods (typically of Gaussians in a GMM /// but could be of pdf-ids), a number gselect >= 1 and a minimum posterior /// 0 <= min_post < 1, it gets the posterior for each element of log-likes diff --git a/src/lat/lattice-functions.cc b/src/lat/lattice-functions.cc index 0ea66712eda..d8443bd7434 100644 --- a/src/lat/lattice-functions.cc +++ b/src/lat/lattice-functions.cc @@ -405,15 +405,11 @@ static inline double LogAddOrMax(bool viterbi, double a, double b) { return LogAdd(a, b); } -// Computes (normal or Viterbi) alphas and betas; returns (total-prob, or -// best-path negated cost) Note: in either case, the alphas and betas are -// negated costs. Requires that lat be topologically sorted. This code -// will work for either CompactLattice or Latice. template -static double ComputeLatticeAlphasAndBetas(const LatticeType &lat, - bool viterbi, - vector *alpha, - vector *beta) { +double ComputeLatticeAlphasAndBetas(const LatticeType &lat, + bool viterbi, + vector *alpha, + vector *beta) { typedef typename LatticeType::Arc Arc; typedef typename Arc::Weight Weight; typedef typename Arc::StateId StateId; @@ -462,6 +458,19 @@ static double ComputeLatticeAlphasAndBetas(const LatticeType &lat, return 0.5 * (tot_backward_prob + tot_forward_prob); } +// instantiate the template for Lattice and CompactLattice +template +double ComputeLatticeAlphasAndBetas(const Lattice &lat, + bool viterbi, + vector *alpha, + vector *beta); + +template +double ComputeLatticeAlphasAndBetas(const CompactLattice &lat, + bool viterbi, + vector *alpha, + vector *beta); + /// This is used in CompactLatticeLimitDepth. diff --git a/src/lat/lattice-functions.h b/src/lat/lattice-functions.h index 0b3e9f8ecc4..c58b2ec32b8 100644 --- a/src/lat/lattice-functions.h +++ b/src/lat/lattice-functions.h @@ -74,6 +74,18 @@ bool ComputeCompactLatticeAlphas(const CompactLattice &lat, bool ComputeCompactLatticeBetas(const CompactLattice &lat, vector *beta); + +// Computes (normal or Viterbi) alphas and betas; returns (total-prob, or +// best-path negated cost) Note: in either case, the alphas and betas are +// negated costs. Requires that lat be topologically sorted. This code +// will work for either CompactLattice or Latice. +template +double ComputeLatticeAlphasAndBetas(const LatticeType &lat, + bool viterbi, + vector *alpha, + vector *beta); + + /// Topologically sort the compact lattice if not already topologically sorted. /// Will crash if the lattice cannot be topologically sorted. void TopSortCompactLatticeIfNeeded(CompactLattice *clat); diff --git a/src/latbin/lattice-copy.cc b/src/latbin/lattice-copy.cc index 76ca034b2e4..f66eb699705 100644 --- a/src/latbin/lattice-copy.cc +++ b/src/latbin/lattice-copy.cc @@ -24,6 +24,108 @@ #include "fstext/fstext-lib.h" #include "lat/kaldi-lattice.h" +namespace kaldi { + int32 CopySubsetLattices(std::string filename, + SequentialLatticeReader *lattice_reader, + LatticeWriter *lattice_writer, + bool include = true, bool ignore_missing = false + ) { + unordered_set subset; + std::set subset_list; + + bool binary; + Input ki(filename, &binary); + KALDI_ASSERT(!binary); + std::string line; + while (std::getline(ki.Stream(), line)) { + std::vector split_line; + SplitStringToVector(line, " \t\r", true, &split_line); + if(split_line.empty()) { + KALDI_ERR << "Unable to parse line \"" << line << "\" encountered in input in " << filename; + } + subset.insert(split_line[0]); + subset_list.insert(split_line[0]); + } + + int32 num_total = 0; + size_t num_success = 0; + for (; !lattice_reader->Done(); lattice_reader->Next(), num_total++) { + if (include && lattice_reader->Key() > *(subset_list.rbegin())) { + KALDI_LOG << "The utterance " << lattice_reader->Key() + << " is larger than " + << "the last key in the include list. Not reading further."; + KALDI_LOG << "Wrote " << num_success << " utterances"; + return 0; + } + + if (include && subset.count(lattice_reader->Key()) > 0) { + lattice_writer->Write(lattice_reader->Key(), lattice_reader->Value()); + num_success++; + } else if (!include && subset.count(lattice_reader->Key()) == 0) { + lattice_writer->Write(lattice_reader->Key(), lattice_reader->Value()); + num_success++; + } + } + + KALDI_LOG << "Wrote " << num_success << " out of " << num_total + << " utterances."; + + if (ignore_missing) return 0; + + return (num_success != 0 ? 0 : 1); + } + + int32 CopySubsetLattices(std::string filename, + SequentialCompactLatticeReader *lattice_reader, + CompactLatticeWriter *lattice_writer, + bool include = true, bool ignore_missing = false + ) { + unordered_set subset; + std::set subset_list; + + bool binary; + Input ki(filename, &binary); + KALDI_ASSERT(!binary); + std::string line; + while (std::getline(ki.Stream(), line)) { + std::vector split_line; + SplitStringToVector(line, " \t\r", true, &split_line); + if(split_line.empty()) { + KALDI_ERR << "Unable to parse line \"" << line << "\" encountered in input in " << filename; + } + subset.insert(split_line[0]); + subset_list.insert(split_line[0]); + } + + int32 num_total = 0; + size_t num_success = 0; + for (; !lattice_reader->Done(); lattice_reader->Next(), num_total++) { + if (include && lattice_reader->Key() > *(subset_list.rbegin())) { + KALDI_LOG << "The utterance " << lattice_reader->Key() + << " is larger than " + << "the last key in the include list. Not reading further."; + KALDI_LOG << "Wrote " << num_success << " utterances"; + return 0; + } + + if (include && subset.count(lattice_reader->Key()) > 0) { + lattice_writer->Write(lattice_reader->Key(), lattice_reader->Value()); + num_success++; + } else if (!include && subset.count(lattice_reader->Key()) == 0) { + lattice_writer->Write(lattice_reader->Key(), lattice_reader->Value()); + num_success++; + } + } + + KALDI_LOG << " Wrote " << num_success << " out of " << num_total + << " utterances."; + + if (ignore_missing) return 0; + + return (num_success != 0 ? 0 : 1); + } +} + int main(int argc, char *argv[]) { try { using namespace kaldi; @@ -36,14 +138,32 @@ int main(int argc, char *argv[]) { const char *usage = "Copy lattices (e.g. useful for changing to text mode or changing\n" "format to standard from compact lattice.)\n" + "The --include and --exclude options can be used to copy only a subset " + "of lattices, where are the --include option specifies the " + "whitelisted utterances that would be copied and --exclude option " + "specifies the blacklisted utterances that would not be copied.\n" + "Only one of --include and --exclude can be supplied.\n" "Usage: lattice-copy [options] lattice-rspecifier lattice-wspecifier\n" " e.g.: lattice-copy --write-compact=false ark:1.lats ark,t:text.lats\n" "See also: lattice-to-fst, and the script egs/wsj/s5/utils/convert_slf.pl\n"; ParseOptions po(usage); - bool write_compact = true; + bool write_compact = true, ignore_missing = false; + std::string include_rxfilename; + std::string exclude_rxfilename; + po.Register("write-compact", &write_compact, "If true, write in normal (compact) form."); - + po.Register("include", &include_rxfilename, + "Text file, the first field of each " + "line being interpreted as the " + "utterance-id whose lattices will be included"); + po.Register("exclude", &exclude_rxfilename, + "Text file, the first field of each " + "line being interpreted as an utterance-id " + "whose lattices will be excluded"); + po.Register("ignore-missing", &ignore_missing, + "Exit with status 0 even if no lattices are copied"); + po.Read(argc, argv); if (po.NumArgs() != 2) { @@ -59,15 +179,46 @@ int main(int argc, char *argv[]) { if (write_compact) { SequentialCompactLatticeReader lattice_reader(lats_rspecifier); CompactLatticeWriter lattice_writer(lats_wspecifier); + + if (include_rxfilename != "") { + if (exclude_rxfilename != "") { + KALDI_ERR << "should not have both --exclude and --include option!"; + } + return CopySubsetLattices(include_rxfilename, + &lattice_reader, &lattice_writer, + true, ignore_missing); + } else if (exclude_rxfilename != "") { + return CopySubsetLattices(exclude_rxfilename, + &lattice_reader, &lattice_writer, + false, ignore_missing); + } + for (; !lattice_reader.Done(); lattice_reader.Next(), n_done++) lattice_writer.Write(lattice_reader.Key(), lattice_reader.Value()); } else { SequentialLatticeReader lattice_reader(lats_rspecifier); LatticeWriter lattice_writer(lats_wspecifier); + + if (include_rxfilename != "") { + if (exclude_rxfilename != "") { + KALDI_ERR << "should not have both --exclude and --include option!"; + } + return CopySubsetLattices(include_rxfilename, + &lattice_reader, &lattice_writer, + true, ignore_missing); + } else if (exclude_rxfilename != "") { + return CopySubsetLattices(exclude_rxfilename, + &lattice_reader, &lattice_writer, + true, ignore_missing); + } + for (; !lattice_reader.Done(); lattice_reader.Next(), n_done++) lattice_writer.Write(lattice_reader.Key(), lattice_reader.Value()); } KALDI_LOG << "Done copying " << n_done << " lattices."; + + if (ignore_missing) return 0; + return (n_done != 0 ? 0 : 1); } catch(const std::exception &e) { std::cerr << e.what(); diff --git a/src/nnet3/nnet-chain-example.cc b/src/nnet3/nnet-chain-example.cc index 01e6cb80daf..74e8be80240 100644 --- a/src/nnet3/nnet-chain-example.cc +++ b/src/nnet3/nnet-chain-example.cc @@ -25,49 +25,6 @@ namespace kaldi { namespace nnet3 { -// writes compressed as unsigned char a vector 'vec' that is required to have -// values between 0 and 1. -static inline void WriteVectorAsChar(std::ostream &os, - bool binary, - const VectorBase &vec) { - if (binary) { - int32 dim = vec.Dim(); - std::vector char_vec(dim); - const BaseFloat *data = vec.Data(); - for (int32 i = 0; i < dim; i++) { - BaseFloat value = data[i]; - KALDI_ASSERT(value >= 0.0 && value <= 1.0); - // below, the adding 0.5 is done so that we round to the closest integer - // rather than rounding down (since static_cast will round down). - char_vec[i] = static_cast(255.0 * value + 0.5); - } - WriteIntegerVector(os, binary, char_vec); - } else { - // the regular floating-point format will be more readable for text mode. - vec.Write(os, binary); - } -} - -// reads data written by WriteVectorAsChar. -static inline void ReadVectorAsChar(std::istream &is, - bool binary, - Vector *vec) { - if (binary) { - BaseFloat scale = 1.0 / 255.0; - std::vector char_vec; - ReadIntegerVector(is, binary, &char_vec); - int32 dim = char_vec.size(); - vec->Resize(dim, kUndefined); - BaseFloat *data = vec->Data(); - for (int32 i = 0; i < dim; i++) - data[i] = scale * char_vec[i]; - } else { - vec->Read(is, binary); - } -} - - - void NnetChainSupervision::Write(std::ostream &os, bool binary) const { CheckDim(); WriteToken(os, binary, ""); diff --git a/src/nnet3/nnet-example-utils.cc b/src/nnet3/nnet-example-utils.cc index 99d41fb06c4..30f7840f6f8 100644 --- a/src/nnet3/nnet-example-utils.cc +++ b/src/nnet3/nnet-example-utils.cc @@ -219,5 +219,72 @@ void GetComputationRequest(const Nnet &nnet, KALDI_ERR << "No outputs in computation request."; } +void WriteVectorAsChar(std::ostream &os, + bool binary, + const VectorBase &vec) { + if (binary) { + int32 dim = vec.Dim(); + std::vector char_vec(dim); + const BaseFloat *data = vec.Data(); + for (int32 i = 0; i < dim; i++) { + BaseFloat value = data[i]; + KALDI_ASSERT(value >= 0.0 && value <= 1.0); + // below, the adding 0.5 is done so that we round to the closest integer + // rather than rounding down (since static_cast will round down). + char_vec[i] = static_cast(255.0 * value + 0.5); + } + WriteIntegerVector(os, binary, char_vec); + } else { + // the regular floating-point format will be more readable for text mode. + vec.Write(os, binary); + } +} + +void ReadVectorAsChar(std::istream &is, + bool binary, + Vector *vec) { + if (binary) { + BaseFloat scale = 1.0 / 255.0; + std::vector char_vec; + ReadIntegerVector(is, binary, &char_vec); + int32 dim = char_vec.size(); + vec->Resize(dim, kUndefined); + BaseFloat *data = vec->Data(); + for (int32 i = 0; i < dim; i++) + data[i] = scale * char_vec[i]; + } else { + vec->Read(is, binary); + } +} + +void RoundUpNumFrames(int32 frame_subsampling_factor, + int32 *num_frames, + int32 *num_frames_overlap) { + if (*num_frames % frame_subsampling_factor != 0) { + int32 new_num_frames = frame_subsampling_factor * + (*num_frames / frame_subsampling_factor + 1); + KALDI_LOG << "Rounding up --num-frames=" << (*num_frames) + << " to a multiple of --frame-subsampling-factor=" + << frame_subsampling_factor + << ", now --num-frames=" << new_num_frames; + *num_frames = new_num_frames; + } + if (*num_frames_overlap % frame_subsampling_factor != 0) { + int32 new_num_frames_overlap = frame_subsampling_factor * + (*num_frames_overlap / frame_subsampling_factor + 1); + KALDI_LOG << "Rounding up --num-frames-overlap=" << (*num_frames_overlap) + << " to a multiple of --frame-subsampling-factor=" + << frame_subsampling_factor + << ", now --num-frames-overlap=" << new_num_frames_overlap; + *num_frames_overlap = new_num_frames_overlap; + } + if (*num_frames_overlap < 0 || *num_frames_overlap >= *num_frames) { + KALDI_ERR << "--num-frames-overlap=" << (*num_frames_overlap) << " < " + << "--num-frames=" << (*num_frames); + } + +} + + } // namespace nnet3 } // namespace kaldi diff --git a/src/nnet3/nnet-example-utils.h b/src/nnet3/nnet-example-utils.h index d54c3296dac..3e309e18915 100644 --- a/src/nnet3/nnet-example-utils.h +++ b/src/nnet3/nnet-example-utils.h @@ -63,6 +63,23 @@ void GetComputationRequest(const Nnet &nnet, ComputationRequest *computation_request); +// Writes as unsigned char a vector 'vec' that is required to have +// values between 0 and 1. +void WriteVectorAsChar(std::ostream &os, + bool binary, + const VectorBase &vec); + +// Reads data written by WriteVectorAsChar. +void ReadVectorAsChar(std::istream &is, + bool binary, + Vector *vec); + +// This function rounds up the quantities 'num_frames' and 'num_frames_overlap' +// to the nearest multiple of the frame_subsampling_factor +void RoundUpNumFrames(int32 frame_subsampling_factor, + int32 *num_frames, + int32 *num_frames_overlap); + } // namespace nnet3 } // namespace kaldi diff --git a/src/nnet3/nnet-utils.cc b/src/nnet3/nnet-utils.cc index 1f45a3b90a4..3315bd1d31f 100644 --- a/src/nnet3/nnet-utils.cc +++ b/src/nnet3/nnet-utils.cc @@ -248,6 +248,22 @@ void ZeroComponentStats(Nnet *nnet) { } } +void ScaleLearningRate(BaseFloat learning_rate_scale, + Nnet *nnet) { + for (int32 c = 0; c < nnet->NumComponents(); c++) { + Component *comp = nnet->GetComponent(c); + if (comp->Properties() & kUpdatableComponent) { + // For now all updatable components inherit from class UpdatableComponent. + // If that changes in future, we will change this code. + UpdatableComponent *uc = dynamic_cast(comp); + if (uc == NULL) + KALDI_ERR << "Updatable component does not inherit from class " + "UpdatableComponent; change this code."; + uc->SetActualLearningRate(uc->LearningRate() * learning_rate_scale); + } + } +} + void SetLearningRate(BaseFloat learning_rate, Nnet *nnet) { for (int32 c = 0; c < nnet->NumComponents(); c++) { @@ -264,6 +280,57 @@ void SetLearningRate(BaseFloat learning_rate, } } +void SetLearningRates(const Vector &learning_rates, + Nnet *nnet) { + for (int32 c = 0, i = 0; c < nnet->NumComponents(); c++) { + Component *comp = nnet->GetComponent(c); + if (comp->Properties() & kUpdatableComponent) { + // For now all updatable components inherit from class UpdatableComponent. + // If that changes in future, we will change this code. + UpdatableComponent *uc = dynamic_cast(comp); + if (uc == NULL) + KALDI_ERR << "Updatable component does not inherit from class " + "UpdatableComponent; change this code."; + KALDI_ASSERT(i < learning_rates.Dim()); + uc->SetActualLearningRate(learning_rates(i)); + } + } +} + +void GetLearningRates(const Nnet &nnet, + Vector *learning_rates) { + learning_rates->Resize(NumUpdatableComponents(nnet)); + for (int32 c = 0, i = 0; c < nnet.NumComponents(); c++) { + const Component *comp = nnet.GetComponent(c); + if (comp->Properties() & kUpdatableComponent) { + // For now all updatable components inherit from class UpdatableComponent. + // If that changes in future, we will change this code. + const UpdatableComponent *uc = dynamic_cast(comp); + if (uc == NULL) + KALDI_ERR << "Updatable component does not inherit from class " + "UpdatableComponent; change this code."; + (*learning_rates)(i++) = uc->LearningRate(); + } + } +} + +void ScaleNnetComponents(const Vector &scale_factors, + Nnet *nnet) { + for (int32 c = 0, i = 0; c < nnet->NumComponents(); c++) { + Component *comp = nnet->GetComponent(c); + if (comp->Properties() & kUpdatableComponent) { + // For now all updatable components inherit from class UpdatableComponent. + // If that changes in future, we will change this code. + UpdatableComponent *uc = dynamic_cast(comp); + if (uc == NULL) + KALDI_ERR << "Updatable component does not inherit from class " + "UpdatableComponent; change this code."; + KALDI_ASSERT(i < scale_factors.Dim()); + uc->Scale(scale_factors(i)); + } + } +} + void ScaleNnet(BaseFloat scale, Nnet *nnet) { if (scale == 1.0) return; else if (scale == 0.0) { diff --git a/src/nnet3/nnet-utils.h b/src/nnet3/nnet-utils.h index a08a26bf27b..9b869aa7933 100644 --- a/src/nnet3/nnet-utils.h +++ b/src/nnet3/nnet-utils.h @@ -116,11 +116,33 @@ void ComputeSimpleNnetContext(const Nnet &nnet, void SetLearningRate(BaseFloat learning_rate, Nnet *nnet); +/// Scales the actual learning rate for all the components in the nnet +/// by this factor +void ScaleLearningRate(BaseFloat learning_rate_scale, + Nnet *nnet); + +/// Sets the actual learning rates for all the updatable components in the +/// neural net to the values in 'learning_rates' vector +/// (one for each updatable component). +void SetLearningRates(const Vector &learning_rates, + Nnet *nnet); + +/// Get the learning rates for all the updatable components in the neural net +/// (the output must have dim equal to the number of updatable components). +void GetLearningRates(const Nnet &nnet, + Vector *learning_rates); + /// Scales the nnet parameters and stats by this scale. void ScaleNnet(BaseFloat scale, Nnet *nnet); + +/// Scales the parameters of each of the updatable components. +/// Here, scales is a vector of size equal to the number of updatable +/// components +void ScaleNnetComponents(const Vector &scales, + Nnet *nnet); /// Does *dest += alpha * src (affects nnet parameters and -/// stored stats). +/// stored stats). void AddNnet(const Nnet &src, BaseFloat alpha, Nnet *dest); /// Returns the total of the number of parameters in the updatable components of diff --git a/src/nnet3bin/Makefile b/src/nnet3bin/Makefile index 51ae535eebf..9fa21de442d 100644 --- a/src/nnet3bin/Makefile +++ b/src/nnet3bin/Makefile @@ -12,7 +12,7 @@ BINFILES = nnet3-init nnet3-info nnet3-get-egs nnet3-copy-egs nnet3-subset-egs \ nnet3-am-adjust-priors nnet3-am-copy nnet3-compute-prob \ nnet3-average nnet3-am-info nnet3-combine nnet3-latgen-faster \ nnet3-copy nnet3-show-progress nnet3-align-compiled \ - nnet3-get-egs-dense-targets nnet3-compute + nnet3-get-egs-dense-targets nnet3-compute nnet3-modify-learning-rates OBJFILES = diff --git a/src/nnet3bin/nnet3-am-copy.cc b/src/nnet3bin/nnet3-am-copy.cc index 907a9993918..dd38288418e 100644 --- a/src/nnet3bin/nnet3-am-copy.cc +++ b/src/nnet3bin/nnet3-am-copy.cc @@ -47,6 +47,7 @@ int main(int argc, char *argv[]) { bool binary_write = true, raw = false; BaseFloat learning_rate = -1; + BaseFloat learning_rate_scale = 1; std::string set_raw_nnet = ""; bool convert_repeated_to_block = false; BaseFloat scale = 1.0; @@ -66,6 +67,9 @@ int main(int argc, char *argv[]) { po.Register("learning-rate", &learning_rate, "If supplied, all the learning rates of updatable components" " are set to this value."); + po.Register("learning-rate-scale", &learning_rate_scale, + "Scales the learning rate of updatable components by this " + "factor"); po.Register("scale", &scale, "The parameter matrices are scaled" " by the specified value."); @@ -100,6 +104,11 @@ int main(int argc, char *argv[]) { if (learning_rate >= 0) SetLearningRate(learning_rate, &(am_nnet.GetNnet())); + + KALDI_ASSERT(learning_rate_scale >= 0.0); + + if (learning_rate_scale != 1.0) + ScaleLearningRate(learning_rate_scale, &(am_nnet.GetNnet())); if (scale != 1.0) ScaleNnet(scale, &(am_nnet.GetNnet())); diff --git a/src/nnet3bin/nnet3-modify-learning-rates.cc b/src/nnet3bin/nnet3-modify-learning-rates.cc new file mode 100644 index 00000000000..89e14a5e819 --- /dev/null +++ b/src/nnet3bin/nnet3-modify-learning-rates.cc @@ -0,0 +1,186 @@ +// nnet3bin/nnet3-modify-learning-rates.cc + +// Copyright 2013 Guoguo Chen +// 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "hmm/transition-model.h" +#include "nnet3/nnet-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace kaldi::nnet3; + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "This program modifies the learning rates so as to equalize the\n" + "relative changes in parameters for each layer, while keeping their\n" + "geometric mean the same (or changing it to a value specified using\n" + "the --average-learning-rate option).\n" + "\n" + "Usage: nnet3-modify-learning-rates [options] \\\n" + " \n" + "e.g.: nnet-modify-learning-rates --average-learning-rate=0.0002 \\\n" + " 5.mdl 6.mdl 6.mdl\n"; + + bool binary_write = true; + bool retroactive = false; + BaseFloat average_learning_rate = 0.0; + BaseFloat first_layer_factor = 1.0; + BaseFloat last_layer_factor = 1.0; + + ParseOptions po(usage); + po.Register("binary", &binary_write, "Write output in binary mode"); + po.Register("average-learning-rate", &average_learning_rate, + "If supplied, change learning rate geometric mean to the given " + "value."); + po.Register("first-layer-factor", &first_layer_factor, "Factor that " + "reduces the target relative learning rate for first layer."); + po.Register("last-layer-factor", &last_layer_factor, "Factor that " + "reduces the target relative learning rate for last layer."); + po.Register("retroactive", &retroactive, "If true, scale the parameter " + "differences as well."); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + KALDI_ASSERT(average_learning_rate >= 0); + + std::string prev_nnet_rxfilename = po.GetArg(1), + cur_nnet_rxfilename = po.GetArg(2), + modified_cur_nnet_rxfilename = po.GetOptArg(3); + + TransitionModel trans_model; + Nnet prev_nnet, cur_nnet; + { + bool binary_read; + Input ki(prev_nnet_rxfilename, &binary_read); + prev_nnet.Read(ki.Stream(), binary_read); + } + { + bool binary_read; + Input ki(cur_nnet_rxfilename, &binary_read); + cur_nnet.Read(ki.Stream(), binary_read); + } + + int32 ret = 0; + + // Get info about magnitude of parameter change. + Nnet diff_nnet(prev_nnet); + AddNnet(cur_nnet, -1.0, &diff_nnet); + int32 num_updatable = NumUpdatableComponents(diff_nnet); + Vector dot_prod(num_updatable); + ComponentDotProducts(diff_nnet, diff_nnet, &dot_prod); + dot_prod.ApplyPow(0.5); // take sqrt to get l2 norm of diff + KALDI_LOG << "Parameter differences per layer are " + << PrintVectorPerUpdatableComponent(prev_nnet, dot_prod); + + Vector baseline_prod(num_updatable); + ComponentDotProducts(prev_nnet, prev_nnet, &baseline_prod); + baseline_prod.ApplyPow(0.5); + dot_prod.DivElements(baseline_prod); + KALDI_LOG << "Relative parameter differences per layer are " + << PrintVectorPerUpdatableComponent(prev_nnet, dot_prod); + + // If relative parameter difference for a certain is zero, set it to the + // mean of the rest values. + int32 num_zero = 0; + for (int32 i = 0; i < num_updatable; i++) { + if (dot_prod(i) == 0.0) { + num_zero++; + } + } + + if (num_zero > 0) { + BaseFloat average_diff = dot_prod.Sum() + / static_cast(num_updatable - num_zero); + for (int32 i = 0; i < num_updatable; i++) { + if (dot_prod(i) == 0.0) { + dot_prod(i) = average_diff; + } + } + KALDI_LOG << "Zeros detected in the relative parameter difference " + << "vector, updating the vector to " << dot_prod ; + } + + // Gets learning rates for previous neural net. + Vector prev_nnet_learning_rates(num_updatable), + cur_nnet_learning_rates(num_updatable); + GetLearningRates(prev_nnet, &prev_nnet_learning_rates); + GetLearningRates(cur_nnet, &cur_nnet_learning_rates); + KALDI_LOG << "Learning rates for previous model per layer are " + << prev_nnet_learning_rates; + KALDI_LOG << "Learning rates for current model per layer are " + << cur_nnet_learning_rates; + + // Gets target geometric mean. + BaseFloat target_geometric_mean = 0.0; + if (average_learning_rate == 0.0) { + target_geometric_mean = Exp(cur_nnet_learning_rates.SumLog() + / static_cast(num_updatable)); + } else { + target_geometric_mean = average_learning_rate; + } + KALDI_ASSERT(target_geometric_mean > 0.0); + + // Works out the new learning rates. We start from the previous model; + // this ensures that if this program is run twice, we get consistent + // results even if it's overwritten the current model. + Vector nnet_learning_rates(prev_nnet_learning_rates); + nnet_learning_rates.DivElements(dot_prod); + KALDI_ASSERT(last_layer_factor > 0.0); + nnet_learning_rates(num_updatable - 1) *= last_layer_factor; + KALDI_ASSERT(first_layer_factor > 0.0); + nnet_learning_rates(0) *= first_layer_factor; + BaseFloat cur_geometric_mean = Exp(nnet_learning_rates.SumLog() + / static_cast(num_updatable)); + nnet_learning_rates.Scale(target_geometric_mean / cur_geometric_mean); + KALDI_LOG << "New learning rates for current model per layer are " + << nnet_learning_rates; + + // Changes the parameter differences if --retroactivate is set to true. + if (retroactive) { + Vector scale_factors(nnet_learning_rates); + scale_factors.DivElements(prev_nnet_learning_rates); + AddNnet(prev_nnet, -1.0, &cur_nnet); + ScaleNnetComponents(scale_factors, &cur_nnet); + AddNnet(prev_nnet, 1.0, &cur_nnet); + KALDI_LOG << "Scale parameter difference retroactively. Scaling factors " + << "are " << scale_factors; + } + + // Sets learning rates and writes updated model. + SetLearningRates(nnet_learning_rates, &cur_nnet); + + Output ko(modified_cur_nnet_rxfilename, binary_write); + cur_nnet.Write(ko.Stream(), binary_write); + + return ret; + } catch(const std::exception &e) { + std::cerr << e.what() << '\n'; + return -1; + } +} +