From 3fcba5cec3a592355ceeaf81aa3c7e257d66c308 Mon Sep 17 00:00:00 2001 From: xinhecuican Date: Thu, 7 Mar 2024 16:08:04 +0800 Subject: [PATCH 1/4] c++ api for keyword spotter --- sherpa-onnx/c-api/c-api.cc | 179 +++++++++++++++++++++++++++++++++++++ sherpa-onnx/c-api/c-api.h | 117 ++++++++++++++++++++++++ 2 files changed, 296 insertions(+) diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 407b359ac..00d1e1992 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -17,6 +17,7 @@ #include "sherpa-onnx/csrc/online-recognizer.h" #include "sherpa-onnx/csrc/voice-activity-detector.h" #include "sherpa-onnx/csrc/wave-writer.h" +#include "sherpa-onnx/csrc/keyword-spotter.h" struct SherpaOnnxOnlineRecognizer { std::unique_ptr impl; @@ -410,6 +411,184 @@ void DestroyOfflineRecognizerResult( } } +// ============================================================ +// For Keyword Spot +// ============================================================ + +struct SherpaOnnxKeywordSpotter { + std::unique_ptr impl; +}; + +SherpaOnnxKeywordSpotter* CreateKeywordSpotter( + const SherpaOnnxKeywordSpotterConfig* config) { + sherpa_onnx::KeywordSpotterConfig spotter_config; + + spotter_config.feat_config.sampling_rate = + SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000); + spotter_config.feat_config.feature_dim = + SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); + + spotter_config.model_config.transducer.encoder = + SHERPA_ONNX_OR(config->model_config.transducer.encoder, ""); + spotter_config.model_config.transducer.decoder = + SHERPA_ONNX_OR(config->model_config.transducer.decoder, ""); + spotter_config.model_config.transducer.joiner = + SHERPA_ONNX_OR(config->model_config.transducer.joiner, ""); + + spotter_config.model_config.paraformer.encoder = + SHERPA_ONNX_OR(config->model_config.paraformer.encoder, ""); + spotter_config.model_config.paraformer.decoder = + SHERPA_ONNX_OR(config->model_config.paraformer.decoder, ""); + + spotter_config.model_config.zipformer2_ctc.model = + SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model, ""); + + spotter_config.model_config.tokens = + SHERPA_ONNX_OR(config->model_config.tokens, ""); + spotter_config.model_config.num_threads = + SHERPA_ONNX_OR(config->model_config.num_threads, 1); + spotter_config.model_config.provider = + SHERPA_ONNX_OR(config->model_config.provider, "cpu"); + spotter_config.model_config.model_type = + SHERPA_ONNX_OR(config->model_config.model_type, ""); + spotter_config.model_config.debug = + SHERPA_ONNX_OR(config->model_config.debug, 0); + + spotter_config.max_active_paths = + SHERPA_ONNX_OR(config->max_active_paths, 4); + + spotter_config.num_trailing_blanks = + SHERPA_ONNX_OR(config->num_trailing_blanks , 1); + + spotter_config.keywords_score = + SHERPA_ONNX_OR(config->keywords_score, 1.0); + + spotter_config.keywords_threshold = + SHERPA_ONNX_OR(config->keywords_threshold, 0.25); + + spotter_config.keywords_file = + SHERPA_ONNX_OR(config->keywords_file, ""); + + if (config->model_config.debug) { + SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str()); + } + + SherpaOnnxKeywordSpotter* spotter = new SherpaOnnxKeywordSpotter; + + spotter->impl = + std::make_unique(spotter_config); + + return spotter; +} + +void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter* spotter) { + delete spotter; +} + +SherpaOnnxOnlineStream* CreateKeywordStream( + const SherpaOnnxKeywordSpotter* spotter) { + SherpaOnnxOnlineStream* stream = + new SherpaOnnxOnlineStream(spotter->impl->CreateStream()); + return stream; +} + +int32_t IsKeywordStreamReady( + SherpaOnnxKeywordSpotter* spotter, SherpaOnnxOnlineStream* stream) { + return spotter->impl->IsReady(stream->impl.get()); +} + +void DecodeKeywordStream(SherpaOnnxKeywordSpotter* spotter, + SherpaOnnxOnlineStream* stream) { + return spotter->impl->DecodeStream(stream->impl.get()); +} + +void DecodeMultipleKeywordStreams( + SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams, + int32_t n) { + std::vector ss(n); + for(int32_t i=0; i!=n; ++i) { + ss[i] = streams[i]->impl.get(); + } + spotter->impl->DecodeStreams(ss.data(), n); +} + +const SherpaOnnxKeywordResult *GetKeywordResult( + SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream) { + sherpa_onnx::KeywordResult result = + spotter->impl->GetResult(stream->impl.get()); + const auto &keyword = result.keyword; + + auto r = new SherpaOnnxKeywordResult; + memset(r, 0, sizeof(SherpaOnnxKeywordResult)); + + r->start_time = result.start_time; + + // copy keyword + r->keyword = new char[keyword.size() + 1]; + std::copy(keyword.begin(), keyword.end(), const_cast(r->keyword)); + const_cast(r->keyword)[keyword.size()] = 0; + + // copy json + const auto &json = result.AsJsonString(); + r->json = new char[json.size() + 1]; + std::copy(json.begin(), json.end(), const_cast(r->json)); + const_cast(r->json)[json.size()] = 0; + + // copy tokens + auto count = result.tokens.size(); + if (count > 0) { + size_t total_length = 0; + for (const auto &token : result.tokens) { + // +1 for the null character at the end of each token + total_length += token.size() + 1; + } + + r->count = count; + // Each word ends with nullptr + r->tokens = new char[total_length]; + memset(reinterpret_cast(const_cast(r->tokens)), 0, + total_length); + char **tokens_temp = new char *[r->count]; + int32_t pos = 0; + for (int32_t i = 0; i < r->count; ++i) { + tokens_temp[i] = const_cast(r->tokens) + pos; + memcpy(reinterpret_cast(const_cast(r->tokens + pos)), + result.tokens[i].c_str(), result.tokens[i].size()); + // +1 to move past the null character + pos += result.tokens[i].size() + 1; + } + r->tokens_arr = tokens_temp; + + if (!result.timestamps.empty()) { + r->timestamps = new float[r->count]; + std::copy(result.timestamps.begin(), result.timestamps.end(), + r->timestamps); + } else { + r->timestamps = nullptr; + } + + } else { + r->count = 0; + r->timestamps = nullptr; + r->tokens = nullptr; + r->tokens_arr = nullptr; + } + + return r; +} + +void DestroyKeywordResult(const SherpaOnnxKeywordResult *r) { + if (r) { + delete[] r->keyword; + delete[] r->json; + delete[] r->tokens; + delete[] r->tokens_arr; + delete[] r->timestamps; + delete r; + } +} + + // ============================================================ // For VAD // ============================================================ diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 62b2f4dcd..4af9727df 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -473,6 +473,123 @@ SHERPA_ONNX_API const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult( SHERPA_ONNX_API void DestroyOfflineRecognizerResult( const SherpaOnnxOfflineRecognizerResult *r); +// ============================================================ +// For Keyword Spot +// ============================================================ +SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult { + /// The triggered keyword. + /// For English, it consists of space separated words. + /// For Chinese, it consists of Chinese words without spaces. + /// Example 1: "hello world" + /// Example 2: "你好世界" + const char* keyword; + + /// Decoded results at the token level. + /// For instance, for BPE-based models it consists of a list of BPE tokens. + const char* tokens; + + const char* const* tokens_arr; + + int32_t count; + + /// timestamps.size() == tokens.size() + /// timestamps[i] records the time in seconds when tokens[i] is decoded. + float* timestamps; + + /// Starting time of this segment. + /// When an endpoint is detected, it will change + float start_time; + + /** Return a json string. + * + * The returned string contains: + * { + * "keyword": "The triggered keyword", + * "tokens": [x, x, x], + * "timestamps": [x, x, x], + * "start_time": x, + * } + */ + const char* json; +}SherpaOnnxKeywordResult ; + +SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig { + SherpaOnnxFeatureConfig feat_config; + SherpaOnnxOnlineModelConfig model_config; + int32_t max_active_paths; + int32_t num_trailing_blanks; + float keywords_score; + float keywords_threshold; + const char* keywords_file; +}SherpaOnnxKeywordSpotterConfig ; + +SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter + SherpaOnnxKeywordSpotter; + +/// @param config Config for the keyword spotter. +/// @return Return a pointer to the spotter. The user has to invoke +/// DestroyKeywordSpotter() to free it to avoid memory leak. +SHERPA_ONNX_API SherpaOnnxKeywordSpotter* CreateKeywordSpotter( + const SherpaOnnxKeywordSpotterConfig* config); + +/// Free a pointer returned by CreateKeywordSpotter() +/// +/// @param p A pointer returned by CreateKeywordSpotter() +SHERPA_ONNX_API void DestroyKeywordSpotter( + SherpaOnnxKeywordSpotter* spotter); + +/// Create an online stream for accepting wave samples. +/// +/// @param spotter A pointer returned by CreateKeywordSpotter() +/// @return Return a pointer to an OnlineStream. The user has to invoke +/// DestroyOnlineStream() to free it to avoid memory leak. +SHERPA_ONNX_API SherpaOnnxOnlineStream* CreateKeywordStream( + const SherpaOnnxKeywordSpotter* spotter); + +/// Return 1 if there are enough number of feature frames for decoding. +/// Return 0 otherwise. +/// +/// @param spotter A pointer returned by CreateKeywordSpotter +/// @param stream A pointer returned by CreateKeywordStream +SHERPA_ONNX_API int32_t IsKeywordStreamReady( + SherpaOnnxKeywordSpotter* spotter, SherpaOnnxOnlineStream* stream); + +/// Call this function to run the neural network model and decoding. +// +/// Precondition for this function: IsKeywordStreamReady() MUST return 1. +SHERPA_ONNX_API void DecodeKeywordStream(SherpaOnnxKeywordSpotter* spotter, + SherpaOnnxOnlineStream* stream); + +/// This function is similar to DecodeKeywordStream(). It decodes multiple +/// OnlineStream in parallel. +/// +/// Caution: The caller has to ensure each OnlineStream is ready, i.e., +/// IsKeywordStreamReady() for that stream should return 1. +/// +/// @param spotter A pointer returned by CreateKeywordSpotter() +/// @param streams A pointer array containing pointers returned by +/// CreateKeywordStream() +/// @param n Number of elements in the given streams array. +SHERPA_ONNX_API void DecodeMultipleKeywordStreams( + SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams, + int32_t n); + +/// Get the decoding results so far for an OnlineStream. +/// +/// @param recognizer A pointer returned by CreateKeywordSpotter(). +/// @param stream A pointer returned by CreateKeywordStream(). +/// @return A pointer containing the result. The user has to invoke +/// DestroyKeywordResult() to free the returned pointer to +/// avoid memory leak. +SHERPA_ONNX_API const SherpaOnnxKeywordResult *GetKeywordResult( + SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream); + +/// Destroy the pointer returned by GetKeywordResult(). +/// +/// @param r A pointer returned by GetKeywordResult() +SHERPA_ONNX_API void DestroyKeywordResult( + const SherpaOnnxKeywordResult *r); + // ============================================================ // For VAD // ============================================================ From a5b2599ebdd3be4fe3d0c5c42c1969c15dd379d9 Mon Sep 17 00:00:00 2001 From: xinhecuican Date: Thu, 7 Mar 2024 16:58:30 +0800 Subject: [PATCH 2/4] fix format --- sherpa-onnx/c-api/c-api.cc | 9 +++++++-- sherpa-onnx/c-api/c-api.h | 6 +++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 00d1e1992..3378e13e0 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -473,6 +473,11 @@ SherpaOnnxKeywordSpotter* CreateKeywordSpotter( SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str()); } + if (!spotter_config.Validate()) { + SHERPA_ONNX_LOGE("Errors in config!"); + return nullptr; + } + SherpaOnnxKeywordSpotter* spotter = new SherpaOnnxKeywordSpotter; spotter->impl = @@ -506,8 +511,8 @@ void DecodeMultipleKeywordStreams( SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams, int32_t n) { std::vector ss(n); - for(int32_t i=0; i!=n; ++i) { - ss[i] = streams[i]->impl.get(); + for (int32_t i = 0; i != n; ++i) { + ss[i] = streams[i]->impl.get(); } spotter->impl->DecodeStreams(ss.data(), n); } diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 4af9727df..cce7ef6a5 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -494,7 +494,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult { /// timestamps.size() == tokens.size() /// timestamps[i] records the time in seconds when tokens[i] is decoded. - float* timestamps; + const float* timestamps; /// Starting time of this segment. /// When an endpoint is detected, it will change @@ -511,7 +511,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult { * } */ const char* json; -}SherpaOnnxKeywordResult ; +} SherpaOnnxKeywordResult; SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig { SherpaOnnxFeatureConfig feat_config; @@ -521,7 +521,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig { float keywords_score; float keywords_threshold; const char* keywords_file; -}SherpaOnnxKeywordSpotterConfig ; +} SherpaOnnxKeywordSpotterConfig ; SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter SherpaOnnxKeywordSpotter; From a0d5fa2c07dfc583273a979ba75a9ca60cfa248f Mon Sep 17 00:00:00 2001 From: xinhecuican Date: Thu, 7 Mar 2024 18:34:43 +0800 Subject: [PATCH 3/4] bug fix --- sherpa-onnx/c-api/c-api.cc | 4 ++-- sherpa-onnx/c-api/c-api.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 3378e13e0..2de9c280a 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -519,7 +519,7 @@ void DecodeMultipleKeywordStreams( const SherpaOnnxKeywordResult *GetKeywordResult( SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream) { - sherpa_onnx::KeywordResult result = + const sherpa_onnx::KeywordResult& result = spotter->impl->GetResult(stream->impl.get()); const auto &keyword = result.keyword; @@ -565,7 +565,7 @@ const SherpaOnnxKeywordResult *GetKeywordResult( r->tokens_arr = tokens_temp; if (!result.timestamps.empty()) { - r->timestamps = new float[r->count]; + r->timestamps = new float[result.timestamps.size()]; std::copy(result.timestamps.begin(), result.timestamps.end(), r->timestamps); } else { diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index cce7ef6a5..1c16f6579 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -494,7 +494,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult { /// timestamps.size() == tokens.size() /// timestamps[i] records the time in seconds when tokens[i] is decoded. - const float* timestamps; + float* timestamps; /// Starting time of this segment. /// When an endpoint is detected, it will change From cb52ae17c7e0091981fe25a734b77896f47898fb Mon Sep 17 00:00:00 2001 From: xinhecuican Date: Fri, 8 Mar 2024 11:02:36 +0800 Subject: [PATCH 4/4] remove whitespace --- sherpa-onnx/c-api/c-api.cc | 6 +++--- sherpa-onnx/c-api/c-api.h | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 2de9c280a..4a30dae94 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -473,9 +473,9 @@ SherpaOnnxKeywordSpotter* CreateKeywordSpotter( SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str()); } - if (!spotter_config.Validate()) { - SHERPA_ONNX_LOGE("Errors in config!"); - return nullptr; + if (!spotter_config.Validate()) { + SHERPA_ONNX_LOGE("Errors in config!"); + return nullptr; } SherpaOnnxKeywordSpotter* spotter = new SherpaOnnxKeywordSpotter; diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 1c16f6579..a6a7389c2 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -521,7 +521,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig { float keywords_score; float keywords_threshold; const char* keywords_file; -} SherpaOnnxKeywordSpotterConfig ; +} SherpaOnnxKeywordSpotterConfig; SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter SherpaOnnxKeywordSpotter;