Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize save_audios() #1131

Merged
merged 2 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lhotse/caching.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from functools import lru_cache, wraps
from threading import Lock
from typing import Any, Callable, Dict, Union
from typing import Any, Callable, Dict, Optional

LHOTSE_CACHING_ENABLED = False

Expand Down Expand Up @@ -114,7 +114,7 @@ def enabled(cls) -> bool:
return cls.__enabled

@classmethod
def try_cache(cls, key: str) -> Union[bytes, None]:
def try_cache(cls, key: str) -> Optional[bytes]:
"""
Test if 'key' is in the chache. If yes return the bytes array,
otherwise return None.
Expand Down
22 changes: 20 additions & 2 deletions lhotse/cut/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from lhotse.audio import RecordingSet, null_result_on_audio_loading_error
from lhotse.augmentation import AugmentFn
from lhotse.caching import is_caching_enabled
from lhotse.cut.base import Cut
from lhotse.cut.data import DataCut
from lhotse.cut.mixed import MixedCut, MixTrack
Expand Down Expand Up @@ -1413,6 +1414,17 @@ def combine_same_recording_channels(self) -> "CutSet":
groups = groupby(lambda cut: (cut.recording.id, cut.start, cut.end), self)
return CutSet.from_cuts(MultiCut.from_mono(*cuts) for cuts in groups.values())

def sort_by_recording_id(self, ascending: bool = True) -> "CutSet":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a unit test to cover this method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to that.

I found a similar unit-test test_cut_set_sort_by_duration().
Could you point me to the place, where this test is called ?

I did not find any code calling test_cut_set_sort_by_duration() in the /test
folder and there are some arguments necessary to be filled...
Thx, Karel

"""
Sort the CutSet alphabetically according to 'recording_id'. Ascending by default.
This is advantageous before caling `save_audios()` on a `trim_to_supervision()`
processed `CutSet`, also make sure that `set_caching_enabled(True)` was called.
"""
return CutSet.from_cuts(
sorted(self, key=(lambda cut: cut.recording.id), reverse=not ascending)
)

def sort_by_duration(self, ascending: bool = False) -> "CutSet":
"""
Sort the CutSet according to cuts duration and return the result. Descending by default.
Expand Down Expand Up @@ -2326,6 +2338,7 @@ def save_audios(
executor: Optional[Executor] = None,
augment_fn: Optional[AugmentFn] = None,
progress_bar: bool = True,
shuffle_on_split: bool = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This option is undocumented

) -> "CutSet":
"""
Store waveforms of all cuts as audio recordings to disk.
Expand Down Expand Up @@ -2355,6 +2368,8 @@ def save_audios(
https://lhotse.readthedocs.io/en/latest/parallelism.html
:param progress_bar: Should a progress bar be displayed (automatically turned off
for parallel computation).
:param shuffle_on_split: Shuffle the ``CutSet`` before splitting it for the parallel workers.
It is active only when `num_jobs > 1`. The default is True.
:return: Returns a new ``CutSet``.
"""
from cytoolz import identity
Expand Down Expand Up @@ -2401,14 +2416,17 @@ def file_storage_path(cut: Cut, storage_path: Pathlike) -> Path:
)

# Parallel execution: prepare the CutSet splits
cut_sets = self.split(num_jobs, shuffle=True)
cut_sets = self.split(num_jobs, shuffle=shuffle_on_split)

# Initialize the default executor if None was given
if executor is None:
import multiprocessing

# The `is_caching_enabled()` state gets transfered to
# the spawned sub-processes implictly (checked).
executor = ProcessPoolExecutor(
num_jobs, mp_context=multiprocessing.get_context("spawn")
max_workers=num_jobs,
mp_context=multiprocessing.get_context("spawn"),
)

# Submit the chunked tasks to parallel workers.
Expand Down
23 changes: 23 additions & 0 deletions test/cut/test_cut_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@
from lhotse.utils import is_module_available


@pytest.fixture
def mini_librispeeh2_cut_set():
recordings = RecordingSet.from_file(
"test/fixtures/mini_librispeech2/lhotse/recordings.jsonl.gz"
)
supervisions = SupervisionSet.from_file(
"test/fixtures/mini_librispeech2/lhotse/supervisions.jsonl.gz"
)
return CutSet.from_manifests(recordings=recordings, supervisions=supervisions)


@pytest.fixture
def cut_set_with_mixed_cut(cut1, cut2):
mixed_cut = MixedCut(
Expand All @@ -51,6 +62,18 @@ def test_cut_set_sort_by_duration(cut_set_with_mixed_cut, ascending, expected):
assert [c.duration for c in cs] == expected


@pytest.mark.parametrize(
["ascending", "expected"],
[
(True, ["lbi-3536-23268-0000", "lbi-6241-61943-0000", "lbi-8842-304647-0000"]),
(False, ["lbi-8842-304647-0000", "lbi-6241-61943-0000", "lbi-3536-23268-0000"]),
],
)
def test_cut_set_sort_by_recording_id(mini_librispeeh2_cut_set, ascending, expected):
cs = mini_librispeeh2_cut_set.sort_by_recording_id(ascending)
assert [c.recording.id for c in cs] == expected


def test_cut_set_iteration(cut_set_with_mixed_cut):
cuts = list(cut_set_with_mixed_cut)
assert len(cut_set_with_mixed_cut) == 3
Expand Down