Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First working version of offline greedy search #2

Merged
merged 1 commit into from
Sep 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion sherpa-ncnn/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
include_directories(${CMAKE_SOURCE_DIR})

add_executable(online-fbank-test online-fbank-test.cc)
target_link_libraries(online-fbank-test kaldi-native-fbank-core)

add_executable(sherpa-ncnn
sherpa-ncnn.cc
symbol-table.cc
wave-reader.cc
)

target_link_libraries(sherpa-ncnn
ncnn
kaldi-native-fbank-core
)
target_link_libraries(sherpa-ncnn ncnn)
112 changes: 112 additions & 0 deletions sherpa-ncnn/csrc/sherpa-ncnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
* limitations under the License.
*/

#include "kaldi-native-fbank/csrc/online-feature.h"
#include "net.h"
#include "sherpa-ncnn/csrc/symbol-table.h"
#include "sherpa-ncnn/csrc/wave-reader.h"
#include <algorithm>
#include <iostream>

static void InitNet(ncnn::Net &net, const std::string &param,
Expand Down Expand Up @@ -52,11 +56,119 @@ int main() {
std::string joiner_model =
"bar/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin";

std::string wav1 = "./test_wavs/1089-134686-0001.wav";
// wav1 = "./test_wavs/1221-135766-0001.wav";
wav1 = "./test_wavs/1221-135766-0002.wav";

ncnn::Net encoder_net;
encoder_net.opt.use_packing_layout = false;
encoder_net.opt.use_fp16_storage = false;

ncnn::Net decoder_net;
decoder_net.opt.use_packing_layout = false;

ncnn::Net joiner_net;
joiner_net.opt.use_packing_layout = false;

InitNet(encoder_net, encoder_param, encoder_model);
InitNet(decoder_net, decoder_param, decoder_model);
InitNet(joiner_net, joiner_param, joiner_model);

std::vector<float> samples = sherpa_ncnn::ReadWave(wav1, 16000);

knf::FbankOptions opts;
opts.frame_opts.dither = 0;
opts.frame_opts.snip_edges = false;
opts.frame_opts.samp_freq = 16000;

opts.mel_opts.num_bins = 80;

knf::OnlineFbank fbank(opts);
fbank.AcceptWaveform(16000, samples.data(), samples.size());
fbank.InputFinished();

int32_t num_encoder_layers = 12;
int32_t batch_size = 1;
int32_t d_model = 512;
int32_t rnn_hidden_size = 1024;

ncnn::Mat h0;
h0.create(d_model, num_encoder_layers);
ncnn::Mat c0;
c0.create(rnn_hidden_size, num_encoder_layers);
h0.fill(0);
c0.fill(0);

int32_t feature_dim = 80;
ncnn::Mat features;
features.create(feature_dim, fbank.NumFramesReady());

for (int32_t i = 0; i != fbank.NumFramesReady(); ++i) {
const float *f = fbank.GetFrame(i);
std::copy(f, f + feature_dim, features.row(i));
}

ncnn::Mat feature_lengths(1);
feature_lengths[0] = features.h;

ncnn::Extractor encoder_ex = encoder_net.create_extractor();

encoder_ex.input("in0", features);
encoder_ex.input("in1", feature_lengths);
encoder_ex.input("in2", h0);
encoder_ex.input("in3", c0);

ncnn::Mat encoder_out;
encoder_ex.extract("out0", encoder_out);

int32_t context_size = 2;
int32_t blank_id = 0;

std::vector<int32_t> hyp(context_size, blank_id);
ncnn::Mat decoder_input(context_size);
static_cast<int32_t *>(decoder_input)[0] = blank_id + 1;
static_cast<int32_t *>(decoder_input)[1] = blank_id + 2;
decoder_input.fill(blank_id);

ncnn::Extractor decoder_ex = decoder_net.create_extractor();
ncnn::Mat decoder_out;
decoder_ex.input("in0", decoder_input);
decoder_ex.extract("out0", decoder_out);
decoder_out = decoder_out.reshape(decoder_out.w);

ncnn::Mat joiner_out;
for (int32_t t = 0; t != encoder_out.h; ++t) {
ncnn::Mat encoder_out_t(512, encoder_out.row(t));

auto joiner_ex = joiner_net.create_extractor();
joiner_ex.input("in0", encoder_out_t);
joiner_ex.input("in1", decoder_out);

joiner_ex.extract("out0", joiner_out);

auto y = static_cast<int32_t>(
std::distance(static_cast<const float *>(joiner_out),
std::max_element(static_cast<const float *>(joiner_out),
static_cast<const float *>(joiner_out) +
joiner_out.w)));

if (y != blank_id) {
static_cast<int32_t *>(decoder_input)[0] = hyp.back();
static_cast<int32_t *>(decoder_input)[1] = y;
hyp.push_back(y);

decoder_ex = decoder_net.create_extractor();
decoder_ex.input("in0", decoder_input);
decoder_ex.extract("out0", decoder_out);
decoder_out = decoder_out.reshape(decoder_out.w);
}
}
std::string text;
sherpa_ncnn::SymbolTable sym("./tokens.txt");
for (int32_t i = context_size; i != hyp.size(); ++i) {
text += sym[hyp[i]];
}

fprintf(stderr, "%s\n", text.c_str());
return 0;
}
78 changes: 78 additions & 0 deletions sherpa-ncnn/csrc/symbol-table.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/**
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "sherpa-ncnn/csrc/symbol-table.h"

#include <cassert>
#include <fstream>
#include <sstream>

namespace sherpa_ncnn {

SymbolTable::SymbolTable(const std::string &filename) {
std::ifstream is(filename);
std::string sym;
int32_t id;
while (is >> sym >> id) {
if (sym.size() >= 3) {
// For BPE-based models, we replace ▁ with a space
// Unicode 9601, hex 0x2581, utf8 0xe29681
const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
sym = sym.replace(0, 3, " ");
}
}

assert(!sym.empty());
assert(sym2id_.count(sym) == 0);
assert(id2sym_.count(id) == 0);

sym2id_.insert({sym, id});
id2sym_.insert({id, sym});
}
assert(is.eof());
}

std::string SymbolTable::ToString() const {
std::ostringstream os;
char sep = ' ';
for (const auto &p : sym2id_) {
os << p.first << sep << p.second << "\n";
}
return os.str();
}

const std::string &SymbolTable::operator[](int32_t id) const {
return id2sym_.at(id);
}

int32_t SymbolTable::operator[](const std::string &sym) const {
return sym2id_.at(sym);
}

bool SymbolTable::contains(int32_t id) const { return id2sym_.count(id) != 0; }

bool SymbolTable::contains(const std::string &sym) const {
return sym2id_.count(sym) != 0;
}

std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table) {
return os << symbol_table.ToString();
}

} // namespace sherpa_ncnn
62 changes: 62 additions & 0 deletions sherpa-ncnn/csrc/symbol-table.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/**
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef SHERPA_NCNN_CSRC_SYMBOL_TABLE_H_
#define SHERPA_NCNN_CSRC_SYMBOL_TABLE_H_

#include <string>
#include <unordered_map>

namespace sherpa_ncnn {

/// It manages mapping between symbols and integer IDs.
class SymbolTable {
public:
SymbolTable() = default;
/// Construct a symbol table from a file.
/// Each line in the file contains two fields:
///
/// sym ID
///
/// Fields are separated by space(s).
explicit SymbolTable(const std::string &filename);

/// Return a string representation of this symbol table
std::string ToString() const;

/// Return the symbol corresponding to the given ID.
const std::string &operator[](int32_t id) const;
/// Return the ID corresponding to the given symbol.
int32_t operator[](const std::string &sym) const;

/// Return true if there is a symbol with the given ID.
bool contains(int32_t id) const;

/// Return true if there is a given symbol in the symbol table.
bool contains(const std::string &sym) const;

private:
std::unordered_map<std::string, int32_t> sym2id_;
std::unordered_map<int32_t, std::string> id2sym_;
};

std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table);

} // namespace sherpa_ncnn

#endif // SHERPA_NCNN_CSRC_SYMBOL_TABLE_H_
107 changes: 107 additions & 0 deletions sherpa-ncnn/csrc/wave-reader.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/**
* Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cassert>
#include <fstream>
#include <iostream>
#include <utility>
#include <vector>

#include "sherpa-ncnn/csrc/wave-reader.h"

namespace sherpa_ncnn {
namespace {
// see http://soundfile.sapp.org/doc/WaveFormat/
//
// Note: We assume little endian here
// TODO(fangjun): Support big endian
struct WaveHeader {
void Validate() const {
// F F I R
assert(chunk_id == 0x46464952);
assert(chunk_size == 36 + subchunk2_size);
// E V A W
assert(format == 0x45564157);
assert(subchunk1_id == 0x20746d66);
assert(subchunk1_size == 16); // 16 for PCM
assert(audio_format == 1); // 1 for PCM
assert(num_channels == 1); // we support only single channel for now
assert(byte_rate == sample_rate * num_channels * bits_per_sample / 8);
assert(block_align == num_channels * bits_per_sample / 8);
assert(bits_per_sample == 16); // we support only 16 bits per sample
}

int32_t chunk_id;
int32_t chunk_size;
int32_t format;
int32_t subchunk1_id;
int32_t subchunk1_size;
int16_t audio_format;
int16_t num_channels;
int32_t sample_rate;
int32_t byte_rate;
int16_t block_align;
int16_t bits_per_sample;
int32_t subchunk2_id;
int32_t subchunk2_size;
};
static_assert(sizeof(WaveHeader) == 44, "");

// Read a wave file of mono-channel.
// Return its samples normalized to the range [-1, 1).
std::vector<float> ReadWaveImpl(std::istream &is, float *sample_rate) {
WaveHeader header;
is.read(reinterpret_cast<char *>(&header), sizeof(header));
assert((bool)is);

header.Validate();

*sample_rate = header.sample_rate;

// header.subchunk2_size contains the number of bytes in the data.
// As we assume each sample contains two bytes, so it is divided by 2 here
std::vector<int16_t> samples(header.subchunk2_size / 2);

is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);

assert((bool)is);

std::vector<float> ans(samples.size());
for (int32_t i = 0; i != ans.size(); ++i) {
ans[i] = samples[i] / 32768.;
}

return ans;
}

} // namespace

std::vector<float> ReadWave(const std::string &filename,
float expected_sample_rate) {
std::ifstream is(filename, std::ifstream::binary);
float sample_rate;
auto samples = ReadWaveImpl(is, &sample_rate);
if (expected_sample_rate != sample_rate) {
std::cerr << "Expected sample rate: " << expected_sample_rate
<< ". Given: " << sample_rate << ".\n";
exit(-1);
}
return samples;
}

} // namespace sherpa_ncnn
Loading