From cdca4e65e9c55ab6d536e20cbccd824d1c69f1e4 Mon Sep 17 00:00:00 2001 From: Sangeet Sagar <15uec053@lnmiit.ac.in> Date: Fri, 17 May 2024 12:44:49 +0200 Subject: [PATCH] adding online nemo transducer model files --- .../csrc/online-transducer-nemo-model.cc | 434 ++++++++++++++++++ .../csrc/online-transducer-nemo-model.h | 151 ++++++ 2 files changed, 585 insertions(+) create mode 100644 sherpa-onnx/csrc/online-transducer-nemo-model.cc create mode 100644 sherpa-onnx/csrc/online-transducer-nemo-model.h diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.cc b/sherpa-onnx/csrc/online-transducer-nemo-model.cc new file mode 100644 index 000000000..685a5ed27 --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.cc @@ -0,0 +1,434 @@ +// sherpa-onnx/csrc/online-transducer-nemo-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-transducer-nemo-model.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-transducer-decoder.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/transpose.h" +#include "sherpa-onnx/csrc/unbind.h" + +namespace sherpa_onnx { + +class OnlineTransducerNeMoModel::Impl { + public: + explicit Impl(const OnlineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_WARNING), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); + } + } + +#if __ANDROID_API__ >= 9 + Impl(AAssetManager *mgr, const OnlineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_WARNING), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.transducer.encoder_filename); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.decoder_filename); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.joiner_filename); + InitJoiner(buf.data(), buf.size()); + } + } +#endif + + std::vector StackStates( + const std::vector> &states) const { + int32_t batch_size = static_cast(states.size()); + int32_t num_encoders = static_cast(num_encoder_layers_.size()); + + std::vector buf(batch_size); + + std::vector ans; + int32_t num_states = static_cast(states[0].size()); + ans.reserve(num_states); + + for (int32_t i = 0; i != (num_states - 2) / 6; ++i) { + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i]; + } + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i + 1]; + } + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i + 2]; + } + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i + 3]; + } + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i + 4]; + } + auto v = Cat(allocator_, buf, 0); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i + 5]; + } + auto v = Cat(allocator_, buf, 0); + ans.push_back(std::move(v)); + } + } + + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][num_states - 2]; + } + auto v = Cat(allocator_, buf, 0); + ans.push_back(std::move(v)); + } + + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][num_states - 1]; + } + auto v = Cat(allocator_, buf, 0); + ans.push_back(std::move(v)); + } + return ans; + } + + std::vector>UnStackStates( + const std::vector &states) const { + int32_t m = std::accumulate(num_encoder_layers_.begin(), + num_encoder_layers_.end(), 0); + assert(states.size() == m * 6 + 2); + + int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; + int32_t num_encoders = num_encoder_layers_.size(); + + std::vector> ans; + ans.resize(batch_size); + + for (int32_t i = 0; i != m; ++i) { + { + auto v = Unbind(allocator_, &states[i * 6], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[i * 6 + 1], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[i * 6 + 2], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[i * 6 + 3], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[i * 6 + 4], 0); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[i * 6 + 5], 0); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + } + + { + auto v = Unbind(allocator_, &states[m * 6], 0); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[m * 6 + 1], 0); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + return ans; + } + + std::pair>RunEncoder(Ort::Value features, + std::vector states, + Ort::Value /* processed_frames */) { + std::vector encoder_inputs; + encoder_inputs.reserve(1 + states.size()); + + encoder_inputs.push_back(std::move(features)); + for (auto &v : states) { + encoder_inputs.push_back(std::move(v)); + } + + auto encoder_out = encoder_sess_->Run( + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), + encoder_inputs.size(), encoder_output_names_ptr_.data(), + encoder_output_names_ptr_.size()); + + std::vector next_states; + next_states.reserve(states.size()); + + for (int32_t i = 1; i != static_cast(encoder_out.size()); ++i) { + next_states.push_back(std::move(encoder_out[i])); + } + return {std::move(encoder_out[0]), std::move(next_states)}; + } + + Ort::Value RunDecoder(Ort::Value decoder_input) { + auto decoder_out = decoder_sess_->Run( + {}, decoder_input_names_ptr_.data(), &decoder_input, 1, + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size()); + return std::move(decoder_out[0]); + } + + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) { + std::array joiner_input = {std::move(encoder_out), + std::move(decoder_out)}; + auto logit = + joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(), + joiner_input.size(), joiner_output_names_ptr_.data(), + joiner_output_names_ptr_.size()); + + return std::move(logit[0]); +} + + std::vector GetDecoderInitStates(int32_t batch_size) const { + std::array s0_shape{pred_rnn_layers_, batch_size, pred_hidden_}; + Ort::Value s0 = Ort::Value::CreateTensor(allocator_, s0_shape.data(), + s0_shape.size()); + + Fill(&s0, 0); + + std::array s1_shape{pred_rnn_layers_, batch_size, pred_hidden_}; + + Ort::Value s1 = Ort::Value::CreateTensor(allocator_, s1_shape.data(), + s1_shape.size()); + + Fill(&s1, 0); + + std::vector states; + + states.reserve(2); + states.push_back(std::move(s0)); + states.push_back(std::move(s1)); + + return states; + } + + int32_t SubsamplingFactor() const { return subsampling_factor_; } + int32_t VocabSize() const { return vocab_size_; } + + OrtAllocator *Allocator() const { return allocator_; } + + std::string FeatureNormalizationMethod() const { return normalize_type_; } + +private: + void InitEncoder(void *model_data, size_t model_data_length) { + encoder_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + + // need to increase by 1 since the blank token is not included in computing + // vocab_size in NeMo. + vocab_size_ += 1; + + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); + SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type"); + SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers"); + SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden"); + + if (normalize_type_ == "NA") { + normalize_type_ = ""; + } + } + + void InitDecoder(void *model_data, size_t model_data_length) { + decoder_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + } + + void InitJoiner(void *model_data, size_t model_data_length) { + joiner_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(joiner_sess_.get(), &joiner_input_names_, + &joiner_input_names_ptr_); + + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, + &joiner_output_names_ptr_); + } + + private: + OnlineModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + std::unique_ptr joiner_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector joiner_input_names_; + std::vector joiner_input_names_ptr_; + + std::vector joiner_output_names_; + std::vector joiner_output_names_ptr_; + + int32_t vocab_size_ = 0; + int32_t subsampling_factor_ = 8; + std::string normalize_type_; + int32_t pred_rnn_layers_ = -1; + int32_t pred_hidden_ = -1; +}; + +OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( + const OnlineModelConfig &config) + : impl_(std::make_unique(config)) {} + +#if __ANDROID_API__ >= 9 +OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( + AAssetManager *mgr, const OnlineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + +OnlineTransducerNeMoModel::~OnlineTransducerNeMoModel() = default; + +int32_t ChunkLength() const { return window_size_; } + +int32_t ChunkShift() const { return chunk_shift_; } + +int32_t OnlineTransducerNeMoModel::SubsamplingFactor() const { + return impl_->SubsamplingFactor(); +} + +int32_t OnlineTransducerNeMoModel::VocabSize() const { + return impl_->VocabSize(); +} + +OrtAllocator *OnlineTransducerNeMoModel::Allocator() const { + return impl_->Allocator(); +} + +std::string OnlineTransducerNeMoModel::FeatureNormalizationMethod() const { + return impl_->FeatureNormalizationMethod(); +} + +} // namespace sherpa_onnx \ No newline at end of file diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.h b/sherpa-onnx/csrc/online-transducer-nemo-model.h new file mode 100644 index 000000000..e502136d4 --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.h @@ -0,0 +1,151 @@ +// sherpa-onnx/csrc/online-transducer-nemo-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_ + +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/online-model-config.h" + +namespace sherpa_onnx { + +// see +// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py#L40 +// Its decoder is stateful, not stateless. +class OnlineTransducerNeMoModel { + public: + explicit OnlineTransducerNeMoModel(const OnlineModelConfig &config); + +#if __ANDROID_API__ >= 9 + OnlineTransducerNeMoModel(AAssetManager *mgr, + const OfflineModelConfig &config); +#endif + + ~OnlineTransducerNeMoModel(); + + /** Stack a list of individual states into a batch. + * + * It is the inverse operation of `UnStackStates`. + * + * @param states states[i] contains the state for the i-th utterance. + * @return Return a single value representing the batched state. + */ + std::vector StackStates( + const std::vector> &states) const; + + /** Unstack a batch state into a list of individual states. + * + * It is the inverse operation of `StackStates`. + * + * @param states A batched state. + * @return ans[i] contains the state for the i-th utterance. + */ + std::vector> UnStackStates( + const std::vector &states) const; + + // /** Get the initial encoder states. + // * + // * @return Return the initial encoder state. + // */ + // std::vector GetEncoderInitStates() = 0; + + /** Run the encoder. + * + * @param features A tensor of shape (N, T, C). It is changed in-place. + * @param states Encoder state of the previous chunk. It is changed in-place. + * @param processed_frames Processed frames before subsampling. It is a 1-D + * tensor with data type int64_t. + * + * @return Return a tuple containing: + * - encoder_out, a tensor of shape (N, T', encoder_out_dim) + * - next_states Encoder state for the next chunk. + */ + std::pair> RunEncoder( + Ort::Value features, std::vector states, + Ort::Value processed_frames) const; // NOLINT + + /** Run the decoder network. + * + * @param targets A int32 tensor of shape (batch_size, 1) + * @param targets_length A int32 tensor of shape (batch_size,) + * @param states The states for the decoder model. + * @return Return a vector: + * - ans[0] is the decoder_out (a float tensor) + * - ans[1] is the decoder_out_length (a int32 tensor) + * - ans[2:] is the states_next + */ + std::pair> RunDecoder( + Ort::Value targets, Ort::Value targets_length, + std::vector states) const; + + std::vector GetDecoderInitStates(int32_t batch_size) const; + + /** Run the joint network. + * + * @param encoder_out Output of the encoder network. + * @param decoder_out Output of the decoder network. + * @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits. + */ + virtual Ort::Value RunJoiner( Ort::Value encoder_out, + Ort::Value decoder_out) const; + + // cache_last_time_dim3 in the model meta_data + int32_t ContextSize() const; + + /** We send this number of feature frames to the encoder at a time. */ + int32_t ChunkSize() const; + + /** Number of input frames to discard after each call to RunEncoder. + * + * For instance, if we have 30 frames, chunk_size=8, chunk_shift=6. + * + * In the first call of RunEncoder, we use frames 0~7 since chunk_size is 8. + * Then we discard frame 0~5 since chunk_shift is 6. + * In the second call of RunEncoder, we use frames 6~13; and then we discard + * frames 6~11. + * In the third call of RunEncoder, we use frames 12~19; and then we discard + * frames 12~16. + * + * Note: ChunkSize() - ChunkShift() == right context size + */ + int32_t ChunkShift() const; + + /** Return the subsampling factor of the model. + */ + int32_t SubsamplingFactor() const; + + int32_t VocabSize() const; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const; + + // Possible values: + // - per_feature + // - all_features (not implemented yet) + // - fixed_mean (not implemented) + // - fixed_std (not implemented) + // - or just leave it to empty + // See + // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59 + // for details + std::string FeatureNormalizationMethod() const; + + private: + class Impl; + std::unique_ptr impl_; + }; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_