Skip to content

Commit

Permalink
Fix types errors (#10)
Browse files Browse the repository at this point in the history
* Fix types errors

* Fix loss initial type

* Fix variable Tuple length

* Fix typing

---------

Co-authored-by: Hubert Siuzdak <[email protected]>
  • Loading branch information
alealv and hubertsiuzdak authored Oct 13, 2023
1 parent 86db50a commit 41c7da6
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 27 deletions.
2 changes: 1 addition & 1 deletion metrics/UTMOS.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, device, ckpt_path="epoch=3-step=7459.ckpt"):
download_file(UTMOS_CKPT_URL, filepath)
self.model = BaselineLightningModule.load_from_checkpoint(filepath).eval().to(device)

def score(self, wavs: torch.tensor) -> torch.tensor:
def score(self, wavs: torch.Tensor) -> torch.Tensor:
"""
Args:
wavs: audio waveform to be evaluated. When len(wavs) == 1 or 2,
Expand Down
27 changes: 17 additions & 10 deletions vocos/discriminators.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Tuple, List
from typing import List, Optional, Tuple

import torch
from torch import nn
from torch.nn import Conv2d
from torch.nn.utils import weight_norm

PeriodsType = Tuple[int, ...]
ResolutionType = Tuple[int, int, int]


class MultiPeriodDiscriminator(nn.Module):
"""
Expand All @@ -17,12 +20,12 @@ class MultiPeriodDiscriminator(nn.Module):
Defaults to None.
"""

def __init__(self, periods: Tuple[int] = (2, 3, 5, 7, 11), num_embeddings: int = None):
def __init__(self, periods: PeriodsType = (2, 3, 5, 7, 11), num_embeddings: Optional[int] = None):
super().__init__()
self.discriminators = nn.ModuleList([DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods])

def forward(
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
y_d_rs = []
y_d_gs = []
Expand All @@ -47,7 +50,7 @@ def __init__(
kernel_size: int = 5,
stride: int = 3,
lrelu_slope: float = 0.1,
num_embeddings: int = None,
num_embeddings: Optional[int] = None,
):
super().__init__()
self.period = period
Expand All @@ -68,7 +71,7 @@ def __init__(
self.lrelu_slope = lrelu_slope

def forward(
self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
x = x.unsqueeze(1)
fmap = []
Expand Down Expand Up @@ -101,8 +104,12 @@ def forward(
class MultiResolutionDiscriminator(nn.Module):
def __init__(
self,
resolutions: Tuple[Tuple[int, int, int]] = ((1024, 256, 1024), (2048, 512, 2048), (512, 128, 512)),
num_embeddings: int = None,
resolutions: Tuple[ResolutionType, ResolutionType, ResolutionType] = (
(1024, 256, 1024),
(2048, 512, 2048),
(512, 128, 512),
),
num_embeddings: Optional[int] = None,
):
"""
Multi-Resolution Discriminator module adapted from https://github.com/mindslab-ai/univnet.
Expand All @@ -120,7 +127,7 @@ def __init__(
)

def forward(
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
y_d_rs = []
y_d_gs = []
Expand All @@ -144,7 +151,7 @@ def __init__(
resolution: Tuple[int, int, int],
channels: int = 64,
in_channels: int = 1,
num_embeddings: int = None,
num_embeddings: Optional[int] = None,
lrelu_slope: float = 0.1,
):
super().__init__()
Expand All @@ -166,7 +173,7 @@ def __init__(
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1)))

def forward(
self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
x = self.spectrogram(x)
Expand Down
5 changes: 4 additions & 1 deletion vocos/feature_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def get_encodec_codes(self, audio):
codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth)
return codes

def forward(self, audio: torch.Tensor, bandwidth_id: torch.Tensor):
def forward(self, audio: torch.Tensor, **kwargs):
bandwidth_id = kwargs.get("bandwidth_id")
if bandwidth_id is None:
raise ValueError("The 'bandwidth_id' argument is required")
self.encodec.eval() # Force eval mode as Pytorch Lightning automatically sets child modules to training mode
self.encodec.set_target_bandwidth(self.bandwidths[bandwidth_id])
codes = self.get_encodec_codes(audio)
Expand Down
9 changes: 8 additions & 1 deletion vocos/heads.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
from torch import nn
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
Expand Down Expand Up @@ -81,7 +83,12 @@ class IMDCTSymExpHead(FourierHead):
"""

def __init__(
self, dim: int, mdct_frame_len: int, padding: str = "same", sample_rate: int = None, clip_audio: bool = False,
self,
dim: int,
mdct_frame_len: int,
padding: str = "same",
sample_rate: Optional[int] = None,
clip_audio: bool = False,
):
super().__init__()
out_dim = mdct_frame_len // 2
Expand Down
10 changes: 5 additions & 5 deletions vocos/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[
Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
the sub-discriminators
"""
loss = 0
loss = torch.zeros(1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype)
gen_losses = []
for dg in disc_outputs:
l = torch.mean(torch.clamp(1 - dg, min=0))
Expand Down Expand Up @@ -79,15 +79,15 @@ def forward(
the sub-discriminators for real outputs, and a list of
loss values for generated outputs.
"""
loss = 0
loss = torch.zeros(1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype)
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean(torch.clamp(1 - dr, min=0))
g_loss = torch.mean(torch.clamp(1 + dg, min=0))
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
r_losses.append(r_loss)
g_losses.append(g_loss)

return loss, r_losses, g_losses

Expand All @@ -106,7 +106,7 @@ def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tens
Returns:
Tensor: The calculated feature matching loss.
"""
loss = 0
loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype)
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
Expand Down
3 changes: 2 additions & 1 deletion vocos/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def _init_weights(self, m):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)

def forward(self, x: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
bandwidth_id = kwargs.get('bandwidth_id', None)
x = self.embed(x)
if self.adanorm:
assert bandwidth_id is not None
Expand Down
8 changes: 4 additions & 4 deletions vocos/modules.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Tuple

import torch
from torch import nn
Expand All @@ -21,7 +21,7 @@ def __init__(
self,
dim: int,
intermediate_dim: int,
layer_scale_init_value: Optional[float] = None,
layer_scale_init_value: float,
adanorm_num_embeddings: Optional[int] = None,
):
super().__init__()
Expand Down Expand Up @@ -106,9 +106,9 @@ def __init__(
self,
dim: int,
kernel_size: int = 3,
dilation: tuple[int] = (1, 3, 5),
dilation: Tuple[int, int, int] = (1, 3, 5),
lrelu_slope: float = 0.1,
layer_scale_init_value: float = None,
layer_scale_init_value: Optional[float] = None,
):
super().__init__()
self.lrelu_slope = lrelu_slope
Expand Down
10 changes: 6 additions & 4 deletions vocos/pretrained.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Tuple, Any, Union, Dict
from __future__ import annotations

from typing import Any, Dict, Self, Tuple, Union

import torch
import yaml
Expand Down Expand Up @@ -45,7 +47,7 @@ def __init__(
self.head = head

@classmethod
def from_hparams(cls, config_path: str) -> "Vocos":
def from_hparams(cls, config_path: str) -> Self:
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
Expand All @@ -58,13 +60,13 @@ def from_hparams(cls, config_path: str) -> "Vocos":
return model

@classmethod
def from_pretrained(self, repo_id: str) -> "Vocos":
def from_pretrained(cls, repo_id: str) -> Self:
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml")
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
model = self.from_hparams(config_path)
model = cls.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu")
if isinstance(model.feature_extractor, EncodecFeatures):
encodec_parameters = {
Expand Down

0 comments on commit 41c7da6

Please sign in to comment.