Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix types errors #10

Merged
merged 4 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions vocos/discriminators.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Tuple, List
from typing import List, Optional, Tuple

import torch
from torch import nn
from torch.nn import Conv2d
from torch.nn.utils import weight_norm


PeriodsType = Tuple[int, int, int, int, int]
alealv marked this conversation as resolved.
Show resolved Hide resolved
ResolutionType = Tuple[int, int, int]

class MultiPeriodDiscriminator(nn.Module):
"""
Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan.
Expand All @@ -17,12 +20,12 @@ class MultiPeriodDiscriminator(nn.Module):
Defaults to None.
"""

def __init__(self, periods: Tuple[int] = (2, 3, 5, 7, 11), num_embeddings: int = None):
def __init__(self, periods: PeriodsType = (2, 3, 5, 7, 11), num_embeddings: Optional[int] = None):
super().__init__()
self.discriminators = nn.ModuleList([DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods])

def forward(
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
y_d_rs = []
y_d_gs = []
Expand All @@ -47,7 +50,7 @@ def __init__(
kernel_size: int = 5,
stride: int = 3,
lrelu_slope: float = 0.1,
num_embeddings: int = None,
num_embeddings: Optional[int] = None,
):
super().__init__()
self.period = period
Expand All @@ -68,7 +71,7 @@ def __init__(
self.lrelu_slope = lrelu_slope

def forward(
self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
x = x.unsqueeze(1)
fmap = []
Expand Down Expand Up @@ -101,8 +104,8 @@ def forward(
class MultiResolutionDiscriminator(nn.Module):
def __init__(
self,
resolutions: Tuple[Tuple[int, int, int]] = ((1024, 256, 1024), (2048, 512, 2048), (512, 128, 512)),
num_embeddings: int = None,
resolutions: Tuple[ResolutionType, ResolutionType, ResolutionType] = ((1024, 256, 1024), (2048, 512, 2048), (512, 128, 512)),
num_embeddings: Optional[int] = None,
):
"""
Multi-Resolution Discriminator module adapted from https://github.com/mindslab-ai/univnet.
Expand All @@ -120,7 +123,7 @@ def __init__(
)

def forward(
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
y_d_rs = []
y_d_gs = []
Expand All @@ -144,7 +147,7 @@ def __init__(
resolution: Tuple[int, int, int],
channels: int = 64,
in_channels: int = 1,
num_embeddings: int = None,
num_embeddings: Optional[int] = None,
lrelu_slope: float = 0.1,
):
super().__init__()
Expand All @@ -166,7 +169,7 @@ def __init__(
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1)))

def forward(
self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
x = self.spectrogram(x)
Expand Down
6 changes: 3 additions & 3 deletions vocos/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[
Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
the sub-discriminators
"""
loss = 0
loss = torch.tensor(0)
gen_losses = []
for dg in disc_outputs:
l = torch.mean(torch.clamp(1 - dg, min=0))
Expand Down Expand Up @@ -79,7 +79,7 @@ def forward(
the sub-discriminators for real outputs, and a list of
loss values for generated outputs.
"""
loss = 0
loss = torch.tensor(0)
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
Expand All @@ -106,7 +106,7 @@ def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tens
Returns:
Tensor: The calculated feature matching loss.
"""
loss = 0
loss = torch.tensor(0)
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
Expand Down
6 changes: 3 additions & 3 deletions vocos/modules.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Tuple

import torch
from torch import nn
Expand Down Expand Up @@ -106,9 +106,9 @@ def __init__(
self,
dim: int,
kernel_size: int = 3,
dilation: tuple[int] = (1, 3, 5),
dilation: Tuple[int, int, int] = (1, 3, 5),
lrelu_slope: float = 0.1,
layer_scale_init_value: float = None,
layer_scale_init_value: Optional[float] = None,
):
super().__init__()
self.lrelu_slope = lrelu_slope
Expand Down
10 changes: 6 additions & 4 deletions vocos/pretrained.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Tuple, Any, Union, Dict
from __future__ import annotations

from typing import Any, Dict, Self, Tuple, Union

import torch
import yaml
Expand Down Expand Up @@ -45,7 +47,7 @@ def __init__(
self.head = head

@classmethod
def from_hparams(cls, config_path: str) -> "Vocos":
def from_hparams(cls, config_path: str) -> Self:
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
Expand All @@ -58,13 +60,13 @@ def from_hparams(cls, config_path: str) -> "Vocos":
return model

@classmethod
def from_pretrained(self, repo_id: str) -> "Vocos":
def from_pretrained(cls, repo_id: str) -> Self:
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml")
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
model = self.from_hparams(config_path)
model = cls.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu")
if isinstance(model.feature_extractor, EncodecFeatures):
encodec_parameters = {
Expand Down