Skip to content

Commit

Permalink
style error fix
Browse files Browse the repository at this point in the history
Signed-off-by: Paarth Neekhara <[email protected]>
  • Loading branch information
paarthneekhara committed Aug 20, 2021
1 parent 8a6f8b0 commit e9dfa32
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions examples/tts/fastpitch2_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
from nemo.collections.common.callbacks import LogEpochTimeCallback
from nemo.collections.tts.models import FastPitchModel
from nemo.core.config import hydra_runner
from nemo.utils.exp_manager import exp_manager
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager


@hydra_runner(config_path="conf", config_name="fastpitch_align_44100")
def main(cfg):
if hasattr(cfg.model.optim, 'sched'):
logging.warning("You are using an optimizer scheduler while finetuning. Are you sure this is intended?")
if (cfg.model.optim.lr > 1e-3 or cfg.model.optim.lr < 1e-5):
if cfg.model.optim.lr > 1e-3 or cfg.model.optim.lr < 1e-5:
logging.warning("The recommended learning rate for finetuning is 2e-4")
trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/asr/data/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ def __getitem__(self, item):
pitch -= self.pitch_avg
pitch[pitch == -self.pitch_avg] = 0.0 # Zero out values that were perviously zero
pitch /= self.pitch_std

speaker = None
if self.collection[item].speaker is not None:
speaker = torch.zeros_like(text_len).fill_(self.collection[item].speaker)
Expand All @@ -649,7 +649,7 @@ def _collate_fn(self, batch):
audio, audio_len, text, text_len, attn_prior = super()._collate_fn(list(zip(*batch[:5])))
pitch_list = batch[5]
speaker_list = batch[6]

pitch = torch.zeros(len(pitch_list), max([pitch.shape[0] for pitch in pitch_list]))

for i, pitch_i in enumerate(pitch_list):
Expand All @@ -660,7 +660,7 @@ def _collate_fn(self, batch):
speakers.append(speaker_i)

speakers = torch.stack(speakers).to(text_len.dtype) if speakers[0] is not None else None

return audio, audio_len, text, text_len, attn_prior, pitch, speakers


Expand Down

0 comments on commit e9dfa32

Please sign in to comment.