Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…o new_fs2
  • Loading branch information
liangym committed Oct 31, 2022
2 parents 96b8e42 + 06383d5 commit d92852a
Show file tree
Hide file tree
Showing 10 changed files with 96 additions and 161 deletions.
2 changes: 1 addition & 1 deletion examples/hey_snips/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
## Metrics

We mesure FRRs with fixing false alarms in one hour:

the release model: https://paddlespeech.bj.bcebos.com/kws/heysnips/kws0_mdtc_heysnips_ckpt.tar.gz
|Model|False Alarm| False Reject Rate|
|--|--|--|
|MDTC| 1| 0.003559 |
2 changes: 2 additions & 0 deletions examples/zh_en_tts/tts3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ optional arguments:
5. `--phones-dict` is the path of the phone vocabulary file.
6. `--speaker-dict` is the path of the speaker id map file when training a multi-speaker FastSpeech2.

We have **added module speaker classifier** with reference to [Learning to Speak Fluently in a Foreign Language: Multilingual Speech Synthesis and Cross-Language Voice Cloning](https://arxiv.org/pdf/1907.04448.pdf). The main parameter configuration: config["model"]["enable_speaker_classifier"], config["model"]["hidden_sc_dim"] and config["updater"]["spk_loss_scale"] in `conf/default.yaml`. The current experimental results show that this module can decouple text information and speaker information, and more experiments are still being sorted out. This module is currently not enabled by default, if you are interested, you can try it yourself.


### Synthesizing
We use [parallel wavegan](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc1) as the default neural vocoder.
Expand Down
2 changes: 1 addition & 1 deletion examples/zh_en_tts/tts3/conf/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ model:
stop_gradient_from_energy_predictor: False # whether to stop the gradient from energy predictor to encoder
spk_embed_dim: 256 # speaker embedding dimension
spk_embed_integration_type: concat # speaker embedding integration type
enable_speaker_classifier: True # Whether to use speaker classifier module
enable_speaker_classifier: False # Whether to use speaker classifier module
hidden_sc_dim: 256 # The hidden layer dim of speaker classifier


Expand Down
2 changes: 1 addition & 1 deletion examples/zh_en_tts/tts3/local/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=1 \
--ngpu=2 \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt
6 changes: 3 additions & 3 deletions examples/zh_en_tts/tts3/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
set -e
source path.sh

gpus=0
stage=1
stop_stage=1
gpus=0,1
stage=0
stop_stage=100

datasets_root_dir=~/datasets
mfa_root_dir=./mfa_results/
Expand Down
146 changes: 18 additions & 128 deletions paddlespeech/t2s/models/fastspeech2/fastspeech2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from paddle import nn
from typeguard import check_argument_types

from paddlespeech.t2s.modules.adversarial_loss.gradient_reversal import GradientReversalLayer
from paddlespeech.t2s.modules.adversarial_loss.speaker_classifier import SpeakerClassifier
from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
Expand All @@ -37,8 +39,6 @@
from paddlespeech.t2s.modules.transformer.encoder import CNNPostnet
from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder
from paddlespeech.t2s.modules.transformer.encoder import TransformerEncoder
from paddlespeech.t2s.modules.multi_speakers.speaker_classifier import SpeakerClassifier
from paddlespeech.t2s.modules.multi_speakers.gradient_reversal import GradientReversalLayer


class FastSpeech2(nn.Layer):
Expand Down Expand Up @@ -140,10 +140,10 @@ def __init__(
# training related
init_type: str="xavier_uniform",
init_enc_alpha: float=1.0,
init_dec_alpha: float=1.0,
init_dec_alpha: float=1.0,
# speaker classifier
enable_speaker_classifier: bool=False,
hidden_sc_dim: int=256,):
hidden_sc_dim: int=256, ):
"""Initialize FastSpeech2 module.
Args:
idim (int):
Expand Down Expand Up @@ -388,7 +388,8 @@ def __init__(
if self.spk_num and self.enable_speaker_classifier:
# set lambda = 1
self.grad_reverse = GradientReversalLayer(1)
self.speaker_classifier = SpeakerClassifier(idim=adim, hidden_sc_dim=self.hidden_sc_dim, spk_num=spk_num)
self.speaker_classifier = SpeakerClassifier(
idim=adim, hidden_sc_dim=self.hidden_sc_dim, spk_num=spk_num)

# define duration predictor
self.duration_predictor = DurationPredictor(
Expand Down Expand Up @@ -601,7 +602,7 @@ def _forward(self,
# (B, Tmax, adim)
hs, _ = self.encoder(xs, x_masks)

if self.spk_num and self.enable_speaker_classifier:
if self.spk_num and self.enable_speaker_classifier and not is_inference:
hs_for_spk_cls = self.grad_reverse(hs)
spk_logits = self.speaker_classifier(hs_for_spk_cls, ilens)
else:
Expand Down Expand Up @@ -794,7 +795,7 @@ def inference(
es = e.unsqueeze(0) if e is not None else None

# (1, L, odim)
_, outs, d_outs, p_outs, e_outs = self._inference(
_, outs, d_outs, p_outs, e_outs, _ = self._forward(
xs,
ilens,
ds=ds,
Expand All @@ -806,7 +807,7 @@ def inference(
is_inference=True)
else:
# (1, L, odim)
_, outs, d_outs, p_outs, e_outs = self._inference(
_, outs, d_outs, p_outs, e_outs, _ = self._forward(
xs,
ilens,
is_inference=True,
Expand All @@ -815,121 +816,8 @@ def inference(
spk_id=spk_id,
tone_id=tone_id)


return outs[0], d_outs[0], p_outs[0], e_outs[0]

def _inference(self,
xs: paddle.Tensor,
ilens: paddle.Tensor,
olens: paddle.Tensor=None,
ds: paddle.Tensor=None,
ps: paddle.Tensor=None,
es: paddle.Tensor=None,
is_inference: bool=False,
return_after_enc=False,
alpha: float=1.0,
spk_emb=None,
spk_id=None,
tone_id=None) -> Sequence[paddle.Tensor]:
# forward encoder
x_masks = self._source_mask(ilens)
# (B, Tmax, adim)
hs, _ = self.encoder(xs, x_masks)

# integrate speaker embedding
if self.spk_embed_dim is not None:
# spk_emb has a higher priority than spk_id
if spk_emb is not None:
hs = self._integrate_with_spk_embed(hs, spk_emb)
elif spk_id is not None:
spk_emb = self.spk_embedding_table(spk_id)
hs = self._integrate_with_spk_embed(hs, spk_emb)

# integrate tone embedding
if self.tone_embed_dim is not None:
if tone_id is not None:
tone_embs = self.tone_embedding_table(tone_id)
hs = self._integrate_with_tone_embed(hs, tone_embs)
# forward duration predictor and variance predictors
d_masks = make_pad_mask(ilens)

if self.stop_gradient_from_pitch_predictor:
p_outs = self.pitch_predictor(hs.detach(), d_masks.unsqueeze(-1))
else:
p_outs = self.pitch_predictor(hs, d_masks.unsqueeze(-1))
if self.stop_gradient_from_energy_predictor:
e_outs = self.energy_predictor(hs.detach(), d_masks.unsqueeze(-1))
else:
e_outs = self.energy_predictor(hs, d_masks.unsqueeze(-1))

if is_inference:
# (B, Tmax)
if ds is not None:
d_outs = ds
else:
d_outs = self.duration_predictor.inference(hs, d_masks)
if ps is not None:
p_outs = ps
if es is not None:
e_outs = es

# use prediction in inference
# (B, Tmax, 1)

p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
(0, 2, 1))
e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs = hs + e_embs + p_embs

# (B, Lmax, adim)
hs = self.length_regulator(hs, d_outs, alpha, is_inference=True)
else:
d_outs = self.duration_predictor(hs, d_masks)
# use groundtruth in training
p_embs = self.pitch_embed(ps.transpose((0, 2, 1))).transpose(
(0, 2, 1))
e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs = hs + e_embs + p_embs

# (B, Lmax, adim)
hs = self.length_regulator(hs, ds, is_inference=False)

# forward decoder
if olens is not None and not is_inference:
if self.reduction_factor > 1:
olens_in = paddle.to_tensor(
[olen // self.reduction_factor for olen in olens.numpy()])
else:
olens_in = olens
# (B, 1, T)
h_masks = self._source_mask(olens_in)
else:
h_masks = None
if return_after_enc:
return hs, h_masks

if self.decoder_type == 'cnndecoder':
# remove output masks for dygraph to static graph
zs = self.decoder(hs, h_masks)
before_outs = zs
else:
# (B, Lmax, adim)
zs, _ = self.decoder(hs, h_masks)
# (B, Lmax, odim)
before_outs = self.feat_out(zs).reshape(
(paddle.shape(zs)[0], -1, self.odim))

# postnet -> (B, Lmax//r * r, odim)
if self.postnet is None:
after_outs = before_outs
else:
after_outs = before_outs + self.postnet(
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))

return before_outs, after_outs, d_outs, p_outs, e_outs

def _integrate_with_spk_embed(self, hs, spk_emb):
"""Integrate speaker embedding with hidden states.
Expand Down Expand Up @@ -1212,7 +1100,8 @@ def forward(
olens: paddle.Tensor,
spk_logits: paddle.Tensor=None,
spk_ids: paddle.Tensor=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,]:
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
paddle.Tensor, ]:
"""Calculate forward propagation.
Args:
Expand Down Expand Up @@ -1249,7 +1138,7 @@ def forward(
"""
speaker_loss = 0.0

# apply mask to remove padded part
if self.use_masking:
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
Expand All @@ -1273,12 +1162,13 @@ def forward(

if spk_logits is not None and spk_ids is not None:
batch_size = spk_ids.shape[0]
spk_ids = paddle.repeat_interleave(spk_ids, spk_logits.shape[1], None)
spk_logits = paddle.reshape(spk_logits, [-1, spk_logits.shape[-1]])
mask_index = spk_logits.abs().sum(axis=1)!=0
spk_ids = paddle.repeat_interleave(spk_ids, spk_logits.shape[1],
None)
spk_logits = paddle.reshape(spk_logits,
[-1, spk_logits.shape[-1]])
mask_index = spk_logits.abs().sum(axis=1) != 0
spk_ids = spk_ids[mask_index]
spk_logits = spk_logits[mask_index]


# calculate loss
l1_loss = self.l1_criterion(before_outs, ys)
Expand All @@ -1289,7 +1179,7 @@ def forward(
energy_loss = self.mse_criterion(e_outs, es)

if spk_logits is not None and spk_ids is not None:
speaker_loss = self.ce_criterion(spk_logits, spk_ids)/batch_size
speaker_loss = self.ce_criterion(spk_logits, spk_ids) / batch_size

# make weighted mask and apply it
if self.use_weighted_masking:
Expand Down
Loading

0 comments on commit d92852a

Please sign in to comment.