From 1f8f715133383b93aae5c53dd700ec53132f366c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ante=20Jukic=CC=81?= Date: Tue, 22 Aug 2023 17:36:50 -0700 Subject: [PATCH] Update AudioCodec API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ante Jukić --- examples/tts/audio_codec.py | 2 +- examples/tts/conf/audio_codec/encodec.yaml | 2 +- nemo/collections/tts/models/audio_codec.py | 168 ++++++++++++++++-- .../tts/modules/encodec_modules.py | 68 +++---- nemo/collections/tts/parts/utils/callbacks.py | 54 ++++-- 5 files changed, 228 insertions(+), 66 deletions(-) diff --git a/examples/tts/audio_codec.py b/examples/tts/audio_codec.py index 9244721298de..800edfb7fb0f 100644 --- a/examples/tts/audio_codec.py +++ b/examples/tts/audio_codec.py @@ -23,7 +23,7 @@ @hydra_runner(config_path="conf/audio_codec", config_name="audio_codec") def main(cfg): - logging.info('\nConfig Params:\n%s', OmegaConf.to_yaml(cfg)) + logging.info('\nConfig Params:\n%s', OmegaConf.to_yaml(cfg, resolve=True)) trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) model = AudioCodecModel(cfg=cfg.model, trainer=trainer) diff --git a/examples/tts/conf/audio_codec/encodec.yaml b/examples/tts/conf/audio_codec/encodec.yaml index bb408d7d48e1..a0f7a50c92dd 100644 --- a/examples/tts/conf/audio_codec/encodec.yaml +++ b/examples/tts/conf/audio_codec/encodec.yaml @@ -90,7 +90,7 @@ model: - _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator log_audio: true log_encoding: true - log_quantized: true + log_dequantized: true dataset: _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index b6f217408a81..be38406d3fa1 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -19,6 +19,7 @@ import torch import torch.nn.functional as F +from einops import rearrange from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf from pytorch_lightning import Trainer @@ -36,7 +37,7 @@ from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, Index, LengthsType from nemo.core.neural_types.neural_type import NeuralType from nemo.core.optim.lr_scheduler import compute_max_steps, prepare_lr_scheduler -from nemo.utils import model_utils +from nemo.utils import logging, model_utils from nemo.utils.decorators import experimental @@ -49,9 +50,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg=cfg, trainer=trainer) + # Expected sample rate for the input audio self.sample_rate = cfg.sample_rate + + # Number of samples in each audio frame that is encoded self.samples_per_frame = cfg.samples_per_frame + # Discriminator updates self.disc_updates_per_period = cfg.get("disc_updates_per_period", 1) self.disc_update_period = cfg.get("disc_update_period", 1) if self.disc_updates_per_period > self.disc_update_period: @@ -59,6 +64,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): f'Number of discriminator updates ({self.disc_updates_per_period}) per period must be less or equal to the configured period ({self.disc_update_period})' ) + # Encoder setup self.audio_encoder = instantiate(cfg.audio_encoder) # Optionally, add gaussian noise to encoder output as an information bottleneck @@ -71,11 +77,16 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): if "vector_quantizer" in cfg: self.vector_quantizer = instantiate(cfg.vector_quantizer) else: + logging.warning('Vector quantizer will not be used.') self.vector_quantizer = None + # Decoder setup self.audio_decoder = instantiate(cfg.audio_decoder) + + # Discriminator setup self.discriminator = instantiate(cfg.discriminator) + # Loss setup mel_loss_dim = cfg.get("mel_loss_dim", 64) mel_loss_resolutions = cfg.mel_loss_resolutions self.time_domain_loss_scale = cfg.get("time_domain_loss_scale", 1.0) @@ -95,7 +106,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.disc_loss_fn = instantiate(cfg.discriminator_loss) self.feature_loss_fn = RelativeFeatureMatchingLoss() + # Log setup self.log_config = cfg.get("log_config", None) + + # Optimizer setup self.lr_schedule_interval = None self.automatic_optimization = False @@ -110,6 +124,16 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): }, ) def encode_audio(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply encoder on the input audio signal. Input will be padded with zeros so + the last frame has full `self.samples_per_frame` samples. + + Args: + audio: input time-domain signal + audio_len: valid length for each example in the batch + + Returns: + Encoder output `encoded` and its length in number of frames `encoded_len` + """ audio, audio_len = self.pad_audio(audio, audio_len) encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len) return encoded, encoded_len @@ -125,6 +149,17 @@ def encode_audio(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[to }, ) def decode_audio(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply decoder on the input encoded representation. Note that the input is a + non-quantized or dequantized representation. + + Args: + inputs: encoded signal + input_len: valid length for each example in the batch + + Returns: + Decoded output `audio` in the time domain and its length in number of samples `audio_len`. + Note that `audio_len` will be a multiple of `self.samples_per_frame`. + """ audio, audio_len = self.audio_decoder(inputs=inputs, input_len=input_len) return audio, audio_len @@ -133,28 +168,107 @@ def decode_audio(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[t "encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()), "encoded_len": NeuralType(tuple('B'), LengthsType()), }, - output_types={"indices": NeuralType(('N', 'B', 'T_encoded'), Index())}, + output_types={"tokens": NeuralType(('B', 'C', 'T_encoded'), Index())}, ) - def quantize_encode(self, encoded: torch.Tensor, encoded_len: torch.Tensor) -> torch.Tensor: + def quantize(self, encoded: torch.Tensor, encoded_len: torch.Tensor) -> torch.Tensor: + """Quantize the continuous encoded representation into a discrete + representation for each frame. + + Args: + encoded: encoded signal representation + encoded_len: valid length of the encoded representation in frames + + Returns: + A tensor of tokens for each codebook for each frame. + """ if not self.vector_quantizer: raise ValueError("Cannot quantize without quantizer") - indices = self.vector_quantizer.encode(inputs=encoded, input_len=encoded_len) - return indices + # vector quantizer is returning [C, B, T], where C is the number of codebooks + tokens = self.vector_quantizer.encode(inputs=encoded, input_len=encoded_len) + # use batch first for the output + tokens = rearrange(tokens, 'C B T -> B C T') + return tokens @typecheck( input_types={ - "indices": NeuralType(('N', 'B', 'T_encoded'), Index()), - "encoded_len": NeuralType(tuple('B'), LengthsType()), + "tokens": NeuralType(('B', 'C', 'T_encoded'), Index()), + "tokens_len": NeuralType(tuple('B'), LengthsType()), }, - output_types={"quantized": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),}, + output_types={"dequantized": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),}, ) - def quantize_decode(self, indices: torch.Tensor, encoded_len: torch.Tensor) -> torch.Tensor: + def dequantize(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> torch.Tensor: + """Convert the discrete input tokens into a continuous encoded representation. + + Args: + tokens: discrete tokens for each codebook for each time frame + tokens_len: valid length of each example in the batch + + Returns: + Continuous encoded representation of the discrete input representation. + """ if not self.vector_quantizer: raise ValueError("Cannot dequantize without quantizer") - quantized = self.vector_quantizer.decode(indices=indices, input_len=encoded_len) - return quantized + # vector quantizer is using [C, B, T], where C is the number of codebooks + tokens = rearrange(tokens, 'B C T -> C B T') + dequantized = self.vector_quantizer.decode(indices=tokens, input_len=tokens_len) + return dequantized + + @typecheck( + input_types={ + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={ + "tokens": NeuralType(('B', 'C', 'T_encoded'), Index()), + "tokens_len": NeuralType(tuple('B'), LengthsType()), + }, + ) + def encode(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert input time-domain audio signal into a discrete representation (tokens). + + Args: + audio: input time-domain signal, shape (batch, number of samples) + audio_len: valid length for each example in the batch, shape (batch size,) + + Returns: + Tokens for each codebook for each frame, shape (batch, number of codebooks, number of frames), + and the corresponding valid lengths, shape (batch,) + """ + # Apply encoder to obtain a continuous vector for each frame + encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len) + # Apply quantizer to obtain discrete representation per frame + tokens = self.quantize(encoded=encoded, encoded_len=encoded_len) + return tokens, encoded_len + + @typecheck( + input_types={ + "tokens": NeuralType(('B', 'C', 'T_encoded'), Index()), + "tokens_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={ + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + }, + ) + def decode(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert discrete input tokens into a continuous time-domain signal. + + Args: + tokens: discrete tokens for each codebook for each time frame, shape (batch, number of codebooks, number of frames) + tokens_len: valid lengths, shape (batch,) + + Returns: + Decoded output `audio` in the time domain and its length in number of samples `audio_len`. + Note that `audio_len` will be a multiple of `self.samples_per_frame`. + """ + # Convert a discrete representation to a dequantized vector for each frame + dequantized = self.dequantize(tokens=tokens, tokens_len=tokens_len) + # Apply decoder to obtain time-domain audio for each frame + audio, audio_len = self.decode_audio(inputs=dequantized, input_len=tokens_len) + + return audio, audio_len @typecheck( input_types={ @@ -167,20 +281,40 @@ def quantize_decode(self, indices: torch.Tensor, encoded_len: torch.Tensor) -> t }, ) def forward(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - audio, audio_len = self.pad_audio(audio, audio_len) + """Apply encoder, quantizer, decoder on the input time-domain signal. + + Args: + audio: input time-domain signal + audio_len: valid length for each example in the batch + + Returns: + Reconstructed time-domain signal `output_audio` and its length in number of samples `output_audio_len`. + """ encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len) if self.vector_quantizer: - indices = self.quantize_encode(encoded=encoded, encoded_len=encoded_len) - quantized = self.quantize_decode(indices=indices, encoded_len=encoded_len) - output_audio, output_audio_len = self.decode_audio(inputs=quantized, input_len=encoded_len) + # quantize to discrete tokens + tokens = self.quantize(encoded=encoded, encoded_len=encoded_len) + # decode tokens to audio + output_audio, output_audio_len = self.decode(tokens=tokens, tokens_len=encoded_len) else: + # no quantization, directly decode to audio output_audio, output_audio_len = self.decode_audio(inputs=encoded, input_len=encoded_len) return output_audio, output_audio_len - # Zero pad the end of the audio so that we do not have a partial end frame. def pad_audio(self, audio, audio_len): + """Zero pad the end of the audio so that we do not have a partial end frame. + The output will be zero-padded to have an integer number of frames of + length `self.samples_per_frame`. + + Args: + audio: input time-domain signal + audio_len: valid length for each example in the batch + + Returns: + Padded time-domain signal `padded_audio` and its length `padded_len`. + """ padded_len = self.samples_per_frame * torch.ceil(audio_len / self.samples_per_frame).int() max_len = padded_len.max().item() num_padding = max_len - audio.shape[1] @@ -357,8 +491,10 @@ def configure_optimizers(self): optim_d = instantiate(optim_config, params=disc_params) if sched_config is None: + logging.debug('Scheduler is not used') return [optim_g, optim_d] + logging.debug('Setting up schedulers') OmegaConf.set_struct(sched_config, False) sched_config["max_steps"] = self.max_steps OmegaConf.set_struct(sched_config, True) diff --git a/nemo/collections/tts/modules/encodec_modules.py b/nemo/collections/tts/modules/encodec_modules.py index 70c9bf713aca..031b2001e5ca 100644 --- a/nemo/collections/tts/modules/encodec_modules.py +++ b/nemo/collections/tts/modules/encodec_modules.py @@ -293,7 +293,7 @@ def forward(self, inputs, input_len): for res_block, up_sample_conv, up_sample_rate in zip( self.res_blocks, self.up_sample_conv_layers, self.up_sample_rates ): - audio_len *= up_sample_rate + audio_len = audio_len * up_sample_rate out = self.activation(out) # [B, C / 2, T * up_sample_rate] out = up_sample_conv(out, audio_len) @@ -610,8 +610,8 @@ def _quantize(self, inputs: Tensor) -> Tensor: def _dequantize(self, indices: Tensor) -> Tensor: # [B, D] - quantized = F.embedding(indices, self.codes) - return quantized + dequantized = F.embedding(indices, self.codes) + return dequantized @property def input_types(self): @@ -623,7 +623,7 @@ def input_types(self): @property def output_types(self): return { - "quantized": NeuralType(('B', 'T', 'D'), EncodedRepresentation()), + "dequantized": NeuralType(('B', 'T', 'D'), EncodedRepresentation()), "indices": NeuralType(('B', 'T'), Index()), } @@ -635,16 +635,16 @@ def forward(self, inputs, input_len): # [B, T] indices = indices_flat.view(*inputs.shape[:-1]) # [B, T, D] - quantized = self._dequantize(indices=indices) + dequantized = self._dequantize(indices=indices) if self.training: # We do expiry of codes here because buffers are in sync and all the workers will make the same decision. self._expire_codes(inputs=input_flat) self._update_codes(inputs=input_flat, indices=indices_flat) - quantized = _mask_3d(quantized, input_len) + dequantized = _mask_3d(dequantized, input_len) indices = mask_sequence_tensor(indices, input_len) - return quantized, indices + return dequantized, indices @typecheck( input_types={ @@ -664,13 +664,13 @@ def encode(self, inputs, input_len): @typecheck( input_types={"indices": NeuralType(('B', 'T'), Index()), "input_len": NeuralType(tuple('B'), LengthsType()),}, - output_types={"quantized": NeuralType(('B', 'T', 'D'), EncodedRepresentation())}, + output_types={"dequantized": NeuralType(('B', 'T', 'D'), EncodedRepresentation())}, ) def decode(self, indices, input_len): # [B, T, D] - quantized = self._dequantize(indices=indices) - quantized = _mask_3d(quantized, input_len) - return quantized + dequantized = self._dequantize(indices=indices) + dequantized = _mask_3d(dequantized, input_len) + return dequantized class ResidualVectorQuantizer(NeuralModule): @@ -741,8 +741,8 @@ def input_types(self): @property def output_types(self): return { - "quantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), - "indices": NeuralType(('B', 'T'), Index()), + "dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "indices": NeuralType(('D', 'B', 'T'), Index()), "commit_loss": NeuralType((), LossType()), } @@ -751,35 +751,35 @@ def forward(self, inputs: Tensor, input_len: Tensor) -> Tuple[Tensor, Tensor, fl residual = rearrange(inputs, "B D T -> B T D") index_list = [] - quantized = torch.zeros_like(residual) + dequantized = torch.zeros_like(residual) for codebook in self.codebooks: - quantized_i, indices_i = codebook(inputs=residual, input_len=input_len) + dequantized_i, indices_i = codebook(inputs=residual, input_len=input_len) if self.training: - quantized_i = residual + (quantized_i - residual).detach() - quantized_i_const = quantized_i.detach() - commit_loss_i = self._commit_loss(input=residual, target=quantized_i_const, input_len=input_len) + dequantized_i = residual + (dequantized_i - residual).detach() + dequantized_i_const = dequantized_i.detach() + commit_loss_i = self._commit_loss(input=residual, target=dequantized_i_const, input_len=input_len) commit_loss = commit_loss + commit_loss_i - residual = residual - quantized_i_const + residual = residual - dequantized_i_const else: - residual = residual - quantized_i + residual = residual - dequantized_i - quantized = quantized + quantized_i + dequantized = dequantized + dequantized_i index_list.append(indices_i) - # [N, B, T] + # [N, B, T], first dimension is number of codebooks indices = torch.stack(index_list) - quantized = rearrange(quantized, "B T D -> B D T") - return quantized, indices, commit_loss + dequantized = rearrange(dequantized, "B T D -> B D T") + return dequantized, indices, commit_loss @typecheck( input_types={ "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), "input_len": NeuralType(tuple('B'), LengthsType()), }, - output_types={"indices": NeuralType(('N', 'B', 'T'), Index())}, + output_types={"indices": NeuralType(('D', 'B', 'T'), Index())}, ) def encode(self, inputs: Tensor, input_len: Tensor) -> Tensor: residual = rearrange(inputs, "B D T -> B T D") @@ -788,8 +788,8 @@ def encode(self, inputs: Tensor, input_len: Tensor) -> Tensor: # [B, T] indices_i = codebook.encode(inputs=residual, input_len=input_len) # [B, D, T] - quantized_i = codebook.decode(indices=indices_i, input_len=input_len) - residual = residual - quantized_i + dequantized_i = codebook.decode(indices=indices_i, input_len=input_len) + residual = residual - dequantized_i index_list.append(indices_i) # [N, B, T] indices = torch.stack(index_list) @@ -797,16 +797,16 @@ def encode(self, inputs: Tensor, input_len: Tensor) -> Tensor: @typecheck( input_types={ - "indices": NeuralType(('N', 'B', 'T'), Index()), + "indices": NeuralType(('D', 'B', 'T'), Index()), "input_len": NeuralType(tuple('B'), LengthsType()), }, - output_types={"quantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),}, + output_types={"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),}, ) def decode(self, indices: Tensor, input_len: Tensor) -> Tensor: # [B, T, D] - quantized = torch.zeros([indices.shape[1], indices.shape[2], self.codebook_dim], device=indices.device) + dequantized = torch.zeros([indices.shape[1], indices.shape[2], self.codebook_dim], device=indices.device) for codebook_indices, codebook in zip(indices, self.codebooks): - quantized_i = codebook.decode(indices=codebook_indices, input_len=input_len) - quantized = quantized + quantized_i - quantized = rearrange(quantized, "B T D -> B D T") - return quantized + dequantized_i = codebook.decode(indices=codebook_indices, input_len=input_len) + dequantized = dequantized + dequantized_i + dequantized = rearrange(dequantized, "B T D -> B D T") + return dequantized diff --git a/nemo/collections/tts/parts/utils/callbacks.py b/nemo/collections/tts/parts/utils/callbacks.py index 0d408658d8ad..304efad3f5f7 100644 --- a/nemo/collections/tts/parts/utils/callbacks.py +++ b/nemo/collections/tts/parts/utils/callbacks.py @@ -29,6 +29,7 @@ from pytorch_lightning.loggers.wandb import WandbLogger from nemo.collections.tts.parts.utils.helpers import create_plot +from nemo.utils import logging from nemo.utils.decorators import experimental HAVE_WANDB = True @@ -160,17 +161,28 @@ def __init__( self.log_wandb = log_wandb if log_tensorboard: + logging.info('Creating tensorboard logger') self.tensorboard_logger = _get_logger(self.loggers, TensorBoardLogger) else: + logging.debug('Not using tensorbord logger') self.tensorboard_logger = None if log_wandb: if not HAVE_WANDB: raise ValueError("Wandb not installed.") + logging.info('Creating wandb logger') self.wandb_logger = _get_logger(self.loggers, WandbLogger) else: + logging.debug('Not using wandb logger') self.wandb_logger = None + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tlog_epochs: %s', self.log_epochs) + logging.debug('\tepoch_frequency: %s', self.epoch_frequency) + logging.debug('\toutput_dir: %s', self.output_dir) + logging.debug('\tlog_tensorboard: %s', self.log_tensorboard) + logging.debug('\tlog_wandb: %s', self.log_wandb) + def _log_audio(self, audio: AudioArtifact, log_dir: Path, step: int): if log_dir: filepath = log_dir / audio.filename @@ -270,10 +282,20 @@ class AudioCodecArtifactGenerator(ArtifactGenerator): Generator for logging Audio Codec model outputs. """ - def __init__(self, log_audio: bool = True, log_encoding: bool = False, log_quantized: bool = False): + def __init__(self, log_audio: bool = True, log_encoding: bool = False, log_dequantized: bool = False): + # Log reconstructed audio (decoder output) self.log_audio = log_audio + # Log encoded representation of the input audio (encoder output) self.log_encoding = log_encoding - self.log_quantized = log_quantized + # Log dequantized encoded representation of the input audio (decoder input) + self.log_dequantized = log_dequantized + # Input audio will be logged only once + self.input_audio_logged = False + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tlog_audio: %s', self.log_audio) + logging.debug('\tlog_encoding: %s', self.log_encoding) + logging.debug('\tlog_dequantized: %s', self.log_dequantized) def _generate_audio(self, model, audio_ids, audio, audio_len): if not self.log_audio: @@ -284,10 +306,14 @@ def _generate_audio(self, model, audio_ids, audio, audio_len): audio_pred, audio_pred_len = model(audio=audio, audio_len=audio_len) audio_artifacts = [] + # Log output audio for i, audio_id in enumerate(audio_ids): audio_pred_i = audio_pred[i, : audio_pred_len[i]].cpu().numpy() audio_artifact = AudioArtifact( - id=f"audio_{audio_id}", data=audio_pred_i, filename=f"{audio_id}.wav", sample_rate=model.sample_rate, + id=f"audio_out_{audio_id}", + data=audio_pred_i, + filename=f"{audio_id}_audio_out.wav", + sample_rate=model.sample_rate, ) audio_artifacts.append(audio_artifact) @@ -296,7 +322,7 @@ def _generate_audio(self, model, audio_ids, audio, audio_len): def _generate_images(self, model, audio_ids, audio, audio_len): image_artifacts = [] - if not self.log_encoding and not self.log_quantized: + if not self.log_encoding and not self.log_dequantized: return image_artifacts with torch.no_grad(): @@ -309,30 +335,30 @@ def _generate_images(self, model, audio_ids, audio, audio_len): encoded_artifact = ImageArtifact( id=f"encoded_{audio_id}", data=encoded_i, - filename=f"{audio_id}_encode.png", + filename=f"{audio_id}_encoded.png", x_axis="Audio Frames", y_axis="Channels", ) image_artifacts.append(encoded_artifact) - if not self.log_quantized: + if not self.log_dequantized: return image_artifacts with torch.no_grad(): # [B, D, T] - indices = model.quantize_encode(encoded=encoded, encoded_len=encoded_len) - quantized = model.quantize_decode(indices=indices, encoded_len=encoded_len) + tokens = model.quantize(encoded=encoded, encoded_len=encoded_len) + dequantized = model.dequantize(tokens=tokens, tokens_len=encoded_len) for i, audio_id in enumerate(audio_ids): - quantized_i = quantized[i, :, : encoded_len[i]].cpu().numpy() - quantized_artifact = ImageArtifact( - id=f"quantized_{audio_id}", - data=quantized_i, - filename=f"{audio_id}_quantized.png", + dequantized_i = dequantized[i, :, : encoded_len[i]].cpu().numpy() + dequantized_artifact = ImageArtifact( + id=f"dequantized_{audio_id}", + data=dequantized_i, + filename=f"{audio_id}_dequantized.png", x_axis="Audio Frames", y_axis="Channels", ) - image_artifacts.append(quantized_artifact) + image_artifacts.append(dequantized_artifact) return image_artifacts