diff --git a/egs/ami/s5/local/run_nnet3_rnnlm.sh b/egs/ami/s5/local/run_nnet3_rnnlm.sh new file mode 100644 index 00000000000..76696de0dd9 --- /dev/null +++ b/egs/ami/s5/local/run_nnet3_rnnlm.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +mic=sdm1 +crit=vr +n=50 +ngram_order=4 + +. ./utils/parse_options.sh +. ./cmd.sh +. ./path.sh + +set -e + +local/train_cued_rnnlms.sh --crit $crit --train-text data/$mic/train/text data/$mic/cued_rnn_$crit + +final_lm=ami_fsh.o3g.kn +LM=$final_lm.pr1-7 + +for decode_set in dev eval; do + dir=exp/$mic/nnet3/tdnn_sp/ + decode_dir=${dir}/decode_${decode_set} + + # N-best rescoring + steps/rnnlmrescore.sh \ + --rnnlm-ver cuedrnnlm \ + --N $n --cmd "$decode_cmd --mem 16G" --inv-acwt 10 0.5 \ + data/lang_$LM data/$mic/cued_rnn_$crit \ + data/$mic/$decode_set ${decode_dir} \ + ${decode_dir}.rnnlm.$crit.cued.$n-best + + # Lattice rescoring + steps/lmrescore_rnnlm_lat.sh \ + --cmd "$decode_cmd --mem 16G" \ + --rnnlm-ver nnet3rnnlm --weight 0.5 --max-ngram-order $ngram_order \ + data/lang_$LM data/$mic/cued_rnn_$crit \ + data/$mic/${decode_set}_hires ${decode_dir} \ + ${decode_dir}.rnnlm.$crit.cued.lat.${ngram_order}gram + +done diff --git a/egs/wsj/s5/steps/lmrescore_rnnlm_lat.sh b/egs/wsj/s5/steps/lmrescore_rnnlm_lat.sh index 75b08bc4779..65ad5315c48 100755 --- a/egs/wsj/s5/steps/lmrescore_rnnlm_lat.sh +++ b/egs/wsj/s5/steps/lmrescore_rnnlm_lat.sh @@ -14,6 +14,7 @@ N=10 inv_acwt=12 weight=1.0 # Interpolation weight for RNNLM. # End configuration section. +rnnlm_ver= echo "$0 $@" # Print the command line for logging @@ -39,6 +40,18 @@ data=$3 indir=$4 outdir=$5 +rescoring_binary=lattice-lmrescore-rnnlm + +first_arg=ark:$rnnlm_dir/unk.probs # this is for mikolov's rnnlm +extra_arg= + +if [ "$rnnlm_ver" == "nnet3rnnlm" ]; then + total_size=`wc -l $rnnlm_dir/unigram.counts | awk '{print $1}'` + rescoring_binary="lattice-lmrescore-nnet3-rnnlm" + cat $rnnlm_dir/rnnlm.input.wlist.index | tail -n +2 | awk '{print $1-1,$2}' > $rnnlm_dir/rnn.wlist + first_arg=$rnnlm_dir/rnn.wlist +fi + oldlm=$oldlang/G.fst if [ -f $oldlang/G.carpa ]; then oldlm=$oldlang/G.carpa @@ -72,20 +85,19 @@ if [ "$oldlm" == "$oldlang/G.fst" ]; then $cmd JOB=1:$nj $outdir/log/rescorelm.JOB.log \ lattice-lmrescore --lm-scale=$oldlm_weight \ "ark:gunzip -c $indir/lat.JOB.gz|" "$oldlm_command" ark:- \| \ - lattice-lmrescore-rnnlm --lm-scale=$weight \ - --max-ngram-order=$max_ngram_order ark:$rnnlm_dir/unk.probs \ - $oldlang/words.txt ark:- "$rnnlm_dir/rnnlm" \ + $rescoring_binary $extra_arg --lm-scale=$weight \ + --max-ngram-order=$max_ngram_order \ + $first_arg $oldlang/words.txt ark:- "$rnnlm_dir/rnnlm" \ "ark,t:|gzip -c>$outdir/lat.JOB.gz" || exit 1; else $cmd JOB=1:$nj $outdir/log/rescorelm.JOB.log \ lattice-lmrescore-const-arpa --lm-scale=$oldlm_weight \ - "ark:gunzip -c $indir/lat.JOB.gz|" "$oldlm" ark:- \| \ - lattice-lmrescore-rnnlm --lm-scale=$weight \ - --max-ngram-order=$max_ngram_order ark:$rnnlm_dir/unk.probs \ - $oldlang/words.txt ark:- "$rnnlm_dir/rnnlm" \ + "ark:gunzip -c $indir/lat.JOB.gz|" "$oldlm_command" ark:- \| \ + $rescoring_binary $extra_arg --lm-scale=$weight \ + --max-ngram-order=$max_ngram_order \ + $first_arg $oldlang/words.txt ark:- "$rnnlm_dir/rnnlm" \ "ark,t:|gzip -c>$outdir/lat.JOB.gz" || exit 1; fi - if ! $skip_scoring ; then err_msg="Not scoring because local/score.sh does not exist or not executable." [ ! -x local/score.sh ] && echo $err_msg && exit 1; diff --git a/src/latbin/Makefile b/src/latbin/Makefile index 357e90f2349..ddd4fa2f80a 100644 --- a/src/latbin/Makefile +++ b/src/latbin/Makefile @@ -4,6 +4,10 @@ all: EXTRA_CXXFLAGS = -Wno-sign-compare include ../kaldi.mk +LDFLAGS += $(CUDA_LDFLAGS) +LDLIBS += $(CUDA_LDLIBS) + + BINFILES = lattice-best-path lattice-prune lattice-equivalent lattice-to-nbest \ lattice-lmrescore lattice-scale lattice-union lattice-to-post \ lattice-determinize lattice-oracle lattice-rmali \ @@ -21,6 +25,7 @@ BINFILES = lattice-best-path lattice-prune lattice-equivalent lattice-to-nbest \ lattice-confidence lattice-determinize-phone-pruned \ lattice-determinize-phone-pruned-parallel lattice-expand-ngram \ lattice-lmrescore-const-arpa lattice-lmrescore-rnnlm nbest-to-prons \ + lattice-lmrescore-nnet3-rnnlm \ lattice-arc-post lattice-determinize-non-compact OBJFILES = @@ -29,9 +34,10 @@ OBJFILES = TESTFILES = -ADDLIBS = ../lat/kaldi-lat.a ../lm/kaldi-lm.a ../fstext/kaldi-fstext.a \ - ../hmm/kaldi-hmm.a ../tree/kaldi-tree.a ../util/kaldi-util.a \ - ../thread/kaldi-thread.a ../matrix/kaldi-matrix.a \ - ../base/kaldi-base.a +ADDLIBS = ../rnnlm/kaldi-rnnlm.a ../lat/kaldi-lat.a ../lm/kaldi-lm.a \ + ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a ../tree/kaldi-tree.a \ + ../util/kaldi-util.a ../thread/kaldi-thread.a \ + ../cudamatrix/kaldi-cudamatrix.a ../matrix/kaldi-matrix.a \ + ../nnet3/kaldi-nnet3.a ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/latbin/lattice-lmrescore-nnet3-rnnlm.cc b/src/latbin/lattice-lmrescore-nnet3-rnnlm.cc new file mode 100644 index 00000000000..776c0608d0d --- /dev/null +++ b/src/latbin/lattice-lmrescore-nnet3-rnnlm.cc @@ -0,0 +1,139 @@ +// latbin/lattice-lmrescore-nnet3-rnnlm.cc + +// Copyright 2017 Johns Hopkins University (author: Daniel Povey) +// Yiming Wang + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "fstext/fstext-lib.h" +#include "lat/kaldi-lattice.h" +#include "lat/lattice-functions.h" +#include "rnnlm/kaldi-rnnlm-rescoring.h" +#include "util/common-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "Rescores lattice with rnnlm. The LM will be wrapped into the\n" + "DeterministicOnDemandFst interface and the rescoring is done by\n" + "composing with the wrapped LM using a special type of composition\n" + "algorithm. Determinization will be applied on the composed lattice.\n" + "\n" + "Usage: lattice-lmrescore-nnet3-rnnlm [options] \\\n" + " \\\n" + " \n" + " e.g.: lattice-lmrescore-nnet3-rnnlm --lm-scale=-1.0 words.txt \\\n" + " ark:in.lats rnnlm ark:out.lats\n"; + + ParseOptions po(usage); + int32 max_ngram_order = 3; + BaseFloat lm_scale = 1.0; + + po.Register("lm-scale", &lm_scale, "Scaling factor for language model " + "costs; frequently 1.0 or -1.0"); + po.Register("max-ngram-order", &max_ngram_order, "If positive, limit the " + "rnnlm context to the given number, -1 means we are not going " + "to limit it."); + + po.Read(argc, argv); + + if (po.NumArgs() != 4 && po.NumArgs() != 5) { + po.PrintUsage(); + exit(1); + } + + std::string lats_rspecifier, rnn_wordlist, + word_symbols_rxfilename, rnnlm_rxfilename, lats_wspecifier; + KALDI_ASSERT (po.NumArgs() == 5); + + rnn_wordlist = po.GetArg(1); + word_symbols_rxfilename = po.GetArg(2); + lats_rspecifier = po.GetArg(3); + rnnlm_rxfilename = po.GetArg(4); + lats_wspecifier = po.GetArg(5); + + // Reads the language model. + rnnlm::LmNnet lm_nnet; + ReadKaldiObject(rnnlm_rxfilename, &lm_nnet); + + const nnet3::DecodableRnnlmSimpleLoopedComputationOptions opts; + const nnet3::DecodableRnnlmSimpleLoopedInfo info(opts, lm_nnet); + + // Reads and writes as compact lattice. + SequentialCompactLatticeReader compact_lattice_reader(lats_rspecifier); + CompactLatticeWriter compact_lattice_writer(lats_wspecifier); + + int32 n_done = 0, n_fail = 0; + for (; !compact_lattice_reader.Done(); compact_lattice_reader.Next()) { + std::string key = compact_lattice_reader.Key(); + CompactLattice clat = compact_lattice_reader.Value(); + compact_lattice_reader.FreeCurrent(); + + if (lm_scale != 0.0) { + // Before composing with the LM FST, we scale the lattice weights + // by the inverse of "lm_scale". We'll later scale by "lm_scale". + // We do it this way so we can determinize and it will give the + // right effect (taking the "best path" through the LM) regardless + // of the sign of lm_scale. + fst::ScaleLattice(fst::GraphLatticeScale(1.0 / lm_scale), &clat); + ArcSort(&clat, fst::OLabelCompare()); + + // Wraps the rnnlm into FST. We re-create it for each lattice to prevent + // memory usage increasing with time. + nnet3::KaldiRnnlmDeterministicFst rnnlm_fst(max_ngram_order, + rnn_wordlist, + word_symbols_rxfilename, + info); + + // Composes lattice with language model. + CompactLattice composed_clat; + ComposeCompactLatticeDeterministic(clat, &rnnlm_fst, &composed_clat); + + // Determinizes the composed lattice. + Lattice composed_lat; + ConvertLattice(composed_clat, &composed_lat); + Invert(&composed_lat); + CompactLattice determinized_clat; + DeterminizeLattice(composed_lat, &determinized_clat); + fst::ScaleLattice(fst::GraphLatticeScale(lm_scale), &determinized_clat); + if (determinized_clat.Start() == fst::kNoStateId) { + KALDI_WARN << "Empty lattice for utterance " << key + << " (incompatible LM?)"; + n_fail++; + } else { + compact_lattice_writer.Write(key, determinized_clat); + n_done++; + } + } else { + // Zero scale so nothing to do. + n_done++; + compact_lattice_writer.Write(key, clat); + } + } + + KALDI_LOG << "Done " << n_done << " lattices, failed for " << n_fail; + return (n_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/rnnlm/Makefile b/src/rnnlm/Makefile index 3a11f0556ed..7cacfa686df 100644 --- a/src/rnnlm/Makefile +++ b/src/rnnlm/Makefile @@ -12,6 +12,7 @@ OBJFILES = rnnlm-component-itf.o rnnlm-utils.o rnnlm-nnet.o rnnlm-component.o rnnlm-training.o \ rnnlm-diagnostics.o \ arpa-sampling.o \ + kaldi-rnnlm-decodable-simple-looped.o kaldi-rnnlm-rescoring.o LIBNAME = kaldi-rnnlm diff --git a/src/rnnlm/kaldi-rnnlm-decodable-simple-looped.cc b/src/rnnlm/kaldi-rnnlm-decodable-simple-looped.cc new file mode 100644 index 00000000000..202ef3b77bd --- /dev/null +++ b/src/rnnlm/kaldi-rnnlm-decodable-simple-looped.cc @@ -0,0 +1,172 @@ +// rnnlm/kaldi-rnnlm-decodable-simple-looped.cc + +// Copyright 2017 Johns Hopkins University (author: Daniel Povey) +// 2017 Yiming Wang + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "rnnlm/kaldi-rnnlm-decodable-simple-looped.h" +#include "nnet3/nnet-utils.h" +#include "nnet3/nnet-compile-looped.h" + +namespace kaldi { +namespace nnet3 { + + +DecodableRnnlmSimpleLoopedInfo::DecodableRnnlmSimpleLoopedInfo( + const DecodableRnnlmSimpleLoopedComputationOptions &opts, + const rnnlm::LmNnet &lm_nnet): + opts(opts), lm_nnet(lm_nnet) { + Init(opts, lm_nnet); +} + +void DecodableRnnlmSimpleLoopedInfo::Init( + const DecodableRnnlmSimpleLoopedComputationOptions &opts, + const rnnlm::LmNnet &lm_nnet) { + opts.Check(); + KALDI_ASSERT(IsSimpleNnet(lm_nnet.Nnet())); + int32 left_context, right_context; + ComputeSimpleNnetContext(lm_nnet.Nnet(), &left_context, &right_context); + frames_left_context = opts.extra_left_context_initial + left_context; + frames_right_context = right_context; + int32 frame_subsampling_factor = 1; + frames_per_chunk = GetChunkSize(lm_nnet.Nnet(), frame_subsampling_factor, + opts.frames_per_chunk); + KALDI_ASSERT(frames_per_chunk == opts.frames_per_chunk); + nnet_output_dim = lm_nnet.Nnet().OutputDim("output"); + KALDI_ASSERT(nnet_output_dim > 0); + + int32 ivector_period = frames_per_chunk; + int32 extra_right_context = 0; + int32 num_sequences = 1; // we're processing one utterance at a time. + CreateLoopedComputationRequestSimple(lm_nnet.Nnet(), frames_per_chunk, + frame_subsampling_factor, + ivector_period, + opts.extra_left_context_initial, + extra_right_context, + num_sequences, + &request1, &request2, &request3); + + CompileLooped(lm_nnet.Nnet(), opts.optimize_config, request1, request2, + request3, &computation); + computation.ComputeCudaIndexes(); + if (GetVerboseLevel() >= 3) { + KALDI_VLOG(3) << "Computation is:"; + computation.Print(std::cerr, lm_nnet.Nnet()); + } +} + +DecodableRnnlmSimpleLooped::DecodableRnnlmSimpleLooped( + const DecodableRnnlmSimpleLoopedInfo &info) : + info_(info), + computer_(info_.opts.compute_config, info_.computation, + info_.lm_nnet.Nnet(), NULL), // NULL is 'nnet_to_update' + // since everytime we provide one chunk to the decodable object, the size of + // feats_ == frames_per_chunk + feats_(info_.frames_per_chunk, + info_.lm_nnet.InputLayer()->InputDim()), + current_log_post_offset_(-1) { + num_frames_ = feats_.NumRows(); +} + +void DecodableRnnlmSimpleLooped::TakeFeatures( + const std::vector &word_indexes) { + KALDI_ASSERT(word_indexes.size() == num_frames_); + std::vector > > + pairs(word_indexes.size()); + for (int32 i = 0; i < word_indexes.size(); i++) { + std::pair one_hot_index(word_indexes[i], 1.0); + std::vector > row(1, one_hot_index); + pairs[i] = row; + } + SparseMatrix feats_temp(feats_.NumCols(), pairs); + feats_.Swap(&feats_temp); + // resets offset so that AdvanceChunk() would be called in GetOutput() and + // GetNnetOutputForFrame() after taking new features + current_log_post_offset_ = -1; +} + +void DecodableRnnlmSimpleLooped::GetNnetOutputForFrame( + int32 frame, VectorBase *output) { + KALDI_ASSERT(frame >= 0 && frame < feats_.NumRows()); + if (frame >= current_log_post_offset_ + current_nnet_output_.NumRows()) + AdvanceChunk(); + output->CopyFromVec(current_nnet_output_.Row(frame - + current_log_post_offset_)); +} + +BaseFloat DecodableRnnlmSimpleLooped::GetOutput(int32 frame, int32 word_index) { + KALDI_ASSERT(frame >= 0 && frame < feats_.NumRows()); + if (frame >= current_log_post_offset_ + current_nnet_output_.NumRows()) + AdvanceChunk(); + + const rnnlm::LmOutputComponent* output_layer = info_.lm_nnet.OutputLayer(); + CuMatrix current_nnet_output_gpu; + current_nnet_output_gpu.Swap(¤t_nnet_output_); + const CuSubVector hidden(current_nnet_output_gpu, + frame - current_log_post_offset_); + BaseFloat log_prob = + output_layer->ComputeLogprobOfWordGivenHistory(hidden, word_index); + // swap the pointer back so that this function can be called multiple times + // with the same returned value before taking next new feats + current_nnet_output_.Swap(¤t_nnet_output_gpu); + return log_prob; +} + +void DecodableRnnlmSimpleLooped::AdvanceChunk() { + int32 begin_input_frame, end_input_frame; + begin_input_frame = -info_.frames_left_context; + // note: end is last plus one. + end_input_frame = info_.frames_per_chunk + info_.frames_right_context; + // currently there is no left/right context and frames_per_chunk == 1 + KALDI_ASSERT(begin_input_frame == 0 && end_input_frame == 1); + + SparseMatrix feats_chunk(end_input_frame - begin_input_frame, + feats_.NumCols()); + int32 num_features = feats_.NumRows(); + for (int32 r = begin_input_frame; r < end_input_frame; r++) { + int32 input_frame = r; + if (input_frame < 0) input_frame = 0; + if (input_frame >= num_features) input_frame = num_features - 1; + feats_chunk.SetRow(r - begin_input_frame, feats_.Row(input_frame)); + } + + const rnnlm::LmInputComponent* input_layer = info_.lm_nnet.InputLayer(); + CuMatrix new_input(feats_chunk.NumRows(), input_layer->OutputDim()); + input_layer->Propagate(feats_chunk, &new_input); + + computer_.AcceptInput("input", &new_input); + + computer_.Run(); + + { + // Note: here GetOutput() is used instead of GetOutputDestructive(), since + // here we have recurrence that goes directly from the output, and the call + // to GetOutputDestructive() would cause a crash on the next chunk. + CuMatrix output(computer_.GetOutput("output")); + + current_nnet_output_.Resize(0, 0); + current_nnet_output_.Swap(&output); + } + KALDI_ASSERT(current_nnet_output_.NumRows() == info_.frames_per_chunk && + current_nnet_output_.NumCols() == info_.nnet_output_dim); + + current_log_post_offset_ = 0; +} + + +} // namespace nnet3 +} // namespace kaldi diff --git a/src/rnnlm/kaldi-rnnlm-decodable-simple-looped.h b/src/rnnlm/kaldi-rnnlm-decodable-simple-looped.h new file mode 100644 index 00000000000..4c35a8c6215 --- /dev/null +++ b/src/rnnlm/kaldi-rnnlm-decodable-simple-looped.h @@ -0,0 +1,184 @@ +// rnnlm/kaldi-rnnlm-decodable-simple-looped.h + +// Copyright 2017 Johns Hopkins University (author: Daniel Povey) +// 2017 Yiming Wang + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_RNNLM_DECODABLE_SIMPLE_LOOPED_H_ +#define KALDI_RNNLM_DECODABLE_SIMPLE_LOOPED_H_ + +#include +#include "base/kaldi-common.h" +#include "gmm/am-diag-gmm.h" +#include "hmm/transition-model.h" +#include "itf/decodable-itf.h" +#include "nnet3/nnet-optimize.h" +#include "nnet3/nnet-compute.h" +#include "nnet3/am-nnet-simple.h" +#include "rnnlm/rnnlm-nnet.h" + +namespace kaldi { +namespace nnet3 { + +// See also nnet-am-decodable-simple.h, which is a decodable object that's based +// on breaking up the input into fixed chunks. The decodable object defined here is based on +// 'looped' computations, which naturally handle infinite left-context (but are +// only ideal for systems that have only recurrence in the forward direction, +// i.e. not BLSTMs... because there isn't a natural way to enforce extra right +// context for each chunk.) + + +// Note: the 'simple' in the name means it applies to networks for which +// IsSimpleNnet(nnet) would return true. 'looped' means we use looped +// computations, with a kGotoLabel statement at the end of it. +struct DecodableRnnlmSimpleLoopedComputationOptions { + int32 extra_left_context_initial; + int32 frames_per_chunk; + bool debug_computation; + NnetOptimizeOptions optimize_config; + NnetComputeOptions compute_config; + DecodableRnnlmSimpleLoopedComputationOptions(): + extra_left_context_initial(0), + frames_per_chunk(1), + debug_computation(false) { } + + void Check() const { + KALDI_ASSERT(extra_left_context_initial >= 0 && frames_per_chunk > 0); + } + + void Register(OptionsItf *opts) { + opts->Register("extra-left-context-initial", &extra_left_context_initial, + "Extra left context to use at the first frame of an utterance (note: " + "this will just consist of repeats of the first frame, and should not " + "usually be necessary."); + opts->Register("frames-per-chunk", &frames_per_chunk, + "Number of frames in each chunk that is separately evaluated " + "by the neural net."); + opts->Register("debug-computation", &debug_computation, "If true, turn on " + "debug for the actual computation (very verbose!)"); + + // register the optimization options with the prefix "optimization". + ParseOptions optimization_opts("optimization", opts); + optimize_config.Register(&optimization_opts); + + // register the compute options with the prefix "computation". + ParseOptions compute_opts("computation", opts); + compute_config.Register(&compute_opts); + } +}; + + +/** + When you instantiate class DecodableNnetSimpleLooped, you should give it + a const reference to this class, that has been previously initialized. + */ +class DecodableRnnlmSimpleLoopedInfo { + public: + DecodableRnnlmSimpleLoopedInfo( + const DecodableRnnlmSimpleLoopedComputationOptions &opts, + const rnnlm::LmNnet &lm_nnet); + + void Init(const DecodableRnnlmSimpleLoopedComputationOptions &opts, + const rnnlm::LmNnet &lm_nnet); + + const DecodableRnnlmSimpleLoopedComputationOptions &opts; + + const rnnlm::LmNnet &lm_nnet; + + // frames_left_context equals the model left context plus the value of the + // --extra-left-context-initial option. + int32 frames_left_context; + // frames_right_context is the same as the right-context of the model. + int32 frames_right_context; + // The frames_per_chunk equals the number of input frames we need for each + // chunk (except for the first chunk). + int32 frames_per_chunk; + + // The output dimension of the nnet neural network (not the final output). + int32 nnet_output_dim; + + // The 3 computation requests that are used to create the looped + // computation are stored in the class, as we need them to work out + // exactly shich iVectors are needed. + ComputationRequest request1, request2, request3; + + // The compiled, 'looped' computation. + NnetComputation computation; +}; + +/* + This class handles the neural net computation; it's mostly accessed + via other wrapper classes. + + It accept just input features */ +class DecodableRnnlmSimpleLooped { + public: + /** + This constructor takes features as input. + Note: it stores references to all arguments to the constructor, so don't + delete them till this goes out of scope. + + @param [in] info This helper class contains all the static pre-computed information + this class needs, and contains a pointer to the neural net. + @param [in] feats The input feature matrix. + */ + DecodableRnnlmSimpleLooped(const DecodableRnnlmSimpleLoopedInfo &info); + + // returns the number of frames of likelihoods. The same as feats_.NumRows() + inline int32 NumFrames() const { return num_frames_; } + + inline int32 NnetOutputDim() const { return info_.nnet_output_dim; } + + // Gets the nnet's output for a particular frame, with 0 <= frame < NumFrames(). + // 'output' must be correctly sized (with dimension NnetOutputDim()). Note: + // you're expected to call this, and GetOutput(), in an order of increasing + // frames. If you deviate from this, one of these calls may crash. + void GetNnetOutputForFrame(int32 frame, VectorBase *output); + + // Updates feats_ with the new incoming word specified in word_indexes + void TakeFeatures(const std::vector &word_indexes); + + // Gets the output for a particular frame and word_index, with + // 0 <= frame < NumFrames(). + BaseFloat GetOutput(int32 frame, int32 word_index); + + private: + // This function does the computation for the next chunk. + void AdvanceChunk(); + + const DecodableRnnlmSimpleLoopedInfo &info_; + + NnetComputer computer_; + + SparseMatrix feats_; + + int32 num_frames_; + + // The current nnet's output that we got from the last time we + // ran the computation. + Matrix current_nnet_output_; + + // The time-offset of the current log-posteriors, equals + // -1 when initialized, or 0 once AdvanceChunk() was called + int32 current_log_post_offset_; +}; + + +} // namespace nnet3 +} // namespace kaldi + +#endif // KALDI_RNNLM_DECODABLE_SIMPLE_LOOPED_H_ diff --git a/src/rnnlm/kaldi-rnnlm-rescoring.cc b/src/rnnlm/kaldi-rnnlm-rescoring.cc new file mode 100644 index 00000000000..88ec94d9dc5 --- /dev/null +++ b/src/rnnlm/kaldi-rnnlm-rescoring.cc @@ -0,0 +1,160 @@ +// rnnlm/kaldi-rnnlm-rescoring.cc + +// Copyright 2017 Johns Hopkins University (author: Daniel Povey) +// Yiming Wang +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "rnnlm/kaldi-rnnlm-rescoring.h" +#include "util/stl-utils.h" +#include "util/text-utils.h" + +namespace kaldi { +namespace nnet3 { + +void KaldiRnnlmDeterministicFst::ReadFstWordSymbolTableAndRnnWordlist( + const std::string &rnn_wordlist, + const std::string &word_symbol_table_rxfilename) { + // Reads symbol table. + fst::SymbolTable *fst_word_symbols = NULL; + if (!(fst_word_symbols = + fst::SymbolTable::ReadText(word_symbol_table_rxfilename))) { + KALDI_ERR << "Could not read symbol table from file " + << word_symbol_table_rxfilename; + } + + full_voc_size_ = fst_word_symbols->NumSymbols(); + fst_label_to_word_.resize(full_voc_size_); + + for (int32 i = 0; i < fst_label_to_word_.size(); ++i) { + fst_label_to_word_[i] = fst_word_symbols->Find(i); + if (fst_label_to_word_[i] == "") { + KALDI_ERR << "Could not find word for integer " << i << "in the word " + << "symbol table, mismatched symbol table or you have discoutinuous " + << "integers in your symbol table?"; + } + } + + fst_label_to_rnn_label_.resize(fst_word_symbols->NumSymbols(), -1); + + rnn_label_to_word_.push_back(""); + rnn_label_to_word_.push_back(""); + out_OOS_index_ = 1; + { // input + std::ifstream ifile(rnn_wordlist.c_str()); + int32 id; + string word; + int32 i = 1; + while (ifile >> id >> word) { // TODO(hxu) ugly fix for cued-rnnlm's bug + // will implement a better fix later + if (word == "[UNK]") { + word = ""; + } else if (word == "") { + continue; + } + i++; + assert(i == id + 2); + rnn_label_to_word_.push_back(word); + + int fst_label = fst_word_symbols->Find(rnn_label_to_word_[i]); + KALDI_ASSERT(fst::SymbolTable::kNoSymbol != fst_label); + fst_label_to_rnn_label_[fst_label] = i; + } + } + + for (int32 i = 0; i < fst_label_to_rnn_label_.size(); i++) { + if (fst_label_to_rnn_label_[i] == -1) { + fst_label_to_rnn_label_[i] = out_OOS_index_; + } + } +} + +KaldiRnnlmDeterministicFst::KaldiRnnlmDeterministicFst(int32 max_ngram_order, + const std::string &rnn_wordlist, + const std::string &word_symbol_table_rxfilename, + const DecodableRnnlmSimpleLoopedInfo &info) { + max_ngram_order_ = max_ngram_order; + ReadFstWordSymbolTableAndRnnWordlist(rnn_wordlist, + word_symbol_table_rxfilename); + + std::vector