Skip to content

Commit

Permalink
First working version of offline greedy search (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Sep 4, 2022
1 parent db5c5ac commit 6b7180a
Show file tree
Hide file tree
Showing 6 changed files with 409 additions and 1 deletion.
10 changes: 9 additions & 1 deletion sherpa-ncnn/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
include_directories(${CMAKE_SOURCE_DIR})

add_executable(online-fbank-test online-fbank-test.cc)
target_link_libraries(online-fbank-test kaldi-native-fbank-core)

add_executable(sherpa-ncnn
sherpa-ncnn.cc
symbol-table.cc
wave-reader.cc
)

target_link_libraries(sherpa-ncnn
ncnn
kaldi-native-fbank-core
)
target_link_libraries(sherpa-ncnn ncnn)
112 changes: 112 additions & 0 deletions sherpa-ncnn/csrc/sherpa-ncnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
* limitations under the License.
*/

#include "kaldi-native-fbank/csrc/online-feature.h"
#include "net.h"
#include "sherpa-ncnn/csrc/symbol-table.h"
#include "sherpa-ncnn/csrc/wave-reader.h"
#include <algorithm>
#include <iostream>

static void InitNet(ncnn::Net &net, const std::string &param,
Expand Down Expand Up @@ -52,11 +56,119 @@ int main() {
std::string joiner_model =
"bar/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin";

std::string wav1 = "./test_wavs/1089-134686-0001.wav";
// wav1 = "./test_wavs/1221-135766-0001.wav";
wav1 = "./test_wavs/1221-135766-0002.wav";

ncnn::Net encoder_net;
encoder_net.opt.use_packing_layout = false;
encoder_net.opt.use_fp16_storage = false;

ncnn::Net decoder_net;
decoder_net.opt.use_packing_layout = false;

ncnn::Net joiner_net;
joiner_net.opt.use_packing_layout = false;

InitNet(encoder_net, encoder_param, encoder_model);
InitNet(decoder_net, decoder_param, decoder_model);
InitNet(joiner_net, joiner_param, joiner_model);

std::vector<float> samples = sherpa_ncnn::ReadWave(wav1, 16000);

knf::FbankOptions opts;
opts.frame_opts.dither = 0;
opts.frame_opts.snip_edges = false;
opts.frame_opts.samp_freq = 16000;

opts.mel_opts.num_bins = 80;

knf::OnlineFbank fbank(opts);
fbank.AcceptWaveform(16000, samples.data(), samples.size());
fbank.InputFinished();

int32_t num_encoder_layers = 12;
int32_t batch_size = 1;
int32_t d_model = 512;
int32_t rnn_hidden_size = 1024;

ncnn::Mat h0;
h0.create(d_model, num_encoder_layers);
ncnn::Mat c0;
c0.create(rnn_hidden_size, num_encoder_layers);
h0.fill(0);
c0.fill(0);

int32_t feature_dim = 80;
ncnn::Mat features;
features.create(feature_dim, fbank.NumFramesReady());

for (int32_t i = 0; i != fbank.NumFramesReady(); ++i) {
const float *f = fbank.GetFrame(i);
std::copy(f, f + feature_dim, features.row(i));
}

ncnn::Mat feature_lengths(1);
feature_lengths[0] = features.h;

ncnn::Extractor encoder_ex = encoder_net.create_extractor();

encoder_ex.input("in0", features);
encoder_ex.input("in1", feature_lengths);
encoder_ex.input("in2", h0);
encoder_ex.input("in3", c0);

ncnn::Mat encoder_out;
encoder_ex.extract("out0", encoder_out);

int32_t context_size = 2;
int32_t blank_id = 0;

std::vector<int32_t> hyp(context_size, blank_id);
ncnn::Mat decoder_input(context_size);
static_cast<int32_t *>(decoder_input)[0] = blank_id + 1;
static_cast<int32_t *>(decoder_input)[1] = blank_id + 2;
decoder_input.fill(blank_id);

ncnn::Extractor decoder_ex = decoder_net.create_extractor();
ncnn::Mat decoder_out;
decoder_ex.input("in0", decoder_input);
decoder_ex.extract("out0", decoder_out);
decoder_out = decoder_out.reshape(decoder_out.w);

ncnn::Mat joiner_out;
for (int32_t t = 0; t != encoder_out.h; ++t) {
ncnn::Mat encoder_out_t(512, encoder_out.row(t));

auto joiner_ex = joiner_net.create_extractor();
joiner_ex.input("in0", encoder_out_t);
joiner_ex.input("in1", decoder_out);

joiner_ex.extract("out0", joiner_out);

auto y = static_cast<int32_t>(
std::distance(static_cast<const float *>(joiner_out),
std::max_element(static_cast<const float *>(joiner_out),
static_cast<const float *>(joiner_out) +
joiner_out.w)));

if (y != blank_id) {
static_cast<int32_t *>(decoder_input)[0] = hyp.back();
static_cast<int32_t *>(decoder_input)[1] = y;
hyp.push_back(y);

decoder_ex = decoder_net.create_extractor();
decoder_ex.input("in0", decoder_input);
decoder_ex.extract("out0", decoder_out);
decoder_out = decoder_out.reshape(decoder_out.w);
}
}
std::string text;
sherpa_ncnn::SymbolTable sym("./tokens.txt");
for (int32_t i = context_size; i != hyp.size(); ++i) {
text += sym[hyp[i]];
}

fprintf(stderr, "%s\n", text.c_str());
return 0;
}
78 changes: 78 additions & 0 deletions sherpa-ncnn/csrc/symbol-table.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/**
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "sherpa-ncnn/csrc/symbol-table.h"

#include <cassert>
#include <fstream>
#include <sstream>

namespace sherpa_ncnn {

SymbolTable::SymbolTable(const std::string &filename) {
std::ifstream is(filename);
std::string sym;
int32_t id;
while (is >> sym >> id) {
if (sym.size() >= 3) {
// For BPE-based models, we replace ▁ with a space
// Unicode 9601, hex 0x2581, utf8 0xe29681
const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
sym = sym.replace(0, 3, " ");
}
}

assert(!sym.empty());
assert(sym2id_.count(sym) == 0);
assert(id2sym_.count(id) == 0);

sym2id_.insert({sym, id});
id2sym_.insert({id, sym});
}
assert(is.eof());
}

std::string SymbolTable::ToString() const {
std::ostringstream os;
char sep = ' ';
for (const auto &p : sym2id_) {
os << p.first << sep << p.second << "\n";
}
return os.str();
}

const std::string &SymbolTable::operator[](int32_t id) const {
return id2sym_.at(id);
}

int32_t SymbolTable::operator[](const std::string &sym) const {
return sym2id_.at(sym);
}

bool SymbolTable::contains(int32_t id) const { return id2sym_.count(id) != 0; }

bool SymbolTable::contains(const std::string &sym) const {
return sym2id_.count(sym) != 0;
}

std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table) {
return os << symbol_table.ToString();
}

} // namespace sherpa_ncnn
62 changes: 62 additions & 0 deletions sherpa-ncnn/csrc/symbol-table.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/**
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef SHERPA_NCNN_CSRC_SYMBOL_TABLE_H_
#define SHERPA_NCNN_CSRC_SYMBOL_TABLE_H_

#include <string>
#include <unordered_map>

namespace sherpa_ncnn {

/// It manages mapping between symbols and integer IDs.
class SymbolTable {
public:
SymbolTable() = default;
/// Construct a symbol table from a file.
/// Each line in the file contains two fields:
///
/// sym ID
///
/// Fields are separated by space(s).
explicit SymbolTable(const std::string &filename);

/// Return a string representation of this symbol table
std::string ToString() const;

/// Return the symbol corresponding to the given ID.
const std::string &operator[](int32_t id) const;
/// Return the ID corresponding to the given symbol.
int32_t operator[](const std::string &sym) const;

/// Return true if there is a symbol with the given ID.
bool contains(int32_t id) const;

/// Return true if there is a given symbol in the symbol table.
bool contains(const std::string &sym) const;

private:
std::unordered_map<std::string, int32_t> sym2id_;
std::unordered_map<int32_t, std::string> id2sym_;
};

std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table);

} // namespace sherpa_ncnn

#endif // SHERPA_NCNN_CSRC_SYMBOL_TABLE_H_
107 changes: 107 additions & 0 deletions sherpa-ncnn/csrc/wave-reader.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/**
* Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cassert>
#include <fstream>
#include <iostream>
#include <utility>
#include <vector>

#include "sherpa-ncnn/csrc/wave-reader.h"

namespace sherpa_ncnn {
namespace {
// see http://soundfile.sapp.org/doc/WaveFormat/
//
// Note: We assume little endian here
// TODO(fangjun): Support big endian
struct WaveHeader {
void Validate() const {
// F F I R
assert(chunk_id == 0x46464952);
assert(chunk_size == 36 + subchunk2_size);
// E V A W
assert(format == 0x45564157);
assert(subchunk1_id == 0x20746d66);
assert(subchunk1_size == 16); // 16 for PCM
assert(audio_format == 1); // 1 for PCM
assert(num_channels == 1); // we support only single channel for now
assert(byte_rate == sample_rate * num_channels * bits_per_sample / 8);
assert(block_align == num_channels * bits_per_sample / 8);
assert(bits_per_sample == 16); // we support only 16 bits per sample
}

int32_t chunk_id;
int32_t chunk_size;
int32_t format;
int32_t subchunk1_id;
int32_t subchunk1_size;
int16_t audio_format;
int16_t num_channels;
int32_t sample_rate;
int32_t byte_rate;
int16_t block_align;
int16_t bits_per_sample;
int32_t subchunk2_id;
int32_t subchunk2_size;
};
static_assert(sizeof(WaveHeader) == 44, "");

// Read a wave file of mono-channel.
// Return its samples normalized to the range [-1, 1).
std::vector<float> ReadWaveImpl(std::istream &is, float *sample_rate) {
WaveHeader header;
is.read(reinterpret_cast<char *>(&header), sizeof(header));
assert((bool)is);

header.Validate();

*sample_rate = header.sample_rate;

// header.subchunk2_size contains the number of bytes in the data.
// As we assume each sample contains two bytes, so it is divided by 2 here
std::vector<int16_t> samples(header.subchunk2_size / 2);

is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);

assert((bool)is);

std::vector<float> ans(samples.size());
for (int32_t i = 0; i != ans.size(); ++i) {
ans[i] = samples[i] / 32768.;
}

return ans;
}

} // namespace

std::vector<float> ReadWave(const std::string &filename,
float expected_sample_rate) {
std::ifstream is(filename, std::ifstream::binary);
float sample_rate;
auto samples = ReadWaveImpl(is, &sample_rate);
if (expected_sample_rate != sample_rate) {
std::cerr << "Expected sample rate: " << expected_sample_rate
<< ". Given: " << sample_rate << ".\n";
exit(-1);
}
return samples;
}

} // namespace sherpa_ncnn
Loading

0 comments on commit 6b7180a

Please sign in to comment.