From ce9726ae1b0123e2ff8850c42051889ecf638076 Mon Sep 17 00:00:00 2001 From: Prabhsimran Singh Date: Tue, 17 Sep 2024 16:16:18 +0530 Subject: [PATCH 1/4] fix: accept input for max_speaker_count in riva_asr_client --- riva/clients/asr/riva_asr_client.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/riva/clients/asr/riva_asr_client.cc b/riva/clients/asr/riva_asr_client.cc index 9e0f215..7ece75a 100644 --- a/riva/clients/asr/riva_asr_client.cc +++ b/riva/clients/asr/riva_asr_client.cc @@ -67,6 +67,7 @@ DEFINE_bool( "Whether to use SSL credentials or not. If ssl_cert is specified, " "this is assumed to be true"); DEFINE_bool(speaker_diarization, false, "Flag that controls if speaker diarization is requested"); +DEFINE_int32(max_speaker_count, -1, "Max number of speakers to detect when performing speaker diarization"); DEFINE_string(metadata, "", "Comma separated key-value pair(s) of metadata to be sent to server"); DEFINE_int32(start_history, -1, "Value to detect and initiate start of speech utterance"); DEFINE_double( @@ -91,14 +92,14 @@ class RecognizeClient { bool automatic_punctuation, bool separate_recognition_per_channel, bool print_transcripts, std::string output_filename, std::string model_name, bool ctm, bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score, - bool speaker_diarization, int32_t start_history, float start_threshold, int32_t stop_history, + bool speaker_diarization, int32_t max_speaker_count, int32_t start_history, float start_threshold, int32_t stop_history, int32_t stop_history_eou, float stop_threshold, float stop_threshold_eou, std::string custom_configuration) : stub_(nr_asr::RivaSpeechRecognition::NewStub(channel)), language_code_(language_code), max_alternatives_(max_alternatives), profanity_filter_(profanity_filter), word_time_offsets_(word_time_offsets), automatic_punctuation_(automatic_punctuation), separate_recognition_per_channel_(separate_recognition_per_channel), - speaker_diarization_(speaker_diarization), print_transcripts_(print_transcripts), + speaker_diarization_(speaker_diarization), max_speaker_count_(max_speaker_count), print_transcripts_(print_transcripts), done_sending_(false), num_requests_(0), num_responses_(0), num_failed_requests_(0), total_audio_processed_(0.), model_name_(model_name), output_filename_(output_filename), verbatim_transcripts_(verbatim_transcripts), boosted_phrases_score_(boosted_phrases_score), @@ -229,6 +230,7 @@ class RecognizeClient { auto speaker_diarization_config = config->mutable_diarization_config(); speaker_diarization_config->set_enable_speaker_diarization(speaker_diarization_); + speaker_diarization_config->set_max_speaker_count(max_speaker_count_); if (model_name_ != "") { config->set_model(model_name_); @@ -401,6 +403,7 @@ class RecognizeClient { bool automatic_punctuation_; bool separate_recognition_per_channel_; bool speaker_diarization_; + int32_t max_speaker_count_; bool print_transcripts_; @@ -456,6 +459,7 @@ main(int argc, char** argv) str_usage << " --boosted_words_score=" << std::endl; str_usage << " --ssl_cert=" << std::endl; str_usage << " --speaker_diarization=" << std::endl; + str_usage << " --max_speaker_count=" << std::endl; str_usage << " --model_name=" << std::endl; str_usage << " --list_models" << std::endl; str_usage << " --metadata=" << std::endl; @@ -529,7 +533,7 @@ main(int argc, char** argv) FLAGS_word_time_offsets, FLAGS_automatic_punctuation, /* separate_recognition_per_channel*/ false, FLAGS_print_transcripts, FLAGS_output_filename, FLAGS_model_name, FLAGS_output_ctm, FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, - (float)FLAGS_boosted_words_score, FLAGS_speaker_diarization, FLAGS_start_history, + (float)FLAGS_boosted_words_score, FLAGS_speaker_diarization, FLAGS_max_speaker_count, FLAGS_start_history, FLAGS_start_threshold, FLAGS_stop_history, FLAGS_stop_history_eou, FLAGS_stop_threshold, FLAGS_stop_threshold_eou, FLAGS_custom_configuration); From 7a45776e6b57fb4e93cd7f5930b78b86a6b92a2f Mon Sep 17 00:00:00 2001 From: Prabhsimran Singh Date: Wed, 18 Sep 2024 19:35:08 +0530 Subject: [PATCH 2/4] update: default max_speaker_count to 3 --- riva/clients/asr/riva_asr_client.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/riva/clients/asr/riva_asr_client.cc b/riva/clients/asr/riva_asr_client.cc index 7ece75a..aa14db7 100644 --- a/riva/clients/asr/riva_asr_client.cc +++ b/riva/clients/asr/riva_asr_client.cc @@ -67,7 +67,7 @@ DEFINE_bool( "Whether to use SSL credentials or not. If ssl_cert is specified, " "this is assumed to be true"); DEFINE_bool(speaker_diarization, false, "Flag that controls if speaker diarization is requested"); -DEFINE_int32(max_speaker_count, -1, "Max number of speakers to detect when performing speaker diarization"); +DEFINE_int32(max_speaker_count, 3, "Max number of speakers to detect when performing speaker diarization"); DEFINE_string(metadata, "", "Comma separated key-value pair(s) of metadata to be sent to server"); DEFINE_int32(start_history, -1, "Value to detect and initiate start of speech utterance"); DEFINE_double( From 64dcf91922148ab25bf47a776b1fcd21ddceca81 Mon Sep 17 00:00:00 2001 From: Prabhsimran Singh Date: Thu, 19 Sep 2024 13:24:14 +0530 Subject: [PATCH 3/4] fix: rename input field to diarization_max_speakers --- riva/clients/asr/riva_asr_client.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/riva/clients/asr/riva_asr_client.cc b/riva/clients/asr/riva_asr_client.cc index aa14db7..84981c1 100644 --- a/riva/clients/asr/riva_asr_client.cc +++ b/riva/clients/asr/riva_asr_client.cc @@ -67,7 +67,7 @@ DEFINE_bool( "Whether to use SSL credentials or not. If ssl_cert is specified, " "this is assumed to be true"); DEFINE_bool(speaker_diarization, false, "Flag that controls if speaker diarization is requested"); -DEFINE_int32(max_speaker_count, 3, "Max number of speakers to detect when performing speaker diarization"); +DEFINE_int32(diarization_max_speakers, 3, "Max number of speakers to detect when performing speaker diarization"); DEFINE_string(metadata, "", "Comma separated key-value pair(s) of metadata to be sent to server"); DEFINE_int32(start_history, -1, "Value to detect and initiate start of speech utterance"); DEFINE_double( @@ -92,14 +92,14 @@ class RecognizeClient { bool automatic_punctuation, bool separate_recognition_per_channel, bool print_transcripts, std::string output_filename, std::string model_name, bool ctm, bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score, - bool speaker_diarization, int32_t max_speaker_count, int32_t start_history, float start_threshold, int32_t stop_history, + bool speaker_diarization, int32_t diarization_max_speakers, int32_t start_history, float start_threshold, int32_t stop_history, int32_t stop_history_eou, float stop_threshold, float stop_threshold_eou, std::string custom_configuration) : stub_(nr_asr::RivaSpeechRecognition::NewStub(channel)), language_code_(language_code), max_alternatives_(max_alternatives), profanity_filter_(profanity_filter), word_time_offsets_(word_time_offsets), automatic_punctuation_(automatic_punctuation), separate_recognition_per_channel_(separate_recognition_per_channel), - speaker_diarization_(speaker_diarization), max_speaker_count_(max_speaker_count), print_transcripts_(print_transcripts), + speaker_diarization_(speaker_diarization), max_speaker_count_(diarization_max_speakers), print_transcripts_(print_transcripts), done_sending_(false), num_requests_(0), num_responses_(0), num_failed_requests_(0), total_audio_processed_(0.), model_name_(model_name), output_filename_(output_filename), verbatim_transcripts_(verbatim_transcripts), boosted_phrases_score_(boosted_phrases_score), @@ -459,7 +459,7 @@ main(int argc, char** argv) str_usage << " --boosted_words_score=" << std::endl; str_usage << " --ssl_cert=" << std::endl; str_usage << " --speaker_diarization=" << std::endl; - str_usage << " --max_speaker_count=" << std::endl; + str_usage << " --diarization_max_speakers=" << std::endl; str_usage << " --model_name=" << std::endl; str_usage << " --list_models" << std::endl; str_usage << " --metadata=" << std::endl; @@ -533,7 +533,7 @@ main(int argc, char** argv) FLAGS_word_time_offsets, FLAGS_automatic_punctuation, /* separate_recognition_per_channel*/ false, FLAGS_print_transcripts, FLAGS_output_filename, FLAGS_model_name, FLAGS_output_ctm, FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, - (float)FLAGS_boosted_words_score, FLAGS_speaker_diarization, FLAGS_max_speaker_count, FLAGS_start_history, + (float)FLAGS_boosted_words_score, FLAGS_speaker_diarization, FLAGS_diarization_max_speakers, FLAGS_start_history, FLAGS_start_threshold, FLAGS_stop_history, FLAGS_stop_history_eou, FLAGS_stop_threshold, FLAGS_stop_threshold_eou, FLAGS_custom_configuration); From 62a49943f30807ef7d592245943b13acd0171e6f Mon Sep 17 00:00:00 2001 From: Prabhsimran Singh Date: Thu, 19 Sep 2024 14:22:30 +0530 Subject: [PATCH 4/4] fix: var name for consistency --- riva/clients/asr/riva_asr_client.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/riva/clients/asr/riva_asr_client.cc b/riva/clients/asr/riva_asr_client.cc index 84981c1..f10ca38 100644 --- a/riva/clients/asr/riva_asr_client.cc +++ b/riva/clients/asr/riva_asr_client.cc @@ -99,7 +99,7 @@ class RecognizeClient { max_alternatives_(max_alternatives), profanity_filter_(profanity_filter), word_time_offsets_(word_time_offsets), automatic_punctuation_(automatic_punctuation), separate_recognition_per_channel_(separate_recognition_per_channel), - speaker_diarization_(speaker_diarization), max_speaker_count_(diarization_max_speakers), print_transcripts_(print_transcripts), + speaker_diarization_(speaker_diarization), diarization_max_speakers_(diarization_max_speakers), print_transcripts_(print_transcripts), done_sending_(false), num_requests_(0), num_responses_(0), num_failed_requests_(0), total_audio_processed_(0.), model_name_(model_name), output_filename_(output_filename), verbatim_transcripts_(verbatim_transcripts), boosted_phrases_score_(boosted_phrases_score), @@ -230,7 +230,7 @@ class RecognizeClient { auto speaker_diarization_config = config->mutable_diarization_config(); speaker_diarization_config->set_enable_speaker_diarization(speaker_diarization_); - speaker_diarization_config->set_max_speaker_count(max_speaker_count_); + speaker_diarization_config->set_max_speaker_count(diarization_max_speakers_); if (model_name_ != "") { config->set_model(model_name_); @@ -403,7 +403,7 @@ class RecognizeClient { bool automatic_punctuation_; bool separate_recognition_per_channel_; bool speaker_diarization_; - int32_t max_speaker_count_; + int32_t diarization_max_speakers_; bool print_transcripts_;