Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept max_speaker_count as input in Riva ASR Client #98

Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions riva/clients/asr/riva_asr_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ DEFINE_bool(
"Whether to use SSL credentials or not. If ssl_cert is specified, "
"this is assumed to be true");
DEFINE_bool(speaker_diarization, false, "Flag that controls if speaker diarization is requested");
DEFINE_int32(diarization_max_speakers, 3, "Max number of speakers to detect when performing speaker diarization");
DEFINE_string(metadata, "", "Comma separated key-value pair(s) of metadata to be sent to server");
DEFINE_int32(start_history, -1, "Value to detect and initiate start of speech utterance");
DEFINE_double(
Expand All @@ -91,14 +92,14 @@ class RecognizeClient {
bool automatic_punctuation, bool separate_recognition_per_channel, bool print_transcripts,
std::string output_filename, std::string model_name, bool ctm, bool verbatim_transcripts,
const std::string& boosted_phrases_file, float boosted_phrases_score,
bool speaker_diarization, int32_t start_history, float start_threshold, int32_t stop_history,
bool speaker_diarization, int32_t diarization_max_speakers, int32_t start_history, float start_threshold, int32_t stop_history,
int32_t stop_history_eou, float stop_threshold, float stop_threshold_eou,
std::string custom_configuration)
: stub_(nr_asr::RivaSpeechRecognition::NewStub(channel)), language_code_(language_code),
max_alternatives_(max_alternatives), profanity_filter_(profanity_filter),
word_time_offsets_(word_time_offsets), automatic_punctuation_(automatic_punctuation),
separate_recognition_per_channel_(separate_recognition_per_channel),
speaker_diarization_(speaker_diarization), print_transcripts_(print_transcripts),
speaker_diarization_(speaker_diarization), max_speaker_count_(diarization_max_speakers), print_transcripts_(print_transcripts),
done_sending_(false), num_requests_(0), num_responses_(0), num_failed_requests_(0),
total_audio_processed_(0.), model_name_(model_name), output_filename_(output_filename),
verbatim_transcripts_(verbatim_transcripts), boosted_phrases_score_(boosted_phrases_score),
Expand Down Expand Up @@ -229,6 +230,7 @@ class RecognizeClient {

auto speaker_diarization_config = config->mutable_diarization_config();
speaker_diarization_config->set_enable_speaker_diarization(speaker_diarization_);
speaker_diarization_config->set_max_speaker_count(max_speaker_count_);

if (model_name_ != "") {
config->set_model(model_name_);
Expand Down Expand Up @@ -401,6 +403,7 @@ class RecognizeClient {
bool automatic_punctuation_;
bool separate_recognition_per_channel_;
bool speaker_diarization_;
int32_t max_speaker_count_;
rmittal-github marked this conversation as resolved.
Show resolved Hide resolved
bool print_transcripts_;


Expand Down Expand Up @@ -456,6 +459,7 @@ main(int argc, char** argv)
str_usage << " --boosted_words_score=<float>" << std::endl;
str_usage << " --ssl_cert=<filename>" << std::endl;
str_usage << " --speaker_diarization=<true|false>" << std::endl;
str_usage << " --diarization_max_speakers=<int>" << std::endl;
str_usage << " --model_name=<model>" << std::endl;
str_usage << " --list_models" << std::endl;
str_usage << " --metadata=<key,value,...>" << std::endl;
Expand Down Expand Up @@ -529,7 +533,7 @@ main(int argc, char** argv)
FLAGS_word_time_offsets, FLAGS_automatic_punctuation,
/* separate_recognition_per_channel*/ false, FLAGS_print_transcripts, FLAGS_output_filename,
FLAGS_model_name, FLAGS_output_ctm, FLAGS_verbatim_transcripts, FLAGS_boosted_words_file,
(float)FLAGS_boosted_words_score, FLAGS_speaker_diarization, FLAGS_start_history,
(float)FLAGS_boosted_words_score, FLAGS_speaker_diarization, FLAGS_diarization_max_speakers, FLAGS_start_history,
FLAGS_start_threshold, FLAGS_stop_history, FLAGS_stop_history_eou, FLAGS_stop_threshold,
FLAGS_stop_threshold_eou, FLAGS_custom_configuration);

Expand Down