Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/nnet3/nnet-am-decodable-simple.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,39 @@ struct NnetSimpleComputationOptions {
ParseOptions compute_opts("computation", opts);
compute_config.Register(&compute_opts);
}

void CheckAndFixConfigs(int32 nnet_modulus) {
static bool warned_frames_per_chunk = false;
if (frame_subsampling_factor < 1 || frames_per_chunk < 1) {
KALDI_ERR << "--frame-subsampling-factor and "
<< "--frames-per-chunk must be > 0";
}
KALDI_ASSERT(nnet_modulus > 0);
int32 n = Lcm(frame_subsampling_factor, nnet_modulus);

if (frames_per_chunk % n != 0) {
// round up to the nearest multiple of n.
int32 new_frames_per_chunk = n * ((frames_per_chunk + n - 1) / n);
if (!warned_frames_per_chunk) {
warned_frames_per_chunk = true;
if (nnet_modulus == 1) {
// simpler error message.
KALDI_LOG << "Increasing --frames-per-chunk from " << frames_per_chunk
<< " to " << new_frames_per_chunk
<< " to make it a multiple of "
<< "--frame-subsampling-factor="
<< frame_subsampling_factor;
} else {
KALDI_LOG << "Increasing --frames-per-chunk from " << frames_per_chunk
<< " to " << new_frames_per_chunk << " due to "
<< "--frame-subsampling-factor=" << frame_subsampling_factor
<< " and "
<< "nnet shift-invariance modulus = " << nnet_modulus;
}
}
frames_per_chunk = new_frames_per_chunk;
}
}
};

/*
Expand Down
52 changes: 6 additions & 46 deletions src/nnet3/nnet-batch-compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,12 @@ NnetBatchComputer::NnetBatchComputer(
log_priors_(priors),
num_full_minibatches_(0) {
log_priors_.ApplyLog();
CheckAndFixConfigs();
ComputeSimpleNnetContext(nnet, &nnet_left_context_,
&nnet_right_context_);
opts_.CheckAndFixConfigs(nnet_.Modulus());
KALDI_ASSERT(opts_.minibatch_size >= 1 && opts_.edge_minibatch_size >= 1 &&
opts_.partial_minibatch_factor < 1.0 &&
opts_.partial_minibatch_factor >= 0.0);

ComputeSimpleNnetContext(nnet, &nnet_left_context_, &nnet_right_context_);
input_dim_ = nnet.InputDim("input");
ivector_dim_ = std::max<int32>(0, nnet.InputDim("ivector"));
output_dim_ = nnet.OutputDim("output");
Expand Down Expand Up @@ -340,49 +343,6 @@ void NnetBatchComputer::GetComputationRequest(
request->outputs.push_back(IoSpecification("output", output_indexes));
}



void NnetBatchComputer::CheckAndFixConfigs() {
static bool warned_frames_per_chunk = false;
int32 nnet_modulus = nnet_.Modulus();
if (opts_.frame_subsampling_factor < 1 ||
opts_.frames_per_chunk < 1) {
KALDI_ERR << "--frame-subsampling-factor and "
<< "--frames-per-chunk must be > 0";
}
KALDI_ASSERT(nnet_modulus > 0);
int32 n = Lcm(opts_.frame_subsampling_factor, nnet_modulus);

if (opts_.frames_per_chunk % n != 0) {
// round up to the nearest multiple of n.
int32 frames_per_chunk = n * ((opts_.frames_per_chunk + n - 1) / n);
if (!warned_frames_per_chunk) {
warned_frames_per_chunk = true;
if (nnet_modulus == 1) {
// simpler error message.
KALDI_LOG << "Increasing --frames-per-chunk from "
<< opts_.frames_per_chunk << " to "
<< frames_per_chunk << " to make it a multiple of "
<< "--frame-subsampling-factor="
<< opts_.frame_subsampling_factor;
} else {
KALDI_LOG << "Increasing --frames-per-chunk from "
<< opts_.frames_per_chunk << " to "
<< frames_per_chunk << " due to "
<< "--frame-subsampling-factor="
<< opts_.frame_subsampling_factor << " and "
<< "nnet shift-invariance modulus = " << nnet_modulus;
}
}
opts_.frames_per_chunk = frames_per_chunk;
}
KALDI_ASSERT(opts_.minibatch_size >= 1 &&
opts_.edge_minibatch_size >= 1 &&
opts_.partial_minibatch_factor < 1.0 &&
opts_.partial_minibatch_factor >= 0.0);
}


void NnetBatchComputer::FormatInputs(
int32 minibatch_size,
const std::vector<NnetInferenceTask*> &tasks,
Expand Down