Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and enhance TIMIT recipe #1072

Merged
merged 4 commits into from
Aug 4, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 127 additions & 89 deletions lhotse/recipes/timit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
import logging
import zipfile
from collections import defaultdict
from concurrent.futures.thread import ThreadPoolExecutor
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import Dict, Optional, Union

from tqdm import tqdm
from tqdm.auto import tqdm

from lhotse import validate_recordings_and_supervisions
from lhotse.audio import Recording, RecordingSet
from lhotse.supervision import SupervisionSegment, SupervisionSet
from lhotse.supervision import AlignmentItem, SupervisionSegment, SupervisionSet
from lhotse.utils import Pathlike, resumable_download


Expand All @@ -37,15 +37,21 @@ def download_timit(
zip_path = target_dir / zip_name
corpus_dir = zip_path.with_suffix("")
completed_detector = corpus_dir / ".completed"

if completed_detector.is_file():
logging.info(f"Skipping {zip_name} because {completed_detector} exists.")
return corpus_dir

resumable_download(base_url, filename=zip_path, force_download=force_download)

with zipfile.ZipFile(zip_path) as zip_file:
corpus_dir.mkdir(parents=True, exist_ok=True)
for names in zip_file.namelist():
for names in tqdm(zip_file.namelist(), "Extracting files"):
zip_file.extract(names, str(corpus_dir))

zip_path.unlink()
completed_detector.touch()

return corpus_dir


Expand All @@ -59,6 +65,7 @@ def prepare_timit(
Returns the manifests which consists of the Recodings and Supervisions.
:param corpus_dir: Pathlike, the path of the data dir.
:param output_dir: Pathlike, the path where to write and save the manifests.
:param supervision_lvl: str='phone', the level of the supervision, 'phone', 'word' or 'text'.
:param num_phones: int=48, the number of phones (60, 48 or 39) for modeling and 48 is regarded as the default value.
:return: a Dict whose key is the dataset part, and the value is Dicts with the keys 'audio' and 'supervisions'.
"""
Expand All @@ -81,97 +88,129 @@ def prepare_timit(

dev_spks, test_spks = get_speakers()

with ThreadPoolExecutor(num_jobs) as ex:
for part in dataset_parts:
wav_files = []

if part == "TRAIN":
print("starting....")
wav_files = glob.glob(str(corpus_dir) + "/TRAIN/*/*/*.WAV")
# filter the SA (dialect sentences)
wav_files = list(
filter(lambda x: x.split("/")[-1][:2] != "SA", wav_files)
)
elif part == "DEV":
wav_files = glob.glob(str(corpus_dir) + "/TEST/*/*/*.WAV")
# filter the SA (dialect sentences)
wav_files = list(
filter(lambda x: x.split("/")[-1][:2] != "SA", wav_files)
)
wav_files = list(
filter(lambda x: x.split("/")[-2].lower() in dev_spks, wav_files)
)
else:
wav_files = glob.glob(str(corpus_dir) + "/TEST/*/*/*.WAV")
# filter the SA (dialect sentences)
wav_files = list(
filter(lambda x: x.split("/")[-1][:2] != "SA", wav_files)
)
wav_files = list(
filter(lambda x: x.split("/")[-2].lower() in test_spks, wav_files)
)

logging.debug(f"{part} dataset manifest generation.")
recordings = []
supervisions = []

for wav_file in tqdm(wav_files):
items = str(wav_file).strip().split("/")
idx = items[-2] + "-" + items[-1][:-4]
speaker = items[-2]
transcript_file = Path(wav_file).with_suffix(".PHN")
if not Path(wav_file).is_file():
logging.warning(f"No such file: {wav_file}")
continue
if not Path(transcript_file).is_file():
logging.warning(f"No transcript: {transcript_file}")
continue
text = []
with open(transcript_file, "r") as f:
lines = f.readlines()
for line in lines:
phone = line.rstrip("\n").split(" ")[-1]
if num_phones != 60:
phone = phones_dict[str(phone)]
text.append(phone)

text = " ".join(text).replace("h#", "sil")

recording = Recording.from_file(path=wav_file, recording_id=idx)
recordings.append(recording)
segment = SupervisionSegment(
id=idx,
recording_id=idx,
start=0.0,
duration=recording.duration,
channel=0,
language="English",
speaker=speaker,
text=text.strip(),
)

supervisions.append(segment)

recording_set = RecordingSet.from_recordings(recordings)
supervision_set = SupervisionSet.from_segments(supervisions)
validate_recordings_and_supervisions(recording_set, supervision_set)

if output_dir is not None:
supervision_set.to_file(
output_dir / f"timit_supervisions_{part}.jsonl.gz"
for part in dataset_parts:
wav_files = []

if part == "TRAIN":
print("starting....")
wav_files = glob.glob(str(corpus_dir) + "/data/TRAIN/*/*/*.WAV")
# filter the SA (dialect sentences)
# wav_files = list(filter(lambda x: x.split("/")[-1][:2] != "SA", wav_files))
elif part == "DEV":
wav_files = glob.glob(str(corpus_dir) + "/data/TEST/*/*/*.WAV")
# filter the SA (dialect sentences)
# wav_files = list(filter(lambda x: x.split("/")[-1][:2] != "SA", wav_files))
wav_files = list(
filter(lambda x: x.split("/")[-2].lower() in dev_spks, wav_files)
)
else:
wav_files = glob.glob(str(corpus_dir) + "/data/TEST/*/*/*.WAV")
# filter the SA (dialect sentences)
# wav_files = list(filter(lambda x: x.split("/")[-1][:2] != "SA", wav_files))
wav_files = list(
filter(lambda x: x.split("/")[-2].lower() in test_spks, wav_files)
)

logging.debug(f"{part} dataset manifest generation.")
recordings = []
supervisions = []

if num_jobs <= 1:
for wav_file in tqdm(wav_files, f"Preparing {part} manifest"):
try:
recording, supervision = prepare_recording(
wav_file, num_phones, phones_dict
)
recording_set.to_file(
output_dir / f"timit_recordings_{part}.jsonl.gz"
recordings.append(recording)
supervisions.append(supervision)
except FileNotFoundError as e:
logging.warning(e.strerror)
else:
with ProcessPoolExecutor(num_jobs) as ex:
results = []
for wav_file in wav_files:
results.append(
ex.submit(prepare_recording, wav_file, num_phones, phones_dict)
)

manifests[part] = {
"recordings": recording_set,
"supervisions": supervision_set,
}
for r in tqdm(results, f"Preparing {part} manifest"):
try:
recording, supervision = r.result()
recordings.append(recording)
supervisions.append(supervision)
except FileNotFoundError as e:
logging.warning(e.strerror)

recording_set = RecordingSet.from_recordings(recordings)
supervision_set = SupervisionSet.from_segments(supervisions)
validate_recordings_and_supervisions(recording_set, supervision_set)

if output_dir is not None:
supervision_set.to_file(output_dir / f"timit_supervisions_{part}.jsonl.gz")
recording_set.to_file(output_dir / f"timit_recordings_{part}.jsonl.gz")

manifests[part] = {
"recordings": recording_set,
"supervisions": supervision_set,
}

return manifests


def prepare_recording(wav_file, num_phones, phones_dict):
items = str(wav_file).strip().split("/")
idx = items[-2] + "-" + items[-1][:-4]
speaker = items[-2]

text_file = Path(wav_file).with_suffix(".TXT")
word_file = Path(wav_file).with_suffix(".WRD")
phone_file = Path(wav_file).with_suffix(".PHN")

recording = Recording.from_file(path=wav_file, recording_id=idx)

with open(text_file, "r") as f:
text = " ".join(f.read().rstrip("\n").split(" ")[2:])

word_alignments = []
with open(word_file, "r") as f:
lines = f.readlines()
for line in lines:
st, et, word = line.strip().split(" ")
start = float(st) / recording.sampling_rate
end = float(et) / recording.sampling_rate

word_alignments.append(AlignmentItem(word, start, end - start))

phone_alignments = []
with open(phone_file, "r") as f:
lines = f.readlines()
for line in lines:
st, et, phone = line.strip().split(" ")
start = float(st) / recording.sampling_rate
end = float(et) / recording.sampling_rate
if num_phones != 60:
phone = phones_dict[phone]

phone_alignments.append(AlignmentItem(phone, start, end - start))

segment = SupervisionSegment(
id=idx,
recording_id=idx,
start=0.0,
duration=recording.duration,
channel=0,
language="English",
speaker=speaker,
gender="male" if speaker.lower().startswith("m") else "female",
text=text.strip(),
)
segment = segment.with_alignment("word", word_alignments).with_alignment(
"phone", phone_alignments
)

return recording, segment


def get_phonemes(num_phones):
"""
Choose and convert the phones for modeling.
Expand Down Expand Up @@ -318,7 +357,6 @@ def get_phonemes(num_phones):


def get_speakers():

# List of test speakers
test_spk = [
"fdhc0",
Expand Down