diff --git a/src/nnet3/nnet-am-decodable-simple.h b/src/nnet3/nnet-am-decodable-simple.h index e83b9e4bab2..74a1e75b59a 100644 --- a/src/nnet3/nnet-am-decodable-simple.h +++ b/src/nnet3/nnet-am-decodable-simple.h @@ -103,6 +103,39 @@ struct NnetSimpleComputationOptions { ParseOptions compute_opts("computation", opts); compute_config.Register(&compute_opts); } + + void CheckAndFixConfigs(int32 nnet_modulus) { + static bool warned_frames_per_chunk = false; + if (frame_subsampling_factor < 1 || frames_per_chunk < 1) { + KALDI_ERR << "--frame-subsampling-factor and " + << "--frames-per-chunk must be > 0"; + } + KALDI_ASSERT(nnet_modulus > 0); + int32 n = Lcm(frame_subsampling_factor, nnet_modulus); + + if (frames_per_chunk % n != 0) { + // round up to the nearest multiple of n. + int32 new_frames_per_chunk = n * ((frames_per_chunk + n - 1) / n); + if (!warned_frames_per_chunk) { + warned_frames_per_chunk = true; + if (nnet_modulus == 1) { + // simpler error message. + KALDI_LOG << "Increasing --frames-per-chunk from " << frames_per_chunk + << " to " << new_frames_per_chunk + << " to make it a multiple of " + << "--frame-subsampling-factor=" + << frame_subsampling_factor; + } else { + KALDI_LOG << "Increasing --frames-per-chunk from " << frames_per_chunk + << " to " << new_frames_per_chunk << " due to " + << "--frame-subsampling-factor=" << frame_subsampling_factor + << " and " + << "nnet shift-invariance modulus = " << nnet_modulus; + } + } + frames_per_chunk = new_frames_per_chunk; + } + } }; /* diff --git a/src/nnet3/nnet-batch-compute.cc b/src/nnet3/nnet-batch-compute.cc index 9d71a021f05..0e07834ed3d 100644 --- a/src/nnet3/nnet-batch-compute.cc +++ b/src/nnet3/nnet-batch-compute.cc @@ -38,9 +38,12 @@ NnetBatchComputer::NnetBatchComputer( log_priors_(priors), num_full_minibatches_(0) { log_priors_.ApplyLog(); - CheckAndFixConfigs(); - ComputeSimpleNnetContext(nnet, &nnet_left_context_, - &nnet_right_context_); + opts_.CheckAndFixConfigs(nnet_.Modulus()); + KALDI_ASSERT(opts_.minibatch_size >= 1 && opts_.edge_minibatch_size >= 1 && + opts_.partial_minibatch_factor < 1.0 && + opts_.partial_minibatch_factor >= 0.0); + + ComputeSimpleNnetContext(nnet, &nnet_left_context_, &nnet_right_context_); input_dim_ = nnet.InputDim("input"); ivector_dim_ = std::max(0, nnet.InputDim("ivector")); output_dim_ = nnet.OutputDim("output"); @@ -340,49 +343,6 @@ void NnetBatchComputer::GetComputationRequest( request->outputs.push_back(IoSpecification("output", output_indexes)); } - - -void NnetBatchComputer::CheckAndFixConfigs() { - static bool warned_frames_per_chunk = false; - int32 nnet_modulus = nnet_.Modulus(); - if (opts_.frame_subsampling_factor < 1 || - opts_.frames_per_chunk < 1) { - KALDI_ERR << "--frame-subsampling-factor and " - << "--frames-per-chunk must be > 0"; - } - KALDI_ASSERT(nnet_modulus > 0); - int32 n = Lcm(opts_.frame_subsampling_factor, nnet_modulus); - - if (opts_.frames_per_chunk % n != 0) { - // round up to the nearest multiple of n. - int32 frames_per_chunk = n * ((opts_.frames_per_chunk + n - 1) / n); - if (!warned_frames_per_chunk) { - warned_frames_per_chunk = true; - if (nnet_modulus == 1) { - // simpler error message. - KALDI_LOG << "Increasing --frames-per-chunk from " - << opts_.frames_per_chunk << " to " - << frames_per_chunk << " to make it a multiple of " - << "--frame-subsampling-factor=" - << opts_.frame_subsampling_factor; - } else { - KALDI_LOG << "Increasing --frames-per-chunk from " - << opts_.frames_per_chunk << " to " - << frames_per_chunk << " due to " - << "--frame-subsampling-factor=" - << opts_.frame_subsampling_factor << " and " - << "nnet shift-invariance modulus = " << nnet_modulus; - } - } - opts_.frames_per_chunk = frames_per_chunk; - } - KALDI_ASSERT(opts_.minibatch_size >= 1 && - opts_.edge_minibatch_size >= 1 && - opts_.partial_minibatch_factor < 1.0 && - opts_.partial_minibatch_factor >= 0.0); -} - - void NnetBatchComputer::FormatInputs( int32 minibatch_size, const std::vector &tasks,