Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/pybind/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ endif
CCFILES = kaldi_pybind.cc \
matrix/matrix_common_pybind.cc matrix/matrix_pybind.cc \
matrix/vector_pybind.cc \
util/table_types_pybind.cc
util/table_types_pybind.cc \
feat/wave_reader_pybind.cc

LIBNAME = kaldi_pybind


ADDLIBS = ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a
ADDLIBS = ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a ../feat/kaldi-feat.a

EXTRA_LDLIBS += $(foreach dep,$(ADDLIBS), $(dir $(dep))lib$(notdir $(basename $(dep))).so)

Expand All @@ -63,6 +64,7 @@ clean:
test: all
python3 tests/test_kaldi_pybind.py
python3 tests/test_matrix.py
python3 tests/test_wave.py

# valgrind-python.supp is from http://svn.python.org/projects/python/trunk/Misc/valgrind-python.supp
# since we do not compile Python from souce, we follow the comment in valgrind-python.supp
Expand Down
105 changes: 105 additions & 0 deletions src/pybind/feat/wave_reader_pybind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// pybind/feat/wave_reader_pybind.cc

// Copyright 2019 Microsoft Corporation (author: Xingyu Na)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "feat/wave_reader_pybind.h"
#include "feat/wave-reader.h"
#include "util/table-types.h"

using namespace kaldi;

void pybind_wave_reader(py::module& m) {
m.attr("kWaveSampleMax") = py::cast(kWaveSampleMax);

py::class_<WaveInfo>(m, "WaveInfo")
.def(py::init<>())
.def("IsStreamed", &WaveInfo::IsStreamed,
"Is stream size unknown? Duration and SampleCount not valid if true.")
.def("SampFreq", &WaveInfo::SampFreq,
"Sample frequency, Hz.")
.def("SampleCount", &WaveInfo::SampleCount,
"Number of samples in stream. Invalid if IsStreamed() is true.")
.def("Duration", &WaveInfo::Duration,
"Approximate duration, seconds. Invalid if IsStreamed() is true.")
.def("NumChannels", &WaveInfo::NumChannels,
"Number of channels, 1 to 16.")
.def("BlockAlign", &WaveInfo::BlockAlign,
"Bytes per sample.")
.def("DataBytes", &WaveInfo::DataBytes,
"Wave data bytes. Invalid if IsStreamed() is true.")
.def("ReverseBytes", &WaveInfo::ReverseBytes,
"Is data file byte order different from machine byte order?");

py::class_<WaveData>(m, "WaveData")
.def(py::init<>())
.def(py::init<const float, const Matrix<float>>(),
py::arg("samp_freq"), py::arg("data"))
.def("Duration", &WaveData::Duration,
"Returns the duration in seconds")
.def("Data", &WaveData::Data, py::return_value_policy::reference)
.def("SampFreq", &WaveData::SampFreq)
.def("Clear", &WaveData::Clear)
.def("CopyFrom", &WaveData::CopyFrom)
.def("Swap", &WaveData::Swap);

py::class_<SequentialTableReader<WaveHolder>>(m, "SequentialWaveReader")
.def(py::init<>())
.def(py::init<const std::string&>(), py::arg("rspecifier"))
.def("Open", &SequentialTableReader<WaveHolder>::Open, py::arg("rspecifier"))
.def("Done", &SequentialTableReader<WaveHolder>::Done)
.def("Key", &SequentialTableReader<WaveHolder>::Key)
.def("FreeCurrent", &SequentialTableReader<WaveHolder>::FreeCurrent)
.def("Value", &SequentialTableReader<WaveHolder>::Value,
py::return_value_policy::reference)
.def("Next", &SequentialTableReader<WaveHolder>::Next)
.def("IsOpen", &SequentialTableReader<WaveHolder>::IsOpen)
.def("Close", &SequentialTableReader<WaveHolder>::Close);

py::class_<RandomAccessTableReader<WaveHolder>>(m, "RandomAccessWaveReader")
.def(py::init<>())

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the copy constructor also be wrapped?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I see

.def(py::init<const std::string&>(), py::arg("rspecifier"))
.def("Open", &RandomAccessTableReader<WaveHolder>::Open, py::arg("rspecifier"))
.def("IsOpen", &RandomAccessTableReader<WaveHolder>::IsOpen)
.def("Close", &RandomAccessTableReader<WaveHolder>::Close)
.def("HasKey", &RandomAccessTableReader<WaveHolder>::HasKey, py::arg("key"))
.def("Value", &RandomAccessTableReader<WaveHolder>::Value,
py::return_value_policy::reference);

py::class_<SequentialTableReader<WaveInfoHolder>>(m, "SequentialWaveInfoReader")
.def(py::init<>())
.def(py::init<const std::string&>(), py::arg("rspecifier"))
.def("Open", &SequentialTableReader<WaveInfoHolder>::Open, py::arg("rspecifier"))
.def("Done", &SequentialTableReader<WaveInfoHolder>::Done)
.def("Key", &SequentialTableReader<WaveInfoHolder>::Key)
.def("FreeCurrent", &SequentialTableReader<WaveInfoHolder>::FreeCurrent)
.def("Value", &SequentialTableReader<WaveInfoHolder>::Value,
py::return_value_policy::reference)
.def("Next", &SequentialTableReader<WaveInfoHolder>::Next)
.def("IsOpen", &SequentialTableReader<WaveInfoHolder>::IsOpen)
.def("Close", &SequentialTableReader<WaveInfoHolder>::Close);

py::class_<RandomAccessTableReader<WaveInfoHolder>>(m, "RandomAccessWaveInfoReader")
.def(py::init<>())
.def(py::init<const std::string&>(), py::arg("rspecifier"))
.def("Open", &RandomAccessTableReader<WaveInfoHolder>::Open, py::arg("rspecifier"))
.def("IsOpen", &RandomAccessTableReader<WaveInfoHolder>::IsOpen)
.def("Close", &RandomAccessTableReader<WaveInfoHolder>::Close)
.def("HasKey", &RandomAccessTableReader<WaveInfoHolder>::HasKey, py::arg("key"))
.def("Value", &RandomAccessTableReader<WaveInfoHolder>::Value,
py::return_value_policy::reference);

}

25 changes: 25 additions & 0 deletions src/pybind/feat/wave_reader_pybind.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// pybind/feat/wave_reader_pybind.h

// Copyright 2019 Microsoft Corporation (author: Xingyu Na)

// See ../../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_PYBIND_FEAT_WAVE_READER_PYBIND_H_
#define KALDI_PYBIND_FEAT_WAVE_READER_PYBIND_H_

#include "pybind/kaldi_pybind.h"

void pybind_wave_reader(py::module& m);

#endif // KALDI_PYBIND_FEAT_WAVE_READER_PYBIND_H_
2 changes: 2 additions & 0 deletions src/pybind/kaldi_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "matrix/matrix_pybind.h"
#include "matrix/vector_pybind.h"
#include "util/table_types_pybind.h"
#include "feat/wave_reader_pybind.h"

void pybind_matrix(py::module& m);
PYBIND11_MODULE(kaldi_pybind, m) {
Expand All @@ -37,4 +38,5 @@ PYBIND11_MODULE(kaldi_pybind, m) {
pybind_matrix(m);
pybind_vector(m);
pybind_table_types(m);
pybind_wave_reader(m);
}
50 changes: 42 additions & 8 deletions src/pybind/matrix_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
class SequentialVectorReader(kp.SequentialBaseFloatVectorReader):
def __init__(self, rspecifier=None):
if not rspecifier:
kp.SequentialBaseFloatVectorReader.__init__(self)
kp.SequentialBaseFloatVectorReader.__init__(self)
else:
kp.SequentialBaseFloatVectorReader.__init__(self, rspecifier)
kp.SequentialBaseFloatVectorReader.__init__(self, rspecifier)

def __iter__(self):
return self
Expand All @@ -24,9 +24,27 @@ def __next__(self):
class SequentialMatrixReader(kp.SequentialBaseFloatMatrixReader):
def __init__(self, rspecifier=None):
if not rspecifier:
kp.SequentialBaseFloatMatrixReader.__init__(self)
kp.SequentialBaseFloatMatrixReader.__init__(self)
else:
kp.SequentialBaseFloatMatrixReader.__init__(self, rspecifier)
kp.SequentialBaseFloatMatrixReader.__init__(self, rspecifier)

def __iter__(self):
return self

def __next__(self):
if self.Done():
raise StopIteration()
else:
key, value = self.Key(), self.Value()
self.Next()
return key, value

class SequentialWaveReader(kp.SequentialWaveReader):
def __init__(self, rspecifier=None):
if not rspecifier:
kp.SequentialWaveReader.__init__(self)
else:
kp.SequentialWaveReader.__init__(self, rspecifier)

def __iter__(self):
return self
Expand All @@ -52,9 +70,9 @@ def __next__(self):
class RandomAccessVectorReader(kp.RandomAccessBaseFloatVectorReader):
def __init__(self, rspecifier=None):
if not rspecifier:
kp.RandomAccessBaseFloatVectorReader.__init__(self)
kp.RandomAccessBaseFloatVectorReader.__init__(self)
else:
kp.RandomAccessBaseFloatVectorReader.__init__(self, rspecifier)
kp.RandomAccessBaseFloatVectorReader.__init__(self, rspecifier)

def __getitem__(self, key):
if not self.HasKey(key):
Expand All @@ -68,9 +86,25 @@ def __contains__(self, key):
class RandomAccessMatrixReader(kp.RandomAccessBaseFloatMatrixReader):
def __init__(self, rspecifier=None):
if not rspecifier:
kp.RandomAccessBaseFloatMatrixReader.__init__(self)
kp.RandomAccessBaseFloatMatrixReader.__init__(self)
else:
kp.RandomAccessBaseFloatMatrixReader.__init__(self, rspecifier)

def __getitem__(self, key):
if not self.HasKey(key):
raise KeyError("{} does not exits.".format(key))
else:
return self.Value(key)

def __contains__(self, key):
return self.HasKey(key)

class RandomAccessWaveReader(kp.RandomAccessWaveReader):
def __init__(self, rspecifier=None):
if not rspecifier:
kp.RandomAccessWaveReader.__init__(self)
else:
kp.RandomAccessBaseFloatMatrixReader.__init__(self, rspecifier)
kp.RandomAccessWaveReader.__init__(self, rspecifier)

def __getitem__(self, key):
if not self.HasKey(key):
Expand Down
25 changes: 25 additions & 0 deletions src/pybind/tests/test_wave.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env python

# Copyright 2019 Microsoft Corporation (author: Xingyu Na)
# Apache 2.0

import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), os.pardir))

import unittest

import numpy as np

import kaldi

class TestWaveData(unittest.TestCase):

def test_duration(self):
waveform = kaldi.FloatMatrix(1, 16000)
wave_data = kaldi.WaveData(samp_freq=16000, data=waveform)
self.assertEqual(1, wave_data.Duration())

if __name__ == '__main__':
unittest.main()