Skip to content

Commit

Permalink
Merge branch 'sim_nj' of https://github.com/desh2608/lhotse into sim_nj
Browse files Browse the repository at this point in the history
  • Loading branch information
desh2608 committed Feb 8, 2023
2 parents efaffcc + 132c938 commit 26a69bb
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 13 deletions.
47 changes: 35 additions & 12 deletions lhotse/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,18 +475,9 @@ def move_to_memory(
channels=channels, offset=ifnone(offset, 0), duration=duration
)
stream = BytesIO()
if torchaudio_soundfile_supports_format() and format == "flac":
# Prefer saving with soundfile backend whenever possible to avoid issue:
# https://github.com/pytorch/audio/issues/2662
# Saving with sox_io backend to FLAC may corrupt the file, IDK about other
# formats but would rather be on the safe side.
torchaudio.backend.soundfile_backend.save(
stream, torch.from_numpy(audio), self.sampling_rate, format=format
)
else:
torchaudio.backend.sox_io_backend.save(
stream, torch.from_numpy(audio), self.sampling_rate, format=format
)
torchaudio_save_flac_safe(
stream, torch.from_numpy(audio), self.sampling_rate, format=format
)
channels = (ifnone(channels, self.channel_ids),)
if isinstance(channels, int):
channels = [channels]
Expand Down Expand Up @@ -2255,6 +2246,38 @@ def read_sph(
return audio, sampling_rate


def torchaudio_save_flac_safe(
dest: Union[str, Path, BytesIO],
src: Union[torch.Tensor, np.ndarray],
sample_rate: int,
*args,
**kwargs,
):
import torchaudio

src = torch.as_tensor(src)
saving_flac = kwargs.get("format") == "flac" or (
not isinstance(dest, BytesIO) and str(dest).endswith(".flac")
)
if torchaudio_soundfile_supports_format() and saving_flac:
# Prefer saving with soundfile backend whenever possible to avoid issue:
# https://github.com/pytorch/audio/issues/2662
# Saving with sox_io backend to FLAC may corrupt the file.
torchaudio.backend.soundfile_backend.save(
dest,
src,
sample_rate=sample_rate,
format=kwargs.pop("format", "flac"),
bits_per_sample=kwargs.pop("bits_per_sample", 16),
*args,
**kwargs,
)
else:
torchaudio.backend.sox_io_backend.save(
dest, src, sample_rate=sample_rate, *args, **kwargs
)


class AudioLoadingError(Exception):
pass

Expand Down
41 changes: 40 additions & 1 deletion lhotse/cut/mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import warnings
from dataclasses import dataclass
from functools import partial, reduce
from io import BytesIO
from operator import add
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
from intervaltree import IntervalTree

from lhotse.audio import AudioMixer, Recording, audio_energy
from lhotse.audio import AudioMixer, Recording, audio_energy, torchaudio_save_flac_safe
from lhotse.augmentation import AugmentFn
from lhotse.cut.base import Cut
from lhotse.cut.data import DataCut
Expand Down Expand Up @@ -346,6 +347,44 @@ def move_to_memory(
],
)

def to_mono(
self,
encoding: str = "flac",
bits_per_sample: Optional[int] = 16,
**kwargs,
) -> "Cut":
"""
Convert this MixedCut to a MonoCut by mixing all tracks and channels into a single one.
The result audio array is stored in memory, and can be saved to disk by calling
``cut.save_audio(path, ...)`` on the result.
.. hint:: the resulting MonoCut will have ``custom`` field populated with the
``custom`` value from the first track of the MixedCut.
:param encoding: Audio encoding argument supported by ``torchaudio.save``. See
https://pytorch.org/audio/stable/backend.html#save (sox_io backend) and
https://pytorch.org/audio/stable/backend.html#id3 (soundfile backend) for more details.
:param bits_per_sample: Audio bits_per_sample argument supported by ``torchaudio.save``. See
https://pytorch.org/audio/stable/backend.html#save (sox_io backend) and
https://pytorch.org/audio/stable/backend.html#id3 (soundfile backend) for more details.
:return: a new MonoCut instance.
"""
samples = self.load_audio(mono_downmix=True)
stream = BytesIO()
torchaudio_save_flac_safe(
stream,
samples,
self.sampling_rate,
format=encoding,
bits_per_sample=bits_per_sample,
)
recording = Recording.from_bytes(stream.getvalue(), recording_id=self.id)
return fastcopy(
recording.to_cut(),
supervisions=[fastcopy(s, channel=0) for s in self.supervisions],
custom=self.tracks[0].cut.custom,
)

def truncate(
self,
*,
Expand Down
32 changes: 32 additions & 0 deletions test/cut/test_cut_with_in_memory_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,35 @@ def test_mixed_cut_move_to_memory():
feats = cut.load_features()
feats_mem = cut_mem.load_features()
np.testing.assert_almost_equal(feats, feats_mem, decimal=1)


def test_mixed_cut_to_mono():
path = "test/fixtures/libri/cuts.json"
cut = CutSet.from_file(path)[0]
cut = cut.pad(duration=cut.duration + 2.0).append(cut)

cut_mem = cut.to_mono("wav")
assert isinstance(cut_mem, MonoCut)
assert not cut_mem.has_features

audio = cut.load_audio()
audio_mem = cut_mem.load_audio()
np.testing.assert_almost_equal(audio, audio_mem, decimal=1)


def test_mixed_cut_to_mono_with_custom():
path = "test/fixtures/libri/cuts.json"
cut = CutSet.from_file(path)[0]
cut.custom_str = "custom_str"
cut = cut.pad(duration=cut.duration + 2.0).append(cut)

cut_mem = cut.to_mono("wav")
assert isinstance(cut_mem, MonoCut)
assert not cut_mem.has_features
assert cut_mem.custom is not None
assert "custom_str" in cut_mem.custom
assert cut_mem.custom_str == "custom_str"

audio = cut.load_audio()
audio_mem = cut_mem.load_audio()
np.testing.assert_almost_equal(audio, audio_mem, decimal=1)

0 comments on commit 26a69bb

Please sign in to comment.