Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep model internal states for streaming CTC zipformer models on endpointing #781

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions c-api-examples/asr-microphone-example/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,18 @@
add_executable(c-api-alsa c-api-alsa.cc alsa.cc)
target_link_libraries(c-api-alsa sherpa-onnx-c-api cargs)

if(DEFINED ENV{SHERPA_ONNX_ALSA_LIB_DIR})
target_link_libraries(c-api-alsa -L$ENV{SHERPA_ONNX_ALSA_LIB_DIR} -lasound)
else()
target_link_libraries(c-api-alsa asound)
endif()
add_executable(hlg-c-api-alsa hlg-c-api-alsa.cc alsa.cc)
target_link_libraries(hlg-c-api-alsa sherpa-onnx-c-api cargs)

set(exes
c-api-alsa
hlg-c-api-alsa
)

foreach(exe IN LISTS exes)
if(DEFINED ENV{SHERPA_ONNX_ALSA_LIB_DIR})
target_link_libraries(${exe} -L$ENV{SHERPA_ONNX_ALSA_LIB_DIR} -lasound)
else()
target_link_libraries(${exe} asound)
endif()
endforeach()
151 changes: 151 additions & 0 deletions c-api-examples/asr-microphone-example/hlg-c-api-alsa.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
// c-api-examples/asr-microphone-example/c-api-alsa.cc
// Copyright (c) 2022-2024 Xiaomi Corporation

#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <algorithm>
#include <cctype> // std::tolower
#include <cstdint>
#include <string>

#include "c-api-examples/asr-microphone-example/alsa.h"
#include "sherpa-onnx/c-api/c-api.h"

bool stop = false;

static void Handler(int sig) {
stop = true;
fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n");
}

int32_t main(int32_t argc, char *argv[]) {
signal(SIGINT, Handler);

if (argc != 2) {
fprintf(stderr, R"(Please provide the device name.

The device name specifies which microphone to use in case there are several
on your system. You can use

arecord -l

to find all available microphones on your computer. For instance, if it outputs

**** List of CAPTURE Hardware Devices ****
card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio]
Subdevices: 1/1
Subdevice #0: subdevice #0

and if you want to select card 3 and device 0 on that card, please use:

plughw:3,0

as the device_name.
)");
return -1;
}

// clang-format off
//
// Please download the model from
// https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
const char *model = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx";
const char *tokens = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt";
const char *graph = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst";
graph = "";
// clang-format on

SherpaOnnxOnlineRecognizerConfig config;

memset(&config, 0, sizeof(config));
config.feat_config.sample_rate = 16000;
config.feat_config.feature_dim = 80;
config.model_config.zipformer2_ctc.model = model;
config.model_config.tokens = tokens;
config.model_config.num_threads = 1;
config.model_config.provider = "cpu";
config.model_config.debug = 0;
config.ctc_fst_decoder_config.graph = graph;

config.enable_endpoint = 1;
config.rule1_min_trailing_silence = 2.4;
config.rule2_min_trailing_silence = 1.2;
config.rule3_min_utterance_length = 300;

const SherpaOnnxOnlineRecognizer *recognizer =
CreateOnlineRecognizer(&config);
if (!recognizer) {
fprintf(stderr, "Failed to create recognizer");
exit(-1);
}

const SherpaOnnxOnlineStream *stream = CreateOnlineStream(recognizer);

const SherpaOnnxDisplay *display = CreateDisplay(50);
int32_t segment_id = 0;

// please use arecord -l to find your device
const char *device_name = argv[1];
sherpa_onnx::Alsa alsa(device_name);
fprintf(stderr, "Use recording device: %s\n", device_name);
fprintf(stderr,
"Please \033[32m\033[1mspeak\033[0m! Press \033[31m\033[1mCtrl + "
"C\033[0m to exit\n");

int32_t expected_sample_rate = 16000;

if (alsa.GetExpectedSampleRate() != expected_sample_rate) {
fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(),
expected_sample_rate);
exit(-1);
}

int32_t chunk = 0.1 * alsa.GetActualSampleRate();

std::string last_text;

int32_t segment_index = 0;

while (!stop) {
const std::vector<float> &samples = alsa.Read(chunk);
AcceptWaveform(stream, expected_sample_rate, samples.data(),
samples.size());
while (IsOnlineStreamReady(recognizer, stream)) {
DecodeOnlineStream(recognizer, stream);
}

const SherpaOnnxOnlineRecognizerResult *r =
GetOnlineStreamResult(recognizer, stream);

std::string text = r->text;
DestroyOnlineRecognizerResult(r);

if (!text.empty() && last_text != text) {
last_text = text;

std::transform(text.begin(), text.end(), text.begin(),
[](auto c) { return std::tolower(c); });

SherpaOnnxPrint(display, segment_index, text.c_str());
fflush(stderr);
}

if (IsEndpoint(recognizer, stream)) {
if (!text.empty()) {
++segment_index;
}
Reset(recognizer, stream);
}
}

// free allocated resources
DestroyDisplay(display);
DestroyOnlineStream(stream);
DestroyOnlineRecognizer(recognizer);
fprintf(stderr, "\n");

return 0;
}
3 changes: 3 additions & 0 deletions sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ void OnlineCtcGreedySearchDecoder::Decode(
auto &r = (*results)[b];

int32_t prev_id = -1;
if (!r.tokens.empty()) {
prev_id = r.tokens.back();
}

for (int32_t t = 0; t != num_frames; ++t, p += vocab_size) {
int32_t y = static_cast<int32_t>(std::distance(
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/online-recognizer-ctc-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
s->SetCtcResult({});

// clear states
s->SetStates(model_->GetInitStates());
// s->SetStates(model_->GetInitStates());

// Note: We only update counters. The underlying audio samples
// are not discarded.
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/speaker-embedding-manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p,
}

float SpeakerEmbeddingManager::Score(const std::string &name,
const float *p) const {
const float *p) const {
return impl_->Score(name, p);
}

Expand Down
Loading