Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/nnet3/natural-gradient-online.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#define KALDI_NNET3_NATURAL_GRADIENT_ONLINE_H_

#include <iostream>
#include <mutex>
#include "base/kaldi-common.h"
#include "matrix/matrix-lib.h"
#include "cudamatrix/cu-matrix-lib.h"
Expand Down
2 changes: 1 addition & 1 deletion src/nnet3/nnet-am-decodable-simple.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ void DecodableNnetSimple::DoNnetComputation(
request.outputs.resize(1);
request.outputs[0].Swap(&output_spec);

const NnetComputation *computation = compiler_.Compile(request);
std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please apply the same fix in nnet3-xvector-compute.cc:47

Nnet *nnet_to_update = NULL; // we're not doing any update.
NnetComputer computer(opts_.compute_config, *computation,
nnet_, nnet_to_update);
Expand Down
2 changes: 1 addition & 1 deletion src/nnet3/nnet-chain-diagnostics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ void NnetChainComputeProb::Compute(const NnetChainExample &chain_eg) {
GetChainComputationRequest(nnet_, chain_eg, need_model_derivative,
store_component_stats, use_xent_regularization,
use_xent_derivative, &request);
const NnetComputation *computation = compiler_.Compile(request);
std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);
NnetComputer computer(nnet_config_.compute_config, *computation,
nnet_, deriv_nnet_);
// give the inputs to the computer object.
Expand Down
2 changes: 1 addition & 1 deletion src/nnet3/nnet-chain-training.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void NnetChainTrainer::Train(const NnetChainExample &chain_eg) {
nnet_config.store_component_stats,
use_xent_regularization, need_model_derivative,
&request);
const NnetComputation *computation = compiler_.Compile(request);
std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);

if (nnet_config.backstitch_training_scale > 0.0 && num_minibatches_processed_
% nnet_config.backstitch_training_interval ==
Expand Down
1 change: 0 additions & 1 deletion src/nnet3/nnet-component-itf.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#define KALDI_NNET3_NNET_COMPONENT_ITF_H_

#include <iostream>
#include <mutex>
#include "nnet3/nnet-common.h"
#include "nnet3/nnet-parse.h"
#include "base/kaldi-error.h"
Expand Down
17 changes: 17 additions & 0 deletions src/nnet3/nnet-computation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1199,6 +1199,23 @@ size_t IoSpecificationHasher::operator () (
(io_spec.has_deriv ? 4261 : 0);
}

// ComputationRequests are distinguished by the names and indexes
// of inputs and outputs
size_t ComputationRequestHasher::operator() (
const ComputationRequest *cr) const noexcept {
size_t ans = 0;
size_t p1 = 4111, p2 = 26951;
IoSpecificationHasher io_hasher;
std::vector<IoSpecification>::const_iterator itr = cr->inputs.begin(),
end = cr->inputs.end();
for (; itr != end; ++itr)
ans = ans * p1 + io_hasher(*itr);
itr = cr->outputs.begin();
end = cr->outputs.end();
for (; itr != end; ++itr)
ans = ans * p2 + io_hasher(*itr);
return ans;
}



Expand Down
16 changes: 16 additions & 0 deletions src/nnet3/nnet-computation.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,22 @@ struct ComputationRequest {
bool operator== (const ComputationRequest &other) const;
};

// Hash function for ComputationRequest. It converts
// ComputationRequest to hash code by looking at input
// and output IoSpecifications vectors.
struct ComputationRequestHasher {
size_t operator()(const ComputationRequest *cr) const noexcept;
};

// Equality function for ComputationRequest pointer
struct ComputationRequestPtrEqual {
public:
bool operator() (const ComputationRequest* cr1,
const ComputationRequest* cr2) const {
return (*cr1) == (*cr2);
}
};


/**
CommandType is an enum that describes the category of the command used in
Expand Down
20 changes: 10 additions & 10 deletions src/nnet3/nnet-diagnostics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ void NnetComputeProb::Compute(const NnetExample &eg) {
GetComputationRequest(nnet_, eg, need_model_derivative,
store_component_stats,
&request);
const NnetComputation *computation = compiler_.Compile(request);
std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);
NnetComputer computer(config_.compute_config, *computation,
nnet_, deriv_nnet_);
// give the inputs to the computer object.
Expand Down Expand Up @@ -122,21 +122,21 @@ void NnetComputeProb::ProcessOutputs(const NnetExample &eg,
totals.tot_objective += tot_objf;
}
// May not be meaningful in non-classification tasks
if (config_.compute_accuracy) {
if (config_.compute_accuracy) {
BaseFloat tot_weight, tot_accuracy;
PerDimObjectiveInfo &acc_totals = accuracy_info_[io.name];

if (config_.compute_per_dim_accuracy &&
if (config_.compute_per_dim_accuracy &&
acc_totals.tot_objective_vec.Dim() == 0) {
acc_totals.tot_objective_vec.Resize(output.NumCols());
acc_totals.tot_weight_vec.Resize(output.NumCols());
}

ComputeAccuracy(io.features, output,
&tot_weight, &tot_accuracy,
config_.compute_per_dim_accuracy ?
config_.compute_per_dim_accuracy ?
&acc_totals.tot_weight_vec : NULL,
config_.compute_per_dim_accuracy ?
config_.compute_per_dim_accuracy ?
&acc_totals.tot_objective_vec : NULL);
acc_totals.tot_weight += tot_weight;
acc_totals.tot_objective += tot_accuracy;
Expand All @@ -149,7 +149,7 @@ void NnetComputeProb::ProcessOutputs(const NnetExample &eg,
bool NnetComputeProb::PrintTotalStats() const {
bool ans = false;
{ // First print regular objectives
unordered_map<std::string, SimpleObjectiveInfo,
unordered_map<std::string, SimpleObjectiveInfo,
StringHasher>::const_iterator iter, end;
iter = objf_info_.begin();
end = objf_info_.end();
Expand All @@ -168,8 +168,8 @@ bool NnetComputeProb::PrintTotalStats() const {
ans = true;
}
}
{
unordered_map<std::string, PerDimObjectiveInfo,
{
unordered_map<std::string, PerDimObjectiveInfo,
StringHasher>::const_iterator iter, end;
// now print accuracies.
iter = accuracy_info_.begin();
Expand All @@ -185,14 +185,14 @@ bool NnetComputeProb::PrintTotalStats() const {
Vector<BaseFloat> accuracy_vec(info.tot_weight_vec.Dim());
for (size_t j = 0; j < info.tot_weight_vec.Dim(); j++) {
if (info.tot_weight_vec(j) != 0) {
accuracy_vec(j) = info.tot_objective_vec(j)
accuracy_vec(j) = info.tot_objective_vec(j)
/ info.tot_weight_vec(j);
} else {
accuracy_vec(j) = -1.0;
}
}

KALDI_LOG << "Overall per-dim accuracy vector for '" << name
KALDI_LOG << "Overall per-dim accuracy vector for '" << name
<< "' is " << accuracy_vec << " per frame"
<< ", over " << info.tot_weight << " frames.";
}
Expand Down
3 changes: 1 addition & 2 deletions src/nnet3/nnet-discriminative-diagnostics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ void NnetDiscriminativeComputeObjf::Compute(const NnetDiscriminativeExample &eg)
store_component_stats,
use_xent_regularization, use_xent_derivative,
&request);
const NnetComputation *computation = compiler_.Compile(request);
std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);
NnetComputer computer(nnet_config_.compute_config, *computation,
nnet_, deriv_nnet_);
// give the inputs to the computer object.
Expand Down Expand Up @@ -206,4 +206,3 @@ const discriminative::DiscriminativeObjectiveInfo* NnetDiscriminativeComputeObjf

} // namespace nnet3
} // namespace kaldi

2 changes: 1 addition & 1 deletion src/nnet3/nnet-discriminative-training.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void NnetDiscriminativeTrainer::Train(const NnetDiscriminativeExample &eg) {
use_xent_regularization,
need_model_derivative,
&request);
const NnetComputation *computation = compiler_.Compile(request);
std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);

NnetComputer computer(nnet_config.compute_config, *computation,
*nnet_,
Expand Down
121 changes: 121 additions & 0 deletions src/nnet3/nnet-optimize-utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4950,5 +4950,126 @@ void OptimizeMemoryCompression(const Nnet &nnet,
}


std::shared_ptr<const NnetComputation> ComputationCache::Find(
const ComputationRequest &in_request) {
std::lock_guard<std::mutex> lock(mutex_);

CacheType::iterator iter = computation_cache_.find(&in_request);
if (iter == computation_cache_.end()) {
return NULL;
} else {
std::shared_ptr<const NnetComputation> ans = iter->second.first;
// Update access record by moving the accessed request to the end of the
// access queue, which declares that it's the most recently used.
access_queue_.splice(access_queue_.end(), access_queue_,
iter->second.second);
return ans;
}
}


ComputationCache::ComputationCache(int32 cache_capacity):
cache_capacity_(cache_capacity) {
KALDI_ASSERT(cache_capacity > 0);
}

std::shared_ptr<const NnetComputation> ComputationCache::Insert(
const ComputationRequest &request_in,
const NnetComputation *computation_in) {

std::lock_guard<std::mutex> lock(mutex_);
if (static_cast<int32>(computation_cache_.size()) >= cache_capacity_) {
// Cache has reached capacity; purge the least-recently-accessed request
const CacheType::iterator iter =
computation_cache_.find(access_queue_.front());
KALDI_ASSERT(iter != computation_cache_.end());
const ComputationRequest *request = iter->first;
computation_cache_.erase(iter);
delete request;
// we don't need to delete the computation in iter->second.first, as the
// shared_ptr takes care of that automatically.
access_queue_.pop_front();
}

// Now insert the thing we need to insert. We'll own the pointer 'request' in
// 'computation_cache_', so we need to allocate our own version.
ComputationRequest *request = new ComputationRequest(request_in);
// When we construct this shared_ptr, it takes ownership of the pointer
// 'computation_in'.
std::shared_ptr<const NnetComputation> computation(computation_in);

AqType::iterator ait = access_queue_.insert(access_queue_.end(), request);

std::pair<CacheType::iterator, bool> p = computation_cache_.insert(
std::make_pair(request, std::make_pair(computation, ait)));
if (!p.second) {
// if p.second is false, this pair was not inserted because
// a computation for the same computation-request already existed in
// the map. This is possible in multi-threaded operations, if two
// threads try to compile the same computation at the same time (only
// one of them will successfully add it).
// We need to erase the access-queue element that we just added, it's
// no longer going to be needed.
access_queue_.erase(ait);
delete request;
}
return computation;
}


void ComputationCache::Read(std::istream &is, bool binary) {
// Note: the object on disk doesn't have tokens like "<ComputationCache>"
// and "</ComputationCache>" for back-compatibility reasons.
int32 computation_cache_size;
ExpectToken(is, binary, "<ComputationCacheSize>");
ReadBasicType(is, binary, &computation_cache_size);
KALDI_ASSERT(computation_cache_size >= 0);
computation_cache_.clear();
access_queue_.clear();
ExpectToken(is, binary, "<ComputationCache>");
for (size_t c = 0; c < computation_cache_size; c++) {
ComputationRequest request;
request.Read(is, binary);
NnetComputation *computation = new NnetComputation();
computation->Read(is, binary);
Insert(request, computation);
}
}

void ComputationCache::Check(const Nnet &nnet) const {
CacheType::const_iterator iter = computation_cache_.begin(),
end = computation_cache_.end();
// We only need to explicitly delete the pointer to the ComputationRequest.
// The pointers to Computation are deleted automatically by std::shared_ptr
// when the reference count goes to zero.
for (; iter != end; ++iter) {
const NnetComputation &computation = *(iter->second.first);
CheckComputationOptions check_config;
ComputationChecker checker(check_config, nnet, computation);
checker.Check();
}
}

void ComputationCache::Write(std::ostream &os, bool binary) const {
WriteToken(os, binary, "<ComputationCacheSize>");
WriteBasicType(os, binary, static_cast<int32>(computation_cache_.size()));
WriteToken(os, binary, "<ComputationCache>");
for (CacheType::const_iterator iter = computation_cache_.begin();
iter != computation_cache_.end(); ++iter) {
iter->first->Write(os, binary);
iter->second.first->Write(os, binary);
}
}

ComputationCache::~ComputationCache() {
CacheType::const_iterator iter = computation_cache_.begin(),
end = computation_cache_.end();
// We only need to explicitly delete the pointer to the ComputationRequest.
// The pointers to Computation are deleted automatically by std::shared_ptr
// when the reference count goes to zero.
for (; iter != end; ++iter)
delete iter->first;
}

} // namespace nnet3
} // namespace kaldi
68 changes: 61 additions & 7 deletions src/nnet3/nnet-optimize-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
#ifndef KALDI_NNET3_NNET_OPTIMIZE_UTILS_H_
#define KALDI_NNET3_NNET_OPTIMIZE_UTILS_H_

#include <mutex>
#include <list>
#include "nnet3/nnet-compile.h"
#include "nnet3/nnet-analyze.h"


namespace kaldi {
namespace nnet3 {

Expand Down Expand Up @@ -613,16 +616,67 @@ void OptimizeLoopedComputation(const Nnet &nnet,
void FixGotoLabel(NnetComputation *computation);


/*
/// Class ComputationCache is used inside class CachingOptimizingCompiler to
/// cache previously computed computations. The code was moved from class
/// CachingOptimizingCompiler to this separate class for clarity when adding
/// thread-safety functionality.
/// It's OK to call Find() and Insert() from multiple threads without
/// additional synchronization.
class ComputationCache {
public:
ComputationCache(int32 cache_capacity);

Possible TODO:
optimizations to replace row-by-row copy and add commands with whole-matrix
commands on smaller sub-matrices (if the row-by-row copy commands have certain
regularities). this is a minor issue, we can handle it later. We have to be
careful if this causes sub-matrices to overlap.
// Note: if something fails in Read(), or the written cache was from an older
// format, it will just leave the cache empty.
void Read(std::istream &is, bool binary);

void Write(std::ostream &os, bool binary) const;


// Searches for the computation corresponding to this computation, and returns
// it if cached, or NULL (as std::shared_ptr) if not. (We need shared_ptr to
// handle multi-threaded operation, so that if the computation is ejected from
// the cache by another thread, it won't be deleted while still in use). This
// function also moves this computation to the end of the
// most-recently-accessed queue, which is why it's not const.
std::shared_ptr<const NnetComputation> Find(const ComputationRequest &request);

*/

// Inserts the computation into the cache-- this is assumed to be the
// computation for the computation-request 'request'. Returns a shared_ptr
// which can be used to access the object. This function takes ownership of
// 'computation'.
std::shared_ptr<const NnetComputation> Insert(const ComputationRequest &request,
const NnetComputation *computation);

~ComputationCache();

// Checks the stored computation for correctness.
void Check(const Nnet &nnet) const;
private:

std::mutex mutex_; // Read/write mutex.

int32 cache_capacity_;

// The access queue for keeping track of the freshness of computation.
// Most-recently-accessed computation is at the end, and
// least-recently-accessed computaiton is at the beginning. Together with
// computation_cache_, this forms a most-recently-used (MRU) cache for
// Computations, indexed by ComputationRequest. The pointers are owned in
// computation_cache_.
typedef std::list<const ComputationRequest*> AqType;
AqType access_queue_;

// Map from computation-request to pair of (computation, and position in
// access_queue_). Used for fast lookup of previously compiled computations.
// All pointers are owned here.
typedef unordered_map<const ComputationRequest*,
std::pair<std::shared_ptr<const NnetComputation>, AqType::iterator>,
ComputationRequestHasher,
ComputationRequestPtrEqual> CacheType;
CacheType computation_cache_;
};



Expand Down
Loading