Skip to content

Commit 4532d26

Browse files
Add list_models option for ASR clients
1 parent 71421c5 commit 4532d26

File tree

4 files changed

+56
-3
lines changed

4 files changed

+56
-3
lines changed

riva/clients/asr/riva_asr_client.cc

+27
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ DEFINE_int32(num_parallel_requests, 10, "Number of parallel requests to keep in
5252
DEFINE_bool(print_transcripts, true, "Print final transcripts");
5353
DEFINE_string(output_filename, "", "Filename to write output transcripts");
5454
DEFINE_string(model_name, "", "Name of the TRTIS model to use");
55+
DEFINE_bool(list_models, false, "List available models on server");
5556
DEFINE_bool(output_ctm, false, "If true, output format should be NIST CTM");
5657
DEFINE_string(language_code, "en-US", "Language code of the model to use");
5758
DEFINE_string(boosted_words_file, "", "File with a list of words to boost. One line per word.");
@@ -455,6 +456,8 @@ main(int argc, char** argv)
455456
str_usage << " --boosted_words_score=<float>" << std::endl;
456457
str_usage << " --ssl_cert=<filename>" << std::endl;
457458
str_usage << " --speaker_diarization=<true|false>" << std::endl;
459+
str_usage << " --model_name=<model>" << std::endl;
460+
str_usage << " --list_models" << std::endl;
458461
str_usage << " --metadata=<key,value,...>" << std::endl;
459462
str_usage << " --start_history=<int>" << std::endl;
460463
str_usage << " --start_threshold=<float>" << std::endl;
@@ -503,6 +506,30 @@ main(int argc, char** argv)
503506
return 1;
504507
}
505508

509+
if (FLAGS_list_models) {
510+
std::unique_ptr<nr_asr::RivaSpeechRecognition::Stub> asr_stub_(
511+
nr_asr::RivaSpeechRecognition::NewStub(grpc_channel));
512+
grpc::ClientContext asr_context;
513+
nr_asr::RivaSpeechRecognitionConfigRequest asr_request;
514+
nr_asr::RivaSpeechRecognitionConfigResponse asr_response;
515+
asr_stub_->GetRivaSpeechRecognitionConfig(&asr_context, asr_request, &asr_response);
516+
517+
std::multimap<std::string, std::string> model_map;
518+
for (int i = 0; i < asr_response.model_config_size(); i++) {
519+
if (asr_response.model_config(i).parameters().find("type")->second == "offline") {
520+
model_map.insert(std::make_pair(
521+
asr_response.model_config(i).parameters().find("language_code")->second,
522+
asr_response.model_config(i).model_name()));
523+
}
524+
}
525+
526+
for (auto& m : model_map) {
527+
std::cout << "'" << m.first << "': '" << m.second << "'" << std::endl;
528+
}
529+
530+
return 0;
531+
}
532+
506533
RecognizeClient recognize_client(
507534
grpc_channel, FLAGS_language_code, FLAGS_max_alternatives, FLAGS_profanity_filter,
508535
FLAGS_word_time_offsets, FLAGS_automatic_punctuation,

riva/clients/asr/riva_streaming_asr_client.cc

+27-1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ DEFINE_string(
6363
output_filename, "final_transcripts.json",
6464
"Filename of .json file containing output transcripts");
6565
DEFINE_string(model_name, "", "Name of the TRTIS model to use");
66+
DEFINE_bool(list_models, false, "List available models on server");
6667
DEFINE_string(language_code, "en-US", "Language code of the model to use");
6768
DEFINE_string(boosted_words_file, "", "File with a list of words to boost. One line per word.");
6869
DEFINE_double(boosted_words_score, 10., "Score by which to boost the boosted words");
@@ -133,6 +134,8 @@ main(int argc, char** argv)
133134
str_usage << " --boosted_words_file=<string>" << std::endl;
134135
str_usage << " --boosted_words_score=<float>" << std::endl;
135136
str_usage << " --ssl_cert=<filename>" << std::endl;
137+
str_usage << " --model_name=<model>" << std::endl;
138+
str_usage << " --list_models" << std::endl;
136139
str_usage << " --metadata=<key,value,...>" << std::endl;
137140
str_usage << " --start_history=<int>" << std::endl;
138141
str_usage << " --start_threshold=<float>" << std::endl;
@@ -182,6 +185,29 @@ main(int argc, char** argv)
182185
return 1;
183186
}
184187

188+
if (FLAGS_list_models) {
189+
std::unique_ptr<nr_asr::RivaSpeechRecognition::Stub> asr_stub_(
190+
nr_asr::RivaSpeechRecognition::NewStub(grpc_channel));
191+
grpc::ClientContext asr_context;
192+
nr_asr::RivaSpeechRecognitionConfigRequest asr_request;
193+
nr_asr::RivaSpeechRecognitionConfigResponse asr_response;
194+
asr_stub_->GetRivaSpeechRecognitionConfig(&asr_context, asr_request, &asr_response);
195+
196+
std::multimap<std::string, std::string> model_map;
197+
for (int i = 0; i < asr_response.model_config_size(); i++) {
198+
if (asr_response.model_config(i).parameters().find("type")->second == "online") {
199+
model_map.insert(std::make_pair(
200+
asr_response.model_config(i).parameters().find("language_code")->second,
201+
asr_response.model_config(i).model_name()));
202+
}
203+
}
204+
205+
for (auto& m : model_map) {
206+
std::cout << "'" << m.first << "': '" << m.second << "'" << std::endl;
207+
}
208+
return 0;
209+
}
210+
185211
StreamingRecognizeClient recognize_client(
186212
grpc_channel, FLAGS_num_parallel_requests, FLAGS_language_code, FLAGS_max_alternatives,
187213
FLAGS_profanity_filter, FLAGS_word_time_offsets, FLAGS_automatic_punctuation,
@@ -230,4 +256,4 @@ main(int argc, char** argv)
230256
}
231257

232258
return 0;
233-
}
259+
}

riva/clients/nmt/riva_nmt_streaming_s2s_client.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ main(int argc, char** argv)
276276
nr_nmt::AvailableLanguageRequest request;
277277
nr_nmt::AvailableLanguageResponse response;
278278

279-
request.set_model("s2s_model"); // this is optional, if empty returns all available models/languages
279+
request.set_model("s2s_model"); // get only S2S supported languages
280280
nmt_s2s->ListSupportedLanguagePairs(&context, request, &response);
281281
std::cout << response.DebugString() << std::endl;
282282
return 0;

riva/clients/nmt/riva_nmt_streaming_s2t_client.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ main(int argc, char** argv)
157157
nr_nmt::AvailableLanguageRequest request;
158158
nr_nmt::AvailableLanguageResponse response;
159159

160-
request.set_model("s2t_model"); // this is optional, if empty returns all available models/languages
160+
request.set_model("s2t_model"); // get only S2T supported languages
161161
nmt_s2t->ListSupportedLanguagePairs(&context, request, &response);
162162
std::cout << response.DebugString() << std::endl;
163163
return 0;

0 commit comments

Comments
 (0)