diff --git a/TTS/bin/compute_statistics.py b/TTS/bin/compute_statistics.py index 9e2b7415d8..a8c4240ec9 100755 --- a/TTS/bin/compute_statistics.py +++ b/TTS/bin/compute_statistics.py @@ -8,6 +8,7 @@ import numpy as np from tqdm import tqdm +from TTS.utils.config_manager import ConfigManager from TTS.tts.datasets.preprocess import load_meta_data from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_config @@ -15,26 +16,33 @@ def main(): """Run preprocessing process.""" - parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.") - parser.add_argument( - "--config_path", type=str, required=True, help="TTS config file path to define audio processin parameters." - ) - parser.add_argument("--out_path", type=str, required=True, help="save path (directory and filename).") + CONFIG = ConfigManager() + + parser = argparse.ArgumentParser( + description="Compute mean and variance of spectrogtram features.") + parser.add_argument("config_path", type=str, + help="TTS config file path to define audio processin parameters.") + parser.add_argument("out_path", type=str, + help="save path (directory and filename).") + parser.add_argument("--data_path", type=str, required=False, + help="folder including the target set of wavs overriding dataset config.") + parser = CONFIG.init_argparse(parser) args = parser.parse_args() + CONFIG.parse_argparse(args) # load config - CONFIG = load_config(args.config_path) - CONFIG.audio["signal_norm"] = False # do not apply earlier normalization - CONFIG.audio["stats_path"] = None # discard pre-defined stats + CONFIG.load_config(args.config_path) + CONFIG.audio_config.signal_norm = False # do not apply earlier normalization + CONFIG.audio_config.stats_path = None # discard pre-defined stats # load audio processor - ap = AudioProcessor(**CONFIG.audio) + ap = AudioProcessor(**CONFIG.audio_config.to_dict()) # load the meta data of target dataset - if "data_path" in CONFIG.keys(): - dataset_items = glob.glob(os.path.join(CONFIG.data_path, "**", "*.wav"), recursive=True) + if args.data_path: + dataset_items = glob.glob(os.path.join(args.data_path, '**', '*.wav'), recursive=True) else: - dataset_items = load_meta_data(CONFIG.datasets)[0] # take only train data + dataset_items = load_meta_data(CONFIG.dataset_config)[0] # take only train data print(f" > There are {len(dataset_items)} files.") mel_sum = 0 @@ -73,14 +81,15 @@ def main(): print(f" > Avg lienar spec scale: {linear_scale.mean()}") # set default config values for mean-var scaling - CONFIG.audio["stats_path"] = output_file_path - CONFIG.audio["signal_norm"] = True + CONFIG.audio_config.stats_path = output_file_path + CONFIG.audio_config.signal_norm = True # remove redundant values - del CONFIG.audio["max_norm"] - del CONFIG.audio["min_level_db"] - del CONFIG.audio["symmetric_norm"] - del CONFIG.audio["clip_norm"] - stats["audio_config"] = CONFIG.audio + del CONFIG.audio_config.max_norm + del CONFIG.audio_config.min_level_db + del CONFIG.audio_config.symmetric_norm + del CONFIG.audio_config.clip_norm + breakpoint() + stats['audio_config'] = CONFIG.audio_config.to_dict() np.save(output_file_path, stats, allow_pickle=True) print(f" > stats saved to {output_file_path}") diff --git a/TTS/bin/train_tacotron.py b/TTS/bin/train_tacotron.py index 28865c7573..9ec068b34f 100644 --- a/TTS/bin/train_tacotron.py +++ b/TTS/bin/train_tacotron.py @@ -10,11 +10,9 @@ import numpy as np import torch from torch.utils.data import DataLoader - from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.layers.losses import TacotronLoss -from TTS.tts.configs.tacotron_config import TacotronConfig from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.measures import alignment_diagonal_score @@ -24,8 +22,11 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.arguments import parse_arguments, process_args from TTS.utils.audio import AudioProcessor -from TTS.utils.distribute import DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor -from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict +from TTS.utils.config_manager import ConfigManager +from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce, + init_distributed, reduce_tensor) +from TTS.utils.generic_utils import (KeepAverage, count_parameters, + remove_experiment_folder, set_init_dict) from TTS.utils.radam import RAdam from TTS.utils.training import ( NoamLR, @@ -739,7 +740,10 @@ def main(args): # pylint: disable=redefined-outer-name if __name__ == "__main__": args = parse_arguments(sys.argv) - c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="tts") + c = TacotronConfig() + args = c.init_argparse(args) + c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( + args, c, model_type='tacotron') try: main(args) diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 1f889b8add..0e80111bc3 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -2,10 +2,7 @@ import re from collections import Counter -import numpy as np -import torch - -from TTS.utils.generic_utils import check_argument +from TTS.utils.generic_utils import find_module def split_dataset(items): @@ -39,17 +36,9 @@ def sequence_mask(sequence_length, max_len=None): return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1) -def to_camel(text): - text = text.capitalize() - text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) - text = text.replace("Tts", "TTS") - return text - - def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): print(" > Using model: {}".format(c.model)) - MyModel = importlib.import_module("TTS.tts.models." + c.model.lower()) - MyModel = getattr(MyModel, to_camel(c.model)) + find_module("TTS.tts.models", c.model.lower()) if c.model.lower() in "tacotron": model = MyModel( num_chars=num_chars + getattr(c, "add_blank", False), @@ -164,189 +153,156 @@ def is_tacotron(c): return "tacotron" in c["model"].lower() -def check_config_tts(c): - check_argument( - "model", - c, - enum_list=["tacotron", "tacotron2", "glow_tts", "speedy_speech", "align_tts"], - restricted=True, - val_type=str, - ) - check_argument("run_name", c, restricted=True, val_type=str) - check_argument("run_description", c, val_type=str) - - # AUDIO - # check_argument('audio', c, restricted=True, val_type=dict) - - # audio processing parameters - # check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056) - # check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058) - # check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000) - # check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length') - # check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length') - # check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1) - # check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10) - # check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000) - # check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5) - # check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000) - - # vocabulary parameters - check_argument("characters", c, restricted=False, val_type=dict) - check_argument( - "pad", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str - ) - check_argument( - "eos", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str - ) - check_argument( - "bos", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str - ) - check_argument( - "characters", - c["characters"] if "characters" in c.keys() else {}, - restricted="characters" in c.keys(), - val_type=str, - ) - check_argument( - "phonemes", - c["characters"] if "characters" in c.keys() else {}, - restricted="characters" in c.keys() and c["use_phonemes"], - val_type=str, - ) - check_argument( - "punctuations", - c["characters"] if "characters" in c.keys() else {}, - restricted="characters" in c.keys(), - val_type=str, - ) - - # normalization parameters - # check_argument('signal_norm', c['audio'], restricted=True, val_type=bool) - # check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool) - # check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000) - # check_argument('clip_norm', c['audio'], restricted=True, val_type=bool) - # check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000) - # check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0) - # check_argument('spec_gain', c['audio'], restricted=True, val_type=[int, float], min_val=1, max_val=100) - # check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool) - # check_argument('trim_db', c['audio'], restricted=True, val_type=int) - - # training parameters - # check_argument('batch_size', c, restricted=True, val_type=int, min_val=1) - # check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1) - # check_argument('r', c, restricted=True, val_type=int, min_val=1) - # check_argument('gradual_training', c, restricted=False, val_type=list) - # check_argument('mixed_precision', c, restricted=False, val_type=bool) - # check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100) - - # loss parameters - # check_argument('loss_masking', c, restricted=True, val_type=bool) - # if c['model'].lower() in ['tacotron', 'tacotron2']: - # check_argument('decoder_loss_alpha', c, restricted=True, val_type=float, min_val=0) - # check_argument('postnet_loss_alpha', c, restricted=True, val_type=float, min_val=0) - # check_argument('postnet_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0) - # check_argument('decoder_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0) - # check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0) - # check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0) - # check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0) - if c['model'].lower in ["speedy_speech", "align_tts"]: - check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0) - check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0) - check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0) - - # validation parameters - # check_argument('run_eval', c, restricted=True, val_type=bool) - # check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0) - # check_argument('test_sentences_file', c, restricted=False, val_type=str) - - # optimizer - check_argument("noam_schedule", c, restricted=False, val_type=bool) - check_argument("grad_clip", c, restricted=True, val_type=float, min_val=0.0) - check_argument("epochs", c, restricted=True, val_type=int, min_val=1) - check_argument("lr", c, restricted=True, val_type=float, min_val=0) - check_argument("wd", c, restricted=is_tacotron(c), val_type=float, min_val=0) - check_argument("warmup_steps", c, restricted=True, val_type=int, min_val=0) - check_argument("seq_len_norm", c, restricted=is_tacotron(c), val_type=bool) - - # tacotron prenet - # check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1) - # check_argument('prenet_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['original', 'bn']) - # check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool) - - # attention - check_argument( - "attention_type", - c, - restricted=is_tacotron(c), - val_type=str, - enum_list=["graves", "original", "dynamic_convolution"], - ) - check_argument("attention_heads", c, restricted=is_tacotron(c), val_type=int) - check_argument("attention_norm", c, restricted=is_tacotron(c), val_type=str, enum_list=["sigmoid", "softmax"]) - check_argument("windowing", c, restricted=is_tacotron(c), val_type=bool) - check_argument("use_forward_attn", c, restricted=is_tacotron(c), val_type=bool) - check_argument("forward_attn_mask", c, restricted=is_tacotron(c), val_type=bool) - check_argument("transition_agent", c, restricted=is_tacotron(c), val_type=bool) - check_argument("transition_agent", c, restricted=is_tacotron(c), val_type=bool) - check_argument("location_attn", c, restricted=is_tacotron(c), val_type=bool) - check_argument("bidirectional_decoder", c, restricted=is_tacotron(c), val_type=bool) - check_argument("double_decoder_consistency", c, restricted=is_tacotron(c), val_type=bool) - check_argument("ddc_r", c, restricted="double_decoder_consistency" in c.keys(), min_val=1, max_val=7, val_type=int) - - if c["model"].lower() in ["tacotron", "tacotron2"]: - # stopnet - # check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool) - # check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool) - - # Model Parameters for non-tacotron models - if c["model"].lower in ["speedy_speech", "align_tts"]: - check_argument("positional_encoding", c, restricted=True, val_type=type) - check_argument("encoder_type", c, restricted=True, val_type=str) - check_argument("encoder_params", c, restricted=True, val_type=dict) - check_argument("decoder_residual_conv_bn_params", c, restricted=True, val_type=dict) - - # GlowTTS parameters - check_argument("encoder_type", c, restricted=not is_tacotron(c), val_type=str) - - # tensorboard - # check_argument('print_step', c, restricted=True, val_type=int, min_val=1) - # check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1) - # check_argument('save_step', c, restricted=True, val_type=int, min_val=1) - # check_argument('checkpoint', c, restricted=True, val_type=bool) - # check_argument('tb_model_param_stats', c, restricted=True, val_type=bool) - - # dataloading - # pylint: disable=import-outside-toplevel - from TTS.tts.utils.text import cleaners - # check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners)) - # check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool) - # check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0) - # check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0) - # check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0) - # check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0) - # check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10) - # check_argument('compute_input_seq_cache', c, restricted=True, val_type=bool) - - # paths - # check_argument('output_path', c, restricted=True, val_type=str) - - # multi-speaker and gst - # check_argument('use_speaker_embedding', c, restricted=True, val_type=bool) - # check_argument('use_external_speaker_embedding_file', c, restricted=c['use_speaker_embedding'], val_type=bool) - # check_argument('external_speaker_embedding_file', c, restricted=c['use_external_speaker_embedding_file'], val_type=str) - if c['model'].lower() in ['tacotron', 'tacotron2'] and c['use_gst']: - # check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool) - # check_argument('gst', c, restricted=is_tacotron(c), val_type=dict) - # check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict]) - # check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000) - # check_argument('gst_use_speaker_embedding', c['gst'], restricted=is_tacotron(c), val_type=bool) - # check_argument('gst_num_heads', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10) - # check_argument('gst_num_style_tokens', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000) - - # datasets - checking only the first entry - # check_argument('datasets', c, restricted=True, val_type=list) - # for dataset_entry in c['datasets']: - # check_argument('name', dataset_entry, restricted=True, val_type=str) - # check_argument('path', dataset_entry, restricted=True, val_type=str) - # check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list]) - # check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str) +# def check_config_tts(c): +# check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts', 'speedy_speech', 'align_tts'], restricted=True, val_type=str) +# check_argument('run_name', c, restricted=True, val_type=str) +# check_argument('run_description', c, val_type=str) + +# # AUDIO +# # check_argument('audio', c, restricted=True, val_type=dict) + +# # audio processing parameters +# # check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056) +# # check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058) +# # check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000) +# # check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length') +# # check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length') +# # check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1) +# # check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10) +# # check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000) +# # check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5) +# # check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000) + +# # vocabulary parameters +# check_argument('characters', c, restricted=False, val_type=dict) +# check_argument('pad', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) +# check_argument('eos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) +# check_argument('bos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) +# check_argument('characters', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) +# check_argument('phonemes', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys() and c['use_phonemes'], val_type=str) +# check_argument('punctuations', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) + +# # normalization parameters +# # check_argument('signal_norm', c['audio'], restricted=True, val_type=bool) +# # check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool) +# # check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000) +# # check_argument('clip_norm', c['audio'], restricted=True, val_type=bool) +# # check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000) +# # check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0) +# # check_argument('spec_gain', c['audio'], restricted=True, val_type=[int, float], min_val=1, max_val=100) +# # check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool) +# # check_argument('trim_db', c['audio'], restricted=True, val_type=int) + +# # training parameters +# # check_argument('batch_size', c, restricted=True, val_type=int, min_val=1) +# # check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1) +# # check_argument('r', c, restricted=True, val_type=int, min_val=1) +# # check_argument('gradual_training', c, restricted=False, val_type=list) +# # check_argument('mixed_precision', c, restricted=False, val_type=bool) +# # check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100) + +# # loss parameters +# # check_argument('loss_masking', c, restricted=True, val_type=bool) +# # if c['model'].lower() in ['tacotron', 'tacotron2']: +# # check_argument('decoder_loss_alpha', c, restricted=True, val_type=float, min_val=0) +# # check_argument('postnet_loss_alpha', c, restricted=True, val_type=float, min_val=0) +# # check_argument('postnet_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0) +# # check_argument('decoder_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0) +# # check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0) +# # check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0) +# # check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0) +# if c['model'].lower in ["speedy_speech", "align_tts"]: +# check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0) +# check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0) +# check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0) + +# # validation parameters +# # check_argument('run_eval', c, restricted=True, val_type=bool) +# # check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0) +# # check_argument('test_sentences_file', c, restricted=False, val_type=str) + +# # optimizer +# check_argument('noam_schedule', c, restricted=False, val_type=bool) +# check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0) +# check_argument('epochs', c, restricted=True, val_type=int, min_val=1) +# check_argument('lr', c, restricted=True, val_type=float, min_val=0) +# check_argument('wd', c, restricted=is_tacotron(c), val_type=float, min_val=0) +# check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0) +# check_argument('seq_len_norm', c, restricted=is_tacotron(c), val_type=bool) + +# # tacotron prenet +# # check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1) +# # check_argument('prenet_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['original', 'bn']) +# # check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool) + +# # attention +# check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original', 'dynamic_convolution']) +# check_argument('attention_heads', c, restricted=is_tacotron(c), val_type=int) +# check_argument('attention_norm', c, restricted=is_tacotron(c), val_type=str, enum_list=['sigmoid', 'softmax']) +# check_argument('windowing', c, restricted=is_tacotron(c), val_type=bool) +# check_argument('use_forward_attn', c, restricted=is_tacotron(c), val_type=bool) +# check_argument('forward_attn_mask', c, restricted=is_tacotron(c), val_type=bool) +# check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool) +# check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool) +# check_argument('location_attn', c, restricted=is_tacotron(c), val_type=bool) +# check_argument('bidirectional_decoder', c, restricted=is_tacotron(c), val_type=bool) +# check_argument('double_decoder_consistency', c, restricted=is_tacotron(c), val_type=bool) +# check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int) + +# if c['model'].lower() in ['tacotron', 'tacotron2']: +# # stopnet +# # check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool) +# # check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool) + +# # Model Parameters for non-tacotron models +# if c['model'].lower in ["speedy_speech", "align_tts"]: +# check_argument('positional_encoding', c, restricted=True, val_type=type) +# check_argument('encoder_type', c, restricted=True, val_type=str) +# check_argument('encoder_params', c, restricted=True, val_type=dict) +# check_argument('decoder_residual_conv_bn_params', c, restricted=True, val_type=dict) + +# # GlowTTS parameters +# check_argument('encoder_type', c, restricted=not is_tacotron(c), val_type=str) + +# # tensorboard +# # check_argument('print_step', c, restricted=True, val_type=int, min_val=1) +# # check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1) +# # check_argument('save_step', c, restricted=True, val_type=int, min_val=1) +# # check_argument('checkpoint', c, restricted=True, val_type=bool) +# # check_argument('tb_model_param_stats', c, restricted=True, val_type=bool) + +# # dataloading +# # pylint: disable=import-outside-toplevel +# from TTS.tts.utils.text import cleaners +# # check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners)) +# # check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool) +# # check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0) +# # check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0) +# # check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0) +# # check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0) +# # check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10) +# # check_argument('compute_input_seq_cache', c, restricted=True, val_type=bool) + +# # paths +# # check_argument('output_path', c, restricted=True, val_type=str) + +# # multi-speaker and gst +# # check_argument('use_speaker_embedding', c, restricted=True, val_type=bool) +# # check_argument('use_external_speaker_embedding_file', c, restricted=c['use_speaker_embedding'], val_type=bool) +# # check_argument('external_speaker_embedding_file', c, restricted=c['use_external_speaker_embedding_file'], val_type=str) +# if c['model'].lower() in ['tacotron', 'tacotron2'] and c['use_gst']: +# # check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool) +# # check_argument('gst', c, restricted=is_tacotron(c), val_type=dict) +# # check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict]) +# # check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000) +# # check_argument('gst_use_speaker_embedding', c['gst'], restricted=is_tacotron(c), val_type=bool) +# # check_argument('gst_num_heads', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10) +# # check_argument('gst_num_style_tokens', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000) + +# # datasets - checking only the first entry +# # check_argument('datasets', c, restricted=True, val_type=list) +# # for dataset_entry in c['datasets']: +# # check_argument('name', dataset_entry, restricted=True, val_type=str) +# # check_argument('path', dataset_entry, restricted=True, val_type=str) +# # check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list]) +# # check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str) diff --git a/TTS/utils/arguments.py b/TTS/utils/arguments.py index 4f6e231762..364baaf9b7 100644 --- a/TTS/utils/arguments.py +++ b/TTS/utils/arguments.py @@ -8,12 +8,10 @@ import os import re -import torch - from TTS.tts.utils.text.symbols import parse_symbols from TTS.utils.console_logger import ConsoleLogger from TTS.utils.generic_utils import create_experiment_folder, get_git_branch -from TTS.utils.io import copy_model_files, load_config +from TTS.utils.io import copy_model_files from TTS.utils.tensorboard_logger import TensorboardLogger @@ -140,11 +138,11 @@ def process_args(args, config, tb_prefix): if not args.best_path: args.best_path = best_model # setup output paths and read configs - c = config.load_json(args.config_path) - if c.mixed_precision: + config.load_json(args.config_path) + if config.mixed_precision: print(" > Mixed precision mode is ON") - if not os.path.exists(c.output_path): - out_path = create_experiment_folder(c.output_path, c.run_name, + if not os.path.exists(config.output_path): + out_path = create_experiment_folder(config.output_path, config.run_name, args.debug) audio_path = os.path.join(out_path, "test_audios") # setup rank 0 process in distributed training @@ -157,7 +155,7 @@ def process_args(args, config, tb_prefix): # if model characters are not set in the config file # save the default set to the config file for future # compatibility. - if c.has('characters_config'): + if config.has('characters_config'): used_characters = parse_symbols() new_fields["characters"] = used_characters copy_model_files(c, args.config_path, out_path, new_fields) @@ -166,6 +164,6 @@ def process_args(args, config, tb_prefix): log_path = out_path tb_logger = TensorboardLogger(log_path, model_name=tb_prefix) # write model desc to tensorboard - tb_logger.tb_add_text("model-description", c["run_description"], 0) + tb_logger.tb_add_text("model-description", config["run_description"], 0) c_logger = ConsoleLogger() return c, out_path, audio_path, c_logger, tb_logger diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index a3a604df2b..8730703288 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -1,6 +1,10 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- import datetime import glob +import importlib import os +import re import shutil import subprocess import sys @@ -67,6 +71,20 @@ def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) +def to_camel(text): + text = text.capitalize() + text = re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) + text = text.replace('Tts', 'TTS') + return text + + +def find_module(module_path: str, module_name: str) -> object: + module_name = module_name.lower() + module = importlib.import_module(module_path+'.'+module_name) + class_name = to_camel(module_name) + return getattr(module, class_name) + + def get_user_data_dir(appname): if sys.platform == "win32": import winreg # pylint: disable=import-outside-toplevel @@ -139,32 +157,3 @@ def update_values(self, value_dict): for key, value in value_dict.items(): self.update_value(key, value) - -def check_argument(name, - c, - prerequest=None, - enum_list=None, - max_val=None, - min_val=None, - restricted=False, - alternative=None, - allow_none=False): - if isinstance(prerequest, List()): - if any([f not in c.keys() for f in prerequest]): - return - else: - if prerequest not in c.keys(): - return - if alternative in c.keys() and c[alternative] is not None: - return - if allow_none and c[name] is None: - return - if restricted: - assert name in c.keys(), f" [!] {name} not defined in config.json" - if name in c.keys(): - if max_val: - assert c[name] <= max_val, f" [!] {name} is larger than max value {max_val}" - if min_val: - assert c[name] >= min_val, f" [!] {name} is smaller than min value {min_val}" - if enum_list: - assert c[name].lower() in enum_list, f' [!] {name} is not a valid value' diff --git a/TTS/utils/io.py b/TTS/utils/io.py index 12745459f0..2d5662eb93 100644 --- a/TTS/utils/io.py +++ b/TTS/utils/io.py @@ -3,6 +3,7 @@ import pickle as pickle_tts import re from shutil import copyfile +from TTS.utils.generic_utils import find_module import yaml @@ -23,32 +24,37 @@ def __init__(self, *args, **kwargs): self.__dict__ = self -# def read_json_with_comments(json_path): -# # fallback to json -# with open(json_path, "r", encoding="utf-8") as f: -# input_str = f.read() -# # handle comments -# input_str = re.sub(r'\\\n', '', input_str) -# input_str = re.sub(r'//.*\n', '\n', input_str) -# data = json.loads(input_str) -# return data +def read_json_with_comments(json_path): + """DEPRECATED""" + # fallback to json + with open(json_path, "r", encoding="utf-8") as f: + input_str = f.read() + # handle comments + input_str = re.sub(r'\\\n', '', input_str) + input_str = re.sub(r'//.*\n', '\n', input_str) + data = json.loads(input_str) + return data -# def load_config(config_path: str) -> AttrDict: -# """Load config files and discard comments +def load_config(config_path: str) -> AttrDict: + """DEPRECATED: Load config files and discard comments -# Args: -# config_path (str): path to config file. -# """ -# config = AttrDict() - -# ext = os.path.splitext(config_path)[1] -# # if ext in (".yml", ".yaml"): -# # with open(config_path, "r", encoding="utf-8") as f: -# # data = yaml.safe_load(f) -# # else: -# data = read_json_with_comments(config_path) -# config.update(data) -# return config + Args: + config_path (str): path to config file. + """ + config_dict = AttrDict() + ext = os.path.splitext(config_path)[1] + if ext in (".yml", ".yaml"): + with open(config_path, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + else: + with open(config_path, "r", encoding="utf-8") as f: + input_str = f.read() + data = json.loads(input_str) + config_dict.update(data) + config_class = find_module('TTS.tts.configs', config_dict.model.lower()+'_config') + config = config_class() + config.from_dict(config_dict) + return def copy_model_files(c, config_file, out_path, new_fields):