Skip to content

Commit

Permalink
Fix up
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed May 17, 2022
1 parent 8e915b7 commit 2d29e82
Showing 1 changed file with 32 additions and 16 deletions.
48 changes: 32 additions & 16 deletions TTS/tts/models/forward_tts_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def __init__(
def _compute_and_save_pitch(audio_config, wav_file, pitch_file=None):
wav, _ = load_audio(wav_file)
f0 = compute_f0(
x=wav.numpy()[0], sample_rate=audio_config.sample_rate, hop_length=audio_config.hop_length, pitch_fmax=audio_config.pitch_fmax
x=wav.numpy()[0],
sample_rate=audio_config.sample_rate,
hop_length=audio_config.hop_length,
pitch_fmax=audio_config.pitch_fmax,
)
# skip the last F0 value to align with the spectrogram
if wav.shape[1] % audio_config.hop_length != 0:
Expand All @@ -104,7 +107,9 @@ def compute_or_load(self, wav_file):
"""
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
if not os.path.exists(pitch_file):
pitch = self._compute_and_save_pitch(audio_config=self.audio_config, wav_file=wav_file, pitch_file=pitch_file)
pitch = self._compute_and_save_pitch(
audio_config=self.audio_config, wav_file=wav_file, pitch_file=pitch_file
)
else:
pitch = np.load(pitch_file)
return pitch.astype(np.float32)
Expand Down Expand Up @@ -300,6 +305,9 @@ class ForwardTTSE2eArgs(ForwardTTSArgs):
upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2])
upsample_initial_channel_decoder: int = 512
upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
# discriminator
upsampling_rates_discriminator: List[int] = field(default_factory=lambda: [4, 4, 4, 4])
periods_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11])
# multi-speaker params
use_speaker_embedding: bool = False
num_speakers: int = 0
Expand Down Expand Up @@ -359,7 +367,18 @@ def __init__(

# use Vits Discriminator for limiting VRAM use
if self.args.init_discriminator:
self.disc = VitsDiscriminator(use_spectral_norm=self.args.use_spectral_norm_discriminator)
self.disc = VitsDiscriminator(
use_spectral_norm=self.args.use_spectral_norm_discriminator,
periods=self.args.periods_discriminator,
upsampling_rates=self.args.upsampling_rates_discriminator,
)

# def check_model_args(self):
# upsample_rate = torch.prod(torch.as_tensor(self.args.upsample_rates_decoder)).item()
# if s
# assert (
# upsample_rate == self.config.audio.hop_length
# ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}"

def init_multispeaker(self, config: Coqpit):
"""Init for multi-speaker training.
Expand Down Expand Up @@ -440,9 +459,10 @@ def forward(
let_short_samples=True,
pad_short=True,
)

vocoder_output = self.waveform_decoder(
x=o_en_ex_slices.detach() if self.args.detach_vocoder_input else o_en_ex_slices,
g=encoder_outputs["g"],
g=encoder_outputs["spk_emb"],
)
wav_seg = segment(
waveform,
Expand All @@ -461,7 +481,7 @@ def forward(
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
encoder_outputs = self.encoder_model.inference(x=x, aux_input=aux_input, skip_decoder=True)
o_en_ex = encoder_outputs["o_en_ex"]
vocoder_output = self.waveform_decoder(x=o_en_ex, g=encoder_outputs["g"])
vocoder_output = self.waveform_decoder(x=o_en_ex, g=encoder_outputs["spk_emb"])
model_outputs = {**encoder_outputs}
model_outputs["model_outputs"] = vocoder_output
return model_outputs
Expand Down Expand Up @@ -860,9 +880,11 @@ def format_batch_on_device(self, batch):
center=False,
)

assert (
batch["pitch"].shape[2] == batch["mel_input"].shape[2]
), f"{batch['pitch'].shape[2]}, {batch['mel'].shape[2]}"
# TODO: Align pitch properly
# assert (
# batch["pitch"].shape[2] == batch["mel_input"].shape[2]
# ), f"{batch['pitch'].shape[2]}, {batch['mel_input'].shape[2]}"
batch["pitch"] = batch["pitch"][:, :, : batch["mel_input"].shape[2]]
batch["mel_lengths"] = (batch["mel_input"].shape[2] * batch["waveform_rel_lens"]).int()

# zero the padding frames
Expand Down Expand Up @@ -990,9 +1012,7 @@ def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List
# language_manager = LanguageManager.init_from_config(config)
return ForwardTTSE2e(config=new_config, tokenizer=tokenizer, speaker_manager=speaker_manager)

def load_checkpoint(
self, config, checkpoint_path, eval=False
):
def load_checkpoint(self, config, checkpoint_path, eval=False):
"""Load model from a checkpoint created by the 👟"""
# pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
Expand All @@ -1003,11 +1023,7 @@ def load_checkpoint(

def get_state_dict(self):
"""Custom state dict of the model with all the necessary components for inference."""
save_state = {
"config": self.config.to_dict(),
"args": self.args.to_dict(),
"model": self.state_dict
}
save_state = {"config": self.config.to_dict(), "args": self.args.to_dict(), "model": self.state_dict}

if hasattr(self, "emb_g"):
save_state["speaker_ids"] = self.speaker_manager.speaker_ids
Expand Down

0 comments on commit 2d29e82

Please sign in to comment.