Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace multi-resolution discriminator; update AdamW default config #30

Merged
merged 4 commits into from
Oct 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,10 @@ See [example notebook](notebooks%2FBark%2BVocos.ipynb).

## Pre-trained models

The provided models were trained up to 2.5 million generator iterations, which resulted in slightly better objective
scores
compared to those reported in the paper.

| Model Name | Dataset | Training Iterations | Parameters
|-------------------------------------------------------------------------------------|---------------|---------------------|------------|
| [charactr/vocos-mel-24khz](https://huggingface.co/charactr/vocos-mel-24khz) | LibriTTS | 2.5 M | 13.5 M
| [charactr/vocos-encodec-24khz](https://huggingface.co/charactr/vocos-encodec-24khz) | DNS Challenge | 2.5 M | 7.9 M
|-------------------------------------------------------------------------------------|---------------|-------------------|------------|
| [charactr/vocos-mel-24khz](https://huggingface.co/charactr/vocos-mel-24khz) | LibriTTS | 1M | 13.5M
| [charactr/vocos-encodec-24khz](https://huggingface.co/charactr/vocos-encodec-24khz) | DNS Challenge | 2M | 7.9M

## Training

Expand Down
2 changes: 1 addition & 1 deletion configs/vocos-encodec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ model:
class_path: vocos.experiment.VocosEncodecExp
init_args:
sample_rate: 24000
initial_learning_rate: 2e-4
initial_learning_rate: 5e-4
mel_loss_coeff: 45
mrd_loss_coeff: 1.0
num_warmup_steps: 0 # Optimizers warmup steps
Expand Down
2 changes: 1 addition & 1 deletion configs/vocos-imdct.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ model:
class_path: vocos.experiment.VocosExp
init_args:
sample_rate: 24000
initial_learning_rate: 2e-4
initial_learning_rate: 5e-4
mel_loss_coeff: 45
mrd_loss_coeff: 0.1
num_warmup_steps: 0 # Optimizers warmup steps
Expand Down
2 changes: 1 addition & 1 deletion configs/vocos-resnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ model:
class_path: vocos.experiment.VocosExp
init_args:
sample_rate: 24000
initial_learning_rate: 2e-4
initial_learning_rate: 5e-4
mel_loss_coeff: 45
mrd_loss_coeff: 0.1
num_warmup_steps: 0 # Optimizers warmup steps
Expand Down
2 changes: 1 addition & 1 deletion configs/vocos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ model:
class_path: vocos.experiment.VocosExp
init_args:
sample_rate: 24000
initial_learning_rate: 2e-4
initial_learning_rate: 5e-4
mel_loss_coeff: 45
mrd_loss_coeff: 0.1
num_warmup_steps: 0 # Optimizers warmup steps
Expand Down
2 changes: 1 addition & 1 deletion vocos/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from vocos.pretrained import Vocos


__version__ = "0.0.4"
__version__ = "0.1.0"
106 changes: 54 additions & 52 deletions vocos/discriminators.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import List, Optional, Tuple

import torch
from einops import rearrange
from torch import nn
from torch.nn import Conv2d
from torch.nn.utils import weight_norm

PeriodsType = Tuple[int, ...]
ResolutionType = Tuple[int, int, int]
from torchaudio.transforms import Spectrogram


class MultiPeriodDiscriminator(nn.Module):
Expand All @@ -20,7 +19,7 @@ class MultiPeriodDiscriminator(nn.Module):
Defaults to None.
"""

def __init__(self, periods: PeriodsType = (2, 3, 5, 7, 11), num_embeddings: Optional[int] = None):
def __init__(self, periods: Tuple[int, ...] = (2, 3, 5, 7, 11), num_embeddings: Optional[int] = None):
super().__init__()
self.discriminators = nn.ModuleList([DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods])

Expand Down Expand Up @@ -104,30 +103,26 @@ def forward(
class MultiResolutionDiscriminator(nn.Module):
def __init__(
self,
resolutions: Tuple[ResolutionType, ResolutionType, ResolutionType] = (
(1024, 256, 1024),
(2048, 512, 2048),
(512, 128, 512),
),
fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
num_embeddings: Optional[int] = None,
):
"""
Multi-Resolution Discriminator module adapted from https://github.com/mindslab-ai/univnet.
Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
Additionally, it allows incorporating conditional information with a learned embeddings table.

Args:
resolutions (tuple[tuple[int, int, int]]): Tuple of resolutions for each discriminator.
Each resolution should be a tuple of (n_fft, hop_length, win_length).
fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
Defaults to None.
"""

super().__init__()
self.discriminators = nn.ModuleList(
[DiscriminatorR(resolution=r, num_embeddings=num_embeddings) for r in resolutions]
[DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
)

def forward(
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
y_d_rs = []
y_d_gs = []
Expand All @@ -148,40 +143,62 @@ def forward(
class DiscriminatorR(nn.Module):
def __init__(
self,
resolution: Tuple[int, int, int],
channels: int = 64,
in_channels: int = 1,
window_length: int,
num_embeddings: Optional[int] = None,
lrelu_slope: float = 0.1,
channels: int = 32,
hop_factor: float = 0.25,
bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
):
super().__init__()
self.resolution = resolution
self.in_channels = in_channels
self.lrelu_slope = lrelu_slope
self.convs = nn.ModuleList(
self.window_length = window_length
self.hop_factor = hop_factor
self.spec_fn = Spectrogram(
n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
)
n_fft = window_length // 2 + 1
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
self.bands = bands
convs = lambda: nn.ModuleList(
[
weight_norm(nn.Conv2d(in_channels, channels, kernel_size=(7, 5), stride=(2, 2), padding=(3, 2))),
weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1))),
weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 2), padding=(2, 1))),
weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 1), padding=1)),
weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 2), padding=1)),
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
]
)
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])

if num_embeddings is not None:
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
torch.nn.init.zeros_(self.emb.weight)
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1)))

def forward(
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))

def spectrogram(self, x):
# Remove DC offset
x = x - x.mean(dim=-1, keepdims=True)
# Peak normalize the volume of input audio
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
x = self.spec_fn(x)
x = torch.view_as_real(x)
x = rearrange(x, "b f t c -> b c t f")
# Split into bands
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
return x_bands

def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
x_bands = self.spectrogram(x)
fmap = []
x = self.spectrogram(x)
x = x.unsqueeze(1)
for l in self.convs:
x = l(x)
x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
fmap.append(x)
x = []
for band, stack in zip(x_bands, self.band_convs):
for i, layer in enumerate(stack):
band = layer(band)
band = torch.nn.functional.leaky_relu(band, 0.1)
if i > 0:
fmap.append(band)
x.append(band)
x = torch.cat(x, dim=-1)
if cond_embedding_id is not None:
emb = self.emb(cond_embedding_id)
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
Expand All @@ -190,20 +207,5 @@ def forward(
x = self.conv_post(x)
fmap.append(x)
x += h
x = torch.flatten(x, 1, -1)

return x, fmap

def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
n_fft, hop_length, win_length = self.resolution
magnitude_spectrogram = torch.stft(
x,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=None, # interestingly rectangular window kind of works here
center=True,
return_complex=True,
).abs()

return magnitude_spectrogram
4 changes: 2 additions & 2 deletions vocos/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def configure_optimizers(self):
{"params": self.head.parameters()},
]

opt_disc = torch.optim.AdamW(disc_params, lr=self.hparams.initial_learning_rate)
opt_gen = torch.optim.AdamW(gen_params, lr=self.hparams.initial_learning_rate)
opt_disc = torch.optim.AdamW(disc_params, lr=self.hparams.initial_learning_rate, betas=(0.8, 0.9))
opt_gen = torch.optim.AdamW(gen_params, lr=self.hparams.initial_learning_rate, betas=(0.8, 0.9))

max_steps = self.trainer.max_steps // 2 # Max steps per optimizer
scheduler_disc = transformers.get_cosine_schedule_with_warmup(
Expand Down