Skip to content

Commit

Permalink
[TTS] FastPitch speaker encoder (NVIDIA#6417)
Browse files Browse the repository at this point in the history
* Add initial codes

Signed-off-by: hsiehjackson <[email protected]>

* Remove wemb

Signed-off-by: hsiehjackson <[email protected]>

* Fix import

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Restore aligner loss

Signed-off-by: hsiehjackson <[email protected]>

* Add ConditionalInput

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix error and support pre-trained config

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Follow comments

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Rename config

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Change copyright and random weight test

Signed-off-by: hsiehjackson <[email protected]>

* Add initial codes

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: hsiehjackson <[email protected]>

* Fix import error

Signed-off-by: hsiehjackson <[email protected]>

* Add initial codes

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: hsiehjackson <[email protected]>

* Fix dataset error

Signed-off-by: hsiehjackson <[email protected]>

* Remove reference speaker embedding

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: hsiehjackson <[email protected]>

* Remove SV encoder

Signed-off-by: hsiehjackson <[email protected]>

* Follow comments

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: hsiehjackson <[email protected]>

* Fix length type

Signed-off-by: hsiehjackson <[email protected]>

* Fix append

Signed-off-by: hsiehjackson <[email protected]>

* Move error msg

Signed-off-by: hsiehjackson <[email protected]>

* Add look-up into speaker encoder

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: hsiehjackson <[email protected]>

* Add valueerror msg

Signed-off-by: hsiehjackson <[email protected]>

* Move lookup

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: hsiehjackson <[email protected]>

* Remove unused

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: hsiehjackson <[email protected]>

* Fix error

Signed-off-by: hsiehjackson <[email protected]>

* Rebase and Fix error

Signed-off-by: hsiehjackson <[email protected]>

* Fix spk encoder

Signed-off-by: hsiehjackson <[email protected]>

* Rename n_speakers

Signed-off-by: hsiehjackson <[email protected]>

* Follow comments

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix n_speakers None error

Signed-off-by: hsiehjackson <[email protected]>

---------

Signed-off-by: hsiehjackson <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: hsiehjackson <[email protected]>
  • Loading branch information
hsiehjackson and pre-commit-ci[bot] committed Jun 2, 2023
1 parent 10c551b commit c880e2c
Show file tree
Hide file tree
Showing 6 changed files with 432 additions and 28 deletions.
26 changes: 23 additions & 3 deletions examples/tts/conf/fastpitch_align_44100_adapter.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name: FastPitch
train_dataset: ???
validation_datasets: ???
sup_data_path: ???
sup_data_types: [ "align_prior_matrix", "pitch", "speaker_id"]
sup_data_types: [ "align_prior_matrix", "pitch", "speaker_id", "reference_audio"]

# Default values from librosa.pyin
pitch_fmin: 65.40639132514966
Expand Down Expand Up @@ -35,10 +35,8 @@ model:
learn_alignment: true
bin_loss_warmup_epochs: 100

n_speakers: 1
max_token_duration: 75
symbols_embedding_dim: 384
speaker_embedding_dim: 384
pitch_embedding_kernel_size: 3

pitch_fmin: ${pitch_fmin}
Expand Down Expand Up @@ -248,6 +246,28 @@ model:
n_layers: 2
condition_types: [ "add", "layernorm" ] # options: [ "add", "cat", "layernorm" ]

speaker_encoder:
_target_: nemo.collections.tts.modules.submodules.SpeakerEncoder
lookup_module:
_target_: nemo.collections.tts.modules.submodules.SpeakerLookupTable
n_speakers: ???
embedding_dim: ${model.symbols_embedding_dim}
gst_module:
_target_: nemo.collections.tts.modules.submodules.GlobalStyleToken
gst_size: ${model.symbols_embedding_dim}
n_style_token: 10
n_style_attn_head: 4
reference_encoder:
_target_: nemo.collections.tts.modules.submodules.ReferenceEncoder
n_mels: ${model.n_mel_channels}
cnn_filters: [32, 32, 64, 64, 128, 128]
dropout: 0.2
gru_hidden: ${model.symbols_embedding_dim}
kernel_size: 3
stride: 2
padding: 1
bias: true

optim:
name: adamw
lr: 1e-3
Expand Down
47 changes: 47 additions & 0 deletions nemo/collections/tts/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
LogMel,
P_voiced,
Pitch,
ReferenceAudio,
SpeakerID,
TTSDataType,
Voiced_mask,
Expand Down Expand Up @@ -483,6 +484,13 @@ def add_energy(self, **kwargs):
def add_speaker_id(self, **kwargs):
pass

def add_reference_audio(self, **kwargs):
assert SpeakerID in self.sup_data_types, "Please add speaker_id in sup_data_types."
"""Add a mapping for each speaker to their manifest indexes"""
self.speaker_to_index_map = defaultdict(set)
for i, d in enumerate(self.data):
self.speaker_to_index_map[d['speaker_id']].add(i)

def get_spec(self, audio):
with torch.cuda.amp.autocast(enabled=False):
spec = self.stft(audio)
Expand Down Expand Up @@ -522,6 +530,12 @@ def _pad_wav_to_multiple(self, wav):
)
return wav

# Random sample a reference index from the same speaker
def sample_reference_index(self, speaker_id):
reference_pool = self.speaker_to_index_map[speaker_id]
reference_index = random.sample(reference_pool, 1)[0]
return reference_index

def __getitem__(self, index):
sample = self.data[index]

Expand Down Expand Up @@ -683,6 +697,19 @@ def __getitem__(self, index):
if SpeakerID in self.sup_data_types_set:
speaker_id = torch.tensor(sample["speaker_id"]).long()

reference_audio, reference_audio_length = None, None
if ReferenceAudio in self.sup_data_types_set:
reference_index = self.sample_reference_index(sample["speaker_id"])
reference_audio = self.featurizer.process(
self.data[reference_index]["audio_filepath"],
trim=self.trim,
trim_ref=self.trim_ref,
trim_top_db=self.trim_top_db,
trim_frame_length=self.trim_frame_length,
trim_hop_length=self.trim_hop_length,
)
reference_audio_length = torch.tensor(reference_audio.shape[0]).long()

return (
audio,
audio_length,
Expand All @@ -700,6 +727,8 @@ def __getitem__(self, index):
voiced_mask,
p_voiced,
audio_shifted,
reference_audio,
reference_audio_length,
)

def __len__(self):
Expand Down Expand Up @@ -733,6 +762,8 @@ def general_collate_fn(self, batch):
voiced_masks,
p_voiceds,
_,
_,
reference_audio_lengths,
) = zip(*batch)

max_audio_len = max(audio_lengths).item()
Expand All @@ -741,6 +772,9 @@ def general_collate_fn(self, batch):
max_durations_len = max([len(i) for i in durations_list]) if Durations in self.sup_data_types_set else None
max_pitches_len = max(pitches_lengths).item() if Pitch in self.sup_data_types_set else None
max_energies_len = max(energies_lengths).item() if Energy in self.sup_data_types_set else None
max_reference_audio_len = (
max(reference_audio_lengths).item() if ReferenceAudio in self.sup_data_types_set else None
)

if LogMel in self.sup_data_types_set:
log_mel_pad = torch.finfo(batch[0][4].dtype).tiny
Expand All @@ -765,6 +799,7 @@ def general_collate_fn(self, batch):
voiced_masks,
p_voiceds,
audios_shifted,
reference_audios,
) = (
[],
[],
Expand All @@ -776,6 +811,7 @@ def general_collate_fn(self, batch):
[],
[],
[],
[],
)

for i, sample_tuple in enumerate(batch):
Expand All @@ -796,6 +832,8 @@ def general_collate_fn(self, batch):
voiced_mask,
p_voiced,
audio_shifted,
reference_audio,
reference_audios_length,
) = sample_tuple

audio = general_padding(audio, audio_len.item(), max_audio_len)
Expand Down Expand Up @@ -834,6 +872,11 @@ def general_collate_fn(self, batch):
if SpeakerID in self.sup_data_types_set:
speaker_ids.append(speaker_id)

if ReferenceAudio in self.sup_data_types_set:
reference_audios.append(
general_padding(reference_audio, reference_audios_length.item(), max_reference_audio_len)
)

data_dict = {
"audio": torch.stack(audios),
"audio_lens": torch.stack(audio_lengths),
Expand All @@ -851,6 +894,10 @@ def general_collate_fn(self, batch):
"voiced_mask": torch.stack(voiced_masks) if Voiced_mask in self.sup_data_types_set else None,
"p_voiced": torch.stack(p_voiceds) if P_voiced in self.sup_data_types_set else None,
"audio_shifted": torch.stack(audios_shifted) if audio_shifted is not None else None,
"reference_audio": torch.stack(reference_audios) if ReferenceAudio in self.sup_data_types_set else None,
"reference_audio_lens": torch.stack(reference_audio_lengths)
if ReferenceAudio in self.sup_data_types_set
else None,
}

return data_dict
Expand Down
78 changes: 67 additions & 11 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,17 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
output_fft = instantiate(self._cfg.output_fft)
duration_predictor = instantiate(self._cfg.duration_predictor)
pitch_predictor = instantiate(self._cfg.pitch_predictor)
speaker_encoder = instantiate(self._cfg.get("speaker_encoder", None))
energy_embedding_kernel_size = cfg.get("energy_embedding_kernel_size", 0)
energy_predictor = instantiate(self._cfg.get("energy_predictor", None))

# [TODO] may remove if we change the pre-trained config
# cfg: condition_types = [ "add" ]
n_speakers = cfg.get("n_speakers", 0)
speaker_emb_condition_prosody = cfg.get("speaker_emb_condition_prosody", False)
speaker_emb_condition_decoder = cfg.get("speaker_emb_condition_decoder", False)
speaker_emb_condition_aligner = cfg.get("speaker_emb_condition_aligner", False)
if cfg.n_speakers > 1:
if n_speakers > 1 and "add" not in input_fft.cond_input.condition_types:
input_fft.cond_input.condition_types.append("add")
if speaker_emb_condition_prosody:
duration_predictor.cond_input.condition_types.append("add")
Expand All @@ -163,7 +166,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
pitch_predictor,
energy_predictor,
self.aligner,
cfg.n_speakers,
speaker_encoder,
n_speakers,
cfg.symbols_embedding_dim,
cfg.pitch_embedding_kernel_size,
energy_embedding_kernel_size,
Expand Down Expand Up @@ -305,6 +309,9 @@ def parse(self, str_input: str, normalize=True) -> torch.tensor:
"attn_prior": NeuralType(('B', 'T_spec', 'T_text'), ProbsType(), optional=True),
"mel_lens": NeuralType(('B'), LengthsType(), optional=True),
"input_lens": NeuralType(('B'), LengthsType(), optional=True),
# reference_* data is used for multi-speaker FastPitch training
"reference_spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType(), optional=True),
"reference_spec_lens": NeuralType(('B'), LengthsType(), optional=True),
}
)
def forward(
Expand All @@ -320,6 +327,8 @@ def forward(
attn_prior=None,
mel_lens=None,
input_lens=None,
reference_spec=None,
reference_spec_lens=None,
):
return self.fastpitch(
text=text,
Expand All @@ -332,21 +341,43 @@ def forward(
attn_prior=attn_prior,
mel_lens=mel_lens,
input_lens=input_lens,
reference_spec=reference_spec,
reference_spec_lens=reference_spec_lens,
)

@typecheck(output_types={"spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType())})
def generate_spectrogram(
self, tokens: 'torch.tensor', speaker: Optional[int] = None, pace: float = 1.0
self,
tokens: 'torch.tensor',
speaker: Optional[int] = None,
pace: float = 1.0,
reference_spec: Optional['torch.tensor'] = None,
reference_spec_lens: Optional['torch.tensor'] = None,
) -> torch.tensor:
if self.training:
logging.warning("generate_spectrogram() is meant to be called in eval mode.")
if isinstance(speaker, int):
speaker = torch.tensor([speaker]).to(self.device)
spect, *_ = self(text=tokens, durs=None, pitch=None, speaker=speaker, pace=pace)
spect, *_ = self(
text=tokens,
durs=None,
pitch=None,
speaker=speaker,
pace=pace,
reference_spec=reference_spec,
reference_spec_lens=reference_spec_lens,
)
return spect

def training_step(self, batch, batch_idx):
attn_prior, durs, speaker, energy = None, None, None, None
attn_prior, durs, speaker, energy, reference_audio, reference_audio_len = (
None,
None,
None,
None,
None,
None,
)
if self.learn_alignment:
assert self.ds_class_name == "TTSDataset", f"Unknown dataset class: {self.ds_class_name}"
batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set)
Expand All @@ -358,10 +389,17 @@ def training_step(self, batch, batch_idx):
pitch = batch_dict.get("pitch", None)
energy = batch_dict.get("energy", None)
speaker = batch_dict.get("speaker_id", None)
reference_audio = batch_dict.get("reference_audio", None)
reference_audio_len = batch_dict.get("reference_audio_lens", None)
else:
audio, audio_lens, text, text_lens, durs, pitch, speaker = batch

mels, spec_len = self.preprocessor(input_signal=audio, length=audio_lens)
reference_spec, reference_spec_len = None, None
if reference_audio is not None:
reference_spec, reference_spec_len = self.preprocessor(
input_signal=reference_audio, length=reference_audio_len
)

(
mels_pred,
Expand All @@ -384,6 +422,8 @@ def training_step(self, batch, batch_idx):
speaker=speaker,
pace=1.0,
spec=mels if self.learn_alignment else None,
reference_spec=reference_spec,
reference_spec_lens=reference_spec_len,
attn_prior=attn_prior,
mel_lens=spec_len,
input_lens=text_lens,
Expand Down Expand Up @@ -441,7 +481,14 @@ def training_step(self, batch, batch_idx):
return loss

def validation_step(self, batch, batch_idx):
attn_prior, durs, speaker, energy = None, None, None, None
attn_prior, durs, speaker, energy, reference_audio, reference_audio_len = (
None,
None,
None,
None,
None,
None,
)
if self.learn_alignment:
assert self.ds_class_name == "TTSDataset", f"Unknown dataset class: {self.ds_class_name}"
batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set)
Expand All @@ -453,10 +500,17 @@ def validation_step(self, batch, batch_idx):
pitch = batch_dict.get("pitch", None)
energy = batch_dict.get("energy", None)
speaker = batch_dict.get("speaker_id", None)
reference_audio = batch_dict.get("reference_audio", None)
reference_audio_len = batch_dict.get("reference_audio_lens", None)
else:
audio, audio_lens, text, text_lens, durs, pitch, speaker = batch

mels, mel_lens = self.preprocessor(input_signal=audio, length=audio_lens)
reference_spec, reference_spec_len = None, None
if reference_audio is not None:
reference_spec, reference_spec_len = self.preprocessor(
input_signal=reference_audio, length=reference_audio_len
)

# Calculate val loss on ground truth durations to better align L2 loss in time
(mels_pred, _, _, log_durs_pred, pitch_pred, _, _, _, attn_hard_dur, pitch, energy_pred, energy_tgt,) = self(
Expand All @@ -467,6 +521,8 @@ def validation_step(self, batch, batch_idx):
speaker=speaker,
pace=1.0,
spec=mels if self.learn_alignment else None,
reference_spec=reference_spec,
reference_spec_lens=reference_spec_len,
attn_prior=attn_prior,
mel_lens=mel_lens,
input_lens=text_lens,
Expand Down Expand Up @@ -496,13 +552,13 @@ def validation_epoch_end(self, outputs):
mel_loss = collect("mel_loss")
dur_loss = collect("dur_loss")
pitch_loss = collect("pitch_loss")
self.log("val_loss", val_loss)
self.log("val_mel_loss", mel_loss)
self.log("val_dur_loss", dur_loss)
self.log("val_pitch_loss", pitch_loss)
self.log("val_loss", val_loss, sync_dist=True)
self.log("val_mel_loss", mel_loss, sync_dist=True)
self.log("val_dur_loss", dur_loss, sync_dist=True)
self.log("val_pitch_loss", pitch_loss, sync_dist=True)
if outputs[0]["energy_loss"] is not None:
energy_loss = collect("energy_loss")
self.log("val_energy_loss", energy_loss)
self.log("val_energy_loss", energy_loss, sync_dist=True)

_, _, _, _, _, spec_target, spec_predict = outputs[0].values()

Expand Down
Loading

0 comments on commit c880e2c

Please sign in to comment.