diff --git a/src/pybind/Makefile b/src/pybind/Makefile index 5db8182c368..8eef865de4e 100644 --- a/src/pybind/Makefile +++ b/src/pybind/Makefile @@ -34,37 +34,73 @@ ifeq ($(shell uname),Darwin) LDFLAGS += -undefined dynamic_lookup endif -CCFILES = kaldi_pybind.cc \ - matrix/matrix_common_pybind.cc matrix/matrix_pybind.cc \ - matrix/vector_pybind.cc \ - util/table_types_pybind.cc \ - feat/wave_reader_pybind.cc +CCFILES = \ +chain/chain_pybind.cc \ +chain/chain_supervision_pybind.cc \ +feat/wave_reader_pybind.cc \ +fst/arc_pybind.cc \ +fst/compile_pybind.cc \ +fst/fst_pybind.cc \ +fst/symbol_table_pybind.cc \ +fst/vector_fst_pybind.cc \ +fst/weight_pybind.cc \ +fstext/kaldi_fst_io_pybind.cc \ +kaldi_pybind.cc \ +matrix/matrix_common_pybind.cc \ +matrix/matrix_pybind.cc \ +matrix/vector_pybind.cc \ +nnet3/nnet3_pybind.cc \ +nnet3/nnet_chain_example_pybind.cc \ +nnet3/nnet_common_pybind.cc \ +nnet3/nnet_example_pybind.cc \ +util/table_types_pybind.cc + +CCFILES_OBJS := $(CCFILES:%.cc=%.o) LIBNAME = kaldi_pybind -ADDLIBS = ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a ../feat/kaldi-feat.a +ADDLIBS = \ +../base/kaldi-base.a \ +../chain/kaldi-chain.a \ +../feat/kaldi-feat.a \ +../fstext/kaldi-fstext.a \ +../matrix/kaldi-matrix.a \ +../nnet3/kaldi-nnet3.a \ +../util/kaldi-util.a + EXTRA_LDLIBS += $(foreach dep,$(ADDLIBS), $(dir $(dep))lib$(notdir $(basename $(dep))).so) +EXTRA_LDLIBS += ../../tools/openfst/lib/libfstscript.so + +LDFLAGS += -Wl,-rpath=$(CURDIR)/../../tools/openfst/lib LIBFILE=$(LIBNAME)$(LIBFILE_EXTENSION) + .PHONY: all clean test all: $(LIBFILE) -$(LIBFILE): $(ADDLIBS) $(CCFILES) - $(CXX) $(CXXFLAGS) -shared -o $@ $(CCFILES) -Wl,--no-whole-archive -Wl,-rpath=$(CURDIR)/../lib $(LDFLAGS) $(LDLIBS) $(EXTRA_LDLIBS) +%.o: %.cc + $(CXX) -c $(CXXFLAGS) -o $@ $^ + +$(LIBFILE): $(ADDLIBS) $(CCFILES_OBJS) + $(CXX) $(CXXFLAGS) -shared -o $@ $(CCFILES_OBJS) -Wl,--no-whole-archive -Wl,-rpath=$(CURDIR)/../lib $(LDFLAGS) $(LDLIBS) $(EXTRA_LDLIBS) python3 -c 'import kaldi_pybind' # this line is a test. clean: -rm -f *.so -rm -rf __pycache__ + -rm -f $(CCFILES_OBJS) test: all python3 tests/test_kaldi_pybind.py python3 tests/test_matrix.py python3 tests/test_wave.py + make -C fst test + make -C chain test + make -C nnet3 test # valgrind-python.supp is from http://svn.python.org/projects/python/trunk/Misc/valgrind-python.supp # since we do not compile Python from souce, we follow the comment in valgrind-python.supp diff --git a/src/pybind/chain/Makefile b/src/pybind/chain/Makefile new file mode 100644 index 00000000000..f02ee0f0d74 --- /dev/null +++ b/src/pybind/chain/Makefile @@ -0,0 +1,4 @@ + +test: + python3 ./chain_supervision_pybind_test.py + diff --git a/src/pybind/chain/chain_pybind.cc b/src/pybind/chain/chain_pybind.cc new file mode 100644 index 00000000000..0dbe46f9707 --- /dev/null +++ b/src/pybind/chain/chain_pybind.cc @@ -0,0 +1,29 @@ +// pybind/chain/chain_pybind.cc + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "chain/chain_pybind.h" + +#include "chain/chain_supervision_pybind.h" + +void pybind_chain(py::module& _m) { + py::module m = _m.def_submodule("chain", "chain pybind for Kaldi"); + + pybind_chain_supervision(m); +} diff --git a/src/pybind/chain/chain_pybind.h b/src/pybind/chain/chain_pybind.h new file mode 100644 index 00000000000..0871e6e6e41 --- /dev/null +++ b/src/pybind/chain/chain_pybind.h @@ -0,0 +1,26 @@ +// pybind/chain/chain_pybind.h + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_PYBIND_CHAIN_CHAIN_PYBIND_H_ +#define KALDI_PYBIND_CHAIN_CHAIN_PYBIND_H_ + +#include "pybind/kaldi_pybind.h" + +void pybind_chain(py::module& m); + +#endif // KALDI_PYBIND_CHAIN_CHAIN_PYBIND_H_ diff --git a/src/pybind/chain/chain_supervision_pybind.cc b/src/pybind/chain/chain_supervision_pybind.cc new file mode 100644 index 00000000000..ac72b81a1e9 --- /dev/null +++ b/src/pybind/chain/chain_supervision_pybind.cc @@ -0,0 +1,121 @@ +// pybind/chain/chain_supervision_pybind.cc + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "chain/chain_pybind.h" + +#include "chain/chain-supervision.h" + +using namespace kaldi::chain; + +void pybind_chain_supervision(py::module& m) { + { + using PyClass = Supervision; + py::class_(m, "Supervision", + "struct Supervision is the fully-processed supervision " + "information for a whole utterance or (after " + "splitting) part of an utterance. It contains the " + "time limits on phones encoded into the FST.") + .def(py::init<>()) + .def(py::init(), py::arg("other")) + .def("Swap", &PyClass::Swap) + .def_readwrite("weight", &PyClass::weight, + "The weight of this example (will usually be 1.0).") + .def_readwrite("num_sequences", &PyClass::num_sequences, + "num_sequences will be 1 if you create a Supervision " + "object from a single lattice or alignment, but if you " + "combine multiple Supevision objects the " + "'num_sequences' is the number of objects that were " + "combined (the FSTs get appended).") + .def_readwrite("frames_per_sequence", &PyClass::frames_per_sequence, + "the number of frames in each sequence of appended " + "objects. num_frames * num_sequences must equal the " + "path length of any path in the FST. Technically this " + "information is redundant with the FST, but it's " + "convenient to have it separately.") + .def_readwrite("label_dim", &PyClass::label_dim, + "the maximum possible value of the labels in 'fst' " + "(which go from 1 to label_dim). For fully-processed " + "examples this will equal the NumPdfs() in the " + "TransitionModel object, but for newer-style " + "'unconstrained' examples that have been output by " + "chain-get-supervision but not yet processed by " + "nnet3-chain-get-egs, it will be the NumTransitionIds() " + "of the TransitionModel object.") + .def_readwrite( + "fst", &PyClass::fst, + "This is an epsilon-free unweighted acceptor that is sorted in " + "increasing order of frame index (this implies it's topologically " + "sorted but it's a stronger condition). The labels will normally " + "be pdf-ids plus one (to avoid epsilons, since pdf-ids are " + "zero-based), but for newer-style 'unconstrained' examples that " + "have been output by chain-get-supervision but not yet processed " + "by nnet3-chain-get-egs, they will be transition-ids. Each " + "successful path in 'fst' has exactly 'frames_per_sequence * " + "num_sequences' arcs on it (first 'frames_per_sequence' arcs for " + "the first sequence; then 'frames_per_sequence' arcs for the " + "second sequence, and so on).") + .def_readwrite( + "e2e_fsts", &PyClass::e2e_fsts, + "'e2e_fsts' may be set as an alternative to 'fst'. These FSTs are " + "used when the numerator computation will be done with 'full " + "forward_backward' instead of constrained in time. (The " + "'constrained in time' fsts are how we described it in the " + "original LF-MMI paper, where each phone can only occur at the " + "same time it occurred in the lattice, extended by a tolerance)." + "\n" + "This 'e2e_fsts' is an array of FSTs, one per sequence, that are " + "acceptors with (pdf_id + 1) on the labels, just like 'fst', but " + "which are cyclic FSTs. Unlike with 'fst', it is not the case with " + "'e2e_fsts' that each arc corresponds to a specific frame)." + "\n" + "There are two situations 'e2e_fsts' might be set. The first is in " + "'end-to-end' training, where we train without a tree from a flat " + "start. The function responsible for creating this object in that " + "case is TrainingGraphToSupervision(); to find out more about " + "end-to-end training, see chain-generic-numerator.h The second " + "situation is where we create the supervision from lattices, and " + "split them into chunks using the time marks in the lattice, but " + "then make a cyclic FST, and don't enforce the times on the " + "lattice inside the chunk. [Code location TBD].") + .def_readwrite("alignment_pdfs", &PyClass::alignment_pdfs, + "This member is only set to a nonempty value if we are " + "creating 'unconstrained' egs. These are egs that are " + "split into chunks using the lattice alignments, but " + "then within the chunks we remove the frame-level " + "constraints on which phones can appear when, and use " + "the 'e2e_fsts' member." + "\n" + "It is only required in order to accumulate the LDA " + "stats using `nnet3-chain-acc-lda-stats`, and it is not " + "merged by nnet3-chain-merge-egs; it will only be " + "present for un-merged egs.") + .def("__str__", + [](const PyClass& sup) { + std::ostringstream os; + os << "weight: " << sup.weight << "\n" + << "num_sequences: " << sup.num_sequences << "\n" + << "frames_per_sequence: " << sup.frames_per_sequence << "\n" + << "label_dim: " << sup.label_dim << "\n"; + return os.str(); + }) + // TODO(fangjun): Check, Write and Read are not wrapped + ; + } +} diff --git a/src/pybind/chain/chain_supervision_pybind.h b/src/pybind/chain/chain_supervision_pybind.h new file mode 100644 index 00000000000..cd4a5a7ba67 --- /dev/null +++ b/src/pybind/chain/chain_supervision_pybind.h @@ -0,0 +1,26 @@ +// pybind/chain/chain_superversion_pybind.h + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_PYBIND_CHAIN_CHAIN_SUPERVISION_PYBIND_H_ +#define KALDI_PYBIND_CHAIN_CHAIN_SUPERVISION_PYBIND_H_ + +#include "pybind/kaldi_pybind.h" + +void pybind_chain_supervision(py::module& m); + +#endif // KALDI_PYBIND_CHAIN_CHAIN_SUPERVISION_PYBIND_H_ diff --git a/src/pybind/chain/chain_supervision_pybind_test.py b/src/pybind/chain/chain_supervision_pybind_test.py new file mode 100755 index 00000000000..30339069cbf --- /dev/null +++ b/src/pybind/chain/chain_supervision_pybind_test.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 + +# Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Apache 2.0 + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), os.pardir)) + +import unittest + +import numpy as np + +import kaldi_pybind.chain as chain + + +class TestChainSupervision(unittest.TestCase): + + def test_chain_supervision(self): + supervision = chain.Supervision() + print(supervision) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/pybind/fst/Makefile b/src/pybind/fst/Makefile new file mode 100644 index 00000000000..e55eb359379 --- /dev/null +++ b/src/pybind/fst/Makefile @@ -0,0 +1,4 @@ + +test: + python3 ./symbol_table_pybind_test.py + python3 ./vector_fst_pybind_test.py diff --git a/src/pybind/fst/arc_pybind.cc b/src/pybind/fst/arc_pybind.cc new file mode 100644 index 00000000000..bfd2e368c46 --- /dev/null +++ b/src/pybind/fst/arc_pybind.cc @@ -0,0 +1,50 @@ +// pybind/fst/arc_pybind.cc + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "fst/arc_pybind.h" + +#include "fst/arc.h" + +void pybind_arc(py::module& m) { + { + using PyClass = fst::StdArc; + using Weight = PyClass::Weight; + using Label = int; + using StateId = int; + + py::class_(m, "StdArc") + .def(py::init<>()) + .def(py::init(), py::arg("ilabel"), + py::arg("olabel"), py::arg("weight"), py::arg("nextstate")) + .def(py::init(), py::arg("weight")) + .def_readwrite("ilabel", &PyClass::ilabel) + .def_readwrite("olabel", &PyClass::olabel) + .def_readwrite("weight", &PyClass::weight) + .def_readwrite("nextstate", &PyClass::nextstate) + .def("__str__", + [](const PyClass& arc) { + std::ostringstream os; + os << "(ilabel: " << arc.ilabel << ", " + << "olabel: " << arc.olabel << ", " + << "weight: " << arc.weight.Value() << ", " + << "nextstate: " << arc.nextstate << ")"; + return os.str(); + }) + .def_static("Type", &PyClass::Type); + } +} diff --git a/src/pybind/fst/arc_pybind.h b/src/pybind/fst/arc_pybind.h new file mode 100644 index 00000000000..f2cbed77db0 --- /dev/null +++ b/src/pybind/fst/arc_pybind.h @@ -0,0 +1,26 @@ +// pybind/fst/arc_pybind.h + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_PYBIND_FST_ARC_PYBIND_H_ +#define KALDI_PYBIND_FST_ARC_PYBIND_H_ + +#include "pybind/kaldi_pybind.h" + +void pybind_arc(py::module& m); + +#endif // KALDI_PYBIND_FST_ARC_PYBIND_H_ diff --git a/src/pybind/fst/compile_pybind.cc b/src/pybind/fst/compile_pybind.cc new file mode 100644 index 00000000000..3b6459c86be --- /dev/null +++ b/src/pybind/fst/compile_pybind.cc @@ -0,0 +1,49 @@ +// pybind/fst/compile_pybind.cc + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "fst/compile_pybind.h" + +#include "fst/script/compile.h" + +void pybind_compile(py::module& m) { + m.def( + "CompileFst", + [](std::string& text_fst_str, const std::string& out_binary_fst_filename, + const std::string& source = "standard_input", + const string& fst_type = "vector", const string& arc_type = "standard", + const fst::SymbolTable* isyms = nullptr, + const fst::SymbolTable* osyms = nullptr, + const fst::SymbolTable* ssyms = nullptr, bool accep = false, + bool ikeep = false, bool okeep = false, bool nkeep = false, + bool allow_negative_labels = false) { + // (fangjun): paramemter `source` is only for debugging ! + std::stringstream strm; + strm << text_fst_str; + fst::script::CompileFst(strm, source, out_binary_fst_filename, fst_type, + arc_type, isyms, osyms, ssyms, accep, ikeep, + okeep, nkeep, allow_negative_labels); + }, + "the fst is written to out_binary_fst_filename", py::arg("text_fst_str"), + py::arg("out_binary_fst_filename"), py::arg("source") = "standard input", + py::arg("fst_type") = "vector", py::arg("arc_type") = "standard", + py::arg("isymbols") = nullptr, py::arg("osymbols") = nullptr, + py::arg("ssymbols") = nullptr, py::arg("acceptor") = false, + py::arg("keep_isymbols") = false, py::arg("keep_osymbols") = false, + py::arg("keep_state_numbering") = false, + py::arg("allow_negative_labels") = false); +} diff --git a/src/pybind/fst/compile_pybind.h b/src/pybind/fst/compile_pybind.h new file mode 100644 index 00000000000..d531d2f6b78 --- /dev/null +++ b/src/pybind/fst/compile_pybind.h @@ -0,0 +1,26 @@ +// pybind/fst/compile_pybind.h + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_PYBIND_FST_COMPILE_PYBIND_H_ +#define KALDI_PYBIND_FST_COMPILE_PYBIND_H_ + +#include "pybind/kaldi_pybind.h" + +void pybind_compile(py::module& m); + +#endif // KALDI_PYBIND_FST_COMPILE_PYBIND_H_ diff --git a/src/pybind/fst/fst_pybind.cc b/src/pybind/fst/fst_pybind.cc new file mode 100644 index 00000000000..f8cda4787a0 --- /dev/null +++ b/src/pybind/fst/fst_pybind.cc @@ -0,0 +1,353 @@ +// pybind/fst/fst_pybind.cc + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "fst/fst_pybind.h" + +#include "fst/fst.h" + +#include "fst/arc_pybind.h" +#include "fst/compile_pybind.h" +#include "fst/symbol_table_pybind.h" +#include "fst/vector_fst_pybind.h" +#include "fst/weight_pybind.h" +#include "fstext/kaldi_fst_io_pybind.h" + +template +using overload_cast_ = py::detail::overload_cast_impl; + +namespace { + +void _pybind_fst(py::module& m) { + m.attr("kNoLabel") = fst::kNoLabel; + m.attr("kNoStateId") = fst::kNoStateId; + { + using PyClass = fst::FstHeader; + py::class_(m, "FstHeader") + .def(py::init<>()) + .def("FstType", &PyClass::FstType, py::return_value_policy::reference) + .def("ArcType", &PyClass::ArcType, py::return_value_policy::reference) + .def("Version", &PyClass::Version) + .def("GetFlags", &PyClass::GetFlags) + .def("Properties", &PyClass::Properties) + .def("Start", &PyClass::Start) + .def("NumStates", &PyClass::NumStates) + .def("NumArcs", &PyClass::NumArcs) + .def("SetFstType", &PyClass::SetFstType, py::arg("type")) + .def("SetArcType", &PyClass::SetArcType, py::arg("type")) + .def("SetVersion", &PyClass::SetVersion, py::arg("version")) + .def("SetFlags", &PyClass::SetFlags, py::arg("flags")) + .def("SetProperties", &PyClass::SetProperties, py::arg("properties")) + .def("SetStart", &PyClass::SetStart, py::arg("start")) + .def("SetNumStates", &PyClass::SetNumStates, py::arg("numstates")) + .def("SetNumArcs", &PyClass::SetNumArcs, py::arg("numarcs")) + .def("Read", &PyClass::Read, py::arg("strm"), py::arg("source"), + py::arg("rewind") = false) + .def("Write", &PyClass::Write, py::arg("strm"), py::arg("source")) + .def("DebugString", &PyClass::DebugString); + } + { + using PyClass = fst::FstWriteOptions; + + py::class_(m, "FstWriteOptions") + .def_readwrite("source", &PyClass::source, "Where you're writing to.") + .def_readwrite("write_header", &PyClass::write_header, + "Where you're writing to.") + .def_readwrite("write_isymbols", &PyClass::write_isymbols, + "Write the header?") + .def_readwrite("write_osymbols", &PyClass::write_osymbols, + "Write input symbols?") + .def_readwrite("align", &PyClass::align, + "Write data aligned (may fail on pipes)?") + .def_readwrite("stream_write", &PyClass::stream_write, + "Avoid seek operations in writing.") + .def( + py::init(), + py::arg("source") = "", py::arg("write_header") = true, + py::arg("write_isymbols") = true, py::arg("write_osymbols") = true, + py::arg("align") = FLAGS_fst_align, py::arg("stream_write") = false) + .def("__str__", [](const PyClass& opt) { + std::ostringstream os; + os << "source: " << opt.source << "\n" + << "write_header: " << opt.write_header << "\n" + << "write_isymbols: " << opt.write_isymbols << "\n" + << "write_osymbols: " << opt.write_osymbols << "\n" + << "align: " << opt.align << "\n" + << "stream_write: " << opt.stream_write << "\n"; + return os.str(); + }); + } + + auto fst_read_options = + py::class_(m, "FstReadOptions") + .def(py::init(), + py::arg("source") = "", py::arg("header") = nullptr, + py::arg("isymbols") = nullptr, py::arg("osymbols") = nullptr) + .def(py::init(), + py::arg("source"), py::arg("isymbols") = nullptr, + py::arg("osymbols") = nullptr) + .def_readwrite("source", &fst::FstReadOptions::source, + "Where you're reading from.") + .def_readwrite("header", &fst::FstReadOptions::header, + "Pointer to FST header; if non-zero, use this info " + "(don't read a stream header).") + .def_readwrite("isymbols", &fst::FstReadOptions::isymbols, + "Pointer to input symbols; if non-zero, use this info " + "(read and skip stream isymbols)") + .def_readwrite("osymbols", &fst::FstReadOptions::osymbols, + "Pointer to output symbols; if non-zero, use this " + "info (read and skip stream osymbols)") + .def_readwrite("mode", &fst::FstReadOptions::mode, + "Read or map files (advisory, if possible)") + .def_readwrite("read_isymbols", &fst::FstReadOptions::read_isymbols, + "Read isymbols, if any (default: true).") + .def_readwrite("read_osymbols", &fst::FstReadOptions::read_osymbols, + "Read osymbols, if any (default: true).") + .def_static("ReadMode", &fst::FstReadOptions::ReadMode, + "Helper function to convert strings FileReadModes into " + "their enum value.", + py::arg("mode")) + .def("DebugString", &fst::FstReadOptions::DebugString, + "Outputs a debug string for the FstReadOptions object."); + + py::enum_( + fst_read_options, "FileReadMode", py::arithmetic(), + "FileReadMode(s) are advisory, there are " + "many conditions than prevent a\n" + "file from being mapped, READ mode will " + "be selected in these cases with\n" + "a warning indicating why it was chosen.") + .value("READ", fst::FstReadOptions::FileReadMode::READ) + .value("MAP", fst::FstReadOptions::FileReadMode::MAP) + .export_values(); + + py::enum_(m, "MatchType", py::arithmetic(), + "Specifies matcher action.") + .value("MATCH_INPUT", fst::MatchType::MATCH_INPUT, "Match input label.") + .value("MATCH_OUTPUT", fst::MatchType::MATCH_OUTPUT, + "Match output label.") + .value("MATCH_BOTH", fst::MatchType::MATCH_BOTH, + "Match input or output label.") + .value("MATCH_NONE", fst::MatchType::MATCH_NONE, "Match nothing.") + .value("MATCH_UNKNOWN", fst::MatchType::MATCH_UNKNOWN, + "match type unknown.") + .export_values(); + { + using PyClass = fst::StateIteratorBase; + py::class_(m, "StdArcStateIteratorBase") + .def("Done", &PyClass::Done, "End of iterator?") + .def("Value", &PyClass::Value, "Returns current state (when !Done()).") + .def("Next", &PyClass::Next, "Advances to next state (when !Done()).") + .def("Reset", &PyClass::Reset, "Resets to initial condition."); + } + + { + using PyClass = fst::StateIteratorData; + py::class_>( + m, "StdArcStateIteratorData") + .def(py::init<>()) + .def_readwrite("base", &PyClass::base, + "Specialized iterator if non-zero.") + .def_readwrite("nstates", &PyClass::nstates, + "Otherwise, the total number of states."); + } + + { + using PyClass = fst::ArcIteratorBase; + py::class_(m, "StdArcArcIteratorBase") + .def("Done", &PyClass::Done, "End of iterator?") + .def("Value", &PyClass::Value, "Returns current arc (when !Done()).", + py::return_value_policy::reference) + .def("Next", &PyClass::Next, "Advances to next arc (when !Done()).") + .def("Position", &PyClass::Position, "Returns current position.") + .def("Reset", &PyClass::Reset, "Resets to initial condition.") + .def("Seek", &PyClass::Seek, "Advances to arbitrary arc by position.") + .def("Flags", &PyClass::Flags, "Returns current behavorial flags.") + .def("SetFlags", &PyClass::SetFlags, "Sets behavorial flags."); + } + + { + using PyClass = fst::ArcIteratorData; + py::class_>( + m, "StdArcArcIteratorData") + .def(py::init<>()) + .def_readwrite("base", &PyClass::base, + "Specialized iterator if non-zero.") + .def_readwrite("arcs", &PyClass::arcs, "Otherwise arcs pointer") + .def_readwrite("narcs", &PyClass::narcs, "arc count") + .def_readwrite("ref_count", &PyClass::ref_count, + "reference count if non-zero."); + } + + { + using PyClass = fst::StdFst; + using Arc = PyClass::Arc; + using StateId = PyClass::StateId; + using Weight = PyClass::Weight; + + auto fst_state_iterator = + py::class_>(m, "StdFstStateIterator"); + auto fst_arc_iterator = + py::class_>(m, "StdFstArcIterator"); + + py::class_( + m, "StdFst", + "A generic FST, templated on the arc definition, with \n" + "common-demoninator methods (use StateIterator and \n" + "ArcIterator to iterate over its states and arcs).") + .def("Start", &PyClass::Start, "Initial state.") + .def("Final", &PyClass::Final, "State's final weight.") + .def("NumArcs", &PyClass::NumArcs, "State's arc count.") + .def("NumInputEpsilons", &PyClass::NumInputEpsilons, + "State's output epsilon count.") + .def("Properties", &PyClass::Properties, + "Property bits. If test = false, return stored properties bits " + "for mask\n" + "(some possibly unknown); if test = true, return property bits " + "for mask\n" + "(computing o.w. unknown).", + py::arg("mask"), py::arg("test")) + .def("Type", &PyClass::Type, "FST typename", + py::return_value_policy::reference) + .def( + "Copy", &PyClass::Copy, + "Gets a copy of this Fst. The copying behaves as follows:\n", + "\n" + "(1) The copying is constant time if safe = false or if safe = " + "true and is on an otherwise unaccessed FST.\n" + "\n" + "(2) If safe = true, the copy is thread-safe in that the original\n" + "and copy can be safely accessed (but not necessarily mutated) by\n" + "separate threads. For some FST types, 'Copy(true)' should only\n" + "be called on an FST that has not otherwise been accessed.\n" + "Behavior is otherwise undefined.\n" + "\n" + "(3) If a MutableFst is copied and then mutated, then the original" + "\n" + "is unmodified and vice versa (often by a copy-on-write on the \n" + "initial mutation, which may not be constant time).", + py::arg("safe") = false, py::return_value_policy::take_ownership) + .def_static( + "Read", + // clang-format off + overload_cast_()(&PyClass::Read), + // clang-format on + "Reads an FST from an input stream; returns nullptr on error.", + py::arg("strm"), py::arg("opts"), + py::return_value_policy::take_ownership) + .def_static( + "Read", overload_cast_()(&PyClass::Read), + "Reads an FST from a file; returns nullptr on error. An empty\n" + "filename results in reading from standard input.", + py::arg("filename"), py::return_value_policy::take_ownership) + .def("Write", + // clang-format off + (bool (PyClass::*)(std::ostream&, const fst::FstWriteOptions&)const)&PyClass::Write, + // clang-format on + "Writes an FST to an output stream; returns false on error.", + py::arg("strm"), py::arg("opts")) + .def("Write", + (bool (PyClass::*)(const fst::string&) const) & PyClass::Write, + "Writes an FST to a file; returns false on error; an empty\n" + "filename results in writing to standard output.", + py::arg("filename")) + .def("InputSymbols", &PyClass::InputSymbols, + "Returns input label symbol table; return nullptr if not " + "specified.", + py::return_value_policy::reference) + .def("OutputSymbols", &PyClass::OutputSymbols, + "Returns output label symbol table; return nullptr if not " + "specified.", + py::return_value_policy::reference) + .def("InitStateIterator", &PyClass::InitStateIterator, + "For generic state iterator construction (not normally called " + "directly by users). Does not copy the FST.", + py::arg("data")) + .def("InitArcIterator", &PyClass::InitArcIterator, + "For generic arc iterator construction (not normally called " + "directly by users). Does not copy the FST.", + py::arg("s"), py::arg("data")) +#if 0 + // TODO(fangjun): what is the use of InitMatcher? + .def("InitMatcher", &PyClass::InitMatcher, + "For generic matcher construction (not normally called directly " + "by users).", + py::arg("match_type")) // TODO(fangjun): reference semantics ? +#endif + ; + fst_state_iterator.def(py::init(), py::arg("fst")) + .def("Done", &fst::StateIterator::Done) + .def("Value", &fst::StateIterator::Value) + .def("Next", &fst::StateIterator::Next) + .def("Reset", &fst::StateIterator::Reset); + + fst_arc_iterator + .def(py::init(), py::arg("fst"), py::arg("s")) + .def("Done", &fst::ArcIterator::Done) + .def("Value", &fst::ArcIterator::Value, + py::return_value_policy::reference) + .def("Next", &fst::ArcIterator::Next) + .def("Reset", &fst::ArcIterator::Reset) + .def("Seek", &fst::ArcIterator::Seek, py::arg("a")) + .def("Position", &fst::ArcIterator::Position) + .def("Flags", &fst::ArcIterator::Flags) + .def("SetFlags", &fst::ArcIterator::SetFlags); + ; + } + + m.def("TestProperties", &fst::TestProperties, py::arg("fst"), + py::arg("mask"), py::arg("known")); + + m.def("FstToString", + // clang-format off + (fst::string (*)(const fst::StdFst&, const fst::FstWriteOptions&))&fst::FstToString, + // clang-format on + py::arg("fst"), + py::arg("options") = fst::FstWriteOptions("FstToString")); + + m.def("FstToString", + // clang-format off + (void (*)(const fst::StdFst&, fst::string*))&fst::FstToString, + // clang-format on + py::arg("fst"), py::arg("result")); + + m.def("FstToString", + // clang-format off + (void (*)(const fst::StdFst&, fst::string*, const fst::FstWriteOptions&))&fst::FstToString, + // clang-format on + py::arg("fst"), py::arg("result"), py::arg("options")); + + m.def("StringToFst", &fst::StringToFst, py::arg("s")); +} + +} // namespace + +void pybind_fst(py::module& _m) { + py::module m = _m.def_submodule("fst", "FST pybind for Kaldi"); + + // WARNING(fangjun): do NOT sort the following in alphabetic order! + pybind_weight(m); + pybind_arc(m); + pybind_symbol_table(m); + + _pybind_fst(m); + pybind_vector_fst(m); + pybind_kaldi_fst_io(m); + pybind_compile(m); +} diff --git a/src/pybind/fst/fst_pybind.h b/src/pybind/fst/fst_pybind.h new file mode 100644 index 00000000000..dadde21da90 --- /dev/null +++ b/src/pybind/fst/fst_pybind.h @@ -0,0 +1,26 @@ +// pybind/fst/fst_pybind.h + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_PYBIND_FST_FST_PYBIND_H_ +#define KALDI_PYBIND_FST_FST_PYBIND_H_ + +#include "pybind/kaldi_pybind.h" + +void pybind_fst(py::module& m); + +#endif // KALDI_PYBIND_FST_FST_PYBIND_H_ diff --git a/src/pybind/fst/symbol_table_pybind.cc b/src/pybind/fst/symbol_table_pybind.cc new file mode 100644 index 00000000000..96b351a6497 --- /dev/null +++ b/src/pybind/fst/symbol_table_pybind.cc @@ -0,0 +1,199 @@ +// pybind/fst/symbol_table_pybind.cc + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "fst/symbol_table_pybind.h" + +#include "fst/symbol-table.h" + +template +using overload_cast_ = py::detail::overload_cast_impl; + +void pybind_symbol_table(py::module& m) { + m.attr("kNoSymbol") = fst::kNoSymbol; + + { + using PyClass = fst::SymbolTableReadOptions; + py::class_(m, "SymbolTableReadOptions") + .def( + py::init>, const fst::string>(), + py::arg("string_hash_ranges"), py::arg("source")) + .def_readwrite("string_hash_ranges", &PyClass::string_hash_ranges) + .def_readwrite("source", &PyClass::source); + } + { + using PyClass = fst::SymbolTableTextOptions; + py::class_(m, "SymbolTableTextOptions") + .def(py::init(), py::arg("allow_negative_labels") = false) + .def_readwrite("allow_negative_labels", &PyClass::allow_negative_labels) + .def_readwrite("fst_field_separator", &PyClass::fst_field_separator); + } + { + using PyClass = fst::SymbolTable; + py::class_( + m, "SymbolTable", + "Symbol (string) to integer (and reverse) mapping.\n" + "\n" + "The SymbolTable implements the mappings of labels to strings and " + "reverse. SymbolTables are used to describe the alphabet of the input " + "and output abels for arcs in a Finite State Transducer." + "\n" + "SymbolTables are reference-counted and can therefore be shared across " + "multiple machines. For example a language model grammar G, with a " + "SymbolTable for the words in the language model can share this symbol " + "table with the lexical representation L o G.") + .def(py::init(), + "Constructs symbol table with an optional name.", + py::arg("name") = "") + .def_static("ReadText", + overload_cast_()( + &PyClass::ReadText), + "Reads a text representation of the symbol table from an " + "istream. Pass a name to give the resulting SymbolTable.", + py::arg("strm"), py::arg("name"), + py::arg("opts") = fst::SymbolTableTextOptions()) + .def_static("ReadText", + overload_cast_()( + &PyClass::ReadText), + "Reads a text representation of the symbol table", + py::arg("filename"), + py::arg("opts") = fst::SymbolTableTextOptions()) + .def_static( + "Read", + overload_cast_()( + &PyClass::Read), + "WARNING: Reading via symbol table read options should not be " + "used. This is a temporary work-around.", + py::arg("strm"), py::arg("opts") = fst::SymbolTableReadOptions()) + .def_static("Read", overload_cast_()( + &PyClass::Read), + "Reads a binary dump of the symbol table from a stream.", + py::arg("strm"), py::arg("source")) + .def_static( + "Read", overload_cast_()(&PyClass::Read), + "Reads a binary dump of the symbol table.", py::arg("filename")) + .def("Copy", &PyClass::Copy, "Creates a reference counted copy.") + .def("AddSymbol", + // clang-format off + (int64 (PyClass::*)(const fst::string&, int64)) &PyClass::AddSymbol, + // clang-format on + "Adds a symbol with given key to table. A symbol table also keeps " + "track of the last available key (highest key value in the symbol " + "table).", + py::arg("symbol"), py::arg("key")) + .def("AddSymbol", + (int64 (PyClass::*)(const fst::string&)) & PyClass::AddSymbol, + "Adds a symbol to the table. The associated value key is " + "automatically assigned by the symbol table.", + py::arg("symbol")) + .def("AddTable", &PyClass::AddTable, + "Adds another symbol table to this table. All key values will be " + "offset" + "by the current available key (highest key value in the symbol " + "table)." + "Note string symbols with the same key value will still have the " + "same" + "key value after the symbol table has been merged, but a different" + "value. Adding symbol tables do not result in changes in the base " + "table.", + py::arg("table")) + .def("RemoveSymbol", &PyClass::RemoveSymbol, py::arg("key")) + .def("Name", &PyClass::Name, "Returns the name of the symbol table.") + .def("SetName", &PyClass::SetName, "Sets the name of the symbol table.") + .def("CheckSum", &PyClass::CheckSum, + "Return the label-agnostic MD5 check-sum for this table. All new " + "symbols added to the table will result in an updated checksum. " + "Deprecated.") + .def("LabeledCheckSum", &PyClass::LabeledCheckSum, + "Same as CheckSum(), but returns an label-dependent version.") + .def("Write", (bool (PyClass::*)(std::ostream&) const) & PyClass::Write, + py::arg("strm")) + .def("Write", + (bool (PyClass::*)(const fst::string&) const) & PyClass::Write, + py::arg("filename")) + .def("WriteText", + // clang-format off + (bool (PyClass::*)(std::ostream&, const fst::SymbolTableTextOptions&) const) &PyClass::WriteText, + // clang-format on + "Dump a text representation of the symbol table via a stream.", + py::arg("strm"), py::arg("opts") = fst::SymbolTableTextOptions()) + .def("WriteText", + (bool (PyClass::*)(const fst::string&) const) & PyClass::WriteText, + "Dump a text representation of the symbol table.", + py::arg("filename")) + .def("Find", (fst::string (PyClass::*)(int64) const) & PyClass::Find, + "Returns the string associated with the key; if the key is out of" + "range (<0, >max), returns an empty string.", + py::arg("key")) + .def("Find", + (int64 (PyClass::*)(const fst::string&) const) & PyClass::Find, + "Returns the key associated with the symbol; if the symbol does " + "not exist, kNoSymbol is returned.", + py::arg("symbol")) + .def("Find", (int64 (PyClass::*)(const char*) const) & PyClass::Find, + "Returns the key associated with the symbol; if the symbol does " + "not exist," + "kNoSymbol is returned.", + py::arg("symbol")) + .def("Member", (bool (PyClass::*)(int64) const) & PyClass::Member, + py::arg("key")) + .def("Member", + (bool (PyClass::*)(const fst::string&) const) & PyClass::Member, + py::arg("symbol")) + .def("AvailableKey", &PyClass::AvailableKey, + "Returns the current available key (i.e., highest key + 1) in the " + "symbol table.") + .def("NumSymbols", &PyClass::NumSymbols, + "Returns the current number of symbols in table (not necessarily " + "equal to AvailableKey()).") + .def("GetNthKey", &PyClass::GetNthKey, py::arg("pos")) + .def("__str__", [](const PyClass& sym) { + std::ostringstream os; + sym.WriteText(os); + return os.str(); + }); + } + { + using PyClass = fst::SymbolTableIterator; + py::class_(m, "SymbolTableIterator") + .def(py::init(), py::arg("table")) + .def("Done", &PyClass::Done, "Returns whether iterator is done.") + .def("Value", &PyClass::Value, "Return the key of the current symbol.") + .def("Symbol", &PyClass::Symbol, + "Return the string of the current symbol.") + .def("Next", &PyClass::Next, "Advances iterator.") + .def("Reset", &PyClass::Reset, "Resets iterator."); + } + m.def("RelabelSymbolTable", &fst::RelabelSymbolTable, + "Relabels a symbol table as specified by the input vector of pairs " + "(old label, new label). The new symbol table only retains symbols for " + "which a relabeling is explicitly specified.", + py::arg("table"), py::arg("pairs")); + + m.def("CompatSymbols", &fst::CompatSymbols, + "Returns true if the two symbol tables have equal checksums. Passing " + "in nullptr for either table always returns true.", + py::arg("sysm1"), py::arg("syms2"), py::arg("warning") = true); + + m.def("SymbolTableToString", &fst::SymbolTableToString, py::arg("table"), + py::arg("result")); + + m.def("StringToSymbolTable", &fst::StringToSymbolTable, py::arg("str"), + py::return_value_policy::take_ownership); +} diff --git a/src/pybind/fst/symbol_table_pybind.h b/src/pybind/fst/symbol_table_pybind.h new file mode 100644 index 00000000000..f7da83e9f98 --- /dev/null +++ b/src/pybind/fst/symbol_table_pybind.h @@ -0,0 +1,26 @@ +// pybind/fst/symbol_table_pybind.h + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_PYBIND_FST_SYMBOL_TABLE_PYBIND_H_ +#define KALDI_PYBIND_FST_SYMBOL_TABLE_PYBIND_H_ + +#include "pybind/kaldi_pybind.h" + +void pybind_symbol_table(py::module& m); + +#endif // KALDI_PYBIND_FST_SYMBOL_TABLE_PYBIND_H_ diff --git a/src/pybind/fst/symbol_table_pybind_test.py b/src/pybind/fst/symbol_table_pybind_test.py new file mode 100755 index 00000000000..a17c3d08570 --- /dev/null +++ b/src/pybind/fst/symbol_table_pybind_test.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 + +# Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Apache 2.0 + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), os.pardir)) + +import unittest + +import numpy as np + +import kaldi_pybind.fst as fst +import kaldi + + +class TestSymbolTable(unittest.TestCase): + + def test_symbol_table(self): + self.assertEqual(fst.kNoSymbol, -1) + + # the name can be arbitrary string, or it can simply be omitted + words = fst.SymbolTable(name='words.txt') + self.assertEqual(words.Name(), 'words.txt') + + # 0 1 2 3 4 5 6 + text = ' hello OpenFST in Python with Pybind11'.split() + indices = [words.AddSymbol(w) for w in text] + + for i in range(len(text)): + self.assertEqual(words.Find(key=i), text[i]) + self.assertEqual(words.Find(symbol=text[i]), i) + self.assertTrue(words.Member(key=i)) + self.assertTrue(words.Member(symbol=text[i])) + + self.assertEqual(words.Find('Kaldi'), fst.kNoSymbol) + self.assertEqual(words.AvailableKey(), len(text)) + self.assertEqual(words.NumSymbols(), len(text)) + + self.assertEqual(words.GetNthKey(pos=5), 5) + + symbol_table_iterator = fst.SymbolTableIterator(words) + i = 0 + while not symbol_table_iterator.Done(): + index = symbol_table_iterator.Value() + symbol = symbol_table_iterator.Symbol() + self.assertEqual(index, i) + self.assertEqual(symbol, text[i]) + symbol_table_iterator.Next() + i += 1 + + # the following is more pythonic for iteration + i = 0 + kaldi_symbol_iterator = kaldi.SymbolTableIterator(words) + for index, symbol in kaldi_symbol_iterator: + self.assertEqual(index, i) + self.assertEqual(symbol, text[i]) + i += 1 + + # to use the iterator again, we must reset it manually + kaldi_symbol_iterator.Reset() + i = 0 + for index, symbol in kaldi_symbol_iterator: + self.assertEqual(index, i) + self.assertEqual(symbol, text[i]) + i += 1 + + # after removing the word 'with' whose index is 5 + + words.RemoveSymbol(key=5) + self.assertEqual(words.Find(key=5), '') + self.assertEqual(words.Find(symbol='with'), fst.kNoSymbol) + self.assertEqual(words.AvailableKey(), len(text)) # still 7 + self.assertEqual(words.NumSymbols(), len(text) - 1) # now 6 = 7-1 + + # at pos 5, we have the word `Pybind11` which has index 6 + self.assertEqual(words.GetNthKey(pos=5), 6) + + words.AddSymbol(symbol='with', key=5) + self.assertEqual(words.Find(key=5), 'with') + self.assertEqual(words.Find(symbol='with'), 5) + self.assertEqual(words.AvailableKey(), len(text)) # still 7 + self.assertEqual(words.NumSymbols(), len(text)) # now 7 + + self.assertEqual(words.GetNthKey(pos=5), 6) # it's still 6 ! + + # test I/O + # to control the field separator, we can use + # fst::SymbolTableTextOptions::fst_field_separator, + # the default separator is controlled by FLAGS_fst_field_separator + # whose default value is '\t ', e.g., a tab and a space + filename = 'words.txt' + words.WriteText(filename=filename) + + words_txt_read_back = fst.SymbolTable.ReadText(filename=filename) + + self.assertEqual(words.CheckSum(), words_txt_read_back.CheckSum()) + self.assertTrue(fst.CompatSymbols(words, words_txt_read_back)) + + # now for binary + filename = 'words.bin' + words.Write(filename=filename) + words_bin_read_back = fst.SymbolTable.Read(filename=filename) + + self.assertEqual(words.CheckSum(), words_bin_read_back.CheckSum()) + self.assertTrue(fst.CompatSymbols(words, words_bin_read_back)) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/pybind/fst/vector_fst_pybind.cc b/src/pybind/fst/vector_fst_pybind.cc new file mode 100644 index 00000000000..a5d148db7be --- /dev/null +++ b/src/pybind/fst/vector_fst_pybind.cc @@ -0,0 +1,311 @@ +// pybind/fst/vector_fst_pybind.cc + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "fst/vector_fst_pybind.h" + +#include "fst/script/info-impl.h" + +#include "fst/script/fst-class.h" +#include "fst/script/print-impl.h" +#include "fst/vector-fst.h" + +template +using overload_cast_ = py::detail::overload_cast_impl; + +namespace { +// this following function is copied from openfst/src/script/info-impl.cc +void PrintFstInfoImpl(const fst::FstInfo& fstinfo, std::ostream& ostrm) { + using namespace fst; + ostrm.setf(std::ios::left); + ostrm.width(50); + ostrm << "fst type" << fstinfo.FstType() << std::endl; + ostrm.width(50); + ostrm << "arc type" << fstinfo.ArcType() << std::endl; + ostrm.width(50); + ostrm << "input symbol table" << fstinfo.InputSymbols() << std::endl; + ostrm.width(50); + ostrm << "output symbol table" << fstinfo.OutputSymbols() << std::endl; + if (!fstinfo.LongInfo()) { + return; + } + ostrm.width(50); + ostrm << "# of states" << fstinfo.NumStates() << std::endl; + ostrm.width(50); + ostrm << "# of arcs" << fstinfo.NumArcs() << std::endl; + ostrm.width(50); + ostrm << "initial state" << fstinfo.Start() << std::endl; + ostrm.width(50); + ostrm << "# of final states" << fstinfo.NumFinal() << std::endl; + ostrm.width(50); + ostrm << "# of input/output epsilons" << fstinfo.NumEpsilons() << std::endl; + ostrm.width(50); + ostrm << "# of input epsilons" << fstinfo.NumInputEpsilons() << std::endl; + ostrm.width(50); + ostrm << "# of output epsilons" << fstinfo.NumOutputEpsilons() << std::endl; + ostrm.width(50); + ostrm << "input label multiplicity" << fstinfo.InputLabelMultiplicity() + << std::endl; + ostrm.width(50); + ostrm << "output label multiplicity" << fstinfo.OutputLabelMultiplicity() + << std::endl; + ostrm.width(50); + string arc_type = ""; + if (fstinfo.ArcFilterType() == "epsilon") + arc_type = "epsilon "; + else if (fstinfo.ArcFilterType() == "iepsilon") + arc_type = "input-epsilon "; + else if (fstinfo.ArcFilterType() == "oepsilon") + arc_type = "output-epsilon "; + const auto accessible_label = "# of " + arc_type + "accessible states"; + ostrm.width(50); + ostrm << accessible_label << fstinfo.NumAccessible() << std::endl; + const auto coaccessible_label = "# of " + arc_type + "coaccessible states"; + ostrm.width(50); + ostrm << coaccessible_label << fstinfo.NumCoAccessible() << std::endl; + const auto connected_label = "# of " + arc_type + "connected states"; + ostrm.width(50); + ostrm << connected_label << fstinfo.NumConnected() << std::endl; + const auto numcc_label = "# of " + arc_type + "connected components"; + ostrm.width(50); + ostrm << numcc_label << fstinfo.NumCc() << std::endl; + const auto numscc_label = "# of " + arc_type + "strongly conn components"; + ostrm.width(50); + ostrm << numscc_label << fstinfo.NumScc() << std::endl; + ostrm.width(50); + ostrm << "input matcher" + << (fstinfo.InputMatchType() == MATCH_INPUT + ? 'y' + : fstinfo.InputMatchType() == MATCH_NONE ? 'n' : '?') + << std::endl; + ostrm.width(50); + ostrm << "output matcher" + << (fstinfo.OutputMatchType() == MATCH_OUTPUT + ? 'y' + : fstinfo.OutputMatchType() == MATCH_NONE ? 'n' : '?') + << std::endl; + ostrm.width(50); + ostrm << "input lookahead" << (fstinfo.InputLookAhead() ? 'y' : 'n') + << std::endl; + ostrm.width(50); + ostrm << "output lookahead" << (fstinfo.OutputLookAhead() ? 'y' : 'n') + << std::endl; + uint64 prop = 1; + for (auto i = 0; i < 64; ++i, prop <<= 1) { + if (prop & kBinaryProperties) { + char value = 'n'; + if (fstinfo.Properties() & prop) value = 'y'; + ostrm.width(50); + ostrm << PropertyNames[i] << value << std::endl; + } else if (prop & kPosTrinaryProperties) { + char value = '?'; + if (fstinfo.Properties() & prop) + value = 'y'; + else if (fstinfo.Properties() & prop << 1) + value = 'n'; + ostrm.width(50); + ostrm << PropertyNames[i] << value << std::endl; + } + } +} +} + +void pybind_vector_fst(py::module& m) { + { + using PyClass = fst::StdVectorFst; + using Arc = PyClass::Arc; + using StateId = PyClass::StateId; + using State = PyClass::State; + + py::class_(m, "StdVectorFst") + .def(py::init<>()) + .def(py::init(), py::arg("fst")) + .def(py::init(), py::arg("fst"), + py::arg("safe") = false) + .def("Start", &PyClass::Start) + .def("Final", &PyClass::Final, py::arg("s")) + .def("SetStart", &PyClass::SetStart, py::arg("s")) + .def("SetFinal", &PyClass::SetFinal, py::arg("s"), py::arg("weight")) + .def("SetProperties", &PyClass::SetProperties, py::arg("props"), + py::arg("mask")) + .def("AddState", (StateId (PyClass::*)()) & PyClass::AddState) + .def("AddArc", &PyClass::AddArc, py::arg("s"), py::arg("arc")) + .def("DeleteStates", (void (PyClass::*)(const std::vector&)) & + PyClass::DeleteStates, + py::arg("dstates")) + .def("DeleteStates", (void (PyClass::*)()) & PyClass::DeleteStates, + "Delete all states") + .def("DeleteArcs", + (void (PyClass::*)(StateId, size_t)) & PyClass::DeleteArcs, + py::arg("state"), py::arg("n")) + .def("DeleteArcs", (void (PyClass::*)(StateId)) & PyClass::DeleteArcs, + py::arg("s")) + .def("ReserveStates", &PyClass::ReserveStates, py::arg("s")) + .def("ReserveArcs", &PyClass::ReserveArcs, py::arg("s"), py::arg("n")) + .def("InputSymbols", &PyClass::InputSymbols, + "Returns input label symbol table; return nullptr if not " + "specified.", + py::return_value_policy::reference) + .def("OutputSymbols", &PyClass::OutputSymbols, + "Returns output label symbol table; return nullptr if not " + "specified.", + py::return_value_policy::reference) + .def("MutableInputSymbols", &PyClass::MutableInputSymbols, + "Returns input label symbol table; return nullptr if not " + "specified.", + py::return_value_policy::reference) + .def("MutableOutputSymbols", &PyClass::MutableOutputSymbols, + "Returns output label symbol table; return nullptr if not " + "specified.", + py::return_value_policy::reference) + .def("SetInputSymbols", &PyClass::SetInputSymbols, py::arg("isyms")) + .def("SetOutputSymbols", &PyClass::SetOutputSymbols, py::arg("osyms")) + .def("NumStates", &PyClass::NumStates) + .def("NumArcs", &PyClass::NumArcs, py::arg("s")) + .def("NumInputEpsilons", &PyClass::NumInputEpsilons, py::arg("s")) + .def("NumOutputEpsilons", &PyClass::NumOutputEpsilons, py::arg("s")) + .def("Properties", &PyClass::Properties, py::arg("mask"), + py::arg("test")) + .def("Type", &PyClass::Type, "FST typename", + py::return_value_policy::reference) + .def("Copy", &PyClass::Copy, + "Get a copy of this VectorFst. See Fst<>::Copy() for further " + "doc.", + py::arg("safe") = false, py::return_value_policy::take_ownership) + .def_static("Read", + // clang-format off + overload_cast_()(&PyClass::Read), + // clang-format on + "Reads a VectorFst from an input stream, returning nullptr " + "on error.", + py::arg("strm"), py::arg("opts"), + py::return_value_policy::take_ownership) + .def_static( + "Read", overload_cast_()(&PyClass::Read), + "Read a VectorFst from a file, returning nullptr on error; " + "empty " + "filename reads from standard input.", + py::arg("filename"), py::return_value_policy::take_ownership) + .def("Write", + // clang-format off + (bool (PyClass::*)(std::ostream&, const fst::FstWriteOptions&)const)&PyClass::Write, + // clang-format on + "Writes an FST to an output stream; returns false on error.", + py::arg("strm"), py::arg("opts")) + .def("Write", + (bool (PyClass::*)(const fst::string&) const) & PyClass::Write, + "Writes an FST to a file; returns false on error; an empty\n" + "filename results in writing to standard output.", + py::arg("filename")) + .def_static("WriteFst", &PyClass::WriteFst, + py::arg("fst"), py::arg("strm"), py::arg("opts")) + .def("InitStateIterator", &PyClass::InitStateIterator, + "For generic state iterator construction (not normally called " + "directly by users). Does not copy the FST.", + py::arg("data")) + .def("InitArcIterator", &PyClass::InitArcIterator, + "For generic arc iterator construction (not normally called " + "directly by users). Does not copy the FST.", + py::arg("s"), py::arg("data")) + .def("info", + [](const PyClass& vector_fst) -> std::string { + std::ostringstream os; + auto _fst = fst::script::FstClass(vector_fst); + auto fst_info = fst::FstInfo(*_fst.GetFst(), true); + PrintFstInfoImpl(fst_info); + return os.str(); + }) + .def("__str__", + [](const PyClass& vector_fst) -> std::string { + std::ostringstream os; + auto _fst = fst::script::FstClass(vector_fst); + fst::FstPrinter( + *_fst.GetFst(), _fst.InputSymbols(), + _fst.OutputSymbols(), + nullptr, // state symbol table, ssyms + false, // false means not in acceptor format + false, // false means not to show weight one + " ", // fst field separator, 6 spaces + "" // missing symbol + ) + .Print(&os, "standard output"); + return os.str(); + }) + .def("ToString", + [](const PyClass& vector_fst, bool is_acceptor = false, + bool show_weight_one = false, + const std::string& fst_field_separator = " ", + const std::string& missing_symbol = "", + const std::string& dest = "stardard output") { + std::ostringstream os; + auto _fst = fst::script::FstClass(vector_fst); + fst::FstPrinter(*_fst.GetFst(), _fst.InputSymbols(), + _fst.OutputSymbols(), nullptr, is_acceptor, + show_weight_one, fst_field_separator, + missing_symbol) + .Print(&os, dest); + return os.str(); + }, + "see fstprint for help, e.g., fstprint --help", + py::arg("is_acceptor") = false, py::arg("show_weight_one") = false, + py::arg("fst_field_separator") = " ", + py::arg("missing_symbol") = "", + py::arg("dest") = "stardard output"); + } + { + using PyClass = fst::StateIterator; + py::class_(m, "StdVectorFstStateIterator") + .def(py::init(), py::arg("fst")) + .def("Done", &PyClass::Done) + .def("Value", &PyClass::Value) + .def("Next", &PyClass::Next) + .def("Reset", &PyClass::Reset); + } + + { + using PyClass = fst::ArcIterator; + using StateId = PyClass::StateId; + py::class_(m, "StdVectorFstArcIterator") + .def(py::init(), py::arg("fst"), + py::arg("s")) + .def("Done", &PyClass::Done) + .def("Value", &PyClass::Value, py::return_value_policy::reference) + .def("Next", &PyClass::Next) + .def("Reset", &PyClass::Reset) + .def("Seek", &PyClass::Seek, py::arg("a")) + .def("Position", &PyClass::Position) + .def("Flags", &PyClass::Flags) + .def("SetFlags", &PyClass::SetFlags); + } + + { + using PyClass = fst::MutableArcIterator; + using StateId = PyClass::StateId; + py::class_(m, "StdVectorFstMutableArcIterator") + .def(py::init(), py::arg("fst"), + py::arg("s")) + .def("Done", &PyClass::Done) + .def("Value", &PyClass::Value, py::return_value_policy::reference) + .def("SetValue", &PyClass::SetValue, py::arg("arc")) + .def("Next", &PyClass::Next) + .def("Reset", &PyClass::Reset) + .def("Seek", &PyClass::Seek, py::arg("a")) + .def("Position", &PyClass::Position) + .def("Flags", &PyClass::Flags) + .def("SetFlags", &PyClass::SetFlags); + } +} diff --git a/src/pybind/fst/vector_fst_pybind.h b/src/pybind/fst/vector_fst_pybind.h new file mode 100644 index 00000000000..61db2e3c77f --- /dev/null +++ b/src/pybind/fst/vector_fst_pybind.h @@ -0,0 +1,26 @@ +// pybind/fst/vector_fst_pybind.h + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_PYBIND_FST_VECTOR_FST_PYBIND_H_ +#define KALDI_PYBIND_FST_VECTOR_FST_PYBIND_H_ + +#include "pybind/kaldi_pybind.h" + +void pybind_vector_fst(py::module& m); + +#endif // KALDI_PYBIND_FST_VECTOR_FST_PYBIND_H_ diff --git a/src/pybind/fst/vector_fst_pybind_test.py b/src/pybind/fst/vector_fst_pybind_test.py new file mode 100755 index 00000000000..7836ded6268 --- /dev/null +++ b/src/pybind/fst/vector_fst_pybind_test.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 + +# Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Apache 2.0 + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), os.pardir)) + +import unittest + +import numpy as np + +import kaldi_pybind.fst as fst +import kaldi + + +class TestStdVectorFst(unittest.TestCase): + + def test_std_vector_fst(self): + vector_fst = fst.StdVectorFst() + + # create the same FST from + # http://www.openfst.org/twiki/bin/view/FST/FstQuickTour#Creating%20FSTs%20Using%20Constructors + # 1st state will be state 0 (returned by AddState) + vector_fst.AddState() + vector_fst.SetStart(0) + vector_fst.AddArc(0, fst.StdArc(1, 1, fst.TropicalWeight(0.5), 1)) + vector_fst.AddArc(0, fst.StdArc(2, 2, fst.TropicalWeight(1.5), 1)) + + vector_fst.AddState() + vector_fst.AddArc(1, fst.StdArc(3, 3, fst.TropicalWeight(2.5), 2)) + + vector_fst.AddState() + vector_fst.SetFinal(2, fst.TropicalWeight(3.5)) + + # fstprint with default options + print(vector_fst) + + print('-' * 20) + print('fstprint with customized options (default options)') + print( + vector_fst.ToString(is_acceptor=False, + show_weight_one=False, + fst_field_separator=" " * 6, + missing_symbol="")) + # now build the symbol table + input_words = ' a b c'.split() + output_words = ' x y z'.split() + + isymbol_table = fst.SymbolTable() + for w in input_words: + isymbol_table.AddSymbol(w) + + osymbol_table = fst.SymbolTable() + for w in output_words: + osymbol_table.AddSymbol(w) + + vector_fst.SetInputSymbols(isyms=isymbol_table) + vector_fst.SetOutputSymbols(osyms=osymbol_table) + print(vector_fst) + + # now for I/O + fst_filename = 'test.fst' + vector_fst.Write(filename=fst_filename) + + read_back_fst = fst.StdVectorFst.Read(filename=fst_filename) + print('fst after reading back is:') + print(read_back_fst) + + # TODO(fangjun): check that the two fsts are the same: start/final/states/arcs/symbol tables + # TODO(fangjun): add fstdraw support + # TODO(fangjun): test fstcompile + + text_fst_str = read_back_fst.ToString() + + compiled_filename = "compiled.fst" + fst.CompileFst(text_fst_str=text_fst_str, + out_binary_fst_filename=compiled_filename, + isymbols=isymbol_table, + osymbols=osymbol_table, + keep_isymbols=True, + keep_osymbols=True) + + read_back_compiled_fst = fst.StdVectorFst.Read( + filename=compiled_filename) + print('-' * 20) + print('read back compiled fst is:') + print(read_back_compiled_fst) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/pybind/fst/weight_pybind.cc b/src/pybind/fst/weight_pybind.cc new file mode 100644 index 00000000000..18ed7d06cad --- /dev/null +++ b/src/pybind/fst/weight_pybind.cc @@ -0,0 +1,67 @@ +// pybind/fst/weight_pybind.cc + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "fst/weight_pybind.h" + +#include "fst/float-weight.h" + +void pybind_weight(py::module& m) { + { + using PyClass = fst::FloatWeight; + py::class_(m, "FloatWeight") + .def(py::init<>()) + .def(py::init(), py::arg("f")) + .def(py::init(), py::arg("weight")) + .def("Value", &PyClass::Value, py::return_value_policy::reference) + .def("Hash", &PyClass::Hash) + .def("__eq__", + [](const PyClass& w1, const PyClass& w2) { return w1 == w2; }) + .def("__str__", [](const PyClass& w) { + std::ostringstream os; + os << w.Value(); + return os.str(); + }); + } + { + using PyClass = fst::TropicalWeight; + py::class_(m, "TropicalWeight") + .def(py::init<>()) + .def(py::init(), py::arg("f")) + .def(py::init(), py::arg("weight")) + .def("Member", &PyClass::Member) + .def("Quantize", &PyClass::Quantize, py::arg("delta") = fst::kDelta) + .def("Reverse", &PyClass::Reverse) + .def_static("Zero", &PyClass::Zero) + .def_static("One", &PyClass::One) + .def_static("NoWeight", &PyClass::NoWeight) + .def_static("Type", &PyClass::Type) + .def_static("Properties", &PyClass::Properties); + + m.def("Plus", [](const PyClass& w1, const PyClass& w2) { + return fst::Plus(w1, w2); + }); + + m.def("Times", [](const PyClass& w1, const PyClass& w2) { + return fst::Times(w1, w2); + }); + + m.def("Divide", [](const PyClass& w1, const PyClass& w2) { + return fst::Divide(w1, w2); + }); + } +} diff --git a/src/pybind/fst/weight_pybind.h b/src/pybind/fst/weight_pybind.h new file mode 100644 index 00000000000..f02d672f5fc --- /dev/null +++ b/src/pybind/fst/weight_pybind.h @@ -0,0 +1,26 @@ +// pybind/fst/weight_pybind.h + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_PYBIND_FST_WEIGHT_PYBIND_H_ +#define KALDI_PYBIND_FST_WEIGHT_PYBIND_H_ + +#include "pybind/kaldi_pybind.h" + +void pybind_weight(py::module& m); + +#endif // KALDI_PYBIND_FST_WEIGHT_PYBIND_H_ diff --git a/src/pybind/fstext/kaldi_fst_io_pybind.cc b/src/pybind/fstext/kaldi_fst_io_pybind.cc new file mode 100644 index 00000000000..c7a9e7616f2 --- /dev/null +++ b/src/pybind/fstext/kaldi_fst_io_pybind.cc @@ -0,0 +1,94 @@ +// pybind/fstext/kaldi_fst_io_pybind.cc + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "fstext/kaldi_fst_io_pybind.h" + +#include "fstext/kaldi-fst-io.h" + +void pybind_kaldi_fst_io(py::module& m) { + m.def("ReadFstKaldi", (fst::StdVectorFst * (*)(std::string))fst::ReadFstKaldi, + "Read a binary FST using Kaldi I/O mechanisms (pipes, etc.) On error, " + "throws using KALDI_ERR. Note: this doesn't support the text-mode " + "option that we generally like to support.", + py::arg("rxfilename"), py::return_value_policy::reference); + + m.def("ReadFstKaldiGeneric", fst::ReadFstKaldiGeneric, + "Read a binary FST using Kaldi I/O mechanisms (pipes, etc.) If it " + "can't read the FST, if throw_on_err == true it throws using " + "KALDI_ERR; otherwise it prints a warning and returns. Note:this " + "doesn't support the text-mode option that we generally like to " + "support. This version currently supports ConstFst or " + "VectorFst (const-fst can give better performance for " + "decoding).", + py::arg("rxfilename"), py::arg("throw_on_err") = true, + py::return_value_policy::reference); + + m.def("CastOrConvertToVectorFst", &fst::CastOrConvertToVectorFst, + "This function attempts to dynamic_cast the pointer 'fst' (which will " + "likely have been returned by ReadFstGeneric()), to the more derived " + "type VectorFst. If this succeeds, it returns the same " + "pointer; if it fails, it converts the FST type (by creating a new " + "VectorFst initialized by 'fst'), prints a warning, and " + "deletes 'fst'.", + py::arg("fst"), py::return_value_policy::reference); + + m.def("ReadFstKaldi", + (void (*)(std::string, fst::StdVectorFst*)) & fst::ReadFstKaldi, + "Version of ReadFstKaldi() that writes to a pointer. Assumes the FST " + "is binary with no binary marker. Crashes on error.", + py::arg("rxfilename"), py::arg("ofst")); + + m.def("WriteFstKaldi", + (void (*)(const fst::StdVectorFst&, std::string)) & fst::WriteFstKaldi, + "Write an FST using Kaldi I/O mechanisms (pipes, etc.) On error, " + "throws using KALDI_ERR. For use only in code in fstbin/, as it " + "doesn't support the text-mode option.", + py::arg("fst"), py::arg("wxfilename")); + + m.def("WriteFstKaldi", + (void (*)(std::ostream&, bool, const fst::StdVectorFst&)) & + fst::WriteFstKaldi, + "This is a more general Kaldi-type-IO mechanism of writing FSTs to " + "streams, supporting binary or text-mode writing. (note: we just " + "write the integers, symbol tables are not supported). On error, " + "throws using KALDI_ERR.", + py::arg("os"), py::arg("binary"), py::arg("fst")); + + m.def("ReadFstKaldi", + (void (*)(std::istream&, bool, fst::StdVectorFst*)) & fst::ReadFstKaldi, + "A generic Kaldi-type-IO mechanism of reading FSTs from streams, " + "supporting binary or text-mode reading/writing.", + py::arg("is"), py::arg("binary"), py::arg("fst")); + m.def("ReadAndPrepareLmFst", &fst::ReadAndPrepareLmFst, + "Read an FST file for LM (G.fst) and make it an acceptor, and make " + "sure it is sorted on labels", + py::arg("rxfilename"), py::return_value_policy::reference); + + { + // fangjun: it should be called StdVectorFstHolder to match the naming + // convention in OpenFst but kaldi uses only StdArc so there is no confusion + // here. + using PyClass = fst::VectorFstHolder; + py::class_(m, "VectorFstHolder") + .def(py::init<>()) + .def_static("Write", &PyClass::Write, py::arg("os"), py::arg("binary"), + py::arg("t")) + .def("Copy", &PyClass::Copy) + .def("Read", &PyClass::Read, "Reads into the holder.", py::arg("is")); + } +} diff --git a/src/pybind/fstext/kaldi_fst_io_pybind.h b/src/pybind/fstext/kaldi_fst_io_pybind.h new file mode 100644 index 00000000000..936e63cd3a4 --- /dev/null +++ b/src/pybind/fstext/kaldi_fst_io_pybind.h @@ -0,0 +1,26 @@ +// pybind/fstext/kaldi_fst_io_pybind.h + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_PYBIND_FSTEXT_KALDI_FST_IO_PYBIND_H_ +#define KALDI_PYBIND_FSTEXT_KALDI_FST_IO_PYBIND_H_ + +#include "pybind/kaldi_pybind.h" + +void pybind_kaldi_fst_io(py::module& m); + +#endif // KALDI_PYBIND_FSTEXT_KALDI_FST_IO_PYBIND_H_ diff --git a/src/pybind/kaldi.py b/src/pybind/kaldi.py index d2bdcfa8767..99a5747f972 100644 --- a/src/pybind/kaldi.py +++ b/src/pybind/kaldi.py @@ -1,2 +1,6 @@ from kaldi_pybind import * from matrix_reader import * +from symbol_table import * +from util.table import SequentialNnetChainExampleReader +from util.table import RandomAccessNnetChainExampleReader +from util.table import NnetChainExampleWriter diff --git a/src/pybind/kaldi_pybind.cc b/src/pybind/kaldi_pybind.cc index ef577ec02c2..559ad4d0b4c 100644 --- a/src/pybind/kaldi_pybind.cc +++ b/src/pybind/kaldi_pybind.cc @@ -21,13 +21,17 @@ #include +#include "feat/wave_reader_pybind.h" #include "matrix/matrix_common_pybind.h" #include "matrix/matrix_pybind.h" #include "matrix/vector_pybind.h" #include "util/table_types_pybind.h" -#include "feat/wave_reader_pybind.h" -void pybind_matrix(py::module& m); +#include "fst/fst_pybind.h" + +#include "chain/chain_pybind.h" +#include "nnet3/nnet3_pybind.h" + PYBIND11_MODULE(kaldi_pybind, m) { m.doc() = "pybind11 binding of some things from kaldi's " @@ -39,4 +43,8 @@ PYBIND11_MODULE(kaldi_pybind, m) { pybind_vector(m); pybind_table_types(m); pybind_wave_reader(m); + + pybind_fst(m); + pybind_chain(m); + pybind_nnet3(m); } diff --git a/src/pybind/kaldi_pybind.h b/src/pybind/kaldi_pybind.h index 741e652a0eb..7e317e38f32 100644 --- a/src/pybind/kaldi_pybind.h +++ b/src/pybind/kaldi_pybind.h @@ -20,8 +20,10 @@ #ifndef KALDI_PYBIND_KALDI_PYBIND_H_ #define KALDI_PYBIND_KALDI_PYBIND_H_ +#include #include #include + namespace py = pybind11; #endif // KALDI_PYBIND_KALDI_PYBIND_H_ diff --git a/src/pybind/matrix/sparse_matrix_pybind.cc b/src/pybind/matrix/sparse_matrix_pybind.cc new file mode 100644 index 00000000000..6ee045d5f82 --- /dev/null +++ b/src/pybind/matrix/sparse_matrix_pybind.cc @@ -0,0 +1,43 @@ +// pybind/matrix/sparse_matrix_pybind.cc + +// Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "matrix/sparse_matrix_pybind.h" + +#include "matrix/sparse-matrix.h" + +using namespace kaldi; + +// in nnet-example.h, class NnetIO contains a field `GeneralMatrix features` +// so we need to wrap GeneralMatrix + +void pybind_sparse_matrix(py::module& m) { + { + using PyClass = GeneralMatrix; + py::class_( + m, "GeneralMatrix", + "This class is a wrapper that enables you to store a matrix in one of " + "three forms: either as a Matrix, or a CompressedMatrix, or " + "a SparseMatrix. It handles the I/O for you, i.e. you read " + "and write a single object type. It is useful for neural-net training " + "targets which might be sparse or not, and might be compressed or not.") + .def(py::init<>()) + // TODO(fangjun): wrap other methods when needed + ; + } +} diff --git a/src/pybind/matrix/sparse_matrix_pybind.h b/src/pybind/matrix/sparse_matrix_pybind.h new file mode 100644 index 00000000000..eeebc1f6c16 --- /dev/null +++ b/src/pybind/matrix/sparse_matrix_pybind.h @@ -0,0 +1,25 @@ +// pybind/matrix/sparse_matrix_pybind.h + +// Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_PYBIND_MATRIX_SPARSE_MATRIX_PYBIND_H_ +#define KALDI_PYBIND_MATRIX_SPARSE_MATRIX_PYBIND_H_ + +#include "pybind/kaldi_pybind.h" + +void pybind_sparse_matrix(py::module& m); + +#endif // KALDI_PYBIND_MATRIX_SPARSE_MATRIX_PYBIND_H_ diff --git a/src/pybind/matrix/vector_pybind.cc b/src/pybind/matrix/vector_pybind.cc index ca1efd057df..317ad80607e 100644 --- a/src/pybind/matrix/vector_pybind.cc +++ b/src/pybind/matrix/vector_pybind.cc @@ -53,13 +53,12 @@ void pybind_vector(py::module& m) { [](VectorBase& m, int i, float v) { m(i) = v; }); py::class_, VectorBase>(m, "FloatVector", - pybind11::buffer_protocol()) - .def_buffer([](const Vector& v) -> pybind11::buffer_info { - return pybind11::buffer_info( - (void*)v.Data(), sizeof(float), - pybind11::format_descriptor::format(), - 1, // num-axes - {v.Dim()}, {4}); // strides (in chars) + py::buffer_protocol()) + .def_buffer([](const Vector& v) -> py::buffer_info { + return py::buffer_info((void*)v.Data(), sizeof(float), + py::format_descriptor::format(), + 1, // num-axes + {v.Dim()}, {4}); // strides (in chars) }) .def(py::init(), py::arg("size"), py::arg("resize_type") = kSetZero); @@ -76,7 +75,7 @@ void pybind_vector(py::module& m) { KALDI_ERR << "Expected dim: 1\n" << "Current dim: " << info.ndim; } - return new - SubVector(reinterpret_cast(info.ptr), info.shape[0]); + return new SubVector(reinterpret_cast(info.ptr), + info.shape[0]); })); } diff --git a/src/pybind/nnet3/Makefile b/src/pybind/nnet3/Makefile new file mode 100644 index 00000000000..d9fd9059847 --- /dev/null +++ b/src/pybind/nnet3/Makefile @@ -0,0 +1,4 @@ + +test: + python3 ./nnet_chain_example_pybind_test.py + diff --git a/src/pybind/nnet3/aishell_test.ark b/src/pybind/nnet3/aishell_test.ark new file mode 100644 index 00000000000..1f07e91e932 Binary files /dev/null and b/src/pybind/nnet3/aishell_test.ark differ diff --git a/src/pybind/nnet3/aishell_test.scp b/src/pybind/nnet3/aishell_test.scp new file mode 100644 index 00000000000..7c055dc6225 --- /dev/null +++ b/src/pybind/nnet3/aishell_test.scp @@ -0,0 +1,2 @@ +BAC009S0704W0210-87 aishell_test.ark:20 +BAC009S0037W0418-150 aishell_test.ark:9595 diff --git a/src/pybind/nnet3/nnet3_pybind.cc b/src/pybind/nnet3/nnet3_pybind.cc new file mode 100644 index 00000000000..c0ccd5979cb --- /dev/null +++ b/src/pybind/nnet3/nnet3_pybind.cc @@ -0,0 +1,33 @@ +// pybind/nnet3/nnet3_pybind.cc + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "nnet3/nnet3_pybind.h" + +#include "nnet3/nnet_chain_example_pybind.h" +#include "nnet3/nnet_common_pybind.h" +#include "nnet3/nnet_example_pybind.h" + +void pybind_nnet3(py::module& _m) { + py::module m = _m.def_submodule("nnet3", "nnet3 pybind for Kaldi"); + + pybind_nnet_common(m); + pybind_nnet_example(m); + pybind_nnet_chain_example(m); +} diff --git a/src/pybind/nnet3/nnet3_pybind.h b/src/pybind/nnet3/nnet3_pybind.h new file mode 100644 index 00000000000..e2498f85ee0 --- /dev/null +++ b/src/pybind/nnet3/nnet3_pybind.h @@ -0,0 +1,26 @@ +// pybind/nnet3/nnet3_pybind.h + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_PYBIND_NNET3_PYBIND_H_ +#define KALDI_PYBIND_NNET3_PYBIND_H_ + +#include "pybind/kaldi_pybind.h" + +void pybind_nnet3(py::module& m); + +#endif // KALDI_PYBIND_NNET3_PYBIND_H_ diff --git a/src/pybind/nnet3/nnet_chain_example_pybind.cc b/src/pybind/nnet3/nnet_chain_example_pybind.cc new file mode 100644 index 00000000000..a04b8ef2a27 --- /dev/null +++ b/src/pybind/nnet3/nnet_chain_example_pybind.cc @@ -0,0 +1,112 @@ +// pybind/nnet3/nnet_chain_example_pybind.cc + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "nnet3/nnet_chain_example_pybind.h" + +#include "nnet3/nnet-chain-example.h" +#include "util/kaldi_table_pybind.h" + +using namespace kaldi; +using namespace kaldi::nnet3; +using namespace kaldi::chain; + +void pybind_nnet_chain_example(py::module& m) { + { + using PyClass = NnetChainSupervision; + py::class_( + m, "NnetChainSupervision", + "For regular setups we use struct 'NnetIo' as the output. For the " + "'chain' models, the output supervision is a little more complex as it " + "involves a lattice and we need to do forward-backward, so we use a " + "separate struct for it. The 'output' name means that it pertains to " + "the output of the network, as opposed to the features which pertain " + "to the input of the network. It actually stores the lattice-like " + "supervision information at the output of the network (which imposes " + "constraints on which frames each phone can be active on") + .def(py::init<>()) + .def_readwrite("name", &PyClass::name, + "the name of the output in the neural net; in simple " + "setups it will just be 'output'.") + .def_readwrite( + "indexes", &PyClass::indexes, + "The indexes that the output corresponds to. The size of this " + "vector will be equal to supervision.num_sequences * " + "supervision.frames_per_sequence. Be careful about the order of " + "these indexes-- it is a little confusing. The indexes in the " + "'index' vector are ordered as: (frame 0 of each sequence); (frame " + "1 of each sequence); and so on. But in the 'supervision' object, " + "the FST contains (sequence 0; sequence 1; ...). So reordering is " + "needed when doing the numerator computation. We order 'indexes' " + "in this way for efficiency in the denominator computation (it " + "helps memory locality), as well as to avoid the need for the nnet " + "to reorder things internally to match the requested output (for " + "layers inside the neural net, the ordering is (frame 0; frame 1 " + "...) as this corresponds to the order you get when you sort a " + "vector of Index).") + .def_readwrite("supervision", &PyClass::supervision, + "The supervision object, containing the FST.") + .def_readwrite( + "deriv_weights", &PyClass::deriv_weights, + "This is a vector of per-frame weights, required to be between 0 " + "and 1, that is applied to the derivative during training (but not " + "during model combination, where the derivatives need to agree " + "with the computed objf values for the optimization code to work). " + " The reason for this is to more exactly handle edge effects and " + "to ensure that no frames are 'double-counted'. The order of this " + "vector corresponds to the order of the 'indexes' (i.e. all the " + "first frames, then all the second frames, etc.) If this vector is " + "empty it means we're not applying per-frame weights, so it's " + "equivalent to a vector of all ones. This vector is written to " + "disk compactly as unsigned char.") + .def("CheckDim", &PyClass::CheckDim) + .def("__str__", + [](const PyClass& sup) { + std::ostringstream os; + os << "name: " << sup.name << "\n"; + return os.str(); + }) + // TODO(fangjun): other methods can be wrapped when needed + ; + } + { + using PyClass = NnetChainExample; + py::class_(m, "NnetChainExample") + .def(py::init<>()) + .def_readwrite("inputs", &PyClass::inputs) + .def_readwrite("outputs", &PyClass::outputs) + .def("Compress", &PyClass::Compress, + "Compresses the input features (if not compressed)") + .def("__eq__", + [](const PyClass& a, const PyClass& b) { return a == b; }); + + // (fangjun): we follow the PyKaldi style to prepend a underline before the + // registered classes and the user in general should not use them directly; + // instead, they should use the corresponding python classes that are more + // easier to use. + pybind_sequential_table_reader>( + m, "_SequentialNnetChainExampleReader"); + + pybind_random_access_table_reader>( + m, "_RandomAccessNnetChainExampleReader"); + + pybind_table_writer>(m, + "_NnetChainExampleWriter"); + } +} diff --git a/src/pybind/nnet3/nnet_chain_example_pybind.h b/src/pybind/nnet3/nnet_chain_example_pybind.h new file mode 100644 index 00000000000..a9abedebc91 --- /dev/null +++ b/src/pybind/nnet3/nnet_chain_example_pybind.h @@ -0,0 +1,26 @@ +// pybind/nnet3/nnet_chain_example_pybind.h + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_PYBIND_NNET3_CHAIN_EXAMPLE_PYBIND_H_ +#define KALDI_PYBIND_NNET3_CHAIN_EXAMPLE_PYBIND_H_ + +#include "pybind/kaldi_pybind.h" + +void pybind_nnet_chain_example(py::module& m); + +#endif // KALDI_PYBIND_NNET3_CHAIN_EXAMPLE_PYBIND_H_ diff --git a/src/pybind/nnet3/nnet_chain_example_pybind_test.py b/src/pybind/nnet3/nnet_chain_example_pybind_test.py new file mode 100755 index 00000000000..66776f9facd --- /dev/null +++ b/src/pybind/nnet3/nnet_chain_example_pybind_test.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 + +# Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Apache 2.0 + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), os.pardir)) + +import unittest + +import kaldi_pybind + +import kaldi_pybind.chain as chain +import kaldi_pybind.nnet3 as nnet3 +from kaldi import NnetChainExampleWriter +from kaldi import RandomAccessNnetChainExampleReader +from kaldi import SequentialNnetChainExampleReader +from kaldi_pybind.fst import StdVectorFst + + +class TestNnetChainExample(unittest.TestCase): + + def test_nnet_chain_example(self): + + # TODO(fangjun): find a place to store the test data + egs_rspecifier = 'scp:./aishell_test.scp' + reader = SequentialNnetChainExampleReader(egs_rspecifier) + for key, value in reader: + inputs = value.inputs + self.assertEqual(len(inputs), 1) + + nnet_io = inputs[0] + self.assertTrue(isinstance(nnet_io, nnet3.NnetIo)) + self.assertEqual(nnet_io.name, 'input') + # its `features` has not been wrapped yet. + + self.assertTrue(isinstance(key, str)) + self.assertTrue(isinstance(value, nnet3.NnetChainExample)) + outputs = value.outputs + num_outputs = len(outputs) + self.assertEqual(num_outputs, 1) + + nnet_chain_sup = outputs[0] + self.assertTrue( + isinstance(nnet_chain_sup, nnet3.NnetChainSupervision)) + self.assertEqual(nnet_chain_sup.name, 'output') + + sup = nnet_chain_sup.supervision + self.assertTrue(isinstance(sup, chain.Supervision)) + weight = sup.weight + self.assertEqual(sup.weight, 1) + self.assertEqual(sup.num_sequences, 1) + # we have to egs in the ark, with 30 and 50 frames per sequence respectively + self.assertTrue(sup.frames_per_sequence == 30 or + sup.frames_per_sequence == 50) + self.assertEqual(sup.label_dim, 4336) + + # now comes to the FST part !!! + fst = sup.fst + self.assertTrue(isinstance(sup.fst, StdVectorFst)) + # see pybind/fst/vector_fst_pybind_test.py for operations wrapped for fst::StdVectorFst + # TODO(fangjun): finish the test + + +if __name__ == '__main__': + unittest.main() diff --git a/src/pybind/nnet3/nnet_common_pybind.cc b/src/pybind/nnet3/nnet_common_pybind.cc new file mode 100644 index 00000000000..4959b8f174d --- /dev/null +++ b/src/pybind/nnet3/nnet_common_pybind.cc @@ -0,0 +1,59 @@ +// pybind/nnet3/nnet_common_pybind.cc + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "nnet3/nnet_common_pybind.h" + +#include "nnet3/nnet-common.h" + +using namespace kaldi; +using namespace kaldi::nnet3; + +void pybind_nnet_common(py::module& m) { + { + // Index is need by NnetChainSupervision in nnet_chain_example_pybind.cc + using PyClass = Index; + py::class_( + m, "Index", + "struct Index is intended to represent the various indexes by which we " + "number the rows of the matrices that the Components process: mainly " + "'n', the index of the member of the minibatch, 't', used for the " + "frame index in speech recognition, and 'x', which is a catch-all " + "extra index which we might use in convolutional setups or for other " + "reasons. It is possible to extend this by adding new indexes if " + "needed.") + .def(py::init<>()) + .def(py::init(), py::arg("n"), py::arg("t"), + py::arg("x") = 0) + .def_readwrite("n", &PyClass::n, "member-index of minibatch, or zero.") + .def_readwrite("t", &PyClass::t, "time-frame.") + .def_readwrite("x", &PyClass::x, + "this may come in useful in convolutional approaches. " + "it is possible to add extra index here, if needed.") + .def("__eq__", + [](const PyClass& a, const PyClass& b) { return a == b; }) + .def("__ne__", + [](const PyClass& a, const PyClass& b) { return a != b; }) + .def("__lt__", [](const PyClass& a, const PyClass& b) { return a < b; }) + .def(py::self + py::self) + .def(py::self += py::self) + // TODO(fangjun): other methods can be wrapped when needed + ; + } +} diff --git a/src/pybind/nnet3/nnet_common_pybind.h b/src/pybind/nnet3/nnet_common_pybind.h new file mode 100644 index 00000000000..ab752a87375 --- /dev/null +++ b/src/pybind/nnet3/nnet_common_pybind.h @@ -0,0 +1,26 @@ +// pybind/nnet3/nnet_common_pybind.h + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_PYBIND_NNET3_NNET_COMMON_PYBIND_H_ +#define KALDI_PYBIND_NNET3_NNET_COMMON_PYBIND_H_ + +#include "pybind/kaldi_pybind.h" + +void pybind_nnet_common(py::module& m); + +#endif // KALDI_PYBIND_NNET3_NNET_COMMON_PYBIND_H_ diff --git a/src/pybind/nnet3/nnet_example_pybind.cc b/src/pybind/nnet3/nnet_example_pybind.cc new file mode 100644 index 00000000000..4aa3e6f8a5f --- /dev/null +++ b/src/pybind/nnet3/nnet_example_pybind.cc @@ -0,0 +1,38 @@ +// pybind/nnet3/nnet_example_pybind.cc + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "nnet3/nnet_example_pybind.h" + +#include "nnet3/nnet-example.h" + +using namespace kaldi; +using namespace kaldi::nnet3; + +void pybind_nnet_example(py::module& m) { + { + using PyClass = NnetIo; + py::class_(m, "NnetIo") + .def(py::init<>()) + .def_readwrite("name", &PyClass::name, + "the name of the input in the neural net; in simple " + "setups it will just be 'input'."); + // TODO(fangjun): other constructors, fields and methods can be wrapped when + } +} diff --git a/src/pybind/nnet3/nnet_example_pybind.h b/src/pybind/nnet3/nnet_example_pybind.h new file mode 100644 index 00000000000..5fc4d321497 --- /dev/null +++ b/src/pybind/nnet3/nnet_example_pybind.h @@ -0,0 +1,26 @@ +// pybind/nnet3/nnet_example_pybind.h + +// Copyright 2019 Mobvoi AI Lab, Beijing, China +// (author: Fangjun Kuang, Yaguang Hu, Jian Wang) + +// See ../../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_PYBIND_NNET3_NNET_EXAMPLE_PYBIND_H_ +#define KALDI_PYBIND_NNET3_NNET_EXAMPLE_PYBIND_H_ + +#include "pybind/kaldi_pybind.h" + +void pybind_nnet_example(py::module& m); + +#endif // KALDI_PYBIND_NNET3_NNET_EXAMPLE_PYBIND_H_ diff --git a/src/pybind/symbol_table.py b/src/pybind/symbol_table.py new file mode 100644 index 00000000000..d4684ffd223 --- /dev/null +++ b/src/pybind/symbol_table.py @@ -0,0 +1,14 @@ +# Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Apache 2.0 + +import kaldi_pybind.fst as fst + + +class SymbolTableIterator(fst.SymbolTableIterator): + + def __iter__(self): + while not self.Done(): + index = self.Value() + symbol = self.Symbol() + yield index, symbol + self.Next() diff --git a/src/pybind/tests/test_arc.py b/src/pybind/tests/test_arc.py new file mode 100755 index 00000000000..71bfe396132 --- /dev/null +++ b/src/pybind/tests/test_arc.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 + +# Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Apache 2.0 + +import math # for math.isnan +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), os.pardir)) + +import unittest + +import numpy as np + +import kaldi_pybind.fst as fst + + +class TestArc(unittest.TestCase): + + def test_std_arc(self): + arc = fst.StdArc() + + self.assertEqual(arc.Type(), 'standard') + self.assertEqual(fst.StdArc.Type(), 'standard') + + ilabel = 0 + olabel = 1 + weight = fst.TropicalWeight.One() + nextstate = 2 + + arc = fst.StdArc(ilabel=ilabel, + olabel=olabel, + weight=weight, + nextstate=nextstate) + self.assertEqual(arc.ilabel, ilabel) + self.assertEqual(arc.olabel, olabel) + self.assertEqual(arc.weight, weight) + self.assertEqual(arc.nextstate, nextstate) + self.assertEqual(str(arc), + '(ilabel: 0, olabel: 1, weight: 0, nextstate: 2)') + + +if __name__ == '__main__': + unittest.main() diff --git a/src/pybind/tests/test_fst.py b/src/pybind/tests/test_fst.py new file mode 100755 index 00000000000..307e1f8b4b1 --- /dev/null +++ b/src/pybind/tests/test_fst.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 + +# Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Apache 2.0 + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), os.pardir)) + +import unittest + +import numpy as np + +import kaldi_pybind.fst as fst + + +class TestArc(unittest.TestCase): + + def test_FstWriteOptions(self): + source = "source name" + write_header = True + write_isymbols = False + write_osymbols = True + align = False + stream_write = True + + opt = fst.FstWriteOptions(source=source, + write_header=write_header, + write_isymbols=write_isymbols, + write_osymbols=write_osymbols, + align=align, + stream_write=stream_write) + + self.assertEqual(opt.source, source) + self.assertEqual(opt.write_header, write_header) + self.assertEqual(opt.write_isymbols, write_isymbols) + self.assertEqual(opt.write_osymbols, write_osymbols) + self.assertEqual(opt.align, align) + self.assertEqual(opt.stream_write, stream_write) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/pybind/tests/test_kaldi_pybind.py b/src/pybind/tests/test_kaldi_pybind.py index cad6416a0ff..aebbe4d4997 100755 --- a/src/pybind/tests/test_kaldi_pybind.py +++ b/src/pybind/tests/test_kaldi_pybind.py @@ -1,4 +1,5 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 + import unittest import numpy as np import os @@ -9,7 +10,9 @@ import kaldi_pybind as kp import kaldi + class TestKaldiPybind(unittest.TestCase): + def test_float_vector(self): # test FloatVector print("=====Testing FloatVector in kaldi_pybind=====") @@ -24,9 +27,9 @@ def test_float_vector(self): print("np_array[2:5] = 2.0") np_array[2:] = 2.0 print(np_array) - + gold = np.array([0, 0, 2, 2, 2]) - self.assertTrue((np_array==gold).all()) + self.assertTrue((np_array == gold).all()) def test_float_matrix(self): # test FloatMatrix @@ -42,13 +45,14 @@ def test_float_matrix(self): print("np_matrix[2][3] = 2.0") np_matrix[2][3] = 2.0 print(np_matrix) - - gold = np.array([[0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 2, 0], - [0, 0, 0, 0, 0],]) - self.assertTrue((np_matrix==gold).all()) - + + gold = np.array([ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 2, 0], + [0, 0, 0, 0, 0], + ]) + self.assertTrue((np_matrix == gold).all()) def test_matrix_reader_writer(self): print("=====Testing Matrix Reader/Writer in kaldi_pybind=====") @@ -67,10 +71,9 @@ def test_matrix_reader_writer(self): self.assertEqual(key, "id_1") value = matrix_reader.Value() - gold = np.array([[0, 0, 0], - [0, 0, 0]]) + gold = np.array([[0, 0, 0], [0, 0, 0]]) print("Read matrix: {}".format(value)) - self.assertTrue((np.array(value, copy=False)==gold).all()) + self.assertTrue((np.array(value, copy=False) == gold).all()) def test_matrix_reader_iterator(self): print("=====Testing Matrix Reader Iterator=====") @@ -84,13 +87,12 @@ def test_matrix_reader_iterator(self): matrix_writer.Close() gold_key_list = ["id_1"] - gold_value_list = [np.array([[0, 0, 0], - [0, 0, 0]])] + gold_value_list = [np.array([[0, 0, 0], [0, 0, 0]])] for (key, value), gold_key, gold_value in zip( - kaldi.ReaderIterator(kaldi.SequentialMatrixReader(rspecifier)), - gold_key_list, gold_value_list): + kaldi.ReaderIterator(kaldi.SequentialMatrixReader(rspecifier)), + gold_key_list, gold_value_list): self.assertEqual(key, gold_key) - self.assertTrue((value==gold_value).all()) + self.assertTrue((value == gold_value).all()) print(key, value) def test_matrix_reader_dict(self): @@ -104,11 +106,11 @@ def test_matrix_reader_dict(self): matrix_writer.Write("id_1", kp_matrix) matrix_writer.Close() - reader_dict = kaldi.ReaderDict(kaldi.RandomAccessMatrixReader(rspecifier)) - gold = np.array([[0, 0, 0], - [0, 0, 0]]) + reader_dict = kaldi.ReaderDict( + kaldi.RandomAccessMatrixReader(rspecifier)) + gold = np.array([[0, 0, 0], [0, 0, 0]]) self.assertTrue("id_1" in reader_dict) - self.assertTrue((np.array(reader_dict["id_1"])==gold).all()) + self.assertTrue((np.array(reader_dict["id_1"]) == gold).all()) self.assertFalse("id_2" in reader_dict) diff --git a/src/pybind/tests/test_matrix.py b/src/pybind/tests/test_matrix.py index 5457ddbfaa4..b491c8280f9 100755 --- a/src/pybind/tests/test_matrix.py +++ b/src/pybind/tests/test_matrix.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) # Apache 2.0 diff --git a/src/pybind/tests/test_wave.py b/src/pybind/tests/test_wave.py index aae2a504ebd..6b44d631554 100644 --- a/src/pybind/tests/test_wave.py +++ b/src/pybind/tests/test_wave.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # Copyright 2019 Microsoft Corporation (author: Xingyu Na) # Apache 2.0 @@ -13,6 +13,7 @@ import kaldi + class TestWaveData(unittest.TestCase): def test_duration(self): @@ -20,6 +21,6 @@ def test_duration(self): wave_data = kaldi.WaveData(samp_freq=16000, data=waveform) self.assertEqual(1, wave_data.Duration()) + if __name__ == '__main__': unittest.main() - diff --git a/src/pybind/tests/test_weight.py b/src/pybind/tests/test_weight.py new file mode 100755 index 00000000000..42d7f82a953 --- /dev/null +++ b/src/pybind/tests/test_weight.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +# Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Apache 2.0 + +import math # for math.isnan +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), os.pardir)) + +import unittest + +import numpy as np + +import kaldi_pybind.fst as fst + + +class TestWeight(unittest.TestCase): + + def test_float_weight(self): + w = fst.FloatWeight(100) + self.assertEqual(w.Value(), 100) + self.assertEqual(str(w), '100') + + def test_tropical_weight(self): + w = fst.TropicalWeight(100) + self.assertEqual(w.Value(), 100) + self.assertEqual(str(w), '100') + self.assertEqual(w.Type(), 'tropical') + + one = w.One() + self.assertEqual(one.Value(), 0) + + zero = fst.TropicalWeight.Zero() + self.assertEqual(zero.Value(), float('inf')) + + self.assertTrue(math.isnan(w.NoWeight().Value())) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/pybind/util/kaldi_table_pybind.h b/src/pybind/util/kaldi_table_pybind.h new file mode 100644 index 00000000000..4cf7b3d7d4d --- /dev/null +++ b/src/pybind/util/kaldi_table_pybind.h @@ -0,0 +1,143 @@ +// pybind/util/kaldi_table_pybind.h + +// Copyright 2019 Daniel Povey +// 2019 Dongji Gao +// 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_PYBIND_UTIL_KALDI_TABLE_PYBIND_H_ +#define KALDI_PYBIND_UTIL_KALDI_TABLE_PYBIND_H_ + +#include "pybind/kaldi_pybind.h" + +#include "util/kaldi-table.h" + +using namespace kaldi; + +template +void pybind_sequential_table_reader(py::module& m, + const std::string& class_name, + const std::string& class_help_doc = "") { + using PyClass = SequentialTableReader; + py::class_(m, class_name.c_str(), class_help_doc.c_str()) + .def(py::init<>()) + .def(py::init(), + "This constructor equivalent to default constructor + 'open', but " + "throws on error.", + py::arg("rspecifier")) + .def("Open", &PyClass::Open, + "Opens the table. Returns exit status; but does throw if " + "previously open stream was in error state. You can call Close to " + "prevent this; anyway, calling Open more than once is not usually " + "needed.", + py::arg("rspecifier")) + .def("Done", &PyClass::Done, + "Returns true if we're done. It will also return true if there's " + "some kind of error and we can't read any more; in this case, you " + "can detect the error by calling Close and checking the return " + "status; otherwise the destructor will throw.") + .def("Key", &PyClass::Key, + "Only valid to call Key() if Done() returned false.") + .def("FreeCurrent", &PyClass::FreeCurrent, + "FreeCurrent() is provided as an optimization to save memory, for " + "large objects. It instructs the class to deallocate the current " + "value. The reference Value() will be invalidated by this.") + .def("Value", &PyClass::Value, + "Return reference to the current value. It's only valid to call " + "this if Done() returned false. The reference is valid till next " + "call to this object. It will throw if you are reading an scp " + "file, did not specify the 'permissive' (p) option and the file " + "cannot be read. [The permissive option makes it behave as if that " + "key does not even exist, if the corresponding file cannot be " + "read.] You probably wouldn't want to catch this exception; the " + "user can just specify the p option in the rspecifier. We make this " + "non-const to enable things like shallow swap on the held object in " + "situations where this would avoid making a redundant copy.", + py::return_value_policy::copy) + .def("Next", &PyClass::Next, + "Next goes to the next key. It will not throw; any error will " + "result in Done() returning true, and then the destructor will " + "throw unless you call Close().") + .def("IsOpen", &PyClass::IsOpen, + "Returns true if table is open for reading (does not imply stream " + "is in good state).") + .def("Close", &PyClass::Close, + "Close() will return false (failure) if Done() became true because " + "of an error/ condition rather than because we are really done " + "[e.g. because of an error or early termination in the archive]. If " + "there is an error and you don't call Close(), the destructor will " + "fail. Close()"); +} + +template +void pybind_random_access_table_reader(py::module& m, + const std::string& class_name, + const std::string& class_help_doc = "") { + using PyClass = RandomAccessTableReader; + py::class_(m, class_name.c_str(), class_help_doc.c_str()) + .def(py::init<>()) + .def(py::init(), + "This constructor equivalent to default constructor + 'open', but " + "throws on error.", + py::arg("rspecifier")) + .def("Open", &PyClass::Open, "Opens the table.", py::arg("rspecifier")) + .def("IsOpen", &PyClass::IsOpen, "Returns true if table is open") + .def("Close", &PyClass::Close, + "Close() will close the table [throws if it was not open], and " + "returns true on success (false if we were reading an archive and " + "we discovered an error in the archive).") + .def("HasKey", &PyClass::HasKey, + "Says if it has this key. If you are using the 'permissive' (p) " + "read option, it will return false for keys whose corresponding " + "entry in the scp file cannot be read.", + py::arg("key")) + .def("Value", &PyClass::Value, + "Value() may throw if you are reading an scp file, you do not have " + "the ' permissive' (p) option, and an entry in the scp file cannot " + "be read. Typically you won't want to catch this error.", + py::return_value_policy::copy); +} + +template +void pybind_table_writer(py::module& m, const std::string& class_name, + const string& class_help_doc = "") { + using PyClass = TableWriter; + py::class_(m, class_name.c_str(), class_help_doc.c_str()) + .def(py::init<>()) + .def(py::init(), + "This constructor equivalent to default constructor + 'open', but " + "throws on error.", + py::arg("wspecifier")) + .def("Open", &PyClass::Open, + "Opens the table. See docs for wspecifier above. If it returns " + "true, it is open.", + py::arg("wspecifier")) + .def("IsOpen", &PyClass::IsOpen, "Returns true if open for writing.") + .def("Write", &PyClass::Write, + "Write the object. Throws KaldiFatalError on error via the " + "KALDI_ERR macro.", + py::arg("key"), py::arg("value")) + .def("Flush", &PyClass::Flush, + "Flush will flush any archive; it does not return error status or " + "throw, any errors will be reported on the next Write or Close. " + "Useful if we may be writing to a command in a pipe and want to " + "ensure good CPU utilization.") + .def("Close", &PyClass::Close, + "Close() is not necessary to call, as the destructor closes it; " + "it's mainly useful if you want to handle error states because the " + "destructor will throw on error if you do not call Close()."); +} + +#endif // KALDI_PYBIND_UTIL_KALDI_TABLE_PYBIND_H_ diff --git a/src/pybind/util/table.py b/src/pybind/util/table.py new file mode 100644 index 00000000000..dc1c362a3e4 --- /dev/null +++ b/src/pybind/util/table.py @@ -0,0 +1,913 @@ +""" +This file is modified from the PyKaldi project +https://github.com/pykaldi/pykaldi/blob/master/kaldi/util/table.py +""" +# +# +# Author: Dogan Can +# Author: Fanjun Kuang +# +# +""" +For detailed documentation of Kaldi tables, table readers/writers, table +read/write specifiers, see `Kaldi I/O mechanisms`_ and +`Kaldi I/O from a command-line perspective`_. + +.. _Kaldi I/O mechanisms: + http://kaldi-asr.org/doc/io.html +.. _Kaldi I/O from a command-line perspective: + http://kaldi-asr.org/doc/io_tut.html +""" + +# TODO(fangjun): set the PYTHONPATH environment variable outside this script +# to avoid set sys.path for every Python script + +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), os.pardir)) + +from kaldi_pybind.nnet3 import _SequentialNnetChainExampleReader +from kaldi_pybind.nnet3 import _RandomAccessNnetChainExampleReader +from kaldi_pybind.nnet3 import _NnetChainExampleWriter + +################################################################################ +# Sequential Readers +################################################################################ + + +class _SequentialReaderBase(object): + """Base class defining the Python API for sequential table readers.""" + + def __init__(self, rspecifier=""): + """ + This class is used for reading objects sequentially from an archive or + script file. It implements the iterator protocol similar to how Python + implements iteration over dictionaries. Each iteration returns a `(key, + value)` pair from the table in sequential order. + + Args: + rspecifier(str): Kaldi rspecifier for reading the table. + If provided, the table is opened for reading. + + Raises: + IOError: If opening the table for reading fails. + """ + super(_SequentialReaderBase, self).__init__() + if rspecifier != "": + if not self.Open(rspecifier): + raise IOError("Error opening sequential table reader with " + "rspecifier: {}".format(rspecifier)) + + def __enter__(self): + return self + + def __iter__(self): + while not self.Done(): + key = self.Key() + value = self.Value() + self.Next() + # WARNING(fangjun): after calling self.Next(), the `value` + # returned above is invalidated, so we use copy semantics + # in the C++ binding code. But if the user does not the pythonic + # way for iteration, the extra copy in the C++ binding code is unncessary. + yield key, value + + def Open(self, rspecifier): + """Opens the table for reading. + + Args: + rspecifier(str): Kaldi rspecifier for reading the table. + If provided, the table is opened for reading. + + Returns: + True if table is opened successfully, False otherwise. + + Raises: + IOError: If opening the table for reading fails. + """ + return super(_SequentialReaderBase, self).Open(rspecifier) + + def Done(self): + """Indicates whether the table reader is exhausted or not. + + This method is provided for compatibility with the C++ API only; + most users should use the Pythonic API. + + Returns: + True if the table reader is exhausted, False otherwise. + """ + return super(_SequentialReaderBase, self).Done() + + def Key(self): + """Returns the current key. + + This method is provided for compatibility with the C++ API only; + most users should use the Pythonic API. + + Returns: + str: The current key. + """ + return super(_SequentialReaderBase, self).Key() + + def FreeCurrent(self): + """Deallocates the current value. + + This method is provided as an optimization to save memory, for large + objects. + """ + super(_SequentialReaderBase, self).FreeCurrent() + + def Value(self): + """Returns the current value. + + This method is provided for compatibility with the C++ API only; + most users should use the Pythonic API. + + Returns: + The current value. + """ + return super(_SequentialReaderBase, self).Value() + + def Next(self): + """Advances the table reader. + + This method is provided for compatibility with the C++ API only; + most users should use the Pythonic API. + """ + super(_SequentialReaderBase, self).Next() + + def IsOpen(self): + """Indicates whether the table reader is open or not. + + This method is provided for compatibility with the C++ API only; + most users should use the Pythonic API. + + Returns: + True if the table reader is open, False otherwise. + """ + return super(_SequentialReaderBase, self).IsOpen() + + def Close(self): + """Closes the table. + + This method is provided for compatibility with the C++ API only; + most users should use the Pythonic API. + + Returns: + True if table is closed successfully, False otherwise. + """ + return super(_SequentialReaderBase, self).Close() + + +class SequentialNnetChainExampleReader(_SequentialReaderBase, + _SequentialNnetChainExampleReader): + """Sequential table reader for nnet chain examples.""" + pass + + +################################################################################ +# Random Access Readers +################################################################################ + + +class _RandomAccessReaderBase(object): + """Base class defining the Python API for random access table readers.""" + + def __init__(self, rspecifier=""): + """ + This class is used for randomly accessing objects in an archive or + script file. It implements `__contains__` and `__getitem__` methods to + provide a dictionary-like interface for accessing table entries. e.g. + `reader[key]` returns the `value` associated with the `key`. + + Args: + rspecifier(str): Kaldi rspecifier for reading the table. + If provided, the table is opened for reading. + + Raises: + IOError: If opening the table for reading fails. + """ + super(_RandomAccessReaderBase, self).__init__() + if rspecifier != "": + if not self.Open(rspecifier): + raise IOError("Error opening random access table reader with " + "rspecifier: {}".format(rspecifier)) + + def __enter__(self): + return self + + def __contains__(self, key): + return self.HasKey(key) + + def __getitem__(self, key): + if self.HasKey(key): + return self.Value(key) + else: + raise KeyError(key) + + def Open(self, rspecifier): + """Opens the table for reading. + + Args: + rspecifier(str): Kaldi rspecifier for reading the table. + If provided, the table is opened for reading. + + Returns: + True if table is opened successfully, False otherwise. + + Raises: + IOError: If opening the table for reading fails. + """ + return super(_RandomAccessReaderBase, self).Open(rspecifier) + + def HasKey(self, key): + """Checks whether the table has the key. + + This method is provided for compatibility with the C++ API only; + most users should use the Pythonic API. + + Args: + key (str): The key. + + Returns: + True if the table has the key, False otherwise. + """ + return super(_RandomAccessReaderBase, self).HasKey(key) + + def Value(self, key): + """Returns the value associated with the key. + + This method is provided for compatibility with the C++ API only; + most users should use the Pythonic API. + + Args: + key (str): The key. + + Returns: + The value associated with the key. + """ + return super(_RandomAccessReaderBase, self).Value(key) + + def IsOpen(self): + """Indicates whether the table reader is open or not. + + This method is provided for compatibility with the C++ API only; + most users should use the Pythonic API. + + Returns: + True if the table reader is open, False otherwise. + """ + return super(_RandomAccessReaderBase, self).IsOpen() + + def Close(self): + """Closes the table. + + This method is provided for compatibility with the C++ API only; + most users should use the Pythonic API. + + Returns: + True if table is closed successfully, False otherwise. + """ + return super(_RandomAccessReaderBase, self).Close() + + +class RandomAccessNnetChainExampleReader(_RandomAccessReaderBase, + _RandomAccessNnetChainExampleReader): + """Random access table reader for nnet chain examples.""" + pass + + +################################################################################ +# Writers +################################################################################ + + +class _WriterBase(object): + """Base class defining the additional Python API for table writers.""" + + def __init__(self, wspecifier=""): + """ + + This class is used for writing objects to an archive or script file. It + implements the `__setitem__` method to provide a dictionary-like + interface for writing table entries, e.g. `writer[key] = value` writes + the pair `(key, value)` to the table. + + Args: + wspecifier (str): Kaldi wspecifier for writing the table. + If provided, the table is opened for writing. + + Raises: + IOError: If opening the table for writing fails. + """ + super(_WriterBase, self).__init__() + if wspecifier != "": + if not self.Open(wspecifier): + raise IOError( + "Error opening table writer with wspecifier: {}".format( + wspecifier)) + + def __enter__(self): + return self + + def __setitem__(self, key, value): + self.Write(key, value) + + def Open(self, wspecifier): + """Opens the table for writing. + + Args: + wspecifier(str): Kaldi wspecifier for writing the table. + If provided, the table is opened for writing. + + Returns: + True if table is opened successfully, False otherwise. + + Raises: + IOError: If opening the table for writing fails. + """ + return super(_WriterBase, self).Open(wspecifier) + + def Flush(self): + """Flushes the table contents to disk/pipe.""" + super(_WriterBase, self).Flush() + + def Write(self, key, value): + """Writes the `(key, value)` pair to the table. + + This method is provided for compatibility with the C++ API only; + most users should use the Pythonic API. + + Args: + key (str): The key. + value: The value. + """ + super(_WriterBase, self).Write(key, value) + + def IsOpen(self): + """Indicates whether the table writer is open or not. + + This method is provided for compatibility with the C++ API only; + most users should use the Pythonic API. + + Returns: + True if the table writer is open, False otherwise. + """ + return super(_WriterBase, self).IsOpen() + + def Close(self): + """Closes the table. + + This method is provided for compatibility with the C++ API only; + most users should use the Pythonic API. + + Returns: + True if table is closed successfully, False otherwise. + """ + return super(_WriterBase, self).Close() + + +class NnetChainExampleWriter(_WriterBase, _NnetChainExampleWriter): + """Table writer for nnet chain examples.""" + pass + + +if False: + # TODO(fangjun): enable the following once other wrappers are added + + class SequentialVectorReader(_SequentialReaderBase, + _kaldi_table.SequentialVectorReader): + """Sequential table reader for single precision vectors.""" + pass + + class SequentialDoubleVectorReader(_SequentialReaderBase, + _kaldi_table.SequentialDoubleVectorReader + ): + """Sequential table reader for double precision vectors.""" + pass + + class SequentialMatrixReader(_SequentialReaderBase, + _kaldi_table.SequentialMatrixReader): + """Sequential table reader for single precision matrices.""" + pass + + class SequentialDoubleMatrixReader(_SequentialReaderBase, + _kaldi_table.SequentialDoubleMatrixReader + ): + """Sequential table reader for double precision matrices.""" + pass + + class SequentialWaveReader(_SequentialReaderBase, + _kaldi_table.SequentialWaveReader): + """Sequential table reader for wave files.""" + pass + + class SequentialWaveInfoReader(_SequentialReaderBase, + _kaldi_table.SequentialWaveInfoReader): + """Sequential table reader for wave file headers.""" + pass + + class SequentialPosteriorReader(_SequentialReaderBase, + _kaldi_table.SequentialPosteriorReader): + """Sequential table reader for frame posteriors.""" + pass + + class SequentialGaussPostReader(_SequentialReaderBase, + _kaldi_table.SequentialGaussPostReader): + """Sequential table reader for Gaussian-level frame posteriors.""" + pass + + class SequentialFstReader(_SequentialReaderBase, + _kaldi_table_ext.SequentialFstReader): + """Sequential table reader for FSTs over the tropical semiring.""" + pass + + class SequentialLogFstReader(_SequentialReaderBase, + _kaldi_table_ext.SequentialLogFstReader): + """Sequential table reader for FSTs over the log semiring.""" + pass + + class SequentialKwsIndexFstReader( + _SequentialReaderBase, + _kaldi_table_ext.SequentialKwsIndexFstReader): + """Sequential table reader for FSTs over the KWS index semiring.""" + pass + + class SequentialLatticeReader(_SequentialReaderBase, + _kaldi_table.SequentialLatticeReader): + """Sequential table reader for lattices.""" + pass + + class SequentialCompactLatticeReader( + _SequentialReaderBase, _kaldi_table.SequentialCompactLatticeReader): + """Sequential table reader for compact lattices.""" + pass + + class SequentialNnetExampleReader(_SequentialReaderBase, + _kaldi_table.SequentialNnetExampleReader): + """Sequential table reader for nnet examples.""" + pass + + class SequentialRnnlmExampleReader(_SequentialReaderBase, + _kaldi_table.SequentialRnnlmExampleReader + ): + """Sequential table reader for RNNLM examples.""" + pass + + class SequentialIntReader(_SequentialReaderBase, + _kaldi_table.SequentialIntReader): + """Sequential table reader for integers.""" + pass + + class SequentialFloatReader(_SequentialReaderBase, + _kaldi_table.SequentialFloatReader): + """Sequential table reader for single precision floats.""" + pass + + class SequentialDoubleReader(_SequentialReaderBase, + _kaldi_table.SequentialDoubleReader): + """Sequential table reader for double precision floats.""" + pass + + class SequentialBoolReader(_SequentialReaderBase, + _kaldi_table.SequentialBoolReader): + """Sequential table reader for Booleans.""" + pass + + class SequentialIntVectorReader(_SequentialReaderBase, + _kaldi_table.SequentialIntVectorReader): + """Sequential table reader for integer sequences.""" + pass + + class SequentialIntVectorVectorReader( + _SequentialReaderBase, + _kaldi_table.SequentialIntVectorVectorReader): + """Sequential table reader for sequences of integer sequences.""" + pass + + class SequentialIntPairVectorReader( + _SequentialReaderBase, _kaldi_table.SequentialIntPairVectorReader): + """Sequential table reader for sequences of integer pairs.""" + pass + + class SequentialFloatPairVectorReader( + _SequentialReaderBase, + _kaldi_table.SequentialFloatPairVectorReader): + """Sequential table reader for sequences of single precision float pairs.""" + pass + + class RandomAccessVectorReader(_RandomAccessReaderBase, + _kaldi_table.RandomAccessVectorReader): + """Random access table reader for single precision vectors.""" + pass + + class RandomAccessDoubleVectorReader( + _RandomAccessReaderBase, + _kaldi_table.RandomAccessDoubleVectorReader): + """Random access table reader for double precision vectors.""" + pass + + class RandomAccessMatrixReader(_RandomAccessReaderBase, + _kaldi_table.RandomAccessMatrixReader): + """Random access table reader for single precision matrices.""" + pass + + class RandomAccessDoubleMatrixReader( + _RandomAccessReaderBase, + _kaldi_table.RandomAccessDoubleMatrixReader): + """Random access table reader for double precision matrices.""" + pass + + class RandomAccessWaveReader(_RandomAccessReaderBase, + _kaldi_table.RandomAccessWaveReader): + """Random access table reader for wave files.""" + pass + + class RandomAccessWaveInfoReader(_RandomAccessReaderBase, + _kaldi_table.RandomAccessWaveInfoReader): + """Random access table reader for wave file headers.""" + pass + + class RandomAccessPosteriorReader(_RandomAccessReaderBase, + _kaldi_table.RandomAccessPosteriorReader): + """Random access table reader for frame posteriors.""" + pass + + class RandomAccessGaussPostReader(_RandomAccessReaderBase, + _kaldi_table.RandomAccessGaussPostReader): + """Random access table reader for Gaussian-level frame posteriors.""" + pass + + class RandomAccessFstReader(_RandomAccessReaderBase, + _kaldi_table_ext.RandomAccessFstReader): + """Random access table reader for FSTs over the tropical semiring.""" + pass + + class RandomAccessLogFstReader(_RandomAccessReaderBase, + _kaldi_table_ext.RandomAccessLogFstReader): + """Random access table reader for FSTs over the log semiring.""" + pass + + class RandomAccessKwsIndexFstReader( + _RandomAccessReaderBase, + _kaldi_table_ext.RandomAccessKwsIndexFstReader): + """Random access table reader for FSTs over the KWS index semiring.""" + pass + + class RandomAccessLatticeReader(_RandomAccessReaderBase, + _kaldi_table.RandomAccessLatticeReader): + """Random access table reader for lattices.""" + pass + + class RandomAccessCompactLatticeReader( + _RandomAccessReaderBase, + _kaldi_table.RandomAccessCompactLatticeReader): + """Random access table reader for compact lattices.""" + pass + + class RandomAccessNnetExampleReader( + _RandomAccessReaderBase, + _kaldi_table.RandomAccessNnetExampleReader): + """Random access table reader for nnet examples.""" + pass + + class RandomAccessIntReader(_RandomAccessReaderBase, + _kaldi_table.RandomAccessIntReader): + """Random access table reader for integers.""" + pass + + class RandomAccessFloatReader(_RandomAccessReaderBase, + _kaldi_table.RandomAccessFloatReader): + """Random access table reader for single precision floats.""" + pass + + class RandomAccessDoubleReader(_RandomAccessReaderBase, + _kaldi_table.RandomAccessDoubleReader): + """Random access table reader for double precision floats.""" + pass + + class RandomAccessBoolReader(_RandomAccessReaderBase, + _kaldi_table.RandomAccessBoolReader): + """Random access table reader for Booleans.""" + pass + + class RandomAccessIntVectorReader(_RandomAccessReaderBase, + _kaldi_table.RandomAccessIntVectorReader): + """Random access table reader for integer sequences.""" + pass + + class RandomAccessIntVectorVectorReader( + _RandomAccessReaderBase, + _kaldi_table.RandomAccessIntVectorVectorReader): + """Random access table reader for sequences of integer sequences.""" + pass + + class RandomAccessIntPairVectorReader( + _RandomAccessReaderBase, + _kaldi_table.RandomAccessIntPairVectorReader): + """Random access table reader for sequences of integer pairs.""" + pass + + class RandomAccessFloatPairVectorReader( + _RandomAccessReaderBase, + _kaldi_table.RandomAccessFloatPairVectorReader): + """ + Random access table reader for sequences of single precision float pairs. + """ + pass + +################################################################################ +# Mapped Random Access Readers +################################################################################ + + class _RandomAccessReaderMappedBase(object): + """ + Base class defining the Python API for mapped random access table readers. + """ + + def __init__(self, table_rspecifier="", map_rspecifier=""): + """ + This class is used for randomly accessing objects in an archive or + script file. It implements `__contains__` and `__getitem__` methods to + provide a dictionary-like interface for accessing table entries. If a + **map_rspecifier** is provided, the map is used for converting the keys + to the actual keys used to query the table, e.g. `reader[key]` returns + the `value` associated with the key `map[key]`. Otherwise, it works like + a random access table reader. + + Args: + table_rspecifier(str): Kaldi rspecifier for reading the table. + If provided, the table is opened for reading. + map_rspecifier (str): Kaldi rspecifier for reading the map. + If provided, the map is opened for reading. + + Raises: + IOError: If opening the table or map for reading fails. + """ + super(_RandomAccessReaderMappedBase, self).__init__() + if table_rspecifier != "" and map_rspecifier != "": + if not self.open(table_rspecifier, map_rspecifier): + raise IOError( + "Error opening mapped random access table reader " + "with table_rspecifier: {}, map_rspecifier: {}".format( + table_rspecifier, map_rspecifier)) + + def __enter__(self): + return self + + def __contains__(self, key): + return self.has_key(key) + + def __getitem__(self, key): + if self.has_key(key): + return self.value(key) + else: + raise KeyError(key) + + def open(self, table_rspecifier, map_rspecifier): + """Opens the table for reading. + + Args: + table_rspecifier(str): Kaldi rspecifier for reading the table. + If provided, the table is opened for reading. + map_rspecifier (str): Kaldi rspecifier for reading the map. + If provided, the map is opened for reading. + + Returns: + True if table is opened successfully, False otherwise. + + Raises: + IOError: If opening the table or map for reading fails. + """ + return super(_RandomAccessReaderMappedBase, + self).open(table_rspecifier, map_rspecifier) + + def has_key(self, key): + """Checks whether the table has the key. + + This method is provided for compatibility with the C++ API only; + most users should use the Pythonic API. + + Args: + key (str): The key. + + Returns: + True if the table has the key, False otherwise. + """ + return super(_RandomAccessReaderMappedBase, self).has_key(key) + + def value(self, key): + """Returns the value associated with the key. + + This method is provided for compatibility with the C++ API only; + most users should use the Pythonic API. + + Args: + key (str): The key. + + Returns: + The value associated with the key. + """ + return super(_RandomAccessReaderMappedBase, self).value(key) + + def is_open(self): + """Indicates whether the table reader is open or not. + + This method is provided for compatibility with the C++ API only; + most users should use the Pythonic API. + + Returns: + True if the table reader is open, False otherwise. + """ + return super(_RandomAccessReaderMappedBase, self).is_open() + + def close(self): + """Closes the table. + + This method is provided for compatibility with the C++ API only; + most users should use the Pythonic API. + + Returns: + True if table is closed successfully, False otherwise. + """ + return super(_RandomAccessReaderMappedBase, self).close() + + class RandomAccessVectorReaderMapped( + _RandomAccessReaderMappedBase, + _kaldi_table.RandomAccessVectorReaderMapped): + """Mapped random access table reader for single precision vectors.""" + pass + + class RandomAccessDoubleVectorReaderMapped( + _RandomAccessReaderMappedBase, + _kaldi_table.RandomAccessDoubleVectorReaderMapped): + """Mapped random access table reader for double precision vectors.""" + pass + + class RandomAccessMatrixReaderMapped( + _RandomAccessReaderMappedBase, + _kaldi_table.RandomAccessMatrixReaderMapped): + """Mapped random access table reader for single precision matrices.""" + pass + + class RandomAccessDoubleMatrixReaderMapped( + _RandomAccessReaderMappedBase, + _kaldi_table.RandomAccessDoubleMatrixReaderMapped): + """Mapped random access table reader for double precision matrices.""" + pass + + class RandomAccessFloatReaderMapped( + _RandomAccessReaderMappedBase, + _kaldi_table.RandomAccessFloatReaderMapped): + """Mapped random access table reader for single precision floats.""" + pass + + class VectorWriter(_WriterBase, _kaldi_table.VectorWriter): + """Table writer for single precision vectors.""" + + def write(self, key, value): + """Writes the `(key, value)` pair to the table. + + This method is provided for compatibility with the C++ API only; + most users should use the Pythonic API. + + Overrides write to accept both Vector and SubVector. + + Args: + key (str): The key. + value: The value. + """ + super(VectorWriter, self).write(key, _matrix.Vector(value)) + + class DoubleVectorWriter(_WriterBase, _kaldi_table.DoubleVectorWriter): + """Table writer for double precision vectors.""" + + def write(self, key, value): + """Writes the `(key, value)` pair to the table. + + This method is provided for compatibility with the C++ API only; + most users should use the Pythonic API. + + Overrides write to accept both DoubleVector and DoubleSubVector. + + Args: + key (str): The key. + value: The value. + """ + super(DoubleVectorWriter, self).write(key, + _matrix.DoubleVector(value)) + + class MatrixWriter(_WriterBase, _kaldi_table.MatrixWriter): + """Table writer for single precision matrices.""" + + def write(self, key, value): + """Writes the `(key, value)` pair to the table. + + This method is provided for compatibility with the C++ API only; + most users should use the Pythonic API. + + Overrides write to accept both Matrix and SubMatrix. + + Args: + key (str): The key. + value: The value. + """ + super(MatrixWriter, self).write(key, _matrix.Matrix(value)) + + class DoubleMatrixWriter(_WriterBase, _kaldi_table.DoubleMatrixWriter): + """Table writer for double precision matrices.""" + + def write(self, key, value): + """Writes the `(key, value)` pair to the table. + + This method is provided for compatibility with the C++ API only; + most users should use the Pythonic API. + + Overrides write to accept both DoubleMatrix and DoubleSubMatrix. + + Args: + key (str): The key. + value: The value. + """ + super(DoubleMatrixWriter, self).write(key, + _matrix.DoubleMatrix(value)) + + class WaveWriter(_WriterBase, _kaldi_table.WaveWriter): + """Table writer for wave files.""" + pass + + class PosteriorWriter(_WriterBase, _kaldi_table.PosteriorWriter): + """Table writer for frame posteriors.""" + pass + + class GaussPostWriter(_WriterBase, _kaldi_table.GaussPostWriter): + """Table writer for Gaussian-level frame posteriors.""" + pass + + class FstWriter(_WriterBase, _kaldi_table_ext.FstWriter): + """Table writer for FSTs over the tropical semiring.""" + pass + + class LogFstWriter(_WriterBase, _kaldi_table_ext.LogFstWriter): + """Table writer for FSTs over the log semiring.""" + pass + + class KwsIndexFstWriter(_WriterBase, _kaldi_table_ext.KwsIndexFstWriter): + """Table writer for FSTs over the KWS index semiring.""" + pass + + class LatticeWriter(_WriterBase, _kaldi_table.LatticeWriter): + """Table writer for lattices.""" + pass + + class CompactLatticeWriter(_WriterBase, _kaldi_table.CompactLatticeWriter): + """Table writer for compact lattices.""" + pass + + class NnetExampleWriter(_WriterBase, _kaldi_table.NnetExampleWriter): + """Table writer for nnet examples.""" + pass + + class RnnlmExampleWriter(_WriterBase, _kaldi_table.RnnlmExampleWriter): + """Table writer for RNNLM examples.""" + pass + + class IntWriter(_WriterBase, _kaldi_table.IntWriter): + """Table writer for integers.""" + pass + + class FloatWriter(_WriterBase, _kaldi_table.FloatWriter): + """Table writer for single precision floats.""" + pass + + class DoubleWriter(_WriterBase, _kaldi_table.DoubleWriter): + """Table writer for double precision floats.""" + pass + + class BoolWriter(_WriterBase, _kaldi_table.BoolWriter): + """Table writer for Booleans.""" + pass + + class IntVectorWriter(_WriterBase, _kaldi_table.IntVectorWriter): + """Table writer for integer sequences.""" + pass + + class IntVectorVectorWriter(_WriterBase, + _kaldi_table.IntVectorVectorWriter): + """Table writer for sequences of integer sequences.""" + pass + + class IntPairVectorWriter(_WriterBase, _kaldi_table.IntPairVectorWriter): + """Table writer for sequences of integer pairs.""" + pass + + class FloatPairVectorWriter(_WriterBase, + _kaldi_table.FloatPairVectorWriter): + """Table writer for sequences of single precision float pairs.""" + pass + +################################################################################ + + __all__ = [ + name for name in dir() if name[0] != '_' and not name.endswith('Base') + ]