Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions src/nnetbin/nnet-train-frmshuff.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,21 +159,24 @@ int main(int argc, char *argv[]) {

CuMatrix<BaseFloat> feats_transf, nnet_out, obj_diff;

Timer time;
Timer time, time_io;
KALDI_LOG << (crossvalidate ? "CROSS-VALIDATION" : "TRAINING")
<< " STARTED";

int32 num_done = 0,
num_no_tgt_mat = 0,
num_other_error = 0;

double time_io_accu = 0.0;

// main loop,
while (!feature_reader.Done()) {
#if HAVE_CUDA == 1
// check that GPU computes accurately,
CuDevice::Instantiate().CheckGpuHealth();
#endif
// fill the randomizer,
time_io.Reset();
for ( ; !feature_reader.Done(); feature_reader.Next()) {
if (feature_randomizer.IsFull()) {
// break the loop without calling Next(),
Expand Down Expand Up @@ -219,6 +222,10 @@ int main(int argc, char *argv[]) {
weights.Scale(w);
}

// accumulate the I/O time,
time_io_accu += time_io.Elapsed();
time_io.Reset(); // to be sure we don't count 2x,

// skip too long utterances (or we run out of memory),
if (mat.NumRows() > max_frames) {
KALDI_WARN << "Utterance too long, skipping! " << utt
Expand Down Expand Up @@ -299,13 +306,7 @@ int main(int argc, char *argv[]) {
weights_randomizer.AddData(weights);
num_done++;

// report the speed,
if (num_done % 5000 == 0) {
double time_now = time.Elapsed();
KALDI_VLOG(1) << "After " << num_done << " utterances: "
<< "time elapsed = " << time_now / 60 << " min; "
<< "processed " << total_frames / time_now << " frames per sec.";
}
time_io.Reset(); // reset before reading next feature matrix,
}

// randomize,
Expand Down Expand Up @@ -350,11 +351,11 @@ int main(int argc, char *argv[]) {

// 1st mini-batch : show what happens in network,
if (total_frames == 0) {
KALDI_VLOG(1) << "### After " << total_frames << " frames,";
KALDI_VLOG(1) << nnet.InfoPropagate();
KALDI_LOG << "### After " << total_frames << " frames,";
KALDI_LOG << nnet.InfoPropagate();
if (!crossvalidate) {
KALDI_VLOG(1) << nnet.InfoBackPropagate();
KALDI_VLOG(1) << nnet.InfoGradient();
KALDI_LOG << nnet.InfoBackPropagate();
KALDI_LOG << nnet.InfoGradient();
}
}

Expand All @@ -380,11 +381,11 @@ int main(int argc, char *argv[]) {
} // main loop,

// after last mini-batch : show what happens in network,
KALDI_VLOG(1) << "### After " << total_frames << " frames,";
KALDI_VLOG(1) << nnet.InfoPropagate();
KALDI_LOG << "### After " << total_frames << " frames,";
KALDI_LOG << nnet.InfoPropagate();
if (!crossvalidate) {
KALDI_VLOG(1) << nnet.InfoBackPropagate();
KALDI_VLOG(1) << nnet.InfoGradient();
KALDI_LOG << nnet.InfoBackPropagate();
KALDI_LOG << nnet.InfoGradient();
}

if (!crossvalidate) {
Expand All @@ -397,7 +398,8 @@ int main(int argc, char *argv[]) {
<< "[" << (crossvalidate ? "CROSS-VALIDATION" : "TRAINING")
<< ", " << (randomize ? "RANDOMIZED" : "NOT-RANDOMIZED")
<< ", " << time.Elapsed() / 60 << " min, processing "
<< total_frames / time.Elapsed() << " frames per sec.]";
<< total_frames / time.Elapsed() << " frames per sec;"
<< " i/o time " << 100.*time_io_accu/time.Elapsed() << "%]";

if (objective_function == "xent") {
KALDI_LOG << xent.ReportPerClass();
Expand Down
36 changes: 11 additions & 25 deletions src/nnetbin/nnet-train-multistream-perutt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ int main(int argc, char *argv[]) {
os << frame_num_utt[i] << " ";
}
os << "]";

KALDI_LOG << "frame_num_utt[" << frame_num_utt.size() << "]" << os.str();
}
// Reset all the streams (we have new sentences),
Expand All @@ -301,31 +300,20 @@ int main(int argc, char *argv[]) {

// 1st model update : show what happens in network,
if (total_frames == 0) {
KALDI_VLOG(1) << "### After " << total_frames << " frames,";
KALDI_VLOG(1) << nnet.Info();
KALDI_VLOG(1) << nnet.InfoPropagate();
KALDI_LOG << "### After " << total_frames << " frames,";
KALDI_LOG << nnet.Info();
KALDI_LOG << nnet.InfoPropagate();
if (!crossvalidate) {
KALDI_VLOG(1) << nnet.InfoBackPropagate();
KALDI_VLOG(1) << nnet.InfoGradient();
KALDI_LOG << nnet.InfoBackPropagate();
KALDI_LOG << nnet.InfoGradient();
}
}

int32 tmp_done = num_done;
kaldi::int64 tmp_frames = total_frames;

num_done += frame_num_utt.size();
total_frames += std::accumulate(frame_num_utt.begin(), frame_num_utt.end(), 0);

// report the speed,
int32 N = 5000;
if (tmp_done / N != num_done / N) {
double time_now = time.Elapsed();
KALDI_VLOG(1) << "After " << num_done << " utterances, "
<< "(" << total_frames/360000.0 << "h), "
<< "time elapsed = " << time_now / 60 << " min; "
<< "processed " << total_frames / time_now << " frames per sec.";
}

// monitor the NN training (--verbose=2),
int32 F = 25000;
if (GetVerboseLevel() >= 3) {
Expand All @@ -343,14 +331,12 @@ int main(int argc, char *argv[]) {
}

// after last model update : show what happens in network,
if (GetVerboseLevel() >= 1) { // vlog-1
KALDI_VLOG(1) << "### After " << total_frames << " frames,";
KALDI_VLOG(1) << nnet.Info();
KALDI_VLOG(1) << nnet.InfoPropagate();
if (!crossvalidate) {
KALDI_VLOG(1) << nnet.InfoBackPropagate();
KALDI_VLOG(1) << nnet.InfoGradient();
}
KALDI_LOG << "### After " << total_frames << " frames,";
KALDI_LOG << nnet.Info();
KALDI_LOG << nnet.InfoPropagate();
if (!crossvalidate) {
KALDI_LOG << nnet.InfoBackPropagate();
KALDI_LOG << nnet.InfoGradient();
}

if (!crossvalidate) {
Expand Down
32 changes: 10 additions & 22 deletions src/nnetbin/nnet-train-multistream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,6 @@ int main(int argc, char *argv[]) {
os << frame_num_utt[i] << " ";
}
os << "]";

KALDI_LOG << "frame_num_utt[" << frame_num_utt.size() << "]" << os.str();
}

Expand All @@ -387,31 +386,20 @@ int main(int argc, char *argv[]) {

// 1st minibatch : show what happens in network,
if (total_frames == 0) {
KALDI_VLOG(1) << "### After " << total_frames << " frames,";
KALDI_VLOG(1) << nnet.Info();
KALDI_VLOG(1) << nnet.InfoPropagate();
KALDI_LOG << "### After " << total_frames << " frames,";
KALDI_LOG << nnet.Info();
KALDI_LOG << nnet.InfoPropagate();
if (!crossvalidate) {
KALDI_VLOG(1) << nnet.InfoBackPropagate();
KALDI_VLOG(1) << nnet.InfoGradient();
KALDI_LOG << nnet.InfoBackPropagate();
KALDI_LOG << nnet.InfoGradient();
}
}

int32 tmp_done = num_done;
kaldi::int64 tmp_frames = total_frames;

num_done += std::accumulate(new_utt_flags.begin(), new_utt_flags.end(), 0);
total_frames += std::accumulate(frame_num_utt.begin(), frame_num_utt.end(), 0);

// report the speed,
int32 N = 5000;
if (tmp_done / N != num_done / N) {
double time_now = time.Elapsed();
KALDI_VLOG(1) << "After " << num_done << " utterances, "
<< "(" << total_frames/360000.0 << "h), "
<< "time elapsed = " << time_now / 60 << " min; "
<< "processed " << total_frames / time_now << " frames per sec.";
}

// monitor the NN training (--verbose=2),
int32 F = 25000;
if (GetVerboseLevel() >= 2) {
Expand All @@ -429,12 +417,12 @@ int main(int argc, char *argv[]) {
}

// after last minibatch : show what happens in network,
KALDI_VLOG(1) << "### After " << total_frames << " frames,";
KALDI_VLOG(1) << nnet.Info();
KALDI_VLOG(1) << nnet.InfoPropagate();
KALDI_LOG << "### After " << total_frames << " frames,";
KALDI_LOG << nnet.Info();
KALDI_LOG << nnet.InfoPropagate();
if (!crossvalidate) {
KALDI_VLOG(1) << nnet.InfoBackPropagate();
KALDI_VLOG(1) << nnet.InfoGradient();
KALDI_LOG << nnet.InfoBackPropagate();
KALDI_LOG << nnet.InfoGradient();
}

if (!crossvalidate) {
Expand Down
29 changes: 8 additions & 21 deletions src/nnetbin/nnet-train-perutt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,11 @@ int main(int argc, char *argv[]) {

// 1st minibatch : show what happens in network,
if (total_frames == 0) {
KALDI_VLOG(1) << "### After " << total_frames << " frames,";
KALDI_VLOG(1) << nnet.InfoPropagate();
KALDI_LOG << "### After " << total_frames << " frames,";
KALDI_LOG << nnet.InfoPropagate();
if (!crossvalidate) {
KALDI_VLOG(1) << nnet.InfoBackPropagate();
KALDI_VLOG(1) << nnet.InfoGradient();
KALDI_LOG << nnet.InfoBackPropagate();
KALDI_LOG << nnet.InfoGradient();
}
}

Expand All @@ -265,27 +265,14 @@ int main(int argc, char *argv[]) {

num_done++;
total_frames += frm_weights.Sum();

// do this every 5000 utterances,
if (num_done % 5000 == 0) {
// report the speed,
double time_now = time.Elapsed();
KALDI_VLOG(1) << "After " << num_done << " utterances: "
<< "time elapsed = " << time_now / 60 << " min; "
<< "processed " << total_frames / time_now << " frames per sec.";
#if HAVE_CUDA == 1
// check that GPU computes accurately,
CuDevice::Instantiate().CheckGpuHealth();
#endif
}
} // main loop,

// after last minibatch : show what happens in network,
KALDI_VLOG(1) << "### After " << total_frames << " frames,";
KALDI_VLOG(1) << nnet.InfoPropagate();
KALDI_LOG << "### After " << total_frames << " frames,";
KALDI_LOG << nnet.InfoPropagate();
if (!crossvalidate) {
KALDI_VLOG(1) << nnet.InfoBackPropagate();
KALDI_VLOG(1) << nnet.InfoGradient();
KALDI_LOG << nnet.InfoBackPropagate();
KALDI_LOG << nnet.InfoGradient();
}

if (!crossvalidate) {
Expand Down