From 8dbe41ce6be91c16d856787d8901d96f5bbe2edc Mon Sep 17 00:00:00 2001 From: Manix Date: Wed, 10 Jul 2024 07:25:32 +0000 Subject: [PATCH] encoder only trt --- .../csrc/online-zipformer2-transducer-model.cc | 14 +++++++++----- .../csrc/online-zipformer2-transducer-model.h | 5 ++++- sherpa-onnx/csrc/session.cc | 17 ++++++++++++++++- sherpa-onnx/csrc/session.h | 3 +++ 4 files changed, 32 insertions(+), 7 deletions(-) diff --git a/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc index e5c448210..f917c0881 100644 --- a/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc +++ b/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc @@ -33,7 +33,9 @@ namespace sherpa_onnx { OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( const OnlineModelConfig &config) : env_(ORT_LOGGING_LEVEL_WARNING), - sess_opts_(GetSessionOptions(config)), + encoder_sess_opts_(GetSessionOptions(config)), + decoder_sess_opts_(GetSessionOptions(config,"decoder")), + joiner_sess_opts_(GetSessionOptions(config,"joiner")), config_(config), allocator_{} { { @@ -57,7 +59,9 @@ OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( AAssetManager *mgr, const OnlineModelConfig &config) : env_(ORT_LOGGING_LEVEL_WARNING), config_(config), - sess_opts_(GetSessionOptions(config)), + encoder_sess_opts_(GetSessionOptions(config)), + decoder_sess_opts_(GetSessionOptions(config)), + joiner_sess_opts_(GetSessionOptions(config)), allocator_{} { { auto buf = ReadFile(mgr, config.transducer.encoder); @@ -79,7 +83,7 @@ OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( void OnlineZipformer2TransducerModel::InitEncoder(void *model_data, size_t model_data_length) { encoder_sess_ = std::make_unique(env_, model_data, - model_data_length, sess_opts_); + model_data_length, encoder_sess_opts_); GetInputNames(encoder_sess_.get(), &encoder_input_names_, &encoder_input_names_ptr_); @@ -132,7 +136,7 @@ void OnlineZipformer2TransducerModel::InitEncoder(void *model_data, void OnlineZipformer2TransducerModel::InitDecoder(void *model_data, size_t model_data_length) { decoder_sess_ = std::make_unique(env_, model_data, - model_data_length, sess_opts_); + model_data_length, decoder_sess_opts_); GetInputNames(decoder_sess_.get(), &decoder_input_names_, &decoder_input_names_ptr_); @@ -157,7 +161,7 @@ void OnlineZipformer2TransducerModel::InitDecoder(void *model_data, void OnlineZipformer2TransducerModel::InitJoiner(void *model_data, size_t model_data_length) { joiner_sess_ = std::make_unique(env_, model_data, - model_data_length, sess_opts_); + model_data_length, joiner_sess_opts_); GetInputNames(joiner_sess_.get(), &joiner_input_names_, &joiner_input_names_ptr_); diff --git a/sherpa-onnx/csrc/online-zipformer2-transducer-model.h b/sherpa-onnx/csrc/online-zipformer2-transducer-model.h index 07c9e9252..aa0f46f81 100644 --- a/sherpa-onnx/csrc/online-zipformer2-transducer-model.h +++ b/sherpa-onnx/csrc/online-zipformer2-transducer-model.h @@ -65,7 +65,10 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { private: Ort::Env env_; - Ort::SessionOptions sess_opts_; + Ort::SessionOptions encoder_sess_opts_; + Ort::SessionOptions decoder_sess_opts_; + Ort::SessionOptions joiner_sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; std::unique_ptr encoder_sess_; diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index b6fdaaa84..2c3f11bd0 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -94,7 +94,7 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, std::to_string(trt_config.trt_timing_cache_enable); auto trt_dump_subgraphs = std::to_string(trt_config.trt_dump_subgraphs); - + // SHERPA_ONNX_LOGE("max workspace : %s",trt_max_workspace_size.c_str()); std::vector trt_options = { {"device_id", device_id.c_str()}, {"trt_max_workspace_size", trt_max_workspace_size.c_str()}, @@ -223,6 +223,21 @@ Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) { config.provider_config.provider, &config.provider_config); } +Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, + const std::string &model_type) { + /* + Transducer models : Only encoder will run with tensorrt, + decoder and joiner will run with cuda + */ + if(config.provider_config.provider == "trt" && + (model_type == "decoder" || model_type == "joiner")) { + return GetSessionOptionsImpl(config.num_threads, + std::string("cuda"), &config.provider_config); + } + return GetSessionOptionsImpl(config.num_threads, + config.provider_config.provider, &config.provider_config); +} + Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { return GetSessionOptionsImpl(config.num_threads, config.provider); } diff --git a/sherpa-onnx/csrc/session.h b/sherpa-onnx/csrc/session.h index a4121436a..691a2ff3c 100644 --- a/sherpa-onnx/csrc/session.h +++ b/sherpa-onnx/csrc/session.h @@ -24,6 +24,9 @@ namespace sherpa_onnx { Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); +Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, + const std::string &model_type); + Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config);