From 41c7da6ac394ea78780d88268b967583d9187a70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Gast=C3=B3n=20Alvarez?= Date: Fri, 13 Oct 2023 15:09:22 +0200 Subject: [PATCH] Fix types errors (#10) * Fix types errors * Fix loss initial type * Fix variable Tuple length * Fix typing --------- Co-authored-by: Hubert Siuzdak <35269911+hubertsiuzdak@users.noreply.github.com> --- metrics/UTMOS.py | 2 +- vocos/discriminators.py | 27 +++++++++++++++++---------- vocos/feature_extractors.py | 5 ++++- vocos/heads.py | 9 ++++++++- vocos/loss.py | 10 +++++----- vocos/models.py | 3 ++- vocos/modules.py | 8 ++++---- vocos/pretrained.py | 10 ++++++---- 8 files changed, 47 insertions(+), 27 deletions(-) diff --git a/metrics/UTMOS.py b/metrics/UTMOS.py index 5e6e9a5..5c42e8a 100644 --- a/metrics/UTMOS.py +++ b/metrics/UTMOS.py @@ -26,7 +26,7 @@ def __init__(self, device, ckpt_path="epoch=3-step=7459.ckpt"): download_file(UTMOS_CKPT_URL, filepath) self.model = BaselineLightningModule.load_from_checkpoint(filepath).eval().to(device) - def score(self, wavs: torch.tensor) -> torch.tensor: + def score(self, wavs: torch.Tensor) -> torch.Tensor: """ Args: wavs: audio waveform to be evaluated. When len(wavs) == 1 or 2, diff --git a/vocos/discriminators.py b/vocos/discriminators.py index 2f6dece..64ab318 100644 --- a/vocos/discriminators.py +++ b/vocos/discriminators.py @@ -1,10 +1,13 @@ -from typing import Tuple, List +from typing import List, Optional, Tuple import torch from torch import nn from torch.nn import Conv2d from torch.nn.utils import weight_norm +PeriodsType = Tuple[int, ...] +ResolutionType = Tuple[int, int, int] + class MultiPeriodDiscriminator(nn.Module): """ @@ -17,12 +20,12 @@ class MultiPeriodDiscriminator(nn.Module): Defaults to None. """ - def __init__(self, periods: Tuple[int] = (2, 3, 5, 7, 11), num_embeddings: int = None): + def __init__(self, periods: PeriodsType = (2, 3, 5, 7, 11), num_embeddings: Optional[int] = None): super().__init__() self.discriminators = nn.ModuleList([DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods]) def forward( - self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None + self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: y_d_rs = [] y_d_gs = [] @@ -47,7 +50,7 @@ def __init__( kernel_size: int = 5, stride: int = 3, lrelu_slope: float = 0.1, - num_embeddings: int = None, + num_embeddings: Optional[int] = None, ): super().__init__() self.period = period @@ -68,7 +71,7 @@ def __init__( self.lrelu_slope = lrelu_slope def forward( - self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None + self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, List[torch.Tensor]]: x = x.unsqueeze(1) fmap = [] @@ -101,8 +104,12 @@ def forward( class MultiResolutionDiscriminator(nn.Module): def __init__( self, - resolutions: Tuple[Tuple[int, int, int]] = ((1024, 256, 1024), (2048, 512, 2048), (512, 128, 512)), - num_embeddings: int = None, + resolutions: Tuple[ResolutionType, ResolutionType, ResolutionType] = ( + (1024, 256, 1024), + (2048, 512, 2048), + (512, 128, 512), + ), + num_embeddings: Optional[int] = None, ): """ Multi-Resolution Discriminator module adapted from https://github.com/mindslab-ai/univnet. @@ -120,7 +127,7 @@ def __init__( ) def forward( - self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None + self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: y_d_rs = [] y_d_gs = [] @@ -144,7 +151,7 @@ def __init__( resolution: Tuple[int, int, int], channels: int = 64, in_channels: int = 1, - num_embeddings: int = None, + num_embeddings: Optional[int] = None, lrelu_slope: float = 0.1, ): super().__init__() @@ -166,7 +173,7 @@ def __init__( self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1))) def forward( - self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None + self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, List[torch.Tensor]]: fmap = [] x = self.spectrogram(x) diff --git a/vocos/feature_extractors.py b/vocos/feature_extractors.py index 0b4d47b..799f1b4 100644 --- a/vocos/feature_extractors.py +++ b/vocos/feature_extractors.py @@ -82,7 +82,10 @@ def get_encodec_codes(self, audio): codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth) return codes - def forward(self, audio: torch.Tensor, bandwidth_id: torch.Tensor): + def forward(self, audio: torch.Tensor, **kwargs): + bandwidth_id = kwargs.get("bandwidth_id") + if bandwidth_id is None: + raise ValueError("The 'bandwidth_id' argument is required") self.encodec.eval() # Force eval mode as Pytorch Lightning automatically sets child modules to training mode self.encodec.set_target_bandwidth(self.bandwidths[bandwidth_id]) codes = self.get_encodec_codes(audio) diff --git a/vocos/heads.py b/vocos/heads.py index 21863f1..24f5cfc 100644 --- a/vocos/heads.py +++ b/vocos/heads.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from torch import nn from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz @@ -81,7 +83,12 @@ class IMDCTSymExpHead(FourierHead): """ def __init__( - self, dim: int, mdct_frame_len: int, padding: str = "same", sample_rate: int = None, clip_audio: bool = False, + self, + dim: int, + mdct_frame_len: int, + padding: str = "same", + sample_rate: Optional[int] = None, + clip_audio: bool = False, ): super().__init__() out_dim = mdct_frame_len // 2 diff --git a/vocos/loss.py b/vocos/loss.py index e6b0ed5..029f6ac 100644 --- a/vocos/loss.py +++ b/vocos/loss.py @@ -51,7 +51,7 @@ def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[ Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from the sub-discriminators """ - loss = 0 + loss = torch.zeros(1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype) gen_losses = [] for dg in disc_outputs: l = torch.mean(torch.clamp(1 - dg, min=0)) @@ -79,15 +79,15 @@ def forward( the sub-discriminators for real outputs, and a list of loss values for generated outputs. """ - loss = 0 + loss = torch.zeros(1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype) r_losses = [] g_losses = [] for dr, dg in zip(disc_real_outputs, disc_generated_outputs): r_loss = torch.mean(torch.clamp(1 - dr, min=0)) g_loss = torch.mean(torch.clamp(1 + dg, min=0)) loss += r_loss + g_loss - r_losses.append(r_loss.item()) - g_losses.append(g_loss.item()) + r_losses.append(r_loss) + g_losses.append(g_loss) return loss, r_losses, g_losses @@ -106,7 +106,7 @@ def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tens Returns: Tensor: The calculated feature matching loss. """ - loss = 0 + loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype) for dr, dg in zip(fmap_r, fmap_g): for rl, gl in zip(dr, dg): loss += torch.mean(torch.abs(rl - gl)) diff --git a/vocos/models.py b/vocos/models.py index 886a88a..09ed55d 100644 --- a/vocos/models.py +++ b/vocos/models.py @@ -74,7 +74,8 @@ def _init_weights(self, m): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) - def forward(self, x: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + bandwidth_id = kwargs.get('bandwidth_id', None) x = self.embed(x) if self.adanorm: assert bandwidth_id is not None diff --git a/vocos/modules.py b/vocos/modules.py index 9688a97..af1d6db 100644 --- a/vocos/modules.py +++ b/vocos/modules.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple import torch from torch import nn @@ -21,7 +21,7 @@ def __init__( self, dim: int, intermediate_dim: int, - layer_scale_init_value: Optional[float] = None, + layer_scale_init_value: float, adanorm_num_embeddings: Optional[int] = None, ): super().__init__() @@ -106,9 +106,9 @@ def __init__( self, dim: int, kernel_size: int = 3, - dilation: tuple[int] = (1, 3, 5), + dilation: Tuple[int, int, int] = (1, 3, 5), lrelu_slope: float = 0.1, - layer_scale_init_value: float = None, + layer_scale_init_value: Optional[float] = None, ): super().__init__() self.lrelu_slope = lrelu_slope diff --git a/vocos/pretrained.py b/vocos/pretrained.py index 3d6b5c8..6f5cde1 100644 --- a/vocos/pretrained.py +++ b/vocos/pretrained.py @@ -1,4 +1,6 @@ -from typing import Tuple, Any, Union, Dict +from __future__ import annotations + +from typing import Any, Dict, Self, Tuple, Union import torch import yaml @@ -45,7 +47,7 @@ def __init__( self.head = head @classmethod - def from_hparams(cls, config_path: str) -> "Vocos": + def from_hparams(cls, config_path: str) -> Self: """ Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file. """ @@ -58,13 +60,13 @@ def from_hparams(cls, config_path: str) -> "Vocos": return model @classmethod - def from_pretrained(self, repo_id: str) -> "Vocos": + def from_pretrained(cls, repo_id: str) -> Self: """ Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub. """ config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml") model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin") - model = self.from_hparams(config_path) + model = cls.from_hparams(config_path) state_dict = torch.load(model_path, map_location="cpu") if isinstance(model.feature_extractor, EncodecFeatures): encodec_parameters = {