diff --git a/.github/workflows/c-api-from-buffer.yaml b/.github/workflows/c-api-from-buffer.yaml index f1284625f..7d3495791 100644 --- a/.github/workflows/c-api-from-buffer.yaml +++ b/.github/workflows/c-api-from-buffer.yaml @@ -183,4 +183,36 @@ jobs: ./streaming-ctc-buffered-tokens-hotwords-c-api - rm -rf sherpa-onnx-streaming-ctc-* \ No newline at end of file + rm -rf sherpa-onnx-streaming-ctc-* + + - name: Test keywords spotting with tokens and hotwords loaded from buffers + shell: bash + run: | + gcc -o keywords-spotter-buffered-tokens-keywords-c-api ./c-api-examples/keywords-spotter-buffered-tokens-keywords-c-api.c \ + -I ./build/install/include \ + -L ./build/install/lib/ \ + -l sherpa-onnx-c-api \ + -l onnxruntime + + ls -lh keywords-spotter-buffered-tokens-keywords-c-api + + if [[ ${{ matrix.os }} == ubuntu-latest ]]; then + ldd ./keywords-spotter-buffered-tokens-keywords-c-api + echo "----" + readelf -d ./keywords-spotter-buffered-tokens-keywords-c-api + fi + + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 + tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 + rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 + + ls -lh sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile + echo "---" + ls -lh sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/test_wavs + + export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH + export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH + + ./keywords-spotter-buffered-tokens-keywords-c-api + + rm -rf sherpa-onnx-kws-zipformer-* \ No newline at end of file diff --git a/c-api-examples/CMakeLists.txt b/c-api-examples/CMakeLists.txt index e2af2d5fb..6d8e22405 100644 --- a/c-api-examples/CMakeLists.txt +++ b/c-api-examples/CMakeLists.txt @@ -60,6 +60,9 @@ add_executable(streaming-ctc-buffered-tokens-hotwords-c-api streaming-ctc-buffered-tokens-hotwords-c-api.c) target_link_libraries(streaming-ctc-buffered-tokens-hotwords-c-api sherpa-onnx-c-api) +add_executable(keywords-spotter-buffered-tokens-keywords-c-api + keywords-spotter-buffered-tokens-keywords-c-api.c) +target_link_libraries(keywords-spotter-buffered-tokens-keywords-c-api sherpa-onnx-c-api) if(SHERPA_ONNX_HAS_ALSA) add_subdirectory(./asr-microphone-example) diff --git a/c-api-examples/keywords-spotter-buffered-tokens-keywords-c-api.c b/c-api-examples/keywords-spotter-buffered-tokens-keywords-c-api.c new file mode 100644 index 000000000..9e0861c40 --- /dev/null +++ b/c-api-examples/keywords-spotter-buffered-tokens-keywords-c-api.c @@ -0,0 +1,196 @@ +// c-api-examples/keywords-spotter-buffered-tokens-keywords-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation +// Copyright (c) 2024 Luo Xiao + +// +// This file demonstrates how to use keywords spotter with sherpa-onnx's C +// API and with tokens and keywords loaded from buffered strings instead of from +// external files API. +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 +// tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 +// rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include "sherpa-onnx/c-api/c-api.h" + +static size_t ReadFile(const char *filename, const char **buffer_out) { + FILE *file = fopen(filename, "r"); + if (file == NULL) { + fprintf(stderr, "Failed to open %s\n", filename); + return -1; + } + fseek(file, 0L, SEEK_END); + long size = ftell(file); + rewind(file); + *buffer_out = malloc(size); + if (*buffer_out == NULL) { + fclose(file); + fprintf(stderr, "Memory error\n"); + return -1; + } + size_t read_bytes = fread(*buffer_out, 1, size, file); + if (read_bytes != size) { + printf("Errors occured in reading the file %s\n", filename); + free((void *)*buffer_out); + *buffer_out = NULL; + fclose(file); + return -1; + } + fclose(file); + return read_bytes; +} + +int32_t main() { + const char *wav_filename = + "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/test_wavs/" + "6.wav"; + const char *encoder_filename = + "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" + "encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"; + const char *decoder_filename = + "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" + "decoder-epoch-12-avg-2-chunk-16-left-64.onnx"; + const char *joiner_filename = + "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" + "joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx"; + const char *provider = "cpu"; + const char *tokens_filename = + "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/tokens.txt"; + const char *keywords_filename = + "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" + "keywords.txt"; + const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + // reading tokens and hotwords to buffers + const char *tokens_buf; + size_t token_buf_size = ReadFile(tokens_filename, &tokens_buf); + if (token_buf_size < 1) { + fprintf(stderr, "Please check your tokens.txt!\n"); + free((void *)tokens_buf); + return -1; + } + const char *keywords_buf; + size_t keywords_buf_size = ReadFile(keywords_filename, &keywords_buf); + if (keywords_buf_size < 1) { + fprintf(stderr, "Please check your hotwords.txt!\n"); + free((void *)keywords_buf); + return -1; + } + + // Zipformer config + SherpaOnnxOnlineTransducerModelConfig zipformer_config; + memset(&zipformer_config, 0, sizeof(zipformer_config)); + zipformer_config.encoder = encoder_filename; + zipformer_config.decoder = decoder_filename; + zipformer_config.joiner = joiner_filename; + + // Online model config + SherpaOnnxOnlineModelConfig online_model_config; + memset(&online_model_config, 0, sizeof(online_model_config)); + online_model_config.debug = 1; + online_model_config.num_threads = 1; + online_model_config.provider = provider; + online_model_config.tokens_buf = tokens_buf; + online_model_config.tokens_buf_size = token_buf_size; + online_model_config.transducer = zipformer_config; + + // Keywords-spotter config + SherpaOnnxKeywordSpotterConfig keywords_spotter_config; + memset(&keywords_spotter_config, 0, sizeof(keywords_spotter_config)); + keywords_spotter_config.max_active_paths = 4; + keywords_spotter_config.keywords_threshold = 0.1; + keywords_spotter_config.keywords_score = 3.0; + keywords_spotter_config.model_config = online_model_config; + keywords_spotter_config.keywords_buf = keywords_buf; + keywords_spotter_config.keywords_buf_size = keywords_buf_size; + + SherpaOnnxKeywordSpotter *keywords_spotter = + SherpaOnnxCreateKeywordSpotter(&keywords_spotter_config); + + free((void *)tokens_buf); + tokens_buf = NULL; + free((void *)keywords_buf); + keywords_buf = NULL; + + if (keywords_spotter == NULL) { + fprintf(stderr, "Please check your config!\n"); + SherpaOnnxFreeWave(wave); + return -1; + } + + SherpaOnnxOnlineStream *stream = + SherpaOnnxCreateKeywordStream(keywords_spotter); + + const SherpaOnnxDisplay *display = SherpaOnnxCreateDisplay(50); + int32_t segment_id = 0; + +// simulate streaming. You can choose an arbitrary N +#define N 3200 + + fprintf(stderr, "sample rate: %d, num samples: %d, duration: %.2f s\n", + wave->sample_rate, wave->num_samples, + (float)wave->num_samples / wave->sample_rate); + + int32_t k = 0; + while (k < wave->num_samples) { + int32_t start = k; + int32_t end = + (start + N > wave->num_samples) ? wave->num_samples : (start + N); + k += N; + + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, + wave->samples + start, end - start); + while (SherpaOnnxIsKeywordStreamReady(keywords_spotter, stream)) { + SherpaOnnxDecodeKeywordStream(keywords_spotter, stream); + } + + const SherpaOnnxKeywordResult *r = + SherpaOnnxGetKeywordResult(keywords_spotter, stream); + + if (strlen(r->keyword)) { + SherpaOnnxPrint(display, segment_id, r->keyword); + } + + SherpaOnnxDestroyKeywordResult(r); + } + + // add some tail padding + float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, + 4800); + + SherpaOnnxFreeWave(wave); + + SherpaOnnxOnlineStreamInputFinished(stream); + while (SherpaOnnxIsKeywordStreamReady(keywords_spotter, stream)) { + SherpaOnnxDecodeKeywordStream(keywords_spotter, stream); + } + + const SherpaOnnxKeywordResult *r = + SherpaOnnxGetKeywordResult(keywords_spotter, stream); + + if (strlen(r->keyword)) { + SherpaOnnxPrint(display, segment_id, r->keyword); + } + + SherpaOnnxDestroyKeywordResult(r); + + SherpaOnnxDestroyDisplay(display); + SherpaOnnxDestroyOnlineStream(stream); + SherpaOnnxDestroyKeywordSpotter(keywords_spotter); + fprintf(stderr, "\n"); + + return 0; +} diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 6b5d6f73a..176557c75 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -667,6 +667,12 @@ SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( spotter_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, ""); + if (config->model_config.tokens_buf && + config->model_config.tokens_buf_size > 0) { + spotter_config.model_config.tokens_buf = std::string( + config->model_config.tokens_buf, config->model_config.tokens_buf_size); + } + spotter_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1); spotter_config.model_config.provider_config.provider = @@ -691,6 +697,10 @@ SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( SHERPA_ONNX_OR(config->keywords_threshold, 0.25); spotter_config.keywords_file = SHERPA_ONNX_OR(config->keywords_file, ""); + if (config->keywords_buf && config->keywords_buf_size > 0) { + spotter_config.keywords_buf = + std::string(config->keywords_buf, config->keywords_buf_size); + } if (config->model_config.debug) { SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str()); diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 3be5a19cd..58615fe48 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -88,8 +88,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig { // - cjkchar+bpe const char *modeling_unit; const char *bpe_vocab; - /// if non-null, loading the tokens from the buffered string directly in - /// prioriy + /// if non-null, loading the tokens from the buffer instead of from the + /// "tokens" file const char *tokens_buf; /// byte size excluding the trailing '\0' int32_t tokens_buf_size; @@ -637,6 +637,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig { float keywords_score; float keywords_threshold; const char *keywords_file; + /// if non-null, loading the keywords from the buffer instead of from the + /// keywords_file + const char *keywords_buf; + /// byte size excluding the trailing '\0' + int32_t keywords_buf_size; } SherpaOnnxKeywordSpotterConfig; SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter diff --git a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h index 2300839f3..759639184 100644 --- a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h +++ b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h @@ -66,15 +66,25 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { public: explicit KeywordSpotterTransducerImpl(const KeywordSpotterConfig &config) : config_(config), - model_(OnlineTransducerModel::Create(config.model_config)), - sym_(config.model_config.tokens) { + model_(OnlineTransducerModel::Create(config.model_config)) { + if (!config.model_config.tokens_buf.empty()) { + sym_ = SymbolTable(config.model_config.tokens_buf, false); + } else { + /// assuming tokens_buf and tokens are guaranteed not being both empty + sym_ = SymbolTable(config.model_config.tokens, true); + } + if (sym_.Contains("")) { unk_id_ = sym_[""]; } model_->SetFeatureDim(config.feat_config.feature_dim); - InitKeywords(); + if (config.keywords_buf.empty()) { + InitKeywords(); + } else { + InitKeywordsFromBufStr(); + } decoder_ = std::make_unique( model_.get(), config_.max_active_paths, config_.num_trailing_blanks, @@ -305,6 +315,12 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { } #endif + void InitKeywordsFromBufStr() { + // keywords_buf's content is supposed to be same as the keywords_file's + std::istringstream is(config_.keywords_buf); + InitKeywords(is); + } + void InitOnlineStream(OnlineStream *stream) const { auto r = decoder_->GetEmptyResult(); SHERPA_ONNX_CHECK_EQ(r.hyps.Size(), 1); diff --git a/sherpa-onnx/csrc/keyword-spotter.cc b/sherpa-onnx/csrc/keyword-spotter.cc index 1110ee584..d1bf6d63b 100644 --- a/sherpa-onnx/csrc/keyword-spotter.cc +++ b/sherpa-onnx/csrc/keyword-spotter.cc @@ -89,8 +89,17 @@ void KeywordSpotterConfig::Register(ParseOptions *po) { } bool KeywordSpotterConfig::Validate() const { - if (keywords_file.empty()) { - SHERPA_ONNX_LOGE("Please provide --keywords-file."); + if (!keywords_file.empty() && !keywords_buf.empty()) { + SHERPA_ONNX_LOGE( + "you can not provide a keywords_buf and a keywords file: '%s', " + "at the same time, which is confusing", + keywords_file.c_str()); + return false; + } + + if (keywords_file.empty() && keywords_buf.empty()) { + SHERPA_ONNX_LOGE( + "Please provide either a keywords-file or the keywords-buf"); return false; } @@ -99,7 +108,7 @@ bool KeywordSpotterConfig::Validate() const { // keywords file will be packaged into the sherpa-onnx-wasm-kws-main.data file // Solution: take keyword_file variable is directly // parsed as a string of keywords - if (!std::ifstream(keywords_file.c_str()).good()) { + if (keywords_buf.empty() && !std::ifstream(keywords_file.c_str()).good()) { SHERPA_ONNX_LOGE("Keywords file '%s' does not exist.", keywords_file.c_str()); return false; diff --git a/sherpa-onnx/csrc/keyword-spotter.h b/sherpa-onnx/csrc/keyword-spotter.h index 3d7935cc3..f0c31bdb4 100644 --- a/sherpa-onnx/csrc/keyword-spotter.h +++ b/sherpa-onnx/csrc/keyword-spotter.h @@ -69,6 +69,11 @@ struct KeywordSpotterConfig { std::string keywords_file; + /// if keywords_buf is non-empty, + /// the keywords will be loaded from the buffer instead of from the + /// "keywrods_file" + std::string keywords_buf; + KeywordSpotterConfig() = default; KeywordSpotterConfig(const FeatureExtractorConfig &feat_config, diff --git a/sherpa-onnx/csrc/online-model-config.h b/sherpa-onnx/csrc/online-model-config.h index a2aaae038..a920512d8 100644 --- a/sherpa-onnx/csrc/online-model-config.h +++ b/sherpa-onnx/csrc/online-model-config.h @@ -46,8 +46,8 @@ struct OnlineModelConfig { std::string bpe_vocab; /// if tokens_buf is non-empty, - /// the tokens will be loaded from the buffered string instead of from the - /// ${tokens} file + /// the tokens will be loaded from the buffer instead of from the + /// "tokens" file std::string tokens_buf; OnlineModelConfig() = default; diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index eedd30b21..45e0f4237 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -107,8 +107,8 @@ struct OnlineRecognizerConfig { std::string rule_fars; /// used only for modified_beam_search, if hotwords_buf is non-empty, - /// the hotwords will be loaded from the buffered string instead of from - /// ${hotwords_file} + /// the hotwords will be loaded from the buffered string instead of from the + /// "hotwords_file" std::string hotwords_buf; OnlineRecognizerConfig() = default;