From e4da298e6b010ae1bd0df551fcb3f7d0a5ac73a8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 17 Mar 2018 17:53:43 -0400 Subject: [PATCH 1/3] [src] Make CachingOptimizingCompiler thread safe. Thanks: Arseniy Gorin. --- src/nnet3/natural-gradient-online.h | 1 - src/nnet3/nnet-am-decodable-simple.cc | 2 +- src/nnet3/nnet-chain-diagnostics.cc | 2 +- src/nnet3/nnet-chain-training.cc | 2 +- src/nnet3/nnet-component-itf.h | 1 - src/nnet3/nnet-computation.cc | 17 ++ src/nnet3/nnet-computation.h | 16 ++ src/nnet3/nnet-diagnostics.cc | 20 +- src/nnet3/nnet-discriminative-diagnostics.cc | 3 +- src/nnet3/nnet-discriminative-training.cc | 2 +- src/nnet3/nnet-optimize-utils.cc | 121 ++++++++++ src/nnet3/nnet-optimize-utils.h | 68 +++++- src/nnet3/nnet-optimize.cc | 237 +++++++------------ src/nnet3/nnet-optimize.h | 79 ++----- src/nnet3/nnet-training.cc | 2 +- src/rnnlm/rnnlm-core-compute.cc | 2 +- src/rnnlm/rnnlm-core-training.cc | 12 +- 17 files changed, 348 insertions(+), 239 deletions(-) diff --git a/src/nnet3/natural-gradient-online.h b/src/nnet3/natural-gradient-online.h index 0b05948977e..b49769da540 100644 --- a/src/nnet3/natural-gradient-online.h +++ b/src/nnet3/natural-gradient-online.h @@ -21,7 +21,6 @@ #define KALDI_NNET3_NATURAL_GRADIENT_ONLINE_H_ #include -#include #include "base/kaldi-common.h" #include "matrix/matrix-lib.h" #include "cudamatrix/cu-matrix-lib.h" diff --git a/src/nnet3/nnet-am-decodable-simple.cc b/src/nnet3/nnet-am-decodable-simple.cc index 35b1506336e..341f87f0372 100644 --- a/src/nnet3/nnet-am-decodable-simple.cc +++ b/src/nnet3/nnet-am-decodable-simple.cc @@ -248,7 +248,7 @@ void DecodableNnetSimple::DoNnetComputation( request.outputs.resize(1); request.outputs[0].Swap(&output_spec); - const NnetComputation *computation = compiler_.Compile(request); + std::shared_ptr computation = compiler_.Compile(request); Nnet *nnet_to_update = NULL; // we're not doing any update. NnetComputer computer(opts_.compute_config, *computation, nnet_, nnet_to_update); diff --git a/src/nnet3/nnet-chain-diagnostics.cc b/src/nnet3/nnet-chain-diagnostics.cc index 084b33347df..a7e60a5e0c4 100644 --- a/src/nnet3/nnet-chain-diagnostics.cc +++ b/src/nnet3/nnet-chain-diagnostics.cc @@ -100,7 +100,7 @@ void NnetChainComputeProb::Compute(const NnetChainExample &chain_eg) { GetChainComputationRequest(nnet_, chain_eg, need_model_derivative, store_component_stats, use_xent_regularization, use_xent_derivative, &request); - const NnetComputation *computation = compiler_.Compile(request); + std::shared_ptr computation = compiler_.Compile(request); NnetComputer computer(nnet_config_.compute_config, *computation, nnet_, deriv_nnet_); // give the inputs to the computer object. diff --git a/src/nnet3/nnet-chain-training.cc b/src/nnet3/nnet-chain-training.cc index 844fb82d32a..1d149b6f193 100644 --- a/src/nnet3/nnet-chain-training.cc +++ b/src/nnet3/nnet-chain-training.cc @@ -68,7 +68,7 @@ void NnetChainTrainer::Train(const NnetChainExample &chain_eg) { nnet_config.store_component_stats, use_xent_regularization, need_model_derivative, &request); - const NnetComputation *computation = compiler_.Compile(request); + std::shared_ptr computation = compiler_.Compile(request); if (nnet_config.backstitch_training_scale > 0.0 && num_minibatches_processed_ % nnet_config.backstitch_training_interval == diff --git a/src/nnet3/nnet-component-itf.h b/src/nnet3/nnet-component-itf.h index 79a1f1a5602..01697353308 100644 --- a/src/nnet3/nnet-component-itf.h +++ b/src/nnet3/nnet-component-itf.h @@ -23,7 +23,6 @@ #define KALDI_NNET3_NNET_COMPONENT_ITF_H_ #include -#include #include "nnet3/nnet-common.h" #include "nnet3/nnet-parse.h" #include "base/kaldi-error.h" diff --git a/src/nnet3/nnet-computation.cc b/src/nnet3/nnet-computation.cc index bb0e7c917fc..520d296ee21 100644 --- a/src/nnet3/nnet-computation.cc +++ b/src/nnet3/nnet-computation.cc @@ -1199,6 +1199,23 @@ size_t IoSpecificationHasher::operator () ( (io_spec.has_deriv ? 4261 : 0); } +// ComputationRequests are distinguished by the names and indexes +// of inputs and outputs +size_t ComputationRequestHasher::operator() ( + const ComputationRequest *cr) const noexcept { + size_t ans = 0; + size_t p1 = 4111, p2 = 26951; + IoSpecificationHasher io_hasher; + std::vector::const_iterator itr = cr->inputs.begin(), + end = cr->inputs.end(); + for (; itr != end; ++itr) + ans = ans * p1 + io_hasher(*itr); + itr = cr->outputs.begin(); + end = cr->outputs.end(); + for (; itr != end; ++itr) + ans = ans * p2 + io_hasher(*itr); + return ans; +} diff --git a/src/nnet3/nnet-computation.h b/src/nnet3/nnet-computation.h index aefcb94c465..97d8b9045ea 100644 --- a/src/nnet3/nnet-computation.h +++ b/src/nnet3/nnet-computation.h @@ -157,6 +157,22 @@ struct ComputationRequest { bool operator== (const ComputationRequest &other) const; }; +// Hash function for ComputationRequest. It converts +// ComputationRequest to hash code by looking at input +// and output IoSpecifications vectors. +struct ComputationRequestHasher { + size_t operator()(const ComputationRequest *cr) const noexcept; +}; + +// Equality function for ComputationRequest pointer +struct ComputationRequestPtrEqual { + public: + bool operator() (const ComputationRequest* cr1, + const ComputationRequest* cr2) const { + return (*cr1) == (*cr2); + } +}; + /** CommandType is an enum that describes the category of the command used in diff --git a/src/nnet3/nnet-diagnostics.cc b/src/nnet3/nnet-diagnostics.cc index 8f9f1be24e4..2a6cfe5de6a 100644 --- a/src/nnet3/nnet-diagnostics.cc +++ b/src/nnet3/nnet-diagnostics.cc @@ -83,7 +83,7 @@ void NnetComputeProb::Compute(const NnetExample &eg) { GetComputationRequest(nnet_, eg, need_model_derivative, store_component_stats, &request); - const NnetComputation *computation = compiler_.Compile(request); + std::shared_ptr computation = compiler_.Compile(request); NnetComputer computer(config_.compute_config, *computation, nnet_, deriv_nnet_); // give the inputs to the computer object. @@ -122,11 +122,11 @@ void NnetComputeProb::ProcessOutputs(const NnetExample &eg, totals.tot_objective += tot_objf; } // May not be meaningful in non-classification tasks - if (config_.compute_accuracy) { + if (config_.compute_accuracy) { BaseFloat tot_weight, tot_accuracy; PerDimObjectiveInfo &acc_totals = accuracy_info_[io.name]; - if (config_.compute_per_dim_accuracy && + if (config_.compute_per_dim_accuracy && acc_totals.tot_objective_vec.Dim() == 0) { acc_totals.tot_objective_vec.Resize(output.NumCols()); acc_totals.tot_weight_vec.Resize(output.NumCols()); @@ -134,9 +134,9 @@ void NnetComputeProb::ProcessOutputs(const NnetExample &eg, ComputeAccuracy(io.features, output, &tot_weight, &tot_accuracy, - config_.compute_per_dim_accuracy ? + config_.compute_per_dim_accuracy ? &acc_totals.tot_weight_vec : NULL, - config_.compute_per_dim_accuracy ? + config_.compute_per_dim_accuracy ? &acc_totals.tot_objective_vec : NULL); acc_totals.tot_weight += tot_weight; acc_totals.tot_objective += tot_accuracy; @@ -149,7 +149,7 @@ void NnetComputeProb::ProcessOutputs(const NnetExample &eg, bool NnetComputeProb::PrintTotalStats() const { bool ans = false; { // First print regular objectives - unordered_map::const_iterator iter, end; iter = objf_info_.begin(); end = objf_info_.end(); @@ -168,8 +168,8 @@ bool NnetComputeProb::PrintTotalStats() const { ans = true; } } - { - unordered_map::const_iterator iter, end; // now print accuracies. iter = accuracy_info_.begin(); @@ -185,14 +185,14 @@ bool NnetComputeProb::PrintTotalStats() const { Vector accuracy_vec(info.tot_weight_vec.Dim()); for (size_t j = 0; j < info.tot_weight_vec.Dim(); j++) { if (info.tot_weight_vec(j) != 0) { - accuracy_vec(j) = info.tot_objective_vec(j) + accuracy_vec(j) = info.tot_objective_vec(j) / info.tot_weight_vec(j); } else { accuracy_vec(j) = -1.0; } } - KALDI_LOG << "Overall per-dim accuracy vector for '" << name + KALDI_LOG << "Overall per-dim accuracy vector for '" << name << "' is " << accuracy_vec << " per frame" << ", over " << info.tot_weight << " frames."; } diff --git a/src/nnet3/nnet-discriminative-diagnostics.cc b/src/nnet3/nnet-discriminative-diagnostics.cc index f23af549d72..488372be8e1 100644 --- a/src/nnet3/nnet-discriminative-diagnostics.cc +++ b/src/nnet3/nnet-discriminative-diagnostics.cc @@ -78,7 +78,7 @@ void NnetDiscriminativeComputeObjf::Compute(const NnetDiscriminativeExample &eg) store_component_stats, use_xent_regularization, use_xent_derivative, &request); - const NnetComputation *computation = compiler_.Compile(request); + std::shared_ptr computation = compiler_.Compile(request); NnetComputer computer(nnet_config_.compute_config, *computation, nnet_, deriv_nnet_); // give the inputs to the computer object. @@ -206,4 +206,3 @@ const discriminative::DiscriminativeObjectiveInfo* NnetDiscriminativeComputeObjf } // namespace nnet3 } // namespace kaldi - diff --git a/src/nnet3/nnet-discriminative-training.cc b/src/nnet3/nnet-discriminative-training.cc index 0a436b69f8c..91a72c73cca 100644 --- a/src/nnet3/nnet-discriminative-training.cc +++ b/src/nnet3/nnet-discriminative-training.cc @@ -70,7 +70,7 @@ void NnetDiscriminativeTrainer::Train(const NnetDiscriminativeExample &eg) { use_xent_regularization, need_model_derivative, &request); - const NnetComputation *computation = compiler_.Compile(request); + std::shared_ptr computation = compiler_.Compile(request); NnetComputer computer(nnet_config.compute_config, *computation, *nnet_, diff --git a/src/nnet3/nnet-optimize-utils.cc b/src/nnet3/nnet-optimize-utils.cc index 756ea45e894..e587c7ff947 100644 --- a/src/nnet3/nnet-optimize-utils.cc +++ b/src/nnet3/nnet-optimize-utils.cc @@ -4950,5 +4950,126 @@ void OptimizeMemoryCompression(const Nnet &nnet, } +std::shared_ptr ComputationCache::Find( + const ComputationRequest &in_request) { + std::lock_guard lock(mutex_); + + CacheType::iterator iter = computation_cache_.find(&in_request); + if (iter == computation_cache_.end()) { + return NULL; + } else { + std::shared_ptr ans = iter->second.first; + // Update access record by moving the accessed request to the end of the + // access queue, which declares that it's the most recently used. + access_queue_.splice(access_queue_.end(), access_queue_, + iter->second.second); + return ans; + } +} + + +ComputationCache::ComputationCache(int32 cache_capacity): + cache_capacity_(cache_capacity) { + KALDI_ASSERT(cache_capacity > 0); +} + +std::shared_ptr ComputationCache::Insert( + const ComputationRequest &request_in, + const NnetComputation *computation_in) { + + std::lock_guard lock(mutex_); + if (static_cast(computation_cache_.size()) >= cache_capacity_) { + // Cache has reached capacity; purge the least-recently-accessed request + const CacheType::iterator iter = + computation_cache_.find(access_queue_.front()); + KALDI_ASSERT(iter != computation_cache_.end()); + const ComputationRequest *request = iter->first; + computation_cache_.erase(iter); + delete request; + // we don't need to delete the computation in iter->second.first, as the + // shared_ptr takes care of that automatically. + access_queue_.pop_front(); + } + + // Now insert the thing we need to insert. We'll own the pointer 'request' in + // 'computation_cache_', so we need to allocate our own version. + ComputationRequest *request = new ComputationRequest(request_in); + // When we construct this shared_ptr, it takes ownership of the pointer + // 'computation_in'. + std::shared_ptr computation(computation_in); + + AqType::iterator ait = access_queue_.insert(access_queue_.end(), request); + + std::pair p = computation_cache_.insert( + std::make_pair(request, std::make_pair(computation, ait))); + if (!p.second) { + // if p.second is false, this pair was not inserted because + // a computation for the same computation-request already existed in + // the map. This is possible in multi-threaded operations, if two + // threads try to compile the same computation at the same time (only + // one of them will successfully add it). + // We need to erase the access-queue element that we just added, it's + // no longer going to be needed. + access_queue_.erase(ait); + delete request; + } + return computation; +} + + +void ComputationCache::Read(std::istream &is, bool binary) { + // Note: the object on disk doesn't have tokens like "" + // and "" for back-compatibility reasons. + int32 computation_cache_size; + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &computation_cache_size); + KALDI_ASSERT(computation_cache_size >= 0); + computation_cache_.clear(); + access_queue_.clear(); + ExpectToken(is, binary, ""); + for (size_t c = 0; c < computation_cache_size; c++) { + ComputationRequest request; + request.Read(is, binary); + NnetComputation *computation = new NnetComputation(); + computation->Read(is, binary); + Insert(request, computation); + } +} + +void ComputationCache::Check(const Nnet &nnet) const { + CacheType::const_iterator iter = computation_cache_.begin(), + end = computation_cache_.end(); + // We only need to explicitly delete the pointer to the ComputationRequest. + // The pointers to Computation are deleted automatically by std::shared_ptr + // when the reference count goes to zero. + for (; iter != end; ++iter) { + const NnetComputation &computation = *(iter->second.first); + CheckComputationOptions check_config; + ComputationChecker checker(check_config, nnet, computation); + checker.Check(); + } +} + +void ComputationCache::Write(std::ostream &os, bool binary) const { + WriteToken(os, binary, ""); + WriteBasicType(os, binary, static_cast(computation_cache_.size())); + WriteToken(os, binary, ""); + for (CacheType::const_iterator iter = computation_cache_.begin(); + iter != computation_cache_.end(); ++iter) { + iter->first->Write(os, binary); + iter->second.first->Write(os, binary); + } +} + +ComputationCache::~ComputationCache() { + CacheType::const_iterator iter = computation_cache_.begin(), + end = computation_cache_.end(); + // We only need to explicitly delete the pointer to the ComputationRequest. + // The pointers to Computation are deleted automatically by std::shared_ptr + // when the reference count goes to zero. + for (; iter != end; ++iter) + delete iter->first; +} + } // namespace nnet3 } // namespace kaldi diff --git a/src/nnet3/nnet-optimize-utils.h b/src/nnet3/nnet-optimize-utils.h index 32adf9e3e19..0a30dcc84cb 100644 --- a/src/nnet3/nnet-optimize-utils.h +++ b/src/nnet3/nnet-optimize-utils.h @@ -20,9 +20,12 @@ #ifndef KALDI_NNET3_NNET_OPTIMIZE_UTILS_H_ #define KALDI_NNET3_NNET_OPTIMIZE_UTILS_H_ +#include +#include #include "nnet3/nnet-compile.h" #include "nnet3/nnet-analyze.h" + namespace kaldi { namespace nnet3 { @@ -613,16 +616,67 @@ void OptimizeLoopedComputation(const Nnet &nnet, void FixGotoLabel(NnetComputation *computation); -/* +/// Class ComputationCache is used inside class CachingOptimizingCompiler to +/// cache previously computed computations. The code was moved from class +/// CachingOptimizingCompiler to this separate class for clarity when adding +/// thread-safety functionality. +/// It's OK to call Find() and Insert() from multiple threads without +/// additional synchronization. +class ComputationCache { + public: + ComputationCache(int32 cache_capacity); - Possible TODO: - optimizations to replace row-by-row copy and add commands with whole-matrix - commands on smaller sub-matrices (if the row-by-row copy commands have certain - regularities). this is a minor issue, we can handle it later. We have to be - careful if this causes sub-matrices to overlap. + // Note: if something fails in Read(), or the written cache was from an older + // format, it will just leave the cache empty. + void Read(std::istream &is, bool binary); + + void Write(std::ostream &os, bool binary) const; + + + // Searches for the computation corresponding to this computation, and returns + // it if cached, or NULL (as std::shared_ptr) if not. (We need shared_ptr to + // handle multi-threaded operation, so that if the computation is ejected from + // the cache by another thread, it won't be deleted while still in use). This + // function also moves this computation to the end of the + // most-recently-accessed queue, which is why it's not const. + std::shared_ptr Find(const ComputationRequest &request); - */ + // Inserts the computation into the cache-- this is assumed to be the + // computation for the computation-request 'request'. Returns a shared_ptr + // which can be used to access the object. This function takes ownership of + // 'computation'. + std::shared_ptr Insert(const ComputationRequest &request, + const NnetComputation *computation); + + ~ComputationCache(); + + // Checks the stored computation for correctness. + void Check(const Nnet &nnet) const; + private: + + std::mutex mutex_; // Read/write mutex. + + int32 cache_capacity_; + + // The access queue for keeping track of the freshness of computation. + // Most-recently-accessed computation is at the end, and + // least-recently-accessed computaiton is at the beginning. Together with + // computation_cache_, this forms a most-recently-used (MRU) cache for + // Computations, indexed by ComputationRequest. The pointers are owned in + // computation_cache_. + typedef std::list AqType; + AqType access_queue_; + + // Map from computation-request to pair of (computation, and position in + // access_queue_). Used for fast lookup of previously compiled computations. + // All pointers are owned here. + typedef unordered_map, AqType::iterator>, + ComputationRequestHasher, + ComputationRequestPtrEqual> CacheType; + CacheType computation_cache_; +}; diff --git a/src/nnet3/nnet-optimize.cc b/src/nnet3/nnet-optimize.cc index ecce196801b..63a7e833c74 100644 --- a/src/nnet3/nnet-optimize.cc +++ b/src/nnet3/nnet-optimize.cc @@ -36,21 +36,19 @@ void NnetOptimizeOptions::Read(std::istream &is, bool binary) { ReadBasicType(is, binary, &propagate_in_place); ExpectToken(is, binary, ""); ReadBasicType(is, binary, &backprop_in_place); - std::string tok; - ReadToken(is, binary, &tok); - if (tok == "") { + if (PeekToken(is, binary) == 'O') { + ExpectToken(is, binary, ""); ReadBasicType(is, binary, &optimize_row_ops); - ReadToken(is, binary, &tok); - } else { - optimize_row_ops = true; } - if (tok == "") { + if (PeekToken(is, binary) == 'S') { + ExpectToken(is, binary, ""); ReadBasicType(is, binary, &split_row_ops); - ReadToken(is, binary, &tok); - } else { - split_row_ops = true; } - KALDI_ASSERT(tok == ""); + if (PeekToken(is, binary) == 'E') { + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &extend_matrices); + } + ExpectToken(is, binary, ""); ReadBasicType(is, binary, &convert_addition); ExpectToken(is, binary, ""); ReadBasicType(is, binary, &remove_assignments); @@ -68,14 +66,19 @@ void NnetOptimizeOptions::Read(std::istream &is, bool binary) { ReadBasicType(is, binary, &min_deriv_time); ExpectToken(is, binary, ""); ReadBasicType(is, binary, &max_deriv_time); - ReadToken(is, binary, &tok); - if (tok == "") { + if (PeekToken(is, binary) == 'M') { + ExpectToken(is, binary, ""); ReadBasicType(is, binary, &max_deriv_time_relative); - ReadToken(is, binary, &tok); } - - - KALDI_ASSERT(tok == ""); + if (PeekToken(is, binary) == 'S') { + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &snip_row_ops); + } + if (PeekToken(is, binary) == 'M') { + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &memory_compression_level); + } + ExpectToken(is, binary, ""); } void NnetOptimizeOptions::Write(std::ostream &os, bool binary) const { @@ -90,6 +93,10 @@ void NnetOptimizeOptions::Write(std::ostream &os, bool binary) const { WriteBasicType(os, binary, backprop_in_place); WriteToken(os, binary, ""); WriteBasicType(os, binary, optimize_row_ops); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, split_row_ops); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, extend_matrices); WriteToken(os, binary, ""); WriteBasicType(os, binary, convert_addition); WriteToken(os, binary, ""); @@ -110,14 +117,20 @@ void NnetOptimizeOptions::Write(std::ostream &os, bool binary) const { WriteBasicType(os, binary, max_deriv_time); WriteToken(os, binary, ""); WriteBasicType(os, binary, max_deriv_time_relative); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, snip_row_ops); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, memory_compression_level); WriteToken(os, binary, ""); } bool NnetOptimizeOptions::operator == (const NnetOptimizeOptions &other) const { - return (other.propagate_in_place == propagate_in_place && - other.optimize == optimize && + return (other.optimize == optimize && other.consolidate_model_update == consolidate_model_update && + other.propagate_in_place == propagate_in_place && other.backprop_in_place == backprop_in_place && + other.optimize_row_ops == optimize_row_ops && + other.split_row_ops == split_row_ops && other.convert_addition == convert_addition && other.remove_assignments == remove_assignments && other.allow_left_merge == allow_left_merge && @@ -127,7 +140,9 @@ bool NnetOptimizeOptions::operator == (const NnetOptimizeOptions &other) const { other.allocate_from_other == allocate_from_other && other.min_deriv_time == min_deriv_time && other.max_deriv_time == max_deriv_time && - other.max_deriv_time_relative == max_deriv_time_relative); + other.max_deriv_time_relative == max_deriv_time_relative && + other.snip_row_ops == snip_row_ops && + other.memory_compression_level == memory_compression_level); } // move commands that resize and zero matrices to as late/early as possible. @@ -613,25 +628,6 @@ void Optimize(const NnetOptimizeOptions &config, KALDI_LOG << "After optimization, max memory use (bytes) = " << GetMaxMemoryUse(*computation); } - -} - -// ComputationRequests are distinguished by the names and indexes -// of inputs and outputs -size_t ComputationRequestHasher::operator() ( - const ComputationRequest *cr) const noexcept { - size_t ans = 0; - size_t p1 = 4111, p2 = 26951; - IoSpecificationHasher io_hasher; - std::vector::const_iterator itr = cr->inputs.begin(), - end = cr->inputs.end(); - for (; itr != end; ++itr) - ans = ans * p1 + io_hasher(*itr); - itr = cr->outputs.begin(); - end = cr->outputs.end(); - for (; itr != end; ++itr) - ans = ans * p2 + io_hasher(*itr); - return ans; } @@ -641,7 +637,8 @@ CachingOptimizingCompiler::CachingOptimizingCompiler( nnet_(nnet), config_(config), seconds_taken_total_(0.0), seconds_taken_compile_(0.0), seconds_taken_optimize_(0.0), seconds_taken_expand_(0.0), - seconds_taken_check_(0.0), seconds_taken_indexes_(0.0) { } + seconds_taken_check_(0.0), seconds_taken_indexes_(0.0), + seconds_taken_io_(0.0), cache_(config.cache_capacity) { } CachingOptimizingCompiler::CachingOptimizingCompiler( const Nnet &nnet, @@ -650,87 +647,41 @@ CachingOptimizingCompiler::CachingOptimizingCompiler( nnet_(nnet), config_(config), opt_config_(opt_config), seconds_taken_total_(0.0), seconds_taken_compile_(0.0), seconds_taken_optimize_(0.0), seconds_taken_expand_(0.0), - seconds_taken_check_(0.0), seconds_taken_indexes_(0.0) { } - -void CachingOptimizingCompiler::UpdateCache(const ComputationRequest *request, - const NnetComputation *computation) { - if (computation_cache_.size() == config_.cache_capacity) { - // full, locate the least-recently-accessed request - const CacheType::iterator it = - computation_cache_.find(access_queue_.front()); - KALDI_ASSERT(it != computation_cache_.end()); - // purge the least-recently-accessed request - const ComputationRequest *r = it->first; - const NnetComputation *c = it->second.first; - computation_cache_.erase(it); - delete r; - delete c; - access_queue_.pop_front(); - } - AqType::iterator ait = access_queue_.insert(access_queue_.end(), request); - computation_cache_.insert(std::make_pair(request, - std::make_pair(computation, ait))); -} + seconds_taken_check_(0.0), seconds_taken_indexes_(0.0), + seconds_taken_io_(0.0), cache_(config.cache_capacity) { } + void CachingOptimizingCompiler::ReadCache(std::istream &is, bool binary) { - NnetOptimizeOptions opt_config_cached; - opt_config_cached.Read(is, binary); - // we won't read cached computations if any optimize option has been changed. - bool read_cache = (opt_config_ == opt_config_cached); - - if (read_cache) { - int32 computation_cache_size; - ExpectToken(is, binary, ""); - ReadBasicType(is, binary, &computation_cache_size); - KALDI_ASSERT(computation_cache_size >= 0); - computation_cache_.clear(); - access_queue_.clear(); - ExpectToken(is, binary, ""); - for (size_t c = 0; c < computation_cache_size; c++) { - ComputationRequest *request = new ComputationRequest(); - request->Read(is, binary); - NnetComputation *computation = new NnetComputation(); - computation->Read(is, binary); - if (GetVerboseLevel() >= 3) { - Timer timer; - CheckComputationOptions check_config; - ComputationChecker checker(check_config, nnet_, *computation); - checker.Check(); - seconds_taken_check_ += timer.Elapsed(); - } - UpdateCache(request, computation); - } + { + Timer timer; + NnetOptimizeOptions opt_config_cached; + opt_config_cached.Read(is, binary); + // we won't read cached computations if any optimize option has been changed. + if (!(opt_config_ == opt_config_cached)) + return; + cache_.Read(is, binary); + seconds_taken_io_ += timer.Elapsed(); + } + if (GetVerboseLevel() >= 2) { + Timer timer; + cache_.Check(nnet_); + seconds_taken_check_ += timer.Elapsed(); + // we consider the check time part of the total time... this is very + // arbitrary but it only affects printed times-taken. + seconds_taken_total_ += timer.Elapsed(); } -} -void CachingOptimizingCompiler::WriteCache(std::ostream &os, bool binary) const { - opt_config_.Write(os, binary); - WriteToken(os, binary, ""); - WriteBasicType(os, binary, static_cast(computation_cache_.size())); - WriteToken(os, binary, ""); - for (CacheType::const_iterator iter = computation_cache_.begin(); - iter != computation_cache_.end(); ++iter) { - iter->first->Write(os, binary); - iter->second.first->Write(os, binary); - } } -void CachingOptimizingCompiler::UpdateAccessQueue(CacheType::iterator &cit) { - // exist, update access record by moving the accessed - // request to the end of the access queue - KALDI_ASSERT(cit != computation_cache_.end()); - access_queue_.splice(access_queue_.end(), access_queue_, - cit->second.second); +void CachingOptimizingCompiler::WriteCache(std::ostream &os, bool binary) { + Timer timer; + opt_config_.Write(os, binary); + cache_.Write(os, binary); + seconds_taken_io_ += timer.Elapsed(); } CachingOptimizingCompiler::~CachingOptimizingCompiler() { - CacheType::const_iterator itr = computation_cache_.begin(), - end = computation_cache_.end(); - for (; itr !=end; ++itr) { - delete itr->first; - delete itr->second.first; - } - if (seconds_taken_total_ > 0.0) { + if (seconds_taken_total_ > 0.0 || seconds_taken_io_ > 0.0) { std::ostringstream os; double seconds_taken_misc = seconds_taken_total_ - seconds_taken_compile_ - seconds_taken_optimize_ - seconds_taken_expand_ @@ -742,52 +693,40 @@ CachingOptimizingCompiler::~CachingOptimizingCompiler() { << seconds_taken_expand_ << " shortcut expansion, " << seconds_taken_check_ << " checking, " << seconds_taken_indexes_ << " computing indexes, " - << seconds_taken_misc << " misc.)"; + << seconds_taken_misc << " misc.) + " + << seconds_taken_io_ << " I/O."; KALDI_LOG << os.str(); // note: the leftover amount is misc things like hashing and == comparisons on // computation-requests, and calling RequestIsDecomposable(). } } -const NnetComputation* CachingOptimizingCompiler::Compile( +std::shared_ptr CachingOptimizingCompiler::Compile( const ComputationRequest &in_request) { Timer timer; - const NnetComputation *ans = CompileInternal(in_request); + std::shared_ptr ans = CompileInternal(in_request); seconds_taken_total_ += timer.Elapsed(); return ans; } -const NnetComputation* CachingOptimizingCompiler::CompileInternal( - const ComputationRequest &in_request) { - const NnetComputation *ans; - // find computation in the cache - CacheType::iterator cit = computation_cache_.find(&in_request); - if (cit == computation_cache_.end()) { - ans = CompileAndCache(in_request); +std::shared_ptr CachingOptimizingCompiler::CompileInternal( + const ComputationRequest &request) { + std::shared_ptr ans = cache_.Find(request); + if (ans != NULL) { + return ans; } else { - // if found, update access queue - const NnetComputation *computation = cit->second.first; - UpdateAccessQueue(cit); - ans = computation; + const NnetComputation *computation = NULL; + if (config_.use_shortcut) + computation = CompileViaShortcut(request); + if (computation == NULL) + computation = CompileNoShortcut(request); + KALDI_ASSERT(computation != NULL); + return cache_.Insert(request, computation); } - return ans; } -const NnetComputation* CachingOptimizingCompiler::CompileAndCache( - const ComputationRequest &in_request) { - // we need to make a copy of ComputationRequest, because it's stored - // as the key in the cache, and we need to own the pointer. - ComputationRequest *request = new ComputationRequest(in_request); - - const NnetComputation *computation = CompileViaShortcut(*request); - if (computation == NULL) - computation = CompileNoShortcut(*request); - UpdateCache(request, computation); - return computation; -} - -const NnetComputation* CachingOptimizingCompiler::CompileNoShortcut( +const NnetComputation *CachingOptimizingCompiler::CompileNoShortcut( const ComputationRequest &request) { Compiler compiler(request, nnet_); @@ -831,12 +770,12 @@ const NnetComputation* CachingOptimizingCompiler::CompileNoShortcut( seconds_taken_optimize_ += timer.Elapsed(); } - if (GetVerboseLevel() >= verbose_cutoff) { std::ostringstream os; computation->Print(os, nnet_); KALDI_LOG << "Optimized computation is: " << os.str(); } + { // check the computation again. Timer timer; CheckComputationOptions check_config; @@ -844,6 +783,7 @@ const NnetComputation* CachingOptimizingCompiler::CompileNoShortcut( checker.Check(); seconds_taken_check_ += timer.Elapsed(); } + { Timer timer; computation->ComputeCudaIndexes(); @@ -853,22 +793,17 @@ const NnetComputation* CachingOptimizingCompiler::CompileNoShortcut( } -const NnetComputation* CachingOptimizingCompiler::CompileViaShortcut( +const NnetComputation *CachingOptimizingCompiler::CompileViaShortcut( const ComputationRequest &request) { - if (!config_.use_shortcut) - return NULL; - int32 num_n_values; ComputationRequest mini_request; if (!RequestIsDecomposable(request, &mini_request, &num_n_values)) return NULL; // By invoking CompileInternal() on the mini request, we go through the same - // caching process as for any externally requested computation. [the only - // difference from Compile() is that it doesn't call the timer code; this - // avoids double-counting the time taken.] This pointer will not have to be - // deleted by this function; it's owned by the class, in the cache. - const NnetComputation *mini_computation = CompileInternal(mini_request); + // caching process as for any externally requested computation. + std::shared_ptr mini_computation = + CompileInternal(mini_request); // note: by default we always create debug_info, even in regular compilation. // (e.g. it defaults to true in CompilerOptions). If it really seems to be a diff --git a/src/nnet3/nnet-optimize.h b/src/nnet3/nnet-optimize.h index a07c5490c5c..3186895838b 100644 --- a/src/nnet3/nnet-optimize.h +++ b/src/nnet3/nnet-optimize.h @@ -23,8 +23,7 @@ #include "nnet3/nnet-compile.h" #include "nnet3/nnet-analyze.h" - -#include +#include "nnet3/nnet-optimize-utils.h" namespace kaldi { namespace nnet3 { @@ -34,6 +33,8 @@ namespace nnet3 { // detected, we can work out which optimization was responsible for the error. // See the Register() function below for option-specific documentation. struct NnetOptimizeOptions { + // Caution: if adding or removing members, the Read and Write functions and + // the == operator should be modified. This relates to computation caching. bool optimize; // setting this false disallow all optimization. bool consolidate_model_update; bool propagate_in_place; @@ -186,22 +187,6 @@ void Optimize(const NnetOptimizeOptions &config, int32 max_output_time_in_request, NnetComputation *computation); -// Hash function for ComputationRequest. It converts -// ComputationRequest to hash code by looking at input -// and output IoSpecifications vectors. -struct ComputationRequestHasher { - size_t operator()(const ComputationRequest *cr) const noexcept; -}; - -// Equality function for ComputationRequest pointer -struct ComputationRequestPtrEqual { - public: - bool operator() (const ComputationRequest* cr1, - const ComputationRequest* cr2) const { - return (*cr1) == (*cr2); - } -}; - struct CachingOptimizingCompilerOptions { @@ -229,6 +214,8 @@ struct CachingOptimizingCompilerOptions { /// This class enables you to do the compilation and optimization in one call, /// and also ensures that if the ComputationRequest is identical to the previous /// one, the compilation process is not repeated. +/// It is safe to call Compile() from multiple parallel threads without additional +/// synchronization; synchronization is managed internally by class ComputationCache. class CachingOptimizingCompiler { public: CachingOptimizingCompiler(const Nnet &nnet, @@ -242,29 +229,34 @@ class CachingOptimizingCompiler { CachingOptimizingCompilerOptions()); ~CachingOptimizingCompiler(); - /// Does the compilation and returns a const pointer to - /// the result, which is owned by this class, not the caller. - /// It calls ComputeCudaIndexes() for you, because you wouldn't - /// be able to do this on a const object. - const NnetComputation* Compile(const ComputationRequest &request); + + /// Does the compilation and returns a const pointer to the result, which is + /// owned by this class, not the caller. It calls ComputeCudaIndexes() for + /// you, because you wouldn't be able to do this on a const object. If you + /// want to preserve thread safety you should hold the result in the same type + /// (std::shared_ptr) while you still need it, but + /// otherwise you can just cast to const NnetComputation*. + std::shared_ptr Compile( + const ComputationRequest &request); void ReadCache(std::istream &is, bool binary); - void WriteCache(std::ostream &os, bool binary) const; + void WriteCache(std::ostream &os, bool binary); + private: // This function just implements the work of Compile(); it's made a separate // function for the convenience of the timer code, to avoid it being called // twice (we also call this function directly from inside the class). - const NnetComputation* CompileInternal(const ComputationRequest &request); + std::shared_ptr CompileInternal(const ComputationRequest &request); // This function, called from CompileInternal(), is called when a // ComputationRequest has been determined not to have already been cached. It // otherwise has the same interface as CompileInternal(), but assumes that // there is nothing cached for this computation as yet. It compiles the // computation and takes care of caching it. - const NnetComputation* CompileAndCache(const ComputationRequest &request); + std::shared_ptr CompileAndCache(const ComputationRequest &request); - // This function, called from CompileAndCache(), tries to compile the + // This function, called from CompileInternal(), tries to compile the // ComputationRequest 'request' via 'shortcut' compilation; if this is // possible, it returns a pointer to a newly allocated computation that it has // compiled this way (note: this computation will not yet have been placed in @@ -273,36 +265,19 @@ class CachingOptimizingCompiler { // request was not decomposable because of too few n values or irregular or // unexpected structure), this function returns NULL and you should compile // via CompileNoShortcut. - const NnetComputation* CompileViaShortcut(const ComputationRequest &request); + const NnetComputation *CompileViaShortcut(const ComputationRequest &request); - // This function, called from CompileAndCache(), tries to compile the + // This function, called from CompileInternal(), tries to compile the // ComputationRequest 'request' via the regular (not shortcut) compilation // process; it returns a pointer to a newly allocated computation that it has // compiled this way (note: this computation will not yet have been placed in // the computation cache). - const NnetComputation* CompileNoShortcut(const ComputationRequest &request); + const NnetComputation *CompileNoShortcut(const ComputationRequest &request); const Nnet &nnet_; CachingOptimizingCompilerOptions config_; NnetOptimizeOptions opt_config_; - // The access queue for keeping track of the freshness of computation. - // Most-recently-accessed computation is at the end, and - // least-recently-accessed computaiton is at the beginning. - // Together with computation_cache_, this forms a most-recently-used (MRU) - // cache for Computations, indexed by ComputationRequest. Pointers - // are owned in computation_cache_. - typedef std::list AqType; - AqType access_queue_; - - // Map from computation-request to pair of (computation, and position in - // access_queue_). Used for fast lookup of previously compiled computations. - // All pointers are owned here. - typedef unordered_map, - ComputationRequestHasher, - ComputationRequestPtrEqual> CacheType; - CacheType computation_cache_; // seconds spent in various phases of compilation-- for diagnostic messages double seconds_taken_total_; @@ -311,15 +286,9 @@ class CachingOptimizingCompiler { double seconds_taken_expand_; double seconds_taken_check_; double seconds_taken_indexes_; + double seconds_taken_io_; - // This function updates the computation cache. It is called within - // CompileInternal(). It takes ownership of the pointers. It inserts the - // request at the end of the queue, and purges the least-recently-accessed - // request from the queue and the cache if the capacity is reached. - void UpdateCache(const ComputationRequest *request, - const NnetComputation *computation); - // This function updates the recently accessed queue. - void UpdateAccessQueue(CacheType::iterator &cit); + ComputationCache cache_; }; diff --git a/src/nnet3/nnet-training.cc b/src/nnet3/nnet-training.cc index 812b66c41b1..49222549e4e 100644 --- a/src/nnet3/nnet-training.cc +++ b/src/nnet3/nnet-training.cc @@ -62,7 +62,7 @@ void NnetTrainer::Train(const NnetExample &eg) { GetComputationRequest(*nnet_, eg, need_model_derivative, config_.store_component_stats, &request); - const NnetComputation *computation = compiler_.Compile(request); + std::shared_ptr computation = compiler_.Compile(request); if (config_.backstitch_training_scale > 0.0 && num_minibatches_processed_ % config_.backstitch_training_interval == diff --git a/src/rnnlm/rnnlm-core-compute.cc b/src/rnnlm/rnnlm-core-compute.cc index d7ec22dc538..f0cf4487c2b 100644 --- a/src/rnnlm/rnnlm-core-compute.cc +++ b/src/rnnlm/rnnlm-core-compute.cc @@ -44,7 +44,7 @@ BaseFloat RnnlmCoreComputer::Compute( store_component_stats, &request); - const NnetComputation *computation = compiler_.Compile(request); + std::shared_ptr computation = compiler_.Compile(request); NnetComputeOptions compute_opts; diff --git a/src/rnnlm/rnnlm-core-training.cc b/src/rnnlm/rnnlm-core-training.cc index ddf4e7b3fb6..63a6dee188d 100644 --- a/src/rnnlm/rnnlm-core-training.cc +++ b/src/rnnlm/rnnlm-core-training.cc @@ -156,7 +156,7 @@ void RnnlmCoreTrainer::Train( store_component_stats, &request); - const NnetComputation *computation = compiler_.Compile(request); + std::shared_ptr computation = compiler_.Compile(request); NnetComputeOptions compute_opts; @@ -178,12 +178,12 @@ void RnnlmCoreTrainer::Train( word_embedding_deriv->AddSmatMat(1.0, derived.input_words_smat, kNoTrans, input_deriv, 1.0); } - // If relevant, add in the part of the gradient that comes from L2 + // If relevant, add in the part of the gradient that comes from L2 // regularization. ApplyL2Regularization(*nnet_, minibatch.num_chunks * config_.l2_regularize_factor, delta_nnet_); - + bool success = UpdateNnetWithMaxChange(*delta_nnet_, config_.max_param_change, 1.0, 1.0 - config_.momentum, nnet_, &num_max_change_per_component_applied_, &num_max_change_global_applied_); @@ -214,7 +214,7 @@ void RnnlmCoreTrainer::TrainBackstitch( store_component_stats, &request); - const NnetComputation *computation = compiler_.Compile(request); + std::shared_ptr computation = compiler_.Compile(request); NnetComputeOptions compute_opts; @@ -259,7 +259,7 @@ void RnnlmCoreTrainer::TrainBackstitch( minibatch.num_chunks * config_.l2_regularize_factor, delta_nnet_); } - + UpdateNnetWithMaxChange(*delta_nnet_, config_.max_param_change, max_change_scale, scale_adding, nnet_, &num_max_change_per_component_applied_, &num_max_change_global_applied_); @@ -309,7 +309,7 @@ void RnnlmCoreTrainer::PrintMaxChangeStats() const { if (num_max_change_global_applied_ > 0) KALDI_LOG << "The global max-change was enforced " << (100.0 * num_max_change_global_applied_) / - (num_minibatches_processed_ * + (num_minibatches_processed_ * (config_.backstitch_training_scale == 0.0 ? 1.0 : 1.0 + 1.0 / config_.backstitch_training_interval)) << "\% of the time."; From 51c87f5ff69a3f8fc8e852ffcd45c4fce35aee55 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 22 Mar 2018 14:28:25 -0400 Subject: [PATCH 2/3] [src] Fix to nnet3-xvector-compute.cc, thx: @gorinars --- src/nnet3bin/nnet3-xvector-compute.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nnet3bin/nnet3-xvector-compute.cc b/src/nnet3bin/nnet3-xvector-compute.cc index e649dc477d5..664b15eb246 100644 --- a/src/nnet3bin/nnet3-xvector-compute.cc +++ b/src/nnet3bin/nnet3-xvector-compute.cc @@ -44,7 +44,7 @@ static void RunNnetComputation(const MatrixBase &features, output_spec.indexes.resize(1); request.outputs.resize(1); request.outputs[0].Swap(&output_spec); - const NnetComputation *computation = compiler->Compile(request); + std::shared_ptr computation = compiler->Compile(request); Nnet *nnet_to_update = NULL; // we're not doing any update. NnetComputer computer(NnetComputeOptions(), *computation, nnet, nnet_to_update); From 71a04174ceea88ac29bbcf80d1493213bdc464f2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 22 Mar 2018 14:30:34 -0400 Subject: [PATCH 3/3] [src] Fix to comment --- src/nnet3/nnet-optimize.h | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/nnet3/nnet-optimize.h b/src/nnet3/nnet-optimize.h index 3186895838b..aaa1182e1b6 100644 --- a/src/nnet3/nnet-optimize.h +++ b/src/nnet3/nnet-optimize.h @@ -232,10 +232,11 @@ class CachingOptimizingCompiler { /// Does the compilation and returns a const pointer to the result, which is /// owned by this class, not the caller. It calls ComputeCudaIndexes() for - /// you, because you wouldn't be able to do this on a const object. If you - /// want to preserve thread safety you should hold the result in the same type - /// (std::shared_ptr) while you still need it, but - /// otherwise you can just cast to const NnetComputation*. + /// you, because you wouldn't be able to do this on a const object. + /// + /// Note: this used to return 'const NnetComputation*'. If you get a + /// compilation failure, just replace 'const NnetComputation*' with + /// 'std::shared_ptr' in the calling code. std::shared_ptr Compile( const ComputationRequest &request); void ReadCache(std::istream &is, bool binary);