Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 3, 2023
1 parent 2f11b29 commit 97ee56f
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions nemo/collections/tts/losses/audio_codec_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class MultiResolutionMelLoss(Loss):
mel_dims: Dimension of mel spectrogram to compute for each resolution. Should be same length as 'resolutions'.
log_guard: Value to add to mel spectrogram to avoid taking log of 0.
"""

def __init__(self, sample_rate: int, resolutions: List[List], mel_dims: List[int], log_guard: float = 1.0):
super(MultiResolutionMelLoss, self).__init__()
assert len(resolutions) == len(mel_dims)
Expand Down Expand Up @@ -186,6 +187,7 @@ class STFTLoss(Loss):
log_guard: Value to add to magnitude spectrogram to avoid taking log of 0.
sqrt_guard: Value to add to when computing absolute value of STFT to avoid NaN loss.
"""

def __init__(self, resolution: List[int], log_guard: float = 1.0, sqrt_guard: float = 1e-5):
super(STFTLoss, self).__init__()
self.loss_fn = MaskedMAELoss()
Expand Down Expand Up @@ -242,6 +244,7 @@ class MultiResolutionSTFTLoss(Loss):
log_guard: Value to add to magnitude spectrogram to avoid taking log of 0.
sqrt_guard: Value to add to when computing absolute value of STFT to avoid NaN loss.
"""

def __init__(self, resolutions: List[List], log_guard: float = 1.0, sqrt_guard: float = 1e-5):
super(MultiResolutionSTFTLoss, self).__init__()
self.loss_fns = torch.nn.ModuleList(
Expand Down

0 comments on commit 97ee56f

Please sign in to comment.