From 57dddb1c2869a809c8874a178ce4adc356d6d8db Mon Sep 17 00:00:00 2001 From: begeekmyfriend Date: Tue, 25 Feb 2020 11:10:47 +0800 Subject: [PATCH] AMP training level no more than O1 Signed-off-by: begeekmyfriend --- hparams.py | 1 - models/fatchord_version.py | 2 +- train_wavernn.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/hparams.py b/hparams.py index 4597e03c..ff846d66 100644 --- a/hparams.py +++ b/hparams.py @@ -12,7 +12,6 @@ # set this to True if you are only interested in WaveRNN ignore_tts = True amp = True -amp_level = 'O2' # DSP --------------------------------------------------------------------------------------------------------------# diff --git a/models/fatchord_version.py b/models/fatchord_version.py index 0b379e8f..4a0e3cb6 100644 --- a/models/fatchord_version.py +++ b/models/fatchord_version.py @@ -229,7 +229,7 @@ def generate(self, mels, save_path, batched, target, overlap, mu_law): output = output[0] end = time.time() - print(f'{end - start}') + print(f'Elapsed {end - start} seconds') return save_wav(output[:wave_len], save_path) diff --git a/train_wavernn.py b/train_wavernn.py index 5a6c1ced..70ebaf49 100644 --- a/train_wavernn.py +++ b/train_wavernn.py @@ -33,7 +33,7 @@ def voc_train_loop(model, loss_func, optimizer, train_set, test_set, init_lr, fi epochs = int((total_steps - model.get_step()) // total_iters + 1) if hp.amp: - model, optimizer = amp.initialize(model, optimizer, opt_level=hp.amp_level) + model, optimizer = amp.initialize(model, optimizer, opt_level='O1') torch.backends.cudnn.benchmark = True