Skip to content

Commit

Permalink
[TTS][refactor] Part 7 - move module from model file. (NVIDIA#6098)
Browse files Browse the repository at this point in the history
* [TTS] move module from model file.

Signed-off-by: Xuesong Yang <[email protected]>

* update copyright header

Signed-off-by: Xuesong Yang <[email protected]>

* update copyright header

Signed-off-by: Xuesong Yang <[email protected]>

---------

Signed-off-by: Xuesong Yang <[email protected]>
  • Loading branch information
XuesongYang authored and titu1994 committed Mar 24, 2023
1 parent fd7d3c1 commit b0537a0
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 35 deletions.
6 changes: 5 additions & 1 deletion nemo/collections/tts/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,10 +14,12 @@

from nemo.collections.tts.models.aligner import AlignerModel
from nemo.collections.tts.models.fastpitch import FastPitchModel
from nemo.collections.tts.models.fastpitch_ssl import FastPitchModel_SSL
from nemo.collections.tts.models.hifigan import HifiGanModel
from nemo.collections.tts.models.mixer_tts import MixerTTSModel
from nemo.collections.tts.models.radtts import RadTTSModel
from nemo.collections.tts.models.spectrogram_enhancer import SpectrogramEnhancerModel
from nemo.collections.tts.models.ssl_tts import SSLDisentangler
from nemo.collections.tts.models.tacotron2 import Tacotron2Model
from nemo.collections.tts.models.two_stages import GriffinLimModel, MelPsuedoInverseModel, TwoStagesModel
from nemo.collections.tts.models.univnet import UnivNetModel
Expand All @@ -27,6 +29,8 @@
__all__ = [
"AlignerModel",
"FastPitchModel",
"FastPitchModel_SSL",
"SSLDisentangler",
"GriffinLimModel",
"HifiGanModel",
"MelPsuedoInverseModel",
Expand Down
46 changes: 12 additions & 34 deletions nemo/collections/tts/models/ssl_tts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,13 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from dataclasses import dataclass
from typing import Dict, Iterable, Optional, Union
from typing import Iterable, Optional

import editdistance
import librosa
import torch
import torch.nn as nn
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
Expand All @@ -27,6 +25,7 @@

import nemo.collections.tts.torch.data as TTSData
from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss
from nemo.collections.tts.modules.ssl_tts import GreedyCTCDecoder
from nemo.collections.tts.torch.tts_tokenizers import BaseTokenizer, EnglishCharsTokenizer
from nemo.core.classes import ModelPT
from nemo.core.classes.common import PretrainedModelInfo
Expand All @@ -35,27 +34,6 @@
from nemo.utils.decorators import experimental


class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels, blank=0):
super().__init__()
self.labels = labels
self.blank = blank

def forward(self, emission):
"""Given a sequence emission over labels, get the best path
Args:
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
List[str]: The resulting transcript
"""
indices = torch.argmax(emission, dim=-1) # [num_seq,]
indices = torch.unique_consecutive(indices, dim=-1)
indices = [i for i in indices if i != self.blank]
joined = "".join([self.labels[i] for i in indices])
return indices, joined


@experimental
class SSLDisentangler(ModelPT):
"""
Expand All @@ -74,34 +52,34 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self._text_tokenizer = EnglishCharsTokenizer(add_blank_at="last")
self._tb_logger = None

self.downstream_nets = nn.ModuleDict()
self.downstream_nets = torch.nn.ModuleDict()
for task in self._cfg.downstream_heads.task_names:

if task == 'speaker_verification':
# setting up downstream heads and loss functions for speaker verification task
in_dim = self._cfg.encoder.d_model
out_dim = self._cfg.downstream_heads.speaker_embed_size
num_speakers = self._cfg.downstream_heads.num_speakers
self.downstream_nets[task] = nn.Linear(in_dim, out_dim)
self.sv_linear = nn.Linear(out_dim, num_speakers)
self.downstream_nets[task] = torch.nn.Linear(in_dim, out_dim)
self.sv_linear = torch.nn.Linear(out_dim, num_speakers)
self.sv_loss = AngularSoftmaxLoss(scale=30, margin=0.4)

elif task == 'content':
# setting up downstream heads and loss functions for text/content recognition task
in_dim = self._cfg.encoder.d_model
out_dim = self._cfg.downstream_heads.content_embed_size
num_chars = len(self._text_tokenizer.tokens) # list of english tokens
self.downstream_nets[task] = nn.Linear(in_dim, out_dim)
self.content_linear = nn.Linear(out_dim, num_chars)
self.ctc_loss = nn.CTCLoss(blank=self._text_tokenizer.blank, zero_infinity=True)
self.downstream_nets[task] = torch.nn.Linear(in_dim, out_dim)
self.content_linear = torch.nn.Linear(out_dim, num_chars)
self.ctc_loss = torch.nn.CTCLoss(blank=self._text_tokenizer.blank, zero_infinity=True)
self.pitch_augment = self._cfg.get('pitch_augment', False)
self.augment_ctc = self._cfg.get('augment_ctc', False)
self.aug_loss_type = self._cfg.get('aug_loss_type', 'mse')
self.stop_gradient = self._cfg.get('stop_gradient', False)
assert (
self.stop_gradient and self.augment_ctc
) == False, "stop_gradient and augment_ctc cannot be true at the same time"
self.mse_loss = nn.MSELoss()
self.mse_loss = torch.nn.MSELoss()

self.ctc_decoder = GreedyCTCDecoder(self._text_tokenizer.tokens, self._text_tokenizer.blank)

Expand Down Expand Up @@ -380,7 +358,7 @@ def training_step(self, batch, batch_idx):
sim_loss = self.mse_loss(content_embedding, content_embedding_aug)
elif self.aug_loss_type == "cosine":

cosine_similarity = nn.functional.cosine_similarity(
cosine_similarity = torch.nn.functional.cosine_similarity(
content_embedding, content_embedding_aug, dim=-1
).mean()

Expand Down Expand Up @@ -471,7 +449,7 @@ def validation_step(self, batch, batch_idx):
if self.aug_loss_type == "mse":
sim_loss = self.mse_loss(content_embedding, content_embedding_aug)
elif self.aug_loss_type == "cosine":
cosine_similarity = nn.functional.cosine_similarity(
cosine_similarity = torch.nn.functional.cosine_similarity(
content_embedding, content_embedding_aug, dim=-1
).mean()
sim_loss = 1.0 - cosine_similarity
Expand Down
35 changes: 35 additions & 0 deletions nemo/collections/tts/modules/ssl_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch


class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels, blank=0):
super().__init__()
self.labels = labels
self.blank = blank

def forward(self, emission):
"""Given a sequence emission over labels, get the best path
Args:
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
List[str]: The resulting transcript
"""
indices = torch.argmax(emission, dim=-1) # [num_seq,]
indices = torch.unique_consecutive(indices, dim=-1)
indices = [i for i in indices if i != self.blank]
joined = "".join([self.labels[i] for i in indices])
return indices, joined

0 comments on commit b0537a0

Please sign in to comment.