Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci

Signed-off-by: hsiehjackson <[email protected]>
  • Loading branch information
pre-commit-ci[bot] authored and hsiehjackson committed Apr 7, 2023
1 parent d575591 commit 575b952
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 163 deletions.
20 changes: 11 additions & 9 deletions examples/tts/fastpitch_finetune_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ def main(cfg):
# Initialize FastPitchModel
model = FastPitchModel(cfg=update_model_config_to_support_adapter(cfg.model), trainer=trainer)
model.maybe_init_from_pretrained_checkpoint(cfg=cfg)



# Extract adapter parameters
with open_dict(cfg.model.adapter):
# get bool variable to determine if weighted speaker embeddings should be used or not
Expand All @@ -100,7 +99,7 @@ def main(cfg):
# Name of the adapter checkpoint which will be saved after training
adapter_state_dict_name = cfg.model.adapter.pop("adapter_state_dict_name", None)
weight_speaker_state_dict_name = cfg.model.adapter.pop("weight_speaker_state_dict_name", None)

# augment adapter name with module name, if not provided by user
if adapter_module_name is not None and ':' not in adapter_name:
adapter_name = f'{adapter_module_name}:{adapter_name}'
Expand All @@ -110,21 +109,21 @@ def main(cfg):

# Freeze model
model.freeze()

# Add weighted speaker embedding module in speaker representation
if add_weight_speaker and model.fastpitch.speaker_emb is not None:
old_emb = model.fastpitch.speaker_emb
new_emb = WeightedSpeakerEmbedding(pretrained_embedding=old_emb, speaker_list=add_weight_speaker_list)
model.fastpitch.speaker_emb = new_emb
model.fastpitch.speaker_emb.embedding_weight.requires_grad = True

############ Current experiment - unfreeze gst and titanet embedding
# for name, param in model.fastpitch.speaker_encoder.gst_module.style_attention.named_parameters():
# for name, param in model.fastpitch.speaker_encoder.gst_module.named_parameters():
# for name, param in model.fastpitch.speaker_encoder.sv_projection_module.named_parameters():
# param.requires_grad = True
############ Current experiment - unfreeze gst and titanet embedding

# Setup adapters
if adapter_global_cfg is not None:
add_global_adapter_cfg(model, adapter_global_cfg)
Expand All @@ -135,14 +134,14 @@ def main(cfg):
# enable adapters
model.set_enabled_adapters(enabled=False)
model.set_enabled_adapters(adapter_name, enabled=True)

# Set model to training mode.
model = model.train()
# Then, Unfreeze just the adapter weights that were enabled above (no part of model)
model.unfreeze_enabled_adapters()
# summarize the model
model.summarize()

lr_logger = pl.callbacks.LearningRateMonitor()
epoch_time_logger = LogEpochTimeCallback()
trainer.callbacks.extend([lr_logger, epoch_time_logger])
Expand All @@ -157,7 +156,10 @@ def main(cfg):

# Save the adapter modules in a seperate file
model.save_adapters(os.path.join(state_path, adapter_state_dict_name))
torch.save(model.state_dict()['fastpitch.speaker_emb.embedding_weight'], os.path.join(state_path, weight_speaker_state_dict_name))
torch.save(
model.state_dict()['fastpitch.speaker_emb.embedding_weight'],
os.path.join(state_path, weight_speaker_state_dict_name),
)


if __name__ == '__main__':
Expand Down
53 changes: 23 additions & 30 deletions nemo/collections/tts/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@
AlignPriorMatrix,
Durations,
Energy,
GSTRefAudio,
LMTokens,
LogMel,
P_voiced,
Pitch,
SpeakerEmbedding,
SpeakerID,
TTSDataType,
Voiced_mask,
GSTRefAudio,
SpeakerEmbedding,
WithLens,
)
from nemo.core.classes import Dataset
Expand Down Expand Up @@ -211,7 +211,7 @@ def __init__(
self.text_normalizer.normalize
if isinstance(self.text_normalizer, Normalizer)
else self.text_normalizer
)
)

self.text_normalizer_call_kwargs = (
text_normalizer_call_kwargs if text_normalizer_call_kwargs is not None else {}
Expand Down Expand Up @@ -343,6 +343,7 @@ def __init__(
getattr(self, f"add_{data_type.name}")(**kwargs)

self.pad_multiple = pad_multiple

@staticmethod
def filter_files(data, ignore_file, min_duration, max_duration, total_duration):
if ignore_file:
Expand Down Expand Up @@ -453,12 +454,12 @@ def add_pitch(self, **kwargs):
if pitch_stats_path is not None:
with open(Path(pitch_stats_path), 'r', encoding="utf-8") as pitch_f:
self.pitch_stats = json.load(pitch_f)

def add_speaker_embedding(self, **kwargs):
embedding_path = kwargs.pop("speaker_embedding_path", None)
assert embedding_path is not None, "Speaker embedding path is required but got None."
self.speaker_embeddings = torch.from_numpy(np.load(embedding_path))

# saving voiced_mask and p_voiced with pitch
def add_voiced_mask(self, **kwargs):
self.voiced_mask_folder = kwargs.pop('voiced_mask_folder', None)
Expand Down Expand Up @@ -497,7 +498,7 @@ def add_gst_ref_audio(self, **kwargs):
for i, d in enumerate(self.data):
self.speaker_to_index[d.get('speaker_id', None)].append(i)
self.speaker_to_index = {k: set(v) for k, v in self.speaker_to_index.items()}

def get_spec(self, audio):
with torch.cuda.amp.autocast(enabled=False):
spec = self.stft(audio)
Expand Down Expand Up @@ -695,24 +696,25 @@ def __getitem__(self, index):
speaker_id = None
if SpeakerID in self.sup_data_types_set:
speaker_id = torch.tensor(sample["speaker_id"]).long()

gst_ref_audio, gst_ref_audio_length = None, None
if GSTRefAudio in self.sup_data_types_set:
ref_sample = sample
if self.speaker_to_index is not None:
if len(self.speaker_to_index[sample["speaker_id"]]) > 1:
ref_pool = self.speaker_to_index[sample["speaker_id"]] - set([index])
else:
else:
ref_pool = self.speaker_to_index[sample["speaker_id"]]
ref_sample = self.data[random.sample(ref_pool, 1)[0]]

ref_features = self.featurizer.process(
ref_sample["audio_filepath"],
trim=self.trim,
trim_ref=self.trim_ref,
trim_top_db=self.trim_top_db,
trim_frame_length=self.trim_frame_length,
trim_hop_length=self.trim_hop_length)
trim_hop_length=self.trim_hop_length,
)

gst_ref_audio, gst_ref_audio_length = ref_features, torch.tensor(ref_features.shape[0]).long()

Expand All @@ -723,7 +725,7 @@ def __getitem__(self, index):
if self.speaker_to_index is not None:
if len(self.speaker_to_index[sample["speaker_id"]]) > 1:
ref_pool = self.speaker_to_index[sample["speaker_id"]] - set([index])
else:
else:
ref_pool = self.speaker_to_index[sample["speaker_id"]]
ref_sample_index = random.sample(ref_pool, 1)[0]
speaker_emb = self.speaker_embeddings[ref_sample_index]
Expand Down Expand Up @@ -793,7 +795,7 @@ def general_collate_fn(self, batch):
max_pitches_len = max(pitches_lengths).item() if Pitch in self.sup_data_types_set else None
max_energies_len = max(energies_lengths).item() if Energy in self.sup_data_types_set else None
max_gst_ref_audio_len = max(gst_ref_audio_lengths).item() if GSTRefAudio in self.sup_data_types_set else None

if LogMel in self.sup_data_types_set:
log_mel_pad = torch.finfo(batch[0][4].dtype).tiny

Expand All @@ -818,21 +820,8 @@ def general_collate_fn(self, batch):
p_voiceds,
audios_shifted,
gst_ref_audios,
speaker_embs
) = (
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[],
[]
)
speaker_embs,
) = ([], [], [], [], [], [], [], [], [], [], [], [])

for i, sample_tuple in enumerate(batch):
(
Expand Down Expand Up @@ -892,9 +881,11 @@ def general_collate_fn(self, batch):

if SpeakerID in self.sup_data_types_set:
speaker_ids.append(speaker_id)

if GSTRefAudio in self.sup_data_types_set:
gst_ref_audios.append(general_padding(gst_ref_audio, gst_ref_audios_length.item(), max_gst_ref_audio_len))
gst_ref_audios.append(
general_padding(gst_ref_audio, gst_ref_audios_length.item(), max_gst_ref_audio_len)
)

if SpeakerEmbedding in self.sup_data_types_set:
speaker_embs.append(speaker_emb)
Expand All @@ -917,7 +908,9 @@ def general_collate_fn(self, batch):
"p_voiced": torch.stack(p_voiceds) if P_voiced in self.sup_data_types_set else None,
"audio_shifted": torch.stack(audios_shifted) if audio_shifted is not None else None,
"gst_ref_audio": torch.stack(gst_ref_audios) if GSTRefAudio in self.sup_data_types_set else None,
"gst_ref_audio_lens": torch.stack(gst_ref_audio_lengths) if GSTRefAudio in self.sup_data_types_set else None,
"gst_ref_audio_lens": torch.stack(gst_ref_audio_lengths)
if GSTRefAudio in self.sup_data_types_set
else None,
"speaker_embedding": torch.stack(speaker_embs) if SpeakerEmbedding in self.sup_data_types_set else None,
}

Expand Down
4 changes: 1 addition & 3 deletions nemo/collections/tts/losses/aligner_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ def forward(self, attn_logprob, in_lens, out_lens):
# Convert to log probabilities
# Note: Mask out probs beyond key_len
key_inds = torch.arange(max_key_len + 1, device=attn_logprob.device, dtype=torch.long)
attn_logprob.masked_fill_(
key_inds.view(1, 1, -1) > key_lens.view(1, -1, 1), -1e15 # key_inds >= key_lens+1
)
attn_logprob.masked_fill_(key_inds.view(1, 1, -1) > key_lens.view(1, -1, 1), -1e15) # key_inds >= key_lens+1
attn_logprob = self.log_softmax(attn_logprob)

# Target sequences
Expand Down
Loading

0 comments on commit 575b952

Please sign in to comment.