Skip to content

Commit

Permalink
[TTS] Remove align interpolator
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan <[email protected]>
  • Loading branch information
rlangman committed May 26, 2023
1 parent f9988d0 commit f23a1ef
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 36 deletions.
11 changes: 4 additions & 7 deletions examples/tts/conf/fastpitch/fastpitch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ model:
dur_loss_scale: 0.1
pitch_loss_scale: 0.1
energy_loss_scale: 0.1
aligner_loss_scale: 0.1

preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
Expand Down Expand Up @@ -93,22 +94,18 @@ model:
field: energy
stats_path: ${feature_stats_path}

align_prior_config:
hop_length: ${feature.hop_length}
use_beta_binomial_interpolator: false

train_ds:
dataset:
_target_: nemo.collections.tts.data.text_to_speech_dataset.TextToSpeechDataset
dataset_meta: ${train_ds_meta}
weighted_sample_steps: ${weighted_sample_steps}
sample_rate: ${feature.sample_rate}
speaker_path: ${speaker_path}
align_prior_hop_length: ${feature.hop_length}
featurizers: ${feature.featurizers}
feature_processors:
pitch: ${model.pitch_processor}
energy: ${model.energy_processor}
align_prior_config: ${model.align_prior_config}
min_duration: 0.1
max_duration: 10.0

Expand All @@ -122,11 +119,11 @@ model:
dataset_meta: ${val_ds_meta}
sample_rate: ${feature.sample_rate}
speaker_path: ${speaker_path}
align_prior_hop_length: ${feature.hop_length}
featurizers: ${feature.featurizers}
feature_processors:
pitch: ${model.pitch_processor}
energy: ${model.energy_processor}
align_prior_config: ${model.align_prior_config}

dataloader_params:
batch_size: ${batch_size}
Expand Down Expand Up @@ -155,7 +152,7 @@ model:
text_tokenizer: ${model.text_tokenizer}
sample_rate: ${feature.sample_rate}
speaker_path: ${speaker_path}
align_prior_config: ${model.align_prior_config}
align_prior_hop_length: ${feature.hop_length}
featurizers: ${feature.featurizers}

feature_processors:
Expand Down
39 changes: 10 additions & 29 deletions nemo/collections/tts/data/text_to_speech_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,6 @@ class DatasetSample:
speaker_index: int = None


@dataclass
class AlignPriorConfig:
hop_length: int
use_beta_binomial_interpolator: bool = False


@experimental
class TextToSpeechDataset(Dataset):
"""
Expand All @@ -78,8 +72,8 @@ class TextToSpeechDataset(Dataset):
featurizers: Optional, list of featurizers to load feature data from. Should be the same config provided
when running scripts.dataset_processing.tts.compute_features.py before training.
feature_processors: Optional, list of feature processors to run on training examples.
align_prior_config: Optional, if provided alignment prior will be calculated and included in
batch output.
align_prior_hop_length: Optional int, hop length of audio features.
If provided alignment prior will be calculated and included in batch output.
min_duration: Optional float, if provided audio files in the training manifest shorter than 'min_duration'
will be ignored.
max_duration: Optional float, if provided audio files in the training manifest longer than 'max_duration'
Expand All @@ -95,7 +89,7 @@ def __init__(
speaker_path: Optional[Path] = None,
featurizers: Optional[Dict[str, Featurizer]] = None,
feature_processors: Optional[Dict[str, FeatureProcessor]] = None,
align_prior_config: Optional[Dict] = None,
align_prior_hop_length: Optional[int] = None,
min_duration: Optional[float] = None,
max_duration: Optional[float] = None,
):
Expand All @@ -104,6 +98,8 @@ def __init__(
self.sample_rate = sample_rate
self.text_tokenizer = text_tokenizer
self.weighted_sample_steps = weighted_sample_steps
self.align_prior_hop_length = align_prior_hop_length
self.include_align_prior = self.align_prior_hop_length is not None

if speaker_path:
self.include_speaker = True
Expand All @@ -125,16 +121,6 @@ def __init__(
else:
self.feature_processors = []

if align_prior_config:
self.align_prior_config = AlignPriorConfig(**align_prior_config)
if self.align_prior_config.use_beta_binomial_interpolator:
self.beta_binomial_interpolator = BetaBinomialInterpolator()
else:
self.beta_binomial_interpolator = None
else:
self.align_prior_config = None
self.beta_binomial_interpolator = None

self.data_samples = []
self.sample_weights = []
for dataset_name, dataset_info in dataset_meta.items():
Expand Down Expand Up @@ -224,15 +210,10 @@ def __getitem__(self, index):
example["speaker"] = data.speaker
example["speaker_index"] = data.speaker_index

if self.align_prior_config:
if self.include_align_prior:
text_len = len(tokens)
spec_len = 1 + librosa.core.samples_to_frames(
audio.shape[0], hop_length=self.align_prior_config.hop_length
)
if self.beta_binomial_interpolator:
align_prior = self.beta_binomial_interpolator(w=spec_len, h=text_len)
else:
align_prior = beta_binomial_prior_distribution(phoneme_count=text_len, mel_count=spec_len)
spec_len = 1 + librosa.core.samples_to_frames(audio.shape[0], hop_length=self.align_prior_hop_length)
align_prior = beta_binomial_prior_distribution(phoneme_count=text_len, mel_count=spec_len)
align_prior = torch.tensor(align_prior, dtype=torch.float32)
example["align_prior"] = align_prior

Expand Down Expand Up @@ -270,7 +251,7 @@ def collate_fn(self, batch: List[dict]):
if self.include_speaker:
speaker_list.append(example["speaker_index"])

if self.align_prior_config:
if self.include_align_prior:
prior_list.append(example["align_prior"])

batch_audio_len = torch.IntTensor(audio_len_list)
Expand All @@ -293,7 +274,7 @@ def collate_fn(self, batch: List[dict]):
if self.include_speaker:
batch_dict["speaker_id"] = torch.IntTensor(speaker_list)

if self.align_prior_config:
if self.include_align_prior:
spec_max_len = max([prior.shape[0] for prior in prior_list])
text_max_len = max([prior.shape[1] for prior in prior_list])
batch_dict["align_prior_matrix"] = stack_tensors(prior_list, max_lens=[text_max_len, spec_max_len],)
Expand Down

0 comments on commit f23a1ef

Please sign in to comment.