Skip to content

Commit

Permalink
cc file outline added
Browse files Browse the repository at this point in the history
  • Loading branch information
sangeet2020 committed May 19, 2024
1 parent ca4bfe8 commit 2bb7d7e
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 0 deletions.
156 changes: 156 additions & 0 deletions sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc
//
// Copyright (c) 2023 Xiaomi Corporation

#include "sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h"

#include <algorithm>
#include <utility>
#include <vector>

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"

namespace sherpa_onnx {

static void UseCachedDecoderOut(
const std::vector<OnlineTransducerDecoderResult> &results,
Ort::Value *decoder_out) {
std::vector<int64_t> shape =
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
float *dst = decoder_out->GetTensorMutableData<float>();
for (const auto &r : results) {
if (r.decoder_out) {
const float *src = r.decoder_out.GetTensorData<float>();
std::copy(src, src + shape[1], dst);
}
dst += shape[1];
}
}

static void UpdateCachedDecoderOut(
OrtAllocator *allocator, const Ort::Value *decoder_out,
std::vector<OnlineTransducerDecoderResult> *results) {
std::vector<int64_t> shape =
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 2> v_shape{1, shape[1]};

const float *src = decoder_out->GetTensorData<float>();
for (auto &r : *results) {
if (!r.decoder_out) {
r.decoder_out = Ort::Value::CreateTensor<float>(allocator, v_shape.data(),
v_shape.size());
}

float *dst = r.decoder_out.GetTensorMutableData<float>();
std::copy(src, src + shape[1], dst);
src += shape[1];
}
}

static std::pair<Ort::Value, Ort::Value> BuildDecoderInput(
int32_t token, OrtAllocator *allocator) {
std::array<int64_t, 2> shape{1, 1};

Ort::Value decoder_input =
Ort::Value::CreateTensor<int32_t>(allocator, shape.data(), shape.size());

std::array<int64_t, 1> length_shape{1};
Ort::Value decoder_input_length = Ort::Value::CreateTensor<int32_t>(
allocator, length_shape.data(), length_shape.size());

int32_t *p = decoder_input.GetTensorMutableData<int32_t>();

int32_t *p_length = decoder_input_length.GetTensorMutableData<int32_t>();

p[0] = token;

p_length[0] = 1;

return {std::move(decoder_input), std::move(decoder_input_length)};
}

static OnlineTransducerDecoderResult DecodeOne(
const float *p, int32_t num_rows, int32_t num_cols,
OnlineTransducerNeMoModel *model, float blank_penalty) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

OnlineTransducerDecoderResult ans;

int32_t vocab_size = model->VocabSize();
int32_t blank_id = vocab_size - 1;

auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator());

std::pair<Ort::Value, std::vector<Ort::Value>> decoder_output_pair =
model->RunDecoder(std::move(decoder_input_pair.first),
std::move(decoder_input_pair.second),
model->GetDecoderInitStates(1));

std::array<int64_t, 3> encoder_shape{1, num_cols, 1};

for (int32_t t = 0; t != num_rows; ++t) {
Ort::Value cur_encoder_out = Ort::Value::CreateTensor(
memory_info, const_cast<float *>(p) + t * num_cols, num_cols,
encoder_shape.data(), encoder_shape.size());

Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out),
View(&decoder_output_pair.first));

float *p_logit = logit.GetTensorMutableData<float>();
if (blank_penalty > 0) {

p_logit[blank_id] -= blank_penalty;
}

auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size)));

if (y != blank_id) {
ans.tokens.push_back(y);
ans.timestamps.push_back(t);

decoder_input_pair = BuildDecoderInput(y, model->Allocator());

decoder_output_pair =
model->RunDecoder(std::move(decoder_input_pair.first),
std::move(decoder_input_pair.second),
std::move(decoder_output_pair.second));
} // if (y != blank_id)
} // for (int32_t i = 0; i != num_rows; ++i)

return ans;
}

std::vector<OnlineTransducerDecoderResult>
OnlineTransducerGreedySearchNeMoDecoder::Decode(
Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) {
auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape();

int32_t batch_size = static_cast<int32_t>(shape[0]);
int32_t dim1 = static_cast<int32_t>(shape[1]);
int32_t dim2 = static_cast<int32_t>(shape[2]);

const float *p = encoder_out.GetTensorData<float>();

// checking for non-null elements in results

// create a new tensor with modified shape based on
// the first element of result and use cached decoder_out
// values if available.

// For each frame (num of frames is given by dim2), compute logits,
// determine tokens, and update results,
// then regenerate decoder output
// if tokens are emitted.

// call UpdateCachedDecoderOut and update frame offset
}

} // namespace sherpa_onnx
33 changes: 33 additions & 0 deletions sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h
//
// Copyright (c) 2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_
#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_

#include <vector>

#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-nemo-model.h"

namespace sherpa_onnx {

class OnlineTransducerGreedySearchNeMoDecoder
: public OnlineTransducerDecoder {
public:
OnlineTransducerGreedySearchNeMoDecoder(OnlineTransducerNeMoModel *model,
float blank_penalty)
: model_(model), blank_penalty_(blank_penalty) {}

std::vector<OnlineTransducerDecoderResult> Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length,
OnlineStream **ss = nullptr, int32_t n = 0) override;

private:
OnlineTransducerNeMoModel *model_; // Not owned
float blank_penalty_;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_

0 comments on commit 2bb7d7e

Please sign in to comment.