Skip to content

Commit

Permalink
Fix discriminator update in AudioCodecModel (#7209)
Browse files Browse the repository at this point in the history
Signed-off-by: Ante Jukić <[email protected]>
  • Loading branch information
anteju committed Aug 10, 2023
1 parent 9c818ed commit ed4ce52
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 8 deletions.
3 changes: 3 additions & 0 deletions examples/tts/audio_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
# limitations under the License.

import pytorch_lightning as pl
from omegaconf import OmegaConf

from nemo.collections.tts.models import AudioCodecModel
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager


@hydra_runner(config_path="conf/audio_codec", config_name="audio_codec")
def main(cfg):
logging.info('\nConfig Params:\n%s', OmegaConf.to_yaml(cfg))
trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
model = AudioCodecModel(cfg=cfg.model, trainer=trainer)
Expand Down
4 changes: 3 additions & 1 deletion examples/tts/conf/audio_codec/encodec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ model:
samples_per_frame: ${samples_per_frame}
time_domain_loss_scale: 0.1
# Probability of updating the discriminator during each training step
disc_update_prob: 0.67
# For example, update the discriminator 2/3 times (2 updates for every 3 batches)
disc_updates_per_period: 2
disc_update_period: 3

# All resolutions for mel reconstruction loss, ordered [num_fft, hop_length, window_length]
mel_loss_resolutions: [
Expand Down
32 changes: 25 additions & 7 deletions nemo/collections/tts/models/audio_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.sample_rate = cfg.sample_rate
self.samples_per_frame = cfg.samples_per_frame

self.disc_update_prob = cfg.get("disc_update_prob", 1.0)
self.disc_updates_per_period = cfg.get("disc_updates_per_period", 1)
self.disc_update_period = cfg.get("disc_update_period", 1)
if self.disc_updates_per_period > self.disc_update_period:
raise ValueError(
f'Number of discriminator updates ({self.disc_updates_per_period}) per period must be less or equal to the configured period ({self.disc_update_period})'
)

self.audio_encoder = instantiate(cfg.audio_encoder)

# Optionally, add gaussian noise to encoder output as an information bottleneck
Expand Down Expand Up @@ -204,18 +210,26 @@ def _process_batch(self, batch):

return audio, audio_len, audio_gen, commit_loss

@property
def disc_update_prob(self) -> float:
"""Probability of updating the discriminator.
"""
return self.disc_updates_per_period / self.disc_update_period

def should_update_disc(self, batch_idx) -> bool:
"""Decide whether to update the descriminator based
on the batch index and configured discriminator update period.
"""
disc_update_step = batch_idx % self.disc_update_period
return disc_update_step < self.disc_updates_per_period

def training_step(self, batch, batch_idx):
optim_gen, optim_disc = self.optimizers()
optim_gen.zero_grad()

audio, audio_len, audio_gen, commit_loss = self._process_batch(batch)

if self.disc_update_prob < random.random():
loss_disc = None
else:
if self.should_update_disc(batch_idx):
# Train discriminator
optim_disc.zero_grad()

disc_scores_real, disc_scores_gen, _, _ = self.discriminator(
audio_real=audio, audio_gen=audio_gen.detach()
)
Expand All @@ -224,6 +238,9 @@ def training_step(self, batch, batch_idx):

self.manual_backward(train_disc_loss)
optim_disc.step()
optim_disc.zero_grad()
else:
loss_disc = None

loss_time_domain = self.time_domain_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len)
train_loss_time_domain = self.time_domain_loss_scale * loss_time_domain
Expand All @@ -245,6 +262,7 @@ def training_step(self, batch, batch_idx):

self.manual_backward(loss_gen_all)
optim_gen.step()
optim_gen.zero_grad()

self.update_lr()

Expand Down

0 comments on commit ed4ce52

Please sign in to comment.