Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/chainbin/nnet3-chain-combine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ int main(int argc, char *argv[]) {
// means we use the freshest batch-norm stats. (Since the batch-norm
// stats are not technically parameters, they are not subject to
// combination like the rest of the model parameters).
SetTestMode(true, &nnet);
SetBatchnormTestMode(true, &nnet);
SetDropoutTestMode(true, &nnet);

std::vector<NnetChainExample> egs;
egs.reserve(10000); // reserve a lot of space to minimize the chance of
Expand Down
15 changes: 11 additions & 4 deletions src/chainbin/nnet3-chain-compute-prob.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ int main(int argc, char *argv[]) {
"Usage: nnet3-chain-compute-prob [options] <raw-nnet3-model-in> <denominator-fst> <training-examples-in>\n"
"e.g.: nnet3-chain-compute-prob 0.mdl den.fst ark:valid.egs\n";

bool test_mode = true;
bool batchnorm_test_mode = true, dropout_test_mode = true;

// This program doesn't support using a GPU, because these probabilities are
// used for diagnostics, and you can just compute them with a small enough
Expand All @@ -50,9 +50,13 @@ int main(int argc, char *argv[]) {

ParseOptions po(usage);

po.Register("test-mode", &test_mode,
po.Register("batchnorm-test-mode", &batchnorm_test_mode,
"If true, set test-mode to true on any BatchNormComponents.");

po.Register("dropout-test-mode", &dropout_test_mode,
"If true, set test-mode to true on any DropoutComponents and "
"DropoutMaskComponents.");

nnet_opts.Register(&po);
chain_opts.Register(&po);

Expand All @@ -70,8 +74,11 @@ int main(int argc, char *argv[]) {
Nnet nnet;
ReadKaldiObject(nnet_rxfilename, &nnet);

if (test_mode)
SetTestMode(true, &nnet);
if (batchnorm_test_mode)
SetBatchnormTestMode(true, &nnet);

if (dropout_test_mode)
SetDropoutTestMode(true, &nnet);

fst::StdVectorFst den_fst;
ReadFstKaldi(den_fst_rxfilename, &den_fst);
Expand Down
10 changes: 10 additions & 0 deletions src/nnet3/nnet-component-itf.h
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,18 @@ class RandomComponent: public Component {
// validation-set performance), but check where else we call srand(). You'll
// need to call srand prior to making this call.
void ResetGenerator() { random_generator_.SeedGpu(); }

// Call this with 'true' to set 'test mode' where the behavior is different
// from normal mode.
void SetTestMode(bool test_mode) { test_mode_ = test_mode; }

RandomComponent(): test_mode_(false) { }
protected:
CuRand<BaseFloat> random_generator_;

// This is true if we want a different behavior for inference from that for
// training.
bool test_mode_;
};

/**
Expand Down
3 changes: 2 additions & 1 deletion src/nnet3/nnet-compute-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ void TestNnetDecodable(Nnet *nnet) {
ivector_dim = std::max<int32>(0, nnet->InputDim("ivector"));
Matrix<BaseFloat> input(num_frames, input_dim);

SetTestMode(true, nnet);
SetBatchnormTestMode(true, nnet);
SetDropoutTestMode(true, nnet);

input.SetRandn();
Vector<BaseFloat> ivector(ivector_dim);
Expand Down
18 changes: 17 additions & 1 deletion src/nnet3/nnet-general-component.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,10 @@ void* DropoutMaskComponent::Propagate(
out->Set(1.0);
return NULL;
}
if (test_mode_) {
out->Set(1.0 - dropout_proportion);
return NULL;
}
const_cast<CuRand<BaseFloat>&>(random_generator_).RandUniform(out);
out->Add(-dropout_proportion);
out->ApplyHeaviside();
Expand Down Expand Up @@ -1442,7 +1446,15 @@ void DropoutMaskComponent::Read(std::istream &is, bool binary) {
ReadBasicType(is, binary, &output_dim_);
ExpectToken(is, binary, "<DropoutProportion>");
ReadBasicType(is, binary, &dropout_proportion_);
ExpectToken(is, binary, "</DropoutMaskComponent>");
std::string token;
ReadToken(is, binary, &token);
if (token == "<TestMode>") {
ReadBasicType(is, binary, &test_mode_); // read test mode
ExpectToken(is, binary, "</DropoutMaskComponent>");
} else {
test_mode_ = false;
KALDI_ASSERT(token == "</DropoutMaskComponent>");
}
}


Expand All @@ -1452,6 +1464,8 @@ void DropoutMaskComponent::Write(std::ostream &os, bool binary) const {
WriteBasicType(os, binary, output_dim_);
WriteToken(os, binary, "<DropoutProportion>");
WriteBasicType(os, binary, dropout_proportion_);
WriteToken(os, binary, "<TestMode>");
WriteBasicType(os, binary, test_mode_);
WriteToken(os, binary, "</DropoutMaskComponent>");
}

Expand All @@ -1465,6 +1479,8 @@ void DropoutMaskComponent::InitFromConfig(ConfigLine *cfl) {
KALDI_ASSERT(ok && output_dim_ > 0);
dropout_proportion_ = 0.5;
cfl->GetValue("dropout-proportion", &dropout_proportion_);
test_mode_ = false;
cfl->GetValue("test-mode", &test_mode_);
}


Expand Down
17 changes: 16 additions & 1 deletion src/nnet3/nnet-simple-component.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,12 @@ void DropoutComponent::InitFromConfig(ConfigLine *cfl) {
int32 dim = 0;
BaseFloat dropout_proportion = 0.0;
bool dropout_per_frame = false;
test_mode_ = false;
bool ok = cfl->GetValue("dim", &dim) &&
cfl->GetValue("dropout-proportion", &dropout_proportion);
cfl->GetValue("dropout-per-frame", &dropout_per_frame);
// It only makes sense to set test-mode in the config for testing purposes.
cfl->GetValue("test-mode", &test_mode_);
// for this stage, dropout is hard coded in
// normal mode if not declared in config
if (!ok || cfl->HasUnusedValues() || dim <= 0 ||
Expand All @@ -128,6 +131,11 @@ void* DropoutComponent::Propagate(const ComponentPrecomputedIndexes *indexes,

BaseFloat dropout = dropout_proportion_;
KALDI_ASSERT(dropout >= 0.0 && dropout <= 1.0);
if (test_mode_) {
out->CopyFromMat(in);
out->Scale(1.0 - dropout);
return NULL;
}
if (!dropout_per_frame_) {
// This const_cast is only safe assuming you don't attempt
// to use multi-threaded code with the GPU.
Expand Down Expand Up @@ -188,9 +196,14 @@ void DropoutComponent::Read(std::istream &is, bool binary) {
if (token == "<DropoutPerFrame>") {
ReadBasicType(is, binary, &dropout_per_frame_); // read dropout mode
ReadToken(is, binary, &token);
KALDI_ASSERT(token == "</DropoutComponent>");
} else {
dropout_per_frame_ = false;
}
if (token == "<TestMode>") {
ReadBasicType(is, binary, &test_mode_); // read test mode
ExpectToken(is, binary, "</DropoutComponent>");
} else {
test_mode_ = false;
KALDI_ASSERT(token == "</DropoutComponent>");
}
}
Expand All @@ -203,6 +216,8 @@ void DropoutComponent::Write(std::ostream &os, bool binary) const {
WriteBasicType(os, binary, dropout_proportion_);
WriteToken(os, binary, "<DropoutPerFrame>");
WriteBasicType(os, binary, dropout_per_frame_);
WriteToken(os, binary, "<TestMode>");
WriteBasicType(os, binary, test_mode_);
WriteToken(os, binary, "</DropoutComponent>");
}

Expand Down
4 changes: 3 additions & 1 deletion src/nnet3/nnet-test-utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1583,8 +1583,10 @@ static void GenerateRandomComponentConfig(std::string *component_type,
}
case 29: {
*component_type = "DropoutComponent";
bool test_mode = (RandInt(0, 1) == 0);
os << "dim=" << RandInt(1, 200)
<< " dropout-proportion=" << RandUniform();
<< " dropout-proportion=" << RandUniform() << " test-mode="
<< (test_mode ? "true" : "false");
break;
}
case 30: {
Expand Down
10 changes: 9 additions & 1 deletion src/nnet3/nnet-utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ void SetDropoutProportion(BaseFloat dropout_proportion,
}


void SetTestMode(bool test_mode, Nnet *nnet) {
void SetBatchnormTestMode(bool test_mode, Nnet *nnet) {
for (int32 c = 0; c < nnet->NumComponents(); c++) {
Component *comp = nnet->GetComponent(c);
BatchNormComponent *bc = dynamic_cast<BatchNormComponent*>(comp);
Expand All @@ -479,6 +479,14 @@ void SetTestMode(bool test_mode, Nnet *nnet) {
}
}

void SetDropoutTestMode(bool test_mode, Nnet *nnet) {
for (int32 c = 0; c < nnet->NumComponents(); c++) {
Component *comp = nnet->GetComponent(c);
RandomComponent *rc = dynamic_cast<RandomComponent*>(comp);
if (rc != NULL)
rc->SetTestMode(test_mode);
}
}

void FindOrphanComponents(const Nnet &nnet, std::vector<int32> *components) {
int32 num_components = nnet.NumComponents(), num_nodes = nnet.NumNodes();
Expand Down
12 changes: 10 additions & 2 deletions src/nnet3/nnet-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,21 @@ std::string NnetInfo(const Nnet &nnet);
/// dropout_proportion value.
void SetDropoutProportion(BaseFloat dropout_proportion, Nnet *nnet);

/// This function currently affects only components of type BatchNormComponent.
/// This function affects only components of type BatchNormComponent.
/// It sets "test mode" on such components (if you call it with test_mode =
/// true, otherwise it would set normal mode, but this wouldn't be needed
/// often). "test mode" means that instead of using statistics from the batch,
/// it does a deterministic normalization based on statistics stored at training
/// time.
void SetTestMode(bool test_mode, Nnet *nnet);
void SetBatchnormTestMode(bool test_mode, Nnet *nnet);

/// This function affects components of child-classes of
/// RandomComponent( currently only DropoutComponent and DropoutMaskComponent).
/// It sets "test mode" on such components (if you call it with test_mode =
/// true, otherwise it would set normal mode, but this wouldn't be needed often).
/// "test mode" means that having a mask containing (1-dropout_prob) in all
/// elements.
void SetDropoutTestMode(bool test_mode, Nnet *nnet);

/// This function finds a list of components that are never used, and outputs
/// the integer comopnent indexes (you can use these to index
Expand Down
3 changes: 2 additions & 1 deletion src/nnet3bin/nnet3-align-compiled.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ int main(int argc, char *argv[]) {
trans_model.Read(ki.Stream(), binary);
am_nnet.Read(ki.Stream(), binary);
}
SetTestMode(true, &(am_nnet.GetNnet()));
SetBatchnormTestMode(true, &(am_nnet.GetNnet()));
SetDropoutTestMode(true, &(am_nnet.GetNnet()));
// this compiler object allows caching of computations across
// different utterances.
CachingOptimizingCompiler compiler(am_nnet.GetNnet(),
Expand Down
3 changes: 2 additions & 1 deletion src/nnet3bin/nnet3-combine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ int main(int argc, char *argv[]) {
// means we use the freshest batch-norm stats. (Since the batch-norm
// stats are not technically parameters, they are not subject to
// combination like the rest of the model parameters).
SetTestMode(true, &nnet);
SetBatchnormTestMode(true, &nnet);
SetDropoutTestMode(true, &nnet);

std::vector<NnetExample> egs;
egs.reserve(10000); // reserve a lot of space to minimize the chance of
Expand Down
15 changes: 11 additions & 4 deletions src/nnet3bin/nnet3-compute-prob.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ int main(int argc, char *argv[]) {
"e.g.: nnet3-compute-prob 0.raw ark:valid.egs\n";


bool test_mode = true;
bool batchnorm_test_mode = true, dropout_test_mode = true;

// This program doesn't support using a GPU, because these probabilities are
// used for diagnostics, and you can just compute them with a small enough
Expand All @@ -49,9 +49,13 @@ int main(int argc, char *argv[]) {

ParseOptions po(usage);

po.Register("test-mode", &test_mode,
po.Register("batchnorm-test-mode", &batchnorm_test_mode,
"If true, set test-mode to true on any BatchNormComponents.");

po.Register("dropout-test-mode", &dropout_test_mode,
"If true, set test-mode to true on any DropoutComponents and "
"DropoutMaskComponents.");

opts.Register(&po);

po.Read(argc, argv);
Expand All @@ -67,8 +71,11 @@ int main(int argc, char *argv[]) {
Nnet nnet;
ReadKaldiObject(raw_nnet_rxfilename, &nnet);

if (test_mode)
SetTestMode(true, &nnet);
if (batchnorm_test_mode)
SetBatchnormTestMode(true, &nnet);

if (dropout_test_mode)
SetDropoutTestMode(true, &nnet);

NnetComputeProb prob_computer(opts, nnet);

Expand Down
3 changes: 2 additions & 1 deletion src/nnet3bin/nnet3-compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ int main(int argc, char *argv[]) {

Nnet nnet;
ReadKaldiObject(nnet_rxfilename, &nnet);
SetTestMode(true, &nnet);
SetBatchnormTestMode(true, &nnet);
SetDropoutTestMode(true, &nnet);

RandomAccessBaseFloatMatrixReader online_ivector_reader(
online_ivector_rspecifier);
Expand Down
3 changes: 2 additions & 1 deletion src/nnet3bin/nnet3-latgen-faster-looped.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ int main(int argc, char *argv[]) {
Input ki(model_in_filename, &binary);
trans_model.Read(ki.Stream(), binary);
am_nnet.Read(ki.Stream(), binary);
SetTestMode(true, &(am_nnet.GetNnet()));
SetBatchnormTestMode(true, &(am_nnet.GetNnet()));
SetDropoutTestMode(true, &(am_nnet.GetNnet()));
}

bool determinize = config.determinize_lattice;
Expand Down
3 changes: 2 additions & 1 deletion src/nnet3bin/nnet3-latgen-faster-parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ int main(int argc, char *argv[]) {
Input ki(model_in_filename, &binary);
trans_model.Read(ki.Stream(), binary);
am_nnet.Read(ki.Stream(), binary);
SetTestMode(true, &(am_nnet.GetNnet()));
SetBatchnormTestMode(true, &(am_nnet.GetNnet()));
SetDropoutTestMode(true, &(am_nnet.GetNnet()));
}

bool determinize = config.determinize_lattice;
Expand Down
3 changes: 2 additions & 1 deletion src/nnet3bin/nnet3-latgen-faster.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ int main(int argc, char *argv[]) {
Input ki(model_in_filename, &binary);
trans_model.Read(ki.Stream(), binary);
am_nnet.Read(ki.Stream(), binary);
SetTestMode(true, &(am_nnet.GetNnet()));
SetBatchnormTestMode(true, &(am_nnet.GetNnet()));
SetDropoutTestMode(true, &(am_nnet.GetNnet()));
}

bool determinize = config.determinize_lattice;
Expand Down