diff --git a/examples/hey_snips/README.md b/examples/hey_snips/README.md index ba263906abf..6311ad928a5 100644 --- a/examples/hey_snips/README.md +++ b/examples/hey_snips/README.md @@ -2,7 +2,7 @@ ## Metrics We mesure FRRs with fixing false alarms in one hour: - +the release model: https://paddlespeech.bj.bcebos.com/kws/heysnips/kws0_mdtc_heysnips_ckpt.tar.gz |Model|False Alarm| False Reject Rate| |--|--|--| |MDTC| 1| 0.003559 | diff --git a/examples/zh_en_tts/tts3/README.md b/examples/zh_en_tts/tts3/README.md index b4b683089cd..f63d5d8fe37 100644 --- a/examples/zh_en_tts/tts3/README.md +++ b/examples/zh_en_tts/tts3/README.md @@ -116,6 +116,8 @@ optional arguments: 5. `--phones-dict` is the path of the phone vocabulary file. 6. `--speaker-dict` is the path of the speaker id map file when training a multi-speaker FastSpeech2. +We have **added module speaker classifier** with reference to [Learning to Speak Fluently in a Foreign Language: Multilingual Speech Synthesis and Cross-Language Voice Cloning](https://arxiv.org/pdf/1907.04448.pdf). The main parameter configuration: config["model"]["enable_speaker_classifier"], config["model"]["hidden_sc_dim"] and config["updater"]["spk_loss_scale"] in `conf/default.yaml`. The current experimental results show that this module can decouple text information and speaker information, and more experiments are still being sorted out. This module is currently not enabled by default, if you are interested, you can try it yourself. + ### Synthesizing We use [parallel wavegan](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc1) as the default neural vocoder. diff --git a/examples/zh_en_tts/tts3/conf/default.yaml b/examples/zh_en_tts/tts3/conf/default.yaml index 06bf1fcc1d7..efa8b3ea2aa 100644 --- a/examples/zh_en_tts/tts3/conf/default.yaml +++ b/examples/zh_en_tts/tts3/conf/default.yaml @@ -74,7 +74,7 @@ model: stop_gradient_from_energy_predictor: False # whether to stop the gradient from energy predictor to encoder spk_embed_dim: 256 # speaker embedding dimension spk_embed_integration_type: concat # speaker embedding integration type - enable_speaker_classifier: True # Whether to use speaker classifier module + enable_speaker_classifier: False # Whether to use speaker classifier module hidden_sc_dim: 256 # The hidden layer dim of speaker classifier diff --git a/examples/zh_en_tts/tts3/local/train.sh b/examples/zh_en_tts/tts3/local/train.sh index 3a5076505dd..1da72f11796 100755 --- a/examples/zh_en_tts/tts3/local/train.sh +++ b/examples/zh_en_tts/tts3/local/train.sh @@ -8,6 +8,6 @@ python3 ${BIN_DIR}/train.py \ --dev-metadata=dump/dev/norm/metadata.jsonl \ --config=${config_path} \ --output-dir=${train_output_path} \ - --ngpu=1 \ + --ngpu=2 \ --phones-dict=dump/phone_id_map.txt \ --speaker-dict=dump/speaker_id_map.txt diff --git a/examples/zh_en_tts/tts3/run.sh b/examples/zh_en_tts/tts3/run.sh index a0c58f35caf..12f99081af8 100755 --- a/examples/zh_en_tts/tts3/run.sh +++ b/examples/zh_en_tts/tts3/run.sh @@ -3,9 +3,9 @@ set -e source path.sh -gpus=0 -stage=1 -stop_stage=1 +gpus=0,1 +stage=0 +stop_stage=100 datasets_root_dir=~/datasets mfa_root_dir=./mfa_results/ diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index 34a2ff98a13..0eb44beb6d2 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -25,6 +25,8 @@ from paddle import nn from typeguard import check_argument_types +from paddlespeech.t2s.modules.adversarial_loss.gradient_reversal import GradientReversalLayer +from paddlespeech.t2s.modules.adversarial_loss.speaker_classifier import SpeakerClassifier from paddlespeech.t2s.modules.nets_utils import initialize from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask from paddlespeech.t2s.modules.nets_utils import make_pad_mask @@ -37,8 +39,6 @@ from paddlespeech.t2s.modules.transformer.encoder import CNNPostnet from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder from paddlespeech.t2s.modules.transformer.encoder import TransformerEncoder -from paddlespeech.t2s.modules.multi_speakers.speaker_classifier import SpeakerClassifier -from paddlespeech.t2s.modules.multi_speakers.gradient_reversal import GradientReversalLayer class FastSpeech2(nn.Layer): @@ -140,10 +140,10 @@ def __init__( # training related init_type: str="xavier_uniform", init_enc_alpha: float=1.0, - init_dec_alpha: float=1.0, + init_dec_alpha: float=1.0, # speaker classifier enable_speaker_classifier: bool=False, - hidden_sc_dim: int=256,): + hidden_sc_dim: int=256, ): """Initialize FastSpeech2 module. Args: idim (int): @@ -388,7 +388,8 @@ def __init__( if self.spk_num and self.enable_speaker_classifier: # set lambda = 1 self.grad_reverse = GradientReversalLayer(1) - self.speaker_classifier = SpeakerClassifier(idim=adim, hidden_sc_dim=self.hidden_sc_dim, spk_num=spk_num) + self.speaker_classifier = SpeakerClassifier( + idim=adim, hidden_sc_dim=self.hidden_sc_dim, spk_num=spk_num) # define duration predictor self.duration_predictor = DurationPredictor( @@ -601,7 +602,7 @@ def _forward(self, # (B, Tmax, adim) hs, _ = self.encoder(xs, x_masks) - if self.spk_num and self.enable_speaker_classifier: + if self.spk_num and self.enable_speaker_classifier and not is_inference: hs_for_spk_cls = self.grad_reverse(hs) spk_logits = self.speaker_classifier(hs_for_spk_cls, ilens) else: @@ -794,7 +795,7 @@ def inference( es = e.unsqueeze(0) if e is not None else None # (1, L, odim) - _, outs, d_outs, p_outs, e_outs = self._inference( + _, outs, d_outs, p_outs, e_outs, _ = self._forward( xs, ilens, ds=ds, @@ -806,7 +807,7 @@ def inference( is_inference=True) else: # (1, L, odim) - _, outs, d_outs, p_outs, e_outs = self._inference( + _, outs, d_outs, p_outs, e_outs, _ = self._forward( xs, ilens, is_inference=True, @@ -815,121 +816,8 @@ def inference( spk_id=spk_id, tone_id=tone_id) - return outs[0], d_outs[0], p_outs[0], e_outs[0] - def _inference(self, - xs: paddle.Tensor, - ilens: paddle.Tensor, - olens: paddle.Tensor=None, - ds: paddle.Tensor=None, - ps: paddle.Tensor=None, - es: paddle.Tensor=None, - is_inference: bool=False, - return_after_enc=False, - alpha: float=1.0, - spk_emb=None, - spk_id=None, - tone_id=None) -> Sequence[paddle.Tensor]: - # forward encoder - x_masks = self._source_mask(ilens) - # (B, Tmax, adim) - hs, _ = self.encoder(xs, x_masks) - - # integrate speaker embedding - if self.spk_embed_dim is not None: - # spk_emb has a higher priority than spk_id - if spk_emb is not None: - hs = self._integrate_with_spk_embed(hs, spk_emb) - elif spk_id is not None: - spk_emb = self.spk_embedding_table(spk_id) - hs = self._integrate_with_spk_embed(hs, spk_emb) - - # integrate tone embedding - if self.tone_embed_dim is not None: - if tone_id is not None: - tone_embs = self.tone_embedding_table(tone_id) - hs = self._integrate_with_tone_embed(hs, tone_embs) - # forward duration predictor and variance predictors - d_masks = make_pad_mask(ilens) - - if self.stop_gradient_from_pitch_predictor: - p_outs = self.pitch_predictor(hs.detach(), d_masks.unsqueeze(-1)) - else: - p_outs = self.pitch_predictor(hs, d_masks.unsqueeze(-1)) - if self.stop_gradient_from_energy_predictor: - e_outs = self.energy_predictor(hs.detach(), d_masks.unsqueeze(-1)) - else: - e_outs = self.energy_predictor(hs, d_masks.unsqueeze(-1)) - - if is_inference: - # (B, Tmax) - if ds is not None: - d_outs = ds - else: - d_outs = self.duration_predictor.inference(hs, d_masks) - if ps is not None: - p_outs = ps - if es is not None: - e_outs = es - - # use prediction in inference - # (B, Tmax, 1) - - p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose( - (0, 2, 1)) - e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose( - (0, 2, 1)) - hs = hs + e_embs + p_embs - - # (B, Lmax, adim) - hs = self.length_regulator(hs, d_outs, alpha, is_inference=True) - else: - d_outs = self.duration_predictor(hs, d_masks) - # use groundtruth in training - p_embs = self.pitch_embed(ps.transpose((0, 2, 1))).transpose( - (0, 2, 1)) - e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose( - (0, 2, 1)) - hs = hs + e_embs + p_embs - - # (B, Lmax, adim) - hs = self.length_regulator(hs, ds, is_inference=False) - - # forward decoder - if olens is not None and not is_inference: - if self.reduction_factor > 1: - olens_in = paddle.to_tensor( - [olen // self.reduction_factor for olen in olens.numpy()]) - else: - olens_in = olens - # (B, 1, T) - h_masks = self._source_mask(olens_in) - else: - h_masks = None - if return_after_enc: - return hs, h_masks - - if self.decoder_type == 'cnndecoder': - # remove output masks for dygraph to static graph - zs = self.decoder(hs, h_masks) - before_outs = zs - else: - # (B, Lmax, adim) - zs, _ = self.decoder(hs, h_masks) - # (B, Lmax, odim) - before_outs = self.feat_out(zs).reshape( - (paddle.shape(zs)[0], -1, self.odim)) - - # postnet -> (B, Lmax//r * r, odim) - if self.postnet is None: - after_outs = before_outs - else: - after_outs = before_outs + self.postnet( - before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) - - return before_outs, after_outs, d_outs, p_outs, e_outs - def _integrate_with_spk_embed(self, hs, spk_emb): """Integrate speaker embedding with hidden states. @@ -1212,7 +1100,8 @@ def forward( olens: paddle.Tensor, spk_logits: paddle.Tensor=None, spk_ids: paddle.Tensor=None, - ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,]: + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, + paddle.Tensor, ]: """Calculate forward propagation. Args: @@ -1249,7 +1138,7 @@ def forward( """ speaker_loss = 0.0 - + # apply mask to remove padded part if self.use_masking: out_masks = make_non_pad_mask(olens).unsqueeze(-1) @@ -1273,12 +1162,13 @@ def forward( if spk_logits is not None and spk_ids is not None: batch_size = spk_ids.shape[0] - spk_ids = paddle.repeat_interleave(spk_ids, spk_logits.shape[1], None) - spk_logits = paddle.reshape(spk_logits, [-1, spk_logits.shape[-1]]) - mask_index = spk_logits.abs().sum(axis=1)!=0 + spk_ids = paddle.repeat_interleave(spk_ids, spk_logits.shape[1], + None) + spk_logits = paddle.reshape(spk_logits, + [-1, spk_logits.shape[-1]]) + mask_index = spk_logits.abs().sum(axis=1) != 0 spk_ids = spk_ids[mask_index] spk_logits = spk_logits[mask_index] - # calculate loss l1_loss = self.l1_criterion(before_outs, ys) @@ -1289,7 +1179,7 @@ def forward( energy_loss = self.mse_criterion(e_outs, es) if spk_logits is not None and spk_ids is not None: - speaker_loss = self.ce_criterion(spk_logits, spk_ids)/batch_size + speaker_loss = self.ce_criterion(spk_logits, spk_ids) / batch_size # make weighted mask and apply it if self.use_weighted_masking: diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py index 7690a9ceaa8..2b25b6a6224 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py @@ -14,6 +14,7 @@ import logging from pathlib import Path +from paddle import DataParallel from paddle import distributed as dist from paddle.io import DataLoader from paddle.nn import Layer @@ -23,6 +24,7 @@ from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator from paddlespeech.t2s.training.reporter import report from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater + logging.basicConfig( format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', datefmt='[%Y-%m-%d %H:%M:%S]') @@ -43,7 +45,8 @@ def __init__(self, super().__init__(model, optimizer, dataloader, init_state=None) self.criterion = FastSpeech2Loss( - use_masking=use_masking, use_weighted_masking=use_weighted_masking,) + use_masking=use_masking, + use_weighted_masking=use_weighted_masking, ) log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) self.filehandler = logging.FileHandler(str(log_file)) @@ -62,7 +65,21 @@ def update_core(self, batch): if spk_emb is not None: spk_id = None - with self.model.no_sync(): + if type( + self.model + ) == DataParallel and self.model._layers.spk_num and self.model._layers.enable_speaker_classifier: + with self.model.no_sync(): + before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( + text=batch["text"], + text_lengths=batch["text_lengths"], + speech=batch["speech"], + speech_lengths=batch["speech_lengths"], + durations=batch["durations"], + pitch=batch["pitch"], + energy=batch["energy"], + spk_id=spk_id, + spk_emb=spk_emb) + else: before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( text=batch["text"], text_lengths=batch["text_lengths"], @@ -87,7 +104,7 @@ def update_core(self, batch): ilens=batch["text_lengths"], olens=olens, spk_logits=spk_logits, - spk_ids=spk_id,) + spk_ids=spk_id, ) loss = l1_loss + duration_loss + pitch_loss + energy_loss + self.spk_loss_scale * speaker_loss @@ -101,16 +118,20 @@ def update_core(self, batch): report("train/duration_loss", float(duration_loss)) report("train/pitch_loss", float(pitch_loss)) report("train/energy_loss", float(energy_loss)) - report("train/speaker_loss", float(speaker_loss)) - report("train/scale_speaker_loss", float(self.spk_loss_scale * speaker_loss)) + if speaker_loss != 0.0: + report("train/speaker_loss", float(speaker_loss)) + report("train/scale_speaker_loss", + float(self.spk_loss_scale * speaker_loss)) losses_dict["l1_loss"] = float(l1_loss) losses_dict["duration_loss"] = float(duration_loss) losses_dict["pitch_loss"] = float(pitch_loss) losses_dict["energy_loss"] = float(energy_loss) losses_dict["energy_loss"] = float(energy_loss) - losses_dict["speaker_loss"] = float(speaker_loss) - losses_dict["scale_speaker_loss"] = float(self.spk_loss_scale * speaker_loss) + if speaker_loss != 0.0: + losses_dict["speaker_loss"] = float(speaker_loss) + losses_dict["scale_speaker_loss"] = float(self.spk_loss_scale * + speaker_loss) losses_dict["loss"] = float(loss) self.msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_dict.items()) @@ -145,7 +166,21 @@ def evaluate_core(self, batch): if spk_emb is not None: spk_id = None - with self.model.no_sync(): + if type( + self.model + ) == DataParallel and self.model._layers.spk_num and self.model._layers.enable_speaker_classifier: + with self.model.no_sync(): + before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( + text=batch["text"], + text_lengths=batch["text_lengths"], + speech=batch["speech"], + speech_lengths=batch["speech_lengths"], + durations=batch["durations"], + pitch=batch["pitch"], + energy=batch["energy"], + spk_id=spk_id, + spk_emb=spk_emb) + else: before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( text=batch["text"], text_lengths=batch["text_lengths"], @@ -168,9 +203,9 @@ def evaluate_core(self, batch): ps=batch["pitch"], es=batch["energy"], ilens=batch["text_lengths"], - olens=olens, + olens=olens, spk_logits=spk_logits, - spk_ids=spk_id,) + spk_ids=spk_id, ) loss = l1_loss + duration_loss + pitch_loss + energy_loss + self.spk_loss_scale * speaker_loss report("eval/loss", float(loss)) @@ -178,15 +213,19 @@ def evaluate_core(self, batch): report("eval/duration_loss", float(duration_loss)) report("eval/pitch_loss", float(pitch_loss)) report("eval/energy_loss", float(energy_loss)) - report("train/speaker_loss", float(speaker_loss)) - report("train/scale_speaker_loss", float(self.spk_loss_scale * speaker_loss)) + if speaker_loss != 0.0: + report("train/speaker_loss", float(speaker_loss)) + report("train/scale_speaker_loss", + float(self.spk_loss_scale * speaker_loss)) losses_dict["l1_loss"] = float(l1_loss) losses_dict["duration_loss"] = float(duration_loss) losses_dict["pitch_loss"] = float(pitch_loss) losses_dict["energy_loss"] = float(energy_loss) - losses_dict["speaker_loss"] = float(speaker_loss) - losses_dict["scale_speaker_loss"] = float(self.spk_loss_scale * speaker_loss) + if speaker_loss != 0.0: + losses_dict["speaker_loss"] = float(speaker_loss) + losses_dict["scale_speaker_loss"] = float(self.spk_loss_scale * + speaker_loss) losses_dict["loss"] = float(loss) self.msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_dict.items()) diff --git a/paddlespeech/t2s/modules/multi_speakers/__init__.py b/paddlespeech/t2s/modules/adversarial_loss/__init__.py similarity index 100% rename from paddlespeech/t2s/modules/multi_speakers/__init__.py rename to paddlespeech/t2s/modules/adversarial_loss/__init__.py diff --git a/paddlespeech/t2s/modules/multi_speakers/gradient_reversal.py b/paddlespeech/t2s/modules/adversarial_loss/gradient_reversal.py similarity index 99% rename from paddlespeech/t2s/modules/multi_speakers/gradient_reversal.py rename to paddlespeech/t2s/modules/adversarial_loss/gradient_reversal.py index 5250f1df163..e98758099f6 100644 --- a/paddlespeech/t2s/modules/multi_speakers/gradient_reversal.py +++ b/paddlespeech/t2s/modules/adversarial_loss/gradient_reversal.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import paddle -from paddle.autograd import PyLayer import paddle.nn as nn +from paddle.autograd import PyLayer + class GradientReversalFunction(PyLayer): """Gradient Reversal Layer from: @@ -57,4 +57,3 @@ def forward(self, x): """Forward in networks """ return GradientReversalFunction.apply(x, self.lambda_) - diff --git a/paddlespeech/t2s/modules/multi_speakers/speaker_classifier.py b/paddlespeech/t2s/modules/adversarial_loss/speaker_classifier.py similarity index 78% rename from paddlespeech/t2s/modules/multi_speakers/speaker_classifier.py rename to paddlespeech/t2s/modules/adversarial_loss/speaker_classifier.py index a64f6d5b851..d731b2d27ae 100644 --- a/paddlespeech/t2s/modules/multi_speakers/speaker_classifier.py +++ b/paddlespeech/t2s/modules/adversarial_loss/speaker_classifier.py @@ -12,14 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from Cross-Lingual-Voice-Cloning(https://github.com/deterministic-algorithms-lab/Cross-Lingual-Voice-Cloning) - -from paddle import nn import paddle +from paddle import nn from typeguard import check_argument_types + class SpeakerClassifier(nn.Layer): - - def __init__(self, idim: int, hidden_sc_dim: int, spk_num: int, ): + def __init__( + self, + idim: int, + hidden_sc_dim: int, + spk_num: int, ): assert check_argument_types() super().__init__() # store hyperparameters @@ -27,11 +30,13 @@ def __init__(self, idim: int, hidden_sc_dim: int, spk_num: int, ): self.hidden_sc_dim = hidden_sc_dim self.spk_num = spk_num - self.model = nn.Sequential(nn.Linear(self.idim, self.hidden_sc_dim), - nn.Linear(self.hidden_sc_dim, self.spk_num)) - + self.model = nn.Sequential( + nn.Linear(self.idim, self.hidden_sc_dim), + nn.Linear(self.hidden_sc_dim, self.spk_num)) + def parse_outputs(self, out, text_lengths): - mask = paddle.arange(out.shape[1]).expand([out.shape[0], out.shape[1]]) < text_lengths.unsqueeze(1) + mask = paddle.arange(out.shape[1]).expand( + [out.shape[0], out.shape[1]]) < text_lengths.unsqueeze(1) out = paddle.transpose(out, perm=[2, 0, 1]) out = out * mask out = paddle.transpose(out, perm=[1, 2, 0]) @@ -44,7 +49,7 @@ def forward(self, encoder_outputs, text_lengths): log probabilities of speaker classification = [batch_size, seq_len, spk_num] """ - - out = self.model(encoder_outputs) + + out = self.model(encoder_outputs) out = self.parse_outputs(out, text_lengths) return out