Skip to content

Commit

Permalink
[TTS] Add output audio format to preprocessing (NVIDIA#6889)
Browse files Browse the repository at this point in the history
* [TTS] Add output audio format to preprocessing

Signed-off-by: Ryan <[email protected]>

* [TTS] Add format validation

Signed-off-by: Ryan <[email protected]>

* [TTS] Fix data tutorial

Signed-off-by: Ryan <[email protected]>

---------

Signed-off-by: Ryan <[email protected]>
Signed-off-by: dorotat <[email protected]>
  • Loading branch information
rlangman authored and dorotat-nv committed Aug 24, 2023
1 parent ed541bd commit 9541c22
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
23 changes: 23 additions & 0 deletions scripts/dataset_processing/tts/preprocess_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
--num_workers=1 \
--trim_config_path="<nemo_root_path>/examples/tts/conf/trim/energy.yaml" \
--output_sample_rate=22050 \
--output_format=flac \
--volume_level=0.95 \
--min_duration=0.5 \
--max_duration=20.0 \
Expand Down Expand Up @@ -97,6 +98,12 @@ def get_args():
parser.add_argument(
"--output_sample_rate", default=0, type=int, help="If provided, rate to resample the audio to."
)
parser.add_argument(
"--output_format",
default="wav",
type=str,
help="If provided, format output audio will be saved as. If not provided, will keep original format.",
)
parser.add_argument(
"--volume_level", default=0.0, type=float, help="If provided, peak volume to normalize audio to."
)
Expand All @@ -123,12 +130,18 @@ def _process_entry(
overwrite_audio: bool,
audio_trimmer: AudioTrimmer,
output_sample_rate: int,
output_format: str,
volume_level: float,
) -> Tuple[dict, float, float]:
audio_filepath = Path(entry["audio_filepath"])

audio_path, audio_path_rel = get_abs_rel_paths(input_path=audio_filepath, base_path=input_audio_dir)

if not output_format:
output_format = audio_path.suffix

output_path = output_audio_dir / audio_path_rel
output_path = output_path.with_suffix(output_format)
output_path.parent.mkdir(exist_ok=True, parents=True)

if output_path.exists() and not overwrite_audio:
Expand Down Expand Up @@ -159,6 +172,9 @@ def _process_entry(

if os.path.isabs(audio_filepath):
entry["audio_filepath"] = str(output_path)
else:
output_filepath = audio_path_rel.with_suffix(output_format)
entry["audio_filepath"] = str(output_filepath)

return entry, original_duration, output_duration

Expand All @@ -175,6 +191,7 @@ def main():
num_workers = args.num_workers
max_entries = args.max_entries
output_sample_rate = args.output_sample_rate
output_format = args.output_format
volume_level = args.volume_level
min_duration = args.min_duration
max_duration = args.max_duration
Expand All @@ -192,6 +209,11 @@ def main():
else:
audio_trimmer = None

if output_format:
if output_format.upper() not in sf.available_formats():
raise ValueError(f"Unsupported output audio format: {output_format}")
output_format = f".{output_format}"

output_audio_dir.mkdir(exist_ok=True, parents=True)

entries = read_manifest(input_manifest_path)
Expand All @@ -207,6 +229,7 @@ def main():
overwrite_audio=overwrite_audio,
audio_trimmer=audio_trimmer,
output_sample_rate=output_sample_rate,
output_format=output_format,
volume_level=volume_level,
)
for entry in tqdm(entries)
Expand Down
5 changes: 4 additions & 1 deletion tutorials/tts/FastPitch_Data_Preparation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,8 @@
"num_workers = 4\n",
"# Downsample data from 48khz to 44.1khz for compatibility\n",
"output_sample_rate = 44100\n",
"# Format of output audio files. Use \"flac\" to compress to a smaller file size.\n",
"output_format = \"flac\"\n",
"# Method for silence trimming. Can use \"energy.yaml\" or \"vad.yaml\".\n",
"# We use VAD for VCTK because the audio has background noise.\n",
"trim_config_path = NEMO_CONFIG_DIR / \"trim\" / \"vad.yaml\"\n",
Expand All @@ -475,6 +477,7 @@
" f\"--output_audio_dir={output_audio_dir}\",\n",
" f\"--num_workers={num_workers}\",\n",
" f\"--output_sample_rate={output_sample_rate}\",\n",
" f\"--output_format={output_format}\",\n",
" f\"--trim_config_path={trim_config_path}\",\n",
" f\"--volume_level={volume_level}\",\n",
" f\"--min_duration={min_duration}\",\n",
Expand Down Expand Up @@ -532,7 +535,7 @@
"source": [
"audio_file = \"p228_009.wav\"\n",
"audio_filepath = input_audio_dir / audio_file\n",
"processed_audio_filepath = output_audio_dir / audio_file\n",
"processed_audio_filepath = output_audio_dir / audio_file.replace(\".wav\", \".flac\")\n",
"\n",
"print(\"Original audio.\")\n",
"ipd.display(ipd.Audio(audio_filepath))\n",
Expand Down

0 comments on commit 9541c22

Please sign in to comment.