Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions src/pybind/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,37 +34,73 @@ ifeq ($(shell uname),Darwin)
LDFLAGS += -undefined dynamic_lookup
endif

CCFILES = kaldi_pybind.cc \
matrix/matrix_common_pybind.cc matrix/matrix_pybind.cc \
matrix/vector_pybind.cc \
util/table_types_pybind.cc \
feat/wave_reader_pybind.cc
CCFILES = \
chain/chain_pybind.cc \
chain/chain_supervision_pybind.cc \
feat/wave_reader_pybind.cc \
fst/arc_pybind.cc \
fst/compile_pybind.cc \
fst/fst_pybind.cc \
fst/symbol_table_pybind.cc \
fst/vector_fst_pybind.cc \
fst/weight_pybind.cc \
fstext/kaldi_fst_io_pybind.cc \
kaldi_pybind.cc \
matrix/matrix_common_pybind.cc \
matrix/matrix_pybind.cc \
matrix/vector_pybind.cc \
nnet3/nnet3_pybind.cc \
nnet3/nnet_chain_example_pybind.cc \
nnet3/nnet_common_pybind.cc \
nnet3/nnet_example_pybind.cc \
util/table_types_pybind.cc

CCFILES_OBJS := $(CCFILES:%.cc=%.o)

LIBNAME = kaldi_pybind


ADDLIBS = ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a ../feat/kaldi-feat.a
ADDLIBS = \
../base/kaldi-base.a \
../chain/kaldi-chain.a \
../feat/kaldi-feat.a \
../fstext/kaldi-fstext.a \
../matrix/kaldi-matrix.a \
../nnet3/kaldi-nnet3.a \
../util/kaldi-util.a


EXTRA_LDLIBS += $(foreach dep,$(ADDLIBS), $(dir $(dep))lib$(notdir $(basename $(dep))).so)
EXTRA_LDLIBS += ../../tools/openfst/lib/libfstscript.so

LDFLAGS += -Wl,-rpath=$(CURDIR)/../../tools/openfst/lib

LIBFILE=$(LIBNAME)$(LIBFILE_EXTENSION)


.PHONY: all clean test

all: $(LIBFILE)

$(LIBFILE): $(ADDLIBS) $(CCFILES)
$(CXX) $(CXXFLAGS) -shared -o $@ $(CCFILES) -Wl,--no-whole-archive -Wl,-rpath=$(CURDIR)/../lib $(LDFLAGS) $(LDLIBS) $(EXTRA_LDLIBS)
%.o: %.cc
$(CXX) -c $(CXXFLAGS) -o $@ $^

$(LIBFILE): $(ADDLIBS) $(CCFILES_OBJS)
$(CXX) $(CXXFLAGS) -shared -o $@ $(CCFILES_OBJS) -Wl,--no-whole-archive -Wl,-rpath=$(CURDIR)/../lib $(LDFLAGS) $(LDLIBS) $(EXTRA_LDLIBS)
python3 -c 'import kaldi_pybind' # this line is a test.

clean:
-rm -f *.so
-rm -rf __pycache__
-rm -f $(CCFILES_OBJS)

test: all
python3 tests/test_kaldi_pybind.py
python3 tests/test_matrix.py
python3 tests/test_wave.py
make -C fst test
make -C chain test
make -C nnet3 test

# valgrind-python.supp is from http://svn.python.org/projects/python/trunk/Misc/valgrind-python.supp
# since we do not compile Python from souce, we follow the comment in valgrind-python.supp
Expand Down
4 changes: 4 additions & 0 deletions src/pybind/chain/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

test:
python3 ./chain_supervision_pybind_test.py

29 changes: 29 additions & 0 deletions src/pybind/chain/chain_pybind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// pybind/chain/chain_pybind.cc

// Copyright 2019 Mobvoi AI Lab, Beijing, China
// (author: Fangjun Kuang, Yaguang Hu, Jian Wang)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "chain/chain_pybind.h"

#include "chain/chain_supervision_pybind.h"

void pybind_chain(py::module& _m) {
py::module m = _m.def_submodule("chain", "chain pybind for Kaldi");

pybind_chain_supervision(m);
}
26 changes: 26 additions & 0 deletions src/pybind/chain/chain_pybind.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// pybind/chain/chain_pybind.h

// Copyright 2019 Mobvoi AI Lab, Beijing, China
// (author: Fangjun Kuang, Yaguang Hu, Jian Wang)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_PYBIND_CHAIN_CHAIN_PYBIND_H_
#define KALDI_PYBIND_CHAIN_CHAIN_PYBIND_H_

#include "pybind/kaldi_pybind.h"

void pybind_chain(py::module& m);

#endif // KALDI_PYBIND_CHAIN_CHAIN_PYBIND_H_
121 changes: 121 additions & 0 deletions src/pybind/chain/chain_supervision_pybind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// pybind/chain/chain_supervision_pybind.cc

// Copyright 2019 Mobvoi AI Lab, Beijing, China
// (author: Fangjun Kuang, Yaguang Hu, Jian Wang)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "chain/chain_pybind.h"

#include "chain/chain-supervision.h"

using namespace kaldi::chain;

void pybind_chain_supervision(py::module& m) {
{
using PyClass = Supervision;
py::class_<PyClass>(m, "Supervision",
"struct Supervision is the fully-processed supervision "
"information for a whole utterance or (after "
"splitting) part of an utterance. It contains the "
"time limits on phones encoded into the FST.")
.def(py::init<>())
.def(py::init<const PyClass&>(), py::arg("other"))
.def("Swap", &PyClass::Swap)
.def_readwrite("weight", &PyClass::weight,
"The weight of this example (will usually be 1.0).")
.def_readwrite("num_sequences", &PyClass::num_sequences,
"num_sequences will be 1 if you create a Supervision "
"object from a single lattice or alignment, but if you "
"combine multiple Supevision objects the "
"'num_sequences' is the number of objects that were "
"combined (the FSTs get appended).")
.def_readwrite("frames_per_sequence", &PyClass::frames_per_sequence,
"the number of frames in each sequence of appended "
"objects. num_frames * num_sequences must equal the "
"path length of any path in the FST. Technically this "
"information is redundant with the FST, but it's "
"convenient to have it separately.")
.def_readwrite("label_dim", &PyClass::label_dim,
"the maximum possible value of the labels in 'fst' "
"(which go from 1 to label_dim). For fully-processed "
"examples this will equal the NumPdfs() in the "
"TransitionModel object, but for newer-style "
"'unconstrained' examples that have been output by "
"chain-get-supervision but not yet processed by "
"nnet3-chain-get-egs, it will be the NumTransitionIds() "
"of the TransitionModel object.")
.def_readwrite(
"fst", &PyClass::fst,
"This is an epsilon-free unweighted acceptor that is sorted in "
"increasing order of frame index (this implies it's topologically "
"sorted but it's a stronger condition). The labels will normally "
"be pdf-ids plus one (to avoid epsilons, since pdf-ids are "
"zero-based), but for newer-style 'unconstrained' examples that "
"have been output by chain-get-supervision but not yet processed "
"by nnet3-chain-get-egs, they will be transition-ids. Each "
"successful path in 'fst' has exactly 'frames_per_sequence * "
"num_sequences' arcs on it (first 'frames_per_sequence' arcs for "
"the first sequence; then 'frames_per_sequence' arcs for the "
"second sequence, and so on).")
.def_readwrite(
"e2e_fsts", &PyClass::e2e_fsts,
"'e2e_fsts' may be set as an alternative to 'fst'. These FSTs are "
"used when the numerator computation will be done with 'full "
"forward_backward' instead of constrained in time. (The "
"'constrained in time' fsts are how we described it in the "
"original LF-MMI paper, where each phone can only occur at the "
"same time it occurred in the lattice, extended by a tolerance)."
"\n"
"This 'e2e_fsts' is an array of FSTs, one per sequence, that are "
"acceptors with (pdf_id + 1) on the labels, just like 'fst', but "
"which are cyclic FSTs. Unlike with 'fst', it is not the case with "
"'e2e_fsts' that each arc corresponds to a specific frame)."
"\n"
"There are two situations 'e2e_fsts' might be set. The first is in "
"'end-to-end' training, where we train without a tree from a flat "
"start. The function responsible for creating this object in that "
"case is TrainingGraphToSupervision(); to find out more about "
"end-to-end training, see chain-generic-numerator.h The second "
"situation is where we create the supervision from lattices, and "
"split them into chunks using the time marks in the lattice, but "
"then make a cyclic FST, and don't enforce the times on the "
"lattice inside the chunk. [Code location TBD].")
.def_readwrite("alignment_pdfs", &PyClass::alignment_pdfs,
"This member is only set to a nonempty value if we are "
"creating 'unconstrained' egs. These are egs that are "
"split into chunks using the lattice alignments, but "
"then within the chunks we remove the frame-level "
"constraints on which phones can appear when, and use "
"the 'e2e_fsts' member."
"\n"
"It is only required in order to accumulate the LDA "
"stats using `nnet3-chain-acc-lda-stats`, and it is not "
"merged by nnet3-chain-merge-egs; it will only be "
"present for un-merged egs.")
.def("__str__",
[](const PyClass& sup) {
std::ostringstream os;
os << "weight: " << sup.weight << "\n"
<< "num_sequences: " << sup.num_sequences << "\n"
<< "frames_per_sequence: " << sup.frames_per_sequence << "\n"
<< "label_dim: " << sup.label_dim << "\n";
return os.str();
})
// TODO(fangjun): Check, Write and Read are not wrapped
;
}
}
26 changes: 26 additions & 0 deletions src/pybind/chain/chain_supervision_pybind.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// pybind/chain/chain_superversion_pybind.h

// Copyright 2019 Mobvoi AI Lab, Beijing, China
// (author: Fangjun Kuang, Yaguang Hu, Jian Wang)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_PYBIND_CHAIN_CHAIN_SUPERVISION_PYBIND_H_
#define KALDI_PYBIND_CHAIN_CHAIN_SUPERVISION_PYBIND_H_

#include "pybind/kaldi_pybind.h"

void pybind_chain_supervision(py::module& m);

#endif // KALDI_PYBIND_CHAIN_CHAIN_SUPERVISION_PYBIND_H_
25 changes: 25 additions & 0 deletions src/pybind/chain/chain_supervision_pybind_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env python3

# Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
# Apache 2.0

import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), os.pardir))

import unittest

import numpy as np

import kaldi_pybind.chain as chain


class TestChainSupervision(unittest.TestCase):

def test_chain_supervision(self):
supervision = chain.Supervision()
print(supervision)


if __name__ == '__main__':
unittest.main()
4 changes: 4 additions & 0 deletions src/pybind/fst/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

test:
python3 ./symbol_table_pybind_test.py
python3 ./vector_fst_pybind_test.py
50 changes: 50 additions & 0 deletions src/pybind/fst/arc_pybind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// pybind/fst/arc_pybind.cc

// Copyright 2019 Mobvoi AI Lab, Beijing, China
// (author: Fangjun Kuang, Yaguang Hu, Jian Wang)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "fst/arc_pybind.h"

#include "fst/arc.h"

void pybind_arc(py::module& m) {
{
using PyClass = fst::StdArc;
using Weight = PyClass::Weight;
using Label = int;
using StateId = int;

py::class_<PyClass>(m, "StdArc")
.def(py::init<>())
.def(py::init<Label, Label, Weight, StateId>(), py::arg("ilabel"),
py::arg("olabel"), py::arg("weight"), py::arg("nextstate"))
.def(py::init<const PyClass&>(), py::arg("weight"))
.def_readwrite("ilabel", &PyClass::ilabel)
.def_readwrite("olabel", &PyClass::olabel)
.def_readwrite("weight", &PyClass::weight)
.def_readwrite("nextstate", &PyClass::nextstate)
.def("__str__",
[](const PyClass& arc) {
std::ostringstream os;
os << "(ilabel: " << arc.ilabel << ", "
<< "olabel: " << arc.olabel << ", "
<< "weight: " << arc.weight.Value() << ", "
<< "nextstate: " << arc.nextstate << ")";
return os.str();
})
.def_static("Type", &PyClass::Type);
}
}
Loading