diff --git a/TTS/auto_tts/complete_recipes.py b/TTS/auto_tts/complete_recipes.py new file mode 100644 index 0000000000..c806340c9f --- /dev/null +++ b/TTS/auto_tts/complete_recipes.py @@ -0,0 +1,309 @@ +import zipfile + +import requests +import tqdm + +from TTS.auto_tts.model_hub import TtsModels, VocoderModels +from TTS.auto_tts.utils import data_loader +from TTS.trainer import Trainer, TrainingArgs, init_training + + +class TtsAutoTrainer(TtsModels): + """ + Args: + + data_path (str): + The path to the dataset. Defaults to None. + + dataset (str): + The dataset identifier. ex: ljspeech would be "ljspeech". Defaults to None. + See auto_tts utils for specific dataset names. + + batch_size (int): + The size the batches you pass to the model. This will depend on gpu memory. + less than 32 is not recommended. Defaults to 32. + + output_path (str): + The path where you want to the model config and model weights. If it is None it will + use your current directory. Defaults to None + + mixed_precision (bool): + enables mixed precision training. can make batch sizes bigger and make training faster. + Could also make some trainings unstable. Defualts to False. + + learning_rate (float): + The learning rate for the model. Defaults to 1e-3. + + epochs (int): + how many times you want to model to go through the entire dataset. This usually doesn't need changing. + Defaults to 1000. + + Usage: + Python: + From TTS.auto_tts.complete_recipes import TtsAutoTrainer + trainer = TtsAutoTrainer(data_path='DEFINE THIS', stats_path=None, dataset="DEFINE THIS" batch_size=32, learning_rate=0.001, + mixed_precision=False, output_path='DEFINE THIS', epochs=1000) + model = trainer.single_speaker_autotts("tacotron2, "double decoder consistency") + model.fit() + + command line: + python single_speaker_autotts.py --data_path ../LJSpeech-1.1 --dataset ljspeech --batch_size 32 --mixed_precision + --model tacotron2 --tacotron2_model_type double decoder consistency --forward_attention + --location_attention + + """ + + def __init__( + self, + data_path=None, + dataset=None, + batch_size=32, + output_path=None, + mixed_precision=False, + learning_rate=1e-3, + epochs=1000, + ): + + super().__init__(batch_size, mixed_precision, learning_rate, epochs, output_path) + self.data_path = data_path + self.dataset_name = dataset + + def single_speaker_autotts( # im actually going to change this to autotts_recipes and i'm making a more generic + # single_speaker_autotts cause it's gonna get too clunky when implenting fine tuning + # all in the same function. it'll be finished in the next commit + self, + model_name, + stats_path=None, + tacotron2_model_type=None, + glow_tts_encoder=None, + forward_attention=False, + location_attention=True, + pretrained=False, + ): + """ + + Args: + model_name (str): + name of the model you want to train. Defaults to None. + + + stats_path (str): + Optional, Stats path for the audio config if the model uses it. Defaults to None. + + + tacotron2_model_type (str): + Optional, Type of tacotron2 model you want to train, either double deocder consistency, + or dynamic convolution attention. Defaults to None. + + + glow_tts_encoder (str): + Optional, Type of encoder to train glow tts with. either transformer, gated, + residual_bn, or time_depth. Defaults to None. + + + forward_attention: + Optional, Whether to use forward attention or not on tacotron2 models, + Usaully makes the model allign faster. Defaults to False. + + + location_attention: + Optional, Whether to use location attention or not on Tacotron2 models. Defaults to True. + + + pretrained (str): + whether to use a pre trained model or not, This is recommended if you are training on + custom data. Defaults to False + + """ + + audio, dataset = data_loader(name=self.dataset_name, path=self.data_path, stats_path=stats_path) + if self.dataset_name == "ljspeech": + if model_name == "tacotron2": + if tacotron2_model_type == "double decoder consistency": + model_config = self._single_speaker_tacotron2_DDC( + audio, dataset, forward_attn=forward_attention, location_attn=location_attention + ) + elif tacotron2_model_type == "dynamic convolution attention": + model_config = self._single_speaker_tacotron2_DCA( + audio, dataset, forward_attn=forward_attention, location_attn=location_attention + ) + else: + model_config = self._single_speaker_tacotron2_base( + audio, dataset, forward_attn=forward_attention, location_attn=location_attention + ) + elif model_name == "glow tts": + model_config = self._single_speaker_glow_tts(audio, dataset, encoder=glow_tts_encoder) + elif model_name == "vits tts": + model_config = self._single_speaker_vits_tts(audio, dataset) + elif model_name == "fast pitch": + model_config = self._ljspeech_fast_fastpitch(audio, dataset) + elif self.dataset_name == "baker": + if model_name == "tacotron2": + if tacotron2_model_type == "double decoder consistency": + model_config = self._single_speaker_tacotron2_DDC( + audio, + dataset, + pla=0.5, + dla=0.5, + ga=0.0, + forward_attn=forward_attention, + location_attn=location_attention, + ) + elif tacotron2_model_type == "dynamic convolution attention": + model_config = self._single_speaker_tacotron2_DCA( + audio, + dataset, + pla=0.5, + dla=0.5, + ga=0.0, + forward_attn=forward_attention, + location_attn=location_attention, + ) + args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), model_config) + trainer = Trainer(args, config, output_path, c_logger, tb_logger) + return trainer + + def multi_speaker_autotts( + self, model_name, speaker_file, glowtts_encoder=None, r=2, forward_attn=True, location_attn=False + ): + """ + + Args: + location_attn (bool): + enable location attention for tacotron2 model. Defaults to True. + + + r (int): + set the r for tacotron2 model. Defaults to 2. + + + forward_attn (bool): + set forward attention for tacotron2 model. Defaults to True. + + + model_name (str): + name of the model you want to train. Defaults to None. + + + speaker_file (str): + Path to either the d_vector file for glow_tts or speaker ids file for vits. + Defaults to None + + + glowtts_encoder: + Optional, which encoder you want the glow tts model to use. Defaults to None. + + """ + audio, dataset = data_loader(name=self.dataset_name, path=self.data_path, stats_path=None) + if self.dataset_name == "vctk": + if model_name == "glow tts": + model_config = self._sc_glow_tts(audio, dataset, speaker_file, encoder=glowtts_encoder) + elif model_name == "vits tts": + model_config = self._vctk_vits_tts(audio, dataset, speaker_file) + elif model_name == "tacotron2": + model_config = self._multi_speaker_vctk_tacotron2( + audio, dataset, speaker_file, r=r, forward_attn=forward_attn, location_attn=location_attn + ) + args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), model_config) + trainer = Trainer(args, config, output_path, c_logger, tb_logger) + return trainer + + +class VocoderAutoTrainer(VocoderModels): + """ + Args: + + data_path (str): + The path to the dataset. Defaults to None. + + dataset (str): + The dataset identifier. ex: ljspeech would be "ljspeech". Defaults to None. + See auto_tts utils for specific dataset names. + + batch_size (int): + The size the batches you pass to the model. This will depend on gpu memory. + less than 32 is not recommended. Defaults to 32. + + output_path (str): + The path where you want to the model config and model weights. If it is None it will + use your current directory. Defaults to None + + mixed_precision (bool): + enables mixed precision training. can make batch sizes bigger and make training faster. + Could also make some trainings unstable. Defualts to False. + + learning_rate (List [float, float]): + The learning rate for the model. This should be a list with the generator rate being first + and discrimiator rate being second. Defaults to [1e-3, 1e-3]. + + epochs (int): + how many times you want to model to go through the entire dataset. This usually doesn't need changing. + Defaults to 1000. + + Usage: + Python: + From TTS.auto_tts.complete_recipes import VocoderAutoTrainer + trainer = VocoderAutoTrainer(data_path='DEFINE THIS', stats_path=None, dataset="DEFINE THIS", + batch_size=32, learning_rate=[1e-3, 1e-3], + mixed_precision=False, output_path='DEFINE THIS', epochs=1000) + model = trainer.single_speaker_autotts("hifigan") + model.fit() + + command line: + python vocoder_autotts.py --data_path ../LJSpeech-1.1 --dataset ljspeech --batch_size 32 --mixed_precision + --model hifigan + + """ + + def __init__( + self, + data_path=None, + dataset=None, + batch_size=32, + output_path=None, + mixed_precision=False, + learning_rate=None, + epochs=1000, + ): + if learning_rate is None: + learning_rate = [0.001, 0.001] + super().__init__( + batch_size, + mixed_precision, + generator_learning_rate=learning_rate[0], + discriminator_learning_rate=learning_rate[1], + epochs=epochs, + output_path=output_path, + ) + self.data_path: str = data_path + self.dataset_name: str = dataset + + def single_speaker_autotts(self, model_name, stats_path=None): + """ + Args: + + model_name (str): + name of the model you want to train. + + Stats_path (str): + Optional, Path to the stats file for the audio config. Defaults to None. + + """ + if self.dataset_name == "ljspeech": + audio, _ = data_loader(name="ljspeech", path=self.data_path, stats_path=stats_path) + if model_name == "hifigan": + model_config = self._hifi_gan(audio, self.data_path) + elif model_name == "wavegrad": + model_config = self._wavegrad(audio, self.data_path) + elif model_name == "univnet": + model_config = self._univnet(audio, self.data_path) + elif model_name == "multiband melgan": + model_config = self._multiband_melgan(audio, self.data_path) + elif model_name == "wavernn": + model_config = self._wavernn(audio, self.data_path) + args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), model_config) + trainer = Trainer(args, config, output_path, c_logger, tb_logger) + return trainer + + def from_pretrained(model_name): + pass diff --git a/TTS/auto_tts/example.py b/TTS/auto_tts/example.py new file mode 100644 index 0000000000..5f87927434 --- /dev/null +++ b/TTS/auto_tts/example.py @@ -0,0 +1,12 @@ +from TTS.utils.manage import ModelManager + +manager = ModelManager() +model_path, config_path, x = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA") + +print(model_path) + +print(config_path) + +print(x) + +manager.list_models() diff --git a/TTS/auto_tts/model_hub.py b/TTS/auto_tts/model_hub.py new file mode 100644 index 0000000000..ecdd61b784 --- /dev/null +++ b/TTS/auto_tts/model_hub.py @@ -0,0 +1,530 @@ +import os + +from recipes.ljspeech.fast_pitch.train_fast_pitch import config as fastpitch_config +from recipes.ljspeech.glow_tts.train_glowtts import config as glowtts_config +from recipes.ljspeech.hifigan.train_hifigan import config as hifigan_config +from recipes.ljspeech.multiband_melgan.train_multiband_melgan import config as multiband_melgan_config +from recipes.ljspeech.univnet.train import config as univnet_config +from recipes.ljspeech.vits_tts.train_vits import config as vits_config +from recipes.ljspeech.wavegrad.train_wavegrad import config as wavegrad_config +from recipes.ljspeech.wavernn.train_wavernn import config as waverrn_config +from TTS.auto_tts.utils import pick_forwardtts_decoder, pick_forwardtts_encoder, pick_glowtts_encoder +from TTS.trainer import TrainingArgs +from TTS.tts.configs.glow_tts_config import GlowTTSConfig +from TTS.tts.configs.tacotron2_config import Tacotron2Config +from TTS.tts.models.forward_tts import ForwardTTSArgs +from TTS.utils.manage import ModelManager + +# os.environ["CUDA_VISIBLE_DEVICES"] = "1" + + +class TtsModels: + def __init__( + self, + batch_size: int, + mixed_precision: bool, + learning_rate: float, + epochs: int, + output_path: str = os.path.dirname(os.path.abspath(__file__)), + ): + + self.batch_size = batch_size + self.output_path = output_path + self.mixed_precision = mixed_precision + self.learning_rate = learning_rate + self.epochs = epochs + self.manager = ModelManager() + + def _single_speaker_from_pretrained(self, model_name): + if model_name == "english glow-tts": + model_path, config_path, _ = self.manager.download_model("tts_models/en/ljspeech/glow-tts") + elif model_name == "english tacotron2-DDC": + model_path, config_path, _ = self.manager.download_model("tts_models/en/ljspeech/tacotron2-DDC") + elif model_name == " english tacotron2": + model_path, config_path, _ = self.manager.download_model("tts_models/en/ek1/tacotron2") + elif model_name == "english tacotron2-DCA": + model_path, config_path, _ = self.manager.download_model("tts_models/en/ljspeech/tacotron2-DCA") + elif model_name == "english speedy speech": + model_path, config_path, _ = self.manager.download_model("tts_models/en/ljspeech/speedy-speech") + elif model_name == "english vits": + model_path, config_path, _ = self.manager.download_model("tts_models/en/ljspeech/vits") + elif model_name == "english fast speech": + model_path, config_path, _ = self.manager.download_model("tts_models/en/ljspeech/fast_pitch") + elif model_name == "spanish tacotron2-DDC": + model_path, config_path, _ = self.manager.download_model("tts_models/es/mai/tacotron2-DDC") + elif model_name == "french tacotron2-DDC": + model_path, config_path, _ = self.manager.download_model("tts_models/fr/mai/tacotron2-DDC") + elif model_name == "chinese tacotron2-DDC": + model_path, config_path, _ = self.manager.download_model("tts_models/zh-CN/baker/tacotron2-DDC-GST") + elif model_name == "german tacotron2-DCA": + model_path, config_path, _ = self.manager.download_model("tts_models/de/thorsten/tacotron2-DCA") + elif model_name == "japanese tacotron2-DDC": + model_path, config_path, _ = self.manager.download_model("tts_models/ja/kokoro/tacotron2-DDC") + + training_args = TrainingArgs(restore_path=model_path) + return model_path, config_path, training_args + + def _multi_speaker_from_pretrained(self, model_name): + if model_name == " english sc-glow-tts": + model_path, config_path, _ = self.manager.download_model("tts_models/en/vctk/sc-glow-tts") + elif model_name == "english vits": + model_path, config_path, _ = self.manager.download_model("tts_models/en/vctk/vits") + training_args = TrainingArgs(restore_path=model_path) + return model_path, config_path, training_args + + def _single_speaker_tacotron2_base( + self, audio, dataset, dla=0.25, pla=0.25, ga=5.0, forward_attn=True, location_attn=True + ): + config = Tacotron2Config( + run_name="single_speaker_tacotron2", + audio=audio, + batch_size=self.batch_size, + eval_batch_size=int(self.batch_size / 2), + r=2, + grad_clip=1, + lr=self.learning_rate, + decoder_loss_alpha=dla, + postnet_loss_alpha=pla, + postnet_diff_spec_alpha=0.25, + decoder_diff_spec_alpha=0.25, + decoder_ssim_alpha=0.25, + postnet_ssim_alpha=0.25, + ga_alpha=ga, + stopnet_pos_weight=15.0, + memory_size=-1, + phoneme_cache_path=os.path.join(self.output_path, "phoneme_cache"), + prenet_type="original", + prenet_dropout=True, + attention_type="original", + attention_heads=5, + attention_norm="sigmoid", + windowing=False, + use_forward_attn=forward_attn, + forward_attn_mask=False, + transition_agent=False, + location_attn=location_attn, + stopnet=True, + separate_stopnet=True, + print_step=25, + save_step=10000, + checkpoint=True, + text_cleaner="basic_cleaners", + num_loader_workers=6, + num_eval_loader_workers=6, + min_seq_len=6, + max_seq_len=150, + output_path=self.output_path, + datasets=[dataset], + ) + return config + + def _multi_speaker_vctk_tacotron2( + self, audio, dataset, speaker_file, ga=10.0, r=7, forward_attn=True, location_attn=True + ): + config = Tacotron2Config( + audio=audio, + run_name="mulit_speaker_tacotron2_vctk", + run_description="multi speaker tacotron2 trained on vctk dataset.", + batch_size=self.batch_size, + eval_batch_size=int(self.batch_size / 2), + mixed_precision=self.mixed_precision, + # gradual_training="", + r=r, + loss_masking=True, + ga_alpha=ga, + run_eval=True, + test_delay_epochs=-1, + grad_clip=1.0, + epochs=self.epochs, + lr=self.learning_rate, + seq_len_norm=True, + memory_size=-1, + attention_type="original", + attention_heads=4, + attention_norm="softmax", + windowing=False, + use_forward_attn=forward_attn, + location_attn=location_attn, + forward_attn_mask=False, + transition_agent=False, + ddc_r=r, + stopnet=True, + separate_stopnet=True, + print_step=25, + plot_step=100, + print_eval=False, + save_step=200, + checkpoint=True, + text_cleaner="english_cleaners", + phoneme_cache_path=os.path.join(self.output_path, "phoneme_cache"), + num_loader_workers=6, + num_eval_loader_workers=6, + max_decoder_steps=1000, + use_speaker_embedding=True, + use_d_vector_file=True, + d_vector_dim=512, + d_vector_file=speaker_file, + min_seq_len=6, + max_seq_len=190, + use_phonemes=True, + use_espeak_phonemes=True, + datasets=[dataset], + ) + return config + + def _single_speaker_tacotron2_DDC( + self, audio, dataset, dla=0.25, pla=0.25, ga=5.0, forward_attn=False, location_attn=True + ): + config = Tacotron2Config( + audio=audio, + run_name="ljspeech-ddc", + run_description="tacotron2 with double decoder consistency.", + batch_size=self.batch_size, + eval_batch_size=self.batch_size // 2, + mixed_precision=False, + loss_masking=True, + decoder_loss_alpha=dla, + postnet_loss_alpha=pla, + postnet_diff_spec_alpha=0.25, + decoder_diff_spec_alpha=0.25, + decoder_ssim_alpha=0.25, + postnet_ssim_alpha=0.25, + ga_alpha=ga, + stopnet_pos_weight=15.0, + run_eval=True, + test_delay_epochs=10, + max_decoder_steps=1000, + grad_clip=0.05, + epochs=self.epochs, + lr=self.learning_rate, + memory_size=-1, + prenet_type="original", + use_forward_attn=forward_attn, + prenet_dropout=True, + attention_type="original", + location_attn=location_attn, + double_decoder_consistency=True, + ddc_r=6, + attention_norm="sigmoid", + r=6, + gradual_training=[ + [0, 6, self.batch_size], + [10000, 4, self.batch_size // 2], + [50000, 3, self.batch_size // 2], + [100000, 2, self.batch_size // 2], + ], + stopnet=True, + separate_stopnet=True, + print_step=25, + print_eval=False, + plot_step=100, + save_step=10000, + checkpoint=True, + text_cleaner="phoneme_cleaners", + num_loader_workers=4, + num_eval_loader_workers=4, + batch_group_size=4, + min_seq_len=6, + max_seq_len=180, + compute_input_seq_cache=True, + phoneme_cache_path=os.path.join(self.output_path, "phoneme_cache"), + output_path=self.output_path, + use_phonemes=False, + phoneme_language="en-us", + datasets=[dataset], + ) + return config + + def _single_speaker_tacotron2_DCA( + self, audio, dataset, dla=0.25, pla=0.25, ga=5.0, forward_attn=False, location_attn=True + ): + config = Tacotron2Config( + audio=audio, + run_name="ljspeech-dca", + run_description="tacotron2 with dynamic conv attention.", + batch_size=self.batch_size, + eval_batch_size=self.batch_size // 2, + mixed_precision=True, + loss_masking=True, + decoder_loss_alpha=dla, + postnet_loss_alpha=pla, + postnet_diff_spec_alpha=0.25, + decoder_diff_spec_alpha=0.25, + decoder_ssim_alpha=0.25, + postnet_ssim_alpha=0.25, + ga_alpha=ga, + stopnet_pos_weight=15.0, + run_eval=True, + test_delay_epochs=10, + max_decoder_steps=1000, + grad_clip=0.05, + epochs=self.epochs, + lr=self.learning_rate, + memory_size=-1, + prenet_type="original", + use_forward_attn=forward_attn, + prenet_dropout=True, + attention_type="dynamic_convolution", + location_attn=location_attn, + attention_norm="sigmoid", + r=2, + stopnet=True, + separate_stopnet=True, + print_step=25, + plot_step=100, + print_eval=False, + save_step=10000, + checkpoint=True, + text_cleaner="phoneme_cleaners", + num_loader_workers=4, + num_eval_loader_workers=4, + batch_group_size=4, + min_seq_len=6, + max_seq_len=180, + compute_input_seq_cache=True, + output_path=self.output_path, + phoneme_cache_path=os.path.join(self.output_path, "phoneme_cache"), + use_phonemes=False, + phoneme_language="en-us", + datasets=[dataset], + ) + return config + + def _single_speaker_glow_tts(self, audio, dataset, encoder): + encoder_type = pick_glowtts_encoder(encoder) + glowtts_config.audio = audio + glowtts_config.batch_size = self.batch_size + glowtts_config.eval_batch_size = self.batch_size // 2 + glowtts_config.epochs = self.epochs + glowtts_config.output_path = self.output_path + glowtts_config.lr = self.learning_rate + glowtts_config.mixed_precision = self.mixed_precision + glowtts_config.encoder_type = encoder_type + glowtts_config.datasets = [dataset] + config = glowtts_config + return config + + def _sc_glow_tts(self, audio, dataset, speaker_file, encoder): + encoder_type = pick_glowtts_encoder(encoder) + config = GlowTTSConfig( + audio=audio, + run_name="multispeaker glow tts", + run_description="glow tts for multispeaker datasets", + batch_size=self.batch_size, + eval_batch_size=self.batch_size // 2, + mixed_precision=self.mixed_precision, + run_eval=True, + test_delay_epochs=-1, + print_eval=False, + print_step=25, + plot_step=100, + model_param_stats=False, + save_step=10000, + num_loader_workers=8, + num_eval_loader_workers=8, + use_noise_augment=False, + output_path=self.output_path, + use_phonemes=True, + use_espeak_phonemes=True, + phoneme_language="en", + compute_input_seq_cache=False, + test_sentences_file=None, + phoneme_cache_path=os.path.join(self.output_path, "phoneme_cache"), + batch_group_size=0, + loss_masking=True, + min_seq_len=2, + max_seq_len=500, + compute_f0=False, + add_blank=True, + use_speaker_embedding=True, + use_d_vector_file=True, + d_vector_dim=256, + encoder_type=encoder_type, + use_encoder_prenet=True, + hidden_channels_dp=256, + hidden_channels_dec=192, + hidden_channels_enc=192, + dropout_p_dp=0.1, + dropout_p_dec=0.05, + mean_only=True, + out_channels=80, + num_flow_blocks_dec=12, + inference_noise_scale=0.0, + kernel_size_dec=5, + dilation_rate=1, + num_block_layers=4, + num_speakers=0, + num_splits=4, + num_squeeze=2, + sigmoid_scale=False, + data_dep_init_steps=10, + style_wav_for_test=None, + length_scale=1.0, + d_vector_file=speaker_file, + grad_clip=5.0, + lr=self.learning_rate, + r=1, + datasets=[dataset], + ) + return config + + def _single_speaker_vits_tts(self, audio, dataset): + vits_config.audio = audio + vits_config.datasets = [dataset] + vits_config.lr_gen = self.learning_rate + vits_config.lr_disc = self.learning_rate + vits_config.batch_size = self.batch_size + vits_config.eval_batch_size = self.batch_size // 2 + vits_config.mixed_precision = self.mixed_precision + vits_config.output_path = self.output_path + config = vits_config + return config + + def _vctk_vits_tts(self, audio, dataset, speaker_file): + vits_config.audio = audio + vits_config.datasets = [dataset] + vits_config.lr_gen = self.learning_rate + vits_config.lr_disc = self.learning_rate + vits_config.batch_size = self.batch_size + vits_config.eval_batch_size = self.batch_size // 2 + vits_config.mixed_precision = self.mixed_precision + vits_config.output_path = self.output_path + vits_config.use_speaker_embedding = True + vits_config.num_speakers = 109 + vits_config.speaker_embedding_channels = 256 + vits_config.speakers_file = speaker_file + vits_config.num_chars = 179 + config = vits_config + return config + + def _ljspeech_fast_fastpitch(self, audio, dataset): + fastpitch_config.audio = audio + fastpitch_config.datasets = [dataset] + fastpitch_config.lr_gen = self.learning_rate + fastpitch_config.lr_disc = self.learning_rate + fastpitch_config.batch_size = self.batch_size + fastpitch_config.eval_batch_size = self.batch_size // 2 + fastpitch_config.mixed_precision = self.mixed_precision + fastpitch_config.output_path = self.output_path + config = fastpitch_config + return config + + def _forward_tts(self, audio, dataset, encoder, decoder): + encoder_type = pick_forwardtts_encoder(encoder) + decoder_type = pick_forwardtts_decoder(decoder) + model_args = ForwardTTSArgs(encoder_type=encoder_type, decoder_type=decoder_type) + pass + + +class VocoderModels: + def __init__( + self, + batch_size, + mixed_precision, + generator_learning_rate, + discriminator_learning_rate, + epochs, + output_path=os.path.dirname(os.path.abspath(__file__)), + ): + self.batch_size = batch_size + self.output_path = output_path + self.mixed_precision = mixed_precision + self.generator_lr = generator_learning_rate + self.discriminator_lr = discriminator_learning_rate + self.epochs = epochs + self.manager = ModelManager() + + def _single_speaker_from_pretrained(self, model_name: str): + if model_name == "universal-wavegrad": + model_path, config_path, _ = self.manager.download_model("vocoder_models/universal/libri-tts/wavegrad") + elif model_name == "universal-fullband-melgan": + model_path, config_path, _ = self.manager.download_model( + "vocoder_models/universal/libri-tts/fullband-melgan" + ) + elif model_name == "english-wavegrad": + model_path, config_path, _ = self.manager.download_model("vocoder_models/en/ek1/wavegrad") + elif model_name == "english-multiband-melgan": + model_path, config_path, _ = self.manager.download_model("vocoder_models/en/ljspeech/multiband-melgan") + elif model_name == "english-hifigan-v2": + model_path, config_path, _ = self.manager.download_model("vocoder_models/en/ljspeech/hifigan_v2") + elif model_name == "english-univnet": + model_path, config_path, _ = self.manager.download_model("vocoder_models/en/ljspeech/univnet") + elif model_name == "german-wavegrad": + model_path, config_path, _ = self.manager.download_model("vocoder_models/de/thorsten/wavegrad") + elif model_name == "german-fullband-melgan": + model_path, config_path, _ = self.manager.download_model("vocoder_models/de/thorsten/fullband-melgan") + elif model_name == "japanese-hifigan-v1": + model_path, config_path, _ = self.manager.download_model("vocoder_models/ja/kokoro/hifigan_v1") + training_args = TrainingArgs(restore_path=model_path) + return model_path, config_path, training_args + + def _multi_speaker_from_pretrained(self, model_name): + if model_name == "english-hifigan-v2": + model_path, config_path, _ = self.manager.download_model("vocoder_models/en/vctk/hifigan_v2") + training_args = TrainingArgs(restore_path=model_path) + return model_path, config_path, training_args + + def _hifi_gan(self, audio, data_path): + hifigan_config.data_path = data_path + hifigan_config.audio = audio + hifigan_config.batch_size = self.batch_size + hifigan_config.eval_batch_size = self.batch_size // 2 + hifigan_config.output_path = self.output_path + hifigan_config.mixed_precision = self.mixed_precision + hifigan_config.epochs = self.epochs + hifigan_config.lr_gen = self.generator_lr + hifigan_config.lr_disc = self.discriminator_lr + config = hifigan_config + return config + + def _wavegrad(self, audio, datapath): + wavegrad_config.audio = audio + wavegrad_config.data_path = datapath + wavegrad_config.batch_size = self.batch_size + wavegrad_config.eval_batch_size = self.batch_size // 2 + wavegrad_config.output_path = self.output_path + wavegrad_config.mixed_precision = self.mixed_precision + wavegrad_config.epochs = self.epochs + wavegrad_config.lr_gen = self.generator_lr + wavegrad_config.lr_disc = self.discriminator_lr + config = wavegrad_config + return config + + def _multiband_melgan(self, audio, data_path): + multiband_melgan_config.audio = audio + multiband_melgan_config.data_path = data_path + multiband_melgan_config.batch_size = self.batch_size + multiband_melgan_config.eval_batch_size = self.batch_size // 2 + multiband_melgan_config.output_path = self.output_path + multiband_melgan_config.mixed_precision = self.mixed_precision + multiband_melgan_config.epochs = self.epochs + multiband_melgan_config.lr_gen = self.generator_lr + multiband_melgan_config.lr_disc = self.discriminator_lr + config = multiband_melgan_config + return config + + def _univnet(self, audio, data_path): + univnet_config.audio = audio + univnet_config.data_path = data_path + univnet_config.batch_size = self.batch_size + univnet_config.eval_batch_size = self.batch_size // 2 + univnet_config.output_path = self.output_path + univnet_config.mixed_precision = self.mixed_precision + univnet_config.epochs = self.epochs + univnet_config.lr_gen = self.generator_lr + univnet_config.lr_disc = self.discriminator_lr + config = univnet_config + return config + + def _wavernn(self, audio, data_path): + waverrn_config.audio = audio + waverrn_config.data_path = data_path + waverrn_config.batch_size = self.batch_size + waverrn_config.eval_batch_size = self.batch_size // 2 + waverrn_config.output_path = self.output_path + waverrn_config.mixed_precision = self.mixed_precision + waverrn_config.epochs = self.epochs + waverrn_config.lr_gen = self.generator_lr + waverrn_config.lr_disc = self.discriminator_lr + config = waverrn_config + return config diff --git a/TTS/auto_tts/single_speaker_autotts.py b/TTS/auto_tts/single_speaker_autotts.py new file mode 100644 index 0000000000..4d7c9e1035 --- /dev/null +++ b/TTS/auto_tts/single_speaker_autotts.py @@ -0,0 +1,70 @@ +import argparse + +from TTS.auto_tts.complete_recipes import TtsAutoTrainer + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--data_path", type=str, required=True, help="path to the dataset") + parser.add_argument("--dataset", type=str, required=True, help="name of the dataset") + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument( + "--output_path", type=str, default="./", help="path you want to store your model and config in." + ) + parser.add_argument( + "--mixed_precision", dest="mixed_precision", action="store_true", help="This turns on mixed precision training." + ) + parser.add_argument("--model", type=str, required=True, help="This is the model you want to train with, c") + parser.add_argument("--learning_rate", type=float, default=0.001) + parser.add_argument("--epochs", type=int, default=1000) + parser.add_argument( + "--tacotron2_model_type", + type=str, + help="this is the type of tactron model you want to train. options are 'double decoder consistency' and 'dynamic convolution attention'. This is only for tacotron2 models", + ) + parser.add_argument( + "--glow_tts_encoder", + type=str, + default=None, + help="the type of encoder glow tts will train with, defaults to transformer encoder.", + ) + parser.add_argument("--stats_path", type=str, default=None, help="stats path for audio config.") + parser.add_argument( + "--forward_attention", + dest="forward_attention", + action="store_true", + help="This turns on foward attention for tacotron2 models.", + ) + parser.add_argument( + "--location_attention", + dest="location_attention", + action="store_true", + help="This turns on location attention for tacotron2 models, recommended to turn on.", + ) + + parser.set_defaults(mixed_precision=False, forward_attention=False, location_attention=False) + + args = parser.parse_args() + args = vars(args) + trainer = TtsAutoTrainer( + args["data_path"], + args["dataset"], + args["batch_size"], + args["output_path"], + args["mixed_precision"], + args["learning_rate"], + args["epochs"], + ) + model = trainer.single_speaker_autotts( + args["model"], + args["stats_path"], + args["tacotron2_model_type"], + args["glow_tts_encoder"], + args["forward_attention"], + args["location_attention"], + ) + model.fit() + + +if __name__ == "__main__": + main() diff --git a/TTS/auto_tts/stats_path/scale_stats_dca.npy b/TTS/auto_tts/stats_path/scale_stats_dca.npy new file mode 100644 index 0000000000..1dc577a682 Binary files /dev/null and b/TTS/auto_tts/stats_path/scale_stats_dca.npy differ diff --git a/TTS/auto_tts/stats_path/scale_stats_ddc.npy b/TTS/auto_tts/stats_path/scale_stats_ddc.npy new file mode 100644 index 0000000000..1dc577a682 Binary files /dev/null and b/TTS/auto_tts/stats_path/scale_stats_ddc.npy differ diff --git a/TTS/auto_tts/utils.py b/TTS/auto_tts/utils.py new file mode 100644 index 0000000000..6c6f17ed72 --- /dev/null +++ b/TTS/auto_tts/utils.py @@ -0,0 +1,113 @@ +from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig + + +def data_loader(name, path, stats_path=None): + if name == "ljspeech": + dataset = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", path=path) + audio = BaseAudioConfig( + ref_level_db=0, trim_db=60, mel_fmin=50.0, mel_fmax=7600.0, spec_gain=1, stats_path=stats_path + ) + + elif name == "vctk": + dataset = BaseDatasetConfig( + name="vctk", + meta_file_train=["p225", "p234", "p238", "p245", "p248", "p261", "p294", "p302", "p326", "p335", "p347"], + meta_file_val=None, + path=path, + ) + audio = BaseAudioConfig( + sample_rate=22050, + preemphasis=0.98, + ref_level_db=20, + clip_norm=True, + mel_fmin=0.0, + mel_fmax=8000.0, + spec_gain=20, + do_trim_silence=False, + trim_db=60, + power=1.5, + num_mels=80, + resample=True, + ) + + elif name == "libri_tts": + dataset = BaseDatasetConfig(name="libri_tts", meta_file_train=None, meta_file_val=None, path=path) + audio = BaseAudioConfig( + resample=False, + sample_rate=24000, + preemphasis=0.98, + ref_level_db=20, + power=1.5, + signal_norm=True, + symmetric_norm=True, + max_norm=4.0, + clip_norm=True, + mel_fmax=8000.0, + spec_gain=20, + do_trim_silence=False, + trim_db=25, + ) + elif name == "baker": + dataset = BaseDatasetConfig(name=name, meta_file_train="metadata.csv", meta_file_val=None, path=path) + audio = BaseAudioConfig( + sample_rate=22050, + preemphasis=0.0, + ref_level_db=0, + do_trim_silence=True, + trim_db=60, + mel_fmin=50.0, + mel_fmax=7600.0, + spec_gain=1, + signal_norm=True, + symmetric_norm=True, + clip_norm=True, + stats_path=stats_path, + ) + return audio, dataset + + +def custom_data_loader(sr, audio_path): + dataset = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", path=audio_path) + pass + # this is for loading custom dataloader, it still takes the ljspeech format but the audio configs will differ + # with each users data so im thinking of a way to have users define their own audio params with this + + +def pick_glowtts_encoder(encoder_name: str): + if encoder_name == "transformer": + encoder_type = "rel_pos_transformer" + elif encoder_name == "gated": + encoder_type = "gated_conv" + elif encoder_name == "residual_bn": + encoder_type = "residual_conv_bn" + elif encoder_name == "time_depth": + encoder_type = "time_depth_separable" + else: + encoder_type = "rel_pos_transformer" + return encoder_type + + +def pick_forwardtts_encoder(encoder_name: str): + if encoder_name == "residual_bn": + encoder = "residual_conv_bn" + elif encoder_name == "fftransformer": + encoder = encoder_name + elif encoder_name == "position transformer": + encoder = "relative_position_transformer" + else: + print("please select an actual encoder. either residual_bn, fftransformer, or position transformer") + return encoder + + +def pick_forwardtts_decoder(decoder_name: str): + if decoder_name == "position transformer": + decoder = "relative_position_transformer" + elif decoder_name == " residual_bn": + decoder = "residual_conv_bn" + elif decoder_name == "wavenet": + decoder = decoder_name + elif decoder_name == "fftransformer": + decoder = decoder_name + else: + print("please select either position transformer, residual_bn, wavenet, or fftransformer") + return decoder diff --git a/TTS/auto_tts/vocoder_autotts.py b/TTS/auto_tts/vocoder_autotts.py new file mode 100644 index 0000000000..48fd54c287 --- /dev/null +++ b/TTS/auto_tts/vocoder_autotts.py @@ -0,0 +1,47 @@ +import argparse + +from TTS.auto_tts.complete_recipes import VocoderAutoTrainer + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--data_path", type=str, required=True, help="path to the dataset") + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--dataset", type=str, required=True, help="name of the dataset you want to train.") + parser.add_argument( + "--output_path", type=str, default="./", help="path you want to store your model and config in." + ) + parser.add_argument( + "--mixed_precision", dest="mixed_precision", action="store_true", help="This turns on mixed precision training." + ) + parser.add_argument("--model", type=str, required=True, help="This is the model you want to train with") + parser.add_argument("--learning_rate", type=float, default=[0.0001, 0.0001]) + parser.add_argument("--epochs", type=int, default=1000) + parser.add_argument( + "--location_attention", + dest="location_attention", + action="store_true", + help="This turns on location attention for tacotron2 models, recommended to turn on.", + ) + parser.add_argument("--stats_path", type=str, default=None, help="path to stats file for audio config.") + + parser.set_defaults(mixed_precision=False, forward_attention=False, location_attention=False) + + args = parser.parse_args() + args = vars(args) + print(args) + trainer = VocoderAutoTrainer( + args["data_path"], + args["dataset"], + args["batch_size"], + args["output_path"], + args["mixed_precision"], + args["learning_rate"], + args["epochs"], + ) + model = trainer.single_speaker_autotts(model_name=args["model"], stats_path=args["stats_path"]) + model.fit() + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/distribute.py b/TTS/bin/distribute.py index 06d5f388ac..9716e2a2f4 100644 --- a/TTS/bin/distribute.py +++ b/TTS/bin/distribute.py @@ -22,6 +22,11 @@ def main(): num_gpus = torch.cuda.device_count() group_id = time.strftime("%Y_%m_%d-%H%M%S") + assert num_gpus > 1, "distributed.py requires multiple available GPUs" + visible_gpus = ( + os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else list(range(0, num_gpus)) + ) + # set arguments for train.py folder_path = pathlib.Path(__file__).parent.absolute() if os.path.exists(os.path.join(folder_path, args.script)): @@ -38,9 +43,11 @@ def main(): # run processes processes = [] - for i in range(num_gpus): + gpus = visible_gpus.split(",") + for i, value in enumerate(gpus): my_env = os.environ.copy() my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i) + my_env["CUDA_VISIBLE_DEVICES"] = "{}".format(value) command[-1] = "--rank={}".format(i) # prevent stdout for processes with rank != 0 stdout = None diff --git a/TTS/tts/datasets/dataset_downloaders.py b/TTS/tts/datasets/dataset_downloaders.py new file mode 100644 index 0000000000..5937295640 --- /dev/null +++ b/TTS/tts/datasets/dataset_downloaders.py @@ -0,0 +1,256 @@ +import logging +import os +import tarfile +from os.path import expanduser + +import requests +from tqdm import tqdm + +from TTS.utils.generic_utils import get_user_data_dir + + +class DatasetDownloaders: + def __init__( + self, + dataset_name: str, + output_path: str = None, + libri_tts_subset: str = "all", + voxceleb_version: str = "both", + mailabs_language: str = "all", + ): + self.name = dataset_name + self.libri_tts_subset = libri_tts_subset + self.voxceleb_version = voxceleb_version + self.mailabs_language = mailabs_language + if output_path is None: + self.output_path = get_user_data_dir("tts/datasets") + else: + self.output_path = os.path.join(output_path, "tts/datasets") + + self.dataset_dict = { + "thorsten-de": ( + "https://www.openslr.org/resources/95/thorsten-de_v02.tgz", + "thorsten-de_v02.tgz", + "Thorsten-De", + ), + "ljspeech": ( + "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2", + "LJSpeech-1.1.tar.bz2", + "LJSpeech-1.1", + ), + "common-voice": (), + "tweb": ("bryanpark/the-world-english-bible-speech-dataset", "tweb.zip", "Tweb"), + "libri-tts-clean-100": ( + "http://www.openslr.org/resources/60/train-clean-100.tar.gz", + "train-clean-100.tar.tz", + "LibriTTS-100-hours", + ), + "libri-tts-clean-360": ( + "http://www.openslr.org/resources/60/train-clean-360.tar.gz", + "train-clean-360.tar.tz", + "LibriTTS-360-hours", + ), + "libri-tts-other-500": ( + "http://www.openslr.org/resources/60/train-other-500.tar.gz", + "train-other-500.tar.tz", + "LibriTTS-500-hours", + ), + "libri-tts-dev-clean": ( + "http://www.openslr.org/resources/60/dev-clean.tar.gz", + "dev-clean.tar.tz", + "LibriTTS-dev-clean", + ), + "libri-tts-dev-other": ( + "http://www.openslr.org/resources/60/dev-other.tar.gz", + "dev-other.tar.tz", + "LibriTTS-dev-other", + ), + "libri-tts-test-clean": ( + "http://www.openslr.org/resources/60/test-clean.tar.gz", + "test-clean.tar.tz", + "LibriTTS-test-clean", + ), + "libri-tts-test-other": ( + "http://www.openslr.org/resources/60/test-other.tar.gz", + "test-other.tar.tz", + "LibriTTS-test-other", + ), + "mailabs-english": ( + "https://data.solak.de/data/Training/stt_tts/en_US.tgz", + "en_US.tgz", + "MaiLabs-English", + ), + "mailabs-german": ("https://data.solak.de/data/Training/stt_tts/de_DE.tgz", "de_DE.tgz", "MaiLabs-German"), + "mailabs-french": ("https://data.solak.de/data/Training/stt_tts/fr_FR.tgz", "fr_FR.tgz", "MaiLabs-French"), + "mailabs-italian": ( + "https://data.solak.de/data/Training/stt_tts/it_IT.tgz", + "it_IT.tgz", + "MaiLabs-Italian", + ), + "mailabs-spanish": ( + "https://data.solak.de/data/Training/stt_tts/es_ES.tgz", + "es_ES.tgz", + "MaiLabs-Spanish", + ), + "vctk-kaggle": ("mfekadu/english-multispeaker-corpus-for-voice-cloning", "vctk.zip", "Vctk"), + "vctk": ("datashare.is.ed.ac.uk/download/DS_10283_3443.zip", "vctk.zip", "Vctk"), + } + + def download_dataset(self): + if self.name == "ljspeech": + dataset_path = self._download_ljspeech() + elif self.name == "libri-tts": + dataset_path = self._download_libri_tts() + elif self.name == "thorsten-de": + dataset_path = self._download_thorsten_german() + elif self.name == "mailabs": + dataset_path = self._download_mailabs() + elif self.name == "vctk": + dataset_path = self._download_vctk() + elif self.name == "tweb": + dataset_path = self._download_tweb() + return dataset_path + + def list_datasets(self): + data_list_dict = { + "ljspeech": ( + "ljspeech", + "24 hours of professional audio of a female reading audio books.This dataset is 2.76 gigs in size.", + ), + } + print(data_list_dict) + + def _download_ljspeech(self): + url, tar_file, data_name = self.dataset_dict["ljspeech"] + data_path = os.path.join(self.output_path, data_name) + self._download_tarred_data(url, tar_file, data_name) + return data_path + + def _download_thorsten_german(self): + url, tar_file, data_name = self.dataset_dict["thorsten-de"] + data_path = os.path.join(self.output_path, data_name) + self._download_tarred_data(url, tar_file, data_name) + return data_path + + def _download_libri_tts(self): + if self.libri_tts_subset == "all": + subset_names = [ + "libri-tts-clean-100", + "libri-tts-clean-360", + "libri-tts-other-500", + "libri-tts-dev-clean", + "libri-tts-dev-other", + "libri-tts-test-clean", + "libri-tts-test-other", + ] + for i in subset_names: + url, tar_file, subset_name = self.dataset_dict[i] + self._download_tarred_data(url, tar_file, subset_name) + print("finished downloading all subsets") + elif self.libri_tts_subset == "clean": + subset_names = ["libri-tts-clean-100", "libri-tts-clean-360", "libri-tts-dev-clean", "libri-tts-test-clean"] + for i in subset_names: + url, tar_file, subset_name = self.dataset_dict[i] + self._download_tarred_data(url, tar_file, subset_name) + print("finished downloading the clean subsets") + elif self.libri_tts_subset == "noisy": + subset_names = ["libri-tts-other-500", "libri-tts-dev-other", "libri-tts-test-other"] + for i in subset_names: + url, tar_file, subset_name = self.dataset_dict[i] + self._download_tarred_data(url, tar_file, subset_name) + print("finished downloading the noisy subsets") + dataset_path = os.path.join(self.output_path, "LibriTTS") + print(f'your dataset was downloaded to {os.path.join(self.output_path, "LibriTTS")}') + return dataset_path + + def _download_vctk(self, remove_silences=True, use_kaggle=True): + if use_kaggle: + url, tar_file, dataset_name = self.dataset_dict["vctk-kaggle"] + data_dir = self._download_kaggle_dataset(url, dataset_name) + dataset_path = os.path.join( + data_dir, + ) + return dataset_path + else: + pass + + def _download_tweb(self): + url, tar_file, dataset_name = self.dataset_dict["tweb"] + data_dir = self._download_kaggle_dataset(url, dataset_name) + return data_dir + + def _download_mailabs(self): + if self.mailabs_language == "all": + language_subsets = [ + "mailabs-english", + "mailabs-german", + "mailabs-french", + "mailabs-italian", + "mailabs-spanish", + ] + for subset in language_subsets: + url, tar_file, data_name = self.dataset_dict[subset] + self._download_tarred_data(url, tar_file, data_name, custom_dir_name="MaiLabs") + elif self.mailabs_language == "english": + url, tar_file, data_name = self.dataset_dict["mailabs-english"] + self._download_tarred_data(url, tar_file, data_name, custom_dir_name="MaiLabs") + elif self.mailabs_language == "french": + url, tar_file, data_name = self.dataset_dict["mailabs-french"] + self._download_tarred_data(url, tar_file, data_name) + elif self.mailabs_language == "german": + url, tar_file, data_name = self.dataset_dict["mailabs-german"] + self._download_tarred_data(url, tar_file, data_name) + elif self.mailabs_language == "italian": + url, tar_file, data_name = self.dataset_dict["mailabs-italian"] + self._download_tarred_data(url, tar_file, data_name) + elif self.mailabs_language == "spanish": + url, tar_file, data_name = self.dataset_dict["mailabs-spanish"] + self._download_tarred_data(url, tar_file, data_name) + print(f'your dataset was downloaded to {os.path.join(self.output_path, "MaiLabs")}') + return data_name + + def _download_voxceleb(self, version="both"): + pass + + def _download_common_voice(self, version): + pass + + def _download_tarred_data(self, url, tar_file, data_name, custom_dir_name=None): + if custom_dir_name is not None: + dataset_dir = os.path.join(self.output_path, custom_dir_name) + else: + dataset_dir = self.output_path + with open(os.path.join(self.output_path, tar_file), "wb") as data: + raw_bytes = requests.get(url, stream=True) + total_bytes = int(raw_bytes.headers["Content-Length"]) + progress_bar = tqdm(total=total_bytes, unit="MiM", unit_scale=True) + print(f"\ndownloading {data_name} data") + for chunk in raw_bytes.iter_content(chunk_size=1024): + data.write(chunk) + progress_bar.update(len(chunk)) + data.close() + progress_bar.close() + print("\nextracting data") + tar_path = tarfile.open(os.path.join(self.output_path, tar_file)) + tar_path.extractall(dataset_dir) + tar_path.close() + os.remove(os.path.join(self.output_path, tar_file)) + + def _download_kaggle_dataset(self, dataset_path, dataset_name): + data_path = os.path.join(self.output_path, dataset_name) + try: + import kaggle + + kaggle.api.authenticate() + print( + f"""\nThe {dataset_name} dataset is being download and untarred via kaggle api. This may take a minute depending on the dataset size""" + ) + kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True) + print(f"dataset is downloaded and stored in {data_path}") + return data_path + except OSError: + logging.warning( + f"""in order to download kaggle datasets, you need to have a kaggle api token stored in your +{os.path.join(expanduser('~'), '.kaggle/kaggle.json')} +If you don't have a kaggle account you can easily make one for free and generate a token in the account settings tab.""" + ) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index b77c1e2315..63f98a08b1 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -225,6 +225,14 @@ def get_data_loader( d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, ) + if ( + config.use_phonemes + and config.compute_input_seq_cache + and not os.path.exists(dataset.phoneme_cache_path) + ): + # precompute phonemes to have a better estimate of sequence lengths. + dataset.compute_input_seq(config.num_loader_workers) + # pre-compute phonemes if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]: if hasattr(self, "eval_data_items") and is_eval: