Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parts and num_jobs options for tedlium #1030

Merged
merged 2 commits into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions lhotse/bin/modes/recipes/tedlium.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import List

import click

from lhotse.bin.modes import download, prepare
from lhotse.recipes.tedlium import download_tedlium, prepare_tedlium
from lhotse.recipes.tedlium import TEDLIUM_PARTS, download_tedlium, prepare_tedlium
from lhotse.utils import Pathlike


Expand All @@ -10,11 +12,33 @@
"tedlium_dir", type=click.Path(exists=True, dir_okay=True, file_okay=False)
)
@click.argument("output_dir", type=click.Path())
def tedlium(tedlium_dir: Pathlike, output_dir: Pathlike):
@click.option(
"--parts",
"-p",
type=click.Choice(TEDLIUM_PARTS),
multiple=True,
default=list(TEDLIUM_PARTS),
help="Which parts of TED-LIUM v3 to prepare (by default all).",
)
@click.option(
"-j",
"--num-jobs",
type=int,
default=1,
help="How many threads to use (can give good speed-ups with slow disks).",
)
def tedlium(
tedlium_dir: Pathlike, output_dir: Pathlike, parts: List[str], num_jobs: int
):
"""
TED-LIUM v3 recording and supervision manifest preparation.
"""
prepare_tedlium(tedlium_root=tedlium_dir, output_dir=output_dir)
prepare_tedlium(
tedlium_root=tedlium_dir,
output_dir=output_dir,
dataset_parts=parts,
num_jobs=num_jobs,
)


@download.command()
Expand Down
116 changes: 73 additions & 43 deletions lhotse/recipes/tedlium.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,21 @@
import logging
import shutil
import tarfile
from concurrent.futures.thread import ThreadPoolExecutor
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, Optional, Sequence, Union

from lhotse import (
Recording,
RecordingSet,
SupervisionSegment,
SupervisionSet,
validate_recordings_and_supervisions,
)
from lhotse.qa import fix_manifests
from lhotse.utils import Pathlike, safe_extract, urlretrieve_progress

TEDLIUM_PARTS = ("train", "dev", "test")


def download_tedlium(
target_dir: Pathlike = ".", force_download: Optional[bool] = False
Expand All @@ -82,7 +85,10 @@ def download_tedlium(


def prepare_tedlium(
tedlium_root: Pathlike, output_dir: Optional[Pathlike] = None
tedlium_root: Pathlike,
output_dir: Optional[Pathlike] = None,
dataset_parts: Union[str, Sequence[str]] = TEDLIUM_PARTS,
num_jobs: int = 1,
) -> Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]:
"""
Prepare manifests for the TED-LIUM v3 corpus.
Expand All @@ -91,50 +97,74 @@ def prepare_tedlium(
Each split contains a RecordingSet and SupervisionSet in a dict under keys 'recordings' and 'supervisions'.

:param tedlium_root: Path to the unpacked TED-LIUM data.
:param output_dir: Path where the manifests should be written.
:param dataset_parts: Which parts of the dataset to prepare.
By default, all parts are prepared.
:param num_jobs: Number of parallel jobs to use.
:return: A dict with standard corpus splits containing the manifests.
"""
tedlium_root = Path(tedlium_root)
output_dir = Path(output_dir) if output_dir is not None else None
corpus = {}
for split in ("train", "dev", "test"):
root = tedlium_root / "legacy" / split
recordings = RecordingSet.from_recordings(
Recording.from_file(p) for p in (root / "sph").glob("*.sph")
)
stms = list((root / "stm").glob("*.stm"))
assert len(stms) == len(recordings), (
f"Mismatch: found {len(recordings)} "
f"sphere files and {len(stms)} STM files. "
f"You might be missing some parts of TEDLIUM..."
)
segments = []
for p in stms:
with p.open() as f:
for idx, l in enumerate(f):
rec_id, _, _, start, end, _, *words = l.split()
start, end = float(start), float(end)
text = " ".join(words).replace("{NOISE}", "[NOISE]")
if text == "ignore_time_segment_in_scoring":
continue
segments.append(
SupervisionSegment(
id=f"{rec_id}-{idx}",
recording_id=rec_id,
start=start,
duration=round(end - start, ndigits=8),
channel=0,
text=text,
language="English",
speaker=rec_id,
)
)
supervisions = SupervisionSet.from_segments(segments)
corpus[split] = {"recordings": recordings, "supervisions": supervisions}

validate_recordings_and_supervisions(**corpus[split])

if output_dir is not None:
recordings.to_file(output_dir / f"tedlium_recordings_{split}.jsonl.gz")
supervisions.to_file(output_dir / f"tedlium_supervisions_{split}.jsonl.gz")

dataset_parts = [dataset_parts] if isinstance(dataset_parts, str) else dataset_parts

with ThreadPoolExecutor(num_jobs) as ex:
for split in dataset_parts:
logging.info(f"Processing {split} split...")
root = tedlium_root / "legacy" / split
recordings = RecordingSet.from_dir(
root / "sph", pattern="*.sph", num_jobs=num_jobs
)
stms = list((root / "stm").glob("*.stm"))
assert len(stms) == len(recordings), (
f"Mismatch: found {len(recordings)} "
f"sphere files and {len(stms)} STM files. "
f"You might be missing some parts of TEDLIUM..."
)
futures = []
for stm in stms:
futures.append(ex.submit(_parse_stm_file, stm))

segments = []
for future in futures:
segments.extend(future.result())

supervisions = SupervisionSet.from_segments(segments)
recordings, supervisions = fix_manifests(recordings, supervisions)

corpus[split] = {"recordings": recordings, "supervisions": supervisions}
validate_recordings_and_supervisions(**corpus[split])

if output_dir is not None:
recordings.to_file(output_dir / f"tedlium_recordings_{split}.jsonl.gz")
supervisions.to_file(
output_dir / f"tedlium_supervisions_{split}.jsonl.gz"
)

return corpus


def _parse_stm_file(stm: str) -> SupervisionSegment:
"""Helper function to parse a single STM file."""
segments = []
with stm.open() as f:
for idx, l in enumerate(f):
rec_id, _, _, start, end, _, *words = l.split()
start, end = float(start), float(end)
text = " ".join(words).replace("{NOISE}", "[NOISE]")
if text == "ignore_time_segment_in_scoring":
continue
segments.append(
SupervisionSegment(
id=f"{rec_id}-{idx}",
recording_id=rec_id,
start=start,
duration=round(end - start, ndigits=8),
channel=0,
text=text,
language="English",
speaker=rec_id,
)
)
return segments