Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

c++ api for keyword spotter #642

Merged
merged 4 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/voice-activity-detector.h"
#include "sherpa-onnx/csrc/wave-writer.h"
#include "sherpa-onnx/csrc/keyword-spotter.h"

struct SherpaOnnxOnlineRecognizer {
std::unique_ptr<sherpa_onnx::OnlineRecognizer> impl;
Expand Down Expand Up @@ -410,6 +411,189 @@ void DestroyOfflineRecognizerResult(
}
}

// ============================================================
// For Keyword Spot
// ============================================================

struct SherpaOnnxKeywordSpotter {
std::unique_ptr<sherpa_onnx::KeywordSpotter> impl;
};

SherpaOnnxKeywordSpotter* CreateKeywordSpotter(
const SherpaOnnxKeywordSpotterConfig* config) {
sherpa_onnx::KeywordSpotterConfig spotter_config;

spotter_config.feat_config.sampling_rate =
SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000);
spotter_config.feat_config.feature_dim =
SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);

spotter_config.model_config.transducer.encoder =
SHERPA_ONNX_OR(config->model_config.transducer.encoder, "");
spotter_config.model_config.transducer.decoder =
SHERPA_ONNX_OR(config->model_config.transducer.decoder, "");
spotter_config.model_config.transducer.joiner =
SHERPA_ONNX_OR(config->model_config.transducer.joiner, "");

spotter_config.model_config.paraformer.encoder =
SHERPA_ONNX_OR(config->model_config.paraformer.encoder, "");
spotter_config.model_config.paraformer.decoder =
SHERPA_ONNX_OR(config->model_config.paraformer.decoder, "");

spotter_config.model_config.zipformer2_ctc.model =
SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model, "");

spotter_config.model_config.tokens =
SHERPA_ONNX_OR(config->model_config.tokens, "");
spotter_config.model_config.num_threads =
SHERPA_ONNX_OR(config->model_config.num_threads, 1);
spotter_config.model_config.provider =
SHERPA_ONNX_OR(config->model_config.provider, "cpu");
spotter_config.model_config.model_type =
SHERPA_ONNX_OR(config->model_config.model_type, "");
spotter_config.model_config.debug =
SHERPA_ONNX_OR(config->model_config.debug, 0);

spotter_config.max_active_paths =
SHERPA_ONNX_OR(config->max_active_paths, 4);

spotter_config.num_trailing_blanks =
SHERPA_ONNX_OR(config->num_trailing_blanks , 1);

spotter_config.keywords_score =
SHERPA_ONNX_OR(config->keywords_score, 1.0);

spotter_config.keywords_threshold =
SHERPA_ONNX_OR(config->keywords_threshold, 0.25);

spotter_config.keywords_file =
SHERPA_ONNX_OR(config->keywords_file, "");

if (config->model_config.debug) {
SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str());
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add

  if (!spotter_config.Validate()) {
    SHERPA_ONNX_LOGE("Errors in config!");
    return nullptr;
  }

if (!spotter_config.Validate()) {
SHERPA_ONNX_LOGE("Errors in config!");
return nullptr;
}

SherpaOnnxKeywordSpotter* spotter = new SherpaOnnxKeywordSpotter;

spotter->impl =
std::make_unique<sherpa_onnx::KeywordSpotter>(spotter_config);

return spotter;
}

void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter* spotter) {
delete spotter;
}

SherpaOnnxOnlineStream* CreateKeywordStream(
const SherpaOnnxKeywordSpotter* spotter) {
SherpaOnnxOnlineStream* stream =
new SherpaOnnxOnlineStream(spotter->impl->CreateStream());
return stream;
}

int32_t IsKeywordStreamReady(
SherpaOnnxKeywordSpotter* spotter, SherpaOnnxOnlineStream* stream) {
return spotter->impl->IsReady(stream->impl.get());
}

void DecodeKeywordStream(SherpaOnnxKeywordSpotter* spotter,
SherpaOnnxOnlineStream* stream) {
return spotter->impl->DecodeStream(stream->impl.get());
}

void DecodeMultipleKeywordStreams(
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams,
int32_t n) {
std::vector<sherpa_onnx::OnlineStream*> ss(n);
for (int32_t i = 0; i != n; ++i) {
ss[i] = streams[i]->impl.get();
}
spotter->impl->DecodeStreams(ss.data(), n);
}

const SherpaOnnxKeywordResult *GetKeywordResult(
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream) {
const sherpa_onnx::KeywordResult& result =
spotter->impl->GetResult(stream->impl.get());
const auto &keyword = result.keyword;

auto r = new SherpaOnnxKeywordResult;
memset(r, 0, sizeof(SherpaOnnxKeywordResult));

r->start_time = result.start_time;

// copy keyword
r->keyword = new char[keyword.size() + 1];
std::copy(keyword.begin(), keyword.end(), const_cast<char *>(r->keyword));
const_cast<char *>(r->keyword)[keyword.size()] = 0;

// copy json
const auto &json = result.AsJsonString();
r->json = new char[json.size() + 1];
std::copy(json.begin(), json.end(), const_cast<char *>(r->json));
const_cast<char *>(r->json)[json.size()] = 0;

// copy tokens
auto count = result.tokens.size();
if (count > 0) {
size_t total_length = 0;
for (const auto &token : result.tokens) {
// +1 for the null character at the end of each token
total_length += token.size() + 1;
}

r->count = count;
// Each word ends with nullptr
r->tokens = new char[total_length];
memset(reinterpret_cast<void *>(const_cast<char *>(r->tokens)), 0,
total_length);
char **tokens_temp = new char *[r->count];
int32_t pos = 0;
for (int32_t i = 0; i < r->count; ++i) {
tokens_temp[i] = const_cast<char *>(r->tokens) + pos;
memcpy(reinterpret_cast<void *>(const_cast<char *>(r->tokens + pos)),
result.tokens[i].c_str(), result.tokens[i].size());
// +1 to move past the null character
pos += result.tokens[i].size() + 1;
}
r->tokens_arr = tokens_temp;

if (!result.timestamps.empty()) {
r->timestamps = new float[result.timestamps.size()];
std::copy(result.timestamps.begin(), result.timestamps.end(),
r->timestamps);
} else {
r->timestamps = nullptr;
}

} else {
r->count = 0;
r->timestamps = nullptr;
r->tokens = nullptr;
r->tokens_arr = nullptr;
}

return r;
}

void DestroyKeywordResult(const SherpaOnnxKeywordResult *r) {
if (r) {
delete[] r->keyword;
delete[] r->json;
delete[] r->tokens;
delete[] r->tokens_arr;
delete[] r->timestamps;
delete r;
}
}


// ============================================================
// For VAD
// ============================================================
Expand Down
117 changes: 117 additions & 0 deletions sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,123 @@ SHERPA_ONNX_API const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult(
SHERPA_ONNX_API void DestroyOfflineRecognizerResult(
const SherpaOnnxOfflineRecognizerResult *r);

// ============================================================
// For Keyword Spot
// ============================================================
SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult {
/// The triggered keyword.
/// For English, it consists of space separated words.
/// For Chinese, it consists of Chinese words without spaces.
/// Example 1: "hello world"
/// Example 2: "你好世界"
const char* keyword;

/// Decoded results at the token level.
/// For instance, for BPE-based models it consists of a list of BPE tokens.
const char* tokens;

const char* const* tokens_arr;

int32_t count;

/// timestamps.size() == tokens.size()
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
float* timestamps;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
float* timestamps;
const float* timestamps;


/// Starting time of this segment.
/// When an endpoint is detected, it will change
float start_time;

/** Return a json string.
*
* The returned string contains:
* {
* "keyword": "The triggered keyword",
* "tokens": [x, x, x],
* "timestamps": [x, x, x],
* "start_time": x,
* }
*/
const char* json;
} SherpaOnnxKeywordResult;

SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig {
SherpaOnnxFeatureConfig feat_config;
SherpaOnnxOnlineModelConfig model_config;
int32_t max_active_paths;
int32_t num_trailing_blanks;
float keywords_score;
float keywords_threshold;
const char* keywords_file;
} SherpaOnnxKeywordSpotterConfig;

SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter
SherpaOnnxKeywordSpotter;

/// @param config Config for the keyword spotter.
/// @return Return a pointer to the spotter. The user has to invoke
/// DestroyKeywordSpotter() to free it to avoid memory leak.
SHERPA_ONNX_API SherpaOnnxKeywordSpotter* CreateKeywordSpotter(
const SherpaOnnxKeywordSpotterConfig* config);

/// Free a pointer returned by CreateKeywordSpotter()
///
/// @param p A pointer returned by CreateKeywordSpotter()
SHERPA_ONNX_API void DestroyKeywordSpotter(
SherpaOnnxKeywordSpotter* spotter);

/// Create an online stream for accepting wave samples.
///
/// @param spotter A pointer returned by CreateKeywordSpotter()
/// @return Return a pointer to an OnlineStream. The user has to invoke
/// DestroyOnlineStream() to free it to avoid memory leak.
SHERPA_ONNX_API SherpaOnnxOnlineStream* CreateKeywordStream(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you prefix every function name with SherpaOnnx?
That is, replace CreateKeywordStream with SherpaOnnxCreateKeywordStream.


I think it was an error to not use SherpaOnnx for online and offline ASR C APIs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after add this many line exceed 80 char, I tried the style is not uniform
and change api name means must modify code of other project

const SherpaOnnxKeywordSpotter* spotter);

/// Return 1 if there are enough number of feature frames for decoding.
/// Return 0 otherwise.
///
/// @param spotter A pointer returned by CreateKeywordSpotter
/// @param stream A pointer returned by CreateKeywordStream
SHERPA_ONNX_API int32_t IsKeywordStreamReady(
SherpaOnnxKeywordSpotter* spotter, SherpaOnnxOnlineStream* stream);

/// Call this function to run the neural network model and decoding.
//
/// Precondition for this function: IsKeywordStreamReady() MUST return 1.
SHERPA_ONNX_API void DecodeKeywordStream(SherpaOnnxKeywordSpotter* spotter,
SherpaOnnxOnlineStream* stream);

/// This function is similar to DecodeKeywordStream(). It decodes multiple
/// OnlineStream in parallel.
///
/// Caution: The caller has to ensure each OnlineStream is ready, i.e.,
/// IsKeywordStreamReady() for that stream should return 1.
///
/// @param spotter A pointer returned by CreateKeywordSpotter()
/// @param streams A pointer array containing pointers returned by
/// CreateKeywordStream()
/// @param n Number of elements in the given streams array.
SHERPA_ONNX_API void DecodeMultipleKeywordStreams(
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams,
int32_t n);

/// Get the decoding results so far for an OnlineStream.
///
/// @param recognizer A pointer returned by CreateKeywordSpotter().
/// @param stream A pointer returned by CreateKeywordStream().
/// @return A pointer containing the result. The user has to invoke
/// DestroyKeywordResult() to free the returned pointer to
/// avoid memory leak.
SHERPA_ONNX_API const SherpaOnnxKeywordResult *GetKeywordResult(
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream);

/// Destroy the pointer returned by GetKeywordResult().
///
/// @param r A pointer returned by GetKeywordResult()
SHERPA_ONNX_API void DestroyKeywordResult(
const SherpaOnnxKeywordResult *r);

// ============================================================
// For VAD
// ============================================================
Expand Down
Loading