Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions src/nnet3bin/nnet3-compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ int main(int argc, char *argv[]) {
"If --apply-exp=true, apply the Exp() function to the output "
"before writing it out.\n"
"\n"
"Usage: nnet3-compute [options] <raw-nnet-in> <features-rspecifier> <matrix-wspecifier>\n"
"Usage: nnet3-compute [options] <nnet-in> <features-rspecifier> <matrix-wspecifier>\n"
" e.g.: nnet3-compute final.raw scp:feats.scp ark:nnet_prediction.ark\n"
"See also: nnet3-compute-from-egs\n";

Expand All @@ -49,7 +49,7 @@ int main(int argc, char *argv[]) {
NnetSimpleComputationOptions opts;
opts.acoustic_scale = 1.0; // by default do no scaling in this recipe.

bool apply_exp = false;
bool apply_exp = false, use_priors = false;
std::string use_gpu = "yes";

std::string word_syms_filename;
Expand All @@ -74,6 +74,9 @@ int main(int argc, char *argv[]) {
"output");
po.Register("use-gpu", &use_gpu,
"yes|no|optional|wait, only has effect if compiled with CUDA");
po.Register("use-priors", &use_priors, "If true, subtract the logs of the "
"priors stored with the model (in this case, "
"a .mdl file is expected as input).");

po.Read(argc, argv);

Expand All @@ -90,12 +93,26 @@ int main(int argc, char *argv[]) {
feature_rspecifier = po.GetArg(2),
matrix_wspecifier = po.GetArg(3);

Nnet nnet;
ReadKaldiObject(nnet_rxfilename, &nnet);
Nnet raw_nnet;
AmNnetSimple am_nnet;
if (use_priors) {
bool binary;
TransitionModel trans_model;
Input ki(nnet_rxfilename, &binary);
trans_model.Read(ki.Stream(), binary);
am_nnet.Read(ki.Stream(), binary);
} else {
ReadKaldiObject(nnet_rxfilename, &raw_nnet);
}
Nnet &nnet = (use_priors ? am_nnet.GetNnet() : raw_nnet);
SetBatchnormTestMode(true, &nnet);
SetDropoutTestMode(true, &nnet);
CollapseModel(CollapseModelConfig(), &nnet);

Vector<BaseFloat> priors;
if (use_priors)
priors = am_nnet.Priors();

RandomAccessBaseFloatMatrixReader online_ivector_reader(
online_ivector_rspecifier);
RandomAccessBaseFloatVectorReaderMapped ivector_reader(
Expand Down Expand Up @@ -139,7 +156,6 @@ int main(int argc, char *argv[]) {
}
}

Vector<BaseFloat> priors;
DecodableNnetSimple nnet_computer(
opts, nnet, priors,
features, &compiler,
Expand Down