diff --git a/src/nnet3bin/nnet3-compute.cc b/src/nnet3bin/nnet3-compute.cc index d6b4b1ded5d..3cd56ef1c74 100644 --- a/src/nnet3bin/nnet3-compute.cc +++ b/src/nnet3bin/nnet3-compute.cc @@ -39,7 +39,7 @@ int main(int argc, char *argv[]) { "If --apply-exp=true, apply the Exp() function to the output " "before writing it out.\n" "\n" - "Usage: nnet3-compute [options] \n" + "Usage: nnet3-compute [options] \n" " e.g.: nnet3-compute final.raw scp:feats.scp ark:nnet_prediction.ark\n" "See also: nnet3-compute-from-egs\n"; @@ -49,7 +49,7 @@ int main(int argc, char *argv[]) { NnetSimpleComputationOptions opts; opts.acoustic_scale = 1.0; // by default do no scaling in this recipe. - bool apply_exp = false; + bool apply_exp = false, use_priors = false; std::string use_gpu = "yes"; std::string word_syms_filename; @@ -74,6 +74,9 @@ int main(int argc, char *argv[]) { "output"); po.Register("use-gpu", &use_gpu, "yes|no|optional|wait, only has effect if compiled with CUDA"); + po.Register("use-priors", &use_priors, "If true, subtract the logs of the " + "priors stored with the model (in this case, " + "a .mdl file is expected as input)."); po.Read(argc, argv); @@ -90,12 +93,26 @@ int main(int argc, char *argv[]) { feature_rspecifier = po.GetArg(2), matrix_wspecifier = po.GetArg(3); - Nnet nnet; - ReadKaldiObject(nnet_rxfilename, &nnet); + Nnet raw_nnet; + AmNnetSimple am_nnet; + if (use_priors) { + bool binary; + TransitionModel trans_model; + Input ki(nnet_rxfilename, &binary); + trans_model.Read(ki.Stream(), binary); + am_nnet.Read(ki.Stream(), binary); + } else { + ReadKaldiObject(nnet_rxfilename, &raw_nnet); + } + Nnet &nnet = (use_priors ? am_nnet.GetNnet() : raw_nnet); SetBatchnormTestMode(true, &nnet); SetDropoutTestMode(true, &nnet); CollapseModel(CollapseModelConfig(), &nnet); + Vector priors; + if (use_priors) + priors = am_nnet.Priors(); + RandomAccessBaseFloatMatrixReader online_ivector_reader( online_ivector_rspecifier); RandomAccessBaseFloatVectorReaderMapped ivector_reader( @@ -139,7 +156,6 @@ int main(int argc, char *argv[]) { } } - Vector priors; DecodableNnetSimple nnet_computer( opts, nnet, priors, features, &compiler,