Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions src/online2/online-ivector-feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,10 @@ void OnlineIvectorFeature::UpdateFrameWeights(
// elements from top (lower-numbered frames) to bottom (higher-numbered
// frames) should be most efficient, assuming it's a heap internally. So we
// go forward not backward in delta_weights while adding.
int32 num_frames_ready = NumFramesReady();
for (size_t i = 0; i < delta_weights.size(); i++) {
delta_weights_.push(delta_weights[i]);
int32 frame = delta_weights[i].first;
KALDI_ASSERT(frame >= 0 && frame < num_frames_ready);
KALDI_ASSERT(frame >= 0);
if (frame > most_recent_frame_with_weight_)
most_recent_frame_with_weight_ = frame;
}
Expand Down Expand Up @@ -221,7 +220,7 @@ void OnlineIvectorFeature::UpdateStatsUntilFrameWeighted(int32 frame) {
delta_weights_provided_ &&
! updated_with_no_delta_weights_ &&
frame <= most_recent_frame_with_weight_);
bool debug_weights = true;
bool debug_weights = false;

int32 ivector_period = info_.ivector_period;
int32 num_cg_iters = info_.num_cg_iters;
Expand All @@ -241,8 +240,6 @@ void OnlineIvectorFeature::UpdateStatsUntilFrameWeighted(int32 frame) {
if (current_frame_weight_debug_.size() <= frame)
current_frame_weight_debug_.resize(frame + 1, 0.0);
current_frame_weight_debug_[frame] += weight;
KALDI_ASSERT(current_frame_weight_debug_[frame] >= -0.01 &&
current_frame_weight_debug_[frame] <= 1.01);
}
}
if ((!info_.use_most_recent_ivector && t % ivector_period == 0) ||
Expand Down Expand Up @@ -384,9 +381,12 @@ BaseFloat OnlineIvectorFeature::ObjfImprPerFrame() const {

OnlineSilenceWeighting::OnlineSilenceWeighting(
const TransitionModel &trans_model,
const OnlineSilenceWeightingConfig &config):
const OnlineSilenceWeightingConfig &config,
int32 frame_subsampling_factor):
trans_model_(trans_model), config_(config),
frame_subsampling_factor_(frame_subsampling_factor),
num_frames_output_and_correct_(0) {
KALDI_ASSERT(frame_subsampling_factor_ >= 1);
std::vector<int32> silence_phones;
SplitStringToIntegers(config.silence_phones_str, ":,", false,
&silence_phones);
Expand Down Expand Up @@ -497,8 +497,15 @@ int32 OnlineSilenceWeighting::GetBeginFrame() {
}

void OnlineSilenceWeighting::GetDeltaWeights(
int32 num_frames_ready,
int32 num_frames_ready_in,
std::vector<std::pair<int32, BaseFloat> > *delta_weights) {
// num_frames_ready_in is at the feature frame-rate, most of the code
// in this function is at the decoder frame-rate.
// round up, so we are sure to get weights for at least the frame
// 'num_frames_ready_in - 1', and maybe one or two frames afterward.
int32 fs = frame_subsampling_factor_,
num_frames_ready = (num_frames_ready_in + fs - 1) / fs;

const int32 max_state_duration = config_.max_state_duration;
const BaseFloat silence_weight = config_.silence_weight;

Expand All @@ -515,11 +522,11 @@ void OnlineSilenceWeighting::GetDeltaWeights(
// frames_out is the number of frames we will output.
KALDI_ASSERT(frames_out >= 0);
std::vector<BaseFloat> frame_weight(frames_out, 1.0);
// we will frame_weight to the value silence_weight for silence frames and for
// transition-ids that repeat with duration > max_state_duration. Frames newer
// than the most recent traceback will get a weight equal to the weight for the
// most recent frame in the traceback; or the silence weight, if there is no
// traceback at all available yet.
// we will set frame_weight to the value silence_weight for silence frames and
// for transition-ids that repeat with duration > max_state_duration. Frames
// newer than the most recent traceback will get a weight equal to the weight
// for the most recent frame in the traceback; or the silence weight, if there
// is no traceback at all available yet.

// First treat some special cases.
if (frames_out == 0) // Nothing to output.
Expand Down Expand Up @@ -578,10 +585,13 @@ void OnlineSilenceWeighting::GetDeltaWeights(
// Even if the delta-weight is zero for the last frame, we provide it,
// because the identity of the most recent frame with a weight is used in
// some debugging/checking code.
if (weight_diff != 0.0 || offset + 1 == frames_out)
delta_weights->push_back(std::make_pair(frame, weight_diff));
}

if (weight_diff != 0.0 || offset + 1 == frames_out) {
for(int32 i = 0; i < frame_subsampling_factor_; i++) {
int32 input_frame = (frame * frame_subsampling_factor_) + i;
delta_weights->push_back(std::make_pair(input_frame, weight_diff));
}
}
}
}

} // namespace kaldi
21 changes: 17 additions & 4 deletions src/online2/online-ivector-feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,14 @@ class OnlineSilenceWeighting {
public:
// Note: you would initialize a new copy of this object for each new
// utterance.
// The frame-subsampling-factor is used for newer nnet3 models, especially
// chain models, when the frame-rate of the decoder is different from the
// frame-rate of the input features. E.g. you might set it to 3 for such
// models.

OnlineSilenceWeighting(const TransitionModel &trans_model,
const OnlineSilenceWeightingConfig &config);
const OnlineSilenceWeightingConfig &config,
int32 frame_subsampling_factor = 1);

bool Active() const { return config_.Active(); }

Expand All @@ -456,7 +462,7 @@ class OnlineSilenceWeighting {
// the stats... the output format is (frame-index, delta-weight). The
// num_frames_ready argument is the number of frames available at the input
// (or equivalently, output) of the online iVector extractor class, which may
// be more than the currently availabl decoder traceback. How many frames
// be more than the currently available decoder traceback. How many frames
// of weights it outputs depends on how much "num_frames_ready" increased
// since last time we called this function, and whether the decoder traceback
// changed. Negative delta_weights might occur if frames previously
Expand All @@ -466,17 +472,19 @@ class OnlineSilenceWeighting {
// this output to class OnlineIvectorFeature by calling its function
// UpdateFrameWeights with the output.
void GetDeltaWeights(
int32 num_frames_ready,
int32 num_frames_ready_in,
std::vector<std::pair<int32, BaseFloat> > *delta_weights);

private:
const TransitionModel &trans_model_;
const OnlineSilenceWeightingConfig &config_;

int32 frame_subsampling_factor_;

unordered_set<int32> silence_phones_;

struct FrameInfo {
//The only reason we need the token pointer is to know far back we have to
// The only reason we need the token pointer is to know far back we have to
// trace before the traceback is the same as what we previously traced back.
void *token;
int32 transition_id;
Expand All @@ -494,6 +502,11 @@ class OnlineSilenceWeighting {
// max_state_duration is relevant.
int32 GetBeginFrame();

// This contains information about any previously computed traceback;
// when the traceback changes we use this variable to compare it with the
// previous traceback.
// It's indexed at the frame-rate of the decoder (may be different
// by 'frame_subsampling_factor_' from the frame-rate of the features.
std::vector<FrameInfo> frame_info_;

// This records how many frames have been output and that currently reflect
Expand Down
3 changes: 2 additions & 1 deletion src/online2bin/online2-wav-nnet3-latgen-faster.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ int main(int argc, char *argv[]) {

OnlineSilenceWeighting silence_weighting(
trans_model,
feature_info.silence_weighting_config);
feature_info.silence_weighting_config,
decodable_opts.frame_subsampling_factor);

SingleUtteranceNnet3Decoder decoder(decoder_opts, trans_model,
decodable_info,
Expand Down