Skip to content

Commit

Permalink
Allow modify model config at decode time for ASR (#1124)
Browse files Browse the repository at this point in the history
  • Loading branch information
iprovalo authored Jul 13, 2024
1 parent ab71c39 commit de04b3b
Show file tree
Hide file tree
Showing 15 changed files with 121 additions and 13 deletions.
45 changes: 35 additions & 10 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,27 @@ struct SherpaOnnxOfflineStream {
: impl(std::move(p)) {}
};

static sherpa_onnx::OfflineRecognizerConfig convertConfig(
const SherpaOnnxOfflineRecognizerConfig *config);
SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
const SherpaOnnxOfflineRecognizerConfig *config) {
sherpa_onnx::OfflineRecognizerConfig recognizer_config =
convertConfig(config);

if (!recognizer_config.Validate()) {
SHERPA_ONNX_LOGE("Errors in config");
return nullptr;
}

SherpaOnnxOfflineRecognizer *recognizer = new SherpaOnnxOfflineRecognizer;

recognizer->impl =
std::make_unique<sherpa_onnx::OfflineRecognizer>(recognizer_config);

return recognizer;
}
sherpa_onnx::OfflineRecognizerConfig convertConfig(
const SherpaOnnxOfflineRecognizerConfig *config) {
sherpa_onnx::OfflineRecognizerConfig recognizer_config;

recognizer_config.feat_config.sampling_rate =
Expand Down Expand Up @@ -398,17 +417,15 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
SHERPA_ONNX_LOGE("%s", recognizer_config.ToString().c_str());
}

if (!recognizer_config.Validate()) {
SHERPA_ONNX_LOGE("Errors in config");
return nullptr;
}

SherpaOnnxOfflineRecognizer *recognizer = new SherpaOnnxOfflineRecognizer;

recognizer->impl =
std::make_unique<sherpa_onnx::OfflineRecognizer>(recognizer_config);
return recognizer_config;
}

return recognizer;
void SherpaOnnxOfflineRecognizerSetConfig(
const SherpaOnnxOfflineRecognizer *recognizer,
const SherpaOnnxOfflineRecognizerConfig *config){
sherpa_onnx::OfflineRecognizerConfig recognizer_config =
convertConfig(config);
recognizer->impl->SetConfig(recognizer_config);
}

void DestroyOfflineRecognizer(SherpaOnnxOfflineRecognizer *recognizer) {
Expand Down Expand Up @@ -461,6 +478,13 @@ const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult(
pText[text.size()] = 0;
r->text = pText;

//lang
const auto &lang = result.lang;
char *c_lang = new char[lang.size() + 1];
std::copy(lang.begin(), lang.end(), c_lang);
c_lang[lang.size()] = '\0';
r->lang = c_lang;

// copy json
std::string json = result.AsJsonString();
char *pJson = new char[json.size() + 1];
Expand Down Expand Up @@ -517,6 +541,7 @@ void DestroyOfflineRecognizerResult(
delete[] r->tokens;
delete[] r->tokens_arr;
delete[] r->json;
delete[] r->lang;
delete r;
}
}
Expand Down
11 changes: 10 additions & 1 deletion sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineStream SherpaOnnxOfflineStream;
SHERPA_ONNX_API SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
const SherpaOnnxOfflineRecognizerConfig *config);

/// @param config Config for the recognizer.
SHERPA_ONNX_API void SherpaOnnxOfflineRecognizerSetConfig(
const SherpaOnnxOfflineRecognizer *recognizer,
const SherpaOnnxOfflineRecognizerConfig *config);

/// Free a pointer returned by CreateOfflineRecognizer()
///
/// @param p A pointer returned by CreateOfflineRecognizer()
Expand Down Expand Up @@ -491,7 +496,7 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams(
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
const char *text;

// Pointer to continuous memory which holds timestamps
// Pointer to continuous memory which holds timestamps
//
// It is NULL if the model does not support timestamps
float *timestamps;
Expand Down Expand Up @@ -519,6 +524,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
* }
*/
const char *json;

//return recognized language
const char *lang;

} SherpaOnnxOfflineRecognizerResult;

/// Get the result of the offline stream.
Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}
}

OfflineRecognizerConfig GetConfig() const override {
return config_;
}


private:
// Decode a single stream.
// Some models do not support batch size > 1, e.g., WeNet CTC models.
Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -431,4 +431,8 @@ std::string OfflineRecognizerImpl::ApplyInverseTextNormalization(
return text;
}

void OfflineRecognizerImpl::SetConfig(const OfflineRecognizerConfig &config) {
config_ = config;
}

} // namespace sherpa_onnx
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class OfflineRecognizerImpl {

virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0;

virtual void SetConfig(const OfflineRecognizerConfig &config);

virtual OfflineRecognizerConfig GetConfig() const = 0;

std::string ApplyInverseTextNormalization(std::string text) const;

private:
Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
}
}

OfflineRecognizerConfig GetConfig() const override {
return config_;
}

private:
std::vector<float> ApplyLFR(const std::vector<float> &in) const {
int32_t lfr_window_size = model_->LfrWindowSize();
Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,11 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
}
}

OfflineRecognizerConfig GetConfig() const override {
return config_;
}


void InitHotwords() {
// each line in hotwords_file contains space-separated words

Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
}
}

OfflineRecognizerConfig GetConfig() const override {
return config_;
}

private:
void PostInit() {
config_.feat_config.nemo_normalize_type =
Expand Down
11 changes: 11 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
}

r.text = text;
r.lang = src.lang;

return r;
}
Expand Down Expand Up @@ -100,8 +101,18 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
}
}

void SetConfig(const OfflineRecognizerConfig &config) override {
config_.model_config.whisper = config.model_config.whisper;
}

OfflineRecognizerConfig GetConfig() const override {
return config_;
}

private:
void DecodeStream(OfflineStream *s) const {
decoder_->SetConfig(config_.model_config.whisper);

int32_t max_num_frames = 3000;
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
Expand Down
8 changes: 8 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,12 @@ void OfflineRecognizer::DecodeStreams(OfflineStream **ss, int32_t n) const {
impl_->DecodeStreams(ss, n);
}

void OfflineRecognizer::SetConfig(const OfflineRecognizerConfig &config) {
impl_->SetConfig(config);
}

OfflineRecognizerConfig OfflineRecognizer::GetConfig() const {
return impl_->GetConfig();
}

} // namespace sherpa_onnx
9 changes: 9 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,15 @@ class OfflineRecognizer {
*/
void DecodeStreams(OfflineStream **ss, int32_t n) const;

/** Onnxruntime Session objects are not affected by this method.
* The exact behavior can be defined by a specific recognizer impl.
* For instance, for the whisper recognizer, you can retrieve the language and task from
* the config and ignore any remaining fields in `config`.
*/
void SetConfig(const OfflineRecognizerConfig &config);

OfflineRecognizerConfig GetConfig() const;

private:
std::unique_ptr<OfflineRecognizerImpl> impl_;
};
Expand Down
4 changes: 3 additions & 1 deletion sherpa-onnx/csrc/offline-stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ struct OfflineRecognitionResult {
// For instance, for BPE-based models it consists of a list of BPE tokens.
std::vector<std::string> tokens;

/// timestamps.size() == tokens.size()
std::string lang;

/// timestamps.size() == tokens.size()
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
std::vector<float> timestamps;

Expand Down
6 changes: 6 additions & 0 deletions sherpa-onnx/csrc/offline-whisper-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_

#include <vector>
#include <string>

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"

namespace sherpa_onnx {

struct OfflineWhisperDecoderResult {
/// The decoded token IDs
std::vector<int32_t> tokens;
std::string lang;
};

class OfflineWhisperDecoder {
Expand All @@ -31,6 +34,9 @@ class OfflineWhisperDecoder {
*/
virtual std::vector<OfflineWhisperDecoderResult> Decode(
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;

virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0;

};

} // namespace sherpa_onnx
Expand Down
11 changes: 11 additions & 0 deletions sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@

namespace sherpa_onnx {

void OfflineWhisperGreedySearchDecoder::SetConfig(const OfflineWhisperModelConfig &config) {
config_ = config;
}

std::vector<OfflineWhisperDecoderResult>
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
Ort::Value cross_v) {
Expand Down Expand Up @@ -129,6 +133,13 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,

std::vector<OfflineWhisperDecoderResult> ans(1);

const auto &id2lang = model_->GetID2Lang();
if (id2lang.count(initial_tokens[1])) {
ans[0].lang = id2lang.at(initial_tokens[1]);
} else {
ans[0].lang = "";
}

ans[0].tokens = std::move(predicted_tokens);

return ans;
Expand Down
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <vector>

#include "sherpa-onnx/csrc/offline-whisper-decoder.h"
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
#include "sherpa-onnx/csrc/offline-whisper-model.h"

namespace sherpa_onnx {
Expand All @@ -22,6 +21,8 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
Ort::Value cross_v) override;

void SetConfig(const OfflineWhisperModelConfig &config) override;

private:
OfflineWhisperModelConfig config_;
OfflineWhisperModel *model_; // not owned
Expand Down

0 comments on commit de04b3b

Please sign in to comment.