Skip to content

Commit

Permalink
[TTS] Add additional config to preprocess_text and compute_feature_st…
Browse files Browse the repository at this point in the history
…ats (#7321)

* [TTS] Add additional config to preprocess_text and compute_feature_stats

Signed-off-by: Ryan <[email protected]>

* [TTS] Rename batch_size to joblib_batch_size

Signed-off-by: Ryan <[email protected]>

---------

Signed-off-by: Ryan <[email protected]>
  • Loading branch information
rlangman committed Aug 29, 2023
1 parent 90feee2 commit f265ac4
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 43 deletions.
89 changes: 53 additions & 36 deletions scripts/dataset_processing/tts/compute_feature_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
$ python <nemo_root_path>/scripts/dataset_processing/tts/compute_feature_stats.py \
--feature_config_path=<nemo_root_path>/examples/tts/conf/features/feature_22050.yaml
--manifest_path=<data_root_path>/manifest.json \
--audio_dir=<data_root_path>/audio \
--feature_dir=<data_root_path>/features \
--manifest_path=<data_root_path>/manifest1.json \
--manifest_path=<data_root_path>/manifest2.json \
--audio_dir=<data_root_path>/audio1 \
--audio_dir=<data_root_path>/audio2 \
--feature_dir=<data_root_path>/features1 \
--feature_dir=<data_root_path>/features2 \
--stats_path=<data_root_path>/feature_stats.json
The output dictionary will contain the feature statistics for every speaker, as well as a "default" entry
Expand Down Expand Up @@ -74,13 +77,17 @@ def get_args():
"--feature_config_path", required=True, type=Path, help="Path to feature config file.",
)
parser.add_argument(
"--manifest_path", required=True, type=Path, help="Path to training manifest.",
"--manifest_path", required=True, type=Path, action="append", help="Path(s) to training manifest.",
)
parser.add_argument(
"--audio_dir", required=True, type=Path, help="Path to base directory with audio data.",
"--audio_dir", required=True, type=Path, action="append", help="Path(s) to base directory with audio data.",
)
parser.add_argument(
"--feature_dir", required=True, type=Path, help="Path to directory where feature data was stored.",
"--feature_dir",
required=True,
type=Path,
action="append",
help="Path(s) to directory where feature data was stored.",
)
parser.add_argument(
"--feature_names", default="pitch,energy", type=str, help="Comma separated list of features to process.",
Expand Down Expand Up @@ -118,26 +125,35 @@ def main():
args = get_args()

feature_config_path = args.feature_config_path
manifest_path = args.manifest_path
audio_dir = args.audio_dir
feature_dir = args.feature_dir
manifest_paths = args.manifest_path
audio_dirs = args.audio_dir
feature_dirs = args.feature_dir
feature_name_str = args.feature_names
mask_field = args.mask_field
stats_path = args.stats_path
overwrite = args.overwrite

if not manifest_path.exists():
raise ValueError(f"Manifest {manifest_path} does not exist.")

if not audio_dir.exists():
raise ValueError(f"Audio directory {audio_dir} does not exist.")

if not feature_dir.exists():
if not (len(manifest_paths) == len(audio_dirs) == len(feature_dirs)):
raise ValueError(
f"Feature directory {feature_dir} does not exist. "
f"Please check that the path is correct and that you ran compute_features.py"
f"Need same number of manifest, audio_dir, and feature_dir. Received: "
f"{len(manifest_paths)}, "
f"{len(audio_dirs)}, "
f"{len(feature_dirs)}"
)

for (manifest_path, audio_dir, feature_dir) in zip(manifest_paths, audio_dirs, feature_dirs):
if not manifest_path.exists():
raise ValueError(f"Manifest {manifest_path} does not exist.")

if not audio_dir.exists():
raise ValueError(f"Audio directory {audio_dir} does not exist.")

if not feature_dir.exists():
raise ValueError(
f"Feature directory {feature_dir} does not exist. "
f"Please check that the path is correct and that you ran compute_features.py"
)

if stats_path.exists():
if overwrite:
print(f"Will overwrite existing stats path: {stats_path}")
Expand All @@ -156,29 +172,30 @@ def main():
# for that speaker
feature_stats = {name: defaultdict(list) for name in feature_names}

entries = read_manifest(manifest_path)
for (manifest_path, audio_dir, feature_dir) in zip(manifest_paths, audio_dirs, feature_dirs):
entries = read_manifest(manifest_path)

for entry in tqdm(entries):
speaker = entry["speaker"]
for entry in tqdm(entries):
speaker = entry["speaker"]

entry_dict = {}
for featurizer in featurizers:
feature_dict = featurizer.load(manifest_entry=entry, audio_dir=audio_dir, feature_dir=feature_dir)
entry_dict.update(feature_dict)
entry_dict = {}
for featurizer in featurizers:
feature_dict = featurizer.load(manifest_entry=entry, audio_dir=audio_dir, feature_dir=feature_dir)
entry_dict.update(feature_dict)

if mask_field:
mask = entry_dict[mask_field]
else:
mask = None
if mask_field:
mask = entry_dict[mask_field]
else:
mask = None

for feature_name in feature_names:
values = entry_dict[feature_name]
if mask is not None:
values = values[mask]
for feature_name in feature_names:
values = entry_dict[feature_name]
if mask is not None:
values = values[mask]

feature_stat_dict = feature_stats[feature_name]
feature_stat_dict["default"].append(values)
feature_stat_dict[speaker].append(values)
feature_stat_dict = feature_stats[feature_name]
feature_stat_dict["default"].append(values)
feature_stat_dict[speaker].append(values)

stat_dict = defaultdict(dict)
for feature_name in feature_names:
Expand Down
42 changes: 35 additions & 7 deletions scripts/dataset_processing/tts/preprocess_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
--output_manifest="<data_root_path>/manifest_processed.json" \
--normalizer_config_path="<nemo_root_path>/examples/tts/conf/text/normalizer_en.yaml" \
--lower_case=True \
--num_workers=1
--num_workers=4 \
--batch_size=16
"""

import argparse
Expand Down Expand Up @@ -54,7 +55,13 @@ def get_args():
help="Whether to overwrite the output manifest file if it exists.",
)
parser.add_argument(
"--lower_case", default=False, type=bool, help="Whether to convert the final text to lower case.",
"--text_key", default="text", type=str, help="Input text field to normalize.",
)
parser.add_argument(
"--normalized_text_key", default="normalized_text", type=str, help="Output field to save normalized text to.",
)
parser.add_argument(
"--lower_case", action=argparse.BooleanOptionalAction, help="Whether to convert the final text to lower case.",
)
parser.add_argument(
"--normalizer_config_path",
Expand All @@ -65,6 +72,9 @@ def get_args():
parser.add_argument(
"--num_workers", default=1, type=int, help="Number of parallel threads to use. If -1 all CPUs are used."
)
parser.add_argument(
"--joblib_batch_size", type=int, help="Batch size for joblib workers. Defaults to 'auto' if not provided."
)
parser.add_argument(
"--max_entries", default=0, type=int, help="If provided, maximum number of entries in the manifest to process."
)
Expand All @@ -73,8 +83,15 @@ def get_args():
return args


def _process_entry(entry: dict, normalizer: Normalizer, lower_case: bool, lower_case_norm: bool) -> dict:
text = entry["text"]
def _process_entry(
entry: dict,
normalizer: Normalizer,
text_key: str,
normalized_text_key: str,
lower_case: bool,
lower_case_norm: bool,
) -> dict:
text = entry[text_key]

if normalizer is not None:
if lower_case_norm:
Expand All @@ -84,7 +101,7 @@ def _process_entry(entry: dict, normalizer: Normalizer, lower_case: bool, lower_
if lower_case:
text = text.lower()

entry["normalized_text"] = text
entry[normalized_text_key] = text

return entry

Expand All @@ -94,8 +111,11 @@ def main():

input_manifest_path = args.input_manifest
output_manifest_path = args.output_manifest
text_key = args.text_key
normalized_text_key = args.normalized_text_key
lower_case = args.lower_case
num_workers = args.num_workers
batch_size = args.joblib_batch_size
max_entries = args.max_entries
overwrite = args.overwrite

Expand All @@ -117,9 +137,17 @@ def main():
if max_entries:
entries = entries[:max_entries]

output_entries = Parallel(n_jobs=num_workers)(
if not batch_size:
batch_size = 'auto'

output_entries = Parallel(n_jobs=num_workers, batch_size=batch_size)(
delayed(_process_entry)(
entry=entry, normalizer=normalizer, lower_case=lower_case, lower_case_norm=lower_case_norm
entry=entry,
normalizer=normalizer,
text_key=text_key,
normalized_text_key=normalized_text_key,
lower_case=lower_case,
lower_case_norm=lower_case_norm,
)
for entry in tqdm(entries)
)
Expand Down

0 comments on commit f265ac4

Please sign in to comment.