Skip to content

Commit

Permalink
Support export_to_kaldi on resampled recordings (lhotse-speech#1162)
Browse files Browse the repository at this point in the history
* Add the test.

* fixed the test.

* Pass the transforms into function because the resample function will modify the transforms.

* Modified other tests.

* black

* isort

---------

Co-authored-by: Piotr Żelasko <[email protected]>
  • Loading branch information
2 people authored and flyingleafe committed Oct 11, 2023
1 parent 0fd742a commit 56c9e01
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 14 deletions.
20 changes: 14 additions & 6 deletions lhotse/kaldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

from lhotse.audio import AudioSource, Recording, RecordingSet, info
from lhotse.features import Features, FeatureSet
Expand Down Expand Up @@ -346,7 +346,9 @@ def export_to_kaldi(
save_kaldi_text_mapping(
data={
recording.id: make_wavscp_channel_string_map(
source, sampling_rate=recording.sampling_rate
source,
sampling_rate=recording.sampling_rate,
transforms=recording.transforms,
)[0]
for recording in recordings
for source in recording.sources
Expand Down Expand Up @@ -400,7 +402,9 @@ def export_to_kaldi(
save_kaldi_text_mapping(
data={
f"{recording.id}_{channel}": make_wavscp_channel_string_map(
source, sampling_rate=recording.sampling_rate
source,
sampling_rate=recording.sampling_rate,
transforms=recording.transforms,
)[channel]
for recording in recordings
for source in recording.sources
Expand Down Expand Up @@ -538,7 +542,7 @@ def save_kaldi_text_mapping(data: Dict[str, Any], path: Path):


def make_wavscp_channel_string_map(
source: AudioSource, sampling_rate: int
source: AudioSource, sampling_rate: int, transforms: Optional[List[Dict]] = None
) -> Dict[int, str]:
if source.type == "url":
raise ValueError("URL audio sources are not supported by Kaldi.")
Expand All @@ -549,14 +553,18 @@ def make_wavscp_channel_string_map(
)
return {0: f"{source.source} |"}
elif source.type == "file":
if Path(source.source).suffix == ".wav" and len(source.channels) == 1:
if (
Path(source.source).suffix == ".wav"
and len(source.channels) == 1
and transforms is None
):
# Note: for single-channel waves, we don't need to invoke ffmpeg; but
# for multi-channel waves, Kaldi is going to complain.
audios = dict()
for channel in source.channels:
audios[channel] = source.source
return audios
elif Path(source.source).suffix == ".sph":
if Path(source.source).suffix == ".sph":
# we will do this specifically using the sph2pipe because
# ffmpeg does not support shorten compression, which is sometimes
# used in the sph files
Expand Down
43 changes: 35 additions & 8 deletions test/test_kaldi_dirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lhotse.audio.recording_set
import lhotse.audio.source
import lhotse.audio.utils
from lhotse import Recording, RecordingSet, SupervisionSegment, SupervisionSet
from lhotse.audio import get_audio_duration_mismatch_tolerance

pytest.importorskip(
Expand Down Expand Up @@ -161,6 +162,32 @@ def test_multi_channel_recording(
assert segments == multi_channel_kaldi_dir["segments"]


def test_resample_recording(tmp_path, multi_channel_recording, multi_channel_kaldi_dir):
with working_directory(tmp_path):
recording = Recording.from_file(
recording_id="mono_c0",
path=os.path.join(os.path.dirname(__file__), "fixtures", "mono_c0.wav"),
).resample(16000)
segment = SupervisionSegment(
id="Segment-c0",
recording_id=recording.id,
start=0,
duration=recording.duration,
channel=0,
text="SIL",
)
lhotse.kaldi.export_to_kaldi(
RecordingSet.from_recordings([recording]),
SupervisionSet.from_segments([segment]),
output_dir=".",
map_underscores_to=None,
prefix_spk_id=False,
)

wavs = open_and_load("wav.scp")
assert "16000" in wavs["mono_c0"]


@contextlib.contextmanager
def working_directory(path):
"""Changes working directory and returns to previous on exit."""
Expand Down Expand Up @@ -231,30 +258,30 @@ def test_fail_on_unknown_source_type(tmp_path):
type="unknown", channels=[0], source="http://example.com/"
)
with pytest.raises(ValueError):
lhotse.kaldi.make_wavscp_channel_string_map(source, 16000)
lhotse.kaldi.make_wavscp_channel_string_map(source, 16000, None)


def test_fail_on_url_source_type(tmp_path):
source = lhotse.audio.source.AudioSource(
type="url", channels=[0], source="http://example.com/"
)
with pytest.raises(ValueError):
lhotse.kaldi.make_wavscp_channel_string_map(source, 16000)
lhotse.kaldi.make_wavscp_channel_string_map(source, 16000, None)


def test_fail_on_command_multichannel_source_type(tmp_path):
source = lhotse.audio.source.AudioSource(
type="command", channels=[0, 1], source="false"
)
with pytest.raises(ValueError):
lhotse.kaldi.make_wavscp_channel_string_map(source, 16000)
lhotse.kaldi.make_wavscp_channel_string_map(source, 16000, None)


def test_ok_on_command_singlechannel_source_type(tmp_path):
source = lhotse.audio.source.AudioSource(
type="command", channels=[0], source="true"
)
out = lhotse.kaldi.make_wavscp_channel_string_map(source, 16000)
out = lhotse.kaldi.make_wavscp_channel_string_map(source, 16000, None)
assert list(out.keys()) == [0]
assert out[0] == "true |"

Expand All @@ -264,7 +291,7 @@ def test_ok_on_file_singlechannel_wav_source_type(tmp_path, channel):
source = lhotse.audio.source.AudioSource(
type="file", channels=[channel], source="nonexistent.wav"
)
out = lhotse.kaldi.make_wavscp_channel_string_map(source, 16000)
out = lhotse.kaldi.make_wavscp_channel_string_map(source, 16000, None)
assert list(out.keys()) == [channel]
assert out[channel] == "nonexistent.wav"

Expand All @@ -274,7 +301,7 @@ def test_ok_on_file_singlechannel_sph_source_type(tmp_path, channel):
source = lhotse.audio.source.AudioSource(
type="file", channels=[channel], source="nonexistent.sph"
)
out = lhotse.kaldi.make_wavscp_channel_string_map(source, 16000)
out = lhotse.kaldi.make_wavscp_channel_string_map(source, 16000, None)
assert list(out.keys()) == [channel]
assert out[channel].startswith("sph2pipe")
assert "nonexistent.sph" in out[channel]
Expand All @@ -286,7 +313,7 @@ def test_ok_on_file_singlechannel_mp3_source_type(tmp_path, channel):
source = lhotse.audio.source.AudioSource(
type="file", channels=[channel], source="nonexistent.mp3"
)
out = lhotse.kaldi.make_wavscp_channel_string_map(source, 16000)
out = lhotse.kaldi.make_wavscp_channel_string_map(source, 16000, None)
assert list(out.keys()) == [channel]
assert out[channel].startswith("ffmpeg")
assert "nonexistent.mp3" in out[channel]
Expand All @@ -297,7 +324,7 @@ def test_ok_on_file_multichannel_wav_source_type(tmp_path):
source = lhotse.audio.source.AudioSource(
type="file", channels=[0, 1, 2], source="nonexistent.wav"
)
out = lhotse.kaldi.make_wavscp_channel_string_map(source, 16000)
out = lhotse.kaldi.make_wavscp_channel_string_map(source, 16000, None)
assert list(out.keys()) == [0, 1, 2]
for channel in out.keys():
assert out[channel].startswith("ffmpeg")
Expand Down

0 comments on commit 56c9e01

Please sign in to comment.