diff --git a/train_ms.py b/train_ms.py index 2dcef509a..1f1708d8e 100644 --- a/train_ms.py +++ b/train_ms.py @@ -191,6 +191,8 @@ def run(): optim_g.param_groups[0]["initial_lr"] = g_resume_lr if not optim_d.param_groups[0].get("initial_lr"): optim_d.param_groups[0]["initial_lr"] = d_resume_lr + if not optim_dur_disc.param_groups[0].get("initial_lr"): + optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr epoch_str = max(epoch_str, 1) global_step = (epoch_str - 1) * len(train_loader)