From de04b3b9bfc6d48a8ac340e00083d9fd5411b81e Mon Sep 17 00:00:00 2001 From: ivan provalov Date: Sat, 13 Jul 2024 07:30:47 -0700 Subject: [PATCH] Allow modify model config at decode time for ASR (#1124) --- sherpa-onnx/c-api/c-api.cc | 45 ++++++++++++++----- sherpa-onnx/c-api/c-api.h | 11 ++++- .../csrc/offline-recognizer-ctc-impl.h | 5 +++ sherpa-onnx/csrc/offline-recognizer-impl.cc | 4 ++ sherpa-onnx/csrc/offline-recognizer-impl.h | 4 ++ .../csrc/offline-recognizer-paraformer-impl.h | 4 ++ .../csrc/offline-recognizer-transducer-impl.h | 5 +++ .../offline-recognizer-transducer-nemo-impl.h | 4 ++ .../csrc/offline-recognizer-whisper-impl.h | 11 +++++ sherpa-onnx/csrc/offline-recognizer.cc | 8 ++++ sherpa-onnx/csrc/offline-recognizer.h | 9 ++++ sherpa-onnx/csrc/offline-stream.h | 4 +- sherpa-onnx/csrc/offline-whisper-decoder.h | 6 +++ .../offline-whisper-greedy-search-decoder.cc | 11 +++++ .../offline-whisper-greedy-search-decoder.h | 3 +- 15 files changed, 121 insertions(+), 13 deletions(-) diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index eb9ec8752..cda5832e2 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -308,8 +308,27 @@ struct SherpaOnnxOfflineStream { : impl(std::move(p)) {} }; +static sherpa_onnx::OfflineRecognizerConfig convertConfig( + const SherpaOnnxOfflineRecognizerConfig *config); SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( const SherpaOnnxOfflineRecognizerConfig *config) { + sherpa_onnx::OfflineRecognizerConfig recognizer_config = + convertConfig(config); + + if (!recognizer_config.Validate()) { + SHERPA_ONNX_LOGE("Errors in config"); + return nullptr; + } + + SherpaOnnxOfflineRecognizer *recognizer = new SherpaOnnxOfflineRecognizer; + + recognizer->impl = + std::make_unique(recognizer_config); + + return recognizer; +} +sherpa_onnx::OfflineRecognizerConfig convertConfig( + const SherpaOnnxOfflineRecognizerConfig *config) { sherpa_onnx::OfflineRecognizerConfig recognizer_config; recognizer_config.feat_config.sampling_rate = @@ -398,17 +417,15 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( SHERPA_ONNX_LOGE("%s", recognizer_config.ToString().c_str()); } - if (!recognizer_config.Validate()) { - SHERPA_ONNX_LOGE("Errors in config"); - return nullptr; - } - - SherpaOnnxOfflineRecognizer *recognizer = new SherpaOnnxOfflineRecognizer; - - recognizer->impl = - std::make_unique(recognizer_config); + return recognizer_config; +} - return recognizer; +void SherpaOnnxOfflineRecognizerSetConfig( + const SherpaOnnxOfflineRecognizer *recognizer, + const SherpaOnnxOfflineRecognizerConfig *config){ + sherpa_onnx::OfflineRecognizerConfig recognizer_config = + convertConfig(config); + recognizer->impl->SetConfig(recognizer_config); } void DestroyOfflineRecognizer(SherpaOnnxOfflineRecognizer *recognizer) { @@ -461,6 +478,13 @@ const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult( pText[text.size()] = 0; r->text = pText; + //lang + const auto &lang = result.lang; + char *c_lang = new char[lang.size() + 1]; + std::copy(lang.begin(), lang.end(), c_lang); + c_lang[lang.size()] = '\0'; + r->lang = c_lang; + // copy json std::string json = result.AsJsonString(); char *pJson = new char[json.size() + 1]; @@ -517,6 +541,7 @@ void DestroyOfflineRecognizerResult( delete[] r->tokens; delete[] r->tokens_arr; delete[] r->json; + delete[] r->lang; delete r; } } diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 4beba2a73..e6d8ae272 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -428,6 +428,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineStream SherpaOnnxOfflineStream; SHERPA_ONNX_API SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( const SherpaOnnxOfflineRecognizerConfig *config); +/// @param config Config for the recognizer. +SHERPA_ONNX_API void SherpaOnnxOfflineRecognizerSetConfig( + const SherpaOnnxOfflineRecognizer *recognizer, + const SherpaOnnxOfflineRecognizerConfig *config); + /// Free a pointer returned by CreateOfflineRecognizer() /// /// @param p A pointer returned by CreateOfflineRecognizer() @@ -491,7 +496,7 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams( SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { const char *text; - // Pointer to continuous memory which holds timestamps + // Pointer to continuous memory which holds timestamps // // It is NULL if the model does not support timestamps float *timestamps; @@ -519,6 +524,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { * } */ const char *json; + + //return recognized language + const char *lang; + } SherpaOnnxOfflineRecognizerResult; /// Get the result of the offline stream. diff --git a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h index cbe9a9e88..9c7252a06 100644 --- a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h @@ -212,6 +212,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { } } + OfflineRecognizerConfig GetConfig() const override { + return config_; + } + + private: // Decode a single stream. // Some models do not support batch size > 1, e.g., WeNet CTC models. diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index 80a6766ce..dd96f2b8a 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -431,4 +431,8 @@ std::string OfflineRecognizerImpl::ApplyInverseTextNormalization( return text; } +void OfflineRecognizerImpl::SetConfig(const OfflineRecognizerConfig &config) { + config_ = config; +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.h b/sherpa-onnx/csrc/offline-recognizer-impl.h index 1ba268c11..32010bf70 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-impl.h @@ -48,6 +48,10 @@ class OfflineRecognizerImpl { virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0; + virtual void SetConfig(const OfflineRecognizerConfig &config); + + virtual OfflineRecognizerConfig GetConfig() const = 0; + std::string ApplyInverseTextNormalization(std::string text) const; private: diff --git a/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h b/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h index a0d4af3b6..13240cc01 100644 --- a/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h @@ -211,6 +211,10 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { } } + OfflineRecognizerConfig GetConfig() const override { + return config_; + } + private: std::vector ApplyLFR(const std::vector &in) const { int32_t lfr_window_size = model_->LfrWindowSize(); diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h index c439319eb..05759ac5b 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h @@ -246,6 +246,11 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { } } + OfflineRecognizerConfig GetConfig() const override { + return config_; + } + + void InitHotwords() { // each line in hotwords_file contains space-separated words diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h index d5902b05b..2f5b9e2a2 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h @@ -139,6 +139,10 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { } } + OfflineRecognizerConfig GetConfig() const override { + return config_; + } + private: void PostInit() { config_.feat_config.nemo_normalize_type = diff --git a/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h index e56f07550..023700e77 100644 --- a/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h @@ -45,6 +45,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, } r.text = text; + r.lang = src.lang; return r; } @@ -100,8 +101,18 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { } } + void SetConfig(const OfflineRecognizerConfig &config) override { + config_.model_config.whisper = config.model_config.whisper; + } + + OfflineRecognizerConfig GetConfig() const override { + return config_; + } + private: void DecodeStream(OfflineStream *s) const { + decoder_->SetConfig(config_.model_config.whisper); + int32_t max_num_frames = 3000; auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); diff --git a/sherpa-onnx/csrc/offline-recognizer.cc b/sherpa-onnx/csrc/offline-recognizer.cc index 1285a5cd3..f73e35ad6 100644 --- a/sherpa-onnx/csrc/offline-recognizer.cc +++ b/sherpa-onnx/csrc/offline-recognizer.cc @@ -156,4 +156,12 @@ void OfflineRecognizer::DecodeStreams(OfflineStream **ss, int32_t n) const { impl_->DecodeStreams(ss, n); } +void OfflineRecognizer::SetConfig(const OfflineRecognizerConfig &config) { + impl_->SetConfig(config); +} + +OfflineRecognizerConfig OfflineRecognizer::GetConfig() const { + return impl_->GetConfig(); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-recognizer.h b/sherpa-onnx/csrc/offline-recognizer.h index 9290a53b5..8f0b47a08 100644 --- a/sherpa-onnx/csrc/offline-recognizer.h +++ b/sherpa-onnx/csrc/offline-recognizer.h @@ -119,6 +119,15 @@ class OfflineRecognizer { */ void DecodeStreams(OfflineStream **ss, int32_t n) const; + /** Onnxruntime Session objects are not affected by this method. + * The exact behavior can be defined by a specific recognizer impl. + * For instance, for the whisper recognizer, you can retrieve the language and task from + * the config and ignore any remaining fields in `config`. + */ + void SetConfig(const OfflineRecognizerConfig &config); + + OfflineRecognizerConfig GetConfig() const; + private: std::unique_ptr impl_; }; diff --git a/sherpa-onnx/csrc/offline-stream.h b/sherpa-onnx/csrc/offline-stream.h index e3c346fc4..0bc7b4a9b 100644 --- a/sherpa-onnx/csrc/offline-stream.h +++ b/sherpa-onnx/csrc/offline-stream.h @@ -26,7 +26,9 @@ struct OfflineRecognitionResult { // For instance, for BPE-based models it consists of a list of BPE tokens. std::vector tokens; - /// timestamps.size() == tokens.size() + std::string lang; + + /// timestamps.size() == tokens.size() /// timestamps[i] records the time in seconds when tokens[i] is decoded. std::vector timestamps; diff --git a/sherpa-onnx/csrc/offline-whisper-decoder.h b/sherpa-onnx/csrc/offline-whisper-decoder.h index c9367eafd..3babb3824 100644 --- a/sherpa-onnx/csrc/offline-whisper-decoder.h +++ b/sherpa-onnx/csrc/offline-whisper-decoder.h @@ -6,14 +6,17 @@ #define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_ #include +#include #include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-whisper-model-config.h" namespace sherpa_onnx { struct OfflineWhisperDecoderResult { /// The decoded token IDs std::vector tokens; + std::string lang; }; class OfflineWhisperDecoder { @@ -31,6 +34,9 @@ class OfflineWhisperDecoder { */ virtual std::vector Decode( Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0; + + virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0; + }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc index 15eacb62b..96bb9d971 100644 --- a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc @@ -12,6 +12,10 @@ namespace sherpa_onnx { +void OfflineWhisperGreedySearchDecoder::SetConfig(const OfflineWhisperModelConfig &config) { + config_ = config; +} + std::vector OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, Ort::Value cross_v) { @@ -129,6 +133,13 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, std::vector ans(1); + const auto &id2lang = model_->GetID2Lang(); + if (id2lang.count(initial_tokens[1])) { + ans[0].lang = id2lang.at(initial_tokens[1]); + } else { + ans[0].lang = ""; + } + ans[0].tokens = std::move(predicted_tokens); return ans; diff --git a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h index 5f2b41680..9692d90d8 100644 --- a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h @@ -8,7 +8,6 @@ #include #include "sherpa-onnx/csrc/offline-whisper-decoder.h" -#include "sherpa-onnx/csrc/offline-whisper-model-config.h" #include "sherpa-onnx/csrc/offline-whisper-model.h" namespace sherpa_onnx { @@ -22,6 +21,8 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder { std::vector Decode(Ort::Value cross_k, Ort::Value cross_v) override; + void SetConfig(const OfflineWhisperModelConfig &config) override; + private: OfflineWhisperModelConfig config_; OfflineWhisperModel *model_; // not owned