Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 30 additions & 79 deletions src/rnnlm/rnnlm-training.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ RnnlmTrainer::RnnlmTrainer(bool train_embedding,
embedding_trainer_(NULL),
word_feature_mat_(word_feature_mat),
num_minibatches_processed_(0),
end_of_input_(false),
previous_minibatch_empty_(1),
current_minibatch_empty_(1),
srand_seed_(RandInt(0, 100000)) {


Expand Down Expand Up @@ -75,13 +72,6 @@ RnnlmTrainer::RnnlmTrainer(bool train_embedding,
<< embedding_mat_->NumRows() << " (mismatch).";
}
}

// Start a thread that calls run_background_thread(this).
// That thread will be responsible for computing derived variables of
// the minibatch, since that can be done independently of the main
// training process.
background_thread_ = std::thread(run_background_thread, this);

}


Expand All @@ -92,25 +82,40 @@ void RnnlmTrainer::Train(RnnlmExample *minibatch) {
<< VocabSize() << ", got "
<< minibatch->vocab_size;

// hand over 'minibatch' to the background thread to have its derived variable
// computed, via the class variable 'current_minibatch_'.
current_minibatch_empty_.Wait();
current_minibatch_.Swap(minibatch);
current_minibatch_full_.Signal();
num_minibatches_processed_++;
if (num_minibatches_processed_ == 1) {
return; // The first time this function is called, return immediately
// because there is no previous minibatch to train on.
RnnlmExampleDerived derived;
CuArray<int32> active_words_cuda;
CuSparseMatrix<BaseFloat> active_word_features;
CuSparseMatrix<BaseFloat> active_word_features_trans;

if (!current_minibatch_.sampled_words.empty()) {
std::vector<int32> active_words;
RenumberRnnlmExample(&current_minibatch_, &active_words);
active_words_cuda.CopyFromVec(active_words);

if (word_feature_mat_ != NULL) {
active_word_features.SelectRows(active_words_cuda,
*word_feature_mat_);
active_word_features_trans.CopyFromSmat(active_word_features,
kTrans);
}
}
previous_minibatch_full_.Wait();
GetRnnlmExampleDerived(current_minibatch_, train_embedding_,
&derived);

derived_.Swap(&derived);
active_words_.Swap(&active_words_cuda);
active_word_features_.Swap(&active_word_features);
active_word_features_trans_.Swap(&active_word_features_trans);

TrainInternal();
previous_minibatch_empty_.Signal();
}


void RnnlmTrainer::GetWordEmbedding(CuMatrix<BaseFloat> *word_embedding_storage,
CuMatrix<BaseFloat> **word_embedding) {
RnnlmExample &minibatch = previous_minibatch_;
RnnlmExample &minibatch = current_minibatch_;
bool sampling = !minibatch.sampled_words.empty();

if (word_feature_mat_ == NULL) {
Expand Down Expand Up @@ -148,7 +153,7 @@ void RnnlmTrainer::GetWordEmbedding(CuMatrix<BaseFloat> *word_embedding_storage,

void RnnlmTrainer::TrainWordEmbedding(
CuMatrixBase<BaseFloat> *word_embedding_deriv) {
RnnlmExample &minibatch = previous_minibatch_;
RnnlmExample &minibatch = current_minibatch_;
bool sampling = !minibatch.sampled_words.empty();

if (word_feature_mat_ == NULL) {
Expand Down Expand Up @@ -186,7 +191,7 @@ void RnnlmTrainer::TrainWordEmbedding(
void RnnlmTrainer::TrainBackstitchWordEmbedding(
bool is_backstitch_step1,
CuMatrixBase<BaseFloat> *word_embedding_deriv) {
RnnlmExample &minibatch = previous_minibatch_;
RnnlmExample &minibatch = current_minibatch_;
bool sampling = !minibatch.sampled_words.empty();

if (word_feature_mat_ == NULL) {
Expand Down Expand Up @@ -239,21 +244,21 @@ void RnnlmTrainer::TrainInternal() {
srand_seed_ % core_config_.backstitch_training_interval) {
bool is_backstitch_step1 = true;
srand(srand_seed_ + num_minibatches_processed_);
core_trainer_->TrainBackstitch(is_backstitch_step1, previous_minibatch_,
core_trainer_->TrainBackstitch(is_backstitch_step1, current_minibatch_,
derived_, *word_embedding,
(train_embedding_ ? &word_embedding_deriv : NULL));
if (train_embedding_)
TrainBackstitchWordEmbedding(is_backstitch_step1, &word_embedding_deriv);

is_backstitch_step1 = false;
srand(srand_seed_ + num_minibatches_processed_);
core_trainer_->TrainBackstitch(is_backstitch_step1, previous_minibatch_,
core_trainer_->TrainBackstitch(is_backstitch_step1, current_minibatch_,
derived_, *word_embedding,
(train_embedding_ ? &word_embedding_deriv : NULL));
if (train_embedding_)
TrainBackstitchWordEmbedding(is_backstitch_step1, &word_embedding_deriv);
} else {
core_trainer_->Train(previous_minibatch_, derived_, *word_embedding,
core_trainer_->Train(current_minibatch_, derived_, *word_embedding,
(train_embedding_ ? &word_embedding_deriv : NULL));
if (train_embedding_)
TrainWordEmbedding(&word_embedding_deriv);
Expand All @@ -265,61 +270,7 @@ int32 RnnlmTrainer::VocabSize() {
else return embedding_mat_->NumRows();
}

void RnnlmTrainer::RunBackgroundThread() {
while (true) {
current_minibatch_full_.Wait();
if (end_of_input_)
return;
RnnlmExampleDerived derived;
CuArray<int32> active_words_cuda;
CuSparseMatrix<BaseFloat> active_word_features;
CuSparseMatrix<BaseFloat> active_word_features_trans;

if (!current_minibatch_.sampled_words.empty()) {
std::vector<int32> active_words;
RenumberRnnlmExample(&current_minibatch_, &active_words);
active_words_cuda.CopyFromVec(active_words);

if (word_feature_mat_ != NULL) {
active_word_features.SelectRows(active_words_cuda,
*word_feature_mat_);
active_word_features_trans.CopyFromSmat(active_word_features,
kTrans);
}
}
GetRnnlmExampleDerived(current_minibatch_, train_embedding_,
&derived);

// Wait until the main thread is not currently processing
// previous_minibatch_; once we get this semaphore we are free to write to
// it and other related variables such as 'derived_'.
previous_minibatch_empty_.Wait();
previous_minibatch_.Swap(&current_minibatch_);
derived_.Swap(&derived);
active_words_.Swap(&active_words_cuda);
active_word_features_.Swap(&active_word_features);
active_word_features_trans_.Swap(&active_word_features_trans);

// The following statement signals that 'previous_minibatch_'
// and related variables have been written to by this thread.
previous_minibatch_full_.Signal();
// The following statement signals that 'current_minibatch_'
// has been consumed by this thread and is no longer needed.
current_minibatch_empty_.Signal();
}
}

RnnlmTrainer::~RnnlmTrainer() {
// Train on the last minibatch, because Train() always trains on the previously
// provided one (for threading reasons).
if (num_minibatches_processed_ > 0) {
previous_minibatch_full_.Wait();
TrainInternal();
}
end_of_input_ = true;
current_minibatch_full_.Signal();
background_thread_.join();

// Note: the following delete statements may cause some diagnostics to be
// issued, from the destructors of those classes.
if (core_trainer_)
Expand Down
78 changes: 5 additions & 73 deletions src/rnnlm/rnnlm-training.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#ifndef KALDI_RNNLM_RNNLM_TRAINING_H_
#define KALDI_RNNLM_RNNLM_TRAINING_H_

#include <thread>
#include "rnnlm/rnnlm-core-training.h"
#include "rnnlm/rnnlm-embedding-training.h"
#include "rnnlm/rnnlm-utils.h"
Expand Down Expand Up @@ -79,10 +78,7 @@ class RnnlmTrainer {


// Train on one example. The example is provided as a pointer because we
// acquire it destructively, via Swap(). Note: this function doesn't
// actually train on this eg; what it does is to train on the previous
// example, and provide this eg to the background thread that computes the
// derived parameters of the eg.
// acquire it destructively, via Swap().
void Train(RnnlmExample *minibatch);


Expand Down Expand Up @@ -129,16 +125,6 @@ class RnnlmTrainer {
bool is_backstitch_step1,
CuMatrixBase<BaseFloat> *word_embedding_deriv);

/// This is the function-call that's run as the background thread which
/// computes the derived parameters for each minibatch.
void RunBackgroundThread();

/// This function is invoked by the newly created background thread.
static void run_background_thread(RnnlmTrainer *trainer) {
trainer->RunBackgroundThread();
}


bool train_embedding_; // true if we are training the embedding.
const RnnlmCoreTrainerOptions &core_config_;
const RnnlmEmbeddingTrainerOptions &embedding_config_;
Expand Down Expand Up @@ -173,32 +159,14 @@ class RnnlmTrainer {
// it's needed.
CuSparseMatrix<BaseFloat> word_feature_mat_transpose_;


// num_minibatches_processed_ starts at zero is incremented each time we
// provide an example to the background thread for computing the derived
// parameters.
int32 num_minibatches_processed_;

// 'current_minibatch' is where the Train() function puts the minibatch that
// is provided to Train(), so that the background thread can work on it.
RnnlmExample current_minibatch_;
// View 'end_of_input_' as part of a unit with current_minibatch_, for threading/access
// purposes. It is set by the foreground thread from the destructor, while
// incrementing the current_minibatch_ready_ semaphore; and when the background
// thread decrements the semaphore and notices that end_of_input_ is true, it will
// exit.
bool end_of_input_;


// previous_minibatch_ is the previous minibatch that was provided to Train(),
// but the minibatch that we're currently trainig on.
RnnlmExample previous_minibatch_;
// The variables derived_ and active_words_ [and more that I'll add, TODO] are in the same
// group as previous_minibatch_ from the point of view
// of threading and access control.
RnnlmExampleDerived derived_;

// The variables derived_ and active_words_ corresponds to group as current_minibatch_.
RnnlmExampleDerived derived_;
// Only if we are doing subsampling (depends on the eg), active_words_
// contains the list of active words for the minibatch 'previous_minibatch_';
// contains the list of active words for the minibatch 'current_minibatch_';
// it is a CUDA version of the 'active_words' output by
// RenumberRnnlmExample(). Otherwise it is empty.
CuArray<int32> active_words_;
Expand All @@ -212,42 +180,6 @@ class RnnlmTrainer {
// This is a derived quantity computed by the background thread.
CuSparseMatrix<BaseFloat> active_word_features_trans_;


// The 'previous_minibatch_full_' semaphore is incremented by the background
// thread once it has written to 'previous_minibatch_' and
// 'derived_previous_', to let the Train() function know that they are ready
// to be trained on. The Train() function waits on this semaphore.
Semaphore previous_minibatch_full_;

// The 'previous_minibatch_empty_' semaphore is incremented by the foreground
// thread when it has done processing previous_minibatch_ and
// derived_ and active_words_ (and hence, it is safe for the background thread to write
// to these variables). The background thread waits on this semaphore once it
// has finished computing the derived variables; and when it successfully
// decrements it, it will write to those variables (quickly, via Swap()).
Semaphore previous_minibatch_empty_;


// The 'current_minibatch_ready_' semaphore is incremented by the foreground
// thread from Train(), when it has written the just-provided minibatch to
// 'current_minibatch_' (it's also incremented by the destructor, together
// with setting end_of_input_. The background thread waits on this semaphore
// before either processing previous_minibatch (if !end_of_input_), or exiting
// (if end_of_input_).
Semaphore current_minibatch_full_;

// The 'current_minibatch_empty_' semaphore is incremented by the background
// thread when it has done processing current_minibatch_,
// so, it is safe for the foreground thread to write
// to this variable). The foreground thread waits on this semaphore before
// writing to 'current_minibatch_' (in practice it should get the semaphore
// immediately since we expect that the foreground thread will have more to
// do than the background thread).
Semaphore current_minibatch_empty_;

std::thread background_thread_; // Background thread for computing 'derived'
// parameters of a minibatch.

// This value is used in backstitch training when we need to ensure
// consistent dropout masks. It's set to a value derived from rand()
// when the class is initialized.
Expand Down