From 47eb098d6124f150c8d9d088aec3a6a338a219c8 Mon Sep 17 00:00:00 2001 From: jluitjens Date: Sat, 19 Oct 2019 10:07:53 -0700 Subject: [PATCH] Write all output to a single lattice writer instead of one per iteration. We are seeing that small corpus are dominated by writer Open/Close. A better solution is to modify the key with the iteration number and then handle the modified keys in scoring. Note that if you have just a single iteration the key is not modified and thus behavior is as expected. --- src/cudadecoderbin/batched-wav-nnet3-cuda.cc | 34 +++++++++----------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/src/cudadecoderbin/batched-wav-nnet3-cuda.cc b/src/cudadecoderbin/batched-wav-nnet3-cuda.cc index 2cec6f90deb..bfe8d8a2ce6 100644 --- a/src/cudadecoderbin/batched-wav-nnet3-cuda.cc +++ b/src/cudadecoderbin/batched-wav-nnet3-cuda.cc @@ -90,13 +90,13 @@ void FinishOneDecode(const std::string &utt, const std::string &key, int64 *num_frames, double *tot_like, CompactLatticeWriter *clat_writer, std::mutex *clat_writer_mutex, std::mutex *stdout_mutex, - CompactLattice &clat) { + const bool write_lattice, CompactLattice &clat) { nvtxRangePushA("FinishOneDecode"); GetDiagnosticsAndPrintOutput(utt, word_syms, clat, stdout_mutex, num_frames, tot_like); - { + if (write_lattice) { std::lock_guard lk(*clat_writer_mutex); - clat_writer->Write(utt, clat); + clat_writer->Write(key, clat); } nvtxRangePop(); @@ -178,8 +178,8 @@ int main(int argc, char *argv[]) { SetDropoutTestMode(true, &(am_nnet.GetNnet())); nnet3::CollapseModel(nnet3::CollapseModelConfig(), &(am_nnet.GetNnet())); - std::vector clat_writers(iterations); - std::vector clat_write_mutexs(iterations); + CompactLatticeWriter clat_writer; + std::mutex clat_write_mutex; fst::Fst *decode_fst = fst::ReadFstKaldiGeneric(fst_rxfilename); @@ -203,6 +203,7 @@ int main(int argc, char *argv[]) { int num_groups_done = 0; + clat_writer.Open(clat_wspecifier); // starting timer here so we // can measure throughput // without allocation @@ -215,13 +216,6 @@ int main(int argc, char *argv[]) { num_task_submitted = 0; SequentialTableReader wav_reader(wav_rspecifier); - std::mutex *clat_writer_mutex = &clat_write_mutexs[iter]; - CompactLatticeWriter *clat_writer = &clat_writers[iter]; - - stringstream filename; - filename << clat_wspecifier << "-" << iter; - clat_writer->Open(filename.str()); - for (; !wav_reader.Done(); wav_reader.Next()) { nvtxRangePushA("Utterance Iteration"); @@ -232,8 +226,10 @@ int main(int argc, char *argv[]) { std::string utt = wav_reader.Key(); std::string key = utt; - // make key unique for each iteration - key = key + "-" + std::to_string(iter); + if (iterations > 0) { + // make key unique for each iteration + key = std::to_string(iter) + "-" + key; + } const WaveData &wave_data = wav_reader.Value(); @@ -247,9 +243,10 @@ int main(int argc, char *argv[]) { auto finish_one_decode_lamba = [ // Capturing the arguments that will change by copy - utt, key, clat_writer_mutex, clat_writer, + utt, key, // Capturing the const/global args by reference &word_syms, &cuda_pipeline, &stdout_mutex, &num_frames, + &clat_write_mutex, &clat_writer, &write_lattice, &tot_like] // The callback function receive the compact lattice as argument // if determinize_lattice is true, it is a determinized lattice @@ -262,7 +259,8 @@ int main(int argc, char *argv[]) { // Captured arguments used to specialize FinishOneDecode for // this task utt, key, word_syms, &cuda_pipeline, &num_frames, &tot_like, - clat_writer, clat_writer_mutex, &stdout_mutex, + &clat_writer, &clat_write_mutex, &stdout_mutex, + write_lattice, // Generated lattice that will be passed once the task is // complete clat_in); @@ -292,7 +290,6 @@ int main(int argc, char *argv[]) { << " Audio: " << total_audio * (iter + 1) << " RealTimeX: " << total_audio * (iter + 1) / total_time; num_groups_done++; - clat_writers[iter].Close(); } } // end iterations loop @@ -309,7 +306,6 @@ int main(int argc, char *argv[]) { << " Audio: " << total_audio * (iter + 1) << " RealTimeX: " << total_audio * (iter + 1) / total_time; num_groups_done++; - clat_writers[iter].Close(); } // number of seconds elapsed since the creation of timer @@ -328,7 +324,7 @@ int main(int argc, char *argv[]) { cuda_pipeline.Finalize(); cudaDeviceSynchronize(); - + delete word_syms; // will delete if non-NULL. return 0;