Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions src/online2bin/online2-tcp-nnet3-decode-faster.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,20 @@ std::string LatticeToString(const Lattice &lat, const fst::SymbolTable &word_sym
return msg.str();
}

std::string GetTimeString(int32 t_beg, int32 t_end, BaseFloat time_unit) {
char buffer[100];
double t_beg2 = t_beg * time_unit;
double t_end2 = t_end * time_unit;
snprintf(buffer, 100, "%.2f %.2f", t_beg2, t_end2);
return std::string(buffer);
}

int32 GetLatticeTimeSpan(const Lattice& lat) {
std::vector<int32> times;
LatticeStateTimes(lat, &times);
return times.back();
}

std::string LatticeToString(const CompactLattice &clat, const fst::SymbolTable &word_syms) {
if (clat.NumStates() == 0) {
KALDI_WARN << "Empty lattice.";
Expand Down Expand Up @@ -132,6 +146,7 @@ int main(int argc, char *argv[]) {
BaseFloat samp_freq = 16000.0;
int port_num = 5050;
int read_timeout = 3;
bool produce_time = false;

po.Register("samp-freq", &samp_freq,
"Sampling frequency of the input signal (coded as 16-bit slinear).");
Expand All @@ -145,6 +160,8 @@ int main(int argc, char *argv[]) {
"Number of seconds of timout for TCP audio data to appear on the stream. Use -1 for blocking.");
po.Register("port-num", &port_num,
"Port number the server will listen on.");
po.Register("produce-time", &produce_time,
"Prepend begin/end times between endpoints (e.g. '5.46 6.81 <text_output>', in seconds)");

feature_opts.Register(&po);
decodable_opts.Register(&po);
Expand All @@ -164,6 +181,9 @@ int main(int argc, char *argv[]) {

OnlineNnet2FeaturePipelineInfo feature_info(feature_opts);

BaseFloat frame_shift = feature_info.FrameShiftInSeconds();
int32 frame_subsampling = decodable_opts.frame_subsampling_factor;

KALDI_VLOG(1) << "Loading AM...";

TransitionModel trans_model;
Expand Down Expand Up @@ -239,6 +259,15 @@ int main(int argc, char *argv[]) {
CompactLattice lat;
decoder.GetLattice(true, &lat);
std::string msg = LatticeToString(lat, *word_syms);

// get time-span from previous endpoint to end of audio,
if (produce_time) {
int32 t_beg = frame_offset - decoder.NumFramesDecoded();
int32 t_end = frame_offset;
msg = GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) + " " + msg;
}

KALDI_VLOG(1) << "EndOfAudio, sending message: " << msg;
server.WriteLn(msg);
} else
server.Write("\n");
Expand All @@ -265,7 +294,17 @@ int main(int argc, char *argv[]) {
if (decoder.NumFramesDecoded() > 0) {
Lattice lat;
decoder.GetBestPath(false, &lat);
TopSort(&lat); // for LatticeStateTimes(),
std::string msg = LatticeToString(lat, *word_syms);

// get time-span after previous endpoint,
if (produce_time) {
int32 t_beg = frame_offset;
int32 t_end = frame_offset + GetLatticeTimeSpan(lat);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if TopSort and LatticeStateTimes are really needed here. Can't we compute t_end using decoder.NumFramesDecoded() and/or samp_count/check_count?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, these are good points, the PR is changed now accordingly...

msg = GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) + " " + msg;
}

KALDI_VLOG(1) << "Temporary transcript: " << msg;
server.WriteLn(msg, "\r");
}
check_count += check_period;
Expand All @@ -277,8 +316,17 @@ int main(int argc, char *argv[]) {
CompactLattice lat;
decoder.GetLattice(true, &lat);
std::string msg = LatticeToString(lat, *word_syms);

// get time-span between endpoints,
if (produce_time) {
int32 t_beg = frame_offset - decoder.NumFramesDecoded();
int32 t_end = frame_offset;
msg = GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) + " " + msg;
}

KALDI_VLOG(1) << "Endpoint, sending message: " << msg;
server.WriteLn(msg);
break;
break; // while (true)
}
}
}
Expand Down Expand Up @@ -439,4 +487,4 @@ void TcpServer::Disconnect() {
client_desc_ = -1;
}
}
} // namespace kaldi
} // namespace kaldi