Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix types errors #10

Merged
merged 4 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion metrics/UTMOS.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, device, ckpt_path="epoch=3-step=7459.ckpt"):
download_file(UTMOS_CKPT_URL, filepath)
self.model = BaselineLightningModule.load_from_checkpoint(filepath).eval().to(device)

def score(self, wavs: torch.tensor) -> torch.tensor:
def score(self, wavs: torch.Tensor) -> torch.Tensor:
"""
Args:
wavs: audio waveform to be evaluated. When len(wavs) == 1 or 2,
Expand Down
27 changes: 17 additions & 10 deletions vocos/discriminators.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Tuple, List
from typing import List, Optional, Tuple

import torch
from torch import nn
from torch.nn import Conv2d
from torch.nn.utils import weight_norm

PeriodsType = Tuple[int, ...]
ResolutionType = Tuple[int, int, int]


class MultiPeriodDiscriminator(nn.Module):
"""
Expand All @@ -17,12 +20,12 @@ class MultiPeriodDiscriminator(nn.Module):
Defaults to None.
"""

def __init__(self, periods: Tuple[int] = (2, 3, 5, 7, 11), num_embeddings: int = None):
def __init__(self, periods: PeriodsType = (2, 3, 5, 7, 11), num_embeddings: Optional[int] = None):
super().__init__()
self.discriminators = nn.ModuleList([DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods])

def forward(
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
y_d_rs = []
y_d_gs = []
Expand All @@ -47,7 +50,7 @@ def __init__(
kernel_size: int = 5,
stride: int = 3,
lrelu_slope: float = 0.1,
num_embeddings: int = None,
num_embeddings: Optional[int] = None,
):
super().__init__()
self.period = period
Expand All @@ -68,7 +71,7 @@ def __init__(
self.lrelu_slope = lrelu_slope

def forward(
self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
x = x.unsqueeze(1)
fmap = []
Expand Down Expand Up @@ -101,8 +104,12 @@ def forward(
class MultiResolutionDiscriminator(nn.Module):
def __init__(
self,
resolutions: Tuple[Tuple[int, int, int]] = ((1024, 256, 1024), (2048, 512, 2048), (512, 128, 512)),
num_embeddings: int = None,
resolutions: Tuple[ResolutionType, ResolutionType, ResolutionType] = (
(1024, 256, 1024),
(2048, 512, 2048),
(512, 128, 512),
),
num_embeddings: Optional[int] = None,
):
"""
Multi-Resolution Discriminator module adapted from https://github.com/mindslab-ai/univnet.
Expand All @@ -120,7 +127,7 @@ def __init__(
)

def forward(
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
y_d_rs = []
y_d_gs = []
Expand All @@ -144,7 +151,7 @@ def __init__(
resolution: Tuple[int, int, int],
channels: int = 64,
in_channels: int = 1,
num_embeddings: int = None,
num_embeddings: Optional[int] = None,
lrelu_slope: float = 0.1,
):
super().__init__()
Expand All @@ -166,7 +173,7 @@ def __init__(
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1)))

def forward(
self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
x = self.spectrogram(x)
Expand Down
5 changes: 4 additions & 1 deletion vocos/feature_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def get_encodec_codes(self, audio):
codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth)
return codes

def forward(self, audio: torch.Tensor, bandwidth_id: torch.Tensor):
def forward(self, audio: torch.Tensor, **kwargs):
bandwidth_id = kwargs.get("bandwidth_id")
if bandwidth_id is None:
raise ValueError("The 'bandwidth_id' argument is required")
self.encodec.eval() # Force eval mode as Pytorch Lightning automatically sets child modules to training mode
self.encodec.set_target_bandwidth(self.bandwidths[bandwidth_id])
codes = self.get_encodec_codes(audio)
Expand Down
9 changes: 8 additions & 1 deletion vocos/heads.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
from torch import nn
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
Expand Down Expand Up @@ -76,7 +78,12 @@ class IMDCTSymExpHead(FourierHead):
"""

def __init__(
self, dim: int, mdct_frame_len: int, padding: str = "same", sample_rate: int = None, clip_audio: bool = False,
self,
dim: int,
mdct_frame_len: int,
padding: str = "same",
sample_rate: Optional[int] = None,
clip_audio: bool = False,
):
super().__init__()
out_dim = mdct_frame_len // 2
Expand Down
10 changes: 5 additions & 5 deletions vocos/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[
Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
the sub-discriminators
"""
loss = 0
loss = torch.zeros(1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype)
gen_losses = []
for dg in disc_outputs:
l = torch.mean(torch.clamp(1 - dg, min=0))
Expand Down Expand Up @@ -79,15 +79,15 @@ def forward(
the sub-discriminators for real outputs, and a list of
loss values for generated outputs.
"""
loss = 0
loss = torch.zeros(1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype)
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean(torch.clamp(1 - dr, min=0))
g_loss = torch.mean(torch.clamp(1 + dg, min=0))
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
r_losses.append(r_loss)
g_losses.append(g_loss)

return loss, r_losses, g_losses

Expand All @@ -106,7 +106,7 @@ def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tens
Returns:
Tensor: The calculated feature matching loss.
"""
loss = 0
loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype)
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
Expand Down
3 changes: 2 additions & 1 deletion vocos/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def _init_weights(self, m):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)

def forward(self, x: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
bandwidth_id = kwargs.get('bandwidth_id', None)
x = self.embed(x)
if self.adanorm:
assert bandwidth_id is not None
Expand Down
8 changes: 4 additions & 4 deletions vocos/modules.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Tuple

import torch
from torch import nn
Expand All @@ -21,7 +21,7 @@ def __init__(
self,
dim: int,
intermediate_dim: int,
layer_scale_init_value: Optional[float] = None,
layer_scale_init_value: float,
adanorm_num_embeddings: Optional[int] = None,
):
super().__init__()
Expand Down Expand Up @@ -106,9 +106,9 @@ def __init__(
self,
dim: int,
kernel_size: int = 3,
dilation: tuple[int] = (1, 3, 5),
dilation: Tuple[int, int, int] = (1, 3, 5),
lrelu_slope: float = 0.1,
layer_scale_init_value: float = None,
layer_scale_init_value: Optional[float] = None,
):
super().__init__()
self.lrelu_slope = lrelu_slope
Expand Down
10 changes: 6 additions & 4 deletions vocos/pretrained.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Tuple, Any, Union, Dict
from __future__ import annotations

from typing import Any, Dict, Self, Tuple, Union

import torch
import yaml
Expand Down Expand Up @@ -45,7 +47,7 @@ def __init__(
self.head = head

@classmethod
def from_hparams(cls, config_path: str) -> "Vocos":
def from_hparams(cls, config_path: str) -> Self:
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
Expand All @@ -58,13 +60,13 @@ def from_hparams(cls, config_path: str) -> "Vocos":
return model

@classmethod
def from_pretrained(self, repo_id: str) -> "Vocos":
def from_pretrained(cls, repo_id: str) -> Self:
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml")
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
model = self.from_hparams(config_path)
model = cls.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu")
if isinstance(model.feature_extractor, EncodecFeatures):
encodec_parameters = {
Expand Down