Skip to content

Commit

Permalink
optimize save_audios() (#1131)
Browse files Browse the repository at this point in the history
* optimize save_audios()

- added sorted method : CutSet::sort_by_recording_id()
- allow to disable shuffling of CutSet inside `CutSet::save_audios()`
   - both changes improve cache hit ratio
- `CutSet::save_audios()` : show in log if caching was active
- caching.py : replace Union[] by Optional[]

* integrating suggestions from PR

- adding unit-test, removing logging.info(), documenting `shuffle_on_split`
  • Loading branch information
KarelVesely84 authored Sep 5, 2023
1 parent 088f180 commit fdf042b
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 4 deletions.
4 changes: 2 additions & 2 deletions lhotse/caching.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from functools import lru_cache, wraps
from threading import Lock
from typing import Any, Callable, Dict, Union
from typing import Any, Callable, Dict, Optional

LHOTSE_CACHING_ENABLED = False

Expand Down Expand Up @@ -114,7 +114,7 @@ def enabled(cls) -> bool:
return cls.__enabled

@classmethod
def try_cache(cls, key: str) -> Union[bytes, None]:
def try_cache(cls, key: str) -> Optional[bytes]:
"""
Test if 'key' is in the chache. If yes return the bytes array,
otherwise return None.
Expand Down
22 changes: 20 additions & 2 deletions lhotse/cut/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from lhotse.audio import RecordingSet, null_result_on_audio_loading_error
from lhotse.augmentation import AugmentFn
from lhotse.caching import is_caching_enabled
from lhotse.cut.base import Cut
from lhotse.cut.data import DataCut
from lhotse.cut.mixed import MixedCut, MixTrack
Expand Down Expand Up @@ -1413,6 +1414,17 @@ def combine_same_recording_channels(self) -> "CutSet":
groups = groupby(lambda cut: (cut.recording.id, cut.start, cut.end), self)
return CutSet.from_cuts(MultiCut.from_mono(*cuts) for cuts in groups.values())

def sort_by_recording_id(self, ascending: bool = True) -> "CutSet":
"""
Sort the CutSet alphabetically according to 'recording_id'. Ascending by default.
This is advantageous before caling `save_audios()` on a `trim_to_supervision()`
processed `CutSet`, also make sure that `set_caching_enabled(True)` was called.
"""
return CutSet.from_cuts(
sorted(self, key=(lambda cut: cut.recording.id), reverse=not ascending)
)

def sort_by_duration(self, ascending: bool = False) -> "CutSet":
"""
Sort the CutSet according to cuts duration and return the result. Descending by default.
Expand Down Expand Up @@ -2326,6 +2338,7 @@ def save_audios(
executor: Optional[Executor] = None,
augment_fn: Optional[AugmentFn] = None,
progress_bar: bool = True,
shuffle_on_split: bool = True,
) -> "CutSet":
"""
Store waveforms of all cuts as audio recordings to disk.
Expand Down Expand Up @@ -2355,6 +2368,8 @@ def save_audios(
https://lhotse.readthedocs.io/en/latest/parallelism.html
:param progress_bar: Should a progress bar be displayed (automatically turned off
for parallel computation).
:param shuffle_on_split: Shuffle the ``CutSet`` before splitting it for the parallel workers.
It is active only when `num_jobs > 1`. The default is True.
:return: Returns a new ``CutSet``.
"""
from cytoolz import identity
Expand Down Expand Up @@ -2401,14 +2416,17 @@ def file_storage_path(cut: Cut, storage_path: Pathlike) -> Path:
)

# Parallel execution: prepare the CutSet splits
cut_sets = self.split(num_jobs, shuffle=True)
cut_sets = self.split(num_jobs, shuffle=shuffle_on_split)

# Initialize the default executor if None was given
if executor is None:
import multiprocessing

# The `is_caching_enabled()` state gets transfered to
# the spawned sub-processes implictly (checked).
executor = ProcessPoolExecutor(
num_jobs, mp_context=multiprocessing.get_context("spawn")
max_workers=num_jobs,
mp_context=multiprocessing.get_context("spawn"),
)

# Submit the chunked tasks to parallel workers.
Expand Down
23 changes: 23 additions & 0 deletions test/cut/test_cut_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@
from lhotse.utils import is_module_available


@pytest.fixture
def mini_librispeeh2_cut_set():
recordings = RecordingSet.from_file(
"test/fixtures/mini_librispeech2/lhotse/recordings.jsonl.gz"
)
supervisions = SupervisionSet.from_file(
"test/fixtures/mini_librispeech2/lhotse/supervisions.jsonl.gz"
)
return CutSet.from_manifests(recordings=recordings, supervisions=supervisions)


@pytest.fixture
def cut_set_with_mixed_cut(cut1, cut2):
mixed_cut = MixedCut(
Expand All @@ -51,6 +62,18 @@ def test_cut_set_sort_by_duration(cut_set_with_mixed_cut, ascending, expected):
assert [c.duration for c in cs] == expected


@pytest.mark.parametrize(
["ascending", "expected"],
[
(True, ["lbi-3536-23268-0000", "lbi-6241-61943-0000", "lbi-8842-304647-0000"]),
(False, ["lbi-8842-304647-0000", "lbi-6241-61943-0000", "lbi-3536-23268-0000"]),
],
)
def test_cut_set_sort_by_recording_id(mini_librispeeh2_cut_set, ascending, expected):
cs = mini_librispeeh2_cut_set.sort_by_recording_id(ascending)
assert [c.recording.id for c in cs] == expected


def test_cut_set_iteration(cut_set_with_mixed_cut):
cuts = list(cut_set_with_mixed_cut)
assert len(cut_set_with_mixed_cut) == 3
Expand Down

0 comments on commit fdf042b

Please sign in to comment.