Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions egs/ami/s5/local/run_nnet3_rnnlm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash

mic=sdm1
crit=vr
n=50
ngram_order=4

. ./utils/parse_options.sh
. ./cmd.sh
. ./path.sh

set -e

local/train_cued_rnnlms.sh --crit $crit --train-text data/$mic/train/text data/$mic/cued_rnn_$crit

final_lm=ami_fsh.o3g.kn
LM=$final_lm.pr1-7

for decode_set in dev eval; do
dir=exp/$mic/nnet3/tdnn_sp/
decode_dir=${dir}/decode_${decode_set}

# N-best rescoring
steps/rnnlmrescore.sh \
--rnnlm-ver cuedrnnlm \
--N $n --cmd "$decode_cmd --mem 16G" --inv-acwt 10 0.5 \
data/lang_$LM data/$mic/cued_rnn_$crit \
data/$mic/$decode_set ${decode_dir} \
${decode_dir}.rnnlm.$crit.cued.$n-best

# Lattice rescoring
steps/lmrescore_rnnlm_lat.sh \
--cmd "$decode_cmd --mem 16G" \
--rnnlm-ver nnet3rnnlm --weight 0.5 --max-ngram-order $ngram_order \
data/lang_$LM data/$mic/cued_rnn_$crit \
data/$mic/${decode_set}_hires ${decode_dir} \
${decode_dir}.rnnlm.$crit.cued.lat.${ngram_order}gram

done
28 changes: 20 additions & 8 deletions egs/wsj/s5/steps/lmrescore_rnnlm_lat.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ N=10
inv_acwt=12
weight=1.0 # Interpolation weight for RNNLM.
# End configuration section.
rnnlm_ver=

echo "$0 $@" # Print the command line for logging

Expand All @@ -39,6 +40,18 @@ data=$3
indir=$4
outdir=$5

rescoring_binary=lattice-lmrescore-rnnlm

first_arg=ark:$rnnlm_dir/unk.probs # this is for mikolov's rnnlm
extra_arg=

if [ "$rnnlm_ver" == "nnet3rnnlm" ]; then
total_size=`wc -l $rnnlm_dir/unigram.counts | awk '{print $1}'`
rescoring_binary="lattice-lmrescore-nnet3-rnnlm"
cat $rnnlm_dir/rnnlm.input.wlist.index | tail -n +2 | awk '{print $1-1,$2}' > $rnnlm_dir/rnn.wlist
first_arg=$rnnlm_dir/rnn.wlist
fi

oldlm=$oldlang/G.fst
if [ -f $oldlang/G.carpa ]; then
oldlm=$oldlang/G.carpa
Expand Down Expand Up @@ -72,20 +85,19 @@ if [ "$oldlm" == "$oldlang/G.fst" ]; then
$cmd JOB=1:$nj $outdir/log/rescorelm.JOB.log \
lattice-lmrescore --lm-scale=$oldlm_weight \
"ark:gunzip -c $indir/lat.JOB.gz|" "$oldlm_command" ark:- \| \
lattice-lmrescore-rnnlm --lm-scale=$weight \
--max-ngram-order=$max_ngram_order ark:$rnnlm_dir/unk.probs \
$oldlang/words.txt ark:- "$rnnlm_dir/rnnlm" \
$rescoring_binary $extra_arg --lm-scale=$weight \
--max-ngram-order=$max_ngram_order \
$first_arg $oldlang/words.txt ark:- "$rnnlm_dir/rnnlm" \
"ark,t:|gzip -c>$outdir/lat.JOB.gz" || exit 1;
else
$cmd JOB=1:$nj $outdir/log/rescorelm.JOB.log \
lattice-lmrescore-const-arpa --lm-scale=$oldlm_weight \
"ark:gunzip -c $indir/lat.JOB.gz|" "$oldlm" ark:- \| \
lattice-lmrescore-rnnlm --lm-scale=$weight \
--max-ngram-order=$max_ngram_order ark:$rnnlm_dir/unk.probs \
$oldlang/words.txt ark:- "$rnnlm_dir/rnnlm" \
"ark:gunzip -c $indir/lat.JOB.gz|" "$oldlm_command" ark:- \| \
$rescoring_binary $extra_arg --lm-scale=$weight \
--max-ngram-order=$max_ngram_order \
$first_arg $oldlang/words.txt ark:- "$rnnlm_dir/rnnlm" \
"ark,t:|gzip -c>$outdir/lat.JOB.gz" || exit 1;
fi

if ! $skip_scoring ; then
err_msg="Not scoring because local/score.sh does not exist or not executable."
[ ! -x local/score.sh ] && echo $err_msg && exit 1;
Expand Down
14 changes: 10 additions & 4 deletions src/latbin/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ all:
EXTRA_CXXFLAGS = -Wno-sign-compare
include ../kaldi.mk

LDFLAGS += $(CUDA_LDFLAGS)
LDLIBS += $(CUDA_LDLIBS)


BINFILES = lattice-best-path lattice-prune lattice-equivalent lattice-to-nbest \
lattice-lmrescore lattice-scale lattice-union lattice-to-post \
lattice-determinize lattice-oracle lattice-rmali \
Expand All @@ -21,6 +25,7 @@ BINFILES = lattice-best-path lattice-prune lattice-equivalent lattice-to-nbest \
lattice-confidence lattice-determinize-phone-pruned \
lattice-determinize-phone-pruned-parallel lattice-expand-ngram \
lattice-lmrescore-const-arpa lattice-lmrescore-rnnlm nbest-to-prons \
lattice-lmrescore-nnet3-rnnlm \
lattice-arc-post lattice-determinize-non-compact

OBJFILES =
Expand All @@ -29,9 +34,10 @@ OBJFILES =

TESTFILES =

ADDLIBS = ../lat/kaldi-lat.a ../lm/kaldi-lm.a ../fstext/kaldi-fstext.a \
../hmm/kaldi-hmm.a ../tree/kaldi-tree.a ../util/kaldi-util.a \
../thread/kaldi-thread.a ../matrix/kaldi-matrix.a \
../base/kaldi-base.a
ADDLIBS = ../rnnlm/kaldi-rnnlm.a ../lat/kaldi-lat.a ../lm/kaldi-lm.a \
../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a ../tree/kaldi-tree.a \
../util/kaldi-util.a ../thread/kaldi-thread.a \
../cudamatrix/kaldi-cudamatrix.a ../matrix/kaldi-matrix.a \
../nnet3/kaldi-nnet3.a ../base/kaldi-base.a

include ../makefiles/default_rules.mk
139 changes: 139 additions & 0 deletions src/latbin/lattice-lmrescore-nnet3-rnnlm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// latbin/lattice-lmrescore-nnet3-rnnlm.cc

// Copyright 2017 Johns Hopkins University (author: Daniel Povey)
// Yiming Wang

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.


#include "base/kaldi-common.h"
#include "fstext/fstext-lib.h"
#include "lat/kaldi-lattice.h"
#include "lat/lattice-functions.h"
#include "rnnlm/kaldi-rnnlm-rescoring.h"
#include "util/common-utils.h"

int main(int argc, char *argv[]) {
try {
using namespace kaldi;
typedef kaldi::int32 int32;
typedef kaldi::int64 int64;

const char *usage =
"Rescores lattice with rnnlm. The LM will be wrapped into the\n"
"DeterministicOnDemandFst interface and the rescoring is done by\n"
"composing with the wrapped LM using a special type of composition\n"
"algorithm. Determinization will be applied on the composed lattice.\n"
"\n"
"Usage: lattice-lmrescore-nnet3-rnnlm [options] <rnnlm-wordlist> \\\n"
" <word-symbol-table-rxfilename> <lattice-rspecifier> \\\n"
" <rnnlm-rxfilename> <lattice-wspecifier>\n"
" e.g.: lattice-lmrescore-nnet3-rnnlm --lm-scale=-1.0 words.txt \\\n"
" ark:in.lats rnnlm ark:out.lats\n";

ParseOptions po(usage);
int32 max_ngram_order = 3;
BaseFloat lm_scale = 1.0;

po.Register("lm-scale", &lm_scale, "Scaling factor for language model "
"costs; frequently 1.0 or -1.0");
po.Register("max-ngram-order", &max_ngram_order, "If positive, limit the "
"rnnlm context to the given number, -1 means we are not going "
"to limit it.");

po.Read(argc, argv);

if (po.NumArgs() != 4 && po.NumArgs() != 5) {
po.PrintUsage();
exit(1);
}

std::string lats_rspecifier, rnn_wordlist,
word_symbols_rxfilename, rnnlm_rxfilename, lats_wspecifier;
KALDI_ASSERT (po.NumArgs() == 5);

rnn_wordlist = po.GetArg(1);
word_symbols_rxfilename = po.GetArg(2);
lats_rspecifier = po.GetArg(3);
rnnlm_rxfilename = po.GetArg(4);
lats_wspecifier = po.GetArg(5);

// Reads the language model.
rnnlm::LmNnet lm_nnet;
ReadKaldiObject(rnnlm_rxfilename, &lm_nnet);

const nnet3::DecodableRnnlmSimpleLoopedComputationOptions opts;
const nnet3::DecodableRnnlmSimpleLoopedInfo info(opts, lm_nnet);

// Reads and writes as compact lattice.
SequentialCompactLatticeReader compact_lattice_reader(lats_rspecifier);
CompactLatticeWriter compact_lattice_writer(lats_wspecifier);

int32 n_done = 0, n_fail = 0;
for (; !compact_lattice_reader.Done(); compact_lattice_reader.Next()) {
std::string key = compact_lattice_reader.Key();
CompactLattice clat = compact_lattice_reader.Value();
compact_lattice_reader.FreeCurrent();

if (lm_scale != 0.0) {
// Before composing with the LM FST, we scale the lattice weights
// by the inverse of "lm_scale". We'll later scale by "lm_scale".
// We do it this way so we can determinize and it will give the
// right effect (taking the "best path" through the LM) regardless
// of the sign of lm_scale.
fst::ScaleLattice(fst::GraphLatticeScale(1.0 / lm_scale), &clat);
ArcSort(&clat, fst::OLabelCompare<CompactLatticeArc>());

// Wraps the rnnlm into FST. We re-create it for each lattice to prevent
// memory usage increasing with time.
nnet3::KaldiRnnlmDeterministicFst rnnlm_fst(max_ngram_order,
rnn_wordlist,
word_symbols_rxfilename,
info);

// Composes lattice with language model.
CompactLattice composed_clat;
ComposeCompactLatticeDeterministic(clat, &rnnlm_fst, &composed_clat);

// Determinizes the composed lattice.
Lattice composed_lat;
ConvertLattice(composed_clat, &composed_lat);
Invert(&composed_lat);
CompactLattice determinized_clat;
DeterminizeLattice(composed_lat, &determinized_clat);
fst::ScaleLattice(fst::GraphLatticeScale(lm_scale), &determinized_clat);
if (determinized_clat.Start() == fst::kNoStateId) {
KALDI_WARN << "Empty lattice for utterance " << key
<< " (incompatible LM?)";
n_fail++;
} else {
compact_lattice_writer.Write(key, determinized_clat);
n_done++;
}
} else {
// Zero scale so nothing to do.
n_done++;
compact_lattice_writer.Write(key, clat);
}
}

KALDI_LOG << "Done " << n_done << " lattices, failed for " << n_fail;
return (n_done != 0 ? 0 : 1);
} catch(const std::exception &e) {
std::cerr << e.what();
return -1;
}
}
1 change: 1 addition & 0 deletions src/rnnlm/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ OBJFILES = rnnlm-component-itf.o rnnlm-utils.o rnnlm-nnet.o rnnlm-component.o
rnnlm-training.o \
rnnlm-diagnostics.o \
arpa-sampling.o \
kaldi-rnnlm-decodable-simple-looped.o kaldi-rnnlm-rescoring.o

LIBNAME = kaldi-rnnlm

Expand Down
Loading