From 6cb955bffc7efbbbdcd4c8a1836daf7c5070fd1c Mon Sep 17 00:00:00 2001 From: Ryan Date: Wed, 10 May 2023 09:50:25 -0700 Subject: [PATCH 1/5] [TTS] Create EnCodec training recipe Signed-off-by: Ryan --- .../encodec/bottleneck/bottleneck_noise.yaml | 2 + .../encodec/bottleneck/bottleneck_vq.yaml | 4 + examples/tts/conf/encodec/encodec.yaml | 134 ++++++ .../tts/conf/encodec/sample/sample_22050.yaml | 4 + .../tts/conf/encodec/sample/sample_24000.yaml | 4 + .../tts/conf/encodec/sample/sample_44100.yaml | 4 + .../tts/conf/encodec/sample/sample_48000.yaml | 4 + examples/tts/conf/hifigan/hifigan_data.yaml | 1 + examples/tts/encodec.py | 31 ++ nemo/collections/tts/data/vocoder_dataset.py | 11 + nemo/collections/tts/losses/encodec_loss.py | 163 +++++++ nemo/collections/tts/losses/loss.py | 48 ++ nemo/collections/tts/models/__init__.py | 2 + nemo/collections/tts/models/encodec.py | 365 ++++++++++++++ nemo/collections/tts/modules/common.py | 27 ++ .../tts/modules/encodec_modules.py | 450 ++++++++++++++++++ .../tts/modules/vector_quantization.py | 403 ++++++++++++++++ nemo/collections/tts/parts/utils/callbacks.py | 65 +++ .../tts/parts/utils/distributed.py | 28 ++ tests/collections/tts/losses/test_loss.py | 44 ++ 20 files changed, 1794 insertions(+) create mode 100644 examples/tts/conf/encodec/bottleneck/bottleneck_noise.yaml create mode 100644 examples/tts/conf/encodec/bottleneck/bottleneck_vq.yaml create mode 100644 examples/tts/conf/encodec/encodec.yaml create mode 100644 examples/tts/conf/encodec/sample/sample_22050.yaml create mode 100644 examples/tts/conf/encodec/sample/sample_24000.yaml create mode 100644 examples/tts/conf/encodec/sample/sample_44100.yaml create mode 100644 examples/tts/conf/encodec/sample/sample_48000.yaml create mode 100644 examples/tts/encodec.py create mode 100644 nemo/collections/tts/losses/encodec_loss.py create mode 100644 nemo/collections/tts/losses/loss.py create mode 100644 nemo/collections/tts/models/encodec.py create mode 100644 nemo/collections/tts/modules/encodec_modules.py create mode 100644 nemo/collections/tts/modules/vector_quantization.py create mode 100644 nemo/collections/tts/parts/utils/distributed.py create mode 100644 tests/collections/tts/losses/test_loss.py diff --git a/examples/tts/conf/encodec/bottleneck/bottleneck_noise.yaml b/examples/tts/conf/encodec/bottleneck/bottleneck_noise.yaml new file mode 100644 index 000000000000..41136e35e275 --- /dev/null +++ b/examples/tts/conf/encodec/bottleneck/bottleneck_noise.yaml @@ -0,0 +1,2 @@ +encoder_noise_stdev: 0.2 +vector_quantizer: null \ No newline at end of file diff --git a/examples/tts/conf/encodec/bottleneck/bottleneck_vq.yaml b/examples/tts/conf/encodec/bottleneck/bottleneck_vq.yaml new file mode 100644 index 000000000000..8efade7dfa17 --- /dev/null +++ b/examples/tts/conf/encodec/bottleneck/bottleneck_vq.yaml @@ -0,0 +1,4 @@ +encoder_noise_stdev: 0.0 +vector_quantizer: + _target_: nemo.collections.tts.modules.vector_quantization.ResidualVectorQuantizer + num_codebooks: 8 \ No newline at end of file diff --git a/examples/tts/conf/encodec/encodec.yaml b/examples/tts/conf/encodec/encodec.yaml new file mode 100644 index 000000000000..7ef1341f3665 --- /dev/null +++ b/examples/tts/conf/encodec/encodec.yaml @@ -0,0 +1,134 @@ +# This config contains the default values for training EnCodec +# If you want to train model on other dataset, you can change config values according to your dataset. +# Most dataset-specific arguments are in the head of the config file, see below. + +name: EnCodec + +defaults: + - sample: ??? + - bottleneck: ??? + +max_epochs: ??? +batch_size: 16 +weighted_sampling_steps_per_epoch: null + +train_ds_meta: ??? +val_ds_meta: ??? +log_ds_meta: ??? + +log_dir: ??? + +model: + + max_epochs: ${max_epochs} + steps_per_epoch: ${weighted_sampling_steps_per_epoch} + + sample_rate: ${sample.sample_rate} + disc_update_steps: 2 + encoder_noise_stdev: ${bottleneck.encoder_noise_stdev} + + mel_loss_resolutions: [ + [32, 8, 32], [64, 16, 64], [128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048] + ] + + train_ds: + dataset: + _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset + weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch} + sample_rate: ${sample.sample_rate} + n_samples: ${sample.train_n_samples} + min_duration: 1.01 + max_duration: null + dataset_meta: ${train_ds_meta} + + dataloader_params: + batch_size: ${batch_size} + drop_last: true + num_workers: 4 + + validation_ds: + dataset: + _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset + sample_rate: ${sample.sample_rate} + n_samples: null + min_duration: null + max_duration: null + trunc_duration: 10.0 + dataset_meta: ${val_ds_meta} + + dataloader_params: + batch_size: 8 + num_workers: 2 + + log_config: + log_dir: ${log_dir} + log_epochs: [10, 50] + epoch_frequency: 100 + log_tensorboard: false + log_wandb: false + + generators: + - _target_: nemo.collections.tts.parts.utils.callbacks.EnCodecArtifactGenerator + log_audio: true + log_encoding: true + + dataset: + _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset + sample_rate: ${sample.sample_rate} + n_samples: null + min_duration: null + max_duration: null + trunc_duration: 15.0 + dataset_meta: ${log_ds_meta} + + dataloader_params: + batch_size: 4 + num_workers: 2 + + audio_encoder: + _target_: nemo.collections.tts.modules.encodec_modules.SEANetEncoder + down_sample_rates: ${sample.down_sample_rates} + + generator: + _target_: nemo.collections.tts.modules.encodec_modules.SEANetDecoder + up_sample_rates: ${sample.up_sample_rates} + + vector_quantizer: ${bottleneck.vector_quantizer} + + discriminator: + _target_: nemo.collections.tts.modules.encodec_modules.MultiResolutionDiscriminatorSTFT + resolutions: [[128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]] + + optim: + _target_: torch.optim.Adam + lr: 3e-4 + betas: [0.5, 0.9] + + sched: + name: ExponentialLR + gamma: 0.999 + +trainer: + num_nodes: 1 + devices: 1 + accelerator: gpu + strategy: ddp + precision: 32 + max_epochs: ${max_epochs} + accumulate_grad_batches: 1 + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + log_every_n_steps: 100 + check_val_every_n_epoch: 5 + benchmark: false + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + create_wandb_logger: false + checkpoint_callback_params: + monitor: val_loss + resume_if_exists: false + resume_ignore_no_checkpoint: false diff --git a/examples/tts/conf/encodec/sample/sample_22050.yaml b/examples/tts/conf/encodec/sample/sample_22050.yaml new file mode 100644 index 000000000000..93728e92a589 --- /dev/null +++ b/examples/tts/conf/encodec/sample/sample_22050.yaml @@ -0,0 +1,4 @@ +sample_rate: 22050 +train_n_samples: 22016 +down_sample_rates: [2, 4, 4, 8] +up_sample_rates: [8, 4, 4, 2] diff --git a/examples/tts/conf/encodec/sample/sample_24000.yaml b/examples/tts/conf/encodec/sample/sample_24000.yaml new file mode 100644 index 000000000000..35d61a48cbc3 --- /dev/null +++ b/examples/tts/conf/encodec/sample/sample_24000.yaml @@ -0,0 +1,4 @@ +sample_rate: 24000 +train_n_samples: 24000 +down_sample_rates: [2, 4, 5, 8] +up_sample_rates: [8, 5, 4, 2] diff --git a/examples/tts/conf/encodec/sample/sample_44100.yaml b/examples/tts/conf/encodec/sample/sample_44100.yaml new file mode 100644 index 000000000000..1519c7b1d4fc --- /dev/null +++ b/examples/tts/conf/encodec/sample/sample_44100.yaml @@ -0,0 +1,4 @@ +sample_rate: 44100 +train_n_samples: 44032 +down_sample_rates: [2, 4, 8, 8] +up_sample_rates: [8, 8, 4, 2] diff --git a/examples/tts/conf/encodec/sample/sample_48000.yaml b/examples/tts/conf/encodec/sample/sample_48000.yaml new file mode 100644 index 000000000000..b13af98f6967 --- /dev/null +++ b/examples/tts/conf/encodec/sample/sample_48000.yaml @@ -0,0 +1,4 @@ +sample_rate: 48000 +train_n_samples: 48000 +down_sample_rates: [2, 4, 5, 8] +up_sample_rates: [8, 5, 4, 2] diff --git a/examples/tts/conf/hifigan/hifigan_data.yaml b/examples/tts/conf/hifigan/hifigan_data.yaml index fde2f169aa8d..62ce3344636e 100644 --- a/examples/tts/conf/hifigan/hifigan_data.yaml +++ b/examples/tts/conf/hifigan/hifigan_data.yaml @@ -92,6 +92,7 @@ model: n_samples: null min_duration: null max_duration: null + trunc_duration: 15.0 dataset_meta: ${log_ds_meta} dataloader_params: diff --git a/examples/tts/encodec.py b/examples/tts/encodec.py new file mode 100644 index 000000000000..1443a1f54f40 --- /dev/null +++ b/examples/tts/encodec.py @@ -0,0 +1,31 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytorch_lightning as pl + +from nemo.collections.tts.models import EnCodecModel +from nemo.core.config import hydra_runner +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="conf/encodec", config_name="encodec") +def main(cfg): + trainer = pl.Trainer(**cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + model = EnCodecModel(cfg=cfg.model, trainer=trainer) + trainer.fit(model) + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/nemo/collections/tts/data/vocoder_dataset.py b/nemo/collections/tts/data/vocoder_dataset.py index 9bb115ba2448..6bf03068a395 100644 --- a/nemo/collections/tts/data/vocoder_dataset.py +++ b/nemo/collections/tts/data/vocoder_dataset.py @@ -65,6 +65,7 @@ class VocoderDataset(Dataset): will be ignored. max_duration: Optional float, if provided audio files in the training manifest longer than 'max_duration' will be ignored. + trunc_duration: Optional int, if provided audio will be truncated to at most 'trunc_duration' seconds. num_audio_retries: Number of read attempts to make when sampling audio file, to avoid training failing from sporadic IO errors. """ @@ -78,6 +79,7 @@ def __init__( feature_processors: Optional[Dict[str, FeatureProcessor]] = None, min_duration: Optional[float] = None, max_duration: Optional[float] = None, + trunc_duration: Optional[float] = None, num_audio_retries: int = 5, ): super().__init__() @@ -88,6 +90,11 @@ def __init__( self.num_audio_retries = num_audio_retries self.load_precomputed_mel = False + if trunc_duration: + self.trunc_samples = int(trunc_duration * self.sample_rate) + else: + self.trunc_samples = None + if feature_processors: logging.info(f"Found feature processors {feature_processors.keys()}") self.feature_processors = list(feature_processors.values()) @@ -132,6 +139,10 @@ def _sample_audio(self, audio_filepath: Path) -> Tuple[torch.Tensor, torch.Tenso else: audio_segment = self._segment_audio(audio_filepath) audio_array = audio_segment.samples + + if self.trunc_samples: + audio_array = audio_array[: self.trunc_samples] + audio = torch.tensor(audio_array) audio_len = torch.tensor(audio.shape[0]) return audio, audio_len diff --git a/nemo/collections/tts/losses/encodec_loss.py b/nemo/collections/tts/losses/encodec_loss.py new file mode 100644 index 000000000000..38e21bac6ac3 --- /dev/null +++ b/nemo/collections/tts/losses/encodec_loss.py @@ -0,0 +1,163 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch +import torch.nn.functional as F + +from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures +from nemo.collections.tts.losses.loss import MaskedLoss +from nemo.core.classes import Loss, typecheck +from nemo.core.neural_types import AudioSignal, LengthsType, LossType, NeuralType, VoidType + + +class MultiResolutionMelLoss(Loss): + def __init__( + self, sample_rate: int, mel_dim: int, resolutions: List[List], l1_scale: float = 1.0, l2_scale: float = 1.0 + ): + super(MultiResolutionMelLoss, self).__init__() + + self.l1_loss_fn = MaskedLoss("l1", loss_scale=l1_scale) + self.l2_loss_fn = MaskedLoss("l2", loss_scale=l2_scale) + + self.mel_features = torch.nn.ModuleList() + for n_fft, hop_len, win_len in resolutions: + mel_feature = FilterbankFeatures( + sample_rate=sample_rate, + nfilt=mel_dim, + n_window_size=win_len, + n_window_stride=hop_len, + n_fft=n_fft, + pad_to=1, + mag_power=1.0, + log_zero_guard_type="add", + log_zero_guard_value=1.0, + mel_norm=None, + normalize=None, + preemph=None, + dither=0.0, + use_grads=True, + ) + self.mel_features.append(mel_feature) + + @property + def input_types(self): + return { + "audio_real": NeuralType(('B', 'T'), AudioSignal()), + "audio_gen": NeuralType(('B', 'T'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "loss": [NeuralType(elements_type=LossType())], + } + + @typecheck() + def forward(self, audio_real, audio_gen, audio_len): + len_diff = audio_real.shape[1] - audio_gen.shape[1] + audio_gen = F.pad(audio_gen, (0, len_diff)) + + loss = 0.0 + for mel_feature in self.mel_features: + mel_real, mel_real_len = mel_feature(x=audio_real, seq_len=audio_len) + mel_gen, _ = mel_feature(x=audio_gen, seq_len=audio_len) + loss = loss + self.l1_loss_fn(predicted=mel_gen, target=mel_real, target_len=mel_real_len) + loss = loss + self.l2_loss_fn(predicted=mel_gen, target=mel_real, target_len=mel_real_len) + + loss /= len(self.mel_features) + + return loss + + +class FeatureMatchingLoss(Loss): + @property + def input_types(self): + return { + "fmaps_real": [[NeuralType(elements_type=VoidType())]], + "fmaps_gen": [[NeuralType(elements_type=VoidType())]], + } + + @property + def output_types(self): + return { + "loss": NeuralType(elements_type=LossType()), + } + + @typecheck() + def forward(self, fmaps_real, fmaps_gen): + loss = 0.0 + for fmap_real, fmap_gen in zip(fmaps_real, fmaps_gen): + # [B, ..., time] + for feat_real, feat_gen in zip(fmap_real, fmap_gen): + # [B, ...] + feat_mean = torch.mean(torch.abs(feat_real), dim=-1) + diff = torch.mean(torch.abs(feat_real - feat_gen), dim=-1) + feat_loss = diff / (feat_mean + 1e-2) + # [1] + feat_loss = torch.mean(feat_loss) / len(fmap_real) + loss = loss + feat_loss + + loss /= len(fmaps_real) + + return loss + + +class GeneratorLoss(Loss): + @property + def input_types(self): + return { + "disc_scores": [NeuralType(('B', 'C', 'T'), VoidType())], + } + + @property + def output_types(self): + return {"loss": NeuralType(elements_type=LossType())} + + @typecheck() + def forward(self, disc_scores): + loss = 0.0 + for disc_score in disc_scores: + loss = loss + torch.mean((1 - disc_score) ** 2) + + loss /= len(disc_scores) + + return loss + + +class DiscriminatorLoss(Loss): + @property + def input_types(self): + return { + "disc_scores_real": [NeuralType(('B', 'C', 'T'), VoidType())], + "disc_scores_gen": [NeuralType(('B', 'C', 'T'), VoidType())], + } + + @property + def output_types(self): + return {"loss": NeuralType(elements_type=LossType())} + + @typecheck() + def forward(self, disc_scores_real, disc_scores_gen): + loss = 0.0 + for disc_score_real, disc_score_gen in zip(disc_scores_real, disc_scores_gen): + loss_real = torch.mean((1 - disc_score_real) ** 2) + loss_gen = torch.mean(disc_score_gen ** 2) + loss = loss + (loss_real + loss_gen) / 2 + + loss /= len(disc_scores_real) + + return loss diff --git a/nemo/collections/tts/losses/loss.py b/nemo/collections/tts/losses/loss.py new file mode 100644 index 000000000000..4eae8d72f7a0 --- /dev/null +++ b/nemo/collections/tts/losses/loss.py @@ -0,0 +1,48 @@ +import torch + +from nemo.core.classes import Loss, typecheck +from nemo.core.neural_types import LengthsType, LossType, NeuralType, PredictionsType, RegressionValuesType + + +class MaskedLoss(Loss): + def __init__(self, loss_type: str, loss_scale: float = 1.0): + super(MaskedLoss, self).__init__() + self.loss_scale = loss_scale + + if loss_type == "l1": + self.loss_fn = torch.nn.L1Loss(reduction='none') + elif loss_type == "l2": + self.loss_fn = torch.nn.MSELoss(reduction='none') + else: + raise ValueError(f"Unknown loss type {loss_type}") + + @property + def input_types(self): + return { + "target": NeuralType(('B', 'D', 'T'), RegressionValuesType()), + "predicted": NeuralType(('B', 'D', 'T'), PredictionsType()), + "target_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "loss": NeuralType(elements_type=LossType()), + } + + @typecheck() + def forward(self, predicted, target, target_len): + assert target.shape[2] == predicted.shape[2] + + # [B, D, T] + loss = self.loss_fn(input=predicted, target=target) + # [B, T] + loss = torch.mean(loss, dim=1) + # [B] + loss = torch.sum(loss, dim=1) / torch.clamp(target_len, min=1.0) + + # [1] + loss = torch.mean(loss) + loss = self.loss_scale * loss + + return loss diff --git a/nemo/collections/tts/models/__init__.py b/nemo/collections/tts/models/__init__.py index db02a0d6bda4..31e649688521 100644 --- a/nemo/collections/tts/models/__init__.py +++ b/nemo/collections/tts/models/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from nemo.collections.tts.models.aligner import AlignerModel +from nemo.collections.tts.models.encodec import EnCodecModel from nemo.collections.tts.models.fastpitch import FastPitchModel from nemo.collections.tts.models.fastpitch_ssl import FastPitchModel_SSL from nemo.collections.tts.models.hifigan import HifiGanModel @@ -42,4 +43,5 @@ "VitsModel", "WaveGlowModel", "SpectrogramEnhancerModel", + "EnCodecModel", ] diff --git a/nemo/collections/tts/models/encodec.py b/nemo/collections/tts/models/encodec.py new file mode 100644 index 000000000000..bfb55cb307c5 --- /dev/null +++ b/nemo/collections/tts/models/encodec.py @@ -0,0 +1,365 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +from pathlib import Path +from typing import List, Tuple + +import torch +from hydra.utils import instantiate +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning import Trainer + +from nemo.collections.tts.losses.encodec_loss import ( + DiscriminatorLoss, + FeatureMatchingLoss, + GeneratorLoss, + MultiResolutionMelLoss, +) +from nemo.collections.tts.modules.common import GaussianDropout +from nemo.collections.tts.parts.utils.callbacks import LoggingCallback +from nemo.collections.tts.parts.utils.helpers import get_batch_size, get_num_workers +from nemo.core import ModelPT +from nemo.core.classes.common import PretrainedModelInfo, typecheck +from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, Index, LengthsType +from nemo.core.neural_types.neural_type import NeuralType +from nemo.core.optim.lr_scheduler import compute_max_steps, prepare_lr_scheduler +from nemo.utils import model_utils +from nemo.utils.decorators import experimental + + +@experimental +class EnCodecModel(ModelPT): + """EnCodec model (https://github.com/facebookresearch/encodec) that encodes and decodes audio.""" + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # Convert to Hydra 1.0 compatible DictConfig + cfg = model_utils.convert_model_config_to_dict_config(cfg) + cfg = model_utils.maybe_update_config_version(cfg) + + super().__init__(cfg=cfg, trainer=trainer) + + self.sample_rate = cfg.sample_rate + self.disc_update_steps = cfg.get("disc_update_steps", 1) + self.audio_encoder = instantiate(cfg.audio_encoder) + + # Optionally, add gaussian noise to encoder output as an information bottleneck + encoder_noise_stdev = cfg.get("encoder_noise_stdev", 0.0) + if encoder_noise_stdev: + self.encoder_noise = GaussianDropout(stdev=encoder_noise_stdev) + else: + self.encoder_noise = None + + if "vector_quantizer" in cfg: + self.vector_quantizer = instantiate(cfg.vector_quantizer) + else: + self.vector_quantizer = None + + self.generator = instantiate(cfg.generator) + self.discriminator = instantiate(cfg.discriminator) + + mel_loss_dim = cfg.get("mel_loss_dim", 64) + mel_loss_resolutions = cfg.mel_loss_resolutions + mel_loss_l1_scale = cfg.get("mel_loss_l1_scale", 0.1) + mel_loss_l2_scale = cfg.get("mel_loss_l2_scale", 1.0) + self.gen_loss_scale = cfg.get("gen_loss_scale", 4.0) + self.feature_loss_scale = cfg.get("feature_loss_scale", 4.0) + + self.mel_loss_fn = MultiResolutionMelLoss( + sample_rate=self.sample_rate, + mel_dim=mel_loss_dim, + resolutions=mel_loss_resolutions, + l1_scale=mel_loss_l1_scale, + l2_scale=mel_loss_l2_scale, + ) + self.gen_loss_fn = GeneratorLoss() + self.feature_loss_fn = FeatureMatchingLoss() + self.disc_loss_fn = DiscriminatorLoss() + + self.log_config = cfg.get("log_config", None) + self.lr_schedule_interval = None + self.automatic_optimization = False + + @typecheck( + input_types={ + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={ + "encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()), + "encoded_len": NeuralType(tuple('B'), LengthsType()), + }, + ) + def encode_audio(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len) + return encoded, encoded_len + + @typecheck( + input_types={ + "inputs": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()), + "input_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={ + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + }, + ) + def decode_audio(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + audio, audio_len = self.generator(inputs=inputs, input_len=input_len) + return audio, audio_len + + @typecheck( + input_types={ + "encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()), + "encoded_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={"indices": NeuralType(('N', 'B', 'T_encoded'), Index())}, + ) + def quantize_encode(self, encoded: torch.Tensor, encoded_len: torch.Tensor) -> torch.Tensor: + if not self.vector_quantizer: + raise ValueError("Cannot quantize without quantizer") + + indices = self.vector_quantizer.encode(inputs=encoded, input_len=encoded_len) + return indices + + @typecheck( + input_types={ + "indices": NeuralType(('N', 'B', 'T_encoded'), Index()), + "encoded_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={"quantized": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),}, + ) + def quantize_decode(self, indices: torch.Tensor, encoded_len: torch.Tensor) -> torch.Tensor: + if not self.vector_quantizer: + raise ValueError("Cannot dequantize without quantizer") + + quantized = self.vector_quantizer.decode(indices=indices, input_len=encoded_len) + return quantized + + @typecheck( + input_types={ + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={ + "output_audio": NeuralType(('B', 'T_audio'), EncodedRepresentation()), + "output_audio_len": NeuralType(tuple('B'), LengthsType()), + }, + ) + def forward(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len) + + if self.vector_quantizer: + indices = self.quantize_encode(encoded=encoded, encoded_len=encoded_len) + quantized = self.quantize_decode(indices=indices, encoded_len=encoded_len) + output_audio, output_audio_len = self.decode_audio(inputs=quantized, input_len=encoded_len) + else: + output_audio, output_audio_len = self.decode_audio(inputs=encoded, input_len=encoded_len) + + return output_audio, output_audio_len + + def _process_batch(self, batch): + # [B, T_audio] + audio = batch.get("audio") + # [B] + audio_len = batch.get("audio_lens") + + # [B, D, T_encoded] + encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len) + + if self.encoder_noise is not None: + encoded = self.encoder_noise(encoded) + + if self.vector_quantizer: + encoded, _, commit_loss = self.vector_quantizer(inputs=encoded, input_len=encoded_len) + else: + commit_loss = None + + # [B, T] + audio_gen, audio_gen_len = self.generator(inputs=encoded, input_len=encoded_len) + + return audio, audio_len, audio_gen, commit_loss + + def training_step(self, batch, batch_idx): + + optim_gen, optim_disc = self.optimizers() + optim_gen.zero_grad() + + audio, audio_len, audio_gen, commit_loss = self._process_batch(batch) + + if batch_idx % self.disc_update_steps != 0: + loss_disc = None + else: + # Train discriminator + optim_disc.zero_grad() + + disc_scores_real, disc_scores_gen, _, _ = self.discriminator( + audio_real=audio, audio_gen=audio_gen.detach() + ) + loss_disc = self.disc_loss_fn(disc_scores_real=disc_scores_real, disc_scores_gen=disc_scores_gen) + + self.manual_backward(loss_disc) + optim_disc.step() + + loss_mel = self.mel_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) + + _, disc_scores_gen, fmaps_real, fmaps_gen = self.discriminator(audio_real=audio, audio_gen=audio_gen) + + loss_gen = self.gen_loss_fn(disc_scores=disc_scores_gen) + train_loss_gen = self.gen_loss_scale * loss_gen + + loss_feature = self.feature_loss_fn(fmaps_real=fmaps_real, fmaps_gen=fmaps_gen) + train_loss_feature = self.feature_loss_scale * loss_feature + + loss_gen_all = loss_mel + train_loss_gen + train_loss_feature + if commit_loss is not None: + loss_gen_all += commit_loss + + self.manual_backward(loss_gen_all) + optim_gen.step() + + self.update_lr() + + metrics = { + "g_loss_mel": loss_mel, + "g_loss_gen": loss_gen, + "g_loss_feature": loss_feature, + "g_loss": loss_gen_all, + "global_step": self.global_step, + "lr": optim_gen.param_groups[0]['lr'], + } + + if loss_disc is not None: + metrics["d_loss"] = loss_disc + + if commit_loss is not None: + metrics["g_loss_commit"] = commit_loss + + self.log_dict(metrics, on_step=True, sync_dist=True) + self.log("t_loss", loss_mel, prog_bar=True, logger=False, sync_dist=True) + + def training_epoch_end(self, outputs): + self.update_lr("epoch") + + def validation_step(self, batch, batch_idx): + audio, audio_len, audio_gen, _ = self._process_batch(batch) + loss_mel = self.mel_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) + self.log_dict({"val_loss": loss_mel}, on_epoch=True, sync_dist=True) + + @staticmethod + def _setup_train_dataloader(cfg): + dataset = instantiate(cfg.dataset) + sampler = dataset.get_sampler(cfg.dataloader_params.batch_size) + data_loader = torch.utils.data.DataLoader( + dataset, collate_fn=dataset.collate_fn, sampler=sampler, **cfg.dataloader_params + ) + return data_loader + + @staticmethod + def _setup_test_dataloader(cfg): + dataset = instantiate(cfg.dataset) + data_loader = torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params) + return data_loader + + def setup_training_data(self, cfg): + self._train_dl = self._setup_train_dataloader(cfg) + + def setup_validation_data(self, cfg): + self._validation_dl = self._setup_test_dataloader(cfg) + + def setup_test_data(self, cfg): + pass + + @property + def max_steps(self): + if "max_steps" in self._cfg: + return self._cfg.get("max_steps") + + if "max_epochs" not in self._cfg: + raise ValueError("Must specify 'max_steps' or 'max_epochs'.") + + if "steps_per_epoch" in self._cfg: + return self._cfg.max_epochs * self._cfg.steps_per_epoch + + return compute_max_steps( + max_epochs=self._cfg.max_epochs, + accumulate_grad_batches=self.trainer.accumulate_grad_batches, + limit_train_batches=self.trainer.limit_train_batches, + num_workers=get_num_workers(self.trainer), + num_samples=len(self._train_dl.dataset), + batch_size=get_batch_size(self._train_dl), + drop_last=self._train_dl.drop_last, + ) + + def configure_optimizers(self): + optim_config = self._cfg.optim.copy() + + OmegaConf.set_struct(optim_config, False) + sched_config = optim_config.pop("sched", None) + OmegaConf.set_struct(optim_config, True) + + gen_params = itertools.chain(self.audio_encoder.parameters(), self.generator.parameters()) + disc_params = self.discriminator.parameters() + optim_g = instantiate(optim_config, params=gen_params) + optim_d = instantiate(optim_config, params=disc_params) + + if sched_config is None: + return [optim_g, optim_d] + + OmegaConf.set_struct(sched_config, False) + sched_config["max_steps"] = self.max_steps + OmegaConf.set_struct(sched_config, True) + + scheduler_g = prepare_lr_scheduler( + optimizer=optim_g, scheduler_config=sched_config, train_dataloader=self._train_dl + ) + + scheduler_d = prepare_lr_scheduler( + optimizer=optim_d, scheduler_config=sched_config, train_dataloader=self._train_dl + ) + + self.lr_schedule_interval = scheduler_g["interval"] + + return [optim_g, optim_d], [scheduler_g, scheduler_d] + + def update_lr(self, interval="step"): + schedulers = self.lr_schedulers() + if schedulers is not None and self.lr_schedule_interval == interval: + sch1, sch2 = schedulers + sch1.step() + sch2.step() + + def configure_callbacks(self): + if not self.log_config: + return + + data_loader = self._setup_test_dataloader(self.log_config) + generators = instantiate(self.log_config.generators) + log_dir = Path(self.log_config.log_dir) if self.log_config.log_dir else None + log_callback = LoggingCallback( + generators=generators, + data_loader=data_loader, + log_epochs=self.log_config.log_epochs, + epoch_frequency=self.log_config.epoch_frequency, + output_dir=log_dir, + loggers=self.trainer.loggers, + log_tensorboard=self.log_config.log_tensorboard, + log_wandb=self.log_config.log_wandb, + ) + + return [log_callback] + + @classmethod + def list_available_models(cls) -> List[PretrainedModelInfo]: + return [] diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 7f6652f8455d..5f7d6153a7d1 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -764,3 +764,30 @@ def forward(self, queries, keys, query_lens, mask=None, key_lens=None, attn_prio attn = self.softmax(attn) # softmax along T2 return attn, attn_logprob + + +class GaussianDropout(torch.nn.Module): + """ + Gaussian dropout using multiplicative gaussian noise. + + https://keras.io/api/layers/regularization_layers/gaussian_dropout/ + + Can be an effective alternative bottleneck to VAE or VQ: + + https://www.deepmind.com/publications/gaussian-dropout-as-an-information-bottleneck-layer + + Unlike some other implementations, this takes the standard deviation of the noise as input + instead of the 'rate' typically defined as: stdev = sqrt(rate / (1 - rate)) + """ + + def __init__(self, stdev=1.0): + super(GaussianDropout, self).__init__() + self.stdev = stdev + + def forward(self, inputs): + if not self.training: + return inputs + + noise = torch.normal(mean=1.0, std=self.stdev, size=inputs.shape, device=inputs.device) + out = noise * inputs + return out diff --git a/nemo/collections/tts/modules/encodec_modules.py b/nemo/collections/tts/modules/encodec_modules.py new file mode 100644 index 000000000000..2af6aeded42b --- /dev/null +++ b/nemo/collections/tts/modules/encodec_modules.py @@ -0,0 +1,450 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable, Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange + +from nemo.core.classes.module import NeuralModule +from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, LengthsType, VoidType +from nemo.core.neural_types.neural_type import NeuralType + + +def get_padding(kernel_size: int, dilation: int = 1) -> int: + return (kernel_size * dilation - dilation) // 2 + + +def get_padding_2d(kernel_size: Tuple[int, int], dilation: Tuple[int, int]) -> Tuple[int, int]: + paddings = (get_padding(kernel_size[0], dilation[0]), get_padding(kernel_size[1], dilation[1])) + return paddings + + +def get_down_sample_padding(kernel_size: int, stride: int) -> int: + return (kernel_size - stride + 1) // 2 + + +def get_up_sample_padding(kernel_size: int, stride: int) -> Tuple[int, int]: + output_padding = (kernel_size - stride) % 2 + padding = (kernel_size - stride + 1) // 2 + return padding, output_padding + + +class Conv1dNorm(NeuralModule): + def __init__( + self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: Optional[int] = None + ): + super().__init__() + if not padding: + padding = get_padding(kernel_size) + conv = nn.Conv1d( + in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding + ) + self.conv = nn.utils.weight_norm(conv) + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'C', 'T'), VoidType()), + } + + @property + def output_types(self): + return { + "out": [NeuralType(('B', 'C', 'T'), VoidType())], + } + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv) + + def forward(self, inputs): + return self.conv(inputs) + + +class ConvTranspose1dNorm(NeuralModule): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1): + super().__init__() + padding, output_padding = get_up_sample_padding(kernel_size, stride) + conv = nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + self.conv = nn.utils.weight_norm(conv) + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'C', 'T'), VoidType()), + } + + @property + def output_types(self): + return { + "out": [NeuralType(('B', 'C', 'T'), VoidType())], + } + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv) + + def forward(self, inputs): + return self.conv(inputs) + + +class Conv2dNorm(NeuralModule): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, int], + stride: Tuple[int, int] = (1, 1), + dilation: Tuple[int, int] = (1, 1), + ): + super().__init__() + assert len(kernel_size) == len(dilation) + padding = get_padding_2d(kernel_size, dilation) + conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + ) + self.conv = nn.utils.weight_norm(conv) + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'C', 'H', 'T'), VoidType()), + } + + @property + def output_types(self): + return { + "out": [NeuralType(('B', 'C', 'H', 'T'), VoidType())], + } + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv) + + def forward(self, inputs): + return self.conv(inputs) + + +class SEANetResnetBlock(NeuralModule): + def __init__(self, channels: int): + super().__init__() + self.activation = nn.ELU() + hidden_channels = channels // 2 + self.pre_conv = Conv1dNorm(in_channels=channels, out_channels=channels, kernel_size=1) + self.res_conv1 = Conv1dNorm(in_channels=channels, out_channels=hidden_channels, kernel_size=3) + self.res_conv2 = Conv1dNorm(in_channels=hidden_channels, out_channels=channels, kernel_size=1) + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'C', 'T_input'), VoidType()), + } + + @property + def output_types(self): + return { + "out": [NeuralType(('B', 'C', 'T_out'), VoidType())], + } + + def remove_weight_norm(self): + self.pre_conv.remove_weight_norm() + self.res_conv1.remove_weight_norm() + self.res_conv2.remove_weight_norm() + + def forward(self, inputs): + res = self.activation(inputs) + res = self.res_conv1(res) + res = self.activation(res) + res = self.res_conv2(res) + + out = self.pre_conv(inputs) + res + return out + + +class SEANetEncoder(NeuralModule): + def __init__( + self, + down_sample_rates: Iterable[int] = (2, 4, 5, 8), + base_channels: int = 32, + in_kernel_size: int = 7, + out_kernel_size: int = 7, + encoded_dim: int = 128, + ): + assert in_kernel_size > 0 + assert out_kernel_size > 0 + + super().__init__() + + self.down_sample_rates = down_sample_rates + self.activation = nn.ELU() + self.pre_conv = Conv1dNorm(in_channels=1, out_channels=base_channels, kernel_size=in_kernel_size) + + in_channels = base_channels + self.res_blocks = nn.ModuleList([]) + self.down_sample_conv_layers = nn.ModuleList([]) + for i, down_sample_rate in enumerate(self.down_sample_rates): + res_block = SEANetResnetBlock(channels=in_channels) + self.res_blocks.append(res_block) + + out_channels = 2 * in_channels + kernel_size = 2 * down_sample_rate + down_sample_conv = Conv1dNorm( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=down_sample_rate, + padding=get_down_sample_padding(kernel_size, down_sample_rate), + ) + in_channels = out_channels + self.down_sample_conv_layers.append(down_sample_conv) + + self.post_conv = Conv1dNorm(in_channels=in_channels, out_channels=encoded_dim, kernel_size=out_kernel_size) + + @property + def input_types(self): + return { + "audio": NeuralType(('B', 'C', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "encoded": [NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation())], + "encoded_len": [NeuralType(tuple('B'), LengthsType())], + } + + def remove_weight_norm(self): + self.pre_conv.remove_weight_norm() + for res_block in self.res_blocks: + res_block.remove_weight_norm() + for down_sample_conv in self.down_sample_conv_layers: + down_sample_conv.remove_weight_norm() + + # TODO: Add masking + def forward(self, audio, audio_len): + encoded_len = audio_len + audio = rearrange(audio, "B T -> B 1 T") + # [B, C, T_audio] + out = self.pre_conv(audio) + for res_block, down_sample_conv, down_sample_rate in zip( + self.res_blocks, self.down_sample_conv_layers, self.down_sample_rates + ): + encoded_len = torch.div(encoded_len, down_sample_rate, rounding_mode="floor") + # [B, C, T] + out = res_block(out) + out = self.activation(out) + # [B, 2 * C, T / down_sample_rate] + out = down_sample_conv(out) + + out = self.activation(out) + # [B, encoded_dim, T_encoded] + encoded = self.post_conv(out) + return encoded, encoded_len + + +class SEANetDecoder(NeuralModule): + def __init__( + self, + up_sample_rates: Iterable[int] = (8, 5, 4, 2), + base_channels: int = 512, + in_kernel_size: int = 7, + out_kernel_size: int = 3, + encoded_dim: int = 128, + ): + assert in_kernel_size > 0 + assert out_kernel_size > 0 + + super().__init__() + + self.up_sample_rates = up_sample_rates + self.activation = nn.ELU() + self.pre_conv = Conv1dNorm(in_channels=encoded_dim, out_channels=base_channels, kernel_size=in_kernel_size) + + in_channels = base_channels + self.res_blocks = nn.ModuleList([]) + self.up_sample_conv_layers = nn.ModuleList([]) + for i, up_sample_rate in enumerate(self.up_sample_rates): + out_channels = in_channels // 2 + kernel_size = 2 * up_sample_rate + up_sample_conv = ConvTranspose1dNorm( + in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=up_sample_rate + ) + in_channels = out_channels + self.up_sample_conv_layers.append(up_sample_conv) + + res_block = SEANetResnetBlock(channels=in_channels) + self.res_blocks.append(res_block) + + self.post_conv = Conv1dNorm(in_channels=in_channels, out_channels=1, kernel_size=out_kernel_size) + self.out_activation = nn.Tanh() + + @property + def input_types(self): + return { + "inputs": [NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation())], + "input_len": [NeuralType(tuple('B'), LengthsType())], + } + + @property + def output_types(self): + return { + "audio": NeuralType(('B', 'C', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + def remove_weight_norm(self): + self.pre_conv.remove_weight_norm() + for up_sample_conv in self.up_sample_conv_layers: + up_sample_conv.remove_weight_norm() + for res_block in self.res_blocks: + res_block.remove_weight_norm() + + # TODO: Add masking + def forward(self, inputs, input_len): + audio_len = input_len + # [B, C, T_encoded] + out = self.pre_conv(inputs) + for res_block, up_sample_conv, up_sample_rate in zip( + self.res_blocks, self.up_sample_conv_layers, self.up_sample_rates + ): + audio_len *= up_sample_rate + out = self.activation(out) + # [B, C / 2, T * up_sample_rate] + out = up_sample_conv(out) + out = res_block(out) + + out = self.activation(out) + # [B, 1, T_audio] + out = self.post_conv(out) + audio = self.out_activation(out) + audio = rearrange(audio, "B 1 T -> B T") + return audio, audio_len + + +class DiscriminatorSTFT(NeuralModule): + def __init__(self, resolution, lrelu_slope=0.1): + super().__init__() + + self.n_fft, self.hop_length, self.win_length = resolution + self.register_buffer("window", torch.hann_window(self.win_length, periodic=False)) + self.activation = nn.LeakyReLU(lrelu_slope) + + self.conv_layers = nn.ModuleList( + [ + Conv2dNorm(2, 32, kernel_size=(3, 9)), + Conv2dNorm(32, 32, kernel_size=(3, 9), dilation=(1, 1), stride=(1, 2)), + Conv2dNorm(32, 32, kernel_size=(3, 9), dilation=(2, 1), stride=(1, 2)), + Conv2dNorm(32, 32, kernel_size=(3, 9), dilation=(4, 1), stride=(1, 2)), + Conv2dNorm(32, 32, kernel_size=(3, 3)), + ] + ) + self.conv_post = Conv2dNorm(32, 1, kernel_size=(3, 3)) + + def stft(self, audio): + # [B, fft, T_spec] + out = torch.stft( + audio, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + normalized=True, + center=True, + return_complex=True, + ) + out = rearrange(out, "B fft T -> B 1 T fft") + # [batch, 2, T_spec, fft] + out = torch.cat([out.real, out.imag], dim=1) + return out + + @property + def input_types(self): + return { + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + } + + @property + def output_types(self): + return { + "scores": NeuralType(('B', 'C', 'T_spec'), VoidType()), + "fmap": [NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())], + } + + def forward(self, audio): + fmap = [] + + # [batch, 2, T_spec, fft] + out = self.stft(audio) + for conv in self.conv_layers: + # [batch, filters, T_spec, fft // 2**i] + out = conv(out) + out = self.activation(out) + fmap.append(out) + # [batch, 1, T_spec, fft // 8] + scores = self.conv_post(out) + fmap.append(scores) + scores = rearrange(scores, "B 1 T C -> B C T") + + return scores, fmap + + +class MultiResolutionDiscriminatorSTFT(NeuralModule): + def __init__(self, resolutions): + super().__init__() + self.discriminators = nn.ModuleList([DiscriminatorSTFT(res) for res in resolutions]) + + @property + def input_types(self): + return { + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_gen": NeuralType(('B', 'T_audio'), AudioSignal()), + } + + @property + def output_types(self): + return { + "scores_real": [NeuralType(('B', 'C', 'T_spec'), VoidType())], + "scores_gen": [NeuralType(('B', 'C', 'T_spec'), VoidType())], + "fmaps_real": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]], + "fmaps_gen": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]], + } + + def forward(self, audio_real, audio_gen): + scores_real = [] + scores_gen = [] + fmaps_real = [] + fmaps_gen = [] + + for disc in self.discriminators: + score_real, fmap_real = disc(audio=audio_real) + scores_real.append(score_real) + fmaps_real.append(fmap_real) + + score_gen, fmap_gen = disc(audio=audio_gen) + scores_gen.append(score_gen) + fmaps_gen.append(fmap_gen) + + return scores_real, scores_gen, fmaps_real, fmaps_gen diff --git a/nemo/collections/tts/modules/vector_quantization.py b/nemo/collections/tts/modules/vector_quantization.py new file mode 100644 index 000000000000..e905fddc1267 --- /dev/null +++ b/nemo/collections/tts/modules/vector_quantization.py @@ -0,0 +1,403 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import Tensor + +from nemo.collections.tts.losses.loss import MaskedLoss +from nemo.collections.tts.parts.utils.distributed import broadcast_tensors +from nemo.core.classes.common import typecheck +from nemo.core.classes.module import NeuralModule +from nemo.core.neural_types.elements import EncodedRepresentation, Index, LengthsType, LossType +from nemo.core.neural_types.neural_type import NeuralType +from nemo.utils.decorators import experimental + + +def _ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None: + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def _laplace_smoothing(inputs: Tensor, n_categories: int, epsilon: float = 1e-5) -> Tensor: + input_sum = inputs.sum() + smoothed = (inputs + epsilon) / (input_sum + n_categories * epsilon) + return input_sum * smoothed + + +def _compute_distances(input1: Tensor, input2: Tensor) -> Tensor: + """ + Compute pairwise L2 distance between two input tensors + + Args: + input1: [B, D] first tensor. + input2: [N, D] second tensor. + + Returns: + [(B, D)] tensor of distances. + """ + input2 = rearrange(input2, "N D -> D N") + distances = input1.pow(2).sum(1, keepdim=True) - (2 * input1 @ input2) + input2.pow(2).sum(0, keepdim=True) + return distances + + +def _sample_vectors(samples: Tensor, num_sample: int) -> Tensor: + """ + Randomly sample from the input batch. + + Args: + samples: [B, D] tensor with features to sample. + num_sample: Number of samples to draw. + If the value is less than or equal to B, then the samples will be unique. + If the value is greater than B, then samples will be drawn with replacement. + + Returns: + Tensor with num_sample values randomly sampled from the input batch. + """ + device = samples.device + total_samples = samples.shape[0] + + if total_samples >= num_sample: + indices = torch.randperm(total_samples, device=device)[:num_sample] + else: + indices = torch.randint(low=0, high=total_samples, size=(num_sample,), device=device) + + return samples[indices] + + +def _k_means(samples: Tensor, num_clusters: int, num_iters: int = 10) -> Tuple[Tensor, Tensor]: + """ + K-means clustering algorithm. + + Args: + samples: [B, D] tensor with features to cluster + num_clusters: K, the number of clusters. + num_iters: Number of iterations of K-means to run. + + Returns: + [K, D] cluster means and [K] bins counting how many input samples belong to each cluster + """ + assert num_iters > 0 + + input_dim = samples.shape[1] + # [K, D] + means = _sample_vectors(samples=samples, num_sample=num_clusters) + + for _ in range(num_iters): + # [B, K] + dists = _compute_distances(samples, means) + + # [N] + buckets = dists.min(dim=1).indices + buckets_repeated = repeat(buckets, "B -> B D", D=input_dim) + # [K] + bin_counts = torch.bincount(buckets, minlength=num_clusters) + bin_counts_expanded = rearrange(bin_counts, "K -> K ()") + + # [K, D] + new_means = buckets.new_zeros(num_clusters, input_dim, dtype=samples.dtype) + new_means.scatter_add_(dim=0, index=buckets_repeated, src=samples) + new_means = new_means / torch.clamp(bin_counts_expanded, min=1) + means = torch.where(bin_counts_expanded == 0, means, new_means) + + return means, bin_counts + + +@experimental +class EuclideanCodebook(NeuralModule): + """ + Codebook with Euclidean distance. + + Args: + codebook_size: Number of codes to use. + codebook_dim: Dimension of each code. + decay: Decay for exponential moving average over the codebooks. + threshold_ema_dead_code: Threshold for dead code expiration. + During every iteration, replace codes with exponential moving average cluster size less than threshold + with randomly selected values from the current batch. + kmeans_iters: Optional int, if provided codes will be initialized from the centroids learned from + kmeans_iters iterations of k-means clustering on the first training batch. + """ + + def __init__( + self, + codebook_size: int, + codebook_dim: int, + decay: float = 0.99, + threshold_ema_dead_code: Optional[int] = 2, + kmeans_iters: Optional[int] = None, + ): + super().__init__() + self.decay = decay + + if kmeans_iters: + codes = nn.init.kaiming_uniform_(torch.empty(codebook_size, codebook_dim)) + else: + codes = torch.zeros(codebook_size, codebook_dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer("initialized", Tensor([not kmeans_iters])) + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("codes", codes) + self.register_buffer("codes_avg", codes.clone()) + + @torch.jit.ignore + def _init_codes(self, data): + if self.initialized: + return + + codes, cluster_size = _k_means(samples=data, num_clusters=self.codebook_size, num_iters=self.kmeans_iters) + self.codes.data.copy_(codes) + self.codes_avg.data.copy_(codes.clone()) + self.cluster_size.data.copy_(cluster_size) + self.initialized.data.copy_(Tensor([True])) + broadcast_tensors(self.buffers()) + + def _expire_codes(self, inputs: Tensor) -> None: + if not self.threshold_ema_dead_code: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + samples = _sample_vectors(samples=inputs, num_sample=self.codebook_size) + expired_codes = rearrange(expired_codes, "K -> K ()") + modified_codes = torch.where(expired_codes, samples, self.codes) + self.codes.data.copy_(modified_codes) + + broadcast_tensors(self.buffers()) + + def _update_codes(self, inputs: Tensor, indices: Tensor) -> None: + code_onehot = F.one_hot(indices, self.codebook_size).type(inputs.dtype) + code_onehot = rearrange(code_onehot, "B N -> N B") + # [N] + code_counts = code_onehot.sum(1) + _ema_inplace(moving_avg=self.cluster_size, new=code_counts, decay=self.decay) + # [N, D] + code_sum = code_onehot @ inputs + _ema_inplace(moving_avg=self.codes_avg, new=code_sum, decay=self.decay) + + cluster_size_smoothed = _laplace_smoothing(self.cluster_size, n_categories=self.codebook_size) + cluster_size_smoothed = rearrange(cluster_size_smoothed, "N -> N ()") + codes_normalized = self.codes_avg / cluster_size_smoothed + self.codes.data.copy_(codes_normalized) + + def _quantize(self, inputs: Tensor) -> Tensor: + # [B, N] + dist = _compute_distances(inputs, self.codes) + # [B] + indices = dist.min(dim=1).indices + return indices + + def _dequantize(self, indices: Tensor) -> Tensor: + # [B, D] + quantized = F.embedding(indices, self.codes) + return quantized + + @property + def input_types(self): + return {"inputs": NeuralType(('B', 'T', 'D'), EncodedRepresentation())} + + @property + def output_types(self): + return { + "quantized": NeuralType(('B', 'T', 'D'), EncodedRepresentation()), + "indices": NeuralType(('B', 'T'), Index()), + } + + def forward(self, inputs): + input_flat = rearrange(inputs, "B T D -> (B T) D") + self._init_codes(input_flat) + # [(B T)] + indices_flat = self._quantize(inputs=input_flat) + # [B, T] + indices = indices_flat.view(*inputs.shape[:-1]) + # [B, T, D] + quantized = self._dequantize(indices=indices) + + if self.training: + # We do expiry of codes here because buffers are in sync and all the workers will make the same decision. + self._expire_codes(inputs=input_flat) + self._update_codes(inputs=input_flat, indices=indices_flat) + + return quantized, indices + + @typecheck( + input_types={"inputs": NeuralType(('B', 'T', 'D'), EncodedRepresentation())}, + output_types={"indices": NeuralType(('B', 'T'), Index())}, + ) + def encode(self, inputs): + input_flat = rearrange(inputs, "B T D -> (B T) D") + # [(B T)] + indices_flat = self._quantize(inputs=input_flat) + # [B, T] + indices = indices_flat.view(*inputs.shape[:-1]) + return indices + + @typecheck( + input_types={"indices": NeuralType(('B', 'T'), Index())}, + output_types={"quantized": NeuralType(('B', 'T', 'D'), EncodedRepresentation())}, + ) + def decode(self, indices): + # [B, T, D] + quantized = self._dequantize(indices=indices) + return quantized + + +class ResidualVectorQuantizer(NeuralModule): + """ + Residual vector quantization (RVQ) algorithm as described in https://arxiv.org/pdf/2107.03312.pdf. + + Args: + num_codebooks: Number of codebooks to use. + commit_loss_scale: Loss scale for codebook commit loss. + codebook_size: Number of codes to use for each codebook. + codebook_dim: Dimension of each code. + decay: Decay for exponential moving average over the codebooks. + threshold_ema_dead_code: Threshold for dead code expiration. + During every iteration, replace codes with exponential moving average cluster size less than threshold + with randomly selected values from the current batch. + kmeans_iters: Optional int, if provided codes will be initialized from the centroids learned from + kmeans_iters iterations of k-means clustering on the first training batch. + """ + + def __init__( + self, + num_codebooks: int, + commit_loss_scale: float = 1.0, + codebook_size: int = 1024, + codebook_dim: int = 128, + decay: float = 0.99, + threshold_ema_dead_code: Optional[int] = 2, + kmeans_iters: Optional[int] = 50, + ): + super().__init__() + self.codebook_dim = codebook_dim + + if commit_loss_scale: + self.commit_loss_fn = MaskedLoss(loss_type="l2", loss_scale=commit_loss_scale) + else: + self.commit_loss_fn = None + + self.codebooks = nn.ModuleList( + [ + EuclideanCodebook( + codebook_size=codebook_size, + codebook_dim=codebook_dim, + decay=decay, + threshold_ema_dead_code=threshold_ema_dead_code, + kmeans_iters=kmeans_iters, + ) + for _ in range(num_codebooks) + ] + ) + + def _commit_loss(self, input, target, input_len): + if not self.commit_loss_fn: + return 0.0 + + return self.commit_loss_fn( + predicted=rearrange(input, "B T D -> B D T"), + target=rearrange(target, "B T D -> B D T"), + target_len=input_len, + ) + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "input_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "quantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "indices": NeuralType(('B', 'T'), Index()), + "commit_loss": NeuralType((), LossType()), + } + + # TODO: Add Masking + def forward(self, inputs: Tensor, input_len: Tensor) -> Tuple[Tensor, Tensor, float]: + commit_loss = 0.0 + residual = rearrange(inputs, "B D T -> B T D") + + index_list = [] + quantized = torch.zeros_like(residual) + for codebook in self.codebooks: + quantized_i, indices_i = codebook(residual) + + if self.training: + quantized_i = residual + (quantized_i - residual).detach() + quantized_i_const = quantized_i.detach() + commit_loss_i = self._commit_loss(input=residual, target=quantized_i_const, input_len=input_len) + commit_loss = commit_loss + commit_loss_i + + residual = residual - quantized_i_const + + else: + residual = residual - quantized_i + + quantized = quantized + quantized_i + index_list.append(indices_i) + + # [N, B, T] + indices = torch.stack(index_list) + quantized = rearrange(quantized, "B T D -> B D T") + return quantized, indices, commit_loss + + @typecheck( + input_types={ + "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "input_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={"indices": NeuralType(('N', 'B', 'T'), Index())}, + ) + def encode(self, inputs: Tensor, input_len: Tensor) -> Tensor: + residual = rearrange(inputs, "B D T -> B T D") + index_list = [] + for codebook in self.codebooks: + # [B, T] + indices_i = codebook.encode(inputs=residual) + # [B, D, T] + quantized_i = codebook.decode(indices=indices_i) + residual = residual - quantized_i + index_list.append(indices_i) + # [N, B, T] + indices = torch.stack(index_list) + return indices + + @typecheck( + input_types={ + "indices": NeuralType(('N', 'B', 'T'), Index()), + "input_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={"quantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),}, + ) + def decode(self, indices: Tensor, input_len: Tensor) -> Tensor: + # [B, T, D] + quantized = torch.zeros([indices.shape[1], indices.shape[2], self.codebook_dim], device=indices.device) + for codebook_indices, codebook in zip(indices, self.codebooks): + quantized_i = codebook.decode(indices=codebook_indices) + quantized = quantized + quantized_i + quantized = rearrange(quantized, "B T D -> B D T") + return quantized diff --git a/nemo/collections/tts/parts/utils/callbacks.py b/nemo/collections/tts/parts/utils/callbacks.py index 2320e5b21a7c..1701c84f594a 100644 --- a/nemo/collections/tts/parts/utils/callbacks.py +++ b/nemo/collections/tts/parts/utils/callbacks.py @@ -265,6 +265,71 @@ def generate_artifacts( return audio_artifacts, [] +class EnCodecArtifactGenerator(ArtifactGenerator): + """ + Generator for logging EnCodec model outputs. + """ + + def __init__(self, log_audio: bool = True, log_encoding: bool = False): + self.log_audio = log_audio + self.log_encoding = log_encoding + + def _generate_audio(self, model, audio_ids, audio, audio_len): + if not self.log_audio: + return [] + + with torch.no_grad(): + # [B, T] + audio_pred, audio_pred_len = model(audio=audio, audio_len=audio_len) + + audio_artifacts = [] + for i, audio_id in enumerate(audio_ids): + audio_pred_i = audio_pred[i][: audio_pred_len[i]].cpu().numpy() + audio_artifact = AudioArtifact( + id=f"audio_{audio_id}", data=audio_pred_i, filename=f"{audio_id}.wav", sample_rate=model.sample_rate, + ) + audio_artifacts.append(audio_artifact) + + return audio_artifacts + + def _generate_encodings(self, model, audio_ids, audio, audio_len): + if not self.log_encoding: + return [] + + with torch.no_grad(): + # [B, D, T] + encoded, encoded_len = model.encode_audio(audio=audio, audio_len=audio_len) + + image_artifacts = [] + for i, audio_id in enumerate(audio_ids): + encoded_i = encoded[i][:, : encoded_len[i]].cpu().numpy() + encoded_artifact = ImageArtifact( + id=f"encoded_{audio_id}", + data=encoded_i, + filename=f"{audio_id}_encode.png", + x_axis="Audio Frames", + y_axis="Channels", + ) + image_artifacts.append(encoded_artifact) + + return image_artifacts + + def generate_artifacts( + self, model: LightningModule, batch_dict: Dict + ) -> Tuple[List[AudioArtifact], List[ImageArtifact]]: + + audio_filepaths = batch_dict.get("audio_filepaths") + audio_ids = [create_id(p) for p in audio_filepaths] + + audio = batch_dict.get("audio") + audio_len = batch_dict.get("audio_lens") + + audio_artifacts = self._generate_audio(model=model, audio_ids=audio_ids, audio=audio, audio_len=audio_len) + image_artifacts = self._generate_encodings(model=model, audio_ids=audio_ids, audio=audio, audio_len=audio_len) + + return audio_artifacts, image_artifacts + + class FastPitchArtifactGenerator(ArtifactGenerator): """ Generator for logging FastPitch model outputs. diff --git a/nemo/collections/tts/parts/utils/distributed.py b/nemo/collections/tts/parts/utils/distributed.py new file mode 100644 index 000000000000..75db69e67059 --- /dev/null +++ b/nemo/collections/tts/parts/utils/distributed.py @@ -0,0 +1,28 @@ +from typing import Iterable + +import torch + + +def _is_distributed(): + return torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1 + + +def _is_complex_or_float(tensor): + return torch.is_floating_point(tensor) or torch.is_complex(tensor) + + +def broadcast_tensors(tensors: Iterable[torch.Tensor], src: int = 0): + """ + Broadcast the tensors from the given parameters to all workers. + This can be used to ensure that all workers have the same model to start with. + """ + if not _is_distributed(): + return + + tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] + handles = [] + for tensor in tensors: + handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) + handles.append(handle) + for handle in handles: + handle.wait() diff --git a/tests/collections/tts/losses/test_loss.py b/tests/collections/tts/losses/test_loss.py new file mode 100644 index 000000000000..ba3f399566e1 --- /dev/null +++ b/tests/collections/tts/losses/test_loss.py @@ -0,0 +1,44 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from nemo.collections.tts.losses.loss import MaskedLoss + + +class TestTTSLoss: + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_masked_loss_l1(self): + loss_fn = MaskedLoss("l1") + target = torch.tensor([[[1.0], [2.0], [0.0]], [[3.0], [0.0], [0.0]]]).transpose(1, 2) + predicted = torch.tensor([[[0.5], [1.0], [0.0]], [[4.5], [0.0], [0.0]]]).transpose(1, 2) + target_len = torch.tensor([2, 1]) + + loss = loss_fn(predicted=predicted, target=target, target_len=target_len) + + assert loss == 1.125 + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_masked_loss_l2(self): + loss_fn = MaskedLoss("l2") + target = torch.tensor([[[1.0], [2.0], [4.0]], [[3.0], [0.0], [0.0]]]).transpose(1, 2) + predicted = torch.tensor([[[0.5], [1.0], [4.0]], [[4.5], [0.0], [0.0]]]).transpose(1, 2) + target_len = torch.tensor([3, 1]) + + loss = loss_fn(predicted=predicted, target=target, target_len=target_len) + + assert loss == (4 / 3) From 583301183372700c25f025c857577ae1cced9c5e Mon Sep 17 00:00:00 2001 From: Ryan Date: Tue, 20 Jun 2023 09:34:42 -0700 Subject: [PATCH 2/5] [TTS] Update encodec recipe Signed-off-by: Ryan --- .../encodec/bottleneck/bottleneck_noise.yaml | 2 - .../encodec/bottleneck/bottleneck_vq.yaml | 4 - examples/tts/conf/encodec/encodec.yaml | 59 ++++++--- .../tts/conf/encodec/sample/sample_22050.yaml | 4 - .../tts/conf/encodec/sample/sample_24000.yaml | 4 - .../tts/conf/encodec/sample/sample_44100.yaml | 4 - .../tts/conf/encodec/sample/sample_48000.yaml | 4 - nemo/collections/tts/losses/encodec_loss.py | 122 +++++++++++++++--- nemo/collections/tts/losses/loss.py | 48 ------- nemo/collections/tts/models/encodec.py | 60 ++++++--- .../tts/modules/encodec_modules.py | 105 ++++++++++++--- .../tts/modules/vector_quantization.py | 52 ++++++-- nemo/collections/tts/parts/utils/callbacks.py | 48 +++++-- nemo/collections/tts/parts/utils/helpers.py | 5 +- .../{test_loss.py => test_encodec_loss.py} | 8 +- 15 files changed, 350 insertions(+), 179 deletions(-) delete mode 100644 examples/tts/conf/encodec/bottleneck/bottleneck_noise.yaml delete mode 100644 examples/tts/conf/encodec/bottleneck/bottleneck_vq.yaml delete mode 100644 examples/tts/conf/encodec/sample/sample_22050.yaml delete mode 100644 examples/tts/conf/encodec/sample/sample_24000.yaml delete mode 100644 examples/tts/conf/encodec/sample/sample_44100.yaml delete mode 100644 examples/tts/conf/encodec/sample/sample_48000.yaml delete mode 100644 nemo/collections/tts/losses/loss.py rename tests/collections/tts/losses/{test_loss.py => test_encodec_loss.py} (89%) diff --git a/examples/tts/conf/encodec/bottleneck/bottleneck_noise.yaml b/examples/tts/conf/encodec/bottleneck/bottleneck_noise.yaml deleted file mode 100644 index 41136e35e275..000000000000 --- a/examples/tts/conf/encodec/bottleneck/bottleneck_noise.yaml +++ /dev/null @@ -1,2 +0,0 @@ -encoder_noise_stdev: 0.2 -vector_quantizer: null \ No newline at end of file diff --git a/examples/tts/conf/encodec/bottleneck/bottleneck_vq.yaml b/examples/tts/conf/encodec/bottleneck/bottleneck_vq.yaml deleted file mode 100644 index 8efade7dfa17..000000000000 --- a/examples/tts/conf/encodec/bottleneck/bottleneck_vq.yaml +++ /dev/null @@ -1,4 +0,0 @@ -encoder_noise_stdev: 0.0 -vector_quantizer: - _target_: nemo.collections.tts.modules.vector_quantization.ResidualVectorQuantizer - num_codebooks: 8 \ No newline at end of file diff --git a/examples/tts/conf/encodec/encodec.yaml b/examples/tts/conf/encodec/encodec.yaml index 7ef1341f3665..cbaae4259aa7 100644 --- a/examples/tts/conf/encodec/encodec.yaml +++ b/examples/tts/conf/encodec/encodec.yaml @@ -1,32 +1,44 @@ -# This config contains the default values for training EnCodec +# This config contains the default values for training 24khz EnCodec # If you want to train model on other dataset, you can change config values according to your dataset. # Most dataset-specific arguments are in the head of the config file, see below. name: EnCodec -defaults: - - sample: ??? - - bottleneck: ??? - max_epochs: ??? +# Adjust batch size based on GPU memory batch_size: 16 +# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch. +# If null, then weighted sampling is disabled. weighted_sampling_steps_per_epoch: null +# Dataset metadata for each manifest +# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 train_ds_meta: ??? val_ds_meta: ??? -log_ds_meta: ??? +log_ds_meta: ??? log_dir: ??? +# Modify these values based on your sample rate +sample_rate: 24000 +train_n_samples: 24000 +down_sample_rates: [2, 4, 5, 8] +up_sample_rates: [8, 5, 4, 2] +# The number of samples per encoded audio frame. Should be the product of the down_sample_rates. +# For example 2 * 4 * 5 * 8 = 320. +samples_per_frame: 320 + model: max_epochs: ${max_epochs} steps_per_epoch: ${weighted_sampling_steps_per_epoch} - sample_rate: ${sample.sample_rate} - disc_update_steps: 2 - encoder_noise_stdev: ${bottleneck.encoder_noise_stdev} + sample_rate: ${sample_rate} + samples_per_frame: ${samples_per_frame} + # Probability of updating the discriminator during each training step + disc_update_prob: 0.67 + # All resolutions for mel reconstruction loss, ordered [num_fft, hop_length, window_length] mel_loss_resolutions: [ [32, 8, 32], [64, 16, 64], [128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048] ] @@ -35,8 +47,8 @@ model: dataset: _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch} - sample_rate: ${sample.sample_rate} - n_samples: ${sample.train_n_samples} + sample_rate: ${sample_rate} + n_samples: ${train_n_samples} min_duration: 1.01 max_duration: null dataset_meta: ${train_ds_meta} @@ -49,17 +61,19 @@ model: validation_ds: dataset: _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset - sample_rate: ${sample.sample_rate} + sample_rate: ${sample_rate} n_samples: null min_duration: null max_duration: null - trunc_duration: 10.0 + trunc_duration: 10.0 # Only use the first 10 seconds of audio for computing validation loss dataset_meta: ${val_ds_meta} dataloader_params: batch_size: 8 num_workers: 2 + # Configures how audio samples are generated and saved during training. + # Remove this section to disable logging. log_config: log_dir: ${log_dir} log_epochs: [10, 50] @@ -71,14 +85,15 @@ model: - _target_: nemo.collections.tts.parts.utils.callbacks.EnCodecArtifactGenerator log_audio: true log_encoding: true + log_quantized: true dataset: _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset - sample_rate: ${sample.sample_rate} + sample_rate: ${sample_rate} n_samples: null min_duration: null max_duration: null - trunc_duration: 15.0 + trunc_duration: 15.0 # Only log the first 15 seconds of generated audio. dataset_meta: ${log_ds_meta} dataloader_params: @@ -87,13 +102,15 @@ model: audio_encoder: _target_: nemo.collections.tts.modules.encodec_modules.SEANetEncoder - down_sample_rates: ${sample.down_sample_rates} + down_sample_rates: ${down_sample_rates} - generator: + audio_decoder: _target_: nemo.collections.tts.modules.encodec_modules.SEANetDecoder - up_sample_rates: ${sample.up_sample_rates} + up_sample_rates: ${up_sample_rates} - vector_quantizer: ${bottleneck.vector_quantizer} + vector_quantizer: + _target_: nemo.collections.tts.modules.vector_quantization.ResidualVectorQuantizer + num_codebooks: 8 discriminator: _target_: nemo.collections.tts.modules.encodec_modules.MultiResolutionDiscriminatorSTFT @@ -113,10 +130,10 @@ trainer: devices: 1 accelerator: gpu strategy: ddp - precision: 32 + precision: 32 # Vector quantization only works with 32-bit precision. max_epochs: ${max_epochs} accumulate_grad_batches: 1 - enable_checkpointing: False # Provided by exp_manager + enable_checkpointing: False # Provided by exp_manager logger: false # Provided by exp_manager log_every_n_steps: 100 check_val_every_n_epoch: 5 diff --git a/examples/tts/conf/encodec/sample/sample_22050.yaml b/examples/tts/conf/encodec/sample/sample_22050.yaml deleted file mode 100644 index 93728e92a589..000000000000 --- a/examples/tts/conf/encodec/sample/sample_22050.yaml +++ /dev/null @@ -1,4 +0,0 @@ -sample_rate: 22050 -train_n_samples: 22016 -down_sample_rates: [2, 4, 4, 8] -up_sample_rates: [8, 4, 4, 2] diff --git a/examples/tts/conf/encodec/sample/sample_24000.yaml b/examples/tts/conf/encodec/sample/sample_24000.yaml deleted file mode 100644 index 35d61a48cbc3..000000000000 --- a/examples/tts/conf/encodec/sample/sample_24000.yaml +++ /dev/null @@ -1,4 +0,0 @@ -sample_rate: 24000 -train_n_samples: 24000 -down_sample_rates: [2, 4, 5, 8] -up_sample_rates: [8, 5, 4, 2] diff --git a/examples/tts/conf/encodec/sample/sample_44100.yaml b/examples/tts/conf/encodec/sample/sample_44100.yaml deleted file mode 100644 index 1519c7b1d4fc..000000000000 --- a/examples/tts/conf/encodec/sample/sample_44100.yaml +++ /dev/null @@ -1,4 +0,0 @@ -sample_rate: 44100 -train_n_samples: 44032 -down_sample_rates: [2, 4, 8, 8] -up_sample_rates: [8, 8, 4, 2] diff --git a/examples/tts/conf/encodec/sample/sample_48000.yaml b/examples/tts/conf/encodec/sample/sample_48000.yaml deleted file mode 100644 index b13af98f6967..000000000000 --- a/examples/tts/conf/encodec/sample/sample_48000.yaml +++ /dev/null @@ -1,4 +0,0 @@ -sample_rate: 48000 -train_n_samples: 48000 -down_sample_rates: [2, 4, 5, 8] -up_sample_rates: [8, 5, 4, 2] diff --git a/nemo/collections/tts/losses/encodec_loss.py b/nemo/collections/tts/losses/encodec_loss.py index 38e21bac6ac3..7448b2decabe 100644 --- a/nemo/collections/tts/losses/encodec_loss.py +++ b/nemo/collections/tts/losses/encodec_loss.py @@ -16,21 +16,104 @@ import torch import torch.nn.functional as F +from einops import rearrange from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures -from nemo.collections.tts.losses.loss import MaskedLoss from nemo.core.classes import Loss, typecheck -from nemo.core.neural_types import AudioSignal, LengthsType, LossType, NeuralType, VoidType +from nemo.core.neural_types import ( + AudioSignal, + LengthsType, + LossType, + NeuralType, + PredictionsType, + RegressionValuesType, + VoidType, +) + + +class MaskedLoss(Loss): + def __init__(self, loss_fn, loss_scale: float = 1.0): + super(MaskedLoss, self).__init__() + self.loss_scale = loss_scale + self.loss_fn = loss_fn + + @property + def input_types(self): + return { + "target": NeuralType(('B', 'D', 'T'), RegressionValuesType()), + "predicted": NeuralType(('B', 'D', 'T'), PredictionsType()), + "target_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "loss": NeuralType(elements_type=LossType()), + } + + @typecheck() + def forward(self, predicted, target, target_len): + assert target.shape[2] == predicted.shape[2] + + # [B, D, T] + loss = self.loss_fn(input=predicted, target=target) + # [B, T] + loss = torch.mean(loss, dim=1) + # [B] + loss = torch.sum(loss, dim=1) / torch.clamp(target_len, min=1.0) + + # [1] + loss = torch.mean(loss) + loss = self.loss_scale * loss + + return loss + + +class MaskedMAELoss(MaskedLoss): + def __init__(self, loss_scale: float = 1.0): + loss_fn = torch.nn.L1Loss(reduction='none') + super(MaskedMAELoss, self).__init__(loss_fn=loss_fn, loss_scale=loss_scale) + + +class MaskedMSELoss(MaskedLoss): + def __init__(self, loss_scale: float = 1.0): + loss_fn = torch.nn.MSELoss(reduction='none') + super(MaskedMSELoss, self).__init__(loss_fn=loss_fn, loss_scale=loss_scale) + + +class AudioLoss(Loss): + def __init__(self): + super(AudioLoss, self).__init__() + self.loss_fn = MaskedMAELoss() + + @property + def input_types(self): + return { + "audio_real": NeuralType(('B', 'T'), AudioSignal()), + "audio_gen": NeuralType(('B', 'T'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "loss": [NeuralType(elements_type=LossType())], + } + + @typecheck() + def forward(self, audio_real, audio_gen, audio_len): + audio_real = rearrange(audio_real, "B T -> B 1 T") + audio_gen = rearrange(audio_gen, "B T -> B 1 T") + loss = self.loss_fn(target=audio_real, predicted=audio_gen, target_len=audio_len) + return loss class MultiResolutionMelLoss(Loss): - def __init__( - self, sample_rate: int, mel_dim: int, resolutions: List[List], l1_scale: float = 1.0, l2_scale: float = 1.0 - ): + def __init__(self, sample_rate: int, mel_dim: int, resolutions: List[List], l1_scale: float = 1.0): super(MultiResolutionMelLoss, self).__init__() - self.l1_loss_fn = MaskedLoss("l1", loss_scale=l1_scale) - self.l2_loss_fn = MaskedLoss("l2", loss_scale=l2_scale) + self.l1_loss_fn = MaskedMAELoss(loss_scale=l1_scale) + self.l2_loss_fn = MaskedMSELoss() self.mel_features = torch.nn.ModuleList() for n_fft, hop_len, win_len in resolutions: @@ -68,22 +151,23 @@ def output_types(self): @typecheck() def forward(self, audio_real, audio_gen, audio_len): - len_diff = audio_real.shape[1] - audio_gen.shape[1] - audio_gen = F.pad(audio_gen, (0, len_diff)) - loss = 0.0 for mel_feature in self.mel_features: mel_real, mel_real_len = mel_feature(x=audio_real, seq_len=audio_len) mel_gen, _ = mel_feature(x=audio_gen, seq_len=audio_len) - loss = loss + self.l1_loss_fn(predicted=mel_gen, target=mel_real, target_len=mel_real_len) - loss = loss + self.l2_loss_fn(predicted=mel_gen, target=mel_real, target_len=mel_real_len) + loss += self.l1_loss_fn(predicted=mel_gen, target=mel_real, target_len=mel_real_len) + loss += self.l2_loss_fn(predicted=mel_gen, target=mel_real, target_len=mel_real_len) loss /= len(self.mel_features) return loss -class FeatureMatchingLoss(Loss): +class RelativeFeatureMatchingLoss(Loss): + def __init__(self, div_guard=1e-2): + super(RelativeFeatureMatchingLoss, self).__init__() + self.div_guard = div_guard + @property def input_types(self): return { @@ -106,10 +190,10 @@ def forward(self, fmaps_real, fmaps_gen): # [B, ...] feat_mean = torch.mean(torch.abs(feat_real), dim=-1) diff = torch.mean(torch.abs(feat_real - feat_gen), dim=-1) - feat_loss = diff / (feat_mean + 1e-2) + feat_loss = diff / (feat_mean + self.div_guard) # [1] feat_loss = torch.mean(feat_loss) / len(fmap_real) - loss = loss + feat_loss + loss += feat_loss loss /= len(fmaps_real) @@ -131,7 +215,7 @@ def output_types(self): def forward(self, disc_scores): loss = 0.0 for disc_score in disc_scores: - loss = loss + torch.mean((1 - disc_score) ** 2) + loss += torch.mean(F.relu(1 - disc_score)) loss /= len(disc_scores) @@ -154,9 +238,9 @@ def output_types(self): def forward(self, disc_scores_real, disc_scores_gen): loss = 0.0 for disc_score_real, disc_score_gen in zip(disc_scores_real, disc_scores_gen): - loss_real = torch.mean((1 - disc_score_real) ** 2) - loss_gen = torch.mean(disc_score_gen ** 2) - loss = loss + (loss_real + loss_gen) / 2 + loss_real = torch.mean(F.relu(1 - disc_score_real)) + loss_gen = torch.mean(F.relu(1 + disc_score_gen)) + loss += (loss_real + loss_gen) / 2 loss /= len(disc_scores_real) diff --git a/nemo/collections/tts/losses/loss.py b/nemo/collections/tts/losses/loss.py deleted file mode 100644 index 4eae8d72f7a0..000000000000 --- a/nemo/collections/tts/losses/loss.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch - -from nemo.core.classes import Loss, typecheck -from nemo.core.neural_types import LengthsType, LossType, NeuralType, PredictionsType, RegressionValuesType - - -class MaskedLoss(Loss): - def __init__(self, loss_type: str, loss_scale: float = 1.0): - super(MaskedLoss, self).__init__() - self.loss_scale = loss_scale - - if loss_type == "l1": - self.loss_fn = torch.nn.L1Loss(reduction='none') - elif loss_type == "l2": - self.loss_fn = torch.nn.MSELoss(reduction='none') - else: - raise ValueError(f"Unknown loss type {loss_type}") - - @property - def input_types(self): - return { - "target": NeuralType(('B', 'D', 'T'), RegressionValuesType()), - "predicted": NeuralType(('B', 'D', 'T'), PredictionsType()), - "target_len": NeuralType(tuple('B'), LengthsType()), - } - - @property - def output_types(self): - return { - "loss": NeuralType(elements_type=LossType()), - } - - @typecheck() - def forward(self, predicted, target, target_len): - assert target.shape[2] == predicted.shape[2] - - # [B, D, T] - loss = self.loss_fn(input=predicted, target=target) - # [B, T] - loss = torch.mean(loss, dim=1) - # [B] - loss = torch.sum(loss, dim=1) / torch.clamp(target_len, min=1.0) - - # [1] - loss = torch.mean(loss) - loss = self.loss_scale * loss - - return loss diff --git a/nemo/collections/tts/models/encodec.py b/nemo/collections/tts/models/encodec.py index bfb55cb307c5..b0480588044e 100644 --- a/nemo/collections/tts/models/encodec.py +++ b/nemo/collections/tts/models/encodec.py @@ -13,19 +13,23 @@ # limitations under the License. import itertools +import math +import random from pathlib import Path from typing import List, Tuple import torch +import torch.nn.functional as F from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf from pytorch_lightning import Trainer from nemo.collections.tts.losses.encodec_loss import ( + AudioLoss, DiscriminatorLoss, - FeatureMatchingLoss, GeneratorLoss, MultiResolutionMelLoss, + RelativeFeatureMatchingLoss, ) from nemo.collections.tts.modules.common import GaussianDropout from nemo.collections.tts.parts.utils.callbacks import LoggingCallback @@ -51,7 +55,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg=cfg, trainer=trainer) self.sample_rate = cfg.sample_rate - self.disc_update_steps = cfg.get("disc_update_steps", 1) + self.samples_per_frame = cfg.samples_per_frame + + self.disc_update_prob = cfg.get("disc_update_prob", 1.0) self.audio_encoder = instantiate(cfg.audio_encoder) # Optionally, add gaussian noise to encoder output as an information bottleneck @@ -66,25 +72,26 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): else: self.vector_quantizer = None - self.generator = instantiate(cfg.generator) + self.audio_decoder = instantiate(cfg.audio_decoder) self.discriminator = instantiate(cfg.discriminator) mel_loss_dim = cfg.get("mel_loss_dim", 64) mel_loss_resolutions = cfg.mel_loss_resolutions - mel_loss_l1_scale = cfg.get("mel_loss_l1_scale", 0.1) - mel_loss_l2_scale = cfg.get("mel_loss_l2_scale", 1.0) - self.gen_loss_scale = cfg.get("gen_loss_scale", 4.0) - self.feature_loss_scale = cfg.get("feature_loss_scale", 4.0) + self.audio_loss_scale = cfg.get("audio_loss_scale", 0.1) + self.mel_loss_scale = cfg.get("mel_loss_scale", 1.0) + mel_loss_l1_scale = cfg.get("mel_loss_l1_scale", 1.0) + self.gen_loss_scale = cfg.get("gen_loss_scale", 3.0) + self.feature_loss_scale = cfg.get("feature_loss_scale", 3.0) + self.audio_loss_fn = AudioLoss() self.mel_loss_fn = MultiResolutionMelLoss( sample_rate=self.sample_rate, mel_dim=mel_loss_dim, resolutions=mel_loss_resolutions, l1_scale=mel_loss_l1_scale, - l2_scale=mel_loss_l2_scale, ) self.gen_loss_fn = GeneratorLoss() - self.feature_loss_fn = FeatureMatchingLoss() + self.feature_loss_fn = RelativeFeatureMatchingLoss() self.disc_loss_fn = DiscriminatorLoss() self.log_config = cfg.get("log_config", None) @@ -102,6 +109,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): }, ) def encode_audio(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + audio, audio_len = self.pad_audio(audio, audio_len) encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len) return encoded, encoded_len @@ -116,7 +124,7 @@ def encode_audio(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[to }, ) def decode_audio(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - audio, audio_len = self.generator(inputs=inputs, input_len=input_len) + audio, audio_len = self.audio_decoder(inputs=inputs, input_len=input_len) return audio, audio_len @typecheck( @@ -158,6 +166,7 @@ def quantize_decode(self, indices: torch.Tensor, encoded_len: torch.Tensor) -> t }, ) def forward(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + audio, audio_len = self.pad_audio(audio, audio_len) encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len) if self.vector_quantizer: @@ -169,11 +178,20 @@ def forward(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.T return output_audio, output_audio_len + # Zero pad the end of the audio so that we do not have a partial end frame. + def pad_audio(self, audio, audio_len): + padded_len = self.samples_per_frame * torch.ceil(audio_len / self.samples_per_frame).int() + max_len = padded_len.max().item() + num_padding = max_len - audio.shape[1] + padded_audio = F.pad(audio, (0, num_padding)) + return padded_audio, padded_len + def _process_batch(self, batch): # [B, T_audio] audio = batch.get("audio") # [B] audio_len = batch.get("audio_lens") + audio, audio_len = self.pad_audio(audio, audio_len) # [B, D, T_encoded] encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len) @@ -187,18 +205,17 @@ def _process_batch(self, batch): commit_loss = None # [B, T] - audio_gen, audio_gen_len = self.generator(inputs=encoded, input_len=encoded_len) + audio_gen, audio_gen_len = self.audio_decoder(inputs=encoded, input_len=encoded_len) return audio, audio_len, audio_gen, commit_loss def training_step(self, batch, batch_idx): - optim_gen, optim_disc = self.optimizers() optim_gen.zero_grad() audio, audio_len, audio_gen, commit_loss = self._process_batch(batch) - if batch_idx % self.disc_update_steps != 0: + if self.disc_update_prob < random.random(): loss_disc = None else: # Train discriminator @@ -212,7 +229,11 @@ def training_step(self, batch, batch_idx): self.manual_backward(loss_disc) optim_disc.step() + loss_audio = self.audio_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) + train_loss_audio = self.audio_loss_scale * loss_audio + loss_mel = self.mel_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) + train_loss_mel = self.mel_loss_scale * loss_mel _, disc_scores_gen, fmaps_real, fmaps_gen = self.discriminator(audio_real=audio, audio_gen=audio_gen) @@ -222,7 +243,7 @@ def training_step(self, batch, batch_idx): loss_feature = self.feature_loss_fn(fmaps_real=fmaps_real, fmaps_gen=fmaps_gen) train_loss_feature = self.feature_loss_scale * loss_feature - loss_gen_all = loss_mel + train_loss_gen + train_loss_feature + loss_gen_all = train_loss_audio + train_loss_mel + train_loss_gen + train_loss_feature if commit_loss is not None: loss_gen_all += commit_loss @@ -232,6 +253,7 @@ def training_step(self, batch, batch_idx): self.update_lr() metrics = { + "g_loss_audio": loss_audio, "g_loss_mel": loss_mel, "g_loss_gen": loss_gen, "g_loss_feature": loss_feature, @@ -247,15 +269,17 @@ def training_step(self, batch, batch_idx): metrics["g_loss_commit"] = commit_loss self.log_dict(metrics, on_step=True, sync_dist=True) - self.log("t_loss", loss_mel, prog_bar=True, logger=False, sync_dist=True) + self.log("t_loss", train_loss_mel, prog_bar=True, logger=False, sync_dist=True) def training_epoch_end(self, outputs): self.update_lr("epoch") def validation_step(self, batch, batch_idx): audio, audio_len, audio_gen, _ = self._process_batch(batch) + loss_audio = self.audio_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) loss_mel = self.mel_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) - self.log_dict({"val_loss": loss_mel}, on_epoch=True, sync_dist=True) + metrics = {"val_loss": loss_audio + loss_mel, "val_loss_audio": loss_audio, "val_loss_mel": loss_mel} + self.log_dict(metrics, on_epoch=True, sync_dist=True) @staticmethod def _setup_train_dataloader(cfg): @@ -309,7 +333,7 @@ def configure_optimizers(self): sched_config = optim_config.pop("sched", None) OmegaConf.set_struct(optim_config, True) - gen_params = itertools.chain(self.audio_encoder.parameters(), self.generator.parameters()) + gen_params = itertools.chain(self.audio_encoder.parameters(), self.audio_decoder.parameters()) disc_params = self.discriminator.parameters() optim_g = instantiate(optim_config, params=gen_params) optim_d = instantiate(optim_config, params=disc_params) @@ -342,7 +366,7 @@ def update_lr(self, interval="step"): def configure_callbacks(self): if not self.log_config: - return + return [] data_loader = self._setup_test_dataloader(self.log_config) generators = instantiate(self.log_config.generators) diff --git a/nemo/collections/tts/modules/encodec_modules.py b/nemo/collections/tts/modules/encodec_modules.py index 2af6aeded42b..ba1c8aa9a348 100644 --- a/nemo/collections/tts/modules/encodec_modules.py +++ b/nemo/collections/tts/modules/encodec_modules.py @@ -18,6 +18,7 @@ import torch.nn as nn from einops import rearrange +from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor from nemo.core.classes.module import NeuralModule from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, LengthsType, VoidType from nemo.core.neural_types.neural_type import NeuralType @@ -50,7 +51,12 @@ def __init__( if not padding: padding = get_padding(kernel_size) conv = nn.Conv1d( - in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + padding_mode="reflect", ) self.conv = nn.utils.weight_norm(conv) @@ -58,6 +64,7 @@ def __init__( def input_types(self): return { "inputs": NeuralType(('B', 'C', 'T'), VoidType()), + "lengths": NeuralType(tuple('B'), LengthsType()), } @property @@ -69,8 +76,10 @@ def output_types(self): def remove_weight_norm(self): nn.utils.remove_weight_norm(self.conv) - def forward(self, inputs): - return self.conv(inputs) + def forward(self, inputs, lengths): + out = self.conv(inputs) + out = mask_sequence_tensor(out, lengths) + return out class ConvTranspose1dNorm(NeuralModule): @@ -84,6 +93,7 @@ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride stride=stride, padding=padding, output_padding=output_padding, + padding_mode="zeros", ) self.conv = nn.utils.weight_norm(conv) @@ -91,6 +101,7 @@ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride def input_types(self): return { "inputs": NeuralType(('B', 'C', 'T'), VoidType()), + "lengths": NeuralType(tuple('B'), LengthsType()), } @property @@ -102,8 +113,10 @@ def output_types(self): def remove_weight_norm(self): nn.utils.remove_weight_norm(self.conv) - def forward(self, inputs): - return self.conv(inputs) + def forward(self, inputs, lengths): + out = self.conv(inputs) + out = mask_sequence_tensor(out, lengths) + return out class Conv2dNorm(NeuralModule): @@ -125,6 +138,7 @@ def __init__( stride=stride, dilation=dilation, padding=padding, + padding_mode="reflect", ) self.conv = nn.utils.weight_norm(conv) @@ -160,6 +174,7 @@ def __init__(self, channels: int): def input_types(self): return { "inputs": NeuralType(('B', 'C', 'T_input'), VoidType()), + "lengths": NeuralType(tuple('B'), LengthsType()), } @property @@ -173,13 +188,54 @@ def remove_weight_norm(self): self.res_conv1.remove_weight_norm() self.res_conv2.remove_weight_norm() - def forward(self, inputs): + def forward(self, inputs, lengths): res = self.activation(inputs) - res = self.res_conv1(res) + res = self.res_conv1(res, lengths) res = self.activation(res) - res = self.res_conv2(res) + res = self.res_conv2(res, lengths) + + out = self.pre_conv(inputs, lengths) + res + out = mask_sequence_tensor(out, lengths) + return out + + +class SEANetRNN(NeuralModule): + def __init__(self, dim: int, num_layers: int, rnn_type: str = "lstm", use_skip: bool = False): + super().__init__() + self.use_skip = use_skip + if rnn_type == "lstm": + self.rnn = torch.nn.LSTM(input_size=dim, hidden_size=dim, num_layers=num_layers) + elif rnn_type == "gru": + self.rnn = torch.nn.GRU(input_size=dim, hidden_size=dim, num_layers=num_layers) + else: + raise ValueError(f"Unknown RNN type {rnn_type}") - out = self.pre_conv(inputs) + res + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'C', 'T'), VoidType()), + "lengths": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "out": [NeuralType(('B', 'C', 'T'), VoidType())], + } + + def forward(self, inputs, lengths): + inputs = rearrange(inputs, "B C T -> B T C") + + packed_inputs = nn.utils.rnn.pack_padded_sequence( + inputs, lengths=lengths.cpu(), batch_first=True, enforce_sorted=False + ) + packed_out, _ = self.rnn(packed_inputs) + out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True) + + if self.use_skip: + out = out + inputs + + out = rearrange(out, "B T C -> B C T") return out @@ -191,6 +247,9 @@ def __init__( in_kernel_size: int = 7, out_kernel_size: int = 7, encoded_dim: int = 128, + rnn_layers: int = 2, + rnn_type: str = "lstm", + rnn_skip: bool = True, ): assert in_kernel_size > 0 assert out_kernel_size > 0 @@ -220,6 +279,7 @@ def __init__( in_channels = out_channels self.down_sample_conv_layers.append(down_sample_conv) + self.rnn = SEANetRNN(dim=in_channels, num_layers=rnn_layers, rnn_type=rnn_type, use_skip=rnn_skip) self.post_conv = Conv1dNorm(in_channels=in_channels, out_channels=encoded_dim, kernel_size=out_kernel_size) @property @@ -243,25 +303,26 @@ def remove_weight_norm(self): for down_sample_conv in self.down_sample_conv_layers: down_sample_conv.remove_weight_norm() - # TODO: Add masking def forward(self, audio, audio_len): encoded_len = audio_len audio = rearrange(audio, "B T -> B 1 T") # [B, C, T_audio] - out = self.pre_conv(audio) + out = self.pre_conv(audio, encoded_len) for res_block, down_sample_conv, down_sample_rate in zip( self.res_blocks, self.down_sample_conv_layers, self.down_sample_rates ): - encoded_len = torch.div(encoded_len, down_sample_rate, rounding_mode="floor") # [B, C, T] - out = res_block(out) + out = res_block(out, encoded_len) out = self.activation(out) + + encoded_len = encoded_len // down_sample_rate # [B, 2 * C, T / down_sample_rate] - out = down_sample_conv(out) + out = down_sample_conv(out, encoded_len) + out = self.rnn(out, encoded_len) out = self.activation(out) # [B, encoded_dim, T_encoded] - encoded = self.post_conv(out) + encoded = self.post_conv(out, encoded_len) return encoded, encoded_len @@ -273,6 +334,9 @@ def __init__( in_kernel_size: int = 7, out_kernel_size: int = 3, encoded_dim: int = 128, + rnn_layers: int = 2, + rnn_type: str = "lstm", + rnn_skip: bool = True, ): assert in_kernel_size > 0 assert out_kernel_size > 0 @@ -282,6 +346,7 @@ def __init__( self.up_sample_rates = up_sample_rates self.activation = nn.ELU() self.pre_conv = Conv1dNorm(in_channels=encoded_dim, out_channels=base_channels, kernel_size=in_kernel_size) + self.rnn = SEANetRNN(dim=base_channels, num_layers=rnn_layers, rnn_type=rnn_type, use_skip=rnn_skip) in_channels = base_channels self.res_blocks = nn.ModuleList([]) @@ -322,23 +387,23 @@ def remove_weight_norm(self): for res_block in self.res_blocks: res_block.remove_weight_norm() - # TODO: Add masking def forward(self, inputs, input_len): audio_len = input_len # [B, C, T_encoded] - out = self.pre_conv(inputs) + out = self.pre_conv(inputs, audio_len) + out = self.rnn(out, audio_len) for res_block, up_sample_conv, up_sample_rate in zip( self.res_blocks, self.up_sample_conv_layers, self.up_sample_rates ): audio_len *= up_sample_rate out = self.activation(out) # [B, C / 2, T * up_sample_rate] - out = up_sample_conv(out) - out = res_block(out) + out = up_sample_conv(out, audio_len) + out = res_block(out, audio_len) out = self.activation(out) # [B, 1, T_audio] - out = self.post_conv(out) + out = self.post_conv(out, audio_len) audio = self.out_activation(out) audio = rearrange(audio, "B 1 T -> B T") return audio, audio_len diff --git a/nemo/collections/tts/modules/vector_quantization.py b/nemo/collections/tts/modules/vector_quantization.py index e905fddc1267..3b4d9fa3ab53 100644 --- a/nemo/collections/tts/modules/vector_quantization.py +++ b/nemo/collections/tts/modules/vector_quantization.py @@ -20,8 +20,9 @@ from einops import rearrange, repeat from torch import Tensor -from nemo.collections.tts.losses.loss import MaskedLoss +from nemo.collections.tts.losses.encodec_loss import MaskedMSELoss from nemo.collections.tts.parts.utils.distributed import broadcast_tensors +from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor from nemo.core.classes.common import typecheck from nemo.core.classes.module import NeuralModule from nemo.core.neural_types.elements import EncodedRepresentation, Index, LengthsType, LossType @@ -117,6 +118,22 @@ def _k_means(samples: Tensor, num_clusters: int, num_iters: int = 10) -> Tuple[T return means, bin_counts +def _mask_3d(tensor: Tensor, lengths: Tensor): + """ + Mask 3d tensor with time on 1st axis. + + Args: + tensor: tensor of shape (B, T, D) + lengths: LongTensor of shape (B,) + Returns: + Masked Tensor (B, T, D) + """ + batch_size, max_lengths, _ = tensor.shape + mask = torch.ones(batch_size, max_lengths, 1).cumsum(dim=1).type_as(lengths) + mask = mask <= rearrange(lengths, "b -> b 1 1") + return tensor * mask + + @experimental class EuclideanCodebook(NeuralModule): """ @@ -215,7 +232,10 @@ def _dequantize(self, indices: Tensor) -> Tensor: @property def input_types(self): - return {"inputs": NeuralType(('B', 'T', 'D'), EncodedRepresentation())} + return { + "inputs": NeuralType(('B', 'T', 'D'), EncodedRepresentation()), + "input_len": NeuralType(tuple('B'), LengthsType()), + } @property def output_types(self): @@ -224,7 +244,7 @@ def output_types(self): "indices": NeuralType(('B', 'T'), Index()), } - def forward(self, inputs): + def forward(self, inputs, input_len): input_flat = rearrange(inputs, "B T D -> (B T) D") self._init_codes(input_flat) # [(B T)] @@ -239,27 +259,34 @@ def forward(self, inputs): self._expire_codes(inputs=input_flat) self._update_codes(inputs=input_flat, indices=indices_flat) + quantized = _mask_3d(quantized, input_len) + indices = mask_sequence_tensor(indices, input_len) return quantized, indices @typecheck( - input_types={"inputs": NeuralType(('B', 'T', 'D'), EncodedRepresentation())}, + input_types={ + "inputs": NeuralType(('B', 'T', 'D'), EncodedRepresentation()), + "input_len": NeuralType(tuple('B'), LengthsType()), + }, output_types={"indices": NeuralType(('B', 'T'), Index())}, ) - def encode(self, inputs): + def encode(self, inputs, input_len): input_flat = rearrange(inputs, "B T D -> (B T) D") # [(B T)] indices_flat = self._quantize(inputs=input_flat) # [B, T] indices = indices_flat.view(*inputs.shape[:-1]) + indices = mask_sequence_tensor(indices, input_len) return indices @typecheck( - input_types={"indices": NeuralType(('B', 'T'), Index())}, + input_types={"indices": NeuralType(('B', 'T'), Index()), "input_len": NeuralType(tuple('B'), LengthsType()),}, output_types={"quantized": NeuralType(('B', 'T', 'D'), EncodedRepresentation())}, ) - def decode(self, indices): + def decode(self, indices, input_len): # [B, T, D] quantized = self._dequantize(indices=indices) + quantized = _mask_3d(quantized, input_len) return quantized @@ -294,7 +321,7 @@ def __init__( self.codebook_dim = codebook_dim if commit_loss_scale: - self.commit_loss_fn = MaskedLoss(loss_type="l2", loss_scale=commit_loss_scale) + self.commit_loss_fn = MaskedMSELoss(loss_scale=commit_loss_scale) else: self.commit_loss_fn = None @@ -336,7 +363,6 @@ def output_types(self): "commit_loss": NeuralType((), LossType()), } - # TODO: Add Masking def forward(self, inputs: Tensor, input_len: Tensor) -> Tuple[Tensor, Tensor, float]: commit_loss = 0.0 residual = rearrange(inputs, "B D T -> B T D") @@ -344,7 +370,7 @@ def forward(self, inputs: Tensor, input_len: Tensor) -> Tuple[Tensor, Tensor, fl index_list = [] quantized = torch.zeros_like(residual) for codebook in self.codebooks: - quantized_i, indices_i = codebook(residual) + quantized_i, indices_i = codebook(inputs=residual, input_len=input_len) if self.training: quantized_i = residual + (quantized_i - residual).detach() @@ -377,9 +403,9 @@ def encode(self, inputs: Tensor, input_len: Tensor) -> Tensor: index_list = [] for codebook in self.codebooks: # [B, T] - indices_i = codebook.encode(inputs=residual) + indices_i = codebook.encode(inputs=residual, input_len=input_len) # [B, D, T] - quantized_i = codebook.decode(indices=indices_i) + quantized_i = codebook.decode(indices=indices_i, input_len=input_len) residual = residual - quantized_i index_list.append(indices_i) # [N, B, T] @@ -397,7 +423,7 @@ def decode(self, indices: Tensor, input_len: Tensor) -> Tensor: # [B, T, D] quantized = torch.zeros([indices.shape[1], indices.shape[2], self.codebook_dim], device=indices.device) for codebook_indices, codebook in zip(indices, self.codebooks): - quantized_i = codebook.decode(indices=codebook_indices) + quantized_i = codebook.decode(indices=codebook_indices, input_len=input_len) quantized = quantized + quantized_i quantized = rearrange(quantized, "B T D -> B D T") return quantized diff --git a/nemo/collections/tts/parts/utils/callbacks.py b/nemo/collections/tts/parts/utils/callbacks.py index 1701c84f594a..896773caf42b 100644 --- a/nemo/collections/tts/parts/utils/callbacks.py +++ b/nemo/collections/tts/parts/utils/callbacks.py @@ -270,9 +270,10 @@ class EnCodecArtifactGenerator(ArtifactGenerator): Generator for logging EnCodec model outputs. """ - def __init__(self, log_audio: bool = True, log_encoding: bool = False): + def __init__(self, log_audio: bool = True, log_encoding: bool = False, log_quantized: bool = False): self.log_audio = log_audio self.log_encoding = log_encoding + self.log_quantized = log_quantized def _generate_audio(self, model, audio_ids, audio, audio_len): if not self.log_audio: @@ -284,7 +285,7 @@ def _generate_audio(self, model, audio_ids, audio, audio_len): audio_artifacts = [] for i, audio_id in enumerate(audio_ids): - audio_pred_i = audio_pred[i][: audio_pred_len[i]].cpu().numpy() + audio_pred_i = audio_pred[i, : audio_pred_len[i]].cpu().numpy() audio_artifact = AudioArtifact( id=f"audio_{audio_id}", data=audio_pred_i, filename=f"{audio_id}.wav", sample_rate=model.sample_rate, ) @@ -292,25 +293,46 @@ def _generate_audio(self, model, audio_ids, audio, audio_len): return audio_artifacts - def _generate_encodings(self, model, audio_ids, audio, audio_len): - if not self.log_encoding: - return [] + def _generate_images(self, model, audio_ids, audio, audio_len): + image_artifacts = [] + + if not self.log_encoding and not self.log_quantized: + return image_artifacts with torch.no_grad(): # [B, D, T] encoded, encoded_len = model.encode_audio(audio=audio, audio_len=audio_len) - image_artifacts = [] + if self.log_encoding: + for i, audio_id in enumerate(audio_ids): + encoded_i = encoded[i, :, : encoded_len[i]].cpu().numpy() + encoded_artifact = ImageArtifact( + id=f"encoded_{audio_id}", + data=encoded_i, + filename=f"{audio_id}_encode.png", + x_axis="Audio Frames", + y_axis="Channels", + ) + image_artifacts.append(encoded_artifact) + + if not self.log_quantized: + return image_artifacts + + with torch.no_grad(): + # [B, D, T] + indices = model.quantize_encode(encoded=encoded, encoded_len=encoded_len) + quantized = model.quantize_decode(indices=indices, encoded_len=encoded_len) + for i, audio_id in enumerate(audio_ids): - encoded_i = encoded[i][:, : encoded_len[i]].cpu().numpy() - encoded_artifact = ImageArtifact( - id=f"encoded_{audio_id}", - data=encoded_i, - filename=f"{audio_id}_encode.png", + quantized_i = quantized[i, :, : encoded_len[i]].cpu().numpy() + quantized_artifact = ImageArtifact( + id=f"quantized_{audio_id}", + data=quantized_i, + filename=f"{audio_id}_quantized.png", x_axis="Audio Frames", y_axis="Channels", ) - image_artifacts.append(encoded_artifact) + image_artifacts.append(quantized_artifact) return image_artifacts @@ -325,7 +347,7 @@ def generate_artifacts( audio_len = batch_dict.get("audio_lens") audio_artifacts = self._generate_audio(model=model, audio_ids=audio_ids, audio=audio, audio_len=audio_len) - image_artifacts = self._generate_encodings(model=model, audio_ids=audio_ids, audio=audio, audio_len=audio_len) + image_artifacts = self._generate_images(model=model, audio_ids=audio_ids, audio=audio, audio_len=audio_len) return audio_artifacts, image_artifacts diff --git a/nemo/collections/tts/parts/utils/helpers.py b/nemo/collections/tts/parts/utils/helpers.py index b9ea0854e48c..72048882fe78 100644 --- a/nemo/collections/tts/parts/utils/helpers.py +++ b/nemo/collections/tts/parts/utils/helpers.py @@ -733,7 +733,10 @@ def mask_sequence_tensor(tensor: torch.Tensor, lengths: torch.Tensor): """ batch_size, *_, max_lengths = tensor.shape - if len(tensor.shape) == 3: + if len(tensor.shape) == 2: + mask = torch.ones(batch_size, max_lengths).cumsum(dim=-1).type_as(lengths) + mask = mask <= rearrange(lengths, "b -> b 1") + elif len(tensor.shape) == 3: mask = torch.ones(batch_size, 1, max_lengths).cumsum(dim=-1).type_as(lengths) mask = mask <= rearrange(lengths, "b -> b 1 1") elif len(tensor.shape) == 4: diff --git a/tests/collections/tts/losses/test_loss.py b/tests/collections/tts/losses/test_encodec_loss.py similarity index 89% rename from tests/collections/tts/losses/test_loss.py rename to tests/collections/tts/losses/test_encodec_loss.py index ba3f399566e1..ab5870e1a33c 100644 --- a/tests/collections/tts/losses/test_loss.py +++ b/tests/collections/tts/losses/test_encodec_loss.py @@ -15,14 +15,14 @@ import pytest import torch -from nemo.collections.tts.losses.loss import MaskedLoss +from nemo.collections.tts.losses.encodec_loss import MaskedMAELoss, MaskedMSELoss -class TestTTSLoss: +class TestEnCodecLoss: @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_masked_loss_l1(self): - loss_fn = MaskedLoss("l1") + loss_fn = MaskedMAELoss() target = torch.tensor([[[1.0], [2.0], [0.0]], [[3.0], [0.0], [0.0]]]).transpose(1, 2) predicted = torch.tensor([[[0.5], [1.0], [0.0]], [[4.5], [0.0], [0.0]]]).transpose(1, 2) target_len = torch.tensor([2, 1]) @@ -34,7 +34,7 @@ def test_masked_loss_l1(self): @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_masked_loss_l2(self): - loss_fn = MaskedLoss("l2") + loss_fn = MaskedMSELoss() target = torch.tensor([[[1.0], [2.0], [4.0]], [[3.0], [0.0], [0.0]]]).transpose(1, 2) predicted = torch.tensor([[[0.5], [1.0], [4.0]], [[4.5], [0.0], [0.0]]]).transpose(1, 2) target_len = torch.tensor([3, 1]) From f6794e3d9c3a5767c7a813d4a8558f9aa98a4d2a Mon Sep 17 00:00:00 2001 From: Ryan Date: Fri, 7 Jul 2023 13:26:14 -0700 Subject: [PATCH 3/5] [TTS] Rename EnCodec to AudioCodec Signed-off-by: Ryan --- examples/tts/{encodec.py => audio_codec.py} | 6 +- .../{encodec => audio_codec}/encodec.yaml | 20 ++++-- .../{encodec_loss.py => audio_codec_loss.py} | 63 ++++++++++++++++--- nemo/collections/tts/models/__init__.py | 4 +- .../tts/models/{encodec.py => audio_codec.py} | 26 +++----- ...odec_modules.py => audio_codec_modules.py} | 0 .../tts/modules/vector_quantization.py | 2 +- nemo/collections/tts/parts/utils/callbacks.py | 4 +- .../tts/losses/test_encodec_loss.py | 2 +- 9 files changed, 87 insertions(+), 40 deletions(-) rename examples/tts/{encodec.py => audio_codec.py} (83%) rename examples/tts/conf/{encodec => audio_codec}/encodec.yaml (87%) rename nemo/collections/tts/losses/{encodec_loss.py => audio_codec_loss.py} (81%) rename nemo/collections/tts/models/{encodec.py => audio_codec.py} (96%) rename nemo/collections/tts/modules/{encodec_modules.py => audio_codec_modules.py} (100%) diff --git a/examples/tts/encodec.py b/examples/tts/audio_codec.py similarity index 83% rename from examples/tts/encodec.py rename to examples/tts/audio_codec.py index 1443a1f54f40..ffc91cd98f01 100644 --- a/examples/tts/encodec.py +++ b/examples/tts/audio_codec.py @@ -14,16 +14,16 @@ import pytorch_lightning as pl -from nemo.collections.tts.models import EnCodecModel +from nemo.collections.tts.models import AudioCodecModel from nemo.core.config import hydra_runner from nemo.utils.exp_manager import exp_manager -@hydra_runner(config_path="conf/encodec", config_name="encodec") +@hydra_runner(config_path="conf/audio_codec", config_name="audio_codec") def main(cfg): trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) - model = EnCodecModel(cfg=cfg.model, trainer=trainer) + model = AudioCodecModel(cfg=cfg.model, trainer=trainer) trainer.fit(model) diff --git a/examples/tts/conf/encodec/encodec.yaml b/examples/tts/conf/audio_codec/encodec.yaml similarity index 87% rename from examples/tts/conf/encodec/encodec.yaml rename to examples/tts/conf/audio_codec/encodec.yaml index cbaae4259aa7..c06d38d9e075 100644 --- a/examples/tts/conf/encodec/encodec.yaml +++ b/examples/tts/conf/audio_codec/encodec.yaml @@ -1,8 +1,8 @@ -# This config contains the default values for training 24khz EnCodec +# This config contains the default values for training 24khz EnCodec model # If you want to train model on other dataset, you can change config values according to your dataset. # Most dataset-specific arguments are in the head of the config file, see below. -name: EnCodec +name: AudioCodec max_epochs: ??? # Adjust batch size based on GPU memory @@ -35,6 +35,8 @@ model: sample_rate: ${sample_rate} samples_per_frame: ${samples_per_frame} + gen_loss_scale: 3.0 + feature_loss_scale: 3.0 # Probability of updating the discriminator during each training step disc_update_prob: 0.67 @@ -82,7 +84,7 @@ model: log_wandb: false generators: - - _target_: nemo.collections.tts.parts.utils.callbacks.EnCodecArtifactGenerator + - _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator log_audio: true log_encoding: true log_quantized: true @@ -101,11 +103,11 @@ model: num_workers: 2 audio_encoder: - _target_: nemo.collections.tts.modules.encodec_modules.SEANetEncoder + _target_: nemo.collections.tts.modules.audio_codec_modules.SEANetEncoder down_sample_rates: ${down_sample_rates} audio_decoder: - _target_: nemo.collections.tts.modules.encodec_modules.SEANetDecoder + _target_: nemo.collections.tts.modules.audio_codec_modules.SEANetDecoder up_sample_rates: ${up_sample_rates} vector_quantizer: @@ -113,9 +115,15 @@ model: num_codebooks: 8 discriminator: - _target_: nemo.collections.tts.modules.encodec_modules.MultiResolutionDiscriminatorSTFT + _target_: nemo.collections.tts.modules.audio_codec_modules.MultiResolutionDiscriminatorSTFT resolutions: [[128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]] + generator_loss: + _target_: nemo.collections.tts.losses.audio_codec_loss.GeneratorHingedLoss + + discriminator_loss: + _target_: nemo.collections.tts.losses.audio_codec_loss.DiscriminatorHingedLoss + optim: _target_: torch.optim.Adam lr: 3e-4 diff --git a/nemo/collections/tts/losses/encodec_loss.py b/nemo/collections/tts/losses/audio_codec_loss.py similarity index 81% rename from nemo/collections/tts/losses/encodec_loss.py rename to nemo/collections/tts/losses/audio_codec_loss.py index 7448b2decabe..d895a2c9e176 100644 --- a/nemo/collections/tts/losses/encodec_loss.py +++ b/nemo/collections/tts/losses/audio_codec_loss.py @@ -164,7 +164,7 @@ def forward(self, audio_real, audio_gen, audio_len): class RelativeFeatureMatchingLoss(Loss): - def __init__(self, div_guard=1e-2): + def __init__(self, div_guard=1e-6): super(RelativeFeatureMatchingLoss, self).__init__() self.div_guard = div_guard @@ -200,11 +200,33 @@ def forward(self, fmaps_real, fmaps_gen): return loss -class GeneratorLoss(Loss): +class GeneratorHingedLoss(Loss): @property def input_types(self): return { - "disc_scores": [NeuralType(('B', 'C', 'T'), VoidType())], + "disc_scores_gen": [NeuralType(('B', 'C', 'T'), VoidType())], + } + + @property + def output_types(self): + return {"loss": NeuralType(elements_type=LossType())} + + @typecheck() + def forward(self, disc_scores_gen): + loss = 0.0 + for disc_score_gen in disc_scores_gen: + loss += torch.mean(F.relu(1 - disc_score_gen)) + + loss /= len(disc_scores_gen) + + return loss + + +class GeneratorSquaredLoss(Loss): + @property + def input_types(self): + return { + "disc_scores_gen": [NeuralType(('B', 'C', 'T'), VoidType())], } @property @@ -212,17 +234,17 @@ def output_types(self): return {"loss": NeuralType(elements_type=LossType())} @typecheck() - def forward(self, disc_scores): + def forward(self, disc_scores_gen): loss = 0.0 - for disc_score in disc_scores: - loss += torch.mean(F.relu(1 - disc_score)) + for disc_score_gen in disc_scores_gen: + loss += torch.mean((1 - disc_score_gen) ** 2) - loss /= len(disc_scores) + loss /= len(disc_scores_gen) return loss -class DiscriminatorLoss(Loss): +class DiscriminatorHingedLoss(Loss): @property def input_types(self): return { @@ -245,3 +267,28 @@ def forward(self, disc_scores_real, disc_scores_gen): loss /= len(disc_scores_real) return loss + + +class DiscriminatorSquaredLoss(Loss): + @property + def input_types(self): + return { + "disc_scores_real": [NeuralType(('B', 'C', 'T'), VoidType())], + "disc_scores_gen": [NeuralType(('B', 'C', 'T'), VoidType())], + } + + @property + def output_types(self): + return {"loss": NeuralType(elements_type=LossType())} + + @typecheck() + def forward(self, disc_scores_real, disc_scores_gen): + loss = 0.0 + for disc_score_real, disc_score_gen in zip(disc_scores_real, disc_scores_gen): + loss_real = torch.mean((1 - disc_score_real) ** 2) + loss_gen = torch.mean(disc_score_gen ** 2) + loss += (loss_real + loss_gen) / 2 + + loss /= len(disc_scores_real) + + return loss diff --git a/nemo/collections/tts/models/__init__.py b/nemo/collections/tts/models/__init__.py index 31e649688521..4f01ea6a099e 100644 --- a/nemo/collections/tts/models/__init__.py +++ b/nemo/collections/tts/models/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from nemo.collections.tts.models.aligner import AlignerModel -from nemo.collections.tts.models.encodec import EnCodecModel +from nemo.collections.tts.models.audio_codec import AudioCodecModel from nemo.collections.tts.models.fastpitch import FastPitchModel from nemo.collections.tts.models.fastpitch_ssl import FastPitchModel_SSL from nemo.collections.tts.models.hifigan import HifiGanModel @@ -29,6 +29,7 @@ __all__ = [ "AlignerModel", + "AudioCodecModel", "FastPitchModel", "FastPitchModel_SSL", "SSLDisentangler", @@ -43,5 +44,4 @@ "VitsModel", "WaveGlowModel", "SpectrogramEnhancerModel", - "EnCodecModel", ] diff --git a/nemo/collections/tts/models/encodec.py b/nemo/collections/tts/models/audio_codec.py similarity index 96% rename from nemo/collections/tts/models/encodec.py rename to nemo/collections/tts/models/audio_codec.py index b0480588044e..f289e778d474 100644 --- a/nemo/collections/tts/models/encodec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -13,7 +13,6 @@ # limitations under the License. import itertools -import math import random from pathlib import Path from typing import List, Tuple @@ -24,13 +23,7 @@ from omegaconf import DictConfig, OmegaConf from pytorch_lightning import Trainer -from nemo.collections.tts.losses.encodec_loss import ( - AudioLoss, - DiscriminatorLoss, - GeneratorLoss, - MultiResolutionMelLoss, - RelativeFeatureMatchingLoss, -) +from nemo.collections.tts.losses.audio_codec_loss import AudioLoss, MultiResolutionMelLoss, RelativeFeatureMatchingLoss from nemo.collections.tts.modules.common import GaussianDropout from nemo.collections.tts.parts.utils.callbacks import LoggingCallback from nemo.collections.tts.parts.utils.helpers import get_batch_size, get_num_workers @@ -44,9 +37,7 @@ @experimental -class EnCodecModel(ModelPT): - """EnCodec model (https://github.com/facebookresearch/encodec) that encodes and decodes audio.""" - +class AudioCodecModel(ModelPT): def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Convert to Hydra 1.0 compatible DictConfig cfg = model_utils.convert_model_config_to_dict_config(cfg) @@ -80,8 +71,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.audio_loss_scale = cfg.get("audio_loss_scale", 0.1) self.mel_loss_scale = cfg.get("mel_loss_scale", 1.0) mel_loss_l1_scale = cfg.get("mel_loss_l1_scale", 1.0) - self.gen_loss_scale = cfg.get("gen_loss_scale", 3.0) - self.feature_loss_scale = cfg.get("feature_loss_scale", 3.0) + self.gen_loss_scale = cfg.get("gen_loss_scale", 1.0) + self.feature_loss_scale = cfg.get("feature_loss_scale", 1.0) self.audio_loss_fn = AudioLoss() self.mel_loss_fn = MultiResolutionMelLoss( @@ -90,9 +81,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): resolutions=mel_loss_resolutions, l1_scale=mel_loss_l1_scale, ) - self.gen_loss_fn = GeneratorLoss() + self.gen_loss_fn = instantiate(cfg.generator_loss) + self.disc_loss_fn = instantiate(cfg.discriminator_loss) self.feature_loss_fn = RelativeFeatureMatchingLoss() - self.disc_loss_fn = DiscriminatorLoss() self.log_config = cfg.get("log_config", None) self.lr_schedule_interval = None @@ -225,8 +216,9 @@ def training_step(self, batch, batch_idx): audio_real=audio, audio_gen=audio_gen.detach() ) loss_disc = self.disc_loss_fn(disc_scores_real=disc_scores_real, disc_scores_gen=disc_scores_gen) + train_disc_loss = loss_disc - self.manual_backward(loss_disc) + self.manual_backward(train_disc_loss) optim_disc.step() loss_audio = self.audio_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) @@ -237,7 +229,7 @@ def training_step(self, batch, batch_idx): _, disc_scores_gen, fmaps_real, fmaps_gen = self.discriminator(audio_real=audio, audio_gen=audio_gen) - loss_gen = self.gen_loss_fn(disc_scores=disc_scores_gen) + loss_gen = self.gen_loss_fn(disc_scores_gen=disc_scores_gen) train_loss_gen = self.gen_loss_scale * loss_gen loss_feature = self.feature_loss_fn(fmaps_real=fmaps_real, fmaps_gen=fmaps_gen) diff --git a/nemo/collections/tts/modules/encodec_modules.py b/nemo/collections/tts/modules/audio_codec_modules.py similarity index 100% rename from nemo/collections/tts/modules/encodec_modules.py rename to nemo/collections/tts/modules/audio_codec_modules.py diff --git a/nemo/collections/tts/modules/vector_quantization.py b/nemo/collections/tts/modules/vector_quantization.py index 3b4d9fa3ab53..ac4b3a3f9aa3 100644 --- a/nemo/collections/tts/modules/vector_quantization.py +++ b/nemo/collections/tts/modules/vector_quantization.py @@ -20,7 +20,7 @@ from einops import rearrange, repeat from torch import Tensor -from nemo.collections.tts.losses.encodec_loss import MaskedMSELoss +from nemo.collections.tts.losses.audio_codec_loss import MaskedMSELoss from nemo.collections.tts.parts.utils.distributed import broadcast_tensors from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor from nemo.core.classes.common import typecheck diff --git a/nemo/collections/tts/parts/utils/callbacks.py b/nemo/collections/tts/parts/utils/callbacks.py index 896773caf42b..0d408658d8ad 100644 --- a/nemo/collections/tts/parts/utils/callbacks.py +++ b/nemo/collections/tts/parts/utils/callbacks.py @@ -265,9 +265,9 @@ def generate_artifacts( return audio_artifacts, [] -class EnCodecArtifactGenerator(ArtifactGenerator): +class AudioCodecArtifactGenerator(ArtifactGenerator): """ - Generator for logging EnCodec model outputs. + Generator for logging Audio Codec model outputs. """ def __init__(self, log_audio: bool = True, log_encoding: bool = False, log_quantized: bool = False): diff --git a/tests/collections/tts/losses/test_encodec_loss.py b/tests/collections/tts/losses/test_encodec_loss.py index ab5870e1a33c..3138a1c97dd2 100644 --- a/tests/collections/tts/losses/test_encodec_loss.py +++ b/tests/collections/tts/losses/test_encodec_loss.py @@ -15,7 +15,7 @@ import pytest import torch -from nemo.collections.tts.losses.encodec_loss import MaskedMAELoss, MaskedMSELoss +from nemo.collections.tts.losses.audio_codec_loss import MaskedMAELoss, MaskedMSELoss class TestEnCodecLoss: From b1d4da0ba26a3368193445c39487f8ae1c7b0940 Mon Sep 17 00:00:00 2001 From: Ryan Date: Tue, 25 Jul 2023 08:13:48 -0700 Subject: [PATCH 4/5] [TTS] Add EnCodec unit tests Signed-off-by: Ryan --- examples/tts/conf/audio_codec/encodec.yaml | 8 +- .../tts/losses/audio_codec_loss.py | 2 +- ...codec_loss.py => test_audio_codec_loss.py} | 2 +- .../tts/modules/test_audio_codec_modules.py | 96 +++++++++++++++++++ 4 files changed, 102 insertions(+), 6 deletions(-) rename tests/collections/tts/losses/{test_encodec_loss.py => test_audio_codec_loss.py} (98%) create mode 100644 tests/collections/tts/modules/test_audio_codec_modules.py diff --git a/examples/tts/conf/audio_codec/encodec.yaml b/examples/tts/conf/audio_codec/encodec.yaml index c06d38d9e075..2b6731dc90ac 100644 --- a/examples/tts/conf/audio_codec/encodec.yaml +++ b/examples/tts/conf/audio_codec/encodec.yaml @@ -35,8 +35,6 @@ model: sample_rate: ${sample_rate} samples_per_frame: ${samples_per_frame} - gen_loss_scale: 3.0 - feature_loss_scale: 3.0 # Probability of updating the discriminator during each training step disc_update_prob: 0.67 @@ -118,11 +116,13 @@ model: _target_: nemo.collections.tts.modules.audio_codec_modules.MultiResolutionDiscriminatorSTFT resolutions: [[128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024], [2048, 512, 2048]] + # The original EnCodec uses hinged loss, but squared-GAN loss is more stable + # and reduces the need to tune the loss weights or use a gradient balancer. generator_loss: - _target_: nemo.collections.tts.losses.audio_codec_loss.GeneratorHingedLoss + _target_: nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss discriminator_loss: - _target_: nemo.collections.tts.losses.audio_codec_loss.DiscriminatorHingedLoss + _target_: nemo.collections.tts.losses.audio_codec_loss.DiscriminatorSquaredLoss optim: _target_: torch.optim.Adam diff --git a/nemo/collections/tts/losses/audio_codec_loss.py b/nemo/collections/tts/losses/audio_codec_loss.py index d895a2c9e176..adde30e53773 100644 --- a/nemo/collections/tts/losses/audio_codec_loss.py +++ b/nemo/collections/tts/losses/audio_codec_loss.py @@ -164,7 +164,7 @@ def forward(self, audio_real, audio_gen, audio_len): class RelativeFeatureMatchingLoss(Loss): - def __init__(self, div_guard=1e-6): + def __init__(self, div_guard=1e-3): super(RelativeFeatureMatchingLoss, self).__init__() self.div_guard = div_guard diff --git a/tests/collections/tts/losses/test_encodec_loss.py b/tests/collections/tts/losses/test_audio_codec_loss.py similarity index 98% rename from tests/collections/tts/losses/test_encodec_loss.py rename to tests/collections/tts/losses/test_audio_codec_loss.py index 3138a1c97dd2..0fe7991e92cb 100644 --- a/tests/collections/tts/losses/test_encodec_loss.py +++ b/tests/collections/tts/losses/test_audio_codec_loss.py @@ -18,7 +18,7 @@ from nemo.collections.tts.losses.audio_codec_loss import MaskedMAELoss, MaskedMSELoss -class TestEnCodecLoss: +class TestAudioCodecLoss: @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_masked_loss_l1(self): diff --git a/tests/collections/tts/modules/test_audio_codec_modules.py b/tests/collections/tts/modules/test_audio_codec_modules.py new file mode 100644 index 000000000000..948b1220f39c --- /dev/null +++ b/tests/collections/tts/modules/test_audio_codec_modules.py @@ -0,0 +1,96 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from nemo.collections.tts.modules.audio_codec_modules import ( + Conv1dNorm, + ConvTranspose1dNorm, + get_down_sample_padding, + get_up_sample_padding, +) + + +class TestAudioCodecModules: + def setup_class(self): + self.in_channels = 8 + self.out_channels = 16 + self.batch_size = 2 + self.len1 = 4 + self.len2 = 8 + self.max_len = 10 + self.kernel_size = 3 + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_conv1d(self): + inputs = torch.rand([self.batch_size, self.in_channels, self.max_len]) + lengths = torch.tensor([self.len1, self.len2], dtype=torch.int32) + + conv = Conv1dNorm(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size) + out = conv(inputs, lengths) + + assert out.shape == (self.batch_size, self.out_channels, self.max_len) + assert torch.all(out[0, :, : self.len1] != 0.0) + assert torch.all(out[0, :, self.len1 :] == 0.0) + assert torch.all(out[1, :, : self.len2] != 0.0) + assert torch.all(out[1, :, self.len2 :] == 0.0) + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_conv1d_downsample(self): + stride = 2 + out_len = self.max_len // stride + out_len_1 = self.len1 // stride + out_len_2 = self.len2 // stride + inputs = torch.rand([self.batch_size, self.in_channels, self.max_len]) + lengths = torch.tensor([out_len_1, out_len_2], dtype=torch.int32) + + padding = get_down_sample_padding(kernel_size=self.kernel_size, stride=stride) + conv = Conv1dNorm( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=self.kernel_size, + stride=stride, + padding=padding, + ) + out = conv(inputs, lengths) + + assert out.shape == (self.batch_size, self.out_channels, out_len) + assert torch.all(out[0, :, :out_len_1] != 0.0) + assert torch.all(out[0, :, out_len_1:] == 0.0) + assert torch.all(out[1, :, :out_len_2] != 0.0) + assert torch.all(out[1, :, out_len_2:] == 0.0) + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_conv1d_transpose_upsample(self): + stride = 2 + out_len = self.max_len * stride + out_len_1 = self.len1 * stride + out_len_2 = self.len2 * stride + inputs = torch.rand([self.batch_size, self.in_channels, self.max_len]) + lengths = torch.tensor([out_len_1, out_len_2], dtype=torch.int32) + + conv = ConvTranspose1dNorm( + in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size, stride=stride + ) + out = conv(inputs, lengths) + + assert out.shape == (self.batch_size, self.out_channels, out_len) + assert torch.all(out[0, :, :out_len_1] != 0.0) + assert torch.all(out[0, :, out_len_1:] == 0.0) + assert torch.all(out[1, :, :out_len_2] != 0.0) + assert torch.all(out[1, :, out_len_2:] == 0.0) From 370715ac6db6ee8911114df88e459381a45ce654 Mon Sep 17 00:00:00 2001 From: Ryan Date: Mon, 31 Jul 2023 13:08:10 -0700 Subject: [PATCH 5/5] [TTS] Add copyright header to distributed.py Signed-off-by: Ryan --- examples/tts/conf/audio_codec/encodec.yaml | 3 ++- .../tts/losses/audio_codec_loss.py | 4 ++-- nemo/collections/tts/models/audio_codec.py | 20 +++++++++++-------- .../tts/parts/utils/distributed.py | 14 +++++++++++++ 4 files changed, 30 insertions(+), 11 deletions(-) diff --git a/examples/tts/conf/audio_codec/encodec.yaml b/examples/tts/conf/audio_codec/encodec.yaml index 2b6731dc90ac..e6e9f2e7876f 100644 --- a/examples/tts/conf/audio_codec/encodec.yaml +++ b/examples/tts/conf/audio_codec/encodec.yaml @@ -2,7 +2,7 @@ # If you want to train model on other dataset, you can change config values according to your dataset. # Most dataset-specific arguments are in the head of the config file, see below. -name: AudioCodec +name: EnCodec max_epochs: ??? # Adjust batch size based on GPU memory @@ -35,6 +35,7 @@ model: sample_rate: ${sample_rate} samples_per_frame: ${samples_per_frame} + time_domain_loss_scale: 0.1 # Probability of updating the discriminator during each training step disc_update_prob: 0.67 diff --git a/nemo/collections/tts/losses/audio_codec_loss.py b/nemo/collections/tts/losses/audio_codec_loss.py index adde30e53773..bde96fadb4c2 100644 --- a/nemo/collections/tts/losses/audio_codec_loss.py +++ b/nemo/collections/tts/losses/audio_codec_loss.py @@ -81,9 +81,9 @@ def __init__(self, loss_scale: float = 1.0): super(MaskedMSELoss, self).__init__(loss_fn=loss_fn, loss_scale=loss_scale) -class AudioLoss(Loss): +class TimeDomainLoss(Loss): def __init__(self): - super(AudioLoss, self).__init__() + super(TimeDomainLoss, self).__init__() self.loss_fn = MaskedMAELoss() @property diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index f289e778d474..30f74dc2be2a 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -23,7 +23,11 @@ from omegaconf import DictConfig, OmegaConf from pytorch_lightning import Trainer -from nemo.collections.tts.losses.audio_codec_loss import AudioLoss, MultiResolutionMelLoss, RelativeFeatureMatchingLoss +from nemo.collections.tts.losses.audio_codec_loss import ( + MultiResolutionMelLoss, + RelativeFeatureMatchingLoss, + TimeDomainLoss, +) from nemo.collections.tts.modules.common import GaussianDropout from nemo.collections.tts.parts.utils.callbacks import LoggingCallback from nemo.collections.tts.parts.utils.helpers import get_batch_size, get_num_workers @@ -68,13 +72,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): mel_loss_dim = cfg.get("mel_loss_dim", 64) mel_loss_resolutions = cfg.mel_loss_resolutions - self.audio_loss_scale = cfg.get("audio_loss_scale", 0.1) + self.time_domain_loss_scale = cfg.get("time_domain_loss_scale", 1.0) self.mel_loss_scale = cfg.get("mel_loss_scale", 1.0) mel_loss_l1_scale = cfg.get("mel_loss_l1_scale", 1.0) self.gen_loss_scale = cfg.get("gen_loss_scale", 1.0) self.feature_loss_scale = cfg.get("feature_loss_scale", 1.0) - self.audio_loss_fn = AudioLoss() + self.time_domain_loss_fn = TimeDomainLoss() self.mel_loss_fn = MultiResolutionMelLoss( sample_rate=self.sample_rate, mel_dim=mel_loss_dim, @@ -221,8 +225,8 @@ def training_step(self, batch, batch_idx): self.manual_backward(train_disc_loss) optim_disc.step() - loss_audio = self.audio_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) - train_loss_audio = self.audio_loss_scale * loss_audio + loss_time_domain = self.time_domain_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) + train_loss_time_domain = self.time_domain_loss_scale * loss_time_domain loss_mel = self.mel_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) train_loss_mel = self.mel_loss_scale * loss_mel @@ -235,7 +239,7 @@ def training_step(self, batch, batch_idx): loss_feature = self.feature_loss_fn(fmaps_real=fmaps_real, fmaps_gen=fmaps_gen) train_loss_feature = self.feature_loss_scale * loss_feature - loss_gen_all = train_loss_audio + train_loss_mel + train_loss_gen + train_loss_feature + loss_gen_all = train_loss_time_domain + train_loss_mel + train_loss_gen + train_loss_feature if commit_loss is not None: loss_gen_all += commit_loss @@ -245,7 +249,7 @@ def training_step(self, batch, batch_idx): self.update_lr() metrics = { - "g_loss_audio": loss_audio, + "g_loss_time_domain": loss_time_domain, "g_loss_mel": loss_mel, "g_loss_gen": loss_gen, "g_loss_feature": loss_feature, @@ -268,7 +272,7 @@ def training_epoch_end(self, outputs): def validation_step(self, batch, batch_idx): audio, audio_len, audio_gen, _ = self._process_batch(batch) - loss_audio = self.audio_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) + loss_audio = self.time_domain_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) loss_mel = self.mel_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) metrics = {"val_loss": loss_audio + loss_mel, "val_loss_audio": loss_audio, "val_loss_mel": loss_mel} self.log_dict(metrics, on_epoch=True, sync_dist=True) diff --git a/nemo/collections/tts/parts/utils/distributed.py b/nemo/collections/tts/parts/utils/distributed.py index 75db69e67059..cbe102bcfdcd 100644 --- a/nemo/collections/tts/parts/utils/distributed.py +++ b/nemo/collections/tts/parts/utils/distributed.py @@ -1,3 +1,17 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Iterable import torch