diff --git a/src/nnet3/nnet-am-decodable-simple.cc b/src/nnet3/nnet-am-decodable-simple.cc index 341f87f0372..d66e24830c6 100644 --- a/src/nnet3/nnet-am-decodable-simple.cc +++ b/src/nnet3/nnet-am-decodable-simple.cc @@ -46,7 +46,7 @@ DecodableNnetSimple::DecodableNnetSimple( (feats_.NumRows() + opts_.frame_subsampling_factor - 1) / opts_.frame_subsampling_factor; KALDI_ASSERT(IsSimpleNnet(nnet)); - ComputeSimpleNnetContext(nnet, &nnet_left_context_, &nnet_right_context_); + compiler_.GetSimpleNnetContext(&nnet_left_context_, &nnet_right_context_); KALDI_ASSERT(!(ivector != NULL && online_ivectors != NULL)); KALDI_ASSERT(!(online_ivectors != NULL && online_ivector_period <= 0 && "You need to set the --online-ivector-period option!")); diff --git a/src/nnet3/nnet-example-utils.h b/src/nnet3/nnet-example-utils.h index 81e2882097d..52b2ebbf904 100644 --- a/src/nnet3/nnet-example-utils.h +++ b/src/nnet3/nnet-example-utils.h @@ -150,7 +150,6 @@ struct ExampleGenerationConfig { struct ChunkTimeInfo is used by class UtteranceSplitter to output information about how we split an utterance into chunks. */ - struct ChunkTimeInfo { int32 first_frame; int32 num_frames; diff --git a/src/nnet3/nnet-optimize.cc b/src/nnet3/nnet-optimize.cc index 63a7e833c74..b0eaa4916ae 100644 --- a/src/nnet3/nnet-optimize.cc +++ b/src/nnet3/nnet-optimize.cc @@ -21,6 +21,7 @@ #include #include "nnet3/nnet-optimize.h" #include "nnet3/nnet-optimize-utils.h" +#include "nnet3/nnet-utils.h" #include "base/timer.h" namespace kaldi { @@ -638,7 +639,8 @@ CachingOptimizingCompiler::CachingOptimizingCompiler( seconds_taken_total_(0.0), seconds_taken_compile_(0.0), seconds_taken_optimize_(0.0), seconds_taken_expand_(0.0), seconds_taken_check_(0.0), seconds_taken_indexes_(0.0), - seconds_taken_io_(0.0), cache_(config.cache_capacity) { } + seconds_taken_io_(0.0), cache_(config.cache_capacity), + nnet_left_context_(-1), nnet_right_context_(-1) { } CachingOptimizingCompiler::CachingOptimizingCompiler( const Nnet &nnet, @@ -648,8 +650,18 @@ CachingOptimizingCompiler::CachingOptimizingCompiler( seconds_taken_total_(0.0), seconds_taken_compile_(0.0), seconds_taken_optimize_(0.0), seconds_taken_expand_(0.0), seconds_taken_check_(0.0), seconds_taken_indexes_(0.0), - seconds_taken_io_(0.0), cache_(config.cache_capacity) { } + seconds_taken_io_(0.0), cache_(config.cache_capacity), + nnet_left_context_(-1), nnet_right_context_(-1) { } +void CachingOptimizingCompiler::GetSimpleNnetContext( + int32 *nnet_left_context, int32 *nnet_right_context) { + if (nnet_left_context_ == -1) { + ComputeSimpleNnetContext(nnet_, &nnet_left_context_, + &nnet_right_context_); + } + *nnet_left_context = nnet_left_context_; + *nnet_right_context = nnet_right_context_; +} void CachingOptimizingCompiler::ReadCache(std::istream &is, bool binary) { { diff --git a/src/nnet3/nnet-optimize.h b/src/nnet3/nnet-optimize.h index 78763732469..0804729519d 100644 --- a/src/nnet3/nnet-optimize.h +++ b/src/nnet3/nnet-optimize.h @@ -242,6 +242,16 @@ class CachingOptimizingCompiler { void ReadCache(std::istream &is, bool binary); void WriteCache(std::ostream &os, bool binary); + + // GetSimpleNnetContext() is equivalent to calling: + // ComputeSimpleNnetContext(nnet_, &nnet_left_context, + // &nnet_right_context) + // but it caches it inside this class. This functionality is independent of + // the rest of the functionality of this class; it just happens to be a + // convenient place to put this mechanism. + void GetSimpleNnetContext(int32 *nnet_left_context, + int32 *nnet_right_context); + private: // This function just implements the work of Compile(); it's made a separate @@ -290,6 +300,10 @@ class CachingOptimizingCompiler { double seconds_taken_io_; ComputationCache cache_; + + // These following two variables are only used by the function GetSimpleNnetContext(). + int32 nnet_left_context_; + int32 nnet_right_context_; };