From e35a66c159db0fd10311b0a41cb1e4fa1d8884e8 Mon Sep 17 00:00:00 2001 From: freewym Date: Tue, 25 Apr 2017 01:56:28 -0400 Subject: [PATCH 1/2] Add 'test mode' to dropout component --- src/chainbin/nnet3-chain-combine.cc | 3 ++- src/chainbin/nnet3-chain-compute-prob.cc | 15 +++++++++++---- src/nnet3/nnet-component-itf.h | 10 ++++++++++ src/nnet3/nnet-compute-test.cc | 3 ++- src/nnet3/nnet-general-component.cc | 18 +++++++++++++++++- src/nnet3/nnet-simple-component.cc | 16 +++++++++++++++- src/nnet3/nnet-test-utils.cc | 4 +++- src/nnet3/nnet-utils.cc | 10 +++++++++- src/nnet3/nnet-utils.h | 12 ++++++++++-- src/nnet3bin/nnet3-align-compiled.cc | 3 ++- src/nnet3bin/nnet3-combine.cc | 3 ++- src/nnet3bin/nnet3-compute-prob.cc | 15 +++++++++++---- src/nnet3bin/nnet3-compute.cc | 3 ++- src/nnet3bin/nnet3-latgen-faster-looped.cc | 3 ++- src/nnet3bin/nnet3-latgen-faster-parallel.cc | 3 ++- src/nnet3bin/nnet3-latgen-faster.cc | 3 ++- 16 files changed, 102 insertions(+), 22 deletions(-) diff --git a/src/chainbin/nnet3-chain-combine.cc b/src/chainbin/nnet3-chain-combine.cc index 5dbfdfea944..85e328709e0 100644 --- a/src/chainbin/nnet3-chain-combine.cc +++ b/src/chainbin/nnet3-chain-combine.cc @@ -82,7 +82,8 @@ int main(int argc, char *argv[]) { // means we use the freshest batch-norm stats. (Since the batch-norm // stats are not technically parameters, they are not subject to // combination like the rest of the model parameters). - SetTestMode(true, &nnet); + SetBatchnormTestMode(true, &nnet); + SetDropoutTestMode(true, &nnet); std::vector egs; egs.reserve(10000); // reserve a lot of space to minimize the chance of diff --git a/src/chainbin/nnet3-chain-compute-prob.cc b/src/chainbin/nnet3-chain-compute-prob.cc index 3d67af84a2b..49827490fab 100644 --- a/src/chainbin/nnet3-chain-compute-prob.cc +++ b/src/chainbin/nnet3-chain-compute-prob.cc @@ -38,7 +38,7 @@ int main(int argc, char *argv[]) { "Usage: nnet3-chain-compute-prob [options] \n" "e.g.: nnet3-chain-compute-prob 0.mdl den.fst ark:valid.egs\n"; - bool test_mode = true; + bool batchnorm_test_mode = true, dropout_test_mode = true; // This program doesn't support using a GPU, because these probabilities are // used for diagnostics, and you can just compute them with a small enough @@ -50,9 +50,13 @@ int main(int argc, char *argv[]) { ParseOptions po(usage); - po.Register("test-mode", &test_mode, + po.Register("batchnorm-test-mode", &batchnorm_test_mode, "If true, set test-mode to true on any BatchNormComponents."); + po.Register("dropout-test-mode", &dropout_test_mode, + "If true, set test-mode to true on any DropoutComponents and " + "DropoutMaskComponents."); + nnet_opts.Register(&po); chain_opts.Register(&po); @@ -70,8 +74,11 @@ int main(int argc, char *argv[]) { Nnet nnet; ReadKaldiObject(nnet_rxfilename, &nnet); - if (test_mode) - SetTestMode(true, &nnet); + if (batchnorm_test_mode) + SetBatchnormTestMode(true, &nnet); + + if (dropout_test_mode) + SetDropoutTestMode(true, &nnet); fst::StdVectorFst den_fst; ReadFstKaldi(den_fst_rxfilename, &den_fst); diff --git a/src/nnet3/nnet-component-itf.h b/src/nnet3/nnet-component-itf.h index af3beae84e2..dc58289721b 100644 --- a/src/nnet3/nnet-component-itf.h +++ b/src/nnet3/nnet-component-itf.h @@ -397,8 +397,18 @@ class RandomComponent: public Component { // validation-set performance), but check where else we call srand(). You'll // need to call srand prior to making this call. void ResetGenerator() { random_generator_.SeedGpu(); } + + // Call this with 'true' to set 'test mode' where the behavior is different + // from normal mode. + void SetTestMode(bool test_mode) { test_mode_ = test_mode; } + + RandomComponent(): test_mode_(false) { } protected: CuRand random_generator_; + + // This is true if we want a different behavior for inference from that for + // training. + bool test_mode_; }; /** diff --git a/src/nnet3/nnet-compute-test.cc b/src/nnet3/nnet-compute-test.cc index 301fdd926ad..df37a921ace 100644 --- a/src/nnet3/nnet-compute-test.cc +++ b/src/nnet3/nnet-compute-test.cc @@ -83,7 +83,8 @@ void TestNnetDecodable(Nnet *nnet) { ivector_dim = std::max(0, nnet->InputDim("ivector")); Matrix input(num_frames, input_dim); - SetTestMode(true, nnet); + SetBatchnormTestMode(true, nnet); + SetDropoutTestMode(true, nnet); input.SetRandn(); Vector ivector(ivector_dim); diff --git a/src/nnet3/nnet-general-component.cc b/src/nnet3/nnet-general-component.cc index 900869f3add..bfb972f8735 100644 --- a/src/nnet3/nnet-general-component.cc +++ b/src/nnet3/nnet-general-component.cc @@ -1414,6 +1414,10 @@ void* DropoutMaskComponent::Propagate( out->Set(1.0); return NULL; } + if (test_mode_) { + out->Set(1.0 - dropout_proportion); + return NULL; + } const_cast&>(random_generator_).RandUniform(out); out->Add(-dropout_proportion); out->ApplyHeaviside(); @@ -1442,7 +1446,15 @@ void DropoutMaskComponent::Read(std::istream &is, bool binary) { ReadBasicType(is, binary, &output_dim_); ExpectToken(is, binary, ""); ReadBasicType(is, binary, &dropout_proportion_); - ExpectToken(is, binary, ""); + std::string token; + ReadToken(is, binary, &token); + if (token == "") { + ReadBasicType(is, binary, &test_mode_); // read test mode + ExpectToken(is, binary, ""); + } else { + test_mode_ = false; + KALDI_ASSERT(token == ""); + } } @@ -1452,6 +1464,8 @@ void DropoutMaskComponent::Write(std::ostream &os, bool binary) const { WriteBasicType(os, binary, output_dim_); WriteToken(os, binary, ""); WriteBasicType(os, binary, dropout_proportion_); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, test_mode_); WriteToken(os, binary, ""); } @@ -1465,6 +1479,8 @@ void DropoutMaskComponent::InitFromConfig(ConfigLine *cfl) { KALDI_ASSERT(ok && output_dim_ > 0); dropout_proportion_ = 0.5; cfl->GetValue("dropout-proportion", &dropout_proportion_); + test_mode_ = false; + cfl->GetValue("test-mode", &test_mode_); } diff --git a/src/nnet3/nnet-simple-component.cc b/src/nnet3/nnet-simple-component.cc index 8caabd0e0aa..9d986a05254 100644 --- a/src/nnet3/nnet-simple-component.cc +++ b/src/nnet3/nnet-simple-component.cc @@ -100,9 +100,11 @@ void DropoutComponent::InitFromConfig(ConfigLine *cfl) { int32 dim = 0; BaseFloat dropout_proportion = 0.0; bool dropout_per_frame = false; + test_mode_ = false; bool ok = cfl->GetValue("dim", &dim) && cfl->GetValue("dropout-proportion", &dropout_proportion); cfl->GetValue("dropout-per-frame", &dropout_per_frame); + cfl->GetValue("test-mode", &test_mode_); // for this stage, dropout is hard coded in // normal mode if not declared in config if (!ok || cfl->HasUnusedValues() || dim <= 0 || @@ -128,6 +130,11 @@ void* DropoutComponent::Propagate(const ComponentPrecomputedIndexes *indexes, BaseFloat dropout = dropout_proportion_; KALDI_ASSERT(dropout >= 0.0 && dropout <= 1.0); + if (test_mode_) { + out->Set(1.0 - dropout); + out->MulElements(in); + return NULL; + } if (!dropout_per_frame_) { // This const_cast is only safe assuming you don't attempt // to use multi-threaded code with the GPU. @@ -188,9 +195,14 @@ void DropoutComponent::Read(std::istream &is, bool binary) { if (token == "") { ReadBasicType(is, binary, &dropout_per_frame_); // read dropout mode ReadToken(is, binary, &token); - KALDI_ASSERT(token == ""); } else { dropout_per_frame_ = false; + } + if (token == "") { + ReadBasicType(is, binary, &test_mode_); // read test mode + ExpectToken(is, binary, ""); + } else { + test_mode_ = false; KALDI_ASSERT(token == ""); } } @@ -203,6 +215,8 @@ void DropoutComponent::Write(std::ostream &os, bool binary) const { WriteBasicType(os, binary, dropout_proportion_); WriteToken(os, binary, ""); WriteBasicType(os, binary, dropout_per_frame_); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, test_mode_); WriteToken(os, binary, ""); } diff --git a/src/nnet3/nnet-test-utils.cc b/src/nnet3/nnet-test-utils.cc index a3387f94125..e1d58b34428 100644 --- a/src/nnet3/nnet-test-utils.cc +++ b/src/nnet3/nnet-test-utils.cc @@ -1583,8 +1583,10 @@ static void GenerateRandomComponentConfig(std::string *component_type, } case 29: { *component_type = "DropoutComponent"; + bool test_mode = (RandInt(0, 1) == 0); os << "dim=" << RandInt(1, 200) - << " dropout-proportion=" << RandUniform(); + << " dropout-proportion=" << RandUniform() << " test-mode=" + << (test_mode ? "true" : "false"); break; } case 30: { diff --git a/src/nnet3/nnet-utils.cc b/src/nnet3/nnet-utils.cc index f710f94ebfa..fe9e9f91997 100644 --- a/src/nnet3/nnet-utils.cc +++ b/src/nnet3/nnet-utils.cc @@ -470,7 +470,7 @@ void SetDropoutProportion(BaseFloat dropout_proportion, } -void SetTestMode(bool test_mode, Nnet *nnet) { +void SetBatchnormTestMode(bool test_mode, Nnet *nnet) { for (int32 c = 0; c < nnet->NumComponents(); c++) { Component *comp = nnet->GetComponent(c); BatchNormComponent *bc = dynamic_cast(comp); @@ -479,6 +479,14 @@ void SetTestMode(bool test_mode, Nnet *nnet) { } } +void SetDropoutTestMode(bool test_mode, Nnet *nnet) { + for (int32 c = 0; c < nnet->NumComponents(); c++) { + Component *comp = nnet->GetComponent(c); + RandomComponent *rc = dynamic_cast(comp); + if (rc != NULL) + rc->SetTestMode(test_mode); + } +} void FindOrphanComponents(const Nnet &nnet, std::vector *components) { int32 num_components = nnet.NumComponents(), num_nodes = nnet.NumNodes(); diff --git a/src/nnet3/nnet-utils.h b/src/nnet3/nnet-utils.h index 27387d1b8b1..6f645d89c56 100644 --- a/src/nnet3/nnet-utils.h +++ b/src/nnet3/nnet-utils.h @@ -164,13 +164,21 @@ std::string NnetInfo(const Nnet &nnet); /// dropout_proportion value. void SetDropoutProportion(BaseFloat dropout_proportion, Nnet *nnet); -/// This function currently affects only components of type BatchNormComponent. +/// This function affects only components of type BatchNormComponent. /// It sets "test mode" on such components (if you call it with test_mode = /// true, otherwise it would set normal mode, but this wouldn't be needed /// often). "test mode" means that instead of using statistics from the batch, /// it does a deterministic normalization based on statistics stored at training /// time. -void SetTestMode(bool test_mode, Nnet *nnet); +void SetBatchnormTestMode(bool test_mode, Nnet *nnet); + +/// This function affects components of child-classes of +/// RandomComponent( currently only DropoutComponent and DropoutMaskComponent). +/// It sets "test mode" on such components (if you call it with test_mode = +/// true, otherwise it would set normal mode, but this wouldn't be needed often). +/// "test mode" means that having a mask containing (1-dropout_prob) in all +/// elements. +void SetDropoutTestMode(bool test_mode, Nnet *nnet); /// This function finds a list of components that are never used, and outputs /// the integer comopnent indexes (you can use these to index diff --git a/src/nnet3bin/nnet3-align-compiled.cc b/src/nnet3bin/nnet3-align-compiled.cc index d8a80f03d8c..69f200ce4e2 100644 --- a/src/nnet3bin/nnet3-align-compiled.cc +++ b/src/nnet3bin/nnet3-align-compiled.cc @@ -112,7 +112,8 @@ int main(int argc, char *argv[]) { trans_model.Read(ki.Stream(), binary); am_nnet.Read(ki.Stream(), binary); } - SetTestMode(true, &(am_nnet.GetNnet())); + SetBatchnormTestMode(true, &(am_nnet.GetNnet())); + SetDropoutTestMode(true, &(am_nnet.GetNnet())); // this compiler object allows caching of computations across // different utterances. CachingOptimizingCompiler compiler(am_nnet.GetNnet(), diff --git a/src/nnet3bin/nnet3-combine.cc b/src/nnet3bin/nnet3-combine.cc index f437b2da8f9..7885bb70b6b 100644 --- a/src/nnet3bin/nnet3-combine.cc +++ b/src/nnet3bin/nnet3-combine.cc @@ -74,7 +74,8 @@ int main(int argc, char *argv[]) { // means we use the freshest batch-norm stats. (Since the batch-norm // stats are not technically parameters, they are not subject to // combination like the rest of the model parameters). - SetTestMode(true, &nnet); + SetBatchnormTestMode(true, &nnet); + SetDropoutTestMode(true, &nnet); std::vector egs; egs.reserve(10000); // reserve a lot of space to minimize the chance of diff --git a/src/nnet3bin/nnet3-compute-prob.cc b/src/nnet3bin/nnet3-compute-prob.cc index 84607dbe820..a67e76976c4 100644 --- a/src/nnet3bin/nnet3-compute-prob.cc +++ b/src/nnet3bin/nnet3-compute-prob.cc @@ -39,7 +39,7 @@ int main(int argc, char *argv[]) { "e.g.: nnet3-compute-prob 0.raw ark:valid.egs\n"; - bool test_mode = true; + bool batchnorm_test_mode = true, dropout_test_mode = true; // This program doesn't support using a GPU, because these probabilities are // used for diagnostics, and you can just compute them with a small enough @@ -49,9 +49,13 @@ int main(int argc, char *argv[]) { ParseOptions po(usage); - po.Register("test-mode", &test_mode, + po.Register("batchnorm-test-mode", &batchnorm_test_mode, "If true, set test-mode to true on any BatchNormComponents."); + po.Register("dropout-test-mode", &dropout_test_mode, + "If true, set test-mode to true on any DropoutComponents and " + "DropoutMaskComponents."); + opts.Register(&po); po.Read(argc, argv); @@ -67,8 +71,11 @@ int main(int argc, char *argv[]) { Nnet nnet; ReadKaldiObject(raw_nnet_rxfilename, &nnet); - if (test_mode) - SetTestMode(true, &nnet); + if (batchnorm_test_mode) + SetBatchnormTestMode(true, &nnet); + + if (dropout_test_mode) + SetDropoutTestMode(true, &nnet); NnetComputeProb prob_computer(opts, nnet); diff --git a/src/nnet3bin/nnet3-compute.cc b/src/nnet3bin/nnet3-compute.cc index fd4de734a7e..be66dba63b1 100644 --- a/src/nnet3bin/nnet3-compute.cc +++ b/src/nnet3bin/nnet3-compute.cc @@ -92,7 +92,8 @@ int main(int argc, char *argv[]) { Nnet nnet; ReadKaldiObject(nnet_rxfilename, &nnet); - SetTestMode(true, &nnet); + SetBatchnormTestMode(true, &nnet); + SetDropoutTestMode(true, &nnet); RandomAccessBaseFloatMatrixReader online_ivector_reader( online_ivector_rspecifier); diff --git a/src/nnet3bin/nnet3-latgen-faster-looped.cc b/src/nnet3bin/nnet3-latgen-faster-looped.cc index ee778082171..c9fbe054d06 100644 --- a/src/nnet3bin/nnet3-latgen-faster-looped.cc +++ b/src/nnet3bin/nnet3-latgen-faster-looped.cc @@ -97,7 +97,8 @@ int main(int argc, char *argv[]) { Input ki(model_in_filename, &binary); trans_model.Read(ki.Stream(), binary); am_nnet.Read(ki.Stream(), binary); - SetTestMode(true, &(am_nnet.GetNnet())); + SetBatchnormTestMode(true, &(am_nnet.GetNnet())); + SetDropoutTestMode(true, &(am_nnet.GetNnet())); } bool determinize = config.determinize_lattice; diff --git a/src/nnet3bin/nnet3-latgen-faster-parallel.cc b/src/nnet3bin/nnet3-latgen-faster-parallel.cc index 95085b3c444..fb21cf235b9 100644 --- a/src/nnet3bin/nnet3-latgen-faster-parallel.cc +++ b/src/nnet3bin/nnet3-latgen-faster-parallel.cc @@ -100,7 +100,8 @@ int main(int argc, char *argv[]) { Input ki(model_in_filename, &binary); trans_model.Read(ki.Stream(), binary); am_nnet.Read(ki.Stream(), binary); - SetTestMode(true, &(am_nnet.GetNnet())); + SetBatchnormTestMode(true, &(am_nnet.GetNnet())); + SetDropoutTestMode(true, &(am_nnet.GetNnet())); } bool determinize = config.determinize_lattice; diff --git a/src/nnet3bin/nnet3-latgen-faster.cc b/src/nnet3bin/nnet3-latgen-faster.cc index d9ebeb599a1..1921ce6a6e5 100644 --- a/src/nnet3bin/nnet3-latgen-faster.cc +++ b/src/nnet3bin/nnet3-latgen-faster.cc @@ -94,7 +94,8 @@ int main(int argc, char *argv[]) { Input ki(model_in_filename, &binary); trans_model.Read(ki.Stream(), binary); am_nnet.Read(ki.Stream(), binary); - SetTestMode(true, &(am_nnet.GetNnet())); + SetBatchnormTestMode(true, &(am_nnet.GetNnet())); + SetDropoutTestMode(true, &(am_nnet.GetNnet())); } bool determinize = config.determinize_lattice; From e8dc5ef85d9a8ef374ad270d3e9d9e1326249700 Mon Sep 17 00:00:00 2001 From: freewym Date: Tue, 25 Apr 2017 22:05:08 -0400 Subject: [PATCH 2/2] fix --- src/nnet3/nnet-simple-component.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/nnet3/nnet-simple-component.cc b/src/nnet3/nnet-simple-component.cc index 9d986a05254..27482678235 100644 --- a/src/nnet3/nnet-simple-component.cc +++ b/src/nnet3/nnet-simple-component.cc @@ -104,6 +104,7 @@ void DropoutComponent::InitFromConfig(ConfigLine *cfl) { bool ok = cfl->GetValue("dim", &dim) && cfl->GetValue("dropout-proportion", &dropout_proportion); cfl->GetValue("dropout-per-frame", &dropout_per_frame); + // It only makes sense to set test-mode in the config for testing purposes. cfl->GetValue("test-mode", &test_mode_); // for this stage, dropout is hard coded in // normal mode if not declared in config @@ -131,8 +132,8 @@ void* DropoutComponent::Propagate(const ComponentPrecomputedIndexes *indexes, BaseFloat dropout = dropout_proportion_; KALDI_ASSERT(dropout >= 0.0 && dropout <= 1.0); if (test_mode_) { - out->Set(1.0 - dropout); - out->MulElements(in); + out->CopyFromMat(in); + out->Scale(1.0 - dropout); return NULL; } if (!dropout_per_frame_) {