Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions src/nnet3/decodable-simple-looped.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,16 @@ void DecodableNnetSimpleLoopedInfo::Init(
if (has_ivectors_)
ModifyNnetIvectorPeriod(ivector_period, nnet);

ComputationRequest request1, request2, request3;
int32 num_sequences = 1; // we're processing one utterance at a time.
int32 extra_right_context = 0;
CreateLoopedComputationRequestSimple(*nnet, frames_per_chunk_,
opts_.frame_subsampling_factor,
ivector_period, opts.extra_left_context_initial,
extra_right_context,
num_sequences,
&request1, &request2, &request3);
&request1_, &request2_, &request3_);

CompileLooped(*nnet, opts_.optimize_config, request1, request2, request3,
CompileLooped(*nnet, opts_.optimize_config, request1_, request2_, request3_,
&computation_);
computation_.ComputeCudaIndexes();
if (GetVerboseLevel() >= 3) {
Expand Down Expand Up @@ -172,11 +171,25 @@ void DecodableNnetSimpleLooped::AdvanceChunk() {
computer_.AcceptInput("input", &feats_chunk);

if (info_.has_ivectors_) {
KALDI_ASSERT(info_.request1_.inputs.size() == 2);
// all but the 1st chunk should have 1 iVector, but no need
// to assume this.
int32 num_ivectors = (num_chunks_computed_ == 0 ?
info_.request1_.inputs[1].indexes.size() :
info_.request2_.inputs[1].indexes.size());
KALDI_ASSERT(num_ivectors > 0);

Vector<BaseFloat> ivector;
// we just get the iVector from the last input frame we needed...
// we don't bother trying to be 'accurate' in getting the iVectors
// for their 'correct' frames, because in general using the
// iVector from as large 't' as possible will be better.
GetCurrentIvector(end_input_frame, &ivector);
CuMatrix<BaseFloat> cu_ivector(1, ivector.Dim());
cu_ivector.Row(0).CopyFromVec(ivector);
computer_.AcceptInput("ivector", &cu_ivector);
Matrix<BaseFloat> ivectors(num_ivectors,
ivector.Dim());
ivectors.CopyRowsFromVec(ivector);
CuMatrix<BaseFloat> cu_ivectors(ivectors);
computer_.AcceptInput("ivector", &cu_ivectors);
}
computer_.Run();

Expand Down
5 changes: 5 additions & 0 deletions src/nnet3/decodable-simple-looped.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ class DecodableNnetSimpleLoopedInfo {
// to accept the iVectors
bool has_ivectors_;

// The 3 computation requests that are used to create the looped
// computation are stored in the class, as we need them to work out
// exactly shich iVectors are needed.
ComputationRequest request1_, request2_, request3_;

// The compiled, 'looped' computation.
NnetComputation computation_;
};
Expand Down
14 changes: 8 additions & 6 deletions src/nnet3/nnet-compile-looped.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ int32 GetChunkSize(const Nnet &nnet,
/// for negative a is not specified (except by relation with the division '/'
/// operator), but in practice it would be <= 0 for almost all implementations.
template<class I> I Mod(I m, I n) {
if (m >= 0) return m % n;
else return -((-m) % n);
I ans = m % n;
if (ans < 0) ans += n;
return ans;
}


Expand Down Expand Up @@ -171,15 +172,16 @@ void CreateLoopedComputationRequestSimple(const Nnet &nnet,
}
for (int32 t = chunk2_input_begin_t; t < chunk2_input_end_t; t++) {
int32 ivector_t = t - Mod(t, ivector_period);
if (ivector_times1.count(ivector_t) == 0)
if (ivector_times2.count(ivector_t) == 0 &&
ivector_times1.count(ivector_t) == 0)
ivector_times2.insert(ivector_t);
}
for (int32 t = chunk3_input_begin_t; t < chunk3_input_end_t; t++) {
int32 ivector_t = t - Mod(t, ivector_period);
if (ivector_times1.count(ivector_t) == 0 &&
ivector_times2.count(ivector_t) == 0) {
if (ivector_times3.count(ivector_t) == 0 &&
ivector_times2.count(ivector_t) == 0 &&
ivector_times1.count(ivector_t) == 0)
ivector_times3.insert(ivector_t);
}
}
}

Expand Down