Skip to content

Commit

Permalink
formatting (#6157)
Browse files Browse the repository at this point in the history
Signed-off-by: Артём Земляк <[email protected]>
Co-authored-by: Артём Земляк <[email protected]>
  • Loading branch information
ArtyomZemlyak and Артём Земляк committed Mar 14, 2023
1 parent 036579d commit 899cf81
Showing 1 changed file with 55 additions and 5 deletions.
60 changes: 55 additions & 5 deletions nemo/collections/tts/models/univnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
import torch.nn.functional as F
from hydra.utils import instantiate
from omegaconf import DictConfig, open_dict
from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch_lightning.loggers.wandb import WandbLogger

from nemo.collections.tts.losses.hifigan_losses import DiscriminatorLoss, GeneratorLoss
Expand All @@ -30,7 +30,7 @@
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types.elements import AudioSignal, MelSpectrogramType
from nemo.core.neural_types.neural_type import NeuralType
from nemo.core.optim.lr_scheduler import compute_max_steps
from nemo.core.optim.lr_scheduler import compute_max_steps, prepare_lr_scheduler
from nemo.utils import logging, model_utils

HAVE_WANDB = True
Expand Down Expand Up @@ -90,11 +90,61 @@ def _get_max_steps(self):
drop_last=self._train_dl.drop_last,
)

@staticmethod
def get_warmup_steps(max_steps, warmup_steps, warmup_ratio):
if warmup_steps is not None and warmup_ratio is not None:
raise ValueError(f'Either use warmup_steps or warmup_ratio for scheduler')

if warmup_steps is not None:
return warmup_steps

if warmup_ratio is not None:
return warmup_ratio * max_steps

raise ValueError(f'Specify warmup_steps or warmup_ratio for scheduler')

def configure_optimizers(self):
optim_g = instantiate(self._cfg.optim, params=self.generator.parameters(),)
optim_d = instantiate(self._cfg.optim, params=itertools.chain(self.mrd.parameters(), self.mpd.parameters()),)
optim_config = self._cfg.optim.copy()

OmegaConf.set_struct(optim_config, False)
sched_config = optim_config.pop("sched", None)
OmegaConf.set_struct(optim_config, True)

# Backward compatibility
if sched_config is None and 'sched' in self._cfg:
sched_config = self._cfg.sched

optim_g = instantiate(optim_config, params=self.generator.parameters(),)
optim_d = instantiate(optim_config, params=itertools.chain(self.mrd.parameters(), self.mpd.parameters()),)

return [optim_g, optim_d]
if sched_config is not None:
max_steps = self._cfg.get("max_steps", None)
if max_steps is None or max_steps < 0:
max_steps = self._get_max_steps()

warmup_steps = UnivNetModel.get_warmup_steps(
max_steps=max_steps,
warmup_steps=sched_config.get("warmup_steps", None),
warmup_ratio=sched_config.get("warmup_ratio", None),
)

OmegaConf.set_struct(sched_config, False)
sched_config["max_steps"] = max_steps
sched_config["warmup_steps"] = warmup_steps
sched_config.pop("warmup_ratio", None)
OmegaConf.set_struct(sched_config, True)

scheduler_g = prepare_lr_scheduler(
optimizer=optim_g, scheduler_config=sched_config, train_dataloader=self._train_dl
)

scheduler_d = prepare_lr_scheduler(
optimizer=optim_d, scheduler_config=sched_config, train_dataloader=self._train_dl
)

return [optim_g, optim_d], [scheduler_g, scheduler_d]
else:
return [optim_g, optim_d]

@typecheck()
def forward(self, *, spec):
Expand Down

0 comments on commit 899cf81

Please sign in to comment.