Skip to content

Commit

Permalink
Update AudioCodec API
Browse files Browse the repository at this point in the history
Signed-off-by: Ante Jukić <[email protected]>
  • Loading branch information
anteju committed Aug 28, 2023
1 parent 6861215 commit 1f8f715
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 66 deletions.
2 changes: 1 addition & 1 deletion examples/tts/audio_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

@hydra_runner(config_path="conf/audio_codec", config_name="audio_codec")
def main(cfg):
logging.info('\nConfig Params:\n%s', OmegaConf.to_yaml(cfg))
logging.info('\nConfig Params:\n%s', OmegaConf.to_yaml(cfg, resolve=True))
trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
model = AudioCodecModel(cfg=cfg.model, trainer=trainer)
Expand Down
2 changes: 1 addition & 1 deletion examples/tts/conf/audio_codec/encodec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ model:
- _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator
log_audio: true
log_encoding: true
log_quantized: true
log_dequantized: true

dataset:
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset
Expand Down
168 changes: 152 additions & 16 deletions nemo/collections/tts/models/audio_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import torch
import torch.nn.functional as F
from einops import rearrange
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
Expand All @@ -36,7 +37,7 @@
from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, Index, LengthsType
from nemo.core.neural_types.neural_type import NeuralType
from nemo.core.optim.lr_scheduler import compute_max_steps, prepare_lr_scheduler
from nemo.utils import model_utils
from nemo.utils import logging, model_utils
from nemo.utils.decorators import experimental


Expand All @@ -49,16 +50,21 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

super().__init__(cfg=cfg, trainer=trainer)

# Expected sample rate for the input audio
self.sample_rate = cfg.sample_rate

# Number of samples in each audio frame that is encoded
self.samples_per_frame = cfg.samples_per_frame

# Discriminator updates
self.disc_updates_per_period = cfg.get("disc_updates_per_period", 1)
self.disc_update_period = cfg.get("disc_update_period", 1)
if self.disc_updates_per_period > self.disc_update_period:
raise ValueError(
f'Number of discriminator updates ({self.disc_updates_per_period}) per period must be less or equal to the configured period ({self.disc_update_period})'
)

# Encoder setup
self.audio_encoder = instantiate(cfg.audio_encoder)

# Optionally, add gaussian noise to encoder output as an information bottleneck
Expand All @@ -71,11 +77,16 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
if "vector_quantizer" in cfg:
self.vector_quantizer = instantiate(cfg.vector_quantizer)
else:
logging.warning('Vector quantizer will not be used.')
self.vector_quantizer = None

# Decoder setup
self.audio_decoder = instantiate(cfg.audio_decoder)

# Discriminator setup
self.discriminator = instantiate(cfg.discriminator)

# Loss setup
mel_loss_dim = cfg.get("mel_loss_dim", 64)
mel_loss_resolutions = cfg.mel_loss_resolutions
self.time_domain_loss_scale = cfg.get("time_domain_loss_scale", 1.0)
Expand All @@ -95,7 +106,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.disc_loss_fn = instantiate(cfg.discriminator_loss)
self.feature_loss_fn = RelativeFeatureMatchingLoss()

# Log setup
self.log_config = cfg.get("log_config", None)

# Optimizer setup
self.lr_schedule_interval = None
self.automatic_optimization = False

Expand All @@ -110,6 +124,16 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
},
)
def encode_audio(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply encoder on the input audio signal. Input will be padded with zeros so
the last frame has full `self.samples_per_frame` samples.
Args:
audio: input time-domain signal
audio_len: valid length for each example in the batch
Returns:
Encoder output `encoded` and its length in number of frames `encoded_len`
"""
audio, audio_len = self.pad_audio(audio, audio_len)
encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len)
return encoded, encoded_len
Expand All @@ -125,6 +149,17 @@ def encode_audio(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[to
},
)
def decode_audio(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply decoder on the input encoded representation. Note that the input is a
non-quantized or dequantized representation.
Args:
inputs: encoded signal
input_len: valid length for each example in the batch
Returns:
Decoded output `audio` in the time domain and its length in number of samples `audio_len`.
Note that `audio_len` will be a multiple of `self.samples_per_frame`.
"""
audio, audio_len = self.audio_decoder(inputs=inputs, input_len=input_len)
return audio, audio_len

Expand All @@ -133,28 +168,107 @@ def decode_audio(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[t
"encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),
"encoded_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={"indices": NeuralType(('N', 'B', 'T_encoded'), Index())},
output_types={"tokens": NeuralType(('B', 'C', 'T_encoded'), Index())},
)
def quantize_encode(self, encoded: torch.Tensor, encoded_len: torch.Tensor) -> torch.Tensor:
def quantize(self, encoded: torch.Tensor, encoded_len: torch.Tensor) -> torch.Tensor:
"""Quantize the continuous encoded representation into a discrete
representation for each frame.
Args:
encoded: encoded signal representation
encoded_len: valid length of the encoded representation in frames
Returns:
A tensor of tokens for each codebook for each frame.
"""
if not self.vector_quantizer:
raise ValueError("Cannot quantize without quantizer")

indices = self.vector_quantizer.encode(inputs=encoded, input_len=encoded_len)
return indices
# vector quantizer is returning [C, B, T], where C is the number of codebooks
tokens = self.vector_quantizer.encode(inputs=encoded, input_len=encoded_len)
# use batch first for the output
tokens = rearrange(tokens, 'C B T -> B C T')
return tokens

@typecheck(
input_types={
"indices": NeuralType(('N', 'B', 'T_encoded'), Index()),
"encoded_len": NeuralType(tuple('B'), LengthsType()),
"tokens": NeuralType(('B', 'C', 'T_encoded'), Index()),
"tokens_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={"quantized": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),},
output_types={"dequantized": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),},
)
def quantize_decode(self, indices: torch.Tensor, encoded_len: torch.Tensor) -> torch.Tensor:
def dequantize(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> torch.Tensor:
"""Convert the discrete input tokens into a continuous encoded representation.
Args:
tokens: discrete tokens for each codebook for each time frame
tokens_len: valid length of each example in the batch
Returns:
Continuous encoded representation of the discrete input representation.
"""
if not self.vector_quantizer:
raise ValueError("Cannot dequantize without quantizer")

quantized = self.vector_quantizer.decode(indices=indices, input_len=encoded_len)
return quantized
# vector quantizer is using [C, B, T], where C is the number of codebooks
tokens = rearrange(tokens, 'B C T -> C B T')
dequantized = self.vector_quantizer.decode(indices=tokens, input_len=tokens_len)
return dequantized

@typecheck(
input_types={
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={
"tokens": NeuralType(('B', 'C', 'T_encoded'), Index()),
"tokens_len": NeuralType(tuple('B'), LengthsType()),
},
)
def encode(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Convert input time-domain audio signal into a discrete representation (tokens).
Args:
audio: input time-domain signal, shape (batch, number of samples)
audio_len: valid length for each example in the batch, shape (batch size,)
Returns:
Tokens for each codebook for each frame, shape (batch, number of codebooks, number of frames),
and the corresponding valid lengths, shape (batch,)
"""
# Apply encoder to obtain a continuous vector for each frame
encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len)
# Apply quantizer to obtain discrete representation per frame
tokens = self.quantize(encoded=encoded, encoded_len=encoded_len)
return tokens, encoded_len

@typecheck(
input_types={
"tokens": NeuralType(('B', 'C', 'T_encoded'), Index()),
"tokens_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_len": NeuralType(tuple('B'), LengthsType()),
},
)
def decode(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Convert discrete input tokens into a continuous time-domain signal.
Args:
tokens: discrete tokens for each codebook for each time frame, shape (batch, number of codebooks, number of frames)
tokens_len: valid lengths, shape (batch,)
Returns:
Decoded output `audio` in the time domain and its length in number of samples `audio_len`.
Note that `audio_len` will be a multiple of `self.samples_per_frame`.
"""
# Convert a discrete representation to a dequantized vector for each frame
dequantized = self.dequantize(tokens=tokens, tokens_len=tokens_len)
# Apply decoder to obtain time-domain audio for each frame
audio, audio_len = self.decode_audio(inputs=dequantized, input_len=tokens_len)

return audio, audio_len

@typecheck(
input_types={
Expand All @@ -167,20 +281,40 @@ def quantize_decode(self, indices: torch.Tensor, encoded_len: torch.Tensor) -> t
},
)
def forward(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
audio, audio_len = self.pad_audio(audio, audio_len)
"""Apply encoder, quantizer, decoder on the input time-domain signal.
Args:
audio: input time-domain signal
audio_len: valid length for each example in the batch
Returns:
Reconstructed time-domain signal `output_audio` and its length in number of samples `output_audio_len`.
"""
encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len)

if self.vector_quantizer:
indices = self.quantize_encode(encoded=encoded, encoded_len=encoded_len)
quantized = self.quantize_decode(indices=indices, encoded_len=encoded_len)
output_audio, output_audio_len = self.decode_audio(inputs=quantized, input_len=encoded_len)
# quantize to discrete tokens
tokens = self.quantize(encoded=encoded, encoded_len=encoded_len)
# decode tokens to audio
output_audio, output_audio_len = self.decode(tokens=tokens, tokens_len=encoded_len)
else:
# no quantization, directly decode to audio
output_audio, output_audio_len = self.decode_audio(inputs=encoded, input_len=encoded_len)

return output_audio, output_audio_len

# Zero pad the end of the audio so that we do not have a partial end frame.
def pad_audio(self, audio, audio_len):
"""Zero pad the end of the audio so that we do not have a partial end frame.
The output will be zero-padded to have an integer number of frames of
length `self.samples_per_frame`.
Args:
audio: input time-domain signal
audio_len: valid length for each example in the batch
Returns:
Padded time-domain signal `padded_audio` and its length `padded_len`.
"""
padded_len = self.samples_per_frame * torch.ceil(audio_len / self.samples_per_frame).int()
max_len = padded_len.max().item()
num_padding = max_len - audio.shape[1]
Expand Down Expand Up @@ -357,8 +491,10 @@ def configure_optimizers(self):
optim_d = instantiate(optim_config, params=disc_params)

if sched_config is None:
logging.debug('Scheduler is not used')
return [optim_g, optim_d]

logging.debug('Setting up schedulers')
OmegaConf.set_struct(sched_config, False)
sched_config["max_steps"] = self.max_steps
OmegaConf.set_struct(sched_config, True)
Expand Down
Loading

0 comments on commit 1f8f715

Please sign in to comment.