Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTS] Update AudioCodec API #7310

Merged
merged 1 commit into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/tts/audio_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

@hydra_runner(config_path="conf/audio_codec", config_name="audio_codec")
def main(cfg):
logging.info('\nConfig Params:\n%s', OmegaConf.to_yaml(cfg))
logging.info('\nConfig Params:\n%s', OmegaConf.to_yaml(cfg, resolve=True))
trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
model = AudioCodecModel(cfg=cfg.model, trainer=trainer)
Expand Down
2 changes: 1 addition & 1 deletion examples/tts/conf/audio_codec/encodec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ model:
- _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator
log_audio: true
log_encoding: true
log_quantized: true
log_dequantized: true
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The final codebook embeddings are a quantized version of the encoder output, so calling them "dequantized" might be misleading. I would favor keeping the convention in EnCodec and DAC and referring to the codebook indices as "codes" (instead of just "indices") and the corresponding embeddings as "quantized".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re: dequantized
We have the following

continuous encoded representation --quantize--> discrete/quantized representation --dequantize-->  continuous representation

With log_dequantize, we log the final continuous representation after dequantization.
Using dequantized is the correct way to refer to the continuous (e.g., float) output of the _dequantize method.

In this PR I did not want to touch RVQ a lot, but it would be nice to change naming there as well to be consistent.


Re: indices
I absolutely I agree that we should change the name.
I originally changed the name to codes in this PR, but decided to scrap it because self.codes in EuclideanCodebook is denoting the embedding, and wanted to avoid confusion.
Another option, which I think would be very appropriate would be to use tokens to denote the discrete representation instead of indices.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After talking about it, I agree it makes the most sense to refer to the indices as "tokens" and the codebook embeddings can either be called "codes" or "dequantized" depending on the context. Some of the renaming and convention changes can be left for a future PR.


dataset:
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset
Expand Down
168 changes: 152 additions & 16 deletions nemo/collections/tts/models/audio_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import torch
import torch.nn.functional as F
from einops import rearrange
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
Expand All @@ -36,7 +37,7 @@
from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, Index, LengthsType
from nemo.core.neural_types.neural_type import NeuralType
from nemo.core.optim.lr_scheduler import compute_max_steps, prepare_lr_scheduler
from nemo.utils import model_utils
from nemo.utils import logging, model_utils
from nemo.utils.decorators import experimental


Expand All @@ -49,16 +50,21 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

super().__init__(cfg=cfg, trainer=trainer)

# Expected sample rate for the input audio
self.sample_rate = cfg.sample_rate

# Number of samples in each audio frame that is encoded
self.samples_per_frame = cfg.samples_per_frame

# Discriminator updates
self.disc_updates_per_period = cfg.get("disc_updates_per_period", 1)
self.disc_update_period = cfg.get("disc_update_period", 1)
if self.disc_updates_per_period > self.disc_update_period:
raise ValueError(
f'Number of discriminator updates ({self.disc_updates_per_period}) per period must be less or equal to the configured period ({self.disc_update_period})'
)

# Encoder setup
self.audio_encoder = instantiate(cfg.audio_encoder)

# Optionally, add gaussian noise to encoder output as an information bottleneck
Expand All @@ -71,11 +77,16 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
if "vector_quantizer" in cfg:
self.vector_quantizer = instantiate(cfg.vector_quantizer)
else:
logging.warning('Vector quantizer will not be used.')
self.vector_quantizer = None

# Decoder setup
self.audio_decoder = instantiate(cfg.audio_decoder)

# Discriminator setup
self.discriminator = instantiate(cfg.discriminator)

# Loss setup
mel_loss_dim = cfg.get("mel_loss_dim", 64)
mel_loss_resolutions = cfg.mel_loss_resolutions
self.time_domain_loss_scale = cfg.get("time_domain_loss_scale", 1.0)
Expand All @@ -95,7 +106,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.disc_loss_fn = instantiate(cfg.discriminator_loss)
self.feature_loss_fn = RelativeFeatureMatchingLoss()

# Log setup
self.log_config = cfg.get("log_config", None)

# Optimizer setup
self.lr_schedule_interval = None
self.automatic_optimization = False

Expand All @@ -110,6 +124,16 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
},
)
def encode_audio(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply encoder on the input audio signal. Input will be padded with zeros so
the last frame has full `self.samples_per_frame` samples.

Args:
audio: input time-domain signal
audio_len: valid length for each example in the batch

Returns:
Encoder output `encoded` and its length in number of frames `encoded_len`
"""
audio, audio_len = self.pad_audio(audio, audio_len)
encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len)
return encoded, encoded_len
Expand All @@ -125,6 +149,17 @@ def encode_audio(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[to
},
)
def decode_audio(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply decoder on the input encoded representation. Note that the input is a
non-quantized or dequantized representation.

Args:
inputs: encoded signal
input_len: valid length for each example in the batch

Returns:
Decoded output `audio` in the time domain and its length in number of samples `audio_len`.
Note that `audio_len` will be a multiple of `self.samples_per_frame`.
"""
audio, audio_len = self.audio_decoder(inputs=inputs, input_len=input_len)
return audio, audio_len

Expand All @@ -133,28 +168,107 @@ def decode_audio(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[t
"encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),
"encoded_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={"indices": NeuralType(('N', 'B', 'T_encoded'), Index())},
output_types={"tokens": NeuralType(('B', 'C', 'T_encoded'), Index())},
)
def quantize_encode(self, encoded: torch.Tensor, encoded_len: torch.Tensor) -> torch.Tensor:
def quantize(self, encoded: torch.Tensor, encoded_len: torch.Tensor) -> torch.Tensor:
"""Quantize the continuous encoded representation into a discrete
representation for each frame.

Args:
encoded: encoded signal representation
encoded_len: valid length of the encoded representation in frames

Returns:
A tensor of tokens for each codebook for each frame.
"""
if not self.vector_quantizer:
raise ValueError("Cannot quantize without quantizer")

indices = self.vector_quantizer.encode(inputs=encoded, input_len=encoded_len)
return indices
# vector quantizer is returning [C, B, T], where C is the number of codebooks
tokens = self.vector_quantizer.encode(inputs=encoded, input_len=encoded_len)
# use batch first for the output
tokens = rearrange(tokens, 'C B T -> B C T')
return tokens

@typecheck(
input_types={
"indices": NeuralType(('N', 'B', 'T_encoded'), Index()),
"encoded_len": NeuralType(tuple('B'), LengthsType()),
"tokens": NeuralType(('B', 'C', 'T_encoded'), Index()),
"tokens_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={"quantized": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),},
output_types={"dequantized": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),},
)
def quantize_decode(self, indices: torch.Tensor, encoded_len: torch.Tensor) -> torch.Tensor:
def dequantize(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> torch.Tensor:
"""Convert the discrete input tokens into a continuous encoded representation.

Args:
tokens: discrete tokens for each codebook for each time frame
tokens_len: valid length of each example in the batch

Returns:
Continuous encoded representation of the discrete input representation.
"""
if not self.vector_quantizer:
raise ValueError("Cannot dequantize without quantizer")

quantized = self.vector_quantizer.decode(indices=indices, input_len=encoded_len)
return quantized
# vector quantizer is using [C, B, T], where C is the number of codebooks
tokens = rearrange(tokens, 'B C T -> C B T')
dequantized = self.vector_quantizer.decode(indices=tokens, input_len=tokens_len)
return dequantized

@typecheck(
input_types={
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={
"tokens": NeuralType(('B', 'C', 'T_encoded'), Index()),
"tokens_len": NeuralType(tuple('B'), LengthsType()),
},
)
def encode(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Convert input time-domain audio signal into a discrete representation (tokens).

Args:
audio: input time-domain signal, shape (batch, number of samples)
audio_len: valid length for each example in the batch, shape (batch size,)

Returns:
Tokens for each codebook for each frame, shape (batch, number of codebooks, number of frames),
and the corresponding valid lengths, shape (batch,)
"""
# Apply encoder to obtain a continuous vector for each frame
encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len)
# Apply quantizer to obtain discrete representation per frame
tokens = self.quantize(encoded=encoded, encoded_len=encoded_len)
return tokens, encoded_len

@typecheck(
input_types={
"tokens": NeuralType(('B', 'C', 'T_encoded'), Index()),
"tokens_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_len": NeuralType(tuple('B'), LengthsType()),
},
)
def decode(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Convert discrete input tokens into a continuous time-domain signal.

Args:
tokens: discrete tokens for each codebook for each time frame, shape (batch, number of codebooks, number of frames)
tokens_len: valid lengths, shape (batch,)

Returns:
Decoded output `audio` in the time domain and its length in number of samples `audio_len`.
Note that `audio_len` will be a multiple of `self.samples_per_frame`.
"""
# Convert a discrete representation to a dequantized vector for each frame
dequantized = self.dequantize(tokens=tokens, tokens_len=tokens_len)
# Apply decoder to obtain time-domain audio for each frame
audio, audio_len = self.decode_audio(inputs=dequantized, input_len=tokens_len)

return audio, audio_len

@typecheck(
input_types={
Expand All @@ -167,20 +281,40 @@ def quantize_decode(self, indices: torch.Tensor, encoded_len: torch.Tensor) -> t
},
)
def forward(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
audio, audio_len = self.pad_audio(audio, audio_len)
"""Apply encoder, quantizer, decoder on the input time-domain signal.

Args:
audio: input time-domain signal
audio_len: valid length for each example in the batch

Returns:
Reconstructed time-domain signal `output_audio` and its length in number of samples `output_audio_len`.
"""
encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len)

if self.vector_quantizer:
indices = self.quantize_encode(encoded=encoded, encoded_len=encoded_len)
quantized = self.quantize_decode(indices=indices, encoded_len=encoded_len)
output_audio, output_audio_len = self.decode_audio(inputs=quantized, input_len=encoded_len)
# quantize to discrete tokens
tokens = self.quantize(encoded=encoded, encoded_len=encoded_len)
# decode tokens to audio
output_audio, output_audio_len = self.decode(tokens=tokens, tokens_len=encoded_len)
else:
# no quantization, directly decode to audio
output_audio, output_audio_len = self.decode_audio(inputs=encoded, input_len=encoded_len)

return output_audio, output_audio_len

# Zero pad the end of the audio so that we do not have a partial end frame.
def pad_audio(self, audio, audio_len):
"""Zero pad the end of the audio so that we do not have a partial end frame.
The output will be zero-padded to have an integer number of frames of
length `self.samples_per_frame`.

Args:
audio: input time-domain signal
audio_len: valid length for each example in the batch

Returns:
Padded time-domain signal `padded_audio` and its length `padded_len`.
"""
padded_len = self.samples_per_frame * torch.ceil(audio_len / self.samples_per_frame).int()
max_len = padded_len.max().item()
num_padding = max_len - audio.shape[1]
Expand Down Expand Up @@ -357,8 +491,10 @@ def configure_optimizers(self):
optim_d = instantiate(optim_config, params=disc_params)

if sched_config is None:
logging.debug('Scheduler is not used')
return [optim_g, optim_d]

logging.debug('Setting up schedulers')
OmegaConf.set_struct(sched_config, False)
sched_config["max_steps"] = self.max_steps
OmegaConf.set_struct(sched_config, True)
Expand Down
Loading
Loading