Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training the model end2end #11

Closed
erogol opened this issue Jan 31, 2023 · 7 comments
Closed

Training the model end2end #11

erogol opened this issue Jan 31, 2023 · 7 comments

Comments

@erogol
Copy link

erogol commented Jan 31, 2023

This is just a heads-up about the discussion in #7.

I tried training the model end2end in different ways but F0 and Energy predictor was always underfitting although eval loss was also going down. They never were able to predict useful values for inference.

Here is roughly my forward pass. I can also share the branch if useful. Happy to see any feedback.

@typechecked
    def forward_all(
        self,
        texts: TensorType["B", "T_text"],
        input_lengths: TensorType["B"],
        mels: TensorType["B", "C_mel", "T_mel"],
        mel_input_length: TensorType["B"],
        F0_real: TensorType["B", 1, "T_mel"],
    ):
        # TODO: use Pitch Extractor (maybe torchcrepe)
        # mask = length_to_mask(mel_input_length // (2 ** model.text_aligner.n_down)).to(self.device)
        text_mask = self.lengths_to_mask(input_lengths).to(self.device)
        mel_mask = self.lengths_to_mask(mel_input_length).to(self.device)

        ##### --> TEXT ENCODER
        t_en, t_emb = self.text_encoder(texts, input_lengths, length_to_mask(input_lengths))

        ##### --> ALIGNER
        _, aligner_soft, aligner_logprob, aligner_hard = self._forward_aligner(
            x=t_emb.detach().transpose(1, 2),
            y=mels,
            x_mask=text_mask,
            y_mask=mel_mask,
            attn_priors=None,
        )

        ##### --> EXPAND
        t_en_ex = t_en @ aligner_hard.squeeze(1)

        ##### --> PRUNE THE BATCH BY THE SHORTEST MEL LENGTH
        mel_len = int(mel_input_length.min().item())
        t_en_ex_clipped = []
        mels_clipped = []
        F0s = []
        idxs = []
        for bib in range(len(mel_input_length)):
            mel_length = int(mel_input_length[bib].item()) + 1

            random_start = np.random.randint(0, mel_length - mel_len)
            idxs.append(random_start)
            t_en_ex_clipped.append(t_en_ex[bib, :, random_start : random_start + mel_len])
            mels_clipped.append(mels[bib, :, random_start : random_start + mel_len])
            F0s.append(F0_real[bib, :, random_start : random_start + mel_len])

        t_en_ex_clipped = torch.stack(t_en_ex_clipped)
        mels_clipped = torch.stack(mels_clipped).detach()
        F0_real = torch.stack(F0s).detach()

        ##### --> CALCULATE REAL ENERGY
        N_real = log_norm(mels_clipped.unsqueeze(1)).squeeze(1).detach()
        # F0_real, _, _ = self.pitch_extractor(gt.unsqueeze(1))

        ##### --> STYLE ENCODER
        s = self.style_encoder(mels_clipped.unsqueeze(1))

        ##### --> DURATION PREDICTOR & PROSOODY ENCODER
        dur_aligner = aligner_hard.sum(axis=-1).detach()
        dur_pred, prosody_en_ex = self.predictor(
            t_en.detach(), s.detach(), input_lengths, aligner_hard.squeeze(1), length_to_mask(input_lengths)
        )  # [B, T_en]
        d_align_mask = self.lengths_to_mask(input_lengths) * self.lengths_to_mask(mel_input_length).transpose(
            1, 2
        )  # [B, 1, T_enc] * [B, T_dec, 1]
        dur_alignment = generate_path(dur_pred, d_align_mask.squeeze(1).transpose(1, 2)).detach()  # [B, T_dec, T_enc]

        # Strip prosody tensor to match the mel length
        p_en = []
        for bib, start in enumerate(idxs):
            p_en.append(prosody_en_ex[bib, :, start : (start + mel_len)])
        p_en = torch.stack(p_en)

        ##### --> Pitch and Energy Predictor
        F0_fake, N_fake = self.predictor.F0Ntrain(p_en, s.detach())

        ##### --> DECODER
        mel_rec = self.decoder(t_en_ex_clipped, F0_real.squeeze(1), N_real, s)
        return {
            "mel_rec": mel_rec,
            "gt": mels_clipped,
            "aligner_logprob": aligner_logprob,
            "aligner_hard": aligner_hard.squeeze(1),
            "aligner_soft": aligner_soft,
            "F0_real": F0_real.squeeze(1),
            "F0_fake": F0_fake,
            "N_real": N_real,
            "N_fake": N_fake,
            "d": dur_pred,
            "d_gt": dur_aligner.squeeze(1),
            "d_alignment": dur_alignment,
        }
@yl4579
Copy link
Owner

yl4579 commented Jan 31, 2023

I'm not sure what you mean by never being able to predict useful values for inference. It looks like you never used F0_fake and N_fake anywhere in your code. Do you mean the F0 and energy produced when training E2E is much worse than training in two stages?

@erogol
Copy link
Author

erogol commented Feb 1, 2023

sorry that it is unclear what I mean from the code. Let me try to explain.

This is just the training pass. It is quite the same idea as Fast speech like models. I try to optimize the predictors on L1loss alongside the model instead of a separate 2nd stage training.

Inference is pretty much same as your code.

What I mean is that the predicted f0 and energy values do not really approximate the correct values. When actually you plot theu look quite random.

I can get you the other parts of the training code if you are interested.

@yl4579
Copy link
Owner

yl4579 commented Feb 1, 2023

How does this compare to those predicted when you trained them in two stages? Did you find the same problem when you trained stage 1 and 2 separately? How does the gradient flow to the predictor with your F0_fake and N_fake? What was the loss function for these two things, like did you use anything other than L1 loss?

@erogol
Copy link
Author

erogol commented Feb 1, 2023

Training 2 stages worked way better but I did not wait for the 2nd stage to converge. I produced a speech with reasonable quality.

BTW how do you decide when to stop training each stage?

@yl4579
Copy link
Owner

yl4579 commented Feb 1, 2023

Then it could be some problems when you train them E2E. Can you try to include all objectives in the second stage, not just the L1 loss? FastSpeech 2 uses binned representation but here we use the exact F0 curve, so just using L1 loss on the curve alone may not be sufficient.

As for how I decide when to stop training, it is simple for the first stage as it is just like training a vocoder for reconstruction, so you can just focus on the validation mel loss. For the second stage, it can be a little bit difficult, but usually I focus on the duration loss. Once the duration loss converges, the quality of the model is pretty much fixed. The F0 might still change a little bit, but the effect is negligible so you can stop at any time after the duration loss converges. This is likely because duration loss converges means the representation learned from the text does not change anymore, and F0 prediction depends on these representations learned for duration prediction.

@yl4579
Copy link
Owner

yl4579 commented Jun 14, 2023

I think I can close this issue because I have managed to train it end-to-end in StyleTTS 2: https://github.com/yl4579/StyleTTS2. Please follow this repo. I will clean up the code and make it public around July or August.

@yl4579 yl4579 closed this as completed Jun 14, 2023
@yl4579
Copy link
Owner

yl4579 commented Nov 20, 2023

Hi Eren, not sure if you are still interested in StyleTTS (and StyleTTS 2). Now StyleTTS 2 has got some attentions, and people are interested in multilingual supports: yl4579/StyleTTS2#41. It would be greatly appreciated if you help integrate this model into Coqui with multilingual support. You can email me at [email protected] if you have any further questions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants