diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc new file mode 100644 index 000000000..d3cba31d7 --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc @@ -0,0 +1,156 @@ +// sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace sherpa_onnx { + +static void UseCachedDecoderOut( + const std::vector &results, + Ort::Value *decoder_out) { + std::vector shape = + decoder_out->GetTensorTypeAndShapeInfo().GetShape(); + float *dst = decoder_out->GetTensorMutableData(); + for (const auto &r : results) { + if (r.decoder_out) { + const float *src = r.decoder_out.GetTensorData(); + std::copy(src, src + shape[1], dst); + } + dst += shape[1]; + } +} + +static void UpdateCachedDecoderOut( + OrtAllocator *allocator, const Ort::Value *decoder_out, + std::vector *results) { + std::vector shape = + decoder_out->GetTensorTypeAndShapeInfo().GetShape(); + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + std::array v_shape{1, shape[1]}; + + const float *src = decoder_out->GetTensorData(); + for (auto &r : *results) { + if (!r.decoder_out) { + r.decoder_out = Ort::Value::CreateTensor(allocator, v_shape.data(), + v_shape.size()); + } + + float *dst = r.decoder_out.GetTensorMutableData(); + std::copy(src, src + shape[1], dst); + src += shape[1]; + } +} + +static std::pair BuildDecoderInput( + int32_t token, OrtAllocator *allocator) { + std::array shape{1, 1}; + + Ort::Value decoder_input = + Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + + std::array length_shape{1}; + Ort::Value decoder_input_length = Ort::Value::CreateTensor( + allocator, length_shape.data(), length_shape.size()); + + int32_t *p = decoder_input.GetTensorMutableData(); + + int32_t *p_length = decoder_input_length.GetTensorMutableData(); + + p[0] = token; + + p_length[0] = 1; + + return {std::move(decoder_input), std::move(decoder_input_length)}; +} + +static OnlineTransducerDecoderResult DecodeOne( + const float *p, int32_t num_rows, int32_t num_cols, + OnlineTransducerNeMoModel *model, float blank_penalty) { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + OnlineTransducerDecoderResult ans; + + int32_t vocab_size = model->VocabSize(); + int32_t blank_id = vocab_size - 1; + + auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator()); + + std::pair> decoder_output_pair = + model->RunDecoder(std::move(decoder_input_pair.first), + std::move(decoder_input_pair.second), + model->GetDecoderInitStates(1)); + + std::array encoder_shape{1, num_cols, 1}; + + for (int32_t t = 0; t != num_rows; ++t) { + Ort::Value cur_encoder_out = Ort::Value::CreateTensor( + memory_info, const_cast(p) + t * num_cols, num_cols, + encoder_shape.data(), encoder_shape.size()); + + Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out), + View(&decoder_output_pair.first)); + + float *p_logit = logit.GetTensorMutableData(); + if (blank_penalty > 0) { + + p_logit[blank_id] -= blank_penalty; + } + + auto y = static_cast(std::distance( + static_cast(p_logit), + std::max_element(static_cast(p_logit), + static_cast(p_logit) + vocab_size))); + + if (y != blank_id) { + ans.tokens.push_back(y); + ans.timestamps.push_back(t); + + decoder_input_pair = BuildDecoderInput(y, model->Allocator()); + + decoder_output_pair = + model->RunDecoder(std::move(decoder_input_pair.first), + std::move(decoder_input_pair.second), + std::move(decoder_output_pair.second)); + } // if (y != blank_id) + } // for (int32_t i = 0; i != num_rows; ++i) + + return ans; +} + +std::vector +OnlineTransducerGreedySearchNeMoDecoder::Decode( + Ort::Value encoder_out, + std::vector *result) { + auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); + + int32_t batch_size = static_cast(shape[0]); + int32_t dim1 = static_cast(shape[1]); + int32_t dim2 = static_cast(shape[2]); + + const float *p = encoder_out.GetTensorData(); + + // checking for non-null elements in results + + // create a new tensor with modified shape based on + // the first element of result and use cached decoder_out + // values if available. + + // For each frame (num of frames is given by dim2), compute logits, + // determine tokens, and update results, + // then regenerate decoder output + // if tokens are emitted. + + // call UpdateCachedDecoderOut and update frame offset +} + +} // namespace sherpa_onnx \ No newline at end of file diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h new file mode 100644 index 000000000..d465f94c3 --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h @@ -0,0 +1,33 @@ +// sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ + +#include + +#include "sherpa-onnx/csrc/online-transducer-decoder.h" +#include "sherpa-onnx/csrc/online-transducer-nemo-model.h" + +namespace sherpa_onnx { + +class OnlineTransducerGreedySearchNeMoDecoder + : public OnlineTransducerDecoder { + public: + OnlineTransducerGreedySearchNeMoDecoder(OnlineTransducerNeMoModel *model, + float blank_penalty) + : model_(model), blank_penalty_(blank_penalty) {} + + std::vector Decode( + Ort::Value encoder_out, Ort::Value encoder_out_length, + OnlineStream **ss = nullptr, int32_t n = 0) override; + + private: + OnlineTransducerNeMoModel *model_; // Not owned + float blank_penalty_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_