diff --git a/riva/clients/asr/riva_asr_client.cc b/riva/clients/asr/riva_asr_client.cc index 9e0f215..f10ca38 100644 --- a/riva/clients/asr/riva_asr_client.cc +++ b/riva/clients/asr/riva_asr_client.cc @@ -67,6 +67,7 @@ DEFINE_bool( "Whether to use SSL credentials or not. If ssl_cert is specified, " "this is assumed to be true"); DEFINE_bool(speaker_diarization, false, "Flag that controls if speaker diarization is requested"); +DEFINE_int32(diarization_max_speakers, 3, "Max number of speakers to detect when performing speaker diarization"); DEFINE_string(metadata, "", "Comma separated key-value pair(s) of metadata to be sent to server"); DEFINE_int32(start_history, -1, "Value to detect and initiate start of speech utterance"); DEFINE_double( @@ -91,14 +92,14 @@ class RecognizeClient { bool automatic_punctuation, bool separate_recognition_per_channel, bool print_transcripts, std::string output_filename, std::string model_name, bool ctm, bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score, - bool speaker_diarization, int32_t start_history, float start_threshold, int32_t stop_history, + bool speaker_diarization, int32_t diarization_max_speakers, int32_t start_history, float start_threshold, int32_t stop_history, int32_t stop_history_eou, float stop_threshold, float stop_threshold_eou, std::string custom_configuration) : stub_(nr_asr::RivaSpeechRecognition::NewStub(channel)), language_code_(language_code), max_alternatives_(max_alternatives), profanity_filter_(profanity_filter), word_time_offsets_(word_time_offsets), automatic_punctuation_(automatic_punctuation), separate_recognition_per_channel_(separate_recognition_per_channel), - speaker_diarization_(speaker_diarization), print_transcripts_(print_transcripts), + speaker_diarization_(speaker_diarization), diarization_max_speakers_(diarization_max_speakers), print_transcripts_(print_transcripts), done_sending_(false), num_requests_(0), num_responses_(0), num_failed_requests_(0), total_audio_processed_(0.), model_name_(model_name), output_filename_(output_filename), verbatim_transcripts_(verbatim_transcripts), boosted_phrases_score_(boosted_phrases_score), @@ -229,6 +230,7 @@ class RecognizeClient { auto speaker_diarization_config = config->mutable_diarization_config(); speaker_diarization_config->set_enable_speaker_diarization(speaker_diarization_); + speaker_diarization_config->set_max_speaker_count(diarization_max_speakers_); if (model_name_ != "") { config->set_model(model_name_); @@ -401,6 +403,7 @@ class RecognizeClient { bool automatic_punctuation_; bool separate_recognition_per_channel_; bool speaker_diarization_; + int32_t diarization_max_speakers_; bool print_transcripts_; @@ -456,6 +459,7 @@ main(int argc, char** argv) str_usage << " --boosted_words_score=" << std::endl; str_usage << " --ssl_cert=" << std::endl; str_usage << " --speaker_diarization=" << std::endl; + str_usage << " --diarization_max_speakers=" << std::endl; str_usage << " --model_name=" << std::endl; str_usage << " --list_models" << std::endl; str_usage << " --metadata=" << std::endl; @@ -529,7 +533,7 @@ main(int argc, char** argv) FLAGS_word_time_offsets, FLAGS_automatic_punctuation, /* separate_recognition_per_channel*/ false, FLAGS_print_transcripts, FLAGS_output_filename, FLAGS_model_name, FLAGS_output_ctm, FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, - (float)FLAGS_boosted_words_score, FLAGS_speaker_diarization, FLAGS_start_history, + (float)FLAGS_boosted_words_score, FLAGS_speaker_diarization, FLAGS_diarization_max_speakers, FLAGS_start_history, FLAGS_start_threshold, FLAGS_stop_history, FLAGS_stop_history_eou, FLAGS_stop_threshold, FLAGS_stop_threshold_eou, FLAGS_custom_configuration);