Skip to content

Commit

Permalink
[TTS] Fixed epoch LR scheduling
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan <[email protected]>
  • Loading branch information
rlangman committed May 25, 2023
1 parent c3d568d commit 37f75d6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
20 changes: 14 additions & 6 deletions nemo/collections/tts/models/hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):

self.log_audio = cfg.get("log_audio", False)
self.log_config = cfg.get("log_config", None)
self.lr_schedule_interval = None
self.automatic_optimization = False

@property
Expand Down Expand Up @@ -144,8 +145,17 @@ def configure_optimizers(self):
optimizer=optim_d, scheduler_config=sched_config, train_dataloader=self._train_dl
)

self.lr_schedule_interval = scheduler_g["interval"]

return [optim_g, optim_d], [scheduler_g, scheduler_d]

def update_lr(self, interval="step"):
schedulers = self.lr_schedulers()
if schedulers is not None and self.lr_schedule_interval == interval:
sch1, sch2 = schedulers
sch1.step()
sch2.step()

@typecheck()
def forward(self, *, spec):
"""
Expand Down Expand Up @@ -199,12 +209,7 @@ def training_step(self, batch, batch_idx):
self.manual_backward(loss_g)
optim_g.step()

# Run schedulers
schedulers = self.lr_schedulers()
if schedulers is not None:
sch1, sch2 = schedulers
sch1.step()
sch2.step()
self.update_lr()

metrics = {
"g_loss_fm_mpd": loss_fm_mpd,
Expand All @@ -221,6 +226,9 @@ def training_step(self, batch, batch_idx):
self.log_dict(metrics, on_step=True, sync_dist=True)
self.log("g_l1_loss", loss_mel, prog_bar=True, logger=False, sync_dist=True)

def training_epoch_end(self, outputs) -> None:
self.update_lr("epoch")

def validation_step(self, batch, batch_idx):
audio, audio_len, audio_mel, audio_mel_len = self._process_batch(batch)

Expand Down
4 changes: 3 additions & 1 deletion scripts/dataset_processing/tts/preprocess_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def get_args():
"--output_manifest", required=True, type=Path, help="Path to output training manifest with processed text.",
)
parser.add_argument(
"--overwrite", default=False, type=bool, help="Whether to overwrite the output manifest file if it exists.",
"--overwrite",
action=argparse.BooleanOptionalAction,
help="Whether to overwrite the output manifest file if it exists.",
)
parser.add_argument(
"--lower_case", default=False, type=bool, help="Whether to convert the final text to lower case.",
Expand Down

0 comments on commit 37f75d6

Please sign in to comment.