Skip to content

Commit

Permalink
[lhotse] Support for NeMo tarred manifests with offset field (#10035)
Browse files Browse the repository at this point in the history
* [lhotse] Support for NeMo tarred manifests with offset field

Signed-off-by: Piotr Żelasko <[email protected]>

* typo fix

Signed-off-by: Piotr Żelasko <[email protected]>

* fix basename

Signed-off-by: Piotr Żelasko <[email protected]>

* relieve heavy CPU memory usage for super-long tarred recordings

Signed-off-by: Piotr Żelasko <[email protected]>

* Tests and fixes

Signed-off-by: Piotr Żelasko <[email protected]>

---------

Signed-off-by: Piotr Żelasko <[email protected]>
  • Loading branch information
pzelasko authored Aug 7, 2024
1 parent 8dbe1da commit e5e648d
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 17 deletions.
89 changes: 73 additions & 16 deletions nemo/collections/common/data/lhotse/nemo_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import soundfile
from cytoolz import groupby
from lhotse import AudioSource, Recording, SupervisionSegment
from lhotse.audio.backend import LibsndfileBackend
from lhotse.cut import Cut
from lhotse.dataset.dataloading import resolve_seed
from lhotse.lazy import LazyIteratorChain, LazyJsonlIterator
Expand Down Expand Up @@ -329,17 +330,26 @@ def __iter__(self) -> Generator[Cut, None, None]:
# Propagate the random seed
extra_fields = [ExtraField.from_dict({"seed": seed, **field_cfg}) for field_cfg in self.extra_fields or ()]

# Handle NeMo tarred manifests with offsets.
# They have multiple JSONL entries where audio paths end with '-sub1', '-sub2', etc. for each offset.
offset_pattern = re.compile(r'^.+(-sub\d+)$')

for sid in shard_ids:
manifest_path = self.paths[sid] if len(self.paths) > 1 else self.paths[0]
shard_manifest = {data["audio_filepath"]: data for data in self.shard_id_to_manifest[sid]}

def basename(d: dict) -> str:
return (
k[: -len(m.group(1))] if (m := offset_pattern.match(k := d["audio_filepath"])) is not None else k
)

shard_manifest: dict[str, list[dict]] = groupby(basename, self.shard_id_to_manifest[sid])
tar_path = self.shard_id_to_tar_path[sid]
with tarfile.open(fileobj=open_best(tar_path, mode="rb"), mode="r|*") as tar:
for tar_info in tar:
assert tar_info.name in shard_manifest, (
f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). "
f"Cannot locate JSON entry for tar file '{tar_info.name}'"
)
data = shard_manifest[tar_info.name]
raw_audio = tar.extractfile(tar_info).read()
# Note: Lhotse has a Recording.from_bytes() utility that we won't use here because
# the profiling indicated significant overhead in torchaudio ffmpeg integration
Expand All @@ -353,21 +363,29 @@ def __iter__(self) -> Generator[Cut, None, None]:
num_samples=meta.frames,
duration=meta.duration,
)
cut = recording.to_cut()
cut.supervisions.append(
SupervisionSegment(
id=cut.id,
recording_id=cut.recording_id,
start=0,
duration=cut.duration,
text=data.get(self.text_field),
language=data.get(self.lang_field),
cuts_for_recording = []
for data in sorted(shard_manifest[tar_info.name], key=lambda d: d["audio_filepath"]):
# Cut the recording into corresponding segment and discard audio data outside the segment.
cut = make_cut_with_subset_inmemory_recording(
recording, offset=data.get("offset", 0.0), duration=data.get("duration")
)
)
cut.custom = _to_custom_attr_dict(data)
for extra_field in extra_fields:
extra_field.attach_to(cut)
yield cut
cut.supervisions.append(
SupervisionSegment(
id=cut.id,
recording_id=cut.recording_id,
start=0,
duration=cut.duration,
text=data.get(self.text_field),
language=data.get(self.lang_field),
)
)
cut.custom = _to_custom_attr_dict(data)
for extra_field in extra_fields:
extra_field.attach_to(cut)
cuts_for_recording.append(cut)
del recording # free the memory - helps with very large audio files
del raw_audio
yield from cuts_for_recording

def __len__(self) -> int:
return len(self.source)
Expand All @@ -376,6 +394,45 @@ def __add__(self, other):
return LazyIteratorChain(self, other)


def make_cut_with_subset_inmemory_recording(
recording: Recording, offset: float = 0.0, duration: float | None = None
) -> Cut:
"""
This method is built specifically to optimize CPU memory usage during dataloading
when reading tarfiles containing very long recordings (1h+).
Normally each cut would hold a reference to the long in-memory recording and load
the necessary subset of audio (there wouldn't be a separate copy of the long recording for each cut).
This is fairly efficient already, but we don't actually need to hold the unused full recording in memory.
Instead, we re-create each cut so that it only holds a reference to the subset of recording necessary.
This allows us to discard unused data which would otherwise be held in memory as part of sampling buffering.
"""

# Fast path: no offset and (almost) matching duration (within 200ms; leeway for different audio codec behavior).
cut = recording.to_cut()
if offset == 0.0 and duration is None or abs(duration - recording.duration) < 0.2:
return cut

# Otherwise, apply the memory optimization.
cut = cut.truncate(offset=offset, duration=duration, preserve_id=True)
audiobytes = BytesIO()
LibsndfileBackend().save_audio(audiobytes, cut.load_audio(), sampling_rate=cut.sampling_rate, format="wav")
audiobytes.seek(0)
new_recording = Recording(
id=recording.id,
sampling_rate=recording.sampling_rate,
num_samples=cut.num_samples,
duration=cut.duration,
sources=[
AudioSource(
type="memory",
channels=recording.channel_ids,
source=audiobytes.getvalue(),
)
],
)
return new_recording.to_cut()


class ExtraField:
TYPE = None
SUPPORTED_TYPES = {}
Expand Down
112 changes: 111 additions & 1 deletion tests/collections/common/test_lhotse_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
import numpy as np
import pytest
import torch
from lhotse import CutSet, MonoCut, NumpyFilesWriter, Recording
from lhotse import CutSet, MonoCut, NumpyFilesWriter, Recording, SupervisionSegment, compute_num_samples
from lhotse.audio import AudioLoadingError
from lhotse.cut import Cut, MixedCut
from lhotse.cut.text import TextPairExample
from lhotse.testing.dummies import dummy_recording
from omegaconf import OmegaConf

from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper
Expand Down Expand Up @@ -1718,3 +1719,112 @@ def test_dataloader_from_nemo_manifest_with_extra_questions_field_sample(
assert isinstance(c, MonoCut)
assert hasattr(c, "question")
assert c.question == "some question number 8"


@pytest.fixture(scope="session")
def nemo_tarred_manifest_path_with_offset(tmp_path_factory) -> Tuple[str, str]:
"""10 utterances of length 1s as a NeMo tarred manifest."""
from lhotse.serialization import SequentialJsonlWriter
from lhotse.shar.writers import TarWriter

root = tmp_path_factory.mktemp("nemo_tar_offset")
root.mkdir(exist_ok=True)
recording = dummy_recording(0, duration=10.0, with_data=True)

with (
TarWriter(f"{root}/audios_0.tar", shard_size=None) as tar_writer,
SequentialJsonlWriter(root / "tarred_audio_filepaths.jsonl") as mft_writer,
):
tar_writer.write(recording.id, BytesIO(recording.sources[0].source))
mft_writer.write(
{ # segment 0-3s
"audio_filepath": recording.id,
"offset": 0.0,
"duration": 3.0,
"text": "irrelevant",
"lang": "en",
"shard_id": 0,
}
)
mft_writer.write(
{ # segment 4-9s
"audio_filepath": recording.id + "-sub1",
"offset": 4.0,
"duration": 5.0,
"text": "irrelevant-2",
"lang": "en",
"shard_id": 0,
}
)
mft_writer.write(
{ # full recording - for reference
"audio_filepath": recording.id + "-sub2",
"offset": 0.0,
"duration": 10.0,
"text": "irrelevant irrelevant-2",
"lang": "en",
"shard_id": 0,
}
)
return mft_writer.path, tar_writer.output_paths[0]


def test_dataloader_from_tarred_nemo_manifest_with_offset(nemo_tarred_manifest_path_with_offset: tuple[str, str]):
json_mft, tar_mft = nemo_tarred_manifest_path_with_offset
config = OmegaConf.create(
{
"manifest_filepath": json_mft,
"tarred_audio_filepaths": tar_mft,
"sample_rate": 16000,
"shuffle": False,
"num_workers": 0,
"batch_size": 3,
"seed": 0,
"shard_seed": 0,
"force_finite": True,
}
)

dl = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=1, dataset=Identity())

# Loads all three examples in a single mini-batch (that's why batch_size=3).
batches = [b for b in dl]
assert len(batches) == 1
(batch,) = batches
assert len(batch) == 3

# Validate example containing full 10s recording.
full_cut = batch[-1]
assert full_cut.start == 0.0
assert full_cut.duration == 10.0
assert full_cut.supervisions[0].text == "irrelevant irrelevant-2"
assert full_cut.supervisions[0].language == "en"
full_audio = full_cut.load_audio()
assert full_audio.shape[1] == full_cut.num_samples == 160000 # 10s * 16kHz

# Validate segment 0-3s.
cut = batch[0]
assert cut.start == 0.0
assert cut.duration == 3.0
assert cut.supervisions[0].text == "irrelevant"
assert cut.supervisions[0].language == "en"
audio = cut.load_audio()
assert audio.shape[1] == cut.num_samples
# Check the audio for the segment is identical to a slice of the full audio.
np.testing.assert_equal(audio, full_audio[:, : compute_num_samples(cut.duration, cut.sampling_rate)])

# Validate segment 4-9s.
# Note: LazyNeMoTarredIterator removes the offset information, as it creates a new recording
# that's a "subset" of the original recording as a memory saving optimization.
# Hence, we will not see cut.start == 4.0.
cut = batch[1]
assert cut.start == 0.0
assert cut.duration == 5.0
assert cut.supervisions[0].text == "irrelevant-2"
assert cut.supervisions[0].language == "en"
audio = cut.load_audio()
assert audio.shape[1] == cut.num_samples
# Check the audio for the segment is identical to a slice of the full audio.
np.testing.assert_equal(
audio, full_audio[:, compute_num_samples(4.0, cut.sampling_rate) : compute_num_samples(9.0, cut.sampling_rate)]
)

0 comments on commit e5e648d

Please sign in to comment.