diff --git a/nemo/collections/tts/losses/audio_codec_loss.py b/nemo/collections/tts/losses/audio_codec_loss.py index 124256f103d99..5a36d03783717 100644 --- a/nemo/collections/tts/losses/audio_codec_loss.py +++ b/nemo/collections/tts/losses/audio_codec_loss.py @@ -119,6 +119,7 @@ class MultiResolutionMelLoss(Loss): mel_dims: Dimension of mel spectrogram to compute for each resolution. Should be same length as 'resolutions'. log_guard: Value to add to mel spectrogram to avoid taking log of 0. """ + def __init__(self, sample_rate: int, resolutions: List[List], mel_dims: List[int], log_guard: float = 1.0): super(MultiResolutionMelLoss, self).__init__() assert len(resolutions) == len(mel_dims) @@ -186,6 +187,7 @@ class STFTLoss(Loss): log_guard: Value to add to magnitude spectrogram to avoid taking log of 0. sqrt_guard: Value to add to when computing absolute value of STFT to avoid NaN loss. """ + def __init__(self, resolution: List[int], log_guard: float = 1.0, sqrt_guard: float = 1e-5): super(STFTLoss, self).__init__() self.loss_fn = MaskedMAELoss() @@ -242,6 +244,7 @@ class MultiResolutionSTFTLoss(Loss): log_guard: Value to add to magnitude spectrogram to avoid taking log of 0. sqrt_guard: Value to add to when computing absolute value of STFT to avoid NaN loss. """ + def __init__(self, resolutions: List[List], log_guard: float = 1.0, sqrt_guard: float = 1e-5): super(MultiResolutionSTFTLoss, self).__init__() self.loss_fns = torch.nn.ModuleList(