diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 407b359ac..4a30dae94 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -17,6 +17,7 @@ #include "sherpa-onnx/csrc/online-recognizer.h" #include "sherpa-onnx/csrc/voice-activity-detector.h" #include "sherpa-onnx/csrc/wave-writer.h" +#include "sherpa-onnx/csrc/keyword-spotter.h" struct SherpaOnnxOnlineRecognizer { std::unique_ptr impl; @@ -410,6 +411,189 @@ void DestroyOfflineRecognizerResult( } } +// ============================================================ +// For Keyword Spot +// ============================================================ + +struct SherpaOnnxKeywordSpotter { + std::unique_ptr impl; +}; + +SherpaOnnxKeywordSpotter* CreateKeywordSpotter( + const SherpaOnnxKeywordSpotterConfig* config) { + sherpa_onnx::KeywordSpotterConfig spotter_config; + + spotter_config.feat_config.sampling_rate = + SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000); + spotter_config.feat_config.feature_dim = + SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); + + spotter_config.model_config.transducer.encoder = + SHERPA_ONNX_OR(config->model_config.transducer.encoder, ""); + spotter_config.model_config.transducer.decoder = + SHERPA_ONNX_OR(config->model_config.transducer.decoder, ""); + spotter_config.model_config.transducer.joiner = + SHERPA_ONNX_OR(config->model_config.transducer.joiner, ""); + + spotter_config.model_config.paraformer.encoder = + SHERPA_ONNX_OR(config->model_config.paraformer.encoder, ""); + spotter_config.model_config.paraformer.decoder = + SHERPA_ONNX_OR(config->model_config.paraformer.decoder, ""); + + spotter_config.model_config.zipformer2_ctc.model = + SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model, ""); + + spotter_config.model_config.tokens = + SHERPA_ONNX_OR(config->model_config.tokens, ""); + spotter_config.model_config.num_threads = + SHERPA_ONNX_OR(config->model_config.num_threads, 1); + spotter_config.model_config.provider = + SHERPA_ONNX_OR(config->model_config.provider, "cpu"); + spotter_config.model_config.model_type = + SHERPA_ONNX_OR(config->model_config.model_type, ""); + spotter_config.model_config.debug = + SHERPA_ONNX_OR(config->model_config.debug, 0); + + spotter_config.max_active_paths = + SHERPA_ONNX_OR(config->max_active_paths, 4); + + spotter_config.num_trailing_blanks = + SHERPA_ONNX_OR(config->num_trailing_blanks , 1); + + spotter_config.keywords_score = + SHERPA_ONNX_OR(config->keywords_score, 1.0); + + spotter_config.keywords_threshold = + SHERPA_ONNX_OR(config->keywords_threshold, 0.25); + + spotter_config.keywords_file = + SHERPA_ONNX_OR(config->keywords_file, ""); + + if (config->model_config.debug) { + SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str()); + } + + if (!spotter_config.Validate()) { + SHERPA_ONNX_LOGE("Errors in config!"); + return nullptr; + } + + SherpaOnnxKeywordSpotter* spotter = new SherpaOnnxKeywordSpotter; + + spotter->impl = + std::make_unique(spotter_config); + + return spotter; +} + +void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter* spotter) { + delete spotter; +} + +SherpaOnnxOnlineStream* CreateKeywordStream( + const SherpaOnnxKeywordSpotter* spotter) { + SherpaOnnxOnlineStream* stream = + new SherpaOnnxOnlineStream(spotter->impl->CreateStream()); + return stream; +} + +int32_t IsKeywordStreamReady( + SherpaOnnxKeywordSpotter* spotter, SherpaOnnxOnlineStream* stream) { + return spotter->impl->IsReady(stream->impl.get()); +} + +void DecodeKeywordStream(SherpaOnnxKeywordSpotter* spotter, + SherpaOnnxOnlineStream* stream) { + return spotter->impl->DecodeStream(stream->impl.get()); +} + +void DecodeMultipleKeywordStreams( + SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams, + int32_t n) { + std::vector ss(n); + for (int32_t i = 0; i != n; ++i) { + ss[i] = streams[i]->impl.get(); + } + spotter->impl->DecodeStreams(ss.data(), n); +} + +const SherpaOnnxKeywordResult *GetKeywordResult( + SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream) { + const sherpa_onnx::KeywordResult& result = + spotter->impl->GetResult(stream->impl.get()); + const auto &keyword = result.keyword; + + auto r = new SherpaOnnxKeywordResult; + memset(r, 0, sizeof(SherpaOnnxKeywordResult)); + + r->start_time = result.start_time; + + // copy keyword + r->keyword = new char[keyword.size() + 1]; + std::copy(keyword.begin(), keyword.end(), const_cast(r->keyword)); + const_cast(r->keyword)[keyword.size()] = 0; + + // copy json + const auto &json = result.AsJsonString(); + r->json = new char[json.size() + 1]; + std::copy(json.begin(), json.end(), const_cast(r->json)); + const_cast(r->json)[json.size()] = 0; + + // copy tokens + auto count = result.tokens.size(); + if (count > 0) { + size_t total_length = 0; + for (const auto &token : result.tokens) { + // +1 for the null character at the end of each token + total_length += token.size() + 1; + } + + r->count = count; + // Each word ends with nullptr + r->tokens = new char[total_length]; + memset(reinterpret_cast(const_cast(r->tokens)), 0, + total_length); + char **tokens_temp = new char *[r->count]; + int32_t pos = 0; + for (int32_t i = 0; i < r->count; ++i) { + tokens_temp[i] = const_cast(r->tokens) + pos; + memcpy(reinterpret_cast(const_cast(r->tokens + pos)), + result.tokens[i].c_str(), result.tokens[i].size()); + // +1 to move past the null character + pos += result.tokens[i].size() + 1; + } + r->tokens_arr = tokens_temp; + + if (!result.timestamps.empty()) { + r->timestamps = new float[result.timestamps.size()]; + std::copy(result.timestamps.begin(), result.timestamps.end(), + r->timestamps); + } else { + r->timestamps = nullptr; + } + + } else { + r->count = 0; + r->timestamps = nullptr; + r->tokens = nullptr; + r->tokens_arr = nullptr; + } + + return r; +} + +void DestroyKeywordResult(const SherpaOnnxKeywordResult *r) { + if (r) { + delete[] r->keyword; + delete[] r->json; + delete[] r->tokens; + delete[] r->tokens_arr; + delete[] r->timestamps; + delete r; + } +} + + // ============================================================ // For VAD // ============================================================ diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 62b2f4dcd..a6a7389c2 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -473,6 +473,123 @@ SHERPA_ONNX_API const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult( SHERPA_ONNX_API void DestroyOfflineRecognizerResult( const SherpaOnnxOfflineRecognizerResult *r); +// ============================================================ +// For Keyword Spot +// ============================================================ +SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult { + /// The triggered keyword. + /// For English, it consists of space separated words. + /// For Chinese, it consists of Chinese words without spaces. + /// Example 1: "hello world" + /// Example 2: "你好世界" + const char* keyword; + + /// Decoded results at the token level. + /// For instance, for BPE-based models it consists of a list of BPE tokens. + const char* tokens; + + const char* const* tokens_arr; + + int32_t count; + + /// timestamps.size() == tokens.size() + /// timestamps[i] records the time in seconds when tokens[i] is decoded. + float* timestamps; + + /// Starting time of this segment. + /// When an endpoint is detected, it will change + float start_time; + + /** Return a json string. + * + * The returned string contains: + * { + * "keyword": "The triggered keyword", + * "tokens": [x, x, x], + * "timestamps": [x, x, x], + * "start_time": x, + * } + */ + const char* json; +} SherpaOnnxKeywordResult; + +SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig { + SherpaOnnxFeatureConfig feat_config; + SherpaOnnxOnlineModelConfig model_config; + int32_t max_active_paths; + int32_t num_trailing_blanks; + float keywords_score; + float keywords_threshold; + const char* keywords_file; +} SherpaOnnxKeywordSpotterConfig; + +SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter + SherpaOnnxKeywordSpotter; + +/// @param config Config for the keyword spotter. +/// @return Return a pointer to the spotter. The user has to invoke +/// DestroyKeywordSpotter() to free it to avoid memory leak. +SHERPA_ONNX_API SherpaOnnxKeywordSpotter* CreateKeywordSpotter( + const SherpaOnnxKeywordSpotterConfig* config); + +/// Free a pointer returned by CreateKeywordSpotter() +/// +/// @param p A pointer returned by CreateKeywordSpotter() +SHERPA_ONNX_API void DestroyKeywordSpotter( + SherpaOnnxKeywordSpotter* spotter); + +/// Create an online stream for accepting wave samples. +/// +/// @param spotter A pointer returned by CreateKeywordSpotter() +/// @return Return a pointer to an OnlineStream. The user has to invoke +/// DestroyOnlineStream() to free it to avoid memory leak. +SHERPA_ONNX_API SherpaOnnxOnlineStream* CreateKeywordStream( + const SherpaOnnxKeywordSpotter* spotter); + +/// Return 1 if there are enough number of feature frames for decoding. +/// Return 0 otherwise. +/// +/// @param spotter A pointer returned by CreateKeywordSpotter +/// @param stream A pointer returned by CreateKeywordStream +SHERPA_ONNX_API int32_t IsKeywordStreamReady( + SherpaOnnxKeywordSpotter* spotter, SherpaOnnxOnlineStream* stream); + +/// Call this function to run the neural network model and decoding. +// +/// Precondition for this function: IsKeywordStreamReady() MUST return 1. +SHERPA_ONNX_API void DecodeKeywordStream(SherpaOnnxKeywordSpotter* spotter, + SherpaOnnxOnlineStream* stream); + +/// This function is similar to DecodeKeywordStream(). It decodes multiple +/// OnlineStream in parallel. +/// +/// Caution: The caller has to ensure each OnlineStream is ready, i.e., +/// IsKeywordStreamReady() for that stream should return 1. +/// +/// @param spotter A pointer returned by CreateKeywordSpotter() +/// @param streams A pointer array containing pointers returned by +/// CreateKeywordStream() +/// @param n Number of elements in the given streams array. +SHERPA_ONNX_API void DecodeMultipleKeywordStreams( + SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams, + int32_t n); + +/// Get the decoding results so far for an OnlineStream. +/// +/// @param recognizer A pointer returned by CreateKeywordSpotter(). +/// @param stream A pointer returned by CreateKeywordStream(). +/// @return A pointer containing the result. The user has to invoke +/// DestroyKeywordResult() to free the returned pointer to +/// avoid memory leak. +SHERPA_ONNX_API const SherpaOnnxKeywordResult *GetKeywordResult( + SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream); + +/// Destroy the pointer returned by GetKeywordResult(). +/// +/// @param r A pointer returned by GetKeywordResult() +SHERPA_ONNX_API void DestroyKeywordResult( + const SherpaOnnxKeywordResult *r); + // ============================================================ // For VAD // ============================================================