From 9eaed8f81880e32733ada50fa7bc2cd5cca44013 Mon Sep 17 00:00:00 2001 From: Navneeth K Date: Thu, 20 Apr 2017 02:55:47 -0400 Subject: [PATCH] Online silence weighting - adding frame subsampling factor --- src/online2/online-ivector-feature.cc | 42 ++++++++++++------- src/online2/online-ivector-feature.h | 21 ++++++++-- .../online2-wav-nnet3-latgen-faster.cc | 3 +- 3 files changed, 45 insertions(+), 21 deletions(-) diff --git a/src/online2/online-ivector-feature.cc b/src/online2/online-ivector-feature.cc index cdfc5948571..1048bd1caa8 100644 --- a/src/online2/online-ivector-feature.cc +++ b/src/online2/online-ivector-feature.cc @@ -164,11 +164,10 @@ void OnlineIvectorFeature::UpdateFrameWeights( // elements from top (lower-numbered frames) to bottom (higher-numbered // frames) should be most efficient, assuming it's a heap internally. So we // go forward not backward in delta_weights while adding. - int32 num_frames_ready = NumFramesReady(); for (size_t i = 0; i < delta_weights.size(); i++) { delta_weights_.push(delta_weights[i]); int32 frame = delta_weights[i].first; - KALDI_ASSERT(frame >= 0 && frame < num_frames_ready); + KALDI_ASSERT(frame >= 0); if (frame > most_recent_frame_with_weight_) most_recent_frame_with_weight_ = frame; } @@ -221,7 +220,7 @@ void OnlineIvectorFeature::UpdateStatsUntilFrameWeighted(int32 frame) { delta_weights_provided_ && ! updated_with_no_delta_weights_ && frame <= most_recent_frame_with_weight_); - bool debug_weights = true; + bool debug_weights = false; int32 ivector_period = info_.ivector_period; int32 num_cg_iters = info_.num_cg_iters; @@ -241,8 +240,6 @@ void OnlineIvectorFeature::UpdateStatsUntilFrameWeighted(int32 frame) { if (current_frame_weight_debug_.size() <= frame) current_frame_weight_debug_.resize(frame + 1, 0.0); current_frame_weight_debug_[frame] += weight; - KALDI_ASSERT(current_frame_weight_debug_[frame] >= -0.01 && - current_frame_weight_debug_[frame] <= 1.01); } } if ((!info_.use_most_recent_ivector && t % ivector_period == 0) || @@ -384,9 +381,12 @@ BaseFloat OnlineIvectorFeature::ObjfImprPerFrame() const { OnlineSilenceWeighting::OnlineSilenceWeighting( const TransitionModel &trans_model, - const OnlineSilenceWeightingConfig &config): + const OnlineSilenceWeightingConfig &config, + int32 frame_subsampling_factor): trans_model_(trans_model), config_(config), + frame_subsampling_factor_(frame_subsampling_factor), num_frames_output_and_correct_(0) { + KALDI_ASSERT(frame_subsampling_factor_ >= 1); std::vector silence_phones; SplitStringToIntegers(config.silence_phones_str, ":,", false, &silence_phones); @@ -497,8 +497,15 @@ int32 OnlineSilenceWeighting::GetBeginFrame() { } void OnlineSilenceWeighting::GetDeltaWeights( - int32 num_frames_ready, + int32 num_frames_ready_in, std::vector > *delta_weights) { + // num_frames_ready_in is at the feature frame-rate, most of the code + // in this function is at the decoder frame-rate. + // round up, so we are sure to get weights for at least the frame + // 'num_frames_ready_in - 1', and maybe one or two frames afterward. + int32 fs = frame_subsampling_factor_, + num_frames_ready = (num_frames_ready_in + fs - 1) / fs; + const int32 max_state_duration = config_.max_state_duration; const BaseFloat silence_weight = config_.silence_weight; @@ -515,11 +522,11 @@ void OnlineSilenceWeighting::GetDeltaWeights( // frames_out is the number of frames we will output. KALDI_ASSERT(frames_out >= 0); std::vector frame_weight(frames_out, 1.0); - // we will frame_weight to the value silence_weight for silence frames and for - // transition-ids that repeat with duration > max_state_duration. Frames newer - // than the most recent traceback will get a weight equal to the weight for the - // most recent frame in the traceback; or the silence weight, if there is no - // traceback at all available yet. + // we will set frame_weight to the value silence_weight for silence frames and + // for transition-ids that repeat with duration > max_state_duration. Frames + // newer than the most recent traceback will get a weight equal to the weight + // for the most recent frame in the traceback; or the silence weight, if there + // is no traceback at all available yet. // First treat some special cases. if (frames_out == 0) // Nothing to output. @@ -578,10 +585,13 @@ void OnlineSilenceWeighting::GetDeltaWeights( // Even if the delta-weight is zero for the last frame, we provide it, // because the identity of the most recent frame with a weight is used in // some debugging/checking code. - if (weight_diff != 0.0 || offset + 1 == frames_out) - delta_weights->push_back(std::make_pair(frame, weight_diff)); - } - + if (weight_diff != 0.0 || offset + 1 == frames_out) { + for(int32 i = 0; i < frame_subsampling_factor_; i++) { + int32 input_frame = (frame * frame_subsampling_factor_) + i; + delta_weights->push_back(std::make_pair(input_frame, weight_diff)); + } + } + } } } // namespace kaldi diff --git a/src/online2/online-ivector-feature.h b/src/online2/online-ivector-feature.h index 5ba289aa79d..942cb387bbb 100644 --- a/src/online2/online-ivector-feature.h +++ b/src/online2/online-ivector-feature.h @@ -442,8 +442,14 @@ class OnlineSilenceWeighting { public: // Note: you would initialize a new copy of this object for each new // utterance. + // The frame-subsampling-factor is used for newer nnet3 models, especially + // chain models, when the frame-rate of the decoder is different from the + // frame-rate of the input features. E.g. you might set it to 3 for such + // models. + OnlineSilenceWeighting(const TransitionModel &trans_model, - const OnlineSilenceWeightingConfig &config); + const OnlineSilenceWeightingConfig &config, + int32 frame_subsampling_factor = 1); bool Active() const { return config_.Active(); } @@ -456,7 +462,7 @@ class OnlineSilenceWeighting { // the stats... the output format is (frame-index, delta-weight). The // num_frames_ready argument is the number of frames available at the input // (or equivalently, output) of the online iVector extractor class, which may - // be more than the currently availabl decoder traceback. How many frames + // be more than the currently available decoder traceback. How many frames // of weights it outputs depends on how much "num_frames_ready" increased // since last time we called this function, and whether the decoder traceback // changed. Negative delta_weights might occur if frames previously @@ -466,17 +472,19 @@ class OnlineSilenceWeighting { // this output to class OnlineIvectorFeature by calling its function // UpdateFrameWeights with the output. void GetDeltaWeights( - int32 num_frames_ready, + int32 num_frames_ready_in, std::vector > *delta_weights); private: const TransitionModel &trans_model_; const OnlineSilenceWeightingConfig &config_; + int32 frame_subsampling_factor_; + unordered_set silence_phones_; struct FrameInfo { - //The only reason we need the token pointer is to know far back we have to + // The only reason we need the token pointer is to know far back we have to // trace before the traceback is the same as what we previously traced back. void *token; int32 transition_id; @@ -494,6 +502,11 @@ class OnlineSilenceWeighting { // max_state_duration is relevant. int32 GetBeginFrame(); + // This contains information about any previously computed traceback; + // when the traceback changes we use this variable to compare it with the + // previous traceback. + // It's indexed at the frame-rate of the decoder (may be different + // by 'frame_subsampling_factor_' from the frame-rate of the features. std::vector frame_info_; // This records how many frames have been output and that currently reflect diff --git a/src/online2bin/online2-wav-nnet3-latgen-faster.cc b/src/online2bin/online2-wav-nnet3-latgen-faster.cc index 62204460159..f8fd1f9ef71 100644 --- a/src/online2bin/online2-wav-nnet3-latgen-faster.cc +++ b/src/online2bin/online2-wav-nnet3-latgen-faster.cc @@ -209,7 +209,8 @@ int main(int argc, char *argv[]) { OnlineSilenceWeighting silence_weighting( trans_model, - feature_info.silence_weighting_config); + feature_info.silence_weighting_config, + decodable_opts.frame_subsampling_factor); SingleUtteranceNnet3Decoder decoder(decoder_opts, trans_model, decodable_info,