diff --git a/deepspeed/runtime/lr_schedules.py b/deepspeed/runtime/lr_schedules.py index 515233851a1d..7846da12fdbd 100755 --- a/deepspeed/runtime/lr_schedules.py +++ b/deepspeed/runtime/lr_schedules.py @@ -706,8 +706,8 @@ def __init__(self, self.min_lrs = self._format_param(self.optimizer, warmup_min_lr, "min_lr") self.max_lrs = self._format_param(self.optimizer, warmup_max_lr, "max_lr") self.delta_lrs = [big - small for big, small in zip(self.max_lrs, self.min_lrs)] - self.warmup_num_steps = warmup_num_steps - self.inverse_log_warm_up = 1.0 / math.log(warmup_num_steps) + self.warmup_num_steps = max(2, warmup_num_steps) + self.inverse_log_warm_up = 1.0 / math.log(self.warmup_num_steps) self.last_batch_iteration = last_batch_iteration def get_lr(self):